from functools import partial
from typing import Optional, Sequence
import numpy as np
import scanpy as sc
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import ListedColormap
from popari._dataset_utils import (
_evaluate_classification_task,
_multigroup_heatmap,
_multireplicate_heatmap,
_plot_all_embeddings,
_plot_cell_type_to_metagene,
_plot_cell_type_to_metagene_difference,
_plot_clusters_to_categories,
_plot_confusion_matrix,
_plot_in_situ,
_plot_metagene_embedding,
_plot_metagene_signature_enrichment,
_plot_umap,
for_model,
setup_squarish_axes,
)
from popari.model import Popari
in_situ = for_model(_plot_in_situ, return_outputs=True)
metagene_embedding = for_model(_plot_metagene_embedding, return_outputs=True)
confusion_matrix = for_model(_plot_confusion_matrix, return_outputs=True)
umap = for_model(_plot_umap, return_outputs=True)
multireplicate_heatmap = for_model(_multireplicate_heatmap, return_outputs=True)
clusters_to_categories = for_model(_plot_clusters_to_categories, return_outputs=True)
metagene_signature_enrichment = for_model(_plot_metagene_signature_enrichment, return_outputs=True)
[docs]
def multigroup_heatmap(
trained_model: Popari,
title_font_size: Optional[int] = None,
group_type: str = "metagene",
axes: Optional[Sequence[Axes]] = None,
key: Optional[str] = None,
level=0,
**heatmap_kwargs,
):
r"""Plot 2D heatmap data across all datasets.
Wrapper function to enable plotting of continuous 2D data across multiple replicates. Only
one of ``obsm``, ``obsp`` or ``uns`` should be used.
Args:
trained_model: the trained Popari model.
axes: A predefined set of matplotlib axes to plot on.
obsm: the key in the ``.obsm`` dataframe to plot.
obsp: the key in the ``.obsp`` dataframe to plot.
uns: the key in the ``.uns`` dataframe to plot. Unstructured data must be 2D in shape.
**heatmap_kwargs: arguments to pass to the `ax.imshow` call for each dataset
"""
datasets = trained_model.hierarchy[level].datasets
groups = trained_model.metagene_groups if group_type == "metagene" else trained_model.spatial_affinity_groups
_multigroup_heatmap(datasets, title_font_size=title_font_size, groups=groups, axes=axes, key=key, **heatmap_kwargs)
def spatial_affinities(
trained_model: Popari,
title_font_size: Optional[int] = None,
spatial_affinity_key: Optional[str] = "Sigma_x_inv",
axes: Optional[Sequence[Axes]] = None,
level=0,
**heatmap_kwargs,
):
r"""Plot Sigma_x_inv across all datasets.
Wrapper function to enable plotting of continuous 2D data across multiple replicates. Only
one of ``obsm``, ``obsp`` or ``uns`` should be used.
Args:
trained_model: the trained Popari model.
axes: A predefined set of matplotlib axes to plot on.
obsm: the key in the ``.obsm`` dataframe to plot.
obsp: the key in the ``.obsp`` dataframe to plot.
uns: the key in the ``.uns`` dataframe to plot. Unstructured data must be 2D in shape.
**heatmap_kwargs: arguments to pass to the `ax.imshow` call for each dataset
"""
datasets = trained_model.hierarchy[level].datasets
# Override following kwargs with
cmap = heatmap_kwargs.pop("cmap") if "cmap" in heatmap_kwargs else "bwr"
nested = heatmap_kwargs.pop("nested") if "nested" in heatmap_kwargs else True
max_value = round(
np.max(np.abs(np.array([dataset.uns[spatial_affinity_key][dataset.name] for dataset in datasets]))),
)
vmin = -max_value
vmax = max_value
_multireplicate_heatmap(
datasets,
title_font_size=title_font_size,
axes=axes,
uns=spatial_affinity_key,
nested=nested,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**heatmap_kwargs,
)
[docs]
def all_embeddings(
trained_model: Popari,
embedding_key: str = "X",
column_names: Optional[str] = None,
level=0,
**spatial_kwargs,
):
r"""Plot all learned metagenes in-situ across all replicates.
Each replicate's metagenes are contained in a separate plot.
Args:
trained_model: the trained Popari model.
embedding_key: the key in the ``.obsm`` dataframe for the cell/spot embeddings.
column_names: a list of the suffixes for each latent feature. If ``None``, it is assumed
that these suffixes are just the indices of the latent features.
"""
datasets = trained_model.hierarchy[level].datasets
first_dataset = datasets[0]
_, K = first_dataset.obsm[f"{embedding_key}"].shape
if column_names == None:
column_names = [f"{embedding_key}_{index}" for index in range(K)]
_plot_all_embeddings(datasets, embedding_key=embedding_key, column_names=column_names, **spatial_kwargs)
def cell_type_to_metagene(trained_model: Popari, cell_type_de_genes: dict, level=0, **correspondence_kwargs):
r"""Plot distribution of gene ranks of marker genes within each metagene.
Args:
trained_model: the trained Popari model.
cell_type_de_genes: dictionary mapping each cell type to a list of marker genes.
Returns:
mapping from each cell type to the median rank of its marker genes in each metagene
"""
datasets = trained_model.hierarchy[level].datasets
first_dataset = datasets[0]
fig, medians = _plot_cell_type_to_metagene(first_dataset, cell_type_de_genes, **correspondence_kwargs)
return medians, fig
def cell_type_to_metagene_difference(
trained_model: Popari,
cell_type_de_genes: dict,
first_metagene: int,
second_metagene: int,
level=0,
**correspondence_kwargs,
):
r"""Plot distribution of gene ranks of marker genes within each metagene.
Args:
trained_model: the trained Popari model.
cell_type_de_genes: dictionary mapping each cell type to a list of marker genes.
Returns:
mapping from each cell type to the median rank of its marker genes in each metagene
"""
datasets = trained_model.hierarchy[level].datasets
first_dataset = datasets[0]
fig, medians = _plot_cell_type_to_metagene_difference(
first_dataset,
cell_type_de_genes,
first_metagene,
second_metagene,
**correspondence_kwargs,
)
def affinity_magnitude_vs_difference(
trained_model,
level=0,
spatial_affinity_key: str = "Sigma_x_inv",
spatial_affinity_bar_key: str = "spatial_affinity_bar",
joint=False,
figsize=(10, 10),
n_best: int = 5,
):
"""Plot all pairwise affinities, in terms of absolute and relative
magnitude.
Args:
trained_model: the trained Popari model.
"""
datasets = trained_model.hierarchy[level].datasets
group_suffix = f"level_{level}" if level > 0 else ""
fig, axes = setup_squarish_axes(len(datasets), figsize=figsize)
all_top_pairs = []
for index in range(len(datasets), axes.size):
axes.flat[index].axis("off")
for ax, (index, dataset) in zip(axes.flat, enumerate(datasets)):
dataset.uns["delta_Sigma"] = {
dataset.name: dataset.uns[spatial_affinity_key][dataset.name]
- dataset.uns[spatial_affinity_bar_key][f"_default_{group_suffix}"],
}
Sigma_x_inv = dataset.uns[spatial_affinity_key][dataset.name]
delta_Sigma = dataset.uns["delta_Sigma"][dataset.name]
pairs = {}
for i in range(trained_model.K):
for j in range(i + 1):
pairs[(i, j)] = (delta_Sigma[(i, j)], Sigma_x_inv[(i, j)])
indices, flat_pairs = zip(*pairs.items())
magnitudes = np.linalg.norm(flat_pairs, axis=1)
sorted_index = np.argsort(magnitudes)
best_index = sorted_index[-n_best:][::-1]
x, y = np.array(flat_pairs).T
ax.scatter(x, y, s=1, color="#D3D3D3")
num_top_points = abs(n_best)
colors = sc.pl.palettes.godsnot_102[:num_top_points]
best_indices = np.array(indices)[best_index]
for (i, j), color in zip(best_indices, colors):
x, y = pairs[(i, j)]
ax.scatter(x, y, s=20, color=color, label=f"m{i} × m{j}")
ax.set_title("Pairwise affinity scatter")
ax.set_xlabel("Difference from average affinity")
ax.set_ylabel("Pairwise affinity")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
all_top_pairs.append(best_indices)
return fig, all_top_pairs
def normalized_affinity_trends(
trained_model,
timepoint_values: Sequence[float],
time_unit="Days",
normalize: bool = False,
spatial_affinity_key: str = "Sigma_x_inv",
n_best: int = 5,
highlight_metric: str = "pearson",
figsize: tuple = None,
margin_size: float = 0.25,
level=0,
):
"""Plot trends for every pair of affinities; highlight top trends.
Args:
trained_model: the trained Popari model.
timepoint_values: x-values against which to plot trends
time_unit: unit in which time is measured (used for x-axis label)
"""
datasets = trained_model.hierarchy[level].datasets
first_dataset = datasets[0]
spatial_trends = first_dataset.uns["spatial_trends"]
all_affinities = np.array([dataset.uns[spatial_affinity_key][dataset.name] for dataset in datasets])
if normalize:
for index in range(len(datasets), axes.size):
axes.flat[index].axis("off")
prenormalization_affinity_std = np.std(all_affinities, axis=0, keepdims=True)
prenormalization_timepoint_std = np.std(timepoint_values)
all_affinities /= prenormalization_affinity_std
timepoint_values /= prenormalization_timepoint_std
timepoint_min = np.min(timepoint_values)
timepoint_ptp = np.ptp(timepoint_values)
timepoint_std = np.std(timepoint_values)
affinity_min = np.min(all_affinities)
affinity_ptp = np.ptp(all_affinities)
affinity_std = np.std(all_affinities, axis=0, keepdims=True)
if figsize is None:
figsize = (10, 5)
fig, ax = plt.subplots(dpi=300, figsize=figsize)
ax.set_ylim([affinity_min - margin_size, affinity_min + affinity_ptp + margin_size])
ax.set_xlim([timepoint_min - margin_size, timepoint_min + timepoint_ptp + margin_size])
ax.set_xticks(timepoint_values)
number_of_lines = abs(n_best)
colors = cm.get_cmap("rainbow", number_of_lines)
for i in range(trained_model.K):
for j in range(i + 1):
affinity_values = all_affinities[:, i, j]
if [i, j] not in spatial_trends["top_pairs"]:
# if True:
line = ax.plot(timepoint_values, affinity_values, color="#D3D3D3", linestyle="--", linewidth=0.5)
for index, (i, j) in enumerate(spatial_trends["top_pairs"]):
affinity_values = all_affinities[:, i, j]
if highlight_metric == "pearson":
r = spatial_trends["pearson_correlations"][(i, j)]
slope = spatial_trends["slopes"][(i, j)]
slope_display = f", slope={slope:.2f}" if not normalize else ""
label = f"m{i} × m{j}, r={r:.2}{slope_display}"
elif highlight_metric == "variance":
variance = spatial_trends["variances"][(i, j)]
label = f"m{i} × m{j}, σ={variance:.2f}"
color = colors(index)
ax.plot(timepoint_values, affinity_values, color=color, linestyle="-", linewidth=3, label=label, zorder=2)
ax.set_title("Pairwise affinity trends")
ax.set_xlabel(f"{time_unit}")
ax.set_ylabel("Pairwise affinity")
ax.legend(loc="upper left", bbox_to_anchor=(1, 0))
return fig