Source code for fstg_toolkit.simulation

# Copyright 2025 ICube (University of Strasbourg - CNRS)
# author: Julien PONTABRY (ICube)
#
# This software is a computer program whose purpose is to provide a toolkit
# to model, process and analyze the longitudinal reorganization of brain
# connectivity data, as functional MRI for instance.
#
# This software is governed by the CeCILL-B license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/or redistribute the software under the terms of the CeCILL-B
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.

"""Defines the tools to simulate spatio-temporal graphs and functional connectivity data."""
from dataclasses import dataclass
from functools import reduce
from itertools import combinations
from typing import Iterable

import networkx as nx
import numpy as np
import pandas as pd
from networkx.classes.reportviews import NodeView

from .graph import RC5, SpatioTemporalGraph, subgraph_nodes


def _fill_matrix(connections, correlations, matrix):
    """Fill a correlation matrix in-place with the given connection values.

    Parameters
    ----------
    connections: list[tuple[int, int]]
        Pairs of area indices (1-based) to connect.
    correlations: list[float]
        Correlation value for each connection pair.
    matrix: numpy.ndarray
        The square correlation matrix to update (modified in-place).
    """
    for (e1, e2), corr in zip(connections, correlations):
        i = e1 - 1
        j = e2 - 1
        matrix[i, j] = matrix[j, i] = corr


@dataclass(frozen=True)
class _CorrelationMatrixNetworksEdgesFiller:
    """Fills intra-network (within-region) correlations in a correlation matrix.

    For each network (node) in a spatial graph, randomly samples a set of
    connected area pairs and assigns correlation values whose mean matches the
    network's internal strength.

    Parameters
    ----------
    threshold: float
        The correlation threshold; intra-network correlations are constrained
        to remain above this value in absolute terms.
    rng: numpy.random.Generator
        A NumPy random number generator used for reproducible sampling.
    """

    threshold: float
    rng: np.random.Generator

    def __choose_connections(self, network: list[int]) -> list[tuple[int, int]]:
        """Randomly select a connected set of area pairs within a network.

        Repeatedly samples random subsets of all pairwise combinations until a
        connected graph can be formed, guaranteeing that every area is reachable.

        Parameters
        ----------
        network: list[int]
            Area identifiers (1-based) belonging to the network.

        Returns
        -------
        list[tuple[int, int]]
            A list of area-pair edges forming a connected subgraph of the network.
        """
        def __new_graph(trial_connections: list[tuple[int, int]]) -> nx.Graph:
            g = nx.Graph()
            g.add_nodes_from(network)
            g.add_edges_from(trial_connections)
            return g

        def __try_random_connections(net: list[int]) -> list[tuple[int, int]]:
            combs = list(combinations(net, 2))

            if len(combs) == 1:
                return combs
            else:
                selected = []

                for _ in range(self.rng.integers(low=len(network) - 1, high=len(combs))):
                    elem = combs.pop(self.rng.integers(len(combs)))
                    selected.append(elem)

                return selected

        trial_graph = __new_graph(__try_random_connections(network))

        while not nx.is_connected(trial_graph):
            trial_graph = __new_graph(__try_random_connections(network))

        return list(trial_graph.edges)

    def __mean_corr_sampler(self, size: int, mean: float, max_attempts: int = 1_000) -> list[float]:
        """Sample ``size`` correlation values whose empirical mean equals ``mean``.

        Values are sampled uniformly within a symmetric interval around ``mean``
        and then shifted so that their mean matches exactly. The interval is
        constrained to keep all values on the same side of ``threshold``.

        Parameters
        ----------
        size: int
            Number of correlation values to sample.
        mean: float
            Target mean correlation value.
        max_attempts: int
            The maximal number of attempts to sample. Default is 1000

        Returns
        -------
        list[float]
            A list of ``size`` correlation values with the exact requested mean.
        """
        def __sample(low: float, high: float) -> np.ndarray:
            values = self.rng.uniform(low=low, high=high, size=size)
            return values + mean - values.mean()

        rad = abs(mean) - self.threshold
        a, b = mean - rad, mean + rad

        samples = __sample(a, b)
        attempts = 0
        while any(samples < a) or any(samples > b):
            samples = __sample(a, b)
            attempts += 1
            if attempts >= max_attempts:
                raise RuntimeError(
                    f"Could not sample values in [{a}, {b}] with mean={mean} after {max_attempts} attempts"
                )

        return samples.tolist()

    def fill(self, spatial_graph: nx.DiGraph, matrix: np.ndarray) -> None:
        """Fill intra-network correlations for all networks in the spatial graph.

        Parameters
        ----------
        spatial_graph: nx.DiGraph
            A single time-point spatial graph whose nodes represent networks
            (each carrying ``areas`` and ``internal_strength`` attributes).
        matrix: numpy.ndarray
            The square correlation matrix to update in-place.
        """
        networks = [(data['areas'], data['internal_strength'])
                    for _, data in spatial_graph.nodes.items()]

        for network, mean_corr in networks:
            connections = self.__choose_connections(network)
            correlations = self.__mean_corr_sampler(len(connections), mean_corr)
            _fill_matrix(connections, correlations, matrix)


