import time

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from graspy.inference import LatentDistributionTest
from graspy.simulations import p_from_latent, sample_edges
from graspy.utils import symmetrize
from hyppo.discrim import DiscrimOneSample
from sklearn.metrics import pairwise_distances

np.random.seed(8888)

sns.set_context("talk")

mpl.rcParams["axes.edgecolor"] = "lightgrey"
mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.spines.top"] = False


def hardy_weinberg(theta):
    """
    Maps a value from [0, 1] to the hardy weinberg curve.
    """
    hw = [theta ** 2, 2 * theta * (1 - theta), (1 - theta) ** 2]
    return np.array(hw).T


def sample_hw_graph(thetas):
    latent = hardy_weinberg(thetas)
    p_mat = p_from_latent(latent, rescale=False, loops=False)
    graph = sample_edges(p_mat, directed=False, loops=False)
    return (graph, p_mat, latent)

Parameters of the experiment

n_timepoints = 5
n_verts = 100
n_graphs_per_timepoint = 10
deltas = np.linspace(0, 2, n_timepoints)

Distributions in latent space

Let $HW(\theta)$ be the Hardy-Weinberg distribution in $\mathbb{R}^3$.

Latent positions are distributed along this curve: $$X \sim HW(\theta)$$ With the distribution along the curve following a Beta distribution: $$\theta \sim Beta(1, 1 + \delta)$$ Let $\delta$ be a proxy for "time"

Below I plot the distributions of $\theta$ for each value of $\delta$, where we will use a different value of $\delta$ for each time point.

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
for delta in deltas:
    thetas = np.random.beta(
        1, 1 + delta, 10000
    )  # fake # to make the distributions look cleaner
    sns.distplot(thetas, label=delta, ax=ax)
plt.legend(title=r"$\delta$", bbox_to_anchor=(1, 1), loc="upper left")
_ = ax.set(ylabel="Frequency", yticks=[], xlabel=r"$\theta$")

Sample latent positions, and then sample graphs

To generate each graph I sample a set of latent positions from the Hardy-Weinberg curve described above. Each time point will have multiple sets of latent positions sampled i.i.d. from the same distribution in latent space, then a single graph is sampled from each set of latent positions.

graphs = []
latents = []
times = []
for t, delta in enumerate(deltas):
    for i in range(n_graphs_per_timepoint):
        thetas = np.random.beta(1, 1 + delta, n_verts)
        graph, pmat, latent = sample_hw_graph(thetas)
        graphs.append(graph)
        times.append(t)
        latents.append(latent)
times = np.array(times)

Plot 2 example sets of sampled latent positions for each time point

Here I just show the first two dimensions of true latent positions. From each of these we sample a graph.

fig, axs = plt.subplots(
    2,
    n_timepoints,
    figsize=(n_timepoints * 4, 8),
    sharex=True,
    sharey=False,  # TODO fix sharey and labeling
)
for t, delta in enumerate(deltas):
    for i in range(2):
        ax = axs[i, t]
        latent = latents[t * n_graphs_per_timepoint + i]
        plot_latent = pd.DataFrame(latent)
        sns.scatterplot(data=plot_latent, x=0, y=1, ax=ax, linewidth=0, alpha=0.5, s=20)
        ax.set(xlabel="", ylabel="", xticks=[], yticks=[])
        if i == 0:
            deltastr = r"$\delta$" + f" = {deltas[t]}"
            ax.set_title(f"t = {t} ({deltastr})")
        if t == 0:
            ax.set_ylabel(f"Sample {i + 1}")

plt.tight_layout()

Plot adjacency matrices for 2 graphs from each time point

fig, axs = plt.subplots(2, n_timepoints, figsize=(n_timepoints * 4, 8))
for t, delta in enumerate(deltas):
    for i in range(2):
        graph = graphs[t * n_graphs_per_timepoint + i]
        ax = axs[i, t]
        sns.heatmap(
            graph,
            ax=ax,
            cbar=False,
            xticklabels=False,
            yticklabels=False,
            cmap="RdBu_r",
            square=True,
            center=0,
        )
        if i == 0:
            deltastr = r"$\delta$" + f" = {deltas[t]}"
            ax.set_title(f"t = {t} ({deltastr})")
        if t == 0:
            ax.set_ylabel(f"Sample {i + 1}")
plt.tight_layout()

Compute the test statistics for Latent Distribution Test (nonpar).

curr_time = time.time()

