Source code for connectome_manipulator.connectome_comparison.properties

# This file is part of connectome-manipulator.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024 Blue Brain Project/EPFL

"""Module for comparing connectomes based on synapse properties:

Structural comparison of two connectomes in terms of synapse properties per pathway,
as specified by the config. For each connectome, the underlying properties maps are
computed by the :func:`compute` function and will be saved to a data file first.
The individual synapse properties maps, together with a difference map between the
two connectomes, are then plotted by means of the :func:`plot` function.
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import progressbar
from connectome_manipulator.access_functions import get_edges_population, get_node_ids


[docs] def compute( circuit, fct="np.mean", group_by=None, sel_src=None, sel_dest=None, per_conn=False, skip_empty_groups=False, edges_popul_name=None, **_, ): """Computes a matrix of synapse property values between groups of neurons of a given circuit's connectome. Args: circuit (bluepysnap.Circuit): Input circuit fct (str): Function to apply, e.g., "np.mean", "np.std" group_by (str): Neuron property name based on which to group connections, e.g., "synapse_class", "layer", or "mtype"; if omitted, the overall average is computed sel_src (str/list-like/dict): Source (pre-synaptic) neuron selection sel_dest (str/list-like/dict): Target (post-synaptic) neuron selection per_conn (bool): If selected, ``fct`` is applied to the average property value per connection (i.e., average value of all synapses belonging to a connection); otherwise, ``fct`` is applied to the synapses of all connections altogether skip_empty_groups (bool): If selected, only group property values that exist within the given source/target selection are kept; otherwise, all group property values, even if not present in the given source/target selection, will be included edges_popul_name (str): Name of SONATA egdes population to extract data from Returns: dict: Dictionary containing the computed data elements; see Notes Note: The returned dictionary contains the data elements that can be selected for plotting through the structural comparison configuration file, together with a common dictionary containing additional information. Each data element is a dictionary with "data" (numpy.ndarray of size <source-group-size x target-group-size>), "name" (str), and "unit" (str) items. Names of these data elements correspond to the synapse properties that are present in the given SONATA edges population. Usual properties may include for example: * "conductance": Peak conductance * "decay_time": Decay time constant * "depression_time": Time constant for recovery from depression * "facilitation_time": Time constant for recovery from facilitation * "u_syn": Utilization of synaptic efficacy * ... """ # Select edge population edges = get_edges_population(circuit, edges_popul_name) # Select corresponding source/target nodes populations src_nodes = edges.source tgt_nodes = edges.target if group_by is None: src_group_sel = [sel_src] tgt_group_sel = [sel_dest] else: if ( skip_empty_groups ): # Take only group property values that exist within given src/tgt selection src_group_values = np.unique( src_nodes.get(get_node_ids(src_nodes, sel_src), properties=group_by) ) tgt_group_values = np.unique( tgt_nodes.get(get_node_ids(tgt_nodes, sel_dest), properties=group_by) ) else: # Keep all group property values, even if not present in given src/tgt selection, to get the full matrix src_group_values = sorted(src_nodes.property_values(group_by)) tgt_group_values = sorted(tgt_nodes.property_values(group_by)) if sel_src is None: sel_src = {} else: assert isinstance( sel_src, dict ), "ERROR: Source node selection must be a dict or empty!" # Otherwise, it cannot be merged with group selection if sel_dest is None: sel_dest = {} else: assert isinstance( sel_dest, dict ), "ERROR: Target node selection must be a dict or empty!" # Otherwise, it cannot be merged with pathway selection src_group_sel = [ {**sel_src, group_by: src_group_values[idx]} for idx in range(len(src_group_values)) ] # group_by will overwrite selection in case group property also exists in selection! tgt_group_sel = [ {**sel_dest, group_by: tgt_group_values[idx]} for idx in range(len(tgt_group_values)) ] # group_by will overwrite selection in case group property also exists in selection! print( f"INFO: Extracting synapse properties (group_by={group_by}, sel_src={sel_src}, sel_dest={sel_dest}, N={len(src_group_values)}x{len(tgt_group_values)} groups, per_conn={per_conn})", flush=True, ) edge_props = sorted(edges.property_names) print(f"INFO: Available synapse properties: \n{edge_props}", flush=True) prop_fct = eval(fct) prop_tables = np.full((len(src_group_sel), len(tgt_group_sel), len(edge_props)), np.nan) pbar = progressbar.ProgressBar() for idx_pre in pbar(range(len(src_group_sel))): sel_pre = src_group_sel[idx_pre] for idx_post, _ in enumerate(tgt_group_sel): sel_post = tgt_group_sel[idx_post] pre_ids = get_node_ids(src_nodes, sel_pre) post_ids = get_node_ids(tgt_nodes, sel_post) e_sel = edges.pathway_edges(pre_ids, post_ids, edge_props) if e_sel.size > 0: if per_conn: # Apply prop_fct to average value per connection conn, conn_idx = np.unique( e_sel[["@source_node", "@target_node"]], axis=0, return_inverse=True ) c_sel = pd.DataFrame(index=range(conn.shape[0]), columns=edge_props) for cidx in range(conn.shape[0]): c_sel.loc[cidx, :] = np.mean(e_sel[conn_idx == cidx], axis=0) prop_tables[idx_pre, idx_post, :] = prop_fct(c_sel.to_numpy(), axis=0) else: prop_tables[idx_pre, idx_post, :] = prop_fct(e_sel.to_numpy(), axis=0) fname = prop_fct.__name__[0].upper() + prop_fct.__name__[1:] cname = " (per conn)" if per_conn else "" res_dict = { edge_props[idx]: { "data": prop_tables[:, :, idx], "name": f'"{edge_props[idx]}" property', "unit": f"{fname} {edge_props[idx]}{cname}", } for idx in range(len(edge_props)) } res_dict["common"] = { "src_group_values": src_group_values, "tgt_group_values": tgt_group_values, } return res_dict
[docs] def plot( res_dict, common_dict, fig_title=None, vmin=None, vmax=None, isdiff=False, group_by=None, **_ ): # pragma:no cover """Plots a properties matrix or a difference matrix. Args: res_dict (dict): Results dictionary, containing selected data for plotting; must contain a "data" item with a properties matrix of type numpy.ndarray of size <#source-group-values x #target-group-values>, as well as "name" and "unit" items containing strings. common_dict (dict): Common dictionary, containing additional information; must contain "src_group_values" and "tgt_group_values" items containing lists of source/target values of the grouped property, matching the size of the properties matrix in ``res_dict`` fig_title (str): Optional figure title vmin (float): Minimum plot range vmax (float): Maximum plot range isdiff (bool): Flag indicating that ``res_dict`` contains a difference matrix; in this case, a symmetric plot range is required and a divergent colormap will be used group_by (str): Neuron property name based on which to group connections, e.g., "synapse_class", "layer", or "mtype"; if omitted, the overall average is computed """ if isdiff: # Difference plot assert -1 * vmin == vmax, "ERROR: Symmetric plot range required!" cmap = "PiYG" # Symmetric (diverging) colormap else: # Regular plot cmap = "hot_r" # Regular colormap plt.imshow(res_dict["data"], interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) if fig_title is None: plt.title(res_dict["name"]) else: plt.title(fig_title) if group_by: plt.xlabel(f"Postsynaptic {group_by}") plt.ylabel(f"Presynaptic {group_by}") n_grp = np.maximum(len(common_dict["src_group_values"]), len(common_dict["tgt_group_values"])) font_size = max(13 - n_grp / 6, 1) # Font scaling if len(common_dict["src_group_values"]) > 0: plt.yticks( range(len(common_dict["src_group_values"])), common_dict["src_group_values"], rotation=0, fontsize=font_size, ) if len(common_dict["tgt_group_values"]) > 0: if max(len(str(grp)) for grp in common_dict["tgt_group_values"]) > 1: rot_x = 90 else: rot_x = 0 plt.xticks( range(len(common_dict["tgt_group_values"])), common_dict["tgt_group_values"], rotation=rot_x, fontsize=font_size, ) cb = plt.colorbar() cb.set_label(res_dict["unit"])