Source code for connectome_manipulator.model_building.pos_mapping
# This file is part of connectome-manipulator.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024 Blue Brain Project/EPFL
"""Module for building position mapping models based on a flatmap"""
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import progressbar
from scipy.interpolate import griddata
from scipy.spatial import distance_matrix
from voxcell.nexus.voxelbrain import Atlas
from connectome_manipulator import log
from connectome_manipulator.model_building import model_types
[docs]
def extract(
circuit,
flatmap_path,
xy_file,
z_file,
xy_scale=None,
z_scale=None,
nodes_pop_name=None,
NN_only=False,
CV_dict=None,
**_,
):
"""Extracts a position mapping from 3D atlas space (x/y/z) to 3D flat space (flat-x/flat-y/depth) of a given population of neurons.
Args:
circuit (bluepysnap.Circuit): Input circuit
flatmap_path (str): Base path to a flatmap
xy_file (str): Filename of x/y mapping file as part of the flatmap
z_file (str): Filename of z (= cortical depth) mapping file as part of the flatmap
xy_scale (list-like): Two-element list with x/y scaling factors from a.u. (as in flatmap) to um
z_scale (float): Scalar value with z scaling factor from a.u. (as in flatmap) to um
nodes_pop_name (str): Name of SONATA nodes population to extract data from
NN_only (bool): If selected, only nearest-neighbor interpolation will be used for position mapping (faster); otherwise, linear interpolation is applied, if possible (slower)
CV_dict (dict): Cross-validation dictionary - Not supported
Returns:
dict: Dictionary containing the extracted data elements, i.e., neuron positions in original and flat space
"""
log.log_assert(CV_dict is None, "ERROR: Cross-validation not supported!")
# Get neuron positions
if nodes_pop_name is None:
log.log_assert(
len(circuit.nodes.population_names) == 1,
f"ERROR: Nodes population could not be determined (found {circuit.nodes.population_names})!",
)
nodes_pop_name = circuit.nodes.population_names[0]
log.debug(f'Loading nodes population "{nodes_pop_name}"')
nodes = circuit.nodes[nodes_pop_name]
nrn_pos = nodes.positions()
nrn_ids = nrn_pos.index.to_numpy()
nrn_pos = nrn_pos.to_numpy()
nrn_lay = nodes.get(properties="layer")
# Load flatmap
flatmap_atlas = Atlas.open(flatmap_path)
flatmap = flatmap_atlas.load_data(xy_file)
depths = flatmap_atlas.load_data(z_file)
log.log_assert(
flatmap.raw.shape[:3] == depths.raw.shape,
"ERROR: Flatmap and depths map sizes inconsistent!",
)
log.log_assert(
np.all(flatmap.voxel_dimensions == depths.voxel_dimensions),
"ERROR: Flatmap and depths map voxels inconsistent!",
)
log.log_assert(
np.all(flatmap.bbox == depths.bbox),
"ERROR: Flatmap and depths map bounding boxes inconsistent!",
)
if xy_scale is None: # x/y scaling from a.u. to um
xy_scale = [
flatmap.voxel_dimensions[0],
flatmap.voxel_dimensions[1],
] # Default: Assume same pixel size in flat space as voxel size in atlas
else:
log.log_assert(
np.array(xy_scale).size == 2
and np.all(np.isfinite(xy_scale))
and np.all(np.array(xy_scale) != 0),
"ERROR: XY scale error!",
)
if z_scale is None: # z scaling from a.u. to um
z_scale = 1.0 # Default: Assume depth values are already scaled to um
else:
log.log_assert(
np.array(z_scale).size == 1 and np.isfinite(z_scale) and z_scale != 0,
"ERROR: Z scale error!",
)
log.info(
f'Loaded x/y flatmap ("{xy_file}"; scale={np.round(xy_scale, decimals=2)}) and z (depth) map ("{z_file}"; scale={np.round(z_scale, decimals=2)}) from "{flatmap_path}"'
)
# Convert cell positions to flat space [Assume: missing values set to -1]
flat_x = flatmap.lookup(nrn_pos)[:, 0].astype(float)
flat_x[flat_x != -1] = flat_x[flat_x != -1] * xy_scale[0]
flat_y = flatmap.lookup(nrn_pos)[:, 1].astype(float)
flat_y[flat_y != -1] = flat_y[flat_y != -1] * xy_scale[1]
flat_z = depths.lookup(nrn_pos).astype(float)
flat_z[flat_z != -1] = flat_z[flat_z != -1] * z_scale
# Determine map indices/positions
map_indices = flatmap.positions_to_indices(
nrn_pos, keep_fraction=True
) # Keep fractions within voxels => fraction x.5 corresponds to voxel center
map_pos = np.floor(map_indices) + 0.5 # Voxel values assumed to correspond to voxel center
# Nearest-neighbor interpolation only (FASTER)
if NN_only:
log.debug("Using nearest-neighbor interpolation only!")
flat_x_intpl = griddata(
map_pos[flat_x != -1], flat_x[flat_x != -1], map_indices, method="nearest"
)
flat_y_intpl = griddata(
map_pos[flat_y != -1], flat_y[flat_y != -1], map_indices, method="nearest"
)
flat_z_intpl = griddata(
map_pos[flat_z != -1], flat_z[flat_z != -1], map_indices, method="nearest"
)
else:
# Linear interpolation, if possible
flat_x_intpl = griddata(map_pos[flat_x != -1], flat_x[flat_x != -1], map_indices)
flat_y_intpl = griddata(map_pos[flat_y != -1], flat_y[flat_y != -1], map_indices)
flat_z_intpl = griddata(map_pos[flat_z != -1], flat_z[flat_z != -1], map_indices)
# Nearest-neighbor interpolation, otherwise
flat_x_intpl[np.isnan(flat_x_intpl)] = griddata(
map_pos[flat_x != -1],
flat_x[flat_x != -1],
map_indices[np.isnan(flat_x_intpl)],
method="nearest",
)
flat_y_intpl[np.isnan(flat_y_intpl)] = griddata(
map_pos[flat_y != -1],
flat_y[flat_y != -1],
map_indices[np.isnan(flat_y_intpl)],
method="nearest",
)
flat_z_intpl[np.isnan(flat_z_intpl)] = griddata(
map_pos[flat_z != -1],
flat_z[flat_z != -1],
map_indices[np.isnan(flat_z_intpl)],
method="nearest",
)
flat_pos = np.vstack((flat_x_intpl, flat_y_intpl, flat_z_intpl)).T
return {"nrn_ids": nrn_ids, "nrn_lay": nrn_lay, "nrn_pos": nrn_pos, "flat_pos": flat_pos}
[docs]
def build(nrn_ids, flat_pos, **_):
"""Builds a flat space position mapping model from data.
Args:
nrn_ids (list-like): List of mapped neuron IDs, as returned by :func:`extract`
flat_pos (numpy.ndarray): Table of mapped neuron positions in 3D flat space of size <#neurons x 3>, as returned by :func:`extract`
Returns:
connectome_manipulator.model_building.model_types.PosMapModel: Resulting position mapping model
"""
flat_pos_table = pd.DataFrame(flat_pos, index=nrn_ids, columns=["x", "y", "z"])
log.log_assert(np.all(np.isfinite(flat_pos_table)), "ERROR: Position error!")
# Create model
model = model_types.PosMapModel(pos_table=flat_pos_table)
log.debug("Model description:\n%s", model)
return model
[docs]
def plot(out_dir, nrn_ids, nrn_lay, nrn_pos, model, **_): # pragma: no cover
"""Visualizes neuron positions in original space vs. mapped space from model output.
Args:
out_dir (str): Path to output directory where the results figures will be stored
nrn_ids (list-like): List of mapped neuron IDs, as returned by :func:`extract`
nrn_lay (list-like): List of layer property values for all mapped neurons, as returned by :func:`extract`
nrn_pos (numpy.ndarray): Table of original neuron positions in 3D atlas space of size <#neurons x 3>, as returned by :func:`extract`
model (connectome_manipulator.model_building.model_types.PosMapModel): Resulting position mapping model, as returned by :func:`build`
"""
nrn_pos_model = model.apply(gids=nrn_ids)
# 3D cell positions in atlas vs. flat space
num_layers = len(np.unique(nrn_lay))
lay_colors = plt.get_cmap("jet")(np.linspace(0, 1, num_layers))
views = [[90, 0], [0, 0]]
pos_list = [nrn_pos, nrn_pos_model]
lbl_list = ["Atlas space (data)", "Flat space (model)"]
fig = plt.figure(figsize=(10, 3 * len(views)), dpi=300)
plt.gcf().patch.set_facecolor("w")
for vidx, v in enumerate(views):
for pidx, (pos, lbl) in enumerate(zip(pos_list, lbl_list)):
ax = fig.add_subplot(
len(views), len(pos_list), vidx * len(pos_list) + pidx + 1, projection="3d"
)
for lidx in range(num_layers):
pos_sel = pos[nrn_lay == lidx + 1, :]
plt.plot(
pos_sel[:, 0],
pos_sel[:, 1],
pos_sel[:, 2],
".",
color=lay_colors[lidx, :],
markersize=1.0,
alpha=0.5,
label=f"L{lidx + 1}",
)
ax.view_init(*v)
ax.set_xlabel("x [$\\mu$m]")
ax.set_ylabel("y [$\\mu$m]")
ax.set_zlabel("z [$\\mu$m]")
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5), ncol=1)
if vidx == 0:
plt.title(lbl + f"\n[N={len(nrn_ids)}cells]")
plt.tight_layout()
out_fn = os.path.abspath(os.path.join(out_dir, "data_vs_model_positions.png"))
log.info(f"Saving {out_fn}...")
plt.savefig(out_fn)
# Cell distances in atlas vs. flat space
max_plot = 10000
if len(nrn_ids) > max_plot:
log.debug("Using subsampling for distance plots!")
nrn_sel = np.random.choice(len(nrn_ids), max_plot)
nrn_ids = nrn_ids[nrn_sel]
nrn_pos = nrn_pos[nrn_sel, :]
nrn_pos_model = nrn_pos_model[nrn_sel, :]
dist_mat_data = distance_matrix(nrn_pos, nrn_pos)
dist_mat_model = distance_matrix(nrn_pos_model, nrn_pos_model)
triu_idx = np.triu_indices(len(nrn_ids), 1)
dist_val_data = dist_mat_data[triu_idx]
dist_val_model = dist_mat_model[triu_idx]
dist_max = max(*dist_val_data, *dist_val_model)
plt.figure(figsize=(5, 5), dpi=300)
plt.plot(dist_val_data, dist_val_model, "b.", alpha=0.1, markersize=1.0, markeredgecolor="none")
plt.plot([0, dist_max], [0, dist_max], "k--")
plt.xlim((0, dist_max))
plt.ylim((0, dist_max))
plt.grid(True)
plt.xlabel("Distance in atlas space (data) [$\\mu$m]")
plt.ylabel("Distance in flat space (model) [$\\mu$m]")
plt.title(f"Cell distances in atlas vs. flat space [N={len(nrn_ids)}cells]")
plt.tight_layout()
out_fn = os.path.abspath(os.path.join(out_dir, "data_vs_model_distances.png"))
log.info(f"Saving {out_fn}...")
plt.savefig(out_fn)
# Nearest neighbors in atlas vs. flat space
NN_mat_data = np.argsort(dist_mat_data, axis=1)
NN_mat_model = np.argsort(dist_mat_model, axis=1)
num_NN_list = list(range(1, 30, 1))
NN_match = np.full(len(num_NN_list), np.nan)
log.debug("Computing nearest neighbors in atlas vs. flat space...")
pbar = progressbar.ProgressBar()
for nidx in pbar(range(len(num_NN_list))):
num_NN = num_NN_list[nidx]
NN_match[nidx] = np.mean(
[
len(np.intersect1d(NN_mat_data[i, 1 : 1 + num_NN], NN_mat_model[i, 1 : 1 + num_NN]))
/ num_NN
for i in range(len(nrn_ids))
]
)
plt.figure(figsize=(5, 4), dpi=300)
plt.plot(num_NN_list, NN_match, ".-")
plt.grid(True)
plt.ylim((0, 1))
plt.xlabel("#Nearest neighbors")
plt.ylabel("Mean match")
plt.title(f"Nearest neighbors in atlas vs. flat space [N={len(nrn_ids)}cells]")
out_fn = os.path.abspath(os.path.join(out_dir, "data_vs_model_neighbors.png"))
log.info(f"Saving {out_fn}...")
plt.savefig(out_fn)