Comparing edge weight thresholds#

Hide code cell source
import datetime
import time

import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import Circle, FancyArrowPatch
from pkg.data import load_network_palette, load_unmatched
from pkg.io import FIG_PATH, get_environment_variables
from pkg.io import glue as default_glue
from pkg.io import savefig
from pkg.plot import (
    SmartSVG,
    merge_axes,
    rainbowarrow,
    set_theme,
    soft_axis_off,
    svg_to_pdf,
)
from pkg.stats import erdos_renyi_test, stochastic_block_test
from pkg.utils import remove_group, sample_toy_networks
from scipy.interpolate import interp1d
from svgutils.compose import Figure, Panel, Text
from tqdm import tqdm

_, _, DISPLAY_FIGS = get_environment_variables()

FILENAME = "thresholding_tests"

FIG_PATH = FIG_PATH / FILENAME


def glue(name, var, **kwargs):
    default_glue(name, var, FILENAME, **kwargs)


def gluefig(name, fig, **kwargs):
    savefig(name, foldername=FILENAME, **kwargs)

    glue(name, fig, figure=True)

    if not DISPLAY_FIGS:
        plt.close()


t0 = time.time()
set_theme()
Environment variables:
   RESAVE_DATA: true
   RERUN_SIMS: true
   DISPLAY_FIGS: False
Hide code cell source
network_palette, NETWORK_KEY = load_network_palette()

left_adj, left_nodes = load_unmatched("left", weights=True)
right_adj, right_nodes = load_unmatched("right", weights=True)

neutral_color = sns.color_palette("Set2")[2]

GROUP_KEY = "celltype_discrete"

left_labels = left_nodes[GROUP_KEY].values
right_labels = right_nodes[GROUP_KEY].values
Hide code cell source
fig, axs = plt.subplots(2, 1, figsize=(4, 5), gridspec_kw=dict(hspace=0))

set_theme(font_scale=1)
source_loc = (0.25, 0.5)
target_loc = (0.75, 0.5)
radius = 0.05
dim_color = "black"
dark_color = "black"


def draw_synapse_end(rad_factor, color="black"):
    rad = np.pi * rad_factor
    x = np.cos(rad)
    y = np.sin(rad)
    scale_factor = 1.6
    x *= radius * scale_factor
    y *= radius * scale_factor
    x += target_loc[0]
    y += target_loc[1]
    c = Circle((x, y), radius=0.0125, color=color)
    ax.add_patch(c)


def draw_synapse(source_loc, connection_rad=0, end_rad=0, color="black"):
    fa = FancyArrowPatch(
        posA=source_loc,
        posB=target_loc,
        connectionstyle=f"arc3,rad={connection_rad}",
        shrinkB=30,
        color=color,
    )
    ax.add_patch(fa)
    draw_synapse_end(end_rad, color=color)


def draw_neurons():
    source_circle = Circle(
        (source_loc),
        radius=radius,
        facecolor=neutral_color,
        edgecolor="black",
        linewidth=2,
        zorder=10,
    )
    ax.add_patch(source_circle)
    ax.text(*source_loc, r"$i$", zorder=11, va="center", ha="center")

    target_circle = Circle(
        (target_loc),
        radius=radius,
        facecolor=neutral_color,
        edgecolor="black",
        linewidth=2,
        zorder=10,
    )
    ax.add_patch(target_circle)
    ax.text(*target_loc, r"$j$", zorder=11, va="center", ha="center")


def set_lims(ax):
    ax.set_xlim(0.19, 0.81)
    ax.set_ylim(0.3, 0.7)


ax = axs[0]
ax.text(0.93, 0.5, 2, fontsize="large", va="center", ha="center")
soft_axis_off(ax)
ax.set_ylabel("Synapse\ncount", rotation=0, ha="right", va="center", labelpad=20)

draw_neurons()
draw_synapse(source_loc, connection_rad=-0.5, end_rad=0.75)
draw_synapse(source_loc, connection_rad=0.5, end_rad=-0.75)

set_lims(ax)

ax.annotate(
    r"Synapse from $i$ to $j$",
    (0.5, 0.63),
    xytext=(40, 25),
    textcoords="offset points",
    ha="center",
    arrowprops=dict(arrowstyle="-|>", facecolor="black", relpos=(0.25, 0)),
    fontsize="small",
)


