# Copyright 2025 ICube (University of Strasbourg - CNRS)
# author: Julien PONTABRY (ICube)
#
# This software is a computer program whose purpose is to provide a toolkit
# to model, process and analyze the longitudinal reorganization of brain
# connectivity data, as functional MRI for instance.
#
# This software is governed by the CeCILL-B license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/or redistribute the software under the terms of the CeCILL-B
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
"""Plotting the spatio-temporal graphs."""
from dataclasses import dataclass
from functools import cache
from math import isclose
from typing import Callable, Any
import networkx as nx
import numpy as np
import pandas as pd
from matplotlib import colormaps as cm
from matplotlib import pyplot as plt
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.backend_bases import MouseEvent, Event
from matplotlib.collections import LineCollection
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.patches import FancyArrowPatch
from matplotlib.text import Text
from matplotlib.widgets import Cursor, RangeSlider
from .graph import SpatioTemporalGraph, RC5
def __time_multipartite_layout(g: SpatioTemporalGraph, dist: float = 1.0) -> dict[int, tuple[int, float]]:
"""Create coordinates for all the nodes of the spatio-temporal graph.
The nodes are placed vertical by time point.
Parameters
----------
g: SpatioTemporalGraph
The spatio-temporal graph.
dist: float, optional
The distance factor for vertical space occupied by the nodes. Default is 1.
Returns
-------
dict[int, tuple[int, float]]
A dictionary of coordinates associated with nodes.
"""
pos = {}
for t in range(g.graph['min_time'],
g.graph['max_time'] + 1):
sub_g = g.sub(t=t)
nodes = sorted(sub_g.nodes)
half_height = dist * (len(nodes) - 1) / 2
heights = np.linspace(-half_height, half_height, len(nodes))
for n, height in zip(nodes, heights):
pos[n] = (t, height)
return pos
[docs]
def multipartite_plot(g: SpatioTemporalGraph, ax: Axes = None) -> None:
"""Draw a multipartite plot for the spatio-temporal graph.
Parameters
----------
g: SpatioTemporalGraph
The spatio-temporal graph.
ax: matplotlib.axes.Axes, optional
The axes on which to plot. If not set, the current axes will be used.
"""
if ax is None:
ax = plt.gca()
pos = __time_multipartite_layout(g)
node_color = [d['internal_strength']
for _, d in g.nodes.items()]
edge_color = []
edge_labels = {}
edge_widths = []
for e, d in g.edges.items():
if d['type'] == 'temporal':
edge_color.append('red')
edge_widths.append(2)
edge_labels[e] = d['transition']
else:
edge_color.append('limegreen')
edge_widths.append(np.abs(d['correlation']) * 4)
edge_labels[e] = d['correlation']
# draw the graph
nx.draw_networkx(g, ax=ax, pos=pos, with_labels=True, node_color=node_color,
cmap='coolwarm', vmin=-1, vmax=1, edge_color=edge_color,
connectionstyle='arc3', width=edge_widths, hide_ticks=False)
nx.draw_networkx_edge_labels(g, ax=ax, pos=pos, edge_labels=edge_labels,
connectionstyle='arc3', hide_ticks=False)
# set up the axes
ax.spines[['left', 'top', 'right']].set_visible(False)
min_time, max_time = g.graph['min_time'], g.graph['max_time']
ax.get_xaxis().tick_bottom()
ax.set_xticks(range(min_time, max_time + 1))
ax.set_xlabel("Time")
ax.set_xlim(min_time - 0.125, max_time + 0.125)
ax.set_yticks([])
def __polar2cart(angles: np.array, distance: float) -> tuple[np.array, np.array]:
"""Calculate the cartesian coordinates from polar ones.
Parameters
----------
angles: np.ndarray
The angles in radian.
distance: float
The distance in matplotlib's unit.
Returns
-------
tuple[np.ndarray, np.ndarray]
The cartesian coordinates.
"""
pts = distance * np.exp(1j * angles)
return np.real(pts), np.imag(pts)
def __readable_angled_annotation(angle: float) -> dict[str, float | str]:
"""Get annotation's properties depending on the display angle.
Parameters
----------
angle: float
The angle in degrees.
Returns
-------
dict[str, float | str]
The properties for areas annotation.
"""
if angle <= 90 or angle >= 270:
return {'rotation': angle, 'ha': 'left'}
else:
return {'rotation': angle+180, 'ha': 'right'}
def __edge_con_style(angle1: float, angle2: float, bending: float = 5) -> str:
"""Get the connection style property for edges.
Parameters
----------
angle1: float
The angle of the first node in radians.
angle2: float
The angle of the second node in radians.
bending: float
The bending coefficient (close to 0 means fully bend and close to
infinity means straight).
Returns
-------
str
The appropriate connection style property.
"""
diff = angle1 - angle2
if diff > np.pi or diff < -np.pi:
sign = -np.sign(diff)
dist = (2*np.pi - abs(diff)) / np.pi
else:
sign = np.sign(diff)
dist = abs(diff) / np.pi
return f'arc3, rad={sign * (1 - dist) ** bending}'
def __annot_con_style(angle1: float, angle2: float) -> str:
"""Get the connection style property for annotations.
Parameters
----------
angle1: float
The angle of the line on the side of the annotation text in degrees.
angle2: float
The angle of the line on the network side in degrees.
Returns
-------
str
The appropriate connection style property.
"""
abs_diff = abs(angle1 - angle2)
diff = min(abs_diff, 360 - abs_diff)
if isclose(angle1, angle2) or isclose(diff, 180) or diff > 180:
return 'arc3'
else:
return f'angle, angleA={angle1}, angleB={angle2}, rad=0'
def __angle_between(vec1: tuple[float, float], vec2: tuple[float, float]) -> float:
"""Calculate the angle between two vectors.
Parameters
----------
vec1: tuple[float, float]
The first vector.
vec2: tuple[float, float]
The second vector.
Returns
-------
float
The angle in degrees.
"""
angle = np.arctan2(vec2[1] - vec1[1], vec2[0] - vec1[0])
if angle < 0:
angle += 2 * np.pi
return np.rad2deg(angle)
def __areas_positions(graph: SpatioTemporalGraph) -> tuple[pd.DataFrame, np.ndarray, np.ndarray, np.ndarray]:
"""Calculate positions for areas' labels in a circular manner.
Parameters
----------
graph: SpatioTemporalGraph
The graph to plot.
Returns
-------
rels: pandas.DataFrame
The sorted areas names.
angles: numpy.ndarray
The areas angles around the circle.
x_areas: numpy.ndarray
The cartesian coordinates of areas along the x-axis.
y_areas: numpy.ndarray
The cartesian coordinates of areas along the y-axis.
"""
rels = graph.areas.sort_values('Name_Region')
n = len(rels)
angles = 2 * np.pi / n * np.arange(n)
x_areas, y_areas = __polar2cart(angles, 1.5)
return rels, angles, x_areas, y_areas
def _spatial_plot_artists(graph: SpatioTemporalGraph, t: float,
edges_bending: float = 3) -> tuple[list[Line2D], list[FancyArrowPatch], list[FancyArrowPatch]]:
"""Generate artists for a spatial plotting of a spatio-temporal graph.
Parameters
----------
graph: SpatioTemporalGraph
The graph to plot.
t: float
The instant to plot in the graph.
edges_bending: float, optional
Controls the bending of the edges. Close to 0 means full bending, close to
infinity means no bending (default is 3).
Returns
-------
networks_markers: list[Line2D]
The artists for the nodes representing the networks.
areas_patches: list[FancyArrowPatch]
The arrows from areas to networks nodes.
edges_patches: list[FancyArrowPatch]
The patches representing the edges between the networks.
"""
sub_g = graph.sub(t=t)
rels, angles, x_areas, y_areas = __areas_positions(graph)
n = len(rels)
# networks node
cmap = cm.get_cmap('coolwarm')
nodes_angles = {}
nodes_coords = {}
areas_network_map = {}
networks_markers = []
for node, data in sub_g.nodes.items():
network = list(data['areas'])
indices = np.argwhere(np.isin(rels.index, network)).flatten().tolist()
angle = angles[indices].mean()
closest_node = min([(on, abs(a-angle)) for on, a in nodes_angles.items()],
default=(None, -1), key=lambda e: e[1])
if isclose(closest_node[1], 0):
angle += 2*np.pi/n/2
nodes_angles[node] = angle
x, y = __polar2cart(angle, 1)
corr = data['internal_strength']
eff = data['efficiency']
l = Line2D([x], [y], marker='o', mfc=cmap(corr / 2 + 0.5),
mec='k', ms=15*eff, zorder=4)
networks_markers.append(l)
nodes_coords[node] = (x, y)
areas_network_map |= {i: (x, y) for i in indices}
# areas' links
areas_patches = []
for i, (x_area, y_area) in enumerate(zip(x_areas, y_areas)):
to_node_angle = __angle_between(areas_network_map[i], __polar2cart(angles[i], 1.2))
angle = np.rad2deg(angles[i])
a = FancyArrowPatch(posA=(x_area, y_area), posB=areas_network_map[i],
arrowstyle='-', connectionstyle=__annot_con_style(angle, to_node_angle),
linestyle=':')
areas_patches.append(a)
# edges between networks
edges_patches = []
for (n1, n2), d in sub_g.edges.items():
e = FancyArrowPatch(posA=nodes_coords[n1], posB=nodes_coords[n2],arrowstyle='-',
connectionstyle=__edge_con_style(nodes_angles[n1], nodes_angles[n2], bending=edges_bending),
linewidth=np.abs(d['correlation'])*4, color=cmap(d['correlation']/2+0.5),
alpha=np.abs(d['correlation']))
edges_patches.append(e)
return networks_markers, areas_patches, edges_patches
def _spatial_plot_background(graph: SpatioTemporalGraph, ax: Axes = None, show_regions: bool = True) -> None:
"""Plot the background of a spatial plot.
Parameters
----------
graph: SpatioTemporalGraph
The graph to plot.
ax: matplotlib.axes.Axes, optional
The axes on which to plot. If not set, the current axes will be used.
show_regions: bool, optional
Flag to show (or not) the region labels.
"""
if ax is None:
ax = plt.gca()
ax.axis('off')
rels, angles, x_areas, y_areas = __areas_positions(graph)
regions = rels['Name_Region'].unique()
n = len(rels)
# plot regions in a pie
regions_cmap = cm.get_cmap('tab20')
ax.pie([len(rels[rels['Name_Region'] == region]) / n
for region in regions],
radius=2.25, startangle=-360 / n / 2,
labels=regions if show_regions else None,
labeldistance=1.1, rotatelabels=False,
colors=[regions_cmap(i) for i in range(len(regions))],
wedgeprops={'width': 1, 'edgecolor': 'w', 'alpha': 0.2})
limit_val = 3 if show_regions else 2.5
ax.set_xlim(-limit_val, limit_val)
ax.set_ylim(-limit_val, limit_val)
# plot areas' labels
for i, (x_area, y_area, area) in enumerate(zip(x_areas, y_areas, rels['Name_Area'])):
angle = np.rad2deg(angles[i])
ax.text(x=x_area, y=y_area, s=area,
va='center', fontsize='x-small', rotation_mode='anchor',
**__readable_angled_annotation(angle))
[docs]
def spatial_plot(graph: SpatioTemporalGraph, t: float, ax: Axes = None, edges_bending: float = 3) -> None:
"""Draw a spatial plot for spatio-temporal graph.
Parameters
----------
graph: SpatioTemporalGraph
The graph to plot.
t: float
The instant to plot in the graph.
ax: matplotlib.axes.Axes, optional
The axes on which to plot. If not set, the current axes will be used.
edges_bending: float, optional
Controls the bending of the edges. Close to 0 means full bending, close to
infinity means no bending (default is 3).
"""
_spatial_plot_background(graph, ax)
networks_markers, areas_arrows, edges_patches = _spatial_plot_artists(graph, t=t, edges_bending=edges_bending)
for network_marker in networks_markers:
ax.add_line(network_marker)
for area_arrow in areas_arrows:
ax.add_patch(area_arrow)
for edge_patch in edges_patches:
ax.add_patch(edge_patch)
class __CoordinatesGenerator:
"""Utility to generate temporal paths from a spatio-temporal graph."""
def __init__(self, graph: SpatioTemporalGraph) -> None:
self.g = graph
self.__max_heights = None
self.__coords = None
@cache
def __next_temp_trans(self, node: int) -> list[int]:
return [m for m in self.g[node]
if self.g[node][m]['type'] == 'temporal']
@cache
def __time_from_node(self, node: int) -> int:
return self.g.nodes[node]['t']
def __find_height_for_path(self, node: int, y: int) -> int:
current_max_y = y
trans = self.__next_temp_trans(node)
while trans != [] and (m := trans[0]) not in self.__coords:
t = self.__time_from_node(m)
if t in self.__max_heights:
current_max_y = max(current_max_y, self.__max_heights[t] + 1)
trans = self.__next_temp_trans(m)
return current_max_y
def __generate_coords_for_node(self, node: int, base_y: int) -> tuple[int, int]:
return self.__time_from_node(node), self.__find_height_for_path(node, base_y)
def __generate_coords_for_path_rec(self, trans_list: list[int], base: int) -> None:
base_y = base
for i, m in enumerate(trans_list):
if m not in self.__coords:
t, y = self.__generate_coords_for_node(m, base_y)
base_y += 1
self.__coords[m] = (t, y)
self.__max_heights[t] = y
next_trans = self.__next_temp_trans(m)
self.__generate_coords_for_path_rec(next_trans, y)
def generate(self, nodes: list[int], base_y: int) -> dict[int, tuple[int, int]]:
"""Generate the coordinates of the temporal paths.
Parameters
----------
nodes: list[int]
The nodes starting the paths.
base_y: int
The initial height location of the path.
Returns
-------
dict[int, tuple[int, int]]
A dictionary mapping a node to its time/height coordinates.
"""
self.__max_heights = {}
self.__coords = {n: (self.g.nodes[n]['t'], base_y + i) for i, n in enumerate(nodes)}
for i, node in enumerate(nodes):
trans_list = self.__next_temp_trans(node)
self.__generate_coords_for_path_rec(trans_list, base_y + i)
return self.__coords
@cache
def _trans_color(transition: RC5) -> str:
"""Defines the color to use for a given RC5 transition.
Parameters
----------
transition: RC5
The transition between two nodes.
Returns
-------
str
The transition's color name.
"""
if transition == RC5.PP:
return 'red'
elif transition == RC5.PPi:
return 'blue'
elif transition == RC5.PO:
return 'limegreen'
else:
return 'black'
class __PathDrawer:
"""Utility to draw temporal paths of a spatio-temporal graph on an axis."""
def __init__(self, g: SpatioTemporalGraph, axe: Axes) -> None:
self.g = g
self.axe = axe
self.__done = None
self.__lines = None
self.__colors = None
@cache
def __next_temp_trans(self, node: int) -> list[tuple[int, RC5]]:
return [(m, self.g[node][m]['transition'])
for m in self.g[node]
if self.g[node][m]['type'] == 'temporal']
def __draw_rec(self, coords: dict[int, tuple[int, int]], trans_list: list[tuple[int, RC5]],
node: int, prev_t: int, prev_y: int) -> None:
for m, rc5 in trans_list:
if m in coords and (node, m) not in self.__done:
t, y = coords[m]
self.__lines.append([(prev_t, prev_y), (t, y)])
self.__colors.append(_trans_color(rc5))
self.__done.add((node, m))
next_trans = self.__next_temp_trans(m)
self.__draw_rec(coords, next_trans, m, t, y)
def draw(self, coords: dict[int, tuple[int, int]], nodes: list[int], base_y: int) -> None:
"""Draw the temporal paths.
Parameters
----------
coords: dict[int, tuple[int, int]]
The coordinates of the nodes in the temporal paths.
nodes: list[int]
The nodes starting the paths.
base_y: int
The initial height location of the path.
"""
self.__done = set()
self.__lines = []
self.__colors = []
for i, node in enumerate(nodes):
trans_list = self.__next_temp_trans(node)
self.__draw_rec(coords, trans_list, node, self.g.nodes[node]['t'], base_y + i)
self.axe.add_collection(
LineCollection(self.__lines, colors=self.__colors,
linewidths=1.5, linestyles='-'))
[docs]
def temporal_plot(graph: SpatioTemporalGraph, ax: Axes = None) -> tuple[dict[int, tuple[int, int]], dict[tuple[int, int], int]]:
"""Draw a temporal plot for a spatio-temporal graph.
Parameters
----------
graph: SpatioTemporalGraph
The spatio-temporal graph.
ax: matplotlib.axes.Axes, optional
The axes on which to plot. If not set, the current axes will be used.
Returns
-------
tuple[dict[int, tuple[int, int]], dict[tuple[int, int], int]]
The first dictionary maps node identifiers to their coordinates (time, height).
The second dictionary maps coordinates (time, height) to node identifiers.
"""
if ax is None:
ax = plt.gca()
rels = graph.areas.sort_values('Name_Region')
regions = rels['Name_Region'].unique().tolist()
times = np.unique([d['t'] for n, d in graph.nodes.items()])
# draw dynamic (nodes + transitions)
cmap = cm.get_cmap('coolwarm')
sub_g = graph.sub(t=0)
heights = []
y = 0
gen = __CoordinatesGenerator(graph)
drawer = __PathDrawer(graph, ax)
all_coord: dict[int, tuple[int, int]] = {}
rev_coord: dict[tuple[int, int], int] = {}
for r, region in enumerate(regions):
nodes = [n for n, d in sub_g.nodes.items() if d['region'] == region]
coords = gen.generate(nodes, y)
drawer.draw(coords, nodes, y)
colors = [graph.nodes[n]['internal_strength'] for n in coords.keys()]
sizes = np.array([graph.nodes[n]['efficiency'] for n in coords.keys()])
ax.scatter(*list(zip(*coords.values())), zorder=2.1,
s=10*sizes**5, c=colors, cmap=cmap, edgecolors='k',
linewidths=0.1, vmin=-1, vmax=1)
heights.append(max(coords.values(), key=lambda x: x[1])[1] + 1 - y)
y += heights[-1] + 1
# save coordinates for later
all_coord.update(coords)
for n, c in coords.items():
rev_coord[c] = n
# draw limits of regions
o = 0
regions_cmap = cm.get_cmap('tab20')
ticks = []
for r, _ in enumerate(regions):
m = heights[r]
o += m
ax.fill_between([times.min() - 0.5, times.max() + 0.5], o - m - 1, o,
fc=regions_cmap(r), alpha=0.2)
ticks.append((o - m - 1 + o) / 2)
o += 1
# set up the axes
ax.spines[['left', 'top', 'right']].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.set_xlabel("Time")
ax.set_xlim(times.min()-1, times.max()+1)
ax.set_yticks(ticks, regions)
for tick in ax.get_yaxis().get_major_ticks():
tick.tick1line.set_visible(False)
ax.set_ylim(-1, sum(heights) + len(heights) - 1)
return all_coord, rev_coord
def _inch2cm(inch: float) -> float:
return inch / 2.54
[docs]
class DynamicTimeCursor(Cursor):
"""A dynamic cursor for time points."""
def __init__(self, axe: Axes, func: Callable[[int], None],
all_coord: dict[int, tuple[int, int]], rev_coord: dict[tuple[int, int], int],
graph: SpatioTemporalGraph, **lineprops):
super().__init__(ax=axe, horizOn=False, useblit=True, **lineprops)
self.__callback = func
self.__last_t = None
self.__all_coord = all_coord
self.__rev_coord = rev_coord
self.__graph = graph
self.__markers = []
def __get_connected_nodes_coord(self, n: int) -> list[tuple[int, int]] | None:
spatial_nodes = [sn for sn in self.__graph.adj[n]
if self.__graph.adj[n][sn]["type"] == "spatial"]
return [self.__all_coord[sn] for sn in spatial_nodes]
[docs]
def onmove(self, event):
if self.ignore(event):
return
if not self.canvas.widgetlock.available(self):
return
# clean the markers (if any)
for marker1 in self.__markers:
marker1.remove()
self.__markers.clear()
if isinstance(event, MouseEvent) and not self.ax.contains(event)[0]:
self.linev.set_visible(False)
self.lineh.set_visible(False)
if self.needclear:
self.canvas.draw()
self.needclear = False
return
# set up the time cursor
xdata, ydata = self.__setup_time_cursor(event)
if not (self.visible and (self.vertOn or self.horizOn)):
return
# set up the connected nodes
t, y = int(round(xdata)), int(round(ydata))
if n := self.__rev_coord.get((t, y)):
coord = self.__get_connected_nodes_coord(n)
self.__markers += self.ax.plot(*list(zip(*coord)), 'sr')
self.__markers += self.ax.plot(t, y, 'sg')
self.__redraw()
# callback with the time position
if self.__last_t != t:
self.__callback(t)
self.__last_t = t
event.xdata = t
def __redraw(self):
if self.useblit:
if self.background is not None:
self.canvas.restore_region(self.background)
# draw time cursor
self.ax.draw_artist(self.linev)
self.ax.draw_artist(self.lineh)
# draw connected nodes
for marker in self.__markers:
self.ax.draw_artist(marker)
self.canvas.blit(self.ax.bbox)
else:
self.canvas.draw_idle()
def __setup_time_cursor(self, event: Event | MouseEvent) -> tuple[Any, Any]:
self.needclear = True
xdata, ydata = self._get_data_coords(event)
self.linev.set_xdata((xdata, xdata))
self.linev.set_visible(self.visible and self.vertOn)
self.lineh.set_ydata((ydata, ydata))
self.lineh.set_visible(self.visible and self.horizOn)
return xdata, ydata
[docs]
@dataclass
class DynamicPlot:
"""A dynamic plot that contains both temporal and spatial plot with interactivity."""
graph: SpatioTemporalGraph
@property
def __networks_markers(self) -> list[Line2D]:
return list(self.spl_axe.lines)
@property
def __arrows_patches(self) -> list[FancyArrowPatch]:
return [p for p in self.spl_axe.patches if isinstance(p, FancyArrowPatch)]
@property
def __time_text(self) -> list[Text]:
return [] if self.time_text is None else [self.time_text]
def __create_figure(self, figure_setup: dict) -> None:
# use the recent toolbar
plt.rcParams['toolbar'] = 'toolmanager'
self.fig = plt.figure(layout='constrained', **figure_setup)
gs = GridSpec(nrows=1, ncols=2, figure=self.fig, width_ratios=[4, 3])
self.tpl_axe = self.fig.add_subplot(gs[0])
gs_side = gs[1].subgridspec(nrows=2, ncols=1, height_ratios=[40, 1])
self.spl_axe = self.fig.add_subplot(gs_side[0])
self.win_axe = self.fig.add_subplot(gs_side[1])
self.spl_bkd = None
self.time_text = None
def __on_window_resized(self) -> None:
self.spl_bkd = None
def __initialize_figure(self) -> None:
# define initial time to show
init_t = 0
# plot both temporal and spatial plots
self.__all_coord, self.__rev_coord = temporal_plot(self.graph, ax=self.tpl_axe)
_spatial_plot_background(self.graph, ax=self.spl_axe, show_regions=False)
# set the initial spatial plot display
self.__on_cursor_changed(init_t)
@staticmethod
def __remove_artists(artists: list[Artist]) -> None:
for a in artists:
a.remove()
def __add_artists(self, artists: list[Artist], adding_func: Callable[[Artist], None]) -> None:
for a in artists:
adding_func(a)
self.spl_axe.draw_artist(a)
def __modify_artists(self, old_artists: list[Artist], new_artists: list[Artist],
modify_func: Callable[[Artist, Artist], None]) -> None:
for ao, an in zip(old_artists, new_artists):
modify_func(ao, an)
self.spl_axe.draw_artist(ao)
@staticmethod
def __modify_networks_markers(ol: Line2D, nl: Line2D) -> None:
ol.set(data=nl.get_data(), mfc=nl.get_markerfacecolor(), ms=nl.get_markersize())
@staticmethod
def __modify_arrows_patches(olp: FancyArrowPatch, nwp: FancyArrowPatch) -> None:
vx = nwp.get_path().vertices
olp.set_positions(posA=(float(vx[0][0]), float(vx[0][1])), posB=(float(vx[-1][0]), float(vx[-1][1])))
olp.set(fc=nwp.get_fc(), ec=nwp.get_ec(), alpha=nwp.get_alpha(), lw=nwp.get_linewidth(),
connectionstyle=nwp.get_connectionstyle(), arrowstyle=nwp.get_arrowstyle(),
ls=nwp.get_linestyle())
def __update_artists(self, old_artists_lists: list[list[Artist]], new_artists_lists: list[list[Artist]],
modifiers: list[Callable[[Artist, Artist], None]], adders: list[Callable[[Artist], None]]) -> None:
# remove excess of old artists (and remember the initial count)
no = []
for oal, nal in zip(old_artists_lists, new_artists_lists):
no.append(len(oal))
self.__remove_artists(oal[len(nal):])
# restore the background
self.fig.canvas.restore_region(self.spl_bkd)
# modify reusable artists
for oal, nal, mod in zip(old_artists_lists, new_artists_lists, modifiers):
self.__modify_artists(oal, nal, mod)
# add excess of new artists
for nal, n, add in zip(new_artists_lists, no, adders):
self.__add_artists(nal[n:], add)
def __recreate_artists(self, old_artists_lists: list[list[Artist]], new_artists_lists: list[list[Artist]],
adders: list[Callable[[Artist], None]]) -> None:
# remove all old artists
for oal in old_artists_lists:
self.__remove_artists(oal)
# save a new background for spatial axe
self.fig.canvas.draw()
self.spl_bkd = self.fig.canvas.copy_from_bbox(self.spl_axe.bbox)
# add all new artists
for nal, add in zip(new_artists_lists, adders):
self.__add_artists(nal, add)
def __update_time_text(self, t: int) -> Text:
text = f"$t={t}$"
if self.time_text is None:
self.time_text = Text(x=0.5, y=0.99, text=text, fontsize='xx-large',
horizontalalignment='center', verticalalignment='top',
transform=self.spl_axe.transAxes)
else:
self.time_text.set_text(text)
return self.time_text
def __on_cursor_changed(self, t: int) -> None:
# get old/new artists
old_networks_markers = self.__networks_markers
old_arrows_patches = self.__arrows_patches
old_time_text = self.__time_text
new_networks_markers, new_areas_patches, new_edges_patches = _spatial_plot_artists(self.graph, t=t)
new_time_text = [self.__update_time_text(t)]
# update or recreate artists (if the background has been invalidated)
old_artists = [old_arrows_patches, old_networks_markers, old_time_text]
new_artists = [new_edges_patches + new_areas_patches, new_networks_markers, new_time_text]
adders = [self.spl_axe.add_patch, self.spl_axe.add_line, self.spl_axe.add_artist]
if self.spl_bkd is None:
self.__recreate_artists(old_artists, new_artists, adders)
else:
modifiers = [self.__modify_arrows_patches, self.__modify_networks_markers,
lambda o, n: None] # time text is already updated by __update_time_text
self.__update_artists(old_artists, new_artists, modifiers, adders)
# blit the spatial axes to draw only changed artists
self.fig.canvas.blit(self.spl_axe.bbox)
def __on_range_changed(self, vals: tuple[float, float]) -> None:
low, high = vals
self.tpl_axe.set_xlim(low-0.5, high+0.5)
def __on_tpl_limits_changed(self):
limits = self.tpl_axe.get_xlim()
l_min, l_max = round(limits[0]+0.5), round(limits[1])
r_min, r_max = self.fig.w_slider.val
if r_min != l_min or r_max != l_max:
self.fig.w_slider.set_val((l_min, l_max))
def __initialize_widgets(self) -> None:
self.fig.t_cursor = DynamicTimeCursor(self.tpl_axe, self.__on_cursor_changed,
self.__all_coord, self.__rev_coord, self.graph,
color='k', lw=0.8, ls='--')
time_range = self.graph.graph['min_time'], self.graph.graph['max_time']
self.fig.w_slider = RangeSlider(ax=self.win_axe, label="Window", valstep=1,
valinit=time_range, valmin=time_range[0], valmax=time_range[1])
self.fig.w_slider.on_changed(self.__on_range_changed)
self.tpl_axe.callbacks.connect('xlim_changed', lambda _: self.__on_tpl_limits_changed())
def __initialize_events(self) -> None:
# capture resize events (update background of spatial plot)
# triggers the reset of the spatial axe background
self.fig.canvas.mpl_connect('resize_event', lambda e: self.__on_window_resized())
# Keep only the home, zoom and help elements in the toolbar
tool_mgr = self.fig.canvas.manager.toolmanager
tool_mgr.remove_tool('back')
tool_mgr.remove_tool('forward')
tool_mgr.remove_tool('pan')
tool_mgr.remove_tool('subplots')
tool_mgr.remove_tool('help')
[docs]
def plot(self, figure_setup: dict):
self.__create_figure(figure_setup)
self.__initialize_figure()
self.__initialize_widgets()
self.__initialize_events()