diff --git a/pyphi/visualize.py b/pyphi/visualize.py index 740048f03..414c79a43 100644 --- a/pyphi/visualize.py +++ b/pyphi/visualize.py @@ -10,6 +10,11 @@ from plotly import express as px from plotly import graph_objs as go from umap import UMAP +from tqdm.notebook import tqdm + +import networkx as nx +from networkx.drawing.nx_agraph import graphviz_layout, to_agraph +from IPython.display import Image from . import relations as rel @@ -42,9 +47,13 @@ def feature_matrix(ces, relations): return features -def get_coords(data, y=None, **params): +def get_coords(data, y=None, n_components=3, **params): umap = UMAP( - n_components=3, metric="euclidean", n_neighbors=30, min_dist=0.5, **params, + n_components=n_components, + metric="euclidean", + n_neighbors=30, + min_dist=0.5, + **params, ) return umap.fit_transform(data, y=y) @@ -78,29 +87,158 @@ def label_purview(mice): return make_label(mice.purview, node_labels=mice.node_labels) -def vertex_sizes(min_size, max_size, ces): - phis = np.array( - [(distinction.cause.phi, distinction.effect.phi) for distinction in ces] +def label_state(mice): + return [rel.maximal_state(mice)[0][node] for node in mice.purview] + + +def label_relation(relation): + relata = relation.relata + + relata_info = "
".join( + [ + f"{label_mechanism(mice)} / {label_purview(mice)} [{mice.direction.name}]" + for n, mice in enumerate(relata) + ] + ) + + relation_info = f"
Relation purview: {make_label(relation.purview, relation.subsystem.node_labels)}
Relation φ = {phi_round(relation.phi)}
" + + return relata_info + relation_info + + +def hovertext_mechanism(distinction): + return f"Distinction: {label_mechanism(distinction.cause)}
Cause: {label_purview(distinction.cause)}
Cause φ = {phi_round(distinction.cause.phi)}
Cause state: {[rel.maximal_state(distinction.cause)[0][i] for i in distinction.cause.purview]}
Effect: {label_purview(distinction.effect)}
Effect φ = {phi_round(distinction.effect.phi)}
Effect state: {[rel.maximal_state(distinction.effect)[0][i] for i in distinction.effect.purview]}" + + +def hovertext_purview(mice): + return f"Distinction: {label_mechanism(mice)}
Direction: {mice.direction.name}
Purview: {label_purview(mice)}
φ = {phi_round(mice.phi)}
State: {[rel.maximal_state(mice)[0][i] for i in mice.purview]}" + + +def hovertext_relation(relation): + relata = relation.relata + + relata_info = "".join( + [ + f"
Distinction {n}: {label_mechanism(mice)}
Direction: {mice.direction.name}
Purview: {label_purview(mice)}
φ = {phi_round(mice.phi)}
State: {[rel.maximal_state(mice)[0][i] for i in mice.purview]}
" + for n, mice in enumerate(relata) + ] ) + + relation_info = f"
Relation purview: {make_label(relation.purview, relation.subsystem.node_labels)}
Relation φ = {phi_round(relation.phi)}
" + + return f"
={len(relata)}-Relation=
" + relata_info + relation_info + + +def normalize_sizes(min_size, max_size, elements): + phis = np.array([element.phi for element in elements]) min_phi = phis.min() max_phi = phis.max() return min_size + (((phis - min_phi) * (max_size - min_size)) / (max_phi - min_phi)) -def plot_relations( +def phi_round(phi): + return np.round(phi, 4) + + +def chunk_list(my_list, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(my_list), n): + yield my_list[i : i + n] + + +def format_node(n, subsystem): + node_format = { + "label": subsystem.node_labels[n], + "style": "filled" if subsystem.state[n] == 1 else "", + "fillcolor": "black" if subsystem.state[n] == 1 else "", + "fontcolor": "white" if subsystem.state[n] == 1 else "black", + } + return node_format + + +def save_digraph( + subsystem, digraph_filename="digraph.png", plot_digraph=False, layout="dot" +): + + G = nx.DiGraph() + + for n in range(subsystem.size): + node_info = format_node(n, subsystem) + G.add_node( + node_info["label"], + style=node_info["style"], + fillcolor=node_info["fillcolor"], + fontcolor=node_info["fontcolor"], + ) + + edges = [subsystem.indices2nodes(indices) for indices in np.argwhere(subsystem.cm)] + + G.add_edges_from(edges) + G.graph["node"] = {"shape": "circle"} + + A = to_agraph(G) + A.layout(layout) + A.draw(digraph_filename) + if plot_digraph: + return Image(digraph_filename) + + +def get_edge_color(relation): + p0 = list(relation.relata.purviews)[0] + p1 = list(relation.relata.purviews)[1] + rp = relation.purview + # Isotext (mutual full-overlap) + if p0 == p1 == rp: + return "fuchsia" + # Sub/Supertext (inclusion / full-overlap) + elif p0 != p1 and (all(n in p1 for n in p0) or all(n in p0 for n in p1)): + return "indigo" + # Paratext (connection / partial-overlap) + elif (p0 == p1 != rp) or ( + any(n in p1 for n in p0) and not all(n in p1 for n in p0) + ): + return "cyan" + else: + raise ValueError("Unexpected relation type, check function to cover all cases") + + +def plot_ces( + subsystem, ces, relations, max_order=3, - cause_effect_offset=(0.5, 0, 0), - vertex_size_range=(10, 30), + cause_effect_offset=(0.3, 0, 0), + vertex_size_range=(10, 40), + edge_size_range=(0.5, 4), + surface_size_range=(0.005, 0.1), + plot_dimentions=(1000, 1600), + mechanism_labels_size=20, + purview_labels_size=15, + show_mechanism_labels=True, + show_purview_labels="legendonly", + show_vertices_mechanisms=True, + show_vertices_purviews=True, + show_edges="legendonly", + show_mesh="legendonly", + show_node_qfolds=False, + show_mechanism_qfolds=True, + show_grid=False, + network_name="", + eye_coordinates=(0.5, 0.5, 0.5), + hovermode="x", + digraph_filename="digraph.png", + digraph_layout="dot", + save_plot_to_html=True, + show_causal_model=True, + order_on_z_axis=True, ): # Select only relations <= max_order relations = list(filter(lambda r: len(r.relata) <= max_order, relations)) # Separate CES into causes and effects separated_ces = rel.separate_ces(ces) - # Initialize figure data - figure_data = [] + # Initialize figure + fig = go.Figure() # Dimensionality reduction # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -115,7 +253,12 @@ def plot_relations( # NOTE: This depends on the implementation of `separate_ces`; causes and # effects are assumed to be adjacent in the returned list umap_features = features[0::2] + features[1::2] - distinction_coords = get_coords(umap_features) + if order_on_z_axis: + distinction_coords = get_coords(umap_features, n_components=2) + cause_effect_offset = cause_effect_offset[:2] + + else: + distinction_coords = get_coords(umap_features) # Duplicate causes and effects so they can be plotted separately coords = np.empty( (distinction_coords.shape[0] * 2, distinction_coords.shape[1]), @@ -129,103 +272,428 @@ def plot_relations( # Purviews # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Extract vertex indices for plotly - x, y, z = coords[:, 0], coords[:, 1], coords[:, 2] + x, y = coords[:, 0], coords[:, 1] + if order_on_z_axis: + z = np.array([len(c.mechanism) for c in separated_ces]) + else: + z = coords[:, 2] + + # Get node labels and indices for future use: + node_labels = subsystem.node_labels + node_indices = subsystem.node_indices + # Get mechanism and purview labels - mechanism_labels = list(map(label_mechanism, separated_ces)) + mechanism_labels = list(map(label_mechanism, ces)) + mechanism_labels_x2 = list(map(label_mechanism, separated_ces)) purview_labels = list(map(label_purview, separated_ces)) - # Make the labels - labels = go.Scatter3d( - x=x, y=y, z=z, mode="text", text=purview_labels, name="Labels", showlegend=True, + + mechanism_hovertext = list(map(hovertext_mechanism, ces)) + vertices_hovertext = list(map(hovertext_purview, separated_ces)) + + # Make mechanism labels + xm, ym, zm = ( + [c + cause_effect_offset[0] / 2 for c in x[::2]], + y[::2], + z[::2], + # [n + (vertex_size_range[1] / 10 ** 3) for n in z[::2]], + ) + labels_mechanisms_trace = go.Scatter3d( + visible=show_mechanism_labels, + x=xm, + y=ym, + z=[n + (vertex_size_range[1] / 10 ** 3) for n in zm], + mode="text", + text=mechanism_labels, + name="Mechanism Labels", + showlegend=True, + textfont=dict(size=mechanism_labels_size, color="black"), + hoverinfo="text", + hovertext=mechanism_hovertext, + hoverlabel=dict(bgcolor="black", font_color="white"), + ) + fig.add_trace(labels_mechanisms_trace) + + # Compute purview and mechanism marker sizes + purview_sizes = normalize_sizes( + vertex_size_range[0], vertex_size_range[1], separated_ces ) - figure_data.append(labels) - # Compute size and color - size = vertex_sizes(vertex_size_range[0], vertex_size_range[1], ces) + mechanism_sizes = [min(phis) for phis in chunk_list(purview_sizes, 2)] + + # Make mechanisms trace + vertices_mechanisms_trace = go.Scatter3d( + visible=show_vertices_mechanisms, + x=xm, + y=ym, + z=zm, + mode="markers", + name="Mechanisms", + text=mechanism_labels, + showlegend=True, + marker=dict(size=mechanism_sizes, color="black"), + hoverinfo="text", + hovertext=mechanism_hovertext, + hoverlabel=dict(bgcolor="black", font_color="white"), + ) + fig.add_trace(vertices_mechanisms_trace) + + # Make purview labels trace color = list(flatten([("red", "green")] * len(ces))) - vertices = go.Scatter3d( + labels_purviews_trace = go.Scatter3d( + visible=show_purview_labels, + x=x, + y=y, + z=[n + (vertex_size_range[1] / 10 ** 3) for n in z], + mode="text", + text=purview_labels, + name="Purview Labels", + showlegend=True, + textfont=dict(size=purview_labels_size, color=color), + hoverinfo="text", + hovertext=vertices_hovertext, + hoverlabel=dict(bgcolor=color), + ) + fig.add_trace(labels_purviews_trace) + + # Make purviews trace + purview_phis = [purview.phi for purview in separated_ces] + direction_labels = list(flatten([["Cause", "Effect"] for c in ces])) + vertices_purviews_trace = go.Scatter3d( + visible=show_vertices_purviews, x=x, y=y, z=z, mode="markers", name="Purviews", + text=purview_labels, showlegend=True, - marker=dict(size=size, color=color), + marker=dict(size=purview_sizes, color=color), + hoverinfo="text", + hovertext=vertices_hovertext, + hoverlabel=dict(bgcolor=color), ) - figure_data.append(vertices) + fig.add_trace(vertices_purviews_trace) + + # Initialize lists for legend + legend_nodes = [] + legend_mechanisms = [] # 2-relations # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Get edges from all relations - edges = list( - flatten( - relation_vertex_indices(features, j) - for j in range(features.shape[1]) - if features[:, j].sum() == 2 - ) - ) - if edges: - # Convert to DataFrame - edges = pd.DataFrame( - dict( - x=x[edges], - y=y[edges], - z=z[edges], - line_group=flatten(zip(range(len(edges) // 2), range(len(edges) // 2))), + if show_edges: + # Get edges from all relations + edges = list( + flatten( + relation_vertex_indices(features, j) + for j in range(features.shape[1]) + if features[:, j].sum() == 2 ) ) - # Plot edges - edge_figure = px.line_3d(edges, x="x", y="y", z="z", line_group="line_group") - figure_data.extend(edge_figure.data) + if edges: + # Convert to DataFrame + edges = pd.DataFrame( + dict( + x=x[edges], + y=y[edges], + z=z[edges], + line_group=flatten( + zip(range(len(edges) // 2), range(len(edges) // 2)) + ), + ) + ) + + # Plot edges separately: + two_relations = list(filter(lambda r: len(r.relata) == 2, relations)) + two_relations_sizes = normalize_sizes( + edge_size_range[0], edge_size_range[1], two_relations + ) + + two_relations_coords = [ + list(chunk_list(list(edges["x"]), 2)), + list(chunk_list(list(edges["y"]), 2)), + list(chunk_list(list(edges["z"]), 2)), + ] + + for r, relation in tqdm( + enumerate(two_relations), + desc="Computing edges", + total=len(two_relations), + ): + relation_nodes = list(flatten(relation.mechanisms)) + relation_color = get_edge_color(relation) + + # Make node contexts traces and legendgroups + if show_node_qfolds: + for node in node_indices: + node_label = make_label([node], node_labels) + if node in relation_nodes: + + edge_two_relation_trace = go.Scatter3d( + visible=show_edges, + legendgroup=f"Node {node_label} q-fold", + showlegend=True if node not in legend_nodes else False, + x=two_relations_coords[0][r], + y=two_relations_coords[1][r], + z=two_relations_coords[2][r], + mode="lines", + name=f"Node {node_label} q-fold", + line_width=two_relations_sizes[r], + line_color=relation_color, + hoverinfo="text", + hovertext=hovertext_relation(relation), + ) + fig.add_trace(edge_two_relation_trace) + + if node not in legend_nodes: + + legend_nodes.append(node) + + # Make nechanism contexts traces and legendgroups + if show_mechanism_qfolds: + mechanisms_list = [distinction.mechanism for distinction in ces] + for mechanism in mechanisms_list: + mechanism_label = make_label(mechanism, node_labels) + if mechanism in relation.mechanisms: + + edge_two_relation_trace = go.Scatter3d( + visible=show_edges, + legendgroup=f"Mechanism {mechanism_label} q-fold", + showlegend=True + if mechanism_label not in legend_mechanisms + else False, + x=two_relations_coords[0][r], + y=two_relations_coords[1][r], + z=two_relations_coords[2][r], + mode="lines", + name=f"Mechanism {mechanism_label} q-fold", + line_width=two_relations_sizes[r], + line_color=relation_color, + hoverinfo="text", + hovertext=hovertext_relation(relation), + ) + fig.add_trace(edge_two_relation_trace) + + if mechanism_label not in legend_mechanisms: + + legend_mechanisms.append(mechanism_label) + + # Make all 2-relations traces and legendgroup + edge_two_relation_trace = go.Scatter3d( + visible=show_edges, + legendgroup="All 2-Relations", + showlegend=True if r == 0 else False, + x=two_relations_coords[0][r], + y=two_relations_coords[1][r], + z=two_relations_coords[2][r], + mode="lines", + # name=label_relation(relation), + name="All 2-Relations", + line_width=two_relations_sizes[r], + line_color=relation_color, + hoverinfo="text", + hovertext=hovertext_relation(relation), + ) + + fig.add_trace(edge_two_relation_trace) # 3-relations # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Get triangles from all relations - triangles = [ - relation_vertex_indices(features, j) - for j in range(features.shape[1]) - if features[:, j].sum() == 3 - ] - if triangles: - # Extract triangle indices - i, j, k = zip(*triangles) - mesh = go.Mesh3d( - # x, y, and z are the coordinates of vertices - x=x, - y=y, - z=z, - # i, j, and k are the vertices of triangles - i=i, - j=j, - k=k, - # Intensity of each vertex, which will be interpolated and color-coded - intensity=np.linspace(0, 1, len(x), endpoint=True), - opacity=0.2, - colorscale="viridis", - showscale=False, - name="3-Relations", - showlegend=True, - ) - figure_data.append(mesh) + if show_mesh: + triangles = [ + relation_vertex_indices(features, j) + for j in range(features.shape[1]) + if features[:, j].sum() == 3 + ] - # Create figure + if triangles: + three_relations = list(filter(lambda r: len(r.relata) == 3, relations)) + three_relations_sizes = normalize_sizes( + surface_size_range[0], surface_size_range[1], three_relations + ) + # Extract triangle indices + i, j, k = zip(*triangles) + for r, triangle in tqdm( + enumerate(triangles), desc="Computing triangles", total=len(triangles) + ): + relation = three_relations[r] + relation_nodes = list(flatten(relation.mechanisms)) + + if show_node_qfolds: + for node in node_indices: + node_label = make_label([node], node_labels) + if node in relation_nodes: + triangle_three_relation_trace = go.Mesh3d( + visible=show_mesh, + legendgroup=f"Node {node_label} q-fold", + showlegend=True if node not in legend_nodes else False, + # x, y, and z are the coordinates of vertices + x=x, + y=y, + z=z, + # i, j, and k are the vertices of triangles + i=[i[r]], + j=[j[r]], + k=[k[r]], + # Intensity of each vertex, which will be interpolated and color-coded + intensity=np.linspace(0, 1, len(x), endpoint=True), + opacity=three_relations_sizes[r], + colorscale="viridis", + showscale=False, + name=f"Node {node_label} q-fold", + hoverinfo="text", + hovertext=hovertext_relation(relation), + ) + fig.add_trace(triangle_three_relation_trace) + + if node not in legend_nodes: + + legend_nodes.append(node) + + if show_mechanism_qfolds: + mechanisms_list = [distinction.mechanism for distinction in ces] + for mechanism in mechanisms_list: + mechanism_label = make_label(mechanism, node_labels) + if mechanism in relation.mechanisms: + triangle_three_relation_trace = go.Mesh3d( + visible=show_mesh, + legendgroup=f"Mechanism {mechanism_label} q-fold", + showlegend=True + if mechanism_label not in legend_mechanisms + else False, + # x, y, and z are the coordinates of vertices + x=x, + y=y, + z=z, + # i, j, and k are the vertices of triangles + i=[i[r]], + j=[j[r]], + k=[k[r]], + # Intensity of each vertex, which will be interpolated and color-coded + intensity=np.linspace(0, 1, len(x), endpoint=True), + opacity=three_relations_sizes[r], + colorscale="viridis", + showscale=False, + name=f"Mechanism {mechanism_label} q-fold", + hoverinfo="text", + hovertext=hovertext_relation(relation), + ) + fig.add_trace(triangle_three_relation_trace) + if mechanism_label not in legend_mechanisms: + legend_mechanisms.append(mechanism_label) + + triangle_three_relation_trace = go.Mesh3d( + visible=show_mesh, + legendgroup="All 3-Relations", + showlegend=True if r == 0 else False, + # x, y, and z are the coordinates of vertices + x=x, + y=y, + z=z, + # i, j, and k are the vertices of triangles + i=[i[r]], + j=[j[r]], + k=[k[r]], + # Intensity of each vertex, which will be interpolated and color-coded + intensity=np.linspace(0, 1, len(x), endpoint=True), + opacity=three_relations_sizes[r], + colorscale="viridis", + showscale=False, + name="All 3-Relations", + hoverinfo="text", + hovertext=hovertext_relation(relation), + ) + fig.add_trace(triangle_three_relation_trace) + + # Create figure # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - axis = dict( - showbackground=True, - showline=True, - zeroline=True, - showgrid=True, - gridcolor="lightgray", - showticklabels=False, - title="", - showspikes=True, - autorange=True, - backgroundcolor="white", - ) + axes_range = [(min(d) - 1, max(d) + 1) for d in (x, y, z)] + + axes = [ + dict( + showbackground=False, + showline=False, + zeroline=False, + showgrid=show_grid, + gridcolor="lightgray", + showticklabels=False, + showspikes=True, + autorange=False, + range=axes_range[dimension], + backgroundcolor="white", + title="", + ) + for dimension in range(3) + ] + layout = go.Layout( showlegend=True, - scene=dict(xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis)), - hovermode="closest", - title="", - autosize=True, + scene_xaxis=axes[0], + scene_yaxis=axes[1], + scene_zaxis=axes[2], + scene_camera=dict( + eye=dict(x=eye_coordinates[0], y=eye_coordinates[1], z=eye_coordinates[2]) + ), + hovermode=hovermode, + title=f"{network_name} Q-STRUCTURE", + title_font_size=30, + legend=dict( + title=dict( + text="Trace legend (click trace to show/hide):", + font=dict(color="black", size=15), + ) + ), + autosize=False, + height=plot_dimentions[0], + width=plot_dimentions[1], ) - # Merge figures - return go.Figure(data=figure_data, layout=layout) + + # Apply layout + fig.layout = layout + + if show_causal_model: + # Create system image + # TODO check why it doesn't show if you write the img to html + save_digraph(subsystem, digraph_filename, layout=digraph_layout) + digraph_coords = (-0.35, 1) + digraph_size = (0.3, 0.4) + + fig.add_layout_image( + dict( + name="Causal model", + source=digraph_filename, + # xref="paper", yref="paper", + x=digraph_coords[0], + y=digraph_coords[1], + sizex=digraph_size[0], + sizey=digraph_size[1], + xanchor="left", + yanchor="top", + ) + ) + + draft_template = go.layout.Template() + draft_template.layout.annotations = [ + dict( + name="Causal model", + text="Causal model", + opacity=1, + font=dict(color="black", size=20), + xref="paper", + yref="paper", + x=digraph_coords[0], + y=digraph_coords[1] + 0.05, + xanchor="left", + yanchor="bottom", + showarrow=False, + ) + ] + fig.update_layout( + margin=dict(l=400), + template=draft_template, + annotations=[dict(templateitemname="Causal model", visible=True)], + ) + + if save_plot_to_html: + plotly.io.write_html(fig, f"{network_name}_CES.html") + + return fig