Look at it

Preliminaries

from pkg.utils import set_warnings

import datetime
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from graspologic.embed import selectSVD
from sklearn.preprocessing import normalize
from graspologic.utils import pass_to_ranks
from umap import AlignedUMAP
from giskard.plot import graphplot
from src.visualization import CLASS_COLOR_DICT

from graspologic.align import OrthogonalProcrustes, SeedlessProcrustes
from graspologic.embed import (
    AdjacencySpectralEmbed,
    OmnibusEmbed,
    select_dimension,
)
from graspologic.match import GraphMatch
from graspologic.plot import pairplot
from graspologic.utils import (
    augment_diagonal,
    binarize,
    multigraph_lcc_intersection,
    pass_to_ranks,
)
from pkg.data import (
    load_maggot_graph,
    load_palette,
    select_nice_nodes,
    load_node_palette,
    load_network_palette,
)
from pkg.io import savefig
from pkg.plot import set_theme
from giskard.utils import get_paired_inds

from src.visualization import adjplot  # TODO fix graspologic version and replace here
from giskard.plot import graphplot
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import binom


def stashfig(name, **kwargs):
    foldername = "look_at_it"
    savefig(name, foldername=foldername, **kwargs)

Load and process data

t0 = time.time()
set_theme()

network_palette, NETWORK_KEY = load_network_palette()
node_palette, NODE_KEY = load_node_palette()

mg = load_maggot_graph()
mg = select_nice_nodes(mg)
nodes = mg.nodes
left_nodes = nodes[nodes["hemisphere"] == "L"]
left_inds = left_nodes["_inds"]
right_nodes = nodes[nodes["hemisphere"] == "R"]
right_inds = right_nodes["_inds"]
left_paired_inds, right_paired_inds = get_paired_inds(
    nodes, pair_key="predicted_pair", pair_id_key="predicted_pair_id"
)
right_paired_inds_shifted = right_paired_inds - len(left_inds)
adj = mg.sum.adj
ll_adj = adj[np.ix_(left_inds, left_inds)]
rr_adj = adj[np.ix_(right_inds, right_inds)]
Removed 13 nodes when taking the largest connected component.
Removed 38 nodes when removing pendants.
Removed 0 nodes when taking the largest connected component.

Plot the ipsilateral subgraph adjacency matrices

def calculate_weighted_degrees(adj):
    return np.sum(adj, axis=0) + np.sum(adj, axis=1)


color_matrix = np.zeros_like(adj)
color_matrix[np.ix_(left_inds, left_inds)] = 0
color_matrix[np.ix_(right_inds, right_inds)] = 1
color_matrix[np.ix_(left_inds, right_inds)] = 2
color_matrix[np.ix_(right_inds, left_inds)] = 3
edge_palette = dict(zip(np.arange(4), sns.color_palette("Set2")))
nodes["degree"] = -calculate_weighted_degrees(adj)
adjplot(
    adj,
    plot_type="scattermap",
    sizes=(1, 1),
    meta=nodes,
    sort_class=["hemisphere"],
    item_order=["simple_group", "degree"],
    tick_fontsize=20,
    color_matrix=color_matrix,
    edge_palette=edge_palette,
)

stashfig("adj-degree-sort")
_images/look_at_it_6_0.png

Fit a 2-block model

n_left = len(left_inds)
n_right = len(right_inds)
# Note: graspologic has some code to do most of this but I wanted everything here to be
# above ground so to speak
# TODO ignore loops?
ll_n_edges = np.count_nonzero(adj[np.ix_(left_inds, left_inds)])
rr_n_edges = np.count_nonzero(adj[np.ix_(right_inds, right_inds)])
lr_n_edges = np.count_nonzero(adj[np.ix_(left_inds, right_inds)])
rl_n_edges = np.count_nonzero(adj[np.ix_(right_inds, left_inds)])
n_edges_matrix = np.array([[ll_n_edges, lr_n_edges], [rl_n_edges, rr_n_edges]])

