Visualizing a priori SBMsΒΆ

cd '/Users/TingshanLiu/Desktop/2020 summer/TL_maggot/maggot_connectome/'
/Users/TingshanLiu/Desktop/2020 summer/TL_maggot/maggot_connectome
from pkg.data import load_maggot_graph
from joblib import Parallel, delayed
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
import networkx as nx
from graspologic.utils import cartesian_product, import_graph, binarize, remove_loops
from graspologic.models import SBMEstimator
from graspologic.plot import heatmap
from graspologic.simulations import er_nm
from graspologic.models.sbm import _get_block_indices, _calculate_block_p, _block_to_full
from sklearn.utils import check_X_y
from scipy.stats import norm, binom_test
context = sns.plotting_context(context="poster", font_scale=1.2)
sns.set_context(context)
def model_sbm(adj, cls_name):
    if len(cls_name) == 1:
        sbm = SBMEstimator()
        sbm.fit(binarize(adj), y=meta[cls_name].values)
    elif len(cls_name) == 2:  # crossing 'hemisphere' w/ another class
        multi_labels = meta[cls_name[0]].astype(str) + meta[cls_name[1]].astype(str)
        # assuming the first class is 'hemisphere'
        sbm = SBMEstimator()
        sbm.fit(binarize(adj), y=multi_labels)
    return sbm

def generate_rand_block_p(adj, y, rep=100, n_jobs=5):
    def _run():
        g = er_nm(len(adj), np.count_nonzero(adj), directed=True)
        sbm_ = SBMEstimator()
        sbm_.fit(binarize(g), y=y)
        return sbm_.block_p_
    rand_block_p = Parallel(n_jobs=n_jobs, verbose=1)(delayed(_run)()
        for _ in range(rep))
    return np.mean(rand_block_p, axis=0)

def compute_block_sig(adj, cls_name, p_mat, n_rand=100):
    labeled = np.where(meta[cls_name].ne(''))[0]
    adj_new = adj[np.ix_(labeled, labeled)]
    block_vert_inds, block_inds, block_inv = _get_block_indices(meta[cls_name].iloc[labeled])
    block_count = _calculate_block_p(binarize(adj_new), block_inds, block_vert_inds, True)
    rand_block_p = generate_rand_block_p(adj_new, meta[cls_name].iloc[labeled], n_rand)
    n_blocks = len(block_inds)
    block_pairs = cartesian_product(block_inds, block_inds)
    block_p = np.zeros((n_blocks, n_blocks))
    p_val = []
    for p in block_pairs:
        from_block = p[0]
        to_block = p[1]
        from_inds = block_vert_inds[from_block]
        to_inds = block_vert_inds[to_block]
        block = binarize(adj)[from_inds, :][:, to_inds]
        p_val.append(round(binom_test(block_count[p[0], p[1]],
            block.size, rand_block_p[p[0], p[1]], alternative='greater'), 4))
    return p_val, block_pairs

def plot_heatmap(B, new_labels, p_val, block_pairs):
    fig = plt.figure(constrained_layout=True, figsize=(38,20))
    gs = fig.add_gridspec(12, 20)
    ax = fig.add_subplot(gs[2:13,1:12])
    heatmap(B, xticklabels=new_labels, yticklabels=new_labels, ax=ax, norm=LogNorm())
    for i in range(len(block_pairs)):
        if p_val[i] <= 0.05:
            if p_val[i] <= 0.005:
                if p_val[i] <= 0.0005:
                    ax.text(block_pairs[i,1], block_pairs[i,0]+0.5,'***', color='w', fontsize=20)
                else:
                    ax.text(block_pairs[i,1], block_pairs[i,0]+0.5,'**', color='w', fontsize=20)
            else:
                ax.text(block_pairs[i,1], block_pairs[i,0]+0.5,'*', color='w', fontsize=20)
    ax.set_xticklabels(new_labels, rotation=45, ha='right')
    ax.set_yticklabels(new_labels, rotation=45)
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=30)
    ax.hlines(split,colors='gray', linestyles='dashed', linewidth=1.5, *ax.get_xlim())
    ax.vlines(split,colors='gray', linestyles='dashed', linewidth=1.5, *ax.get_xlim())
    ax = fig.add_subplot(gs[6:,15:])
    ax.hist(p_val, bins=[0, 0.0001, 0.001, 0.01, 0.1, 1])
    ax.set_yscale('symlog')
    ax.set_xscale('symlog',linthreshx=0.0001)
    ax.set_xticks([0.0001, 0.001, 0.01, 0.1, 1])
    ax.set_xticklabels(ax.get_xticks(), fontsize=30, rotation=45)
    ax.set_yticklabels(ax.get_yticks(), fontsize=30)
    ax.set_xlabel('p-value', fontsize=30)
    ax.set_ylabel('count', fontsize=30)
    # plt.savefig('cls4_wei',transparent=False, facecolor='white', bbox_inches = "tight")

