# This file is part of connectome-manipulator.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024 Blue Brain Project/EPFL
"""Module for building synaptic delay models"""
import os.path
import matplotlib.pyplot as plt
import numpy as np
import progressbar
from sklearn.linear_model import LinearRegression
from connectome_manipulator import log
from connectome_manipulator.model_building import model_types
from connectome_manipulator.access_functions import get_node_ids, get_edges_population, get_cv_data
[docs]
def build(dist_bins, dist_delays_mean, dist_delays_std, dist_delay_min, bin_size_um, **_):
"""Fits a linear distance-dependent synaptic delay model of type ``LinDelayModel`` to the data.
Args:
dist_bins (numpy.ndarray): Distance bin edges, as returned by :func:`extract`
dist_delays_mean (numpy.ndarray): Delay mean for all bins, as returned by :func:`extract`
dist_delays_std (numpy.ndarray): Delay std for all bins, as returned by :func:`extract`
dist_delay_min (float): Overall delay minimum, as returned by :func:`extract`
bin_size_um (float): Distance bin size in um
Returns:
connectome_manipulator.model_building.model_types.LinDelayModel: Fitted linear distance-dependent delay model
"""
log.log_assert(np.all((np.diff(dist_bins) - bin_size_um) < 1e-12), "ERROR: Bin size mismatch!")
bin_offset = 0.5 * bin_size_um
# Mean delay model (linear)
X = np.array(dist_bins[:-1][np.isfinite(dist_delays_mean)] + bin_offset, ndmin=2).T
y = dist_delays_mean[np.isfinite(dist_delays_mean)]
dist_delays_mean_fit = LinearRegression().fit(X, y)
delay_mean_coeff_a = dist_delays_mean_fit.intercept_
delay_mean_coeff_b = dist_delays_mean_fit.coef_[0]
# Std delay model (const)
delay_std = np.nanmean(dist_delays_std)
# Min delay model (const)
delay_min = dist_delay_min
# Create model
model = model_types.LinDelayModel(
delay_mean_coeff_a=float(delay_mean_coeff_a),
delay_mean_coeff_b=float(delay_mean_coeff_b),
delay_std=float(delay_std),
delay_min=float(delay_min),
)
log.debug("Model description:\n%s", model)
return model
[docs]
def plot(
out_dir, dist_bins, dist_delays_mean, dist_delays_std, dist_count, model, **_
): # pragma: no cover
"""Visualizes extracted data vs. actual model output.
Args:
out_dir (str): Path to output directory where the results figures will be stored
dist_bins (numpy.ndarray): Distance bin edges, as returned by :func:`extract`
dist_delays_mean (numpy.ndarray): Delay mean for all bins, as returned by :func:`extract`
dist_delays_std (numpy.ndarray): Delay std for all bins, as returned by :func:`extract`
dist_count (numpy.ndarray): Number of data elemets in each bin, as returned by :func:`extract`
model (connectome_manipulator.model_building.model_types.LinDelayModel): Fitted linear distance-dependent delay model, as returned by :func:`build`
"""
bin_width = np.diff(dist_bins[:2])[0]
model_params = model.get_param_dict()
mean_model_str = f'f(x) = {model_params["delay_mean_coeff_b"]:.3f} * x + {model_params["delay_mean_coeff_a"]:.3f}'
std_model_str = f'f(x) = {model_params["delay_std"]:.3f}'
min_model_str = f'f(x) = {model_params["delay_min"]:.3f}'
# Draw figure
model_kwargs = dict(zip(("src_type", "tgt_type"), model.default_types)) | {
"distance": dist_bins
}
plt.figure(figsize=(8, 4), dpi=300)
plt.bar(
dist_bins[:-1] + 0.5 * bin_width,
dist_delays_mean,
width=0.95 * bin_width,
facecolor="tab:blue",
label=f"Data mean: N = {np.sum(dist_count)} synapses",
)
plt.bar(
dist_bins[:-1] + 0.5 * bin_width,
dist_delays_std,
width=0.5 * bin_width,
facecolor="tab:red",
label=f"Data std: N = {np.sum(dist_count)} synapses",
)
plt.plot(
dist_bins,
model.get_mean(**model_kwargs),
"--",
color="tab:brown",
label="Model mean: " + mean_model_str,
)
plt.plot(
dist_bins,
model.get_std(**model_kwargs),
"--",
color="tab:olive",
label="Model std: " + std_model_str,
)
plt.plot(
dist_bins,
model.get_min(**model_kwargs),
"--",
color="tab:gray",
label="Model min: " + min_model_str,
)
plt.xlim((dist_bins[0], dist_bins[-1]))
plt.xlabel("Distance [um]")
plt.ylabel("Delay [ms]")
plt.title("Distance-dependent synaptic delays", fontweight="bold")
plt.legend(loc="upper left", bbox_to_anchor=(1.1, 1.0))
# Add second axis with bin counts
count_color = "tab:orange"
ax_count = plt.gca().twinx()
ax_count.set_yscale("log")
ax_count.step(dist_bins, np.concatenate((dist_count[:1], dist_count)), color=count_color)
ax_count.set_ylabel("Count", color=count_color)
ax_count.tick_params(axis="y", which="both", colors=count_color)
ax_count.spines["right"].set_color(count_color)
plt.tight_layout()
out_fn = os.path.abspath(os.path.join(out_dir, "data_vs_model.png"))
log.info(f"Saving {out_fn}...")
plt.savefig(out_fn)
# Visualize model output (generative model)
dist_centers = [np.mean(dist_bins[:2]), np.mean(dist_bins), np.mean(dist_bins[-2:])]
N = 1000 # Number of samples
plt.figure(figsize=(8, 4), dpi=300)
for didx, d in enumerate(dist_centers):
plt.subplot(1, len(dist_centers), didx + 1)
plt.hist(model.apply(distance=np.full(N, d)), bins=50)
plt.ylim(plt.ylim()) # Freeze limit
plt.title(f"{d:.0f} um")
plt.xlabel("Delay [ms]")
plt.ylabel("Count")
plt.suptitle("Delay distributions")
plt.tight_layout()
out_fn = os.path.abspath(os.path.join(out_dir, "model_output.png"))
log.info(f"Saving {out_fn}...")
plt.savefig(out_fn)