Load the data

# collapse
import datetime
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix

from graspologic.match import GraphMatch
from graspologic.match.qap import _doubly_stochastic
from src.visualization import adjplot

t0 = time.time()

sns.set_context("talk")

meta_path = "ALPN_crossmatching/data/meta.csv"
nblast_path = "ALPN_crossmatching/data/nblast_scores.csv"

meta = pd.read_csv(meta_path, index_col=0)
meta = meta.set_index("id")
meta["label"].fillna("unk", inplace=True)
nblast_scores = pd.read_csv(nblast_path, index_col=0, header=0)
nblast_scores.columns = nblast_scores.columns.astype(int)

Look at the data

# collapse
adjplot(
    nblast_scores.values,
    meta=meta,
    sort_class=["source"],
    item_order="lineage",
    colors="lineage",
    cbar_kws=dict(shrink=0.7),
)
(<AxesSubplot:>,
 <mpl_toolkits.axes_grid1.axes_divider.AxesDivider at 0x7f9edd9a4750>,
 <matplotlib.axes._axes.Axes at 0x7f9edd9e0410>,
 <matplotlib.axes._axes.Axes at 0x7f9edc2a02d0>)

# collapse
adjplot(
    nblast_scores.values,
    meta=meta,
    sort_class=["lineage"],
    item_order="source",
    colors="source",
    cbar_kws=dict(shrink=0.7),
)
(<AxesSubplot:>,
 <mpl_toolkits.axes_grid1.axes_divider.AxesDivider at 0x7f9edc65fa10>,
 <matplotlib.axes._axes.Axes at 0x7f9edc679ed0>,
 <matplotlib.axes._axes.Axes at 0x7f9edcbf1510>)

Plot the distribution of pairwise scores

# collapse
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
sns.histplot(nblast_scores.values.ravel(), element="step", stat="density")
<AxesSubplot:ylabel='Density'>

Split the NBLAST scores by dataset

# collapse
datasets = ["FAFB(L)", "FAFB(R)"]
dataset1_meta = meta[meta["source"] == datasets[0]]
dataset2_meta = meta[meta["source"] == datasets[1]]

dataset1_ids = dataset1_meta.index
dataset1_intra = nblast_scores.loc[dataset1_ids, dataset1_ids].values

dataset2_ids = dataset2_meta.index
dataset2_intra = nblast_scores.loc[dataset2_ids, dataset2_ids].values

# TODO use these also via the linear term in GMP
dataset1_to_dataset2 = nblast_scores.loc[dataset1_ids, dataset2_ids].values
dataset2_to_dataset1 = nblast_scores.loc[dataset2_ids, dataset1_ids].values

Plot the NBLAST scores before alignment

# collapse
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
adjplot(dataset1_intra, cbar=False, ax=axs[0])
adjplot(dataset2_intra, cbar=False, ax=axs[1])
(<AxesSubplot:>,
 <mpl_toolkits.axes_grid1.axes_divider.AxesDivider at 0x7f9edc973550>,
 <AxesSubplot:>,
 <AxesSubplot:>)

Run the NBLAST score matching without using any prior information

# collapse
gm = GraphMatch(
    n_init=100,
    init="barycenter",
    max_iter=200,
    shuffle_input=True,
    eps=1e-5,
    gmp=True,
    padding="naive",
)

gm.fit(dataset1_intra, dataset2_intra)
perm_inds = gm.perm_inds_
print(f"Matching objective function: {gm.score_}")
Matching objective function: 6227.852163812509

# collapse
dataset2_intra_matched = dataset2_intra[perm_inds][:, perm_inds][: len(dataset1_ids)]
dataset2_meta_matched = dataset2_meta.iloc[perm_inds][: len(dataset1_ids)]

Plot the NBLAST scores after alignment