def plot_connectivity(cls, block_pairs, B, p_val):
    new_keys = np.array(list(cls.keys()))
    counts = []
    for k in new_keys:
        idx = np.where(new_keys == k)[0][0]
        n = np.sum(meta['class_4'] == idx)
        counts.append(n)
    nodes = range(len(new_keys))
    labels = {}
    for n in nodes:
        labels[n] = new_keys[n].split('_')[-1]
    edges = block_pairs[np.where(np.array(p_val)<0.005)[0]]
    edges = [tuple(edges[i,:]) for i in range(edges.shape[0])]
    weights = B.ravel()[np.where(np.array(p_val)<0.005)[0]]
    weights = [((i-np.min(weights)) / (np.max(weights)-np.min(weights))) * 1.5
        for i in weights]
    pos = {0:[0,11], 1:[2.25,11], 2:[4.5,11], 3:[6.75,11], 4:[9,11], 5:[11.25,11], 6:[13.5,11],
        7:[0,7], 8:[2,7], 9:[4,7], 10:[6,7], 11:[8,7],
        12:[0,1], 13:[2,1], 14:[3.5,3], 15:[3.5,-1.5], 16:[5,1],
        18:[8,1], 19:[10,1], 20:[12,1], 21:[15,-13],
        17:[4.5,-8], 
        22:[1,-19], 23:[4.5,-19], 
        24:[1,-23], 25:[4.5,-23], 26:[8,-23]}

    fig, ax = plt.subplots(1, figsize=(8,10))
    ax.set_xlim([-5,18])
    ax.set_ylim([-24,16.5])
    g = nx.MultiDiGraph()
    g.add_nodes_from(nodes)
    g.add_edges_from(edges)
    nx.draw_networkx_nodes(g, node_size=counts, pos=pos, ax=ax)
    label_pos = pos.copy()
    offset_size = 0.8
    for i in pos.keys():
        # for j in range(2):
        #     label_pos[i][j] = pos[i][j] + counts[i]/1200 + offset_size[j]
        label_pos[i] = [j+counts[i]/1000+offset_size for j in pos[i]]
    label_pos[13][1] = 1.5
    i = -1
    for e in g.edges:
        i += 1
        ax.annotate("",
                    xy=pos[e[1]], xycoords='data',
                    xytext=pos[e[0]], textcoords='data',
                    arrowprops=dict(arrowstyle="->", color="0.5",
                                    shrinkA=5+counts[e[0]]/50,
                                    shrinkB=5+counts[e[1]]/50,
                                    patchA=None, patchB=None,
                                    connectionstyle="arc3,rad=0.1",
                                    mutation_scale=10, lw=weights[i],
                                    ),
                    )
    ax.axis('off')
    for i in range(len(edges)):
        if edges[i][0] == edges[i][1]:
            gap = counts[edges[i][0]] / 600 + 0.3
            center = tuple([pos[edges[i][0]][0], pos[edges[i][0]][1]+gap])
            l = mpl.patches.Arc(center,0.3,0.5,0,-60,240,
                lw=weights[i], color='0.5')
            ax.add_patch(l)
    nx.draw_networkx_labels(g, pos=label_pos, labels=labels, font_color='r', ax=ax)
    ypos_ = [11,7,0.5,-8,-19,-23]
    texts = ['Input', '2nd-Order', '3rd-Order','4th-Order','Nth-1 Order','Nth Order']
    for i in range(len(texts)):
        ax.text(-2.5, ypos_[i], texts[i], ha='center', va='center', fontsize=10)

    plt.show()

