Matching the two sides of the brain

Using techniques for graph matching, we can try to uncover a one-to-one correspondence between the nodes of the left side of the brain and the nodes of the right side.

Preliminaries

from pkg.utils import set_warnings

import datetime
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


from graspologic.match import GraphMatch
from graspologic.utils import to_laplacian

from pkg.data import load_maggot_graph
from pkg.io import savefig, get_out_dir
from pkg.match import quadratic_assignment
from pkg.match.qap import _doubly_stochastic
from pkg.plot import set_theme
from src.visualization import adjplot  # TODO fix graspologic version and replace here
from pkg.data import load_network_palette

t0 = time.time()


def stashfig(name, **kwargs):
    foldername = "matching_illustration"
    savefig(name, foldername=foldername, **kwargs)


colors = sns.color_palette("Set1")
# palette = dict(zip(["Left", "Right"], colors))
palette, NETWORK_KEY = load_network_palette()
set_theme()

Load and process data

mg = load_maggot_graph()
mg = mg[mg.nodes["paper_clustered_neurons"]]

ll_mg, rr_mg, lr_mg, rl_mg = mg.bisect(paired=True)

ll_adj = ll_mg.sum.adj
rr_adj = rr_mg.sum.adj

nodes = ll_mg.nodes
nodes["_inds"] = range(len(nodes))
sorted_nodes = nodes.sort_values(["simple_group"])
sort_inds = sorted_nodes["_inds"]

ll_adj = ll_adj[np.ix_(sort_inds, sort_inds)]
rr_adj = rr_adj[np.ix_(sort_inds, sort_inds)]
# Match ratio with this went down to like .7
# ll_adj = to_laplacian(ll_adj, form="R-DAD")
# rr_adj = to_laplacian(rr_adj, form="R-DAD")


def plot_adjs(left, right, title=""):
    fig, axs = plt.subplots(1, 2, figsize=(15, 7))
    adjplot(
        left,
        plot_type="scattermap",
        sizes=(2, 2),
        ax=axs[0],
        title=r"L $\to$ L",
        color=palette["Left"],
    )
    adjplot(
        right,
        plot_type="scattermap",
        sizes=(2, 2),
        ax=axs[1],
        title=r"R $\to$ R",
        color=palette["Right"],
    )
    fig.suptitle(title, ha="center", x=0.51)
    return fig, axs


plot_adjs(ll_adj, rr_adj, title="Known alignment")
(<Figure size 1080x504 with 2 Axes>,
 array([<AxesSubplot:title={'center':'L $\\to$ L'}>,
        <AxesSubplot:title={'center':'R $\\to$ R'}>], dtype=object))
_images/matching_illustration_4_1.png

Run several restarts of graph matching to understand the variance in solutions

def generate_random_doubly_stochastic(n, alpha=0.5):
    J = np.ones((n, n)) / n
    K = np.random.uniform(size=(n, n))
    # Sinkhorn balancing
    K = _doubly_stochastic(K)
    P = J * (1 - alpha) + K * alpha
    return P


options = dict(maximize=True, maxiter=30, shuffle_input=False, tol=1e-14)
n = len(ll_adj)
n_restarts = 10
alpha = 0
initial_rows = []
best_res = None
best_fun = -np.inf
for i in range(n_restarts):
    currtime = time.time()
    shuffle_inds = np.random.permutation(n)
    unshuffle_inds = np.argsort(shuffle_inds)
    P0 = generate_random_doubly_stochastic(n, alpha)
    options["P0"] = P0.copy()
    res = quadratic_assignment(
        ll_adj, rr_adj[shuffle_inds][:, shuffle_inds], options=options
    )
    res["input_shuffle"] = shuffle_inds
    res["P0"] = P0
    initial_rows.append(res)
    if res["fun"] > best_fun:
        best_fun = res["fun"]
        best_res = res
    match_ratio = (unshuffle_inds == res["col_ind"]).mean()
    print(f"{time.time() - currtime:.3f} seconds elapsed for restart {i+1}.")
    print(f"Restart {i+1} did {res['nit']} iterations.")
    print(f"Final score: {res['fun']}")
    print(f"Match ratio: {match_ratio}")
    print()
10.431 seconds elapsed for restart 1.
Restart 1 did 30 iterations.
Final score: 826668.0
Match ratio: 0.8018166804293972

10.279 seconds elapsed for restart 2.
Restart 2 did 26 iterations.
Final score: 835306.0
Match ratio: 0.7745664739884393