@dataclass(frozen=True)
class _CorrelationMatrixInterRegionEdgesFiller:
    """Fills inter-network (cross-region) correlations in a correlation matrix.

    For each spatial edge between two networks, randomly samples area pairs
    across the two networks and assigns correlation values whose maximum (in
    absolute terms) matches the edge's correlation attribute.

    Parameters
    ----------
    threshold: float
        The correlation threshold; inter-region correlations are constrained
        to remain above this value in absolute terms.
    rng: numpy.random.Generator
        A NumPy random number generator used for reproducible sampling.
    """

    threshold: float
    rng: np.random.Generator

    def __max_correlation_sampler(self, size: int, target: float) -> list[float]:
        """Sample ``size`` correlation values whose extreme value equals ``target``.

        The extreme is the maximum for positive targets and the minimum for
        negative targets, so that the inter-region edge correlation is reproduced.

        Parameters
        ----------
        size: int
            Number of correlation values to sample.
        target: float
            The desired extreme correlation value.

        Returns
        -------
        list[float]
            A list of ``size`` correlation values with the exact requested extreme.
        """
        sample_fun = np.max if target >= 0 else np.min

        def __sample(low: float, high: float) -> np.array:
            values = self.rng.uniform(low=low, high=high, size=size)
            return values + target - sample_fun(values)

        thr = np.sign(target) * self.threshold
        mean = (thr + target) / 2
        rad = abs(target - mean)
        a, b = mean - rad, mean + rad

        samples = __sample(a, b)
        while any(samples < a) or any(samples > b):
            samples = __sample(a, b)

        return samples

    def __choose_inter_region_connections(self, network1: set, network2: set) -> list[tuple[int, int]]:
        """Randomly select area pairs connecting two distinct networks.

        For each area in ``network1``, a Poisson-distributed number of areas from
        ``network2`` are selected as targets (at least one per source area).

        Parameters
        ----------
        network1: set[int]
            Area identifiers (1-based) of the source network.
        network2: set[int]
            Area identifiers (1-based) of the target network.

        Returns
        -------
        list[tuple[int, int]]
            A list of (source_area, target_area) pairs.
        """
        def __choose_inter_region_connection(k: int, network: set, n: int) -> list[tuple[int, int]]:
            combs = list(zip([k]*len(network), network))
            selected = []

            for _ in range(min(n, len(combs))):
                elem = combs.pop(self.rng.integers(len(combs)))
                selected.append(elem)

            return selected

        connections_sizes = self.rng.poisson(lam=1, size=len(network1))
        connections = []

        for node, nb_connections in zip(network1, connections_sizes):
            connections += __choose_inter_region_connection(node, network2, max(1, nb_connections))

        return connections

    def fill(self, spatial_graph: nx.DiGraph, matrix: np.array) -> None:
        """Fill inter-network correlations for all edges in the spatial graph.

        Parameters
        ----------
        spatial_graph: nx.DiGraph
            A single time-point spatial graph whose edges carry a ``correlation``
            attribute describing the inter-network correlation strength.
        matrix: numpy.ndarray
            The square correlation matrix to update in-place.
        """
        for (node1, node2), data in spatial_graph.edges.items():
            connections = self.__choose_inter_region_connections(
                spatial_graph.nodes[node1]['areas'], spatial_graph.nodes[node2]['areas'])
            correlations = self.__max_correlation_sampler(len(connections), data['correlation'])
            _fill_matrix(connections, correlations, matrix)


