# Copyright 2025 ICube (University of Strasbourg - CNRS)
# author: Julien PONTABRY (ICube)
#
# This software is a computer program whose purpose is to provide a toolkit
# to model, process and analyze the longitudinal reorganization of brain
# connectivity data, as functional MRI for instance.
#
# This software is governed by the CeCILL-B license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/or redistribute the software under the terms of the CeCILL-B
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
"""Defines spatio-temporal graphs and related structures."""
from collections.abc import Iterable
from enum import Enum, auto, unique
from math import isclose
from numbers import Number
from typing import Any
import networkx as nx
import pandas as pd
[docs]
@unique
class RC5(Enum):
"""Defines an RC5 temporal transition."""
EQ = auto()
PP = auto()
PPi = auto()
PO = auto()
DC = auto()
[docs]
@staticmethod
def from_name(name: str) -> 'RC5':
"""Find the RC5 transition from its name.
Parameters
----------
name: string
The name of the RC5 transition.
Returns
-------
RC5
The corresponding RC5 transition if it has been found.
Raises
------
ValueError: if no RC5 transition can be found with the given name.
Examples
--------
Access to a transition is simple as typing the right transition.
>>> RC5.PPi
<RC5.PPi: 3>
To access to all the available transition, one can iterate on the RC5 enumeration.
>>> for transition in RC5:
... print(transition)
...
EQ
PP
PPi
PO
DC
When only the transition name is available, the transition itself can be retrieved
from it with the `from_name` static method.
>>> RC5.from_name("PO")
<RC5.PO: 4>
>>> RC5.from_name("EQ")
<RC5.EQ: 1>
When the provided name does not match any available transition, a `ValueError`
transition is thrown.
>>> RC5.from_name("NN")
Traceback (most recent call last):
ValueError: Unable to find a transition named "NN"!
"""
available = [e for e in RC5 if e.name == name]
if len(available) > 0:
return available[0]
else:
raise ValueError(f"Unable to find a transition named \"{name}\"!")
[docs]
@staticmethod
def includes(name: str) -> bool:
"""Check whether a name corresponds to a valid RC5 transition.
Parameters
----------
name: str
The name to test.
Returns
-------
bool
True if `name` matches one of the RC5 transitions; False otherwise.
"""
try:
RC5.from_name(name)
return True
except ValueError:
return False
def __str__(self) -> str:
return self.name
def __check_data(data: dict[str, Any], key: str, value: Any) -> bool:
"""Check whether a data dictionary satisfies a key/value condition.
If the key is absent from `data` the condition is considered satisfied.
When `value` is an iterable the check tests membership; otherwise it tests equality.
Parameters
----------
data: dict[str, Any]
The node or edge data dictionary to inspect.
key: str
The attribute key to check.
value: Any
The expected value or iterable of accepted values.
Returns
-------
bool
True if the condition is satisfied; False otherwise.
"""
if key not in data:
return True
if isinstance(value, Iterable):
return data[key] in value
else:
return data[key] == value
[docs]
def subgraph_nodes(graph: nx.Graph, **conditions: Any) -> nx.Graph:
"""Take the subgraph that matches the conditions on the nodes.
Parameters
----------
graph: nx.Graph
The initial graph.
conditions: dict[str, any]
The conditions on the nodes of the subgraph as keywords arguments. As value,
any single value or iterable of values is supported.
Returns
-------
nx.Graph
The subgraph matching the specified conditions.
Example
-------
>>> G = nx.Graph()
>>> G.add_nodes_from([(1, dict(a=0, b=1)), (2, dict(a=2, b=1)), (3, dict(a=2, b=2)), (4, dict(a=2, b=1))])
>>> G.add_edges_from([(1, 2), (3, 4), (1, 4)])
>>> subgraph_nodes(G).nodes
NodeView((1, 2, 3, 4))
>>> subgraph_nodes(G, a=0).nodes
NodeView((1,))
>>> subgraph_nodes(G, b=1).nodes
NodeView((1, 2, 4))
>>> subgraph_nodes(G, a=2).nodes
NodeView((2, 3, 4))
>>> subgraph_nodes(G, a=2, b=2).nodes
NodeView((3,))
>>> subgraph_nodes(G, b=(1, 2), a=2).nodes
NodeView((2, 3, 4))
>>> subgraph_nodes(G, b=range(1, 3)).nodes
NodeView((1, 2, 3, 4))
"""
return graph.subgraph([n for n, d in graph.nodes(data=True)
if all(__check_data(d, k, v) for k, v in conditions.items())])
[docs]
def subgraph_edges(graph: nx.Graph, **conditions: Any) -> nx.Graph:
"""Take the subgraph induced by edges that match the conditions.
Parameters
----------
graph: nx.Graph
The initial graph.
conditions: dict[str, Any]
Conditions on the edge data as keyword arguments. As value, any single value
or iterable of values is supported.
Returns
-------
nx.Graph
The edge-induced subgraph matching the specified conditions.
"""
return graph.edge_subgraph([(n1, n2) for n1, n2, d in graph.edges(data=True)
if all(__check_data(d, k, v) for k, v in conditions.items())])
[docs]
class SpatioTemporalGraph(nx.DiGraph):
"""A spatio-temporal graph wrapping a directed NetworkX graph.
Nodes carry brain area/region metadata; spatial edges carry correlation values;
temporal edges carry :class:`RC5` transition types.
Parameters
----------
graph: nx.Graph, optional
An existing NetworkX graph to initialise from.
areas: pandas.DataFrame, optional
A DataFrame describing brain areas with columns ``Name_Area`` and ``Name_Region``,
indexed by ``Id_Area``.
"""
[docs]
def __init__(self, graph: nx.Graph = None, areas: pd.DataFrame = None) -> None:
"""Initialise the spatio-temporal graph.
Parameters
----------
graph: nx.Graph, optional
An existing NetworkX graph to initialise from.
areas: pandas.DataFrame, optional
A DataFrame describing brain areas (index: ``Id_Area``).
"""
super().__init__(graph)
self.areas = areas
@property
def time_range(self) -> range:
"""Get the time range covered by the spatio-temporal graph."""
return range(self.graph['max_time']+1)
[docs]
def sub(self, **conditions) -> 'SpatioTemporalGraph':
"""Helper to take the subgraph of the spatio-temporal graph matching the specified conditions.
See :func:`~graph.subgraph` for the arguments.
"""
def __split(c: dict) -> tuple[dict, dict]:
nc = {}
ec = {}
for key, value in c.items():
if key in {'t', 'areas', 'region'}:
nc[key] = value
elif key in {'type', 'transition'}:
ec[key] = value
if 'transition' in ec:
ec['type'] = 'temporal'
return nc, ec
cond_nodes, cond_edges = __split(conditions)
graph = self
if cond_nodes:
graph = subgraph_nodes(graph, **cond_nodes)
if cond_edges:
graph = subgraph_edges(graph, **cond_edges)
return SpatioTemporalGraph(graph, self.areas)
[docs]
def sub_spatial(self) -> 'SpatioTemporalGraph':
return self.sub(type='spatial')
[docs]
def sub_temporal(self) -> 'SpatioTemporalGraph':
return self.sub(type='temporal')
[docs]
def __eq__(self, other: 'SpatioTemporalGraph') -> bool:
"""Test equality with another spatio-temporal graph.
Two graphs are equal when their nodes, edges (including data) and
areas DataFrames are all equal.
Parameters
----------
other: SpatioTemporalGraph
The graph to compare against.
Returns
-------
bool
True if both graphs are structurally and data-wise identical.
"""
return nx.utils.graphs_equal(self, other) and self.areas.equals(other.areas)
[docs]
def __str__(self) -> str:
"""Return a human-readable summary of the spatio-temporal graph."""
return f"SpatioTemporalGraph(#areas={len(self.areas)}, #regions={len(set(self.areas['Name_Region']))}, "\
f"#nodes={len(self.nodes)}, #spatial edges={len([_ for _, _, d in self.edges(data=True) if d['type'] == 'spatial'])}, "\
f"#temporal edges={len([_ for _, _, d in self.edges(data=True) if d['type'] == 'temporal'])})"
[docs]
def __repr__(self) -> str:
"""Return the canonical string representation (same as __str__)."""
return str(self)
def __data_almost_equal(data1: dict[str, Any], data2: dict[str, Any],
rel_tol: float = 1e-9, abs_tol: float = 0.0) -> bool:
"""Check if two data dictionaries are almost equal, with tolerance on numeric values.
Parameters
----------
data1: dict[str, any]
First data dictionary.
data2: dict[str, any]
Second data dictionary.
rel_tol: float, optional
Relative tolerance for numeric comparisons (default 1e-9).
abs_tol: float, optional
Absolute tolerance for numeric comparisons (default 0.0).
Returns
-------
bool
True if both dicts have the same keys and their values are equal
(with numeric tolerance for numeric types).
"""
if len(data1) != len(data2):
return False
for key, item in data1.items():
if key not in data2:
return False
elif isinstance(item, Number):
if not isclose(data1[key], data2[key],
rel_tol=rel_tol, abs_tol=abs_tol):
return False
elif data1[key] != data2[key]:
return False
return True
[docs]
def are_st_graphs_close(graph1: SpatioTemporalGraph, graph2: SpatioTemporalGraph) -> bool:
"""Test if two spatio-temporal graphs are equal with some tolerance on numerical values.
Parameters
----------
graph1: SpatioTemporalGraph
The first spatio-temporal graph to compare.
graph2: SpatioTemporalGraph
The second spatio-temporal graph to compare.
Returns
-------
bool
True if graphs are almost equal; false otherwise.
Examples
--------
>>> g1 = nx.DiGraph()
>>> g1.add_nodes_from([(1, dict(t=0, areas={1, 2}, region="R1", internal_strength=0.8)),
... (2, dict(t=0, areas={3}, region="R1", internal_strength=1)),
... (3, dict(t=1, areas={1, 2, 3}, region="R1", internal_strength=0.9))])
>>> g1.add_edges_from([(1, 3, dict(transition=RC5.PP, type='temporal'))])
>>> a1 = pd.DataFrame({'Id_Area': [1, 2, 3], 'Name_Area': ["A1", "A2", "A3"], 'Name_Region': ["R1", "R1", "R1"]})
>>> st_g1 = SpatioTemporalGraph(g1, a1)
>>> g2 = nx.DiGraph()
>>> g2.add_nodes_from([(1, dict(t=0, areas={1, 2}, region="R1", internal_strength=0.7999999999999)),
... (2, dict(t=0, areas={3}, region="R1", internal_strength=1.000000000001)),
... (3, dict(t=1, areas={1, 2, 3}, region="R1", internal_strength=0.8999999999999))])
>>> g2.add_edges_from([(1, 3, dict(transition=RC5.PP, type='temporal'))])
>>> a2 = pd.DataFrame({'Id_Area': [1, 2, 3], 'Name_Area': ["A1", "A2", "A3"], 'Name_Region': ["R1", "R1", "R1"]})
>>> st_g2 = SpatioTemporalGraph(g2, a2)
>>> g3 = nx.DiGraph()
>>> g3.add_nodes_from([(1, dict(t=0, areas={1, 2}, region="R1", internal_strength=0.8)),
... (2, dict(t=1, areas={1, 2}, region="R1", internal_strength=0.8))])
>>> g3.add_edges_from([(1, 2, dict(transition=RC5.EQ, type='temporal'))])
>>> a3 = pd.DataFrame({'Id_Area': [1, 2], 'Name_Area': ["A1", "A2"], 'Name_Region': ["R1", "R1"]})
>>> st_g3 = SpatioTemporalGraph(g3, a3)
>>> are_st_graphs_close(st_g1, st_g1)
True
>>> are_st_graphs_close(st_g1, st_g2)
True
>>> are_st_graphs_close(st_g1, st_g3)
False
"""
nodes1 = graph1.nodes
nodes2 = graph2.nodes
nodes_equal = list(nodes1) == list(nodes2)
if not nodes_equal:
return False
nodes_data_almost_equal = all(__data_almost_equal(nodes1[n], nodes2[n])
for n in nodes1)
edges1 = graph1.edges
edges2 = graph2.edges
edges_equal = list(edges1) == list(edges2)
if not edges_equal:
return False
edges_data_almost_equal = all(__data_almost_equal(edges1[e], edges2[e])
for e in edges1)
return nodes_equal and edges_equal and nodes_data_almost_equal and \
edges_data_almost_equal and graph1.areas.equals(graph2.areas)