import networkx as nx
import os.path as op
from typing import Sequence
from .database import RoamFile, default_database, load
from pyvis.network import Network
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
from subprocess import run, PIPE
from typing import List, Optional
import numpy as np
[docs]def get_graph(files: Sequence[RoamFile]):
G = nx.MultiDiGraph()
for f in files:
G.add_node(f.filename, file=f.filename, title=f.title,
size=min(30, 10 + 2 * (len(f.links_to) + len(f.links_from))))
for f in files:
for target in f.links_to:
if isinstance(target, RoamFile):
G.add_edge(f.filename, target.filename)
return G
[docs]def fig_plotly(full_graph: nx.MultiGraph, ref_nodes=None, n_neighbour=1):
if ref_nodes is None:
ref_nodes = ()
use_nodes = set(ref_nodes)
for _ in range(n_neighbour):
for node in set(use_nodes):
use_nodes.update(full_graph[node])
G = full_graph.subgraph(use_nodes)
prev_pos = getattr(fig_plotly, '_pos', {})
if len(prev_pos) == 0:
prev_pos = None
pos = nx.spring_layout(G, pos=prev_pos)
fig_plotly._pos = pos # type: ignore
edge_x: List[Optional[np.ndarray]] = []
edge_y: List[Optional[np.ndarray]] = []
for edge in G.edges():
pos1 = pos[edge[0]]
pos2 = pos[edge[1]]
edge_x.extend((pos1[0], pos2[0], None))
edge_y.extend((pos1[1], pos2[1], None))
node_x = []
node_y = []
for node in G.nodes():
node_x.append(pos[node][0])
node_y.append(pos[node][1])
node_trace = go.Scatter(
x=[pos[node][0] for node in G.nodes()],
y=[pos[node][1] for node in G.nodes()],
mode='markers+text',
marker=dict(
showscale=True,
color=['red' if node in ref_nodes else 'blue' for node in G.nodes()],
size=[min(40, full_graph.degree(node) * 2 + 10) for node in G.nodes()],
line_width=2),
text=[G.nodes[node]['title'] for node in G.nodes()],
hovertext=[G.nodes[node]['file'] for node in G.nodes()],
textposition="bottom center",
)
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines'
)
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
title='Network graph',
titlefont_size=16,
showlegend=False,
hovermode='closest',
clickmode='event+select',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
)
return fig
[docs]def show_pyvis(G: nx.MultiDiGraph):
nt = Network()
nt.from_nx(G)
nt.toggle_physics(True)
#nt.show_buttons(True)
nt.show("nx.html")
[docs]def show(database, debug=False):
app = dash.Dash("org-roam", external_stylesheets=['https://codepen.io/chriddyp/pen/bWLwgP.css'])
app.layout = html.Div(children=[
html.H1(children='Knowledge database'),
dcc.Dropdown(options=[], multi=True, id='selection'),
dcc.Interval(id='refresh', interval=1000),
dcc.Graph(id='network', figure=go.Figure()),
dcc.Slider(min=0, max=10, step=1, value=1, id='adjacency'),
html.Br(),
html.Div(id='text'),
])
graph = [None, None, None]
@app.callback(
Output('selection', 'options'),
[Input('refresh', 'n_intervals')],
)
def read_database(_n_intervals):
if op.getmtime(database) == graph[1]:
return graph[2]
if graph[1] == 'loading':
if graph[2] is None:
return []
return graph[2]
edited = op.getmtime(database)
graph[1] = 'loading'
files = load(database)
graph[0] = get_graph(files).to_undirected()
res = [{'label': graph[0].nodes[node]['title'], 'value': node} for node in graph[0].nodes()]
graph[2] = res
graph[1] = edited
return res
@app.callback(
Output(component_id='network', component_property='figure'),
[Input(component_id='selection', component_property='value'),
Input(component_id='adjacency', component_property='value')]
)
def update_output_div(input_value, n_neighbour):
if graph[0] is None:
return go.Figure()
return fig_plotly(graph[0], input_value, int(n_neighbour))
@app.callback(
Output(component_id='selection', component_property='value'),
[Input(component_id='network', component_property='clickData')],
[State(component_id='selection', component_property='value')],
)
def update_output_div(clickData, value):
if clickData is None or value is None:
return value
if 'hovertext' not in clickData['points'][0]:
return value
file = clickData['points'][0]['hovertext']
if file in value:
value.remove(file)
else:
value.append(file)
return value
@app.callback(
Output(component_id='text', component_property='children'),
[Input(component_id='network', component_property='hoverData')],
[State(component_id='text', component_property='children')],
)
def show_text(hoverData, old_children):
if hoverData is None:
return
file = hoverData['points'][0]['hovertext']
if getattr(show_text, '_file', None) == file:
return old_children
run(['emacsclient', f'org-protocol://roam-file?file={file}'])
show_text._file = file
res = run([
"pandoc", "--mathjax",
"-f", "org", "-t", "markdown",
file
], stdout=PIPE)
return [html.Pre(file), html.H2(graph[0].nodes[file]['title']), dcc.Markdown(res.stdout.decode())]
app.run_server(debug=debug)