# collapse
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
adjplot(dataset1_intra, cbar=False, ax=axs[0])
adjplot(dataset2_intra_matched, cbar=False, ax=axs[1])
(<AxesSubplot:>,
 <mpl_toolkits.axes_grid1.axes_divider.AxesDivider at 0x7f9ecabedb90>,
 <AxesSubplot:>,
 <AxesSubplot:>)

Peek at the metadata after alignment

# collapse
dataset1_meta
name lineage label is_canonical ntype source
id
12201990 neuron 12201991 mPN mALT left GD Amy lPN unk False mPN FAFB(L)
12201994 neuron 12201995 mPN mALT left XZ former Dragon lPN unk False mPN FAFB(L)
12202003 neuron 12202004 mPN mALT left GD Avril lPN unk False mPN FAFB(L)
27 Uniglomerular mALT DM4 adPN DB ECM adPN DM4_adPN True uPN FAFB(L)
2076 Uniglomerular mALT DC1 adPN PN015 DB adPN DC1_adPN True uPN FAFB(L)
... ... ... ... ... ... ...
2349018 Uniglomerular mALT DL3 lPN 2349019 RJVR lPN DL3_lPN True uPN FAFB(L)
2349022 Uniglomerular mALT DL3 lPN 2349023 RJVR lPN DL3_lPN True uPN FAFB(L)
4070 Uniglomerular mALT VC3l adPN DB XZ adPN VC3l_adPN True uPN FAFB(L)
2349030 Uniglomerular mALT DL3 lPN (outlier) 2349031 RJVR lPN DL3_lPN True uPN FAFB(L)
434153 Uniglomerular mALT VM1 lPN 434154 lPN VM1_lPN True uPN FAFB(L)

335 rows × 6 columns

# collapse
dataset2_meta_matched
name lineage label is_canonical ntype source
id
57430 Multiglomerular mALT lPN VP2+VL1+VP3+2 LTS 0.9... lPN unk False mPN FAFB(R)
57426 Multiglomerular mALT lPN VP1d+DM4+VC1+5 LTS 0.... lPN unk False mPN FAFB(R)
3648956 Multiglomerular mALT lPN VP3+VP2+VL1+1 LTS 0.9... lPN unk False mPN FAFB(R)
39682 Uniglomerular mALT DM4 adPN 39683 JMR adPN DM4_adPN True uPN FAFB(R)
771242 Uniglomerular mALT DC1 adPN 23665 AA adPN DC1_adPN True uPN FAFB(R)
... ... ... ... ... ... ...
177706 Uniglomerular mALT DL3 lPN 177707 AEB lPN DL3_lPN True uPN FAFB(R)
57311 Uniglomerular mALT DA1 lPN 57312 LK lPN DA1_lPN True uPN FAFB(R)
59129 Uniglomerular mALT VC3l adPN 59130 AJ adPN VC3l_adPN True uPN FAFB(R)
77661 Uniglomerular mALT DL3 lPN 77662 LK lPN DL3_lPN True uPN FAFB(R)
24726 Uniglomerular mALT VM1 lPN 24727 BH lPN VM1_lPN True uPN FAFB(R)

335 rows × 6 columns

Plot confusion matrices for the predicted matching

# collapse


def confusionplot(
    labels1,
    labels2,
    ax=None,
    figsize=(10, 10),
    xlabel="",
    ylabel="",
    title="Confusion matrix",
    annot=True,
    add_diag_proportion=True,
    **kwargs,
):
    unique_labels = np.unique(list(labels1) + list(labels2))
    conf_mat = confusion_matrix(labels1, labels2, labels=unique_labels, normalize=None)
    conf_mat = pd.DataFrame(data=conf_mat, index=unique_labels, columns=unique_labels)

    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=figsize)
    sns.heatmap(
        conf_mat,
        ax=ax,
        square=True,
        cmap="RdBu_r",
        center=0,
        cbar_kws=dict(shrink=0.6),
        annot=annot,
        fmt="d",
        mask=conf_mat == 0,
        **kwargs,
    )
    ax.set(ylabel=ylabel, xlabel=xlabel)
    if add_diag_proportion:
        on_diag = np.trace(conf_mat.values) / np.sum(conf_mat.values)
        title += f" ({on_diag:0.2f} correct)"
    ax.set_title(title, fontsize="large", pad=10)
    return ax

