from itertools import combinations, combinations_with_replacement

import graspologic as gp
import hyppo
import matplotlib
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
from hyppo.ksample import MANOVA, KSample
from joblib import Parallel, delayed
from matplotlib.transforms import Bbox
from scipy.spatial.distance import squareform
from scipy.stats import kruskal
from seaborn.utils import relative_luminance
from sklearn.preprocessing import LabelEncoder
from statsmodels.stats.multitest import multipletests

from pkg.data import (
    GENOTYPES,
    HEMISPHERES,
    SUB_STRUCTURES,
    SUPER_STRUCTURES,
    load_fa_corr,
    load_vertex_df,
    load_vertex_metadata,
    load_volume,
    load_volume_corr,
)
from pkg.inference import run_ksample, run_pairwise
from pkg.plot import plot_heatmaps, plot_pairwise
from pkg.utils import rank, squareize

matplotlib.rcParams["font.family"] = "monospace"

import warnings

warnings.simplefilter(action="ignore")
# Load the data

volume_correlations, labels = load_volume_corr()

meta = load_vertex_df()

volume_correlations = np.array([rank(v) for v in np.abs(volume_correlations)])

K-sample latent position distribution test#

Embed the data simultaneously using Omnibus embedding#

The purpose of omnibus embedding is to obtain a low dimensional representation of the three correlation matrices such that the embedded correlation matrices can be compared to each other in a meaningful way {cite:p}Athreya2018-zp,ase-consistency-1,ase-consistency-2,ase-consistency-3. The embedding provides a low dimensional vector per region of the brain for each correlation matrix, resulting in a matrix per genotype where is the “embedding dimension” and . The resulting vectors per region is called a latent positions of a vertex. Omnibus embedding is similar to the idea of using PCA on data to get dataset with reduced number of features.

latents = gp.embed.OmnibusEmbed().fit_transform(volume_correlations)

Running K-sample on all vertices#

ksample = KSample("dcorr")

ksample.test(*latents, auto=False, workers=-1)
IndependenceTestOutput(stat=0.0909769965981636, pvalue=0.000999000999000999)

Hierachical#

vertex_hemispheres = meta.Hemisphere.values
vertex_structures = meta.Level_1.values
vertex_hemisphere_structures = (meta.Hemisphere + "-" + meta.Level_1).values
vertex_hemisphere_substructures = (
    meta.Hemisphere + "-" + meta.Level_1 + "-" + meta.Level_2
).values
res = []

for labels in [
    vertex_hemispheres,
    vertex_structures,
    vertex_hemisphere_structures,
    vertex_hemisphere_substructures,
]:
    tmp = []
    for lab in np.unique(labels):
        idx = labels == lab

        if idx.sum() > 3:
            ksample = KSample("dcorr")
            stat, pval = ksample.test(*latents[:, idx, :], auto=False, workers=-1)
        else:
            pval = 1

        tmp.append(pval)

    res.append(tmp)
res = [[0.000999000999000999]] + res
corrected = multipletests(
    [j for i in res for j in i],
)
c = list(np.log10(corrected[1]))
s = []

for length in [len(i) for i in res]:
    t = []
    for j in range(length):
        t.append(c.pop(0))
    while len(t) < 30:
        t.append(np.nan)
    s.append(t)

s = np.array(s)
c = np.log10(list(corrected[1]))

kwags = dict(
    cmap="RdBu",
    square=True,
    cbar=False,
    vmax=c.max(),
    vmin=c.min(),
    center=0,
    xticklabels=[],
)
sns.set_context("talk")


fig, ax = plt.subplots(figsize=(6, 8), ncols=5, dpi=200, constrained_layout=True)

for idx, (data, labs) in enumerate(
    zip(
        s,
        [
            ["All"],
            vertex_hemispheres,
            vertex_structures,
            vertex_hemisphere_structures,
            vertex_hemisphere_substructures,
        ],
    )
):
    im = sns.heatmap(s[[idx]].T, **kwags, ax=ax[idx], yticklabels=np.unique(labs))
    ax[idx].tick_params(
        axis="y",
        labelrotation=0,
        pad=0.5,
        length=1,
        left=False,
    )

    colors = im.get_children()[0].get_facecolors()
    pad = 0.2

    for jdx, row in enumerate(s[idx]):
        i, j = np.unravel_index(jdx, (30, 1))

        # REF: seaborn heatmap
        lum = relative_luminance(colors[jdx])
        text_color = ".15" if lum > 0.408 else "w"
        lw = 20 / 30

        if row <= np.log10(0.05):
            xs = [j + pad, j + 1 - pad]
            ys = [i + pad, i + 1 - pad]
            ax[idx].plot(xs, ys, color=text_color, linewidth=lw)
            xs = [j + 1 - pad, j + pad]
            ys = [i + pad, i + 1 - pad]
            ax[idx].plot(xs, ys, color=text_color, linewidth=lw)
_images/fe697c8339bed3d71cd04d872f6ecfb2be61dae75668b033c37926aac24cd2bf.png