ax = axs[1]
ax.text(0.93, 0.5, "2 / 5", fontsize="large", va="center", ha="center")
soft_axis_off(ax)
ax.set_ylabel("Input\nproportion", rotation=0, ha="right", va="center", labelpad=20)

draw_neurons()


draw_synapse(source_loc, connection_rad=-0.5, end_rad=0.75)
draw_synapse(source_loc, connection_rad=0.5, end_rad=-0.75)

dist = 0.15
draw_synapse(
    (target_loc[0], target_loc[1] + dist),
    connection_rad=0,
    end_rad=0.5,
    color=dim_color,
)
draw_synapse(
    (target_loc[0] - dist, target_loc[1]),
    connection_rad=0,
    end_rad=1,
    color=dim_color,
)
draw_synapse(
    (target_loc[0], target_loc[1] - dist),
    connection_rad=0,
    end_rad=-0.5,
    color=dim_color,
)

set_lims(ax)

ax.annotate(
    r"Synapse from not $i$ to $j$",
    (0.75, 0.4),
    xytext=(-10, -50),
    textcoords="offset points",
    ha="right",
    arrowprops=dict(arrowstyle="-|>", facecolor="black", relpos=(0.75, 1)),
    fontsize="small",
)


fig.set_facecolor("w")

fig.text(0.07, 0.89, "Weight\n type", fontsize="large", ha="right")
fig.text(0.97, 0.89, "Weight\n" + r"$i \rightarrow$ j", fontsize="large")


border_color = "lightgrey"
line1 = mpl.lines.Line2D(
    (-0.25, 1.2),
    (0.5, 0.5),
    transform=fig.transFigure,
    color=border_color,
    linewidth=1.5,
)
line2 = mpl.lines.Line2D(
    (0.95, 0.95),
    (0.15, 0.85),
    transform=fig.transFigure,
    color=border_color,
    linewidth=1.5,
)
line3 = mpl.lines.Line2D(
    (0.1, 0.1),
    (0.15, 0.85),
    transform=fig.transFigure,
    color=border_color,
    linewidth=1.5,
)

fig.lines = (line1, line2, line3)

gluefig("weight_notions", fig)
Hide code cell source
rng = np.random.default_rng(8888)


A1, A2, node_data = sample_toy_networks()

node_data["labels"] = np.ones(len(node_data), dtype=int)
palette = {1: sns.color_palette("Set2")[2]}

g1 = nx.from_numpy_array(A1)
g2 = nx.from_numpy_array(A2)

pos1 = nx.kamada_kawai_layout(g1)
pos2 = nx.kamada_kawai_layout(g2)


def weight_adjacency(A, scale=6):
    A = A.copy()
    sources, targets = np.nonzero(A)
    for source, target in zip(sources, targets):
        # weight = rng.poisson(scale)
        weight = rng.uniform(1, 10)
        A[source, target] = weight
    return A


def layoutplot(
    g,
    pos,
    nodes,
    ax=None,
    figsize=(10, 10),
    weight_scale=1,
    node_alpha=1,
    node_size=300,
    palette=None,
    edge_alpha=0.4,
    edge_color="black",
):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)

    edgelist = g.edges()
    weights = np.array([g[u][v]["weight"] for u, v in edgelist])
    weights *= weight_scale

    nx.draw_networkx_nodes(
        g,
        pos,
        nodelist=nodes.index,
        node_color=nodes["labels"].map(palette),
        edgecolors="black",
        alpha=node_alpha,
        node_size=node_size,
        ax=ax,
    )

    nx.draw_networkx_edges(
        g,
        pos,
        edgelist=edgelist,
        nodelist=nodes.index,
        width=weights,
        edge_vmin=-3,
        edge_vmax=9,
        edge_color=weights,
        edge_cmap=mpl.colormaps["binary"],
        alpha=edge_alpha,
        ax=ax,
        node_size=node_size,
    )

    soft_axis_off(ax)

    return ax


set_theme(font_scale=1.75)