def _calculate_block_p_wei(graph, block_inds, block_vert_inds, return_counts=False):
    n_blocks = len(block_inds)
    block_pairs = cartesian_product(block_inds, block_inds)
    block_p = np.zeros((n_blocks, n_blocks))
    for p in block_pairs:
        from_block = p[0]
        to_block = p[1]
        from_inds = block_vert_inds[from_block]
        to_inds = block_vert_inds[to_block]
        block = graph[from_inds, :][:, to_inds]
        if np.count_nonzero(block) == 0:
            p = 0
        else:
            p = np.sum(block) / np.count_nonzero(block)
        block_p[from_block, to_block] = p
    return block_p

def weighted_sbm(adj, y):
    graph = import_graph(adj)
    check_X_y(graph, y)
    graph = remove_loops(graph)
    block_vert_inds, block_inds, block_inv = _get_block_indices(y)
    block_p = _calculate_block_p_wei(graph, block_inds, block_vert_inds)
    p_mat = _block_to_full(block_p, block_inv, graph.shape)
    p_mat = remove_loops(p_mat)
    return block_p, p_mat

def generate_rand_block_p_wei(adj, y, rep=100, n_jobs=5):
    def _run():
        g = er_nm(len(adj), np.count_nonzero(adj), directed=True)
        edges = adj[np.nonzero(adj)]
        np.random.shuffle(edges)
        g[np.nonzero(g)] = edges
        r_block_p, _ = weighted_sbm(g, y)
        return r_block_p
    rand_block_p = Parallel(n_jobs=n_jobs, verbose=1)(delayed(_run)()
        for _ in range(rep))
    return rand_block_p

def compute_block_sig_wei(block_p, adj, cls_name, n_rand=100):
    _, block_inds, _ = _get_block_indices(meta[cls_name])
    rand_block_p = generate_rand_block_p_wei(adj, meta[cls_name], n_rand)
    block_pairs = cartesian_product(block_inds, block_inds)
    p_val = []
    for p in block_pairs:
        null = [i[p[0], p[1]] for i in rand_block_p]
        if np.mean(null) == 0 and block_p[p[0], p[1]] == 0:
            p_val.append(1)
        else:
            # print(np.mean(null), np.std(null), block_p[p[0], p[1]])
            z_score = (block_p[p[0], p[1]] - np.mean(null)) / np.std(null)
            p_val.append(round(1 - norm.cdf(z_score)))
    return p_val, block_pairs


#%%[markdown]
# ### load the data
mg = load_maggot_graph()
mg = mg[mg.nodes["paper_clustered_neurons"]]
mg.fix_pairs()
meta = mg.nodes.copy()
adj = mg.sum.adj.copy()

#%%[markdown]
# ### grap information on labels from different columns of meta
class_4 = {'sens_AN': [["['sens_subclass_AN']"], 'all_class2'],
    'sens_olf': [["['sens_subclass_ORN']"], 'all_class2'],
    'sens_MN': [["['sens_subclass_MN']"], 'all_class2'],
    'sens_thermo': [["['sens_subclass_thermo']"], 'all_class2'],
    'sens_photo': [["['sens_subclass_photoRh5']",
        "['sens_subclass_photoRh6']"], 'all_class2'],
    'sens_vtd': [["['sens_subclass_vtd']"], 'all_class2'],
    'sens_ascending': [['ascendings'], 'simple_group'],
    'inter_uPN': [['uPN'], 'merge_class'],
    'inter_mPN': [['mPN'], 'merge_class'],
    'inter_tPN': [['tPN'], 'merge_class'],
    'inter_vPN': [['vPN'], 'merge_class'],
    'inter_unkPN': [],
    'inter_LHN': [['LHNs'], 'simple_group'],
    'inter_KC': [['KCs'], 'simple_group'],
    'inter_MBIN': [['MBINs'], 'simple_group'],
    'inter_MBON': [['MBONs'], 'simple_group'],
    'inter_FBN': [['MB-FBNs'], 'simple_group'],
    'inter_CN': [['CNs'], 'simple_group'],
    'inter_ALN': [['bLN','pLN','cLN'], 'merge_class'],
    'inter_LON': [['LON'], 'merge_class'],
    'inter_unkLN': [],
    'inter_unk': [['unk'], 'simple_group'],
    'out_pre-dVNC': [['pre-dVNCs'], 'simple_group'],
    'out_pre-dSEZ': [['pre-dSEZs'], 'simple_group'],
    'out_dVNC': [['dVNCs'], 'simple_group'],
    'out_dSEZ': [['dSEZs'], 'simple_group'],
    'out_RGN': [['RGNs'], 'simple_group'],
}

