Hierarchical Clustering Analysis - Omnibus Embedding#

from itertools import combinations_with_replacement

import graspologic as gp
import matplotlib
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
from scipy.stats import kruskal
from seaborn.utils import relative_luminance
from sklearn.preprocessing import LabelEncoder
from statsmodels.stats.multitest import multipletests

import hyppo
from hyppo.ksample import MANOVA, KSample
from pkg.data import (
    GENOTYPES,
    HEMISPHERES,
    SUB_STRUCTURES,
    SUPER_STRUCTURES,
    load_fa_corr,
    load_vertex_df,
    load_vertex_metadata,
    load_volume_corr,
)
from pkg.inference import run_ksample, run_pairwise
from pkg.plot import plot_heatmaps, plot_pairwise
from pkg.utils import squareize

matplotlib.rcParams["font.family"] = "monospace"

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
# Load the data

volume_correlations, labels = load_volume_corr()

meta = load_vertex_df()
## Plot to make sure nothing is wrong
fig, ax = plt.subplots(
    ncols=3,
    figsize=(10, 3),
    constrained_layout=True,
    dpi=300,
    gridspec_kw=dict(width_ratios=[1, 1, 1]),
)

for i in range(len(volume_correlations)):
    gp.plot.adjplot(
        volume_correlations[i],
        ax=ax[i],
        vmin=-1,
        vmax=1,
        # meta=vertex_name,
        # group=["Hemisphere_abbrev"],
    )
    ax[i].set_title(f"{GENOTYPES[i]}", pad=0, size=12)
_images/ce18085267165cabe9b575658cdfcfc47ff5f1cf4415a6fd7778db2490c7230f.png

Embed the data simultaneously using Omnibus embedding#

The purpose of omnibus embedding is to obtain a low dimensional representation of the three correlation matrices such that the embedded correlation matrices can be compared to each other in a meaningful way [Athreya et al., 2018, Lyzinski et al., 2014, Sussman et al., 2012, Sussman et al., 2014]. The embedding provides a low dimensional vector per region of the brain for each correlation matrix, resulting in a \(322\times d\) matrix per genotype where \(d\) is the “embedding dimension” and \(d << 332\). The resulting vectors per region is called a latent positions of a vertex. Omnibus embedding is similar to the idea of using PCA on data to get dataset with reduced number of features.

We will use the embeddings in order to perform hierchical clustering [Athey et al., 2019].

omni = gp.embed.OmnibusEmbed(2)
Xhats = omni.fit_transform([*volume_correlations])

print(omni.n_components_)
2

The automated algorithm chooses \(d\) for us, which in this case is \(d=4\).

Perform hierarchical clustering#

First we concatenate the embeddings resulting in a matrix with size \(332 \times 3d\). We will iteratively divide the regions into two clusters. For example, we divide the \(332\) regions into two clusters, and each resulting cluster will be divided into two more clusters, etc. The clustering algorithm used at each division is Gaussian mixture modeling [Athey et al., 2019].

The clustering algorithm will tell us which regions should be grouped together, and subsequently forms our “subgraphs” given by the clusterings.

max_level = 4

cluster = gp.cluster.DivisiveCluster(max_level=max_level, cluster_kws={"kmeans_n_init": 5})

cluster_labels = cluster.fit_predict(np.hstack(Xhats), fcluster=True)

cluster_label_df = pd.DataFrame(
    cluster_labels, columns=[f"cluster_level_{i}" for i in range(1, max_level + 1)]
)

Plotting the cluster labeling along with apriori labels using Sankey diagrams#

Sankey diagrams tell us which regions or subregions belong to which clusters.

def count_groups(label_matrix):
    levels = label_matrix.shape[1] - 1
    d = []

    for level in range(levels):
        upper_cluster_ids = np.unique(label_matrix[:, level])

        for upper_cluster_id in upper_cluster_ids:
            lower_cluster_ids, counts = np.unique(
                label_matrix[label_matrix[:, level] == upper_cluster_id][:, level + 1],
                return_counts=True,
            )

            for idx, lower_cluster_id in enumerate(lower_cluster_ids):
                d.append((upper_cluster_id, lower_cluster_id, counts[idx]))

    d = np.array(d)

    source = d[:, 0]
    target = d[:, 1]
    value = d[:, 2]

    return source, target, value


def append_apriori_labels(apriori_labels, cluster_matrix):
    encoder = LabelEncoder()
    apriori_labels_encoded = encoder.fit_transform(apriori_labels)
    apriori_labels_encoded = apriori_labels_encoded.reshape(-1, 1)

    # Increase the original cluster_matrix labels
    cluster_matrix_ = cluster_matrix + np.max(apriori_labels_encoded) + 1

    out = np.hstack([apriori_labels_encoded, cluster_matrix_])

    return out, list(encoder.classes_)
hemispheric_clusters, encoded_labels = append_apriori_labels(meta.Hemisphere, cluster_labels)

source, target, value = count_groups(hemispheric_clusters)

fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=encoded_labels
                + [f"Cluster {i}" for i in range(np.max(hemispheric_clusters))],
            ),
            link=dict(source=source, target=target, value=value),
        )
    ]
)

fig.update_layout(title_text="Hemispheric Clustering", font_size=10)
fig.show(dpi=300, width=1000, height=600)
level_1_clusters, encoded_labels = append_apriori_labels(meta.Level_1, cluster_labels)

