Source code for fstg_toolkit.graph

# Copyright 2025 ICube (University of Strasbourg - CNRS)
# author: Julien PONTABRY (ICube)
#
# This software is a computer program whose purpose is to provide a toolkit
# to model, process and analyze the longitudinal reorganization of brain
# connectivity data, as functional MRI for instance.
#
# This software is governed by the CeCILL-B license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/or redistribute the software under the terms of the CeCILL-B
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.

"""Defines spatio-temporal graphs and related structures."""
from collections.abc import Iterable
from enum import Enum, auto, unique
from math import isclose
from numbers import Number
from typing import Any

import networkx as nx
import pandas as pd


[docs] @unique class RC5(Enum): """Defines an RC5 temporal transition.""" EQ = auto() PP = auto() PPi = auto() PO = auto() DC = auto()
[docs] @staticmethod def from_name(name: str) -> 'RC5': """Find the RC5 transition from its name. Parameters ---------- name: string The name of the RC5 transition. Returns ------- RC5 The corresponding RC5 transition if it has been found. Raises ------ ValueError: if no RC5 transition can be found with the given name. Examples -------- Access to a transition is simple as typing the right transition. >>> RC5.PPi <RC5.PPi: 3> To access to all the available transition, one can iterate on the RC5 enumeration. >>> for transition in RC5: ... print(transition) ... EQ PP PPi PO DC When only the transition name is available, the transition itself can be retrieved from it with the `from_name` static method. >>> RC5.from_name("PO") <RC5.PO: 4> >>> RC5.from_name("EQ") <RC5.EQ: 1> When the provided name does not match any available transition, a `ValueError` transition is thrown. >>> RC5.from_name("NN") Traceback (most recent call last): ValueError: Unable to find a transition named "NN"! """ available = [e for e in RC5 if e.name == name] if len(available) > 0: return available[0] else: raise ValueError(f"Unable to find a transition named \"{name}\"!")
[docs] @staticmethod def includes(name: str) -> bool: """Check whether a name corresponds to a valid RC5 transition. Parameters ---------- name: str The name to test. Returns ------- bool True if `name` matches one of the RC5 transitions; False otherwise. """ try: RC5.from_name(name) return True except ValueError: return False
def __str__(self) -> str: return self.name
def __check_data(data: dict[str, Any], key: str, value: Any) -> bool: """Check whether a data dictionary satisfies a key/value condition. If the key is absent from `data` the condition is considered satisfied. When `value` is an iterable the check tests membership; otherwise it tests equality. Parameters ---------- data: dict[str, Any] The node or edge data dictionary to inspect. key: str The attribute key to check. value: Any The expected value or iterable of accepted values. Returns ------- bool True if the condition is satisfied; False otherwise. """ if key not in data: return True if isinstance(value, Iterable): return data[key] in value else: return data[key] == value
[docs] def subgraph_nodes(graph: nx.Graph, **conditions: Any) -> nx.Graph: """Take the subgraph that matches the conditions on the nodes. Parameters ---------- graph: nx.Graph The initial graph. conditions: dict[str, any] The conditions on the nodes of the subgraph as keywords arguments. As value, any single value or iterable of values is supported. Returns ------- nx.Graph The subgraph matching the specified conditions. Example ------- >>> G = nx.Graph() >>> G.add_nodes_from([(1, dict(a=0, b=1)), (2, dict(a=2, b=1)), (3, dict(a=2, b=2)), (4, dict(a=2, b=1))]) >>> G.add_edges_from([(1, 2), (3, 4), (1, 4)]) >>> subgraph_nodes(G).nodes NodeView((1, 2, 3, 4)) >>> subgraph_nodes(G, a=0).nodes NodeView((1,)) >>> subgraph_nodes(G, b=1).nodes NodeView((1, 2, 4)) >>> subgraph_nodes(G, a=2).nodes NodeView((2, 3, 4)) >>> subgraph_nodes(G, a=2, b=2).nodes NodeView((3,)) >>> subgraph_nodes(G, b=(1, 2), a=2).nodes NodeView((2, 3, 4)) >>> subgraph_nodes(G, b=range(1, 3)).nodes NodeView((1, 2, 3, 4)) """ return graph.subgraph([n for n, d in graph.nodes(data=True) if all(__check_data(d, k, v) for k, v in conditions.items())])
[docs] def subgraph_edges(graph: nx.Graph, **conditions: Any) -> nx.Graph: """Take the subgraph induced by edges that match the conditions. Parameters ---------- graph: nx.Graph The initial graph. conditions: dict[str, Any] Conditions on the edge data as keyword arguments. As value, any single value or iterable of values is supported. Returns ------- nx.Graph The edge-induced subgraph matching the specified conditions. """ return graph.edge_subgraph([(n1, n2) for n1, n2, d in graph.edges(data=True) if all(__check_data(d, k, v) for k, v in conditions.items())])
[docs] class SpatioTemporalGraph(nx.DiGraph): """A spatio-temporal graph wrapping a directed NetworkX graph. Nodes carry brain area/region metadata; spatial edges carry correlation values; temporal edges carry :class:`RC5` transition types. Parameters ---------- graph: nx.Graph, optional An existing NetworkX graph to initialise from. areas: pandas.DataFrame, optional A DataFrame describing brain areas with columns ``Name_Area`` and ``Name_Region``, indexed by ``Id_Area``. """
[docs] def __init__(self, graph: nx.Graph = None, areas: pd.DataFrame = None) -> None: """Initialise the spatio-temporal graph. Parameters ---------- graph: nx.Graph, optional An existing NetworkX graph to initialise from. areas: pandas.DataFrame, optional A DataFrame describing brain areas (index: ``Id_Area``). """ super().__init__(graph) self.areas = areas
@property def time_range(self) -> range: """Get the time range covered by the spatio-temporal graph.""" return range(self.graph['max_time']+1)
[docs] def sub(self, **conditions) -> 'SpatioTemporalGraph': """Helper to take the subgraph of the spatio-temporal graph matching the specified conditions. See :func:`~graph.subgraph` for the arguments. """ def __split(c: dict) -> tuple[dict, dict]: nc = {} ec = {} for key, value in c.items(): if key in {'t', 'areas', 'region'}: nc[key] = value elif key in {'type', 'transition'}: ec[key] = value if 'transition' in ec: ec['type'] = 'temporal' return nc, ec cond_nodes, cond_edges = __split(conditions) graph = self if cond_nodes: graph = subgraph_nodes(graph, **cond_nodes) if cond_edges: graph = subgraph_edges(graph, **cond_edges) return SpatioTemporalGraph(graph, self.areas)
[docs] def sub_spatial(self) -> 'SpatioTemporalGraph': return self.sub(type='spatial')
[docs] def sub_temporal(self) -> 'SpatioTemporalGraph': return self.sub(type='temporal')
[docs] def __eq__(self, other: 'SpatioTemporalGraph') -> bool: """Test equality with another spatio-temporal graph. Two graphs are equal when their nodes, edges (including data) and areas DataFrames are all equal. Parameters ---------- other: SpatioTemporalGraph The graph to compare against. Returns ------- bool True if both graphs are structurally and data-wise identical. """ return nx.utils.graphs_equal(self, other) and self.areas.equals(other.areas)
[docs] def __str__(self) -> str: """Return a human-readable summary of the spatio-temporal graph.""" return f"SpatioTemporalGraph(#areas={len(self.areas)}, #regions={len(set(self.areas['Name_Region']))}, "\ f"#nodes={len(self.nodes)}, #spatial edges={len([_ for _, _, d in self.edges(data=True) if d['type'] == 'spatial'])}, "\ f"#temporal edges={len([_ for _, _, d in self.edges(data=True) if d['type'] == 'temporal'])})"
[docs] def __repr__(self) -> str: """Return the canonical string representation (same as __str__).""" return str(self)
def __data_almost_equal(data1: dict[str, Any], data2: dict[str, Any], rel_tol: float = 1e-9, abs_tol: float = 0.0) -> bool: """Check if two data dictionaries are almost equal, with tolerance on numeric values. Parameters ---------- data1: dict[str, any] First data dictionary. data2: dict[str, any] Second data dictionary. rel_tol: float, optional Relative tolerance for numeric comparisons (default 1e-9). abs_tol: float, optional Absolute tolerance for numeric comparisons (default 0.0). Returns ------- bool True if both dicts have the same keys and their values are equal (with numeric tolerance for numeric types). """ if len(data1) != len(data2): return False for key, item in data1.items(): if key not in data2: return False elif isinstance(item, Number): if not isclose(data1[key], data2[key], rel_tol=rel_tol, abs_tol=abs_tol): return False elif data1[key] != data2[key]: return False return True
[docs] def are_st_graphs_close(graph1: SpatioTemporalGraph, graph2: SpatioTemporalGraph) -> bool: """Test if two spatio-temporal graphs are equal with some tolerance on numerical values. Parameters ---------- graph1: SpatioTemporalGraph The first spatio-temporal graph to compare. graph2: SpatioTemporalGraph The second spatio-temporal graph to compare. Returns ------- bool True if graphs are almost equal; false otherwise. Examples -------- >>> g1 = nx.DiGraph() >>> g1.add_nodes_from([(1, dict(t=0, areas={1, 2}, region="R1", internal_strength=0.8)), ... (2, dict(t=0, areas={3}, region="R1", internal_strength=1)), ... (3, dict(t=1, areas={1, 2, 3}, region="R1", internal_strength=0.9))]) >>> g1.add_edges_from([(1, 3, dict(transition=RC5.PP, type='temporal'))]) >>> a1 = pd.DataFrame({'Id_Area': [1, 2, 3], 'Name_Area': ["A1", "A2", "A3"], 'Name_Region': ["R1", "R1", "R1"]}) >>> st_g1 = SpatioTemporalGraph(g1, a1) >>> g2 = nx.DiGraph() >>> g2.add_nodes_from([(1, dict(t=0, areas={1, 2}, region="R1", internal_strength=0.7999999999999)), ... (2, dict(t=0, areas={3}, region="R1", internal_strength=1.000000000001)), ... (3, dict(t=1, areas={1, 2, 3}, region="R1", internal_strength=0.8999999999999))]) >>> g2.add_edges_from([(1, 3, dict(transition=RC5.PP, type='temporal'))]) >>> a2 = pd.DataFrame({'Id_Area': [1, 2, 3], 'Name_Area': ["A1", "A2", "A3"], 'Name_Region': ["R1", "R1", "R1"]}) >>> st_g2 = SpatioTemporalGraph(g2, a2) >>> g3 = nx.DiGraph() >>> g3.add_nodes_from([(1, dict(t=0, areas={1, 2}, region="R1", internal_strength=0.8)), ... (2, dict(t=1, areas={1, 2}, region="R1", internal_strength=0.8))]) >>> g3.add_edges_from([(1, 2, dict(transition=RC5.EQ, type='temporal'))]) >>> a3 = pd.DataFrame({'Id_Area': [1, 2], 'Name_Area': ["A1", "A2"], 'Name_Region': ["R1", "R1"]}) >>> st_g3 = SpatioTemporalGraph(g3, a3) >>> are_st_graphs_close(st_g1, st_g1) True >>> are_st_graphs_close(st_g1, st_g2) True >>> are_st_graphs_close(st_g1, st_g3) False """ nodes1 = graph1.nodes nodes2 = graph2.nodes nodes_equal = list(nodes1) == list(nodes2) if not nodes_equal: return False nodes_data_almost_equal = all(__data_almost_equal(nodes1[n], nodes2[n]) for n in nodes1) edges1 = graph1.edges edges2 = graph2.edges edges_equal = list(edges1) == list(edges2) if not edges_equal: return False edges_data_almost_equal = all(__data_almost_equal(edges1[e], edges2[e]) for e in edges1) return nodes_equal and edges_equal and nodes_data_almost_equal and \ edges_data_almost_equal and graph1.areas.equals(graph2.areas)