ll_p_edge = ll_n_edges / (n_left ** 2)
rr_p_edge = rr_n_edges / (n_right ** 2)
lr_p_edge = lr_n_edges / (n_left * n_right)
rl_p_edge = rl_n_edges / (n_left * n_right)
p_edge_matrix = np.array([[ll_p_edge, lr_p_edge], [rl_p_edge, rr_p_edge]])

Plot the connection probability matrix

colors = sns.color_palette("Set2")


def make_custom_cmap(to_rgb, from_rgb=(1, 1, 1)):
    # REF: https://stackoverflow.com/questions/16267143/matplotlib-single-colored-colormap-with-saturation
    # from color r,g,b
    r1, g1, b1 = from_rgb

    # to color r,g,b
    r2, g2, b2 = to_rgb

    cdict = {
        "red": ((0, r1, r1), (1, r2, r2)),
        "green": ((0, g1, g1), (1, g2, g2)),
        "blue": ((0, b1, b1), (1, b2, b2)),
    }

    cmap = LinearSegmentedColormap("custom_cmap", cdict)
    return cmap


# forgive me god for what I'm about to do
fig, axs = plt.subplots(2, 2, figsize=(4, 4), gridspec_kw=dict(hspace=0, wspace=0))
heatmap_kws = dict(
    vmin=0, vmax=0.02, annot=True, cbar=False, xticklabels=False, yticklabels=False
)


def plot_heatmap_element(value, color, ax):
    cmap = make_custom_cmap(color)
    sns.heatmap(np.array([[value]]), cmap=cmap, ax=ax, **heatmap_kws)


values = [ll_p_edge, lr_p_edge, rl_p_edge, rr_p_edge]
ordered_colors = [colors[0], colors[2], colors[3], colors[1]]

for i, (val, col) in enumerate(zip(values, ordered_colors)):
    plot_heatmap_element(val, col, axs.flat[i])
axs[0, 0].set_ylabel("L", rotation=0, labelpad=10)
axs[1, 0].set_ylabel("R", rotation=0, labelpad=10)
axs[1, 0].set_xlabel("L")
axs[1, 1].set_xlabel("R")
fig.suptitle("Connection probabilities")
fig.text(0, 0.385, "Presynaptic", ha="center", rotation=90)
fig.text(0.52, -0.02, "Postsynaptic", ha="center")
stashfig("connection-probabilities")
_images/look_at_it_10_0.png

Plot the number of edges for each lateral type with 99% confidence intervals

fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ns = [n_left ** 2, n_right ** 2, n_left * n_right, n_left * n_right]
edge_counts = [ll_n_edges, rr_n_edges, lr_n_edges, rl_n_edges]
for i, (n_edges, n) in enumerate(zip(edge_counts, ns)):
    p_edge = n_edges / n
    err = binom(n, p_edge).interval(0.99)
    err = np.array(err)
    ax.bar(i, n_edges, color=colors[i])
    ax.plot([i, i], err, color="black", zorder=90)
names = [r"L $\to$ L", r"R $\to$ R", r"L $\to$ R", r"R $\to$ L"]
ax.xaxis.set_major_locator(plt.FixedLocator(np.arange(4)))
ax.set(ylabel="# of edges", xticklabels=names)
stashfig("n_edges_plot")
_images/look_at_it_12_0.png

Plot a network layout of the whole network

ax = graphplot(
    network=adj,
    meta=nodes,
    n_components=64,
    n_neighbors=64,
    min_dist=0.8,
    hue=NODE_KEY,
    node_palette=node_palette,
    random_state=888888,
    sizes=(10, 40),
    hue_labels="radial",
    hue_label_fontsize="xx-small",
    adjust_labels=True,
    # verbose=True,
)
stashfig("whole-network-layout")
_images/look_at_it_14_0.png

Plot the network split out by left/right