all_classes = {'class_4': class_4}
for ii, kk in enumerate(all_classes.keys()):
    meta[kk] = ''
    for idx, (k, v) in enumerate(all_classes[kk].items()):
        if len(v) == 1:
            meta.loc[meta[v[0]].eq(True).any(1), kk] = idx
        elif len(v) > 1:
            for i in range(len(v[0])):
                meta.loc[meta[v[1]].eq(v[0][i]), kk] = idx

# Some nodes that labeled as "PN" or "LN" in meta['simple_group']
# are labeled as things like "AN2", "MN2" etc in meta['merge_class']
# that do not specify which "PN" or "LN"
# even though meta['merge_class'] gives detailed labels for the majority of nodes
# so in order not to exclude any node, we will label them as "unkPN" or "unkLN"
idx = list(class_4.keys()).index("inter_unkPN")
meta.loc[(meta['class_4'].eq('') & meta['simple_group'].eq('PNs')), 'class_4'] = idx
idx = list(class_4.keys()).index("inter_unkLN")
meta.loc[(meta['class_4'].eq('') & meta['simple_group'].eq('LNs')), 'class_4'] = idx
print(len(meta['class_4'].values))
2874
# class_1: 3-block (input, interneuron, output)
meta['class_1'] = ''
meta.loc[meta['class_4'].isin(range(7)), 'class_1'] = 0
meta.loc[meta['class_4'].isin(range(7, 22)), 'class_1'] = 1
meta.loc[meta['class_4'].isin(range(22, 27)), 'class_1'] = 2 
print(len(meta['class_1'].values))

#%%[markdown]
# ### fit an SBM for each category and test significance on the B_ij's
2874
all_classes = ['class_1', 'simple_group', 'class_4']
sbms = []
sbms_hemi = []
for k in all_classes:
    sbms.append(model_sbm(adj, [k]))
    sbms_hemi.append(model_sbm(adj, ['hemisphere', k]))
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/scikit_learn-0.24.1-py3.7-macosx-10.7-x86_64.egg/sklearn/utils/validation.py:63: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
  return f(*args, **kwargs)
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/scikit_learn-0.24.1-py3.7-macosx-10.7-x86_64.egg/sklearn/utils/validation.py:63: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
  return f(*args, **kwargs)
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/scikit_learn-0.24.1-py3.7-macosx-10.7-x86_64.egg/sklearn/utils/validation.py:63: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().
  return f(*args, **kwargs)
cls_name = 'class_4'
sbm_id = all_classes.index(cls_name)
p_val, block_pairs = compute_block_sig(adj, cls_name, sbms[sbm_id].p_mat_, 500)
[Parallel(n_jobs=5)]: Using backend LokyBackend with 5 concurrent workers.
[Parallel(n_jobs=5)]: Done  40 tasks      | elapsed:   38.8s
[Parallel(n_jobs=5)]: Done 190 tasks      | elapsed:  1.8min
[Parallel(n_jobs=5)]: Done 440 tasks      | elapsed:  3.6min
[Parallel(n_jobs=5)]: Done 500 out of 500 | elapsed:  4.0min finished
# for class_4
cls = class_4
labels = list(cls.keys())
new_labels = [i.split('_')[-1] for i in labels]
B = sbms[sbm_id].block_p_
split = [7, 24]

# # for 'simple_group'
# idx = [14,9,7,2,1,5,6,4,0,3,15,13,12,11,10,8]  # indices for reordering
# labels = list(np.unique(meta['simple_group'], return_inverse=True)[0])
# new_labels = [labels[i] for i in idx]
# B = sbms[sbm_id].block_p_[np.ix_(idx, idx)]
# p_val = np.array(p_val).reshape(B.shape)[np.ix_(idx, idx)].ravel()
# split = [2, 11]
plot_heatmap(B, new_labels, p_val, block_pairs)
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/ipykernel_launcher.py:65: MatplotlibDeprecationWarning: The 'linthreshx' parameter of __init__() has been renamed 'linthresh' since Matplotlib 3.3; support for the old name will be dropped two minor releases later.
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/ipykernel_launcher.py:68: UserWarning: FixedFormatter should only be used together with FixedLocator
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/IPython/core/pylabtools.py:132: UserWarning: constrained_layout not applied.  At least one axes collapsed to zero width or height.
  fig.canvas.print_figure(bytes_io, **kw)