fig, axs = plt.subplots(
    4,
    3,
    figsize=(12, 10),
    constrained_layout=True,
    gridspec_kw=dict(height_ratios=[0.5, 1, 0.25, 1], hspace=0, wspace=0),
)
A1 = weight_adjacency(A1)
A2 = weight_adjacency(A2)
kwargs = dict(
    palette=palette, edge_alpha=1, edge_color=(0.65, 0.65, 0.65), weight_scale=0.75
)
thresholds = [1, 4, 7]
for i in range(3):
    A1[A1 < thresholds[i]] = 0
    A2[A2 < thresholds[i]] = 0
    g1 = nx.from_numpy_array(A1)
    g2 = nx.from_numpy_array(A2)

    ax = axs[1, i]
    layoutplot(g1, pos1, node_data, ax=ax, **kwargs)
    ax = axs[3, i]
    layoutplot(g2, pos2, node_data, ax=ax, **kwargs)


ax = merge_axes(fig, axs, rows=0)

rainbowarrow(ax, start=(0.1, 0.5), end=(0.9, 0.5), cmap="Greys", n=1000, lw=30)
ax.set(ylim=(0.4, 0.8), xlim=(0, 1))
ax.set_title("Increasing edge weight threshold", fontsize="large", y=0.5)
ax.axis("off")


def draw_comparison(ax):
    ax.text(
        0.48, 0.35, r"$\overset{?}{=}$", fontsize="xx-large", ha="center", va="center"
    )
    ax.set(ylim=(0, 1), xlim=(0, 1))
    ax.axis("off")


ax = axs[2, 0]
draw_comparison(ax)

ax = axs[2, 1]
draw_comparison(ax)

ax = axs[2, 2]
draw_comparison(ax)

ax.annotate(
    "Rerun all\n tests",
    (0.6, 0.6),
    xytext=(45, 0),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="-|>", facecolor="black"),
    fontsize="medium",
    va="center",
)

axs[1, 0].set_ylabel(
    "Left",
    color=network_palette["Left"],
    size="large",
    rotation=0,
    ha="right",
    labelpad=10,
)

axs[3, 0].set_ylabel(
    "Right",
    color=network_palette["Right"],
    size="large",
    rotation=0,
    ha="right",
    labelpad=10,
)

fig.set_facecolor("w")

gluefig("thresholding_methods", fig)
Hide code cell source
def construct_weight_data(left_adj, right_adj):
    indices = np.nonzero(left_adj)
    left_weights = left_adj[indices]

    indices = np.nonzero(right_adj)
    right_weights = right_adj[indices]

    labels = np.concatenate(
        (len(left_weights) * ["Left"], len(right_weights) * ["Right"])
    )
    weights = np.concatenate((left_weights, right_weights))
    weight_data = pd.Series(data=weights, name="weights").to_frame()
    weight_data["labels"] = labels
    return weight_data


weight_data = construct_weight_data(left_adj, right_adj)
weight_data = weight_data[weight_data["weights"] < 10]


set_theme(font_scale=1.25)
fig, ax = plt.subplots(1, 1, figsize=(8, 5))

sns.histplot(
    data=weight_data,
    x="weights",
    hue="labels",
    palette=network_palette,
    ax=ax,
    discrete=True,
    cumulative=False,
)
sns.move_legend(ax, loc="upper right", title="Hemisphere")
ax.set(xlabel="Weight (synapse count)")
ax.set(xticks=np.arange(1, 10))
ax.set_yscale("log")

gluefig("synapse_weight_histogram", fig)
Hide code cell source
d_key = "Density"
gc_key = "Group connection"
dagc_key = "Density-adjusted\ngroup connection"


def binarize(A, threshold=None):
    # threshold is the smallest that is kept

    B = A.copy()

    if threshold is not None:
        B[B < threshold] = 0

    return B


rows = []
thresholds = np.arange(1, 10)
for threshold in tqdm(thresholds):
    left_adj_thresh = binarize(left_adj, threshold=threshold)
    right_adj_thresh = binarize(right_adj, threshold=threshold)

    p_edges_removed = 1 - (
        np.count_nonzero(left_adj_thresh) + np.count_nonzero(right_adj_thresh)
    ) / (np.count_nonzero(left_adj) + np.count_nonzero(right_adj))

    stat, pvalue, misc = erdos_renyi_test(left_adj_thresh, right_adj_thresh)
    row = {
        "threshold": threshold,
        "stat": stat,
        "pvalue": pvalue,
        "method": d_key,
        "p_edges_removed": p_edges_removed,
    }
    rows.append(row)

    for adjusted in [False, True]:
        if adjusted:
            method = dagc_key
        else:
            method = gc_key
        stat, pvalue, misc = stochastic_block_test(
            left_adj_thresh,
            right_adj_thresh,
            left_labels,
            right_labels,
            density_adjustment=adjusted,
        )
        row = {
            "threshold": threshold,
            "adjusted": adjusted,
            "stat": stat,
            "pvalue": pvalue,
            "method": method,
            "p_edges_removed": p_edges_removed,
        }
        rows.append(row)