source, target, value = count_groups(level_1_clusters)

fig = go.Figure(
    data=[
        go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=encoded_labels + [f"Cluster {i}" for i in range(np.max(level_1_clusters))],
            ),
            link=dict(source=source, target=target, value=value),
        )
    ]
)

fig.update_layout(title_text="Level 1 Clustering", font_size=10)
fig.show(dpi=300, width=1000, height=600)

Visualizing different clustering levels using heatmaps#

Heatmaps can qualitatively tell us if there are any underlying structures within the clusters.

cl = pd.DataFrame(cluster_labels, columns=[f"Cluster_{i}" for i in range(1, max_level + 1)])

meta = pd.concat([meta, cl], axis=1)

meta.head()
Structure Abbreviation Hemisphere Level_1 Level_2 Cluster_1 Cluster_2 Cluster_3 Cluster_4
0 Cingulate_Cortex_Area_24a A24a L FB IS 0 3 3 3
1 Cingulate_Cortex_Area_24a_prime A24aPrime L FB IS 1 5 11 14
2 Cingulate_Cortex_Area_24b A24b L FB IS 1 5 10 10
3 Cingulate_Cortex_Area_24b_prime A24bPrime L FB IS 1 5 11 15
4 Cingulate_Cortex_Area_29a A29a L FB IS 0 2 7 7
## Plot to make sure nothing is wrong
fig, ax = plt.subplots(
    ncols=3,
    figsize=(20, 10),
    # constrained_layout=True,
    dpi=300,
    gridspec_kw=dict(width_ratios=[1, 1, 1]),
)

for (i, genotype) in enumerate(volume_correlations):
    gp.plot.adjplot(
        genotype,
        ax=ax[i],
        vmin=-1,
        vmax=1,
        meta=meta,
        group=[f"Cluster_{max_level}"],
    )
    ax[i].set_title(f"{labels[i]}", pad=90, size=30)

# fig.savefig(f"./figures/2022-02-02-multigraph-clustering-level-{l + 1}.png", bbox_inches='tight')
_images/bbe22c541e76e0da48648112688a5156a1dfdeb86020ae9c8bf51f553602186c.png

It seems like cluster structures are predominantly driven by the large correlations in the APOE3 genotype

Testing for significantly different clusters at various levels#

Again, the clusters tell us which regions should be grouped together. Hence, each cluster forms our subgraph. For each subgraph, we test whether the distribution of the latent positions (aka the embeddings) are significantly different. Specifically we test

(1)#\[\begin{align} H_0:& \qquad F_{APOE2}=F_{APOE3}=F_{APOE4}\\ H_A:& \qquad \text{At least one pair of distributions is different} \end{align}\]

using a 3-sample distance correlation. We correct for multiple hypothesis testing via Holm-Bonferroni correction.

import string
to_relabel = cl.astype(str)
relabels = {
    "0": "A1",
    "1": "A2",
}

for i in range(1, 5):
    if i == "1":
        col = to_relabel.loc[:, f"Cluster_{i}"]
        col.replace(relabels, inplace=True)
    elif i != "1":
        col = to_relabel.loc[:, f"Cluster_{i}"]
        uniques = np.unique(col)
        letter = string.ascii_uppercase[i - 1]
        for idx, u in enumerate(uniques):
            if u not in list(relabels.keys()):
                relabels[u] = letter + str(idx)
        col.replace(relabels, inplace=True)

l1 = to_relabel.Cluster_1
l2 = to_relabel.Cluster_1 + "-" + to_relabel.Cluster_2
l3 = to_relabel.Cluster_1 + "-" + to_relabel.Cluster_2 + "-" + to_relabel.Cluster_3
l4 = (
    to_relabel.Cluster_1
    + "-"
    + to_relabel.Cluster_2
    + "-"
    + to_relabel.Cluster_3
    + "-"
    + to_relabel.Cluster_4
)
volume_ksample = [
    run_ksample(volume_correlations, labels, idx, test="kruskal", absolute=True)
    for idx, labels in enumerate(
        [
            l1,
            l2,
            l3,
            l4,
        ]
    )
]

volume_ksample = pd.concat(volume_ksample, ignore_index=True)

volume_ksample.to_csv("../results/outputs/omnibus_3sample_apriori.csv", index=False)
sns.set_context("talk", font_scale=0.5)
fig, _ = plot_heatmaps(volume_ksample, True)

fig.savefig("./figures/omnibus_ksample_volume.pdf")


fig, _ = plot_heatmaps(volume_ksample, cbar=False, ranked_pvalue=True)

fig.savefig("./figures/omnibus_ksample_volume_ranked_pvalue.pdf")
_images/e4ec237df236286836b36485cc1c76dc7c3527db411188b28cf1d81afe697014.png _images/f3b4680f6c93b8e9b5fa3bc916a66ccf0f2b992d523d1065902924414ba63e2c.png
volume_pairwise = run_pairwise(
    volume_correlations,
    GENOTYPES,
    l4,
    absolute=True,
    test="mannwhitney",
)
fig, _ = plot_pairwise(volume_pairwise, volume_ksample)
fig.savefig("./figures/omnibus_pairwise_volume.pdf")
_images/8b96145bd5dfafe9b9239ca7c2efa02c6ec80623524abc9d0d16e19500ee6adb.png