_images/visualizing_a_prior_sbm_11_1.png
plot_connectivity(class_4, block_pairs, B, p_val)

#%%[markdown]
# #### try to incorporate the edge weights
_images/visualizing_a_prior_sbm_12_0.png
block_p, p_mat = weighted_sbm(adj, y=meta['class_4'].values)
p_val, block_pairs = compute_block_sig_wei(block_p, adj, 'class_4', 500)
[Parallel(n_jobs=5)]: Using backend LokyBackend with 5 concurrent workers.
[Parallel(n_jobs=5)]: Done  40 tasks      | elapsed:   20.2s
[Parallel(n_jobs=5)]: Done 190 tasks      | elapsed:  1.5min
[Parallel(n_jobs=5)]: Done 440 tasks      | elapsed:  3.5min
[Parallel(n_jobs=5)]: Done 500 out of 500 | elapsed:  4.0min finished
plot_heatmap(block_p, new_labels, p_val, block_pairs)
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/ipykernel_launcher.py:65: MatplotlibDeprecationWarning: The 'linthreshx' parameter of __init__() has been renamed 'linthresh' since Matplotlib 3.3; support for the old name will be dropped two minor releases later.
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/ipykernel_launcher.py:68: UserWarning: FixedFormatter should only be used together with FixedLocator
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/IPython/core/pylabtools.py:132: UserWarning: constrained_layout not applied.  At least one axes collapsed to zero width or height.
  fig.canvas.print_figure(bytes_io, **kw)
_images/visualizing_a_prior_sbm_15_1.png
plot_connectivity(cls, block_pairs, block_p, p_val)

#%%[markdown]
# ### look at the distribution of the edge weights grouped by source & destination
_images/visualizing_a_prior_sbm_16_0.png
fig,ax = plt.subplots(1,figsize=(36,12))
labels = ['input', 'interneuron', 'output']
for i in range(3):
    for j in range(3):
        idx1 = np.where(meta['class_1'] == i)[0]
        idx2 = np.where(meta['class_1'] == j)[0]
        adj_sub = adj[np.ix_(idx1, idx2)]
        ax.hist(adj_sub.ravel(), log=True, density=True,
        label=labels[i]+"->"+labels[j], histtype=u'step')
ax.legend()
ax.set(xlabel='Edge Weight', ylabel='Density')


#%%[markdown]
# ### evaluate fitted SBMs
[Text(0.5, 0, 'Edge Weight'), Text(0, 0.5, 'Density')]
_images/visualizing_a_prior_sbm_17_1.png
bic = np.zeros((3,2))
lik = np.zeros((3,2))
n_params = np.zeros((3,2))
all_cls = ['class_1', 'simple_group', 'class_4']
estimators = [sbms, sbms_hemi]
for i in range(len(all_cls)):
    labeled = np.where(meta[all_cls[i]].ne(''))[0]
    adj_new = adj[np.ix_(labeled, labeled)]
    for ii,e in enumerate(estimators):
        e[i].n_verts = len(adj_new)
        bic[i,ii] = e[i].bic(binarize(adj_new))
        lik[i,ii] = e[i].score(binarize(adj_new))
        n_params[i,ii] = e[i]._n_parameters()
fig,axs = plt.subplots(3,1,figsize=(18,20))
data = [-bic, lik, n_params]
ylabels = ['BIC', 'Log Likelihood', '# parameters']
for i in range(len(data)):
    ax = axs[i]
    ax.plot(data[i][:,0], label='single')
    ax.plot(data[i][:,1], label='cross with hemisphere')
    ax.set_ylabel(ylabels[i])
    if i < 2:
        ax.set(xticklabels='', xticks=range(3))
    else:
        ax.set(xticklabels=['3-block', 'simple_group', 'complex_group'],
        xticks=range(3))
        ax.legend()
/Users/TingshanLiu/opt/miniconda3/envs/graspologic2021Jan/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: FixedFormatter should only be used together with FixedLocator
  del sys.path[0]
_images/visualizing_a_prior_sbm_19_1.png