integer_results = pd.DataFrame(rows)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:10<00:00,  1.15s/it]
Hide code cell source
def add_alpha_line(ax):
    ax.axhline(0.05, color="black", linestyle=":", zorder=-1)
    ax.annotate(
        r"0.05",
        (ax.get_xlim()[0], 0.05),
        xytext=(-45, -15),
        textcoords="offset points",
        arrowprops=dict(arrowstyle="-", color="black"),
        clip_on=False,
        ha="right",
    )
Hide code cell source
def plot_thresholding_pvalues(
    results, weight, figsize=(8, 6), no_reject_x=None, reject_x=None
):
    set_theme(font_scale=1.25)
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    colors = sns.color_palette("tab20")
    palette = dict(zip([gc_key, dagc_key, d_key], [colors[0], colors[1], colors[12]]))

    sns.scatterplot(
        data=results,
        x="p_edges_removed",
        y="pvalue",
        hue="method",
        palette=palette,
        ax=ax,
        legend=True,
    )
    sns.lineplot(
        data=results,
        x="p_edges_removed",
        y="pvalue",
        hue="method",
        palette=palette,
        ax=ax,
        legend=False,
    )

    ax.set(
        yscale="log",
        ylabel="p-value",
        xlabel="Edges removed",
        yticks=np.geomspace(1, 1e-20, 5),
    )

    # just pick any method because these are the same for each
    single_results = results[results["method"] == "Density"]
    x = single_results["p_edges_removed"]
    y = single_results["threshold"]

    ax.set_xlim((x.min(), x.max()))
    ax.tick_params(axis="both", length=5)
    ax.set_xticks([0, 0.25, 0.5, 0.75])
    ax.set_xticklabels(["0%", "25%", "50%", "75%"])

    # basically fitting splines to interpolate linearly between points we checked
    prop_to_thresh = interp1d(
        x=x, y=y, kind="slinear", bounds_error=False, fill_value=(0, 1)
    )
    thresh_to_prop = interp1d(
        x=y, y=x, kind="slinear", bounds_error=False, fill_value=(0, 1)
    )

    ax2 = ax.secondary_xaxis(-0.2, functions=(prop_to_thresh, thresh_to_prop))

    if weight == "input_proportion":
        ax2.set_xticks([0.005, 0.01, 0.015, 0.02])
        ax2.set_xticklabels(["0.5%", "1%", "1.5%", "2%"])
        ax2.set_xlabel("Weight threshold (input percentage)")
    elif weight == "synapse_count":
        ax2.set_xlabel("Weight threshold (synapse count)")
    ax2.tick_params(axis="both", length=5)

    add_alpha_line(ax)

    sns.move_legend(
        ax,
        "lower left",
        title="Test",
        frameon=True,
        fontsize="small",
        ncol=1,
        labelspacing=0.3,
    )

    # shading
    ax.autoscale(False)
    if no_reject_x is None:
        no_reject_x = ax.get_xlim()[1]
    ax.fill_between(
        (ax.get_xlim()[0], no_reject_x),
        y1=0.05,
        y2=ax.get_ylim()[0],
        color="darkred",
        alpha=0.05,
    )

    if reject_x is None:
        reject_x = np.mean(ax.get_xlim())

    y = np.mean(np.sqrt(np.product(ax.get_ylim())))
    ax.text(
        reject_x,
        y,
        "Reject\nsymmetry",
        ha="center",
        va="center",
        color="darkred",
    )

    return fig, ax