13.261 seconds elapsed for restart 3.
Restart 3 did 30 iterations.
Final score: 834678.0
Match ratio: 0.791907514450867

13.670 seconds elapsed for restart 4.
Restart 4 did 30 iterations.
Final score: 813866.0
Match ratio: 0.7720891824938068

12.878 seconds elapsed for restart 5.
Restart 5 did 30 iterations.
Final score: 824108.0
Match ratio: 0.7745664739884393

11.989 seconds elapsed for restart 6.
Restart 6 did 28 iterations.
Final score: 822726.0
Match ratio: 0.7861271676300579

11.302 seconds elapsed for restart 7.
Restart 7 did 26 iterations.
Final score: 827739.0
Match ratio: 0.7795210569777044

11.782 seconds elapsed for restart 8.
Restart 8 did 30 iterations.
Final score: 822873.0
Match ratio: 0.7927332782824112

11.304 seconds elapsed for restart 9.
Restart 9 did 30 iterations.
Final score: 830558.0
Match ratio: 0.7803468208092486

11.620 seconds elapsed for restart 10.
Restart 10 did 28 iterations.
Final score: 825977.0
Match ratio: 0.7902559867877786

Illustrate graph matching progress by iteration

currtime = time.time()
max_iter = res["nit"]
options = dict(maximize=True, maxiter=1, shuffle_input=False, tol=1e-4)
n = len(ll_adj)
max_iter = 30
shuffle_inds = best_res["input_shuffle"]
P = best_res["P0"]
unshuffle_inds = np.argsort(shuffle_inds)
plot_adjs(ll_adj, rr_adj[shuffle_inds][:, shuffle_inds])
rows = []
fun = 0
for i in range(max_iter):
    options["P0"] = P
    rr_adj_shuffle = rr_adj[shuffle_inds][:, shuffle_inds]
    res = quadratic_assignment(ll_adj, rr_adj_shuffle, options=options)
    P = res["P_final"]
    res["iteration"] = i
    perm_inds = res["col_ind"]
    match_ratio = (perm_inds == unshuffle_inds).mean()
    res["match_ratio"] = match_ratio
    plot_adjs(ll_adj, rr_adj_shuffle[perm_inds][:, perm_inds], title=f"Iteration {i}")
    rows.append(res)
    if res["fun"] == fun:
        break
    fun = res["fun"]
print(f"{time.time() - currtime:.3f} seconds elapsed.")

results = pd.DataFrame(rows)

fig, axs = plt.subplots(2, 1, figsize=(8, 8))
sns.lineplot(data=results, x="iteration", y="fun", ax=axs[0])
sns.lineplot(data=results, x="iteration", y="match_ratio", ax=axs[1])
12.034 seconds elapsed.
<AxesSubplot:xlabel='iteration', ylabel='match_ratio'>
_images/matching_illustration_8_2.png _images/matching_illustration_8_3.png _images/matching_illustration_8_4.png _images/matching_illustration_8_5.png _images/matching_illustration_8_6.png _images/matching_illustration_8_7.png _images/matching_illustration_8_8.png _images/matching_illustration_8_9.png _images/matching_illustration_8_10.png _images/matching_illustration_8_11.png _images/matching_illustration_8_12.png _images/matching_illustration_8_13.png _images/matching_illustration_8_14.png _images/matching_illustration_8_15.png _images/matching_illustration_8_16.png _images/matching_illustration_8_17.png

Start graph matching from the known pairs to see if we can improve

gm = GraphMatch(n_init=25, init=np.eye(n), max_iter=100, shuffle_input=True, eps=1e-4)
gm.fit(ll_adj, rr_adj)
perm_inds = gm.perm_inds_
match_ratio = (perm_inds == np.arange(n)).mean()
print("Match ratio when initializing from known pairs:")
print(match_ratio)
print("GMP objective function after initializing from known pairs:")
print(gm.score_)
print("GMP objective function from known pairs:")
print((ll_adj * rr_adj).sum())
Match ratio when initializing from known pairs:
0.8860445912469034
GMP objective function after initializing from known pairs:
868447.0
GMP objective function from known pairs:
865969.0

End

elapsed = time.time() - t0
delta = datetime.timedelta(seconds=elapsed)
print("----")
print(f"Script took {delta}")
print(f"Completed at {datetime.datetime.now()}")
print("----")
----
Script took 0:15:31.255236
Completed at 2021-05-13 14:58:07.099221
----