pval_mat = np.zeros((len(graphs), len(graphs)))
tstat_mat = np.zeros((len(graphs), len(graphs)))
n_comparisons = (len(graphs) * (len(graphs) - 1)) / 2
counter = 0
for i, graph1 in enumerate(graphs):
    for j, graph2 in enumerate(graphs):
        if i < j:
            ldt = LatentDistributionTest(n_bootstraps=200, workers=1)
            ldt.fit(graph1, graph2)
            pval_mat[i, j] = ldt.p_value_
            tstat_mat[i, j] = ldt.sample_T_statistic_
pval_mat = symmetrize(
    pval_mat, method="triu"
)  # need to do way more bootstraps to be meaningful
tstat_mat = symmetrize(tstat_mat, method="triu")

print(f"{(time.time() - curr_time)/60:.3f} minutes elapsed")
5.981 minutes elapsed

All pairwise test statistics and p-values

Here I show test statistics for the latent position test between all possible pairs of graphs. Higher means more different. The test statistic being used here is the 2-sample dcorr test statistic on the estimated latent positions. Note that I'm not doing the new seedless alignment here (but I'd like to).

Then, I show the same for the p-values.

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
sns.heatmap(
    tstat_mat,
    ax=ax,
    xticklabels=False,
    yticklabels=False,
    cmap="Reds",
    square=True,
    cbar_kws=dict(shrink=0.7),
)
line_kws = dict(linestyle="-", linewidth=1, color="grey")
for t in range(1, n_timepoints):
    ax.axvline(t * n_graphs_per_timepoint, **line_kws)
    ax.axhline(t * n_graphs_per_timepoint, **line_kws)
tick_locs = (
    np.arange(0, n_timepoints * n_graphs_per_timepoint, n_graphs_per_timepoint)
    + n_graphs_per_timepoint / 2
)
ax.set(
    xticks=tick_locs,
    xticklabels=np.arange(n_timepoints),
    xlabel="Time point",
    title="Latent distribution test statistics",
)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
sns.heatmap(
    pval_mat,
    ax=ax,
    xticklabels=False,
    yticklabels=False,
    cmap="Reds",
    square=True,
    cbar_kws=dict(shrink=0.7),
)
line_kws = dict(linestyle="-", linewidth=1, color="grey")
for t in range(1, n_timepoints):
    ax.axvline(t * n_graphs_per_timepoint, **line_kws)
    ax.axhline(t * n_graphs_per_timepoint, **line_kws)
tick_locs = (
    np.arange(0, n_timepoints * n_graphs_per_timepoint, n_graphs_per_timepoint)
    + n_graphs_per_timepoint / 2
)
_ = ax.set(
    xticks=tick_locs,
    xticklabels=np.arange(n_timepoints),
    xlabel="Time point",
    title="Latent distribution test p-values",
)

Computing discriminability

Looks at whether distances between samples from the same object (time point, in this case) are smaller than distances between samples from different objects. In a sense, it's looking at whether the diagonal blocks in the above are smaller than the rest of the matrix. Here I'm using the test statistic from above as the distance. Permutation test is used to test whether one's ability to discriminate between "multiple samples" from the same object is highter than one would expect by chance.

curr_time = time.time()

discrim = DiscrimOneSample(is_dist=True)
discrim.test(tstat_mat, times)
print(f"Discriminability one-sample p-value: {discrim.pvalue_}")
print(f"Discriminability test statistic: {discrim.stat}")

print(f"{(time.time() - curr_time)/60:.3f} minutes elapsed")
Discriminability one-sample p-value: 0.001
Discriminability test statistic: 0.8648333333333332
0.061 minutes elapsed

Test statistics and p-values as a function of time difference

Here I just play with plotting these test statistics and p-values as a function of how different in time the two graphs were. I add jitter to the time difference values just for visibility.

time_dist_mat = pairwise_distances(times.reshape((-1, 1)), metric="manhattan")

triu_inds = np.triu_indices_from(time_dist_mat, k=1)

time_dists = time_dist_mat[triu_inds] + np.random.uniform(-0.2, 0.2, len(triu_inds[0]))
latent_dists = tstat_mat[triu_inds]

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
sns.scatterplot(x=time_dists, y=latent_dists, s=10, linewidth=0, alpha=0.3, ax=ax)
ax.set(ylabel="Test statistic", xlabel="Difference in time")

pval_dists = pval_mat[triu_inds]
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
sns.scatterplot(x=time_dists, y=pval_dists, s=10, linewidth=0, alpha=0.3, ax=ax)
_ = ax.set(ylabel="p-value", xlabel="Difference in time")