Hide code cell source
fig, ax = plot_thresholding_pvalues(integer_results, "synapse_count")
gluefig("synapse_threshold_pvalues_legend", fig)
ax.get_legend().remove()
gluefig("synapse_threshold_pvalues", fig)
Hide code cell source
### EDGE WEIGHTS AS INPUT PROPORTIONS
Hide code cell source
left_input = (left_nodes["axon_input"] + left_nodes["dendrite_input"]).values
left_input[left_input == 0] = 1
left_adj_input_norm = left_adj / left_input[None, :]

right_input = (right_nodes["axon_input"] + right_nodes["dendrite_input"]).values
right_input[right_input == 0] = 1
right_adj_input_norm = right_adj / right_input[None, :]
Hide code cell source
rows = []
thresholds = np.linspace(0, 0.03, 31)
for threshold in tqdm(thresholds):
    left_adj_thresh = binarize(left_adj_input_norm, threshold=threshold)
    right_adj_thresh = binarize(right_adj_input_norm, threshold=threshold)

    p_edges_removed = 1 - (
        np.count_nonzero(left_adj_thresh) + np.count_nonzero(right_adj_thresh)
    ) / (np.count_nonzero(left_adj) + np.count_nonzero(right_adj))

    stat, pvalue, misc = erdos_renyi_test(left_adj_thresh, right_adj_thresh)
    row = {
        "threshold": threshold,
        "stat": stat,
        "pvalue": pvalue,
        "method": d_key,
        "p_edges_removed": p_edges_removed,
    }
    rows.append(row)

    for adjusted in [False, True]:
        if adjusted:
            method = dagc_key
        else:
            method = gc_key
        stat, pvalue, misc = stochastic_block_test(
            left_adj_thresh,
            right_adj_thresh,
            left_labels,
            right_labels,
            density_adjustment=adjusted,
        )
        row = {
            "threshold": threshold,
            "adjusted": adjusted,
            "stat": stat,
            "pvalue": pvalue,
            "method": method,
            "p_edges_removed": p_edges_removed,
        }
        rows.append(row)
input_results = pd.DataFrame(rows)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:38<00:00,  1.23s/it]
Hide code cell source
x = input_results[input_results["method"] == "Density"].iloc[12]["p_edges_removed"]
x_threshold = input_results[input_results["method"] == "Density"].iloc[12]["threshold"]

fig, ax = plot_thresholding_pvalues(
    input_results, "input_proportion", no_reject_x=x, reject_x=np.mean((0, x))
)

ax.axvline(
    x,
    ax.get_ylim()[0],
    0.95,
    color="black",
    linestyle="--",
    zorder=0,
)
ax.text(
    x + 0.005,
    np.mean(np.sqrt(np.product(ax.get_ylim()))),
    r"$\rightarrow$"
    + "\n\n\n  All tests fail to\n  reject symmetry\n\n\n"
    + r"$\rightarrow$",
    ha="left",
    va="center",
)

gluefig("input_threshold_pvalues_legend", fig)
ax.get_legend().remove()
gluefig("input_threshold_pvalues", fig)
Hide code cell source
# Look at histogram
Hide code cell source
weight_data = construct_weight_data(left_adj_input_norm, right_adj_input_norm)
median = np.median(weight_data["weights"])

set_theme(font_scale=1.25)
fig, ax = plt.subplots(1, 1, figsize=(8, 5))

sns.histplot(
    data=weight_data,
    x="weights",
    hue="labels",
    palette=network_palette,
    ax=ax,
    discrete=False,
    cumulative=False,
    common_norm=True,
    log_scale=True,
    stat="count",
    bins=50,
    kde=True,
)
sns.move_legend(ax, loc="upper right", title="Hemisphere")
ax.set(xlabel="Weight (input percentage)")
ax.set_xticks([1e-3, 1e-2, 1e-1, 1])
ax.set_xticklabels(["0.1%", "1%", "10%", "100%"])
ax.set_yticks([1000, 2000])
ax.tick_params(pad=10)
ax.axvline(x_threshold, 0, 0.99, color="black", linestyle="--", linewidth=4, alpha=1)
ax.annotate(
    "Critical point\n(Panel F)",
    (x_threshold, ax.get_ylim()[1] * 0.95),
    xytext=(-50, 10),
    textcoords="offset points",
    ha="right",
    va="top",
    color="black",
    arrowprops=dict(arrowstyle="-|>", facecolor="black", relpos=(1, 0.7)),
)
gluefig("input_proportion_histogram", fig)
Hide code cell source
fontsize = 9