Confusion matrix for neuron type

# collapse
confusionplot(
    dataset1_meta["ntype"],
    dataset2_meta_matched["ntype"],
    ylabel=datasets[0],
    xlabel=datasets[1],
    title="Type confusion matrix",
)
<AxesSubplot:title={'center':'Type confusion matrix (0.96 correct)'}, xlabel='FAFB(R)', ylabel='FAFB(L)'>

Confusion matrix for lineage

# collapse
confusionplot(
    dataset1_meta["lineage"],
    dataset2_meta_matched["lineage"],
    ylabel=datasets[0],
    xlabel=datasets[1],
    title="Lineage confusion matrix",
)
<AxesSubplot:title={'center':'Lineage confusion matrix (0.97 correct)'}, xlabel='FAFB(R)', ylabel='FAFB(L)'>

Confusion matrix for label

NB: There are many "unknown" in the label category, which was messinig up the color palette here, so I clipped the color range at the maximum for the non-unknown categories. It could be skewing the accuracy thought (e.g. unk matched to unk counts as correct).

# collapse
labels1 = dataset1_meta["label"]
dataset1_vmax = labels1.value_counts()[1:].max()
labels2 = dataset2_meta_matched["label"]
dataset2_vmax = labels2.value_counts()[1:].max()
vmax = max(dataset1_vmax, dataset2_vmax)


confusionplot(
    labels1,
    labels2,
    ylabel=datasets[0],
    xlabel=datasets[1],
    title="Label confusion matrix",
    annot=False,
    vmax=vmax,
    xticklabels=False,
    yticklabels=False,
)
<AxesSubplot:title={'center':'Label confusion matrix (0.90 correct)'}, xlabel='FAFB(R)', ylabel='FAFB(L)'>

Accuracy for the above, ignoring unclear/unknown

# collapse
unique_labels = np.unique(list(labels1) + list(labels2))
conf_mat = confusion_matrix(labels1, labels2, labels=unique_labels, normalize=None)
conf_mat = pd.DataFrame(data=conf_mat, index=unique_labels, columns=unique_labels)
conf_mat = conf_mat.iloc[:-5, :-5]  # hack to ignore anything "unclear"
on_diag = np.trace(conf_mat.values) / np.sum(conf_mat.values)
print(f"{on_diag:.2f}")
0.88

Matching with a prior

Here we try to use the group label as a soft prior (not a hard constraint) on the matching proceedure.

We do this by initializing from the "groupycenter" as opposed to the barycenter of the doubly stochastic matrices.

Construct an initialization from the lineages

# collapse

groups1 = dataset1_meta["lineage"]
groups2 = dataset2_meta["lineage"]

unique_groups = np.unique(list(groups1) + list(groups2))

n = len(groups2)  # must be the size of the larger
D = np.zeros((n, n))

group = unique_groups[-1]
layers = []
for group in unique_groups:
    inds1 = np.where(groups1 == group)[0]
    inds2 = np.where(groups2 == group)[0]
    not_inds1 = np.where(groups1 != group)[0]
    not_inds2 = np.where(groups2 != group)[0]
    n_groups = [len(inds1), len(inds2)]
    argmax_n_group = np.argmax(n_groups)
    max_n_group = n_groups[argmax_n_group]
    if min(n_groups) != 0:
        val = 1 / max_n_group
        layer = np.zeros((n, n))
        layer[np.ix_(inds1, inds2)] = val
        D += layer
    # if n_groups[0] != n_groups[1]:
    #     if argmax_n_group == 1:
    #         # then the column sums will be less than 0
    #         col_sum = layer[np.ix_(inds1, inds2)].sum(axis=0).mean()
    #         layer[np.ix_(not_inds1, inds2)] = 1 / len(not_inds1) * (1 - col_sum)

    #     elif argmax_n_group == 0:
    #         # then the row sums  will be less than 0
    #         row_sum = layer[np.ix_(inds1, inds2)].sum(axis=1).mean()
    #         layer[np.ix_(inds1, not_inds2)] = 1 / len(not_inds2) * (1 - row_sum)_d

    #
    #
    #     D[np.ix_(inds1, inds2)] = val

    #     # row_sums = np.sum(layer[inds1], axis=1).mean()
    #     # col_sums = np.sum(layer[:, inds2], axis=0).mean()
    #     layers.append(layer)