side_palette = dict(zip(["L", "R"], sns.color_palette("Set2")))
# edge_palette = dict(
#     zip([("L", "L"), ("R", "R"), ("L", "R"), ("R", "L")], sns.color_palette("Set2"))
# )
ax = graphplot(
    network=adj,
    meta=nodes,
    n_components=64,
    n_neighbors=64,
    min_dist=0.8,
    hue="hemisphere",
    node_palette=side_palette,
    random_state=888888,
    sizes=(10, 40),
    hue_label_fontsize="xx-small",
    tile="hemisphere",
    tile_layout=[["L", "R"]],
    figsize=(20, 10),
    edge_linewidth=0.2,
    edge_alpha=0.2,
    subsample_edges=0.2,
    # edge_hue="prepost",
    # edge_palette=edge_palette,
    # hue_labels="radial",
    # adjust_labels=True,
    # verbose=True,
)
ax.text(0, 1, "Left", fontsize="x-large")
ax.text(2, 1, "Right", fontsize="x-large")
stashfig("split-network-layout")
_images/look_at_it_16_0.png

Simple statistics for the left hemisphere induced subgraph

ll_mg, rr_mg, _, _ = mg.bisect(paired=False)
ll_mg.sum
n_nodes n_edges sum_edge_weights
edge_type
sum 1481.0 36447.0 117068.0

Simple statistics for the right hemisphere induced subgraph

rr_mg.sum
n_nodes n_edges sum_edge_weights
edge_type
sum 1481.0 39096.0 129527.0

[Experimental] Plot a graph layout for each hemisphere using aligned UMAP

def ase(adj, n_components=None):
    U, S, Vt = selectSVD(adj, n_components=n_components, algorithm="full")
    S_sqrt = np.diag(np.sqrt(S))
    X = U @ S_sqrt
    Y = Vt.T @ S_sqrt
    return X, Y


def prescale_for_embed(adjs):
    norms = [np.linalg.norm(adj, ord="fro") for adj in adjs]
    mean_norm = np.mean(norms)
    adjs = [adjs[i] * mean_norm / norms[i] for i in range(len(adjs))]
    return adjs


n_components = 16  # 24 looked fine
power = True
normed = False
if power:
    ll_adj_for_umap = pass_to_ranks(ll_adj)
    rr_adj_for_umap = pass_to_ranks(rr_adj)
    if normed:
        ll_adj_for_umap = normalize(ll_adj_for_umap, axis=1)
        rr_adj_for_umap = normalize(rr_adj_for_umap, axis=1)
    ll_adj_for_umap = ll_adj_for_umap @ ll_adj_for_umap
    rr_adj_for_umap = rr_adj_for_umap @ rr_adj_for_umap
else:
    ll_adj_for_umap = pass_to_ranks(ll_adj)
    rr_adj_for_umap = pass_to_ranks(rr_adj)

X_ll, Y_ll = ase(ll_adj_for_umap, n_components=n_components)
X_rr, Y_rr = ase(rr_adj_for_umap, n_components=n_components)

Z_ll = np.concatenate((X_ll, Y_ll), axis=1)
Z_rr = np.concatenate((X_rr, Y_rr), axis=1)
# Z_ll, _ = ase(Z_ll, n_components=n_components)
# Z_rr, _ = ase(Z_rr, n_components=n_components)

relation_dict = dict(zip(left_paired_inds, right_paired_inds_shifted))

aumap = AlignedUMAP(
    random_state=88888,
    n_neighbors=32,
    min_dist=0.95,
    metric="cosine",
    alignment_regularisation=1e-2,
)
umap_embeds = aumap.fit_transform([Z_ll, Z_rr], relations=[relation_dict])

graphplot_kws = dict(sizes=(30, 60))
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
graphplot(
    network=ll_adj,
    embedding=umap_embeds[0],
    meta=left_nodes,
    hue="merge_class",
    node_palette=CLASS_COLOR_DICT,
    ax=axs[0],
    **graphplot_kws,
)
graphplot(
    network=rr_adj,
    embedding=umap_embeds[1],
    meta=right_nodes,
    hue="merge_class",
    node_palette=CLASS_COLOR_DICT,
    ax=axs[1],
    **graphplot_kws,
)
stashfig("aligned-umap-layout")
_images/look_at_it_22_0.png
elapsed = time.time() - t0
delta = datetime.timedelta(seconds=elapsed)
print("----")
print(f"Script took {delta}")
print(f"Completed at {datetime.datetime.now()}")
print("----")
----
Script took 0:02:39.102026
Completed at 2021-05-12 18:58:23.461493
----