weight_notions = SmartSVG(FIG_PATH / "weight_notions.svg")
weight_notions.set_width(200)
weight_notions.move(10, 15)
weight_notions_panel = Panel(
    weight_notions,
    Text("A) Notions of edge weight", 5, 10, size=fontsize, weight="bold"),
)

methods = SmartSVG(FIG_PATH / "thresholding_methods.svg")
methods.set_width(200)
methods.move(10, 15)
methods_panel = Panel(
    methods, Text("B) Thresholding methods", 5, 10, size=fontsize, weight="bold")
)
methods_panel.move(weight_notions.width * 0.85, 0)

synapse_hist = SmartSVG(FIG_PATH / "synapse_weight_histogram.svg")
synapse_hist.set_width(200)
synapse_hist.move(10, 15)
synapse_hist_panel = Panel(
    synapse_hist,
    Text("C) Synapse count distribution", 5, 10, size=fontsize, weight="bold"),
)
synapse_hist_panel.move(0, methods.height * 0.9)

input_hist = SmartSVG(FIG_PATH / "input_proportion_histogram.svg")
input_hist.set_width(200)
input_hist.move(10, 15)
input_hist_panel = Panel(
    input_hist,
    Text("D) Input proportion distribution", 5, 10, size=fontsize, weight="bold"),
)
input_hist_panel.move(synapse_hist.width * 0.85, methods.height * 0.9)

synapse_pvalues = SmartSVG(FIG_PATH / "synapse_threshold_pvalues_legend.svg")
synapse_pvalues.set_width(200)
synapse_pvalues.move(10, 15)
synapse_pvalues_panel = Panel(
    synapse_pvalues,
    Text("E) Synapse thresholding p-values", 5, 10, size=fontsize, weight="bold"),
)
synapse_pvalues_panel.move(0, (methods.height + synapse_hist.height) * 0.9)

input_pvalues = SmartSVG(FIG_PATH / "input_threshold_pvalues_legend.svg")
input_pvalues.set_width(200)
input_pvalues.move(10, 15)
input_pvalues_panel = Panel(
    input_pvalues,
    Text("F) Input thresholding p-values", 5, 10, size=fontsize, weight="bold"),
)
input_pvalues_panel.move(
    synapse_pvalues.width * 0.85, (methods.height + synapse_hist.height) * 0.9
)

fig = Figure(
    methods.width * 2 * 0.88,
    (methods.height + synapse_hist.height + synapse_pvalues.height) * 0.92,
    weight_notions_panel,
    methods_panel,
    synapse_hist_panel,
    input_hist_panel,
    synapse_pvalues_panel,
    input_pvalues_panel,
)
fig.save(FIG_PATH / "thresholding_composite.svg")

svg_to_pdf(
    FIG_PATH / "thresholding_composite.svg", FIG_PATH / "thresholding_composite.pdf"
)

fig
_images/e0e345eab0eaed42ce180fba143d43c38f8ce52143a48c47be9ffdeb439d1b04.svg
Hide code cell source
elapsed = time.time() - t0
delta = datetime.timedelta(seconds=elapsed)
print(f"Script took {delta}")
print(f"Completed at {datetime.datetime.now()}")
Script took 0:01:05.383410
Completed at 2023-03-10 13:44:29.894814
Hide code cell source
(
    left_adj_input_norm_sub,
    right_adj_input_norm_sub,
    left_nodes_sub,
    right_nodes_sub,
) = remove_group(
    left_adj_input_norm, right_adj_input_norm, left_nodes, right_nodes, "KCs"
)


weight_data = construct_weight_data(left_adj_input_norm_sub, right_adj_input_norm_sub)
median = np.median(weight_data["weights"])
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
sns.histplot(
    data=weight_data,
    x="weights",
    hue="labels",
    palette=network_palette,
    ax=ax,
    discrete=False,
    cumulative=False,
    common_norm=True,
    log_scale=True,
    stat="count",
    bins=50,
    kde=True,
)
sns.move_legend(ax, loc="upper right", title="Hemisphere")
ax.set(xlabel="Weight (input proportion)")
ax.axvline(x_threshold, color="black", linestyle="--", linewidth=4, alpha=1)

ax.set_title("KC-")
gluefig("input_proportion_histogram_kc_minus", fig)