# D[:, D.sum(axis=0) == 0] = 1 / n
# D[D.sum(axis=1) == 0] = 1 / n
D += 1 / (n ** 2)  # need to add somthing small for sinkhorn to converge
D0 = _doubly_stochastic(D)

Run matching from the informed initialization

# collapse
gm = GraphMatch(
    n_init=100,
    init=D0,
    max_iter=200,
    shuffle_input=True,
    eps=1e-5,
    gmp=True,
    padding="naive",
)

gm.fit(dataset1_intra, dataset2_intra)
perm_inds = gm.perm_inds_
print(f"Matching objective function: {gm.score_}")
Matching objective function: 6229.611251407659

# collapse
dataset2_intra_matched = dataset2_intra[perm_inds][:, perm_inds][: len(dataset1_ids)]
dataset2_meta_matched = dataset2_meta.iloc[perm_inds][: len(dataset1_ids)]

Plot confusion matrices for the predicted matching started from the prior

Confusion matrix for neuron type

# collapse
confusionplot(
    dataset1_meta["ntype"],
    dataset2_meta_matched["ntype"],
    ylabel=datasets[0],
    xlabel=datasets[1],
    title="Type confusion matrix",
)
<AxesSubplot:title={'center':'Type confusion matrix (0.95 correct)'}, xlabel='FAFB(R)', ylabel='FAFB(L)'>

Confusion matrix for lineage

# collapse
confusionplot(
    dataset1_meta["lineage"],
    dataset2_meta_matched["lineage"],
    ylabel=datasets[0],
    xlabel=datasets[1],
    title="Lineage confusion matrix",
)
<AxesSubplot:title={'center':'Lineage confusion matrix (0.96 correct)'}, xlabel='FAFB(R)', ylabel='FAFB(L)'>

Confusion matrix for label

# collapse
labels1 = dataset1_meta["label"]
dataset1_vmax = labels1.value_counts()[1:].max()
labels2 = dataset2_meta_matched["label"]
dataset2_vmax = labels2.value_counts()[1:].max()
vmax = max(dataset1_vmax, dataset2_vmax)


confusionplot(
    labels1,
    labels2,
    ylabel=datasets[0],
    xlabel=datasets[1],
    title="Label confusion matrix",
    annot=False,
    vmax=vmax,
    xticklabels=False,
    yticklabels=False,
)
<AxesSubplot:title={'center':'Label confusion matrix (0.90 correct)'}, xlabel='FAFB(R)', ylabel='FAFB(L)'>

Observations/notes

  • Matching accuracy looked worse when I tried random initializations instead of barycenter
  • Open question of what to do with the weights themselves, I was expecting to have to use pass to ranks or some other transform but the raw scores seemed to work fairly well
  • 'VUMa2' is a lineage in one FAFB and not the other hemisphere
  • solution using my groupycenter thing doesn't seem that different. possible that the barycenter initialization finds a similar score/matching?

End

elapsed = time.time() - t0
delta = datetime.timedelta(seconds=elapsed)
print("----")
print(f"Script took {delta}")
print(f"Completed at {datetime.datetime.now()}")
print("----")
----
Script took 0:11:49.256432
Completed at 2021-03-19 17:59:15.240449
----