[docs] class CorrelationMatrixSequenceSimulator: """Simulate a sequence of correlation matrices from a spatio-temporal graph. Examples -------- >>> graph = nx.DiGraph() >>> graph.add_node(1, t=0, areas={1, 2}, region='Region 1', internal_strength=0.98) >>> graph.add_node(2, t=0, areas={3, 4}, region='Region 2', internal_strength=-0.98) >>> graph.add_edge(1, 2, correlation=0.94, t=0, type='spatial') >>> graph.add_edge(2, 1, correlation=0.94, t=0, type='spatial') >>> graph.graph['min_time'] = 0 >>> graph.graph['max_time'] = 0 >>> areas = pd.DataFrame({'Id_Area': [1, 2, 3, 4], ... 'Name_Area': ['A1', 'A2', 'A3', 'A4'], ... 'Name_Region': ['R1', 'R1', 'R2', 'R2']}) >>> areas.set_index('Id_Area', inplace=True) >>> simulator = CorrelationMatrixSequenceSimulator(SpatioTemporalGraph(graph, areas), threshold=0.4, ... rng=np.random.default_rng(40)) >>> matrix = simulator.simulate() >>> matrix.shape (1, 4, 4) >>> matrix array([[[ 1. , 0.98 , 0.65453818, 0.94 ], [ 0.98 , 1. , 0.85381682, 0.61873641], [ 0.65453818, 0.85381682, 1. , -0.98 ], [ 0.94 , 0.61873641, -0.98 , 1. ]]]) """
[docs] def __init__(self, graph: SpatioTemporalGraph, threshold: float = 0.4, rng: np.random.Generator = np.random.default_rng()) -> None: """Initialise the simulator. Parameters ---------- graph: SpatioTemporalGraph The reference spatio-temporal graph from which correlation matrices will be back-generated. threshold: float, optional Minimum absolute correlation required for spatial edges to be considered significant (default 0.4, must be in [0, 1]). rng: numpy.random.Generator, optional Random number generator for reproducible results. Raises ------ ValueError If ``threshold`` is outside [0, 1]. """ self.graph = graph self.threshold = threshold self.__rng = rng self.__network_edges_filler = _CorrelationMatrixNetworksEdgesFiller(self.threshold, self.__rng) self.__inter_region_edges_filler = _CorrelationMatrixInterRegionEdgesFiller(self.threshold, self.__rng) self.__init_validation__()
[docs] def __init_validation__(self): """Validate constructor parameters. Raises ------ ValueError If ``threshold`` is not in [0, 1]. """ if self.threshold < 0 or self.threshold > 1: raise ValueError("The threshold must be within range [0, 1]!")
def __simulate_corr_matrix(self, spatial_graph: nx.DiGraph) -> np.array: """Generate a single correlation matrix from a one-time-point spatial graph. Fills intra-network and inter-network correlations, then randomises any remaining zero entries with sub-threshold uniform noise. Parameters ---------- spatial_graph: nx.DiGraph The spatial subgraph at a single time point. Returns ------- numpy.ndarray A square symmetric correlation matrix of shape ``(n_areas, n_areas)``. """ matrix = np.eye(len(self.graph.areas)) self.__network_edges_filler.fill(spatial_graph, matrix) self.__inter_region_edges_filler.fill(spatial_graph, matrix) null_elements = matrix == 0 bound = self.threshold * 0.99 matrix[null_elements] = self.__rng.uniform(low=-bound, high=bound, size=null_elements.sum()) return matrix
[docs] def simulate(self) -> np.array: """Simulate the sequence of correlation matrices. Returns ------- numpy.array A 3D-shaped array that contains the correlations matrices for each time. """ return np.array([self.__simulate_corr_matrix(self.graph.sub(t=t)) for t in self.graph.time_range])
def __trans(sources: Iterable[int] | int, targets: Iterable[int] | int, kind: str) -> list[tuple[int, int, RC5]]: """Expand a human-readable transition description into a list of (source, target, RC5) triples. Parameters ---------- sources: int or Iterable[int] Source node(s). For a ``'merge'`` transition this must be an iterable. targets: int or Iterable[int] Target node(s). For a ``'split'`` transition this must be an iterable. kind: str Transition type: ``'split'`` (:attr:`RC5.PPi`), ``'merge'`` (:attr:`RC5.PP`), ``'eq'`` (:attr:`RC5.EQ`), or any other string for :attr:`RC5.PO`. Returns ------- list[tuple[int, int, RC5]] A list of ``(source_node, target_node, transition)`` tuples. """ if kind.lower() == 'split': return [(sources, target, RC5.PPi) for target in targets] elif kind.lower() == 'merge': return [(source, targets, RC5.PP) for source in sources] else: trans = RC5.EQ if kind.lower() == 'eq' else RC5.PO return [(sources, targets, trans)] def __def2areas(areas_def: tuple[int, int] | Iterable[int] | int) -> set[int]: """Convert a compact area definition to a set of area identifiers. Parameters ---------- areas_def: tuple[int, int] | Iterable[int] | int - A 2-tuple ``(start, end)`` is expanded to the inclusive range ``{start, start+1, ..., end}``. - Any other iterable is converted directly to a set. - A single integer is wrapped in a singleton set. Returns ------- set[int] The corresponding set of area identifiers. """ if isinstance(areas_def, tuple) and len(areas_def) == 2: start, end = areas_def return set(range(start, end + 1)) elif isinstance(areas_def, Iterable): return set(areas_def) else: return {areas_def}
[docs] def generate_pattern(networks_list: list[list[tuple[tuple[int, int], int, float]]], spatial_edges: list[tuple[int, int, float]], temporal_edges: list[tuple[Iterable[int] | int, Iterable[int] | int, str]]) -> SpatioTemporalGraph: """Generate a pattern with the specified properties. Parameters ---------- networks_list: list[list[tuple[tuple[int, int], int, float]]] A list of nodes per time instant, defined themselves by a tuple of area range, region id and internal strength. spatial_edges: list[tuple[int, int, float]] A list of spatial edges defined by a tuple of source/target nodes and a correlation. temporal_edges: list[tuple[Iterable[int] | int, Iterable[int] | int, str]] A list of temporal edges, defined by a tuple of source(s)/target(s) nodes and a transition. Returns ------- SpatioTemporalGraph A spatio-temporal graph that can be used as a pattern. Example ------- >>> pattern = generate_pattern( ... networks_list=[[((1, 5), 1, -0.2), ((6, 7), 2, 0.3), ((8, 10), 2, 0.6)], ... [((1, 5), 1, 0.6), ((6, 10), 2, -0.5)]], ... spatial_edges=[(1, 2, 0.45), (4, 5, 0.8)], ... temporal_edges=[(1, 4, 'eq'), ((2, 3), 5, 'merge')]) >>> pattern.nodes NodeView((1, 2, 3, 4, 5)) >>> pattern.edges OutEdgeView([(1, 2), (1, 4), (2, 1), (2, 5), (3, 5), (4, 5), (5, 4)]) >>> pattern.areas Name_Area Name_Region Id_Area 1 Area 1 Region 1 2 Area 2 Region 1 3 Area 3 Region 1 4 Area 4 Region 1 5 Area 5 Region 1 6 Area 6 Region 2 7 Area 7 Region 2 8 Area 8 Region 2 9 Area 9 Region 2 10 Area 10 Region 2 """ g = nx.DiGraph() g.graph['min_time'] = 0 g.graph['max_time'] = len(networks_list) - 1 k = 1 all_areas = set() areas_regions = {} for t, networks in enumerate(networks_list): for areas_def, region_id, strength in networks: areas = __def2areas(areas_def) all_areas |= areas region = f"Region {region_id}" for area in areas: areas_regions[area] = region g.add_node(k, t=t, areas=areas, region=region, internal_strength=strength) k += 1 for source, target, corr in spatial_edges: g.add_edge(source, target, correlation=corr, type='spatial') g.add_edge(target, source, correlation=corr, type='spatial') for temporal_link in temporal_edges: for source, target, rc5 in __trans(*temporal_link): g.add_edge(source, target, transition=rc5, type='temporal') all_areas = sorted(all_areas) areas = pd.DataFrame({'Id_Area': all_areas, 'Name_Area': [f"Area {a}" for a in all_areas], 'Name_Region': [areas_regions[a] for a in all_areas]}) areas.set_index('Id_Area', inplace=True) return SpatioTemporalGraph(g, areas)
[docs] class SpatioTemporalGraphSimulator: """Simulator for spatio-temporal graphs. The simulator needs predefined patterns (created either manually or automatically) to generate a full spatio-temporal graph with those patterns included as instructed, eventually with in-between repeats. Examples -------- >>> pattern1 = generate_pattern( ... networks_list=[ ... [((1, 2), 1, 0.7), (3, 1, 1), ((4, 5), 2, -0.8)], ... [((1, 3), 1, 0.8), ((4, 5), 2, -0.8)]], ... spatial_edges=[(1, 3, 0.5), (4, 5, 0.6)], ... temporal_edges=[((1, 2), 4, 'merge'), (3, 5, 'eq')]) >>> pattern2 = generate_pattern( ... networks_list=[ ... [((1, 3), 1, 0.8), ((4, 5), 2, -0.8)], ... [((1, 2), 1, 0.7), (3, 1, 1), ((4, 5), 2, -0.8)]], ... spatial_edges=[(1, 2, 0.6), (3, 5, 0.5)], ... temporal_edges=[(1, (3, 4), 'split'), (2, 5, 'eq')]) >>> simulator = SpatioTemporalGraphSimulator(p1=pattern1, p2=pattern2) >>> graph = simulator.simulate('p2', 3, 'p1') >>> graph.nodes NodeView((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) >>> graph.edges OutEdgeView([(1, 2), (1, 3), (1, 4), (2, 1), (2, 5), (3, 5), (3, 6), (4, 7), (5, 3), (5, 8), (6, 8), (6, 9), (7, 10), (8, 6), (8, 11), (9, 11), (9, 12), (10, 13), (11, 9), (11, 14), (12, 14), (12, 15), (13, 16), (14, 12), (14, 17), (15, 17), (15, 18), (16, 18), (17, 15), (17, 19), (18, 19), (19, 18)]) """
[docs] def __init__(self, **patterns: SpatioTemporalGraph) -> None: """Initialise the simulator with a set of named patterns. Parameters ---------- **patterns: SpatioTemporalGraph Keyword arguments mapping pattern names (str) to their :class:`~graph.SpatioTemporalGraph` definitions. """ self.__patterns = patterns
def _simulate_areas_descriptions(self, patterns: list[str | int]) -> pd.DataFrame: """Merge the areas DataFrames of all named patterns in the sequence. Integer entries (repeat counts) are ignored; only string entries that reference registered patterns contribute their areas. Parameters ---------- patterns: list[str | int] The pattern sequence (same format accepted by :meth:`simulate`). Returns ------- pandas.DataFrame A deduplicated DataFrame containing area descriptions from all referenced patterns. """ areas_descriptions = [self.__patterns[pattern].areas for pattern in patterns if isinstance(pattern, str)] return pd.concat(areas_descriptions).drop_duplicates() @staticmethod def __shift_node_data(data: dict[str, any], dt: int) -> dict[str, any]: """Return a copy of a node's data dictionary with the time attribute shifted. Parameters ---------- data: dict[str, any] Original node data (must contain a ``'t'`` key). dt: int Time offset to add. Returns ------- dict[str, any] A shallow copy of ``data`` with ``t`` incremented by ``dt``. """ tmp = dict(data) tmp['t'] += dt return tmp @staticmethod def __shift_nodes(nodes: NodeView, dt: int, k: int) -> list[tuple[int, dict[str, any]]]: """Shift node identifiers and time attributes for graph concatenation. Parameters ---------- nodes: NodeView The nodes to shift. dt: int Time offset added to each node's ``t`` attribute. k: int Integer offset added to each node identifier to avoid collisions with existing nodes. Returns ------- list[tuple[int, dict[str, any]]] A sorted list of ``(new_node_id, shifted_data)`` pairs. """ return [(n + k, SpatioTemporalGraphSimulator.__shift_node_data(d, dt)) for n, d in sorted(nodes.items(), key=lambda x: x[0])] def _simulate_graph_from_patterns(self, patterns: list[str | int]) -> nx.DiGraph: """Build a directed graph by concatenating patterns and optional repeats. Patterns are chained sequentially: node identifiers and time steps are shifted to avoid collisions. An integer element inserts that many EQ-linked copies of the last time-point of the preceding pattern. Parameters ---------- patterns: list[str | int] Sequence of pattern names (str) and repeat counts (int). Returns ------- nx.DiGraph The assembled directed graph with updated ``min_time`` / ``max_time`` graph attributes. Raises ------ ValueError If a sequence element is neither a ``str`` nor an ``int``. """ g = nx.DiGraph(self.__patterns[patterns[0]]) for next_pattern in patterns[1:]: last_t = g.graph['max_time'] last_out = subgraph_nodes(g, t=last_t) if isinstance(next_pattern, int): for i in range(next_pattern): k = len(last_out.nodes) m = (i + 1) * k g.add_nodes_from(SpatioTemporalGraphSimulator.__shift_nodes(last_out.nodes, i + 1, m)) g.add_edges_from(reduce(list.__add__, [ [(n1 + m, n2 + m, d), (n2 + m, n1 + m, d)] for (n1, n2), d in last_out.edges.items() if d['type'] == 'spatial'], [])) g.add_edges_from([(n + i * k, n + m, {'transition': RC5.EQ, 'type': 'temporal'}) for n in sorted(last_out.nodes)]) g.graph['max_time'] += next_pattern elif isinstance(next_pattern, str): next_pattern = self.__patterns[next_pattern] k = max(g.nodes) dt = g.graph['max_time'] - g.graph['min_time'] + 1 # add pattern (with time and nodes shifted appropriately) g.add_nodes_from(SpatioTemporalGraphSimulator.__shift_nodes(next_pattern.nodes, dt, k)) g.add_edges_from([(e1 + k, e2 + k, d) for (e1, e2), d in next_pattern.edges.items()]) g.graph['max_time'] += next_pattern.graph['max_time'] + 1 # make the connection between last pattern and next one next_in = subgraph_nodes(next_pattern, t=next_pattern.graph['min_time']) for nout, nin in zip(sorted(last_out.nodes), sorted(next_in.nodes)): g.add_edge(nout, nin + k, transition=RC5.EQ, type='temporal') else: raise ValueError(f"pattern type {type(next_pattern)} " "is not recognized! It must be either " "int or nx.DiGraph.") return g
[docs] def simulate(self, *patterns: str | int) -> SpatioTemporalGraph: """Simulate the given sequence of pattern. Parameters ---------- patterns: tuple[str | int] The sequence of patterns. A string references a pattern registered at the creation of the simulator and an integer reference a number of times repeats in-between patterns. A repeat is a subgraph of the last time of the last pattern. Returns ------- SpatioTemporalGraph The built spatio-temporal graph. """ return SpatioTemporalGraph(self._simulate_graph_from_patterns(list(patterns)), self._simulate_areas_descriptions(list(patterns)))