# Copyright 2025 ICube (University of Strasbourg - CNRS)
# author: Julien PONTABRY (ICube)
#
# This software is a computer program whose purpose is to provide a toolkit
# to model, process and analyze the longitudinal reorganization of brain
# connectivity data, as functional MRI for instance.
#
# This software is governed by the CeCILL-B license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/or redistribute the software under the terms of the CeCILL-B
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
"""Defines the tools to simulate spatio-temporal graphs and functional connectivity data."""
from dataclasses import dataclass
from functools import reduce
from itertools import combinations
from typing import Iterable
import networkx as nx
import numpy as np
import pandas as pd
from networkx.classes.reportviews import NodeView
from .graph import RC5, SpatioTemporalGraph, subgraph_nodes
def _fill_matrix(connections, correlations, matrix):
"""Fill a correlation matrix in-place with the given connection values.
Parameters
----------
connections: list[tuple[int, int]]
Pairs of area indices (1-based) to connect.
correlations: list[float]
Correlation value for each connection pair.
matrix: numpy.ndarray
The square correlation matrix to update (modified in-place).
"""
for (e1, e2), corr in zip(connections, correlations):
i = e1 - 1
j = e2 - 1
matrix[i, j] = matrix[j, i] = corr
@dataclass(frozen=True)
class _CorrelationMatrixNetworksEdgesFiller:
"""Fills intra-network (within-region) correlations in a correlation matrix.
For each network (node) in a spatial graph, randomly samples a set of
connected area pairs and assigns correlation values whose mean matches the
network's internal strength.
Parameters
----------
threshold: float
The correlation threshold; intra-network correlations are constrained
to remain above this value in absolute terms.
rng: numpy.random.Generator
A NumPy random number generator used for reproducible sampling.
"""
threshold: float
rng: np.random.Generator
def __choose_connections(self, network: list[int]) -> list[tuple[int, int]]:
"""Randomly select a connected set of area pairs within a network.
Repeatedly samples random subsets of all pairwise combinations until a
connected graph can be formed, guaranteeing that every area is reachable.
Parameters
----------
network: list[int]
Area identifiers (1-based) belonging to the network.
Returns
-------
list[tuple[int, int]]
A list of area-pair edges forming a connected subgraph of the network.
"""
def __new_graph(trial_connections: list[tuple[int, int]]) -> nx.Graph:
g = nx.Graph()
g.add_nodes_from(network)
g.add_edges_from(trial_connections)
return g
def __try_random_connections(net: list[int]) -> list[tuple[int, int]]:
combs = list(combinations(net, 2))
if len(combs) == 1:
return combs
else:
selected = []
for _ in range(self.rng.integers(low=len(network) - 1, high=len(combs))):
elem = combs.pop(self.rng.integers(len(combs)))
selected.append(elem)
return selected
trial_graph = __new_graph(__try_random_connections(network))
while not nx.is_connected(trial_graph):
trial_graph = __new_graph(__try_random_connections(network))
return list(trial_graph.edges)
def __mean_corr_sampler(self, size: int, mean: float, max_attempts: int = 1_000) -> list[float]:
"""Sample ``size`` correlation values whose empirical mean equals ``mean``.
Values are sampled uniformly within a symmetric interval around ``mean``
and then shifted so that their mean matches exactly. The interval is
constrained to keep all values on the same side of ``threshold``.
Parameters
----------
size: int
Number of correlation values to sample.
mean: float
Target mean correlation value.
max_attempts: int
The maximal number of attempts to sample. Default is 1000
Returns
-------
list[float]
A list of ``size`` correlation values with the exact requested mean.
"""
def __sample(low: float, high: float) -> np.ndarray:
values = self.rng.uniform(low=low, high=high, size=size)
return values + mean - values.mean()
rad = abs(mean) - self.threshold
a, b = mean - rad, mean + rad
samples = __sample(a, b)
attempts = 0
while any(samples < a) or any(samples > b):
samples = __sample(a, b)
attempts += 1
if attempts >= max_attempts:
raise RuntimeError(
f"Could not sample values in [{a}, {b}] with mean={mean} after {max_attempts} attempts"
)
return samples.tolist()
def fill(self, spatial_graph: nx.DiGraph, matrix: np.ndarray) -> None:
"""Fill intra-network correlations for all networks in the spatial graph.
Parameters
----------
spatial_graph: nx.DiGraph
A single time-point spatial graph whose nodes represent networks
(each carrying ``areas`` and ``internal_strength`` attributes).
matrix: numpy.ndarray
The square correlation matrix to update in-place.
"""
networks = [(data['areas'], data['internal_strength'])
for _, data in spatial_graph.nodes.items()]
for network, mean_corr in networks:
connections = self.__choose_connections(network)
correlations = self.__mean_corr_sampler(len(connections), mean_corr)
_fill_matrix(connections, correlations, matrix)
@dataclass(frozen=True)
class _CorrelationMatrixInterRegionEdgesFiller:
"""Fills inter-network (cross-region) correlations in a correlation matrix.
For each spatial edge between two networks, randomly samples area pairs
across the two networks and assigns correlation values whose maximum (in
absolute terms) matches the edge's correlation attribute.
Parameters
----------
threshold: float
The correlation threshold; inter-region correlations are constrained
to remain above this value in absolute terms.
rng: numpy.random.Generator
A NumPy random number generator used for reproducible sampling.
"""
threshold: float
rng: np.random.Generator
def __max_correlation_sampler(self, size: int, target: float) -> list[float]:
"""Sample ``size`` correlation values whose extreme value equals ``target``.
The extreme is the maximum for positive targets and the minimum for
negative targets, so that the inter-region edge correlation is reproduced.
Parameters
----------
size: int
Number of correlation values to sample.
target: float
The desired extreme correlation value.
Returns
-------
list[float]
A list of ``size`` correlation values with the exact requested extreme.
"""
sample_fun = np.max if target >= 0 else np.min
def __sample(low: float, high: float) -> np.array:
values = self.rng.uniform(low=low, high=high, size=size)
return values + target - sample_fun(values)
thr = np.sign(target) * self.threshold
mean = (thr + target) / 2
rad = abs(target - mean)
a, b = mean - rad, mean + rad
samples = __sample(a, b)
while any(samples < a) or any(samples > b):
samples = __sample(a, b)
return samples
def __choose_inter_region_connections(self, network1: set, network2: set) -> list[tuple[int, int]]:
"""Randomly select area pairs connecting two distinct networks.
For each area in ``network1``, a Poisson-distributed number of areas from
``network2`` are selected as targets (at least one per source area).
Parameters
----------
network1: set[int]
Area identifiers (1-based) of the source network.
network2: set[int]
Area identifiers (1-based) of the target network.
Returns
-------
list[tuple[int, int]]
A list of (source_area, target_area) pairs.
"""
def __choose_inter_region_connection(k: int, network: set, n: int) -> list[tuple[int, int]]:
combs = list(zip([k]*len(network), network))
selected = []
for _ in range(min(n, len(combs))):
elem = combs.pop(self.rng.integers(len(combs)))
selected.append(elem)
return selected
connections_sizes = self.rng.poisson(lam=1, size=len(network1))
connections = []
for node, nb_connections in zip(network1, connections_sizes):
connections += __choose_inter_region_connection(node, network2, max(1, nb_connections))
return connections
def fill(self, spatial_graph: nx.DiGraph, matrix: np.array) -> None:
"""Fill inter-network correlations for all edges in the spatial graph.
Parameters
----------
spatial_graph: nx.DiGraph
A single time-point spatial graph whose edges carry a ``correlation``
attribute describing the inter-network correlation strength.
matrix: numpy.ndarray
The square correlation matrix to update in-place.
"""
for (node1, node2), data in spatial_graph.edges.items():
connections = self.__choose_inter_region_connections(
spatial_graph.nodes[node1]['areas'], spatial_graph.nodes[node2]['areas'])
correlations = self.__max_correlation_sampler(len(connections), data['correlation'])
_fill_matrix(connections, correlations, matrix)
[docs]
class CorrelationMatrixSequenceSimulator:
"""Simulate a sequence of correlation matrices from a spatio-temporal graph.
Examples
--------
>>> graph = nx.DiGraph()
>>> graph.add_node(1, t=0, areas={1, 2}, region='Region 1', internal_strength=0.98)
>>> graph.add_node(2, t=0, areas={3, 4}, region='Region 2', internal_strength=-0.98)
>>> graph.add_edge(1, 2, correlation=0.94, t=0, type='spatial')
>>> graph.add_edge(2, 1, correlation=0.94, t=0, type='spatial')
>>> graph.graph['min_time'] = 0
>>> graph.graph['max_time'] = 0
>>> areas = pd.DataFrame({'Id_Area': [1, 2, 3, 4],
... 'Name_Area': ['A1', 'A2', 'A3', 'A4'],
... 'Name_Region': ['R1', 'R1', 'R2', 'R2']})
>>> areas.set_index('Id_Area', inplace=True)
>>> simulator = CorrelationMatrixSequenceSimulator(SpatioTemporalGraph(graph, areas), threshold=0.4,
... rng=np.random.default_rng(40))
>>> matrix = simulator.simulate()
>>> matrix.shape
(1, 4, 4)
>>> matrix
array([[[ 1. , 0.98 , 0.65453818, 0.94 ],
[ 0.98 , 1. , 0.85381682, 0.61873641],
[ 0.65453818, 0.85381682, 1. , -0.98 ],
[ 0.94 , 0.61873641, -0.98 , 1. ]]])
"""
[docs]
def __init__(self, graph: SpatioTemporalGraph, threshold: float = 0.4,
rng: np.random.Generator = np.random.default_rng()) -> None:
"""Initialise the simulator.
Parameters
----------
graph: SpatioTemporalGraph
The reference spatio-temporal graph from which correlation
matrices will be back-generated.
threshold: float, optional
Minimum absolute correlation required for spatial edges to be
considered significant (default 0.4, must be in [0, 1]).
rng: numpy.random.Generator, optional
Random number generator for reproducible results.
Raises
------
ValueError
If ``threshold`` is outside [0, 1].
"""
self.graph = graph
self.threshold = threshold
self.__rng = rng
self.__network_edges_filler = _CorrelationMatrixNetworksEdgesFiller(self.threshold, self.__rng)
self.__inter_region_edges_filler = _CorrelationMatrixInterRegionEdgesFiller(self.threshold, self.__rng)
self.__init_validation__()
[docs]
def __init_validation__(self):
"""Validate constructor parameters.
Raises
------
ValueError
If ``threshold`` is not in [0, 1].
"""
if self.threshold < 0 or self.threshold > 1:
raise ValueError("The threshold must be within range [0, 1]!")
def __simulate_corr_matrix(self, spatial_graph: nx.DiGraph) -> np.array:
"""Generate a single correlation matrix from a one-time-point spatial graph.
Fills intra-network and inter-network correlations, then randomises
any remaining zero entries with sub-threshold uniform noise.
Parameters
----------
spatial_graph: nx.DiGraph
The spatial subgraph at a single time point.
Returns
-------
numpy.ndarray
A square symmetric correlation matrix of shape ``(n_areas, n_areas)``.
"""
matrix = np.eye(len(self.graph.areas))
self.__network_edges_filler.fill(spatial_graph, matrix)
self.__inter_region_edges_filler.fill(spatial_graph, matrix)
null_elements = matrix == 0
bound = self.threshold * 0.99
matrix[null_elements] = self.__rng.uniform(low=-bound, high=bound, size=null_elements.sum())
return matrix
[docs]
def simulate(self) -> np.array:
"""Simulate the sequence of correlation matrices.
Returns
-------
numpy.array
A 3D-shaped array that contains the correlations matrices for each time.
"""
return np.array([self.__simulate_corr_matrix(self.graph.sub(t=t))
for t in self.graph.time_range])
def __trans(sources: Iterable[int] | int, targets: Iterable[int] | int, kind: str) -> list[tuple[int, int, RC5]]:
"""Expand a human-readable transition description into a list of (source, target, RC5) triples.
Parameters
----------
sources: int or Iterable[int]
Source node(s). For a ``'merge'`` transition this must be an iterable.
targets: int or Iterable[int]
Target node(s). For a ``'split'`` transition this must be an iterable.
kind: str
Transition type: ``'split'`` (:attr:`RC5.PPi`), ``'merge'`` (:attr:`RC5.PP`),
``'eq'`` (:attr:`RC5.EQ`), or any other string for :attr:`RC5.PO`.
Returns
-------
list[tuple[int, int, RC5]]
A list of ``(source_node, target_node, transition)`` tuples.
"""
if kind.lower() == 'split':
return [(sources, target, RC5.PPi) for target in targets]
elif kind.lower() == 'merge':
return [(source, targets, RC5.PP) for source in sources]
else:
trans = RC5.EQ if kind.lower() == 'eq' else RC5.PO
return [(sources, targets, trans)]
def __def2areas(areas_def: tuple[int, int] | Iterable[int] | int) -> set[int]:
"""Convert a compact area definition to a set of area identifiers.
Parameters
----------
areas_def: tuple[int, int] | Iterable[int] | int
- A 2-tuple ``(start, end)`` is expanded to the inclusive range
``{start, start+1, ..., end}``.
- Any other iterable is converted directly to a set.
- A single integer is wrapped in a singleton set.
Returns
-------
set[int]
The corresponding set of area identifiers.
"""
if isinstance(areas_def, tuple) and len(areas_def) == 2:
start, end = areas_def
return set(range(start, end + 1))
elif isinstance(areas_def, Iterable):
return set(areas_def)
else:
return {areas_def}
[docs]
def generate_pattern(networks_list: list[list[tuple[tuple[int, int], int, float]]],
spatial_edges: list[tuple[int, int, float]],
temporal_edges: list[tuple[Iterable[int] | int, Iterable[int] | int, str]]) -> SpatioTemporalGraph:
"""Generate a pattern with the specified properties.
Parameters
----------
networks_list: list[list[tuple[tuple[int, int], int, float]]]
A list of nodes per time instant, defined themselves by a tuple of area range, region id and
internal strength.
spatial_edges: list[tuple[int, int, float]]
A list of spatial edges defined by a tuple of source/target nodes and a correlation.
temporal_edges: list[tuple[Iterable[int] | int, Iterable[int] | int, str]]
A list of temporal edges, defined by a tuple of source(s)/target(s) nodes and a transition.
Returns
-------
SpatioTemporalGraph
A spatio-temporal graph that can be used as a pattern.
Example
-------
>>> pattern = generate_pattern(
... networks_list=[[((1, 5), 1, -0.2), ((6, 7), 2, 0.3), ((8, 10), 2, 0.6)],
... [((1, 5), 1, 0.6), ((6, 10), 2, -0.5)]],
... spatial_edges=[(1, 2, 0.45), (4, 5, 0.8)],
... temporal_edges=[(1, 4, 'eq'), ((2, 3), 5, 'merge')])
>>> pattern.nodes
NodeView((1, 2, 3, 4, 5))
>>> pattern.edges
OutEdgeView([(1, 2), (1, 4), (2, 1), (2, 5), (3, 5), (4, 5), (5, 4)])
>>> pattern.areas
Name_Area Name_Region
Id_Area
1 Area 1 Region 1
2 Area 2 Region 1
3 Area 3 Region 1
4 Area 4 Region 1
5 Area 5 Region 1
6 Area 6 Region 2
7 Area 7 Region 2
8 Area 8 Region 2
9 Area 9 Region 2
10 Area 10 Region 2
"""
g = nx.DiGraph()
g.graph['min_time'] = 0
g.graph['max_time'] = len(networks_list) - 1
k = 1
all_areas = set()
areas_regions = {}
for t, networks in enumerate(networks_list):
for areas_def, region_id, strength in networks:
areas = __def2areas(areas_def)
all_areas |= areas
region = f"Region {region_id}"
for area in areas:
areas_regions[area] = region
g.add_node(k, t=t, areas=areas, region=region, internal_strength=strength)
k += 1
for source, target, corr in spatial_edges:
g.add_edge(source, target, correlation=corr, type='spatial')
g.add_edge(target, source, correlation=corr, type='spatial')
for temporal_link in temporal_edges:
for source, target, rc5 in __trans(*temporal_link):
g.add_edge(source, target, transition=rc5, type='temporal')
all_areas = sorted(all_areas)
areas = pd.DataFrame({'Id_Area': all_areas,
'Name_Area': [f"Area {a}" for a in all_areas],
'Name_Region': [areas_regions[a] for a in all_areas]})
areas.set_index('Id_Area', inplace=True)
return SpatioTemporalGraph(g, areas)
[docs]
class SpatioTemporalGraphSimulator:
"""Simulator for spatio-temporal graphs.
The simulator needs predefined patterns (created either manually or
automatically) to generate a full spatio-temporal graph with those patterns
included as instructed, eventually with in-between repeats.
Examples
--------
>>> pattern1 = generate_pattern(
... networks_list=[
... [((1, 2), 1, 0.7), (3, 1, 1), ((4, 5), 2, -0.8)],
... [((1, 3), 1, 0.8), ((4, 5), 2, -0.8)]],
... spatial_edges=[(1, 3, 0.5), (4, 5, 0.6)],
... temporal_edges=[((1, 2), 4, 'merge'), (3, 5, 'eq')])
>>> pattern2 = generate_pattern(
... networks_list=[
... [((1, 3), 1, 0.8), ((4, 5), 2, -0.8)],
... [((1, 2), 1, 0.7), (3, 1, 1), ((4, 5), 2, -0.8)]],
... spatial_edges=[(1, 2, 0.6), (3, 5, 0.5)],
... temporal_edges=[(1, (3, 4), 'split'), (2, 5, 'eq')])
>>> simulator = SpatioTemporalGraphSimulator(p1=pattern1, p2=pattern2)
>>> graph = simulator.simulate('p2', 3, 'p1')
>>> graph.nodes
NodeView((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
>>> graph.edges
OutEdgeView([(1, 2), (1, 3), (1, 4), (2, 1), (2, 5), (3, 5), (3, 6), (4, 7), (5, 3), (5, 8),
(6, 8), (6, 9), (7, 10), (8, 6), (8, 11), (9, 11), (9, 12), (10, 13), (11, 9), (11, 14), (12, 14),
(12, 15), (13, 16), (14, 12), (14, 17), (15, 17), (15, 18), (16, 18), (17, 15), (17, 19), (18, 19), (19, 18)])
"""
[docs]
def __init__(self, **patterns: SpatioTemporalGraph) -> None:
"""Initialise the simulator with a set of named patterns.
Parameters
----------
**patterns: SpatioTemporalGraph
Keyword arguments mapping pattern names (str) to their
:class:`~graph.SpatioTemporalGraph` definitions.
"""
self.__patterns = patterns
def _simulate_areas_descriptions(self, patterns: list[str | int]) -> pd.DataFrame:
"""Merge the areas DataFrames of all named patterns in the sequence.
Integer entries (repeat counts) are ignored; only string entries that
reference registered patterns contribute their areas.
Parameters
----------
patterns: list[str | int]
The pattern sequence (same format accepted by :meth:`simulate`).
Returns
-------
pandas.DataFrame
A deduplicated DataFrame containing area descriptions from all
referenced patterns.
"""
areas_descriptions = [self.__patterns[pattern].areas
for pattern in patterns
if isinstance(pattern, str)]
return pd.concat(areas_descriptions).drop_duplicates()
@staticmethod
def __shift_node_data(data: dict[str, any], dt: int) -> dict[str, any]:
"""Return a copy of a node's data dictionary with the time attribute shifted.
Parameters
----------
data: dict[str, any]
Original node data (must contain a ``'t'`` key).
dt: int
Time offset to add.
Returns
-------
dict[str, any]
A shallow copy of ``data`` with ``t`` incremented by ``dt``.
"""
tmp = dict(data)
tmp['t'] += dt
return tmp
@staticmethod
def __shift_nodes(nodes: NodeView, dt: int, k: int) -> list[tuple[int, dict[str, any]]]:
"""Shift node identifiers and time attributes for graph concatenation.
Parameters
----------
nodes: NodeView
The nodes to shift.
dt: int
Time offset added to each node's ``t`` attribute.
k: int
Integer offset added to each node identifier to avoid collisions
with existing nodes.
Returns
-------
list[tuple[int, dict[str, any]]]
A sorted list of ``(new_node_id, shifted_data)`` pairs.
"""
return [(n + k, SpatioTemporalGraphSimulator.__shift_node_data(d, dt))
for n, d in sorted(nodes.items(), key=lambda x: x[0])]
def _simulate_graph_from_patterns(self, patterns: list[str | int]) -> nx.DiGraph:
"""Build a directed graph by concatenating patterns and optional repeats.
Patterns are chained sequentially: node identifiers and time steps are
shifted to avoid collisions. An integer element inserts that many EQ-linked
copies of the last time-point of the preceding pattern.
Parameters
----------
patterns: list[str | int]
Sequence of pattern names (str) and repeat counts (int).
Returns
-------
nx.DiGraph
The assembled directed graph with updated ``min_time`` / ``max_time``
graph attributes.
Raises
------
ValueError
If a sequence element is neither a ``str`` nor an ``int``.
"""
g = nx.DiGraph(self.__patterns[patterns[0]])
for next_pattern in patterns[1:]:
last_t = g.graph['max_time']
last_out = subgraph_nodes(g, t=last_t)
if isinstance(next_pattern, int):
for i in range(next_pattern):
k = len(last_out.nodes)
m = (i + 1) * k
g.add_nodes_from(SpatioTemporalGraphSimulator.__shift_nodes(last_out.nodes, i + 1, m))
g.add_edges_from(reduce(list.__add__, [
[(n1 + m, n2 + m, d),
(n2 + m, n1 + m, d)]
for (n1, n2), d in last_out.edges.items()
if d['type'] == 'spatial'], []))
g.add_edges_from([(n + i * k, n + m, {'transition': RC5.EQ, 'type': 'temporal'})
for n in sorted(last_out.nodes)])
g.graph['max_time'] += next_pattern
elif isinstance(next_pattern, str):
next_pattern = self.__patterns[next_pattern]
k = max(g.nodes)
dt = g.graph['max_time'] - g.graph['min_time'] + 1
# add pattern (with time and nodes shifted appropriately)
g.add_nodes_from(SpatioTemporalGraphSimulator.__shift_nodes(next_pattern.nodes, dt, k))
g.add_edges_from([(e1 + k, e2 + k, d)
for (e1, e2), d in next_pattern.edges.items()])
g.graph['max_time'] += next_pattern.graph['max_time'] + 1
# make the connection between last pattern and next one
next_in = subgraph_nodes(next_pattern, t=next_pattern.graph['min_time'])
for nout, nin in zip(sorted(last_out.nodes), sorted(next_in.nodes)):
g.add_edge(nout, nin + k, transition=RC5.EQ, type='temporal')
else:
raise ValueError(f"pattern type {type(next_pattern)} "
"is not recognized! It must be either "
"int or nx.DiGraph.")
return g
[docs]
def simulate(self, *patterns: str | int) -> SpatioTemporalGraph:
"""Simulate the given sequence of pattern.
Parameters
----------
patterns: tuple[str | int]
The sequence of patterns. A string references a pattern registered at the
creation of the simulator and an integer reference a number of times repeats
in-between patterns. A repeat is a subgraph of the last time of the last pattern.
Returns
-------
SpatioTemporalGraph
The built spatio-temporal graph.
"""
return SpatioTemporalGraph(self._simulate_graph_from_patterns(list(patterns)),
self._simulate_areas_descriptions(list(patterns)))