Source code for pertpy.tools._dialogue

"""DIALOGUE: cross-cell-type multicellular program discovery.

Re-implements the algorithm of Jerby-Arnon & Regev (2022), `livnatje/DIALOGUE <https://github.com/livnatje/DIALOGUE>`__, on AnnData.

The pipeline has three phases:
    1. ``fit_programs`` — pseudobulk per sample, filter informative features by ANOVA, center + winsorize, run penalized multiple-CCA, residualize on confounders, find program gene signatures by partial Spearman correlation.
    2. ``test_celltype_pairs`` — for every ordered pair of cell types, fit a hierarchical linear model (``y ~ (1 | sample) + x + cell_quality + tme_qc``) of one cell type's program score against the partner cell type's pseudobulk expression of candidate genes, producing signed z-scores.
    3. ``refine_scores`` — aggregate per-gene HLM p-values across pairs via Fisher's method, fit a non-negative least-squares regression of CCA scores against retained genes, and write final per-cell program scores back to ``adata.obsm``.
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import singledispatch
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import statsmodels.formula.api as smf
from scipy import sparse as sp
from scipy import stats
from scipy.optimize import nnls
from sparsecca import multicca_permute, multicca_pmd
from statsmodels.stats.multitest import multipletests

from pertpy._doc import _doc_params, doc_common_plot_args

if TYPE_CHECKING:
    from collections.abc import Sequence

    from anndata import AnnData
    from matplotlib.figure import Figure

_LOG2_PI = float(np.log(2.0 * np.pi))


def _pseudobulk_per_sample(
    adata: AnnData,
    sample_key: str,
    *,
    layer: str | None = None,
    agg: str = "median",
) -> pd.DataFrame:
    """Sample-level pseudobulk of ``adata``.

    Uses :func:`scanpy.get.aggregate` so sparse ``X`` stays sparse until per-group densification.

    Returns:
        DataFrame indexed by sample, columns are ``adata.var_names``.
    """
    aggregated = sc.get.aggregate(adata, by=sample_key, func=agg, layer=layer)
    matrix = aggregated.layers[agg]
    if sp.issparse(matrix):
        matrix = matrix.toarray()
    return pd.DataFrame(matrix, index=list(aggregated.obs_names), columns=list(aggregated.var_names))


@singledispatch
def _column_anova(matrix, groups: np.ndarray) -> np.ndarray:
    """One-way ANOVA p-value for each column of ``matrix``, grouping rows by ``groups``.

    Dispatches on dense ``np.ndarray`` and ``scipy.sparse`` matrices.
    """
    raise NotImplementedError(f"Unsupported matrix type: {type(matrix)!r}")


@_column_anova.register(np.ndarray)
def _column_anova_dense(matrix: np.ndarray, groups: np.ndarray) -> np.ndarray:
    groups = np.asarray(groups)
    unique = np.unique(groups)
    n_features = matrix.shape[1]
    pvals = np.ones(n_features, dtype=np.float64)
    for j in range(n_features):
        samples = [matrix[groups == g, j] for g in unique]
        if any(s.size < 1 for s in samples) or len(samples) < 2:
            continue
        _, p = stats.f_oneway(*samples)
        pvals[j] = p if np.isfinite(p) else 1.0
    return pvals


@_column_anova.register(sp.spmatrix)
def _column_anova_sparse(matrix: sp.spmatrix, groups: np.ndarray) -> np.ndarray:
    return _column_anova_dense(matrix.toarray(), np.asarray(groups))


def _anova_filter_features(
    matrix: np.ndarray | sp.spmatrix,
    groups: np.ndarray,
    *,
    alpha: float = 0.05,
) -> np.ndarray:
    """Boolean mask of columns with one-way ANOVA p < ``alpha`` after BH adjustment."""
    raw = _column_anova(matrix, np.asarray(groups))
    if raw.size == 0:
        return np.zeros(0, dtype=bool)
    adjusted = multipletests(raw, method="fdr_bh")[1]
    return adjusted < alpha


def _center_scale_winsorize(matrix: np.ndarray, *, cap: float = 0.01) -> np.ndarray:
    """Column-wise center + unit-variance scaling, then clip to inner ``[cap, 1-cap]`` quantiles.

    Matches R's ``center.matrix`` + ``cap.mat``.
    """
    arr = np.asarray(matrix, dtype=np.float64)
    mean = arr.mean(axis=0, keepdims=True)
    std = arr.std(axis=0, ddof=1, keepdims=True)
    std = np.where(std > 0, std, 1.0)
    scaled = (arr - mean) / std
    if cap > 0:
        lo = np.quantile(scaled, cap, axis=0, keepdims=True)
        hi = np.quantile(scaled, 1.0 - cap, axis=0, keepdims=True)
        scaled = np.clip(scaled, lo, hi)
    return scaled


def _residualize(values: np.ndarray, covariates: np.ndarray) -> np.ndarray:
    """OLS residuals of ``values`` regressed on ``covariates``.

    ``values`` is ``[n_obs, n_targets]``, ``covariates`` is ``[n_obs, n_covar]``.
    An intercept column is added automatically.
    """
    values = np.asarray(values, dtype=np.float64)
    if values.ndim == 1:
        values = values[:, None]
    covariates = np.asarray(covariates, dtype=np.float64)
    if covariates.ndim == 1:
        covariates = covariates[:, None]
    n = values.shape[0]
    if covariates.shape[0] != n:
        raise ValueError("values and covariates must share the first dimension")
    design = np.column_stack([np.ones(n), covariates])
    beta, *_ = np.linalg.lstsq(design, values, rcond=None)
    return values - design @ beta


@singledispatch
def _partial_spearman(X, Y, Z, *, batch_size: int = 2048):
    """Partial Spearman correlation of every column of X against every column of Y, controlling for Z.

    Returns ``(R, P)`` arrays of shape ``[X_cols, Y_cols]``.
    Matches R's ``ppcor::pcor.mat`` with the Spearman method as used by ``DIALOGUE::pcor.mat``.
    Dispatches on dense ``np.ndarray`` and sparse ``scipy.sparse`` matrices.
    The sparse branch processes column blocks of ``batch_size`` so the full gene-by-cell matrix is never materialized.
    """
    raise NotImplementedError(f"Unsupported X type: {type(X)!r}")


def _prepare_partial_targets(Y: np.ndarray, Z: np.ndarray) -> tuple[np.ndarray, np.ndarray, int, int]:
    """Compute the rank-residualized + standardized target ``Ys`` and the design matrix used to residualize X column blocks."""
    Y = np.asarray(Y, dtype=np.float64)
    Z = np.asarray(Z, dtype=np.float64)
    if Z.ndim == 1:
        Z = Z[:, None]
    n = Y.shape[0]
    if Z.shape[0] != n:
        raise ValueError("Y and Z must have the same number of rows")
    Y_rank = pd.DataFrame(Y).rank().to_numpy()
    Z_rank = pd.DataFrame(Z).rank().to_numpy()
    design = np.column_stack([np.ones(n), Z_rank])
    Yr = Y_rank - design @ np.linalg.lstsq(design, Y_rank, rcond=None)[0]
    Y_std = np.where(Yr.std(0, ddof=1) > 0, Yr.std(0, ddof=1), 1.0)
    Ys = (Yr - Yr.mean(0)) / Y_std
    df = max(n - 2 - Z_rank.shape[1], 1)
    return Ys, design, n, df


def _partial_spearman_block(
    X_block: np.ndarray, Ys: np.ndarray, design: np.ndarray, n: int, df: int
) -> tuple[np.ndarray, np.ndarray]:
    X_rank = pd.DataFrame(X_block).rank().to_numpy()
    Xr = X_rank - design @ np.linalg.lstsq(design, X_rank, rcond=None)[0]
    Xs = (Xr - Xr.mean(0)) / np.where(Xr.std(0, ddof=1) > 0, Xr.std(0, ddof=1), 1.0)
    R = (Xs.T @ Ys) / (n - 1)
    t_stat = R * np.sqrt(df / np.clip(1 - R**2, 1e-30, None))
    P = 2.0 * stats.t.sf(np.abs(t_stat), df=df)
    return R, P


@_partial_spearman.register(np.ndarray)
def _partial_spearman_dense(X: np.ndarray, Y: np.ndarray, Z: np.ndarray, *, batch_size: int = 2048):
    Ys, design, n, df = _prepare_partial_targets(Y, Z)
    if X.shape[0] != n:
        raise ValueError("X and Y must have the same number of rows")
    return _partial_spearman_block(np.asarray(X, dtype=np.float64), Ys, design, n, df)


@_partial_spearman.register(sp.spmatrix)
def _partial_spearman_sparse(X: sp.spmatrix, Y: np.ndarray, Z: np.ndarray, *, batch_size: int = 2048):
    Ys, design, n, df = _prepare_partial_targets(Y, Z)
    if X.shape[0] != n:
        raise ValueError("X and Y must have the same number of rows")
    n_features = X.shape[1]
    R_all = np.zeros((n_features, Ys.shape[1]), dtype=np.float64)
    P_all = np.zeros((n_features, Ys.shape[1]), dtype=np.float64)
    X_csc = X.tocsc()
    for start in range(0, n_features, batch_size):
        end = min(start + batch_size, n_features)
        block = np.asarray(X_csc[:, start:end].toarray(), dtype=np.float64)
        R, P = _partial_spearman_block(block, Ys, design, n, df)
        R_all[start:end] = R
        P_all[start:end] = P
    return R_all, P_all


def _zscores_from_signed_pvalues(estimate: np.ndarray, pvalue: np.ndarray) -> np.ndarray:
    """Signed log10-style z-scores: positive when estimate>0 & p small, negative when estimate<0 & p small.

    Matches R's ``get.cor.zscores`` after converting one-sided p-values.
    """
    estimate = np.asarray(estimate, dtype=np.float64)
    pvalue = np.asarray(pvalue, dtype=np.float64)
    pos = np.where(pvalue == 0, 0.0, pvalue)
    smallest = pos[pos > 0]
    floor = smallest.min() / 2.0 if smallest.size else 1e-300
    pvalue = np.where(pvalue > 0, pvalue, floor)
    pos_half = np.where(estimate > 0, pvalue / 2.0, 1.0 - pvalue / 2.0)
    neg_half = np.where(-estimate > 0, pvalue / 2.0, 1.0 - pvalue / 2.0)
    z = np.where(pos_half > 0.5, np.log10(neg_half), -np.log10(pos_half))
    return z


def _fisher_combine_by_label(pvalues: np.ndarray, labels: np.ndarray) -> np.ndarray:
    """Fisher-combine p-values within each label group, returning one combined p-value per row.

    ``pvalues`` shape ``[n_rows, n_columns]``; ``labels`` shape ``[n_rows]`` (program label per row).
    Mirrors R's ``fisher.combine`` applied after ``p.adjust.mat.per.label`` over programs.
    """
    pvalues = np.asarray(pvalues, dtype=np.float64)
    labels = np.asarray(labels)
    adjusted = np.full_like(pvalues, np.nan)
    for label in np.unique(labels):
        mask = labels == label
        block = pvalues[mask]
        valid = ~np.isnan(block)
        for j in range(block.shape[1]):
            col = block[:, j]
            mask_j = valid[:, j]
            if mask_j.sum() < 1:
                continue
            adj = multipletests(col[mask_j], method="fdr_bh")[1]
            full = np.full(col.shape, np.nan)
            full[mask_j] = adj
            block[:, j] = full
        adjusted[mask] = block
    combined = np.empty(adjusted.shape[0], dtype=np.float64)
    for i, row in enumerate(adjusted):
        finite = row[np.isfinite(row) & (row > 0)]
        if finite.size == 0:
            combined[i] = 1.0
            continue
        stat = -2.0 * np.log(finite).sum()
        df = 2 * finite.size
        combined[i] = float(stats.chi2.sf(stat, df=df))
    return combined


def _iterative_nnls(
    X: np.ndarray,
    y: np.ndarray,
    feature_rank: np.ndarray,
    *,
    correlation_threshold: float = 0.95,
    minimum_features: int = 5,
) -> np.ndarray:
    """Iterative non-negative least squares matching R's ``DLG.iterative.nnls``.

    Features are bucketed by their normalized rank (``feature_rank in [0, 1]``). Within each bucket
    (largest first, down to a third), fit NNLS on that subset, accumulate the fit, then repeat with
    the residuals as the new target. Stop early when the cumulative fit correlates with the original
    target above ``correlation_threshold``.
    """
    X = np.asarray(X, dtype=np.float64)
    y0 = np.asarray(y, dtype=np.float64).ravel()
    y = y0.copy()
    coef = np.zeros(X.shape[1], dtype=np.float64)
    feature_rank = np.asarray(feature_rank, dtype=np.float64)
    y_fit = np.zeros_like(y0)
    buckets = sorted({float(r) for r in feature_rank if r >= 1.0 / 3.0}, reverse=True)

    for bucket in buckets:
        mask = feature_rank == bucket
        if mask.sum() < minimum_features:
            continue
        x_sel = X[:, mask]
        x_coef, _ = nnls(x_sel, y)
        coef[mask] = x_coef
        fit = x_sel @ x_coef
        y_fit = y_fit + fit
        y = y - fit
        if np.unique(y_fit).size > 10 and np.corrcoef(y0, y_fit)[0, 1] > correlation_threshold:
            return coef

    leftover = feature_rank < (buckets[-1] if buckets else 1.0)
    if leftover.sum() >= minimum_features:
        x_sel = X[:, leftover]
        x_coef, _ = nnls(x_sel, y)
        coef[leftover] = x_coef
    return coef


def _hlm_pvalue_per_row(
    expression: np.ndarray,
    score: np.ndarray,
    covariates: pd.DataFrame,
    sample_groups: np.ndarray,
) -> pd.DataFrame:
    """Hierarchical linear model per row: ``score ~ (1|sample) + x + covariates`` where ``x`` is each row of ``expression``.

    Returns a DataFrame with ``estimate`` and ``pvalue`` columns indexed like ``expression.index``.
    """
    if isinstance(expression, pd.DataFrame):
        gene_index = expression.index
        expression = expression.to_numpy()
    else:
        gene_index = pd.Index([f"gene_{i}" for i in range(expression.shape[0])])
    score = np.asarray(score, dtype=np.float64)
    n = score.shape[0]
    covariates = covariates.reset_index(drop=True).copy()
    if covariates.shape[0] != n:
        raise ValueError("covariates rows must match score length")
    groups = pd.Series(sample_groups, name="_sample_").reset_index(drop=True)
    base = pd.concat([covariates, groups], axis=1)
    base["_y_"] = score
    extra_terms = " + ".join(f"Q('{col}')" for col in covariates.columns)
    formula = "_y_ ~ _x_" + (f" + {extra_terms}" if extra_terms else "")
    estimates = np.full(expression.shape[0], np.nan)
    pvalues = np.full(expression.shape[0], np.nan)
    # statsmodels' mixedlm raises ConvergenceWarning whenever the optimizer doesn't
    # hit a tiny gradient tolerance; that happens routinely on degenerate genes and
    # the recorded pvalue is still usable, so silence the noise in bulk loops.
    import warnings as _warnings

    with _warnings.catch_warnings():
        _warnings.simplefilter("ignore", category=Warning)
        for i in range(expression.shape[0]):
            base["_x_"] = expression[i]
            try:
                fit = smf.mixedlm(formula, base, groups=base["_sample_"]).fit(method="bfgs", reml=False, disp=False)
                estimates[i] = float(fit.params.get("_x_", np.nan))
                pvalues[i] = float(fit.pvalues.get("_x_", np.nan))
            except Exception:  # noqa: BLE001 — model may fail on degenerate covariates; record NaN
                continue
    return pd.DataFrame({"estimate": estimates, "pvalue": pvalues}, index=gene_index)


@dataclass
class DialogueState:
    """Cached intermediates produced by ``fit_programs``."""

    cell_type_order: list[str]
    shared_samples: list[str]
    pseudobulk_features: dict[str, pd.DataFrame]
    weights: dict[str, np.ndarray]
    cca_scores: dict[str, np.ndarray]
    empirical_pvalues: pd.DataFrame
    cca_correlations: pd.DataFrame
    gene_signatures: dict[str, dict[str, dict[str, list[str]]]]


[docs] class Dialogue: """Multicellular program discovery (DIALOGUE). Args: celltype_key: Column of ``adata.obs`` with the cell-type assignment. sample_key: Column of ``adata.obs`` with the sample / niche identifier. cell_quality_key: Column of ``adata.obs`` with the per-cell QC value (typically log-counts) used as a confounder in residualization and in the per-pair HLM (R's ``cellQ``). n_programs: Number of multicellular programs to fit (``k`` in the paper). feature_space_key: ``adata.obsm`` key for the pre-computed feature space (typically PCA). n_components: Number of components of the feature space to use. n_genes_per_signature: Number of top-correlated genes kept per program signature (R's ``n.genes``). anova_alpha: Per-feature ANOVA significance threshold for filtering uninformative components (R's ``p.anova``). winsorize_quantile: Tail clipping fraction applied to pseudobulk components (R's ``cap.mat`` parameter). n_permutations: Permutations used to derive empirical PMD p-values (R's ``n1`` in ``DIALOGUE1.PMD.empirical``). empirical_alpha: P-value threshold below which a program is considered shared between a pair of cell types (R's implicit ``< 0.1``). use_tme_qc: If True, add ``tme_qc`` (partner-celltype per-sample average of ``cell_quality_key``) as an additional HLM covariate (R default). additional_covariates: Extra ``adata.obs`` columns to include as HLM covariates. min_cells_per_sample: Minimum cells per sample required for a cell type to be considered in the pair-level HLM (R's ``abn.c``). random_state: Reproducibility seed for permutation tests and PMD permute search. """ def __init__( self, *, celltype_key: str, sample_key: str, cell_quality_key: str = "cellQ", n_programs: int = 3, feature_space_key: str = "X_pca", n_components: int = 30, n_genes_per_signature: int = 100, anova_alpha: float = 0.05, winsorize_quantile: float = 0.01, n_permutations: int = 100, empirical_alpha: float = 0.1, use_tme_qc: bool = True, additional_covariates: Sequence[str] = (), min_cells_per_sample: int = 5, random_state: int = 1234, ) -> None: self.celltype_key = celltype_key self.sample_key = sample_key self.cell_quality_key = cell_quality_key self.n_programs = n_programs self.feature_space_key = feature_space_key self.n_components = n_components self.n_genes_per_signature = n_genes_per_signature self.anova_alpha = anova_alpha self.winsorize_quantile = winsorize_quantile self.n_permutations = n_permutations self.empirical_alpha = empirical_alpha self.use_tme_qc = use_tme_qc self.additional_covariates = tuple(additional_covariates) self.min_cells_per_sample = min_cells_per_sample self.random_state = int(random_state)
[docs] def fit_programs(self, adata: AnnData) -> AnnData: """Identify multicellular programs across cell types via penalized multiple-CCA. Phase 1 of DIALOGUE. Pseudobulks each cell type per sample, filters uninformative components by ANOVA, centers and winsorizes them, then runs penalized multiple-CCA on the cell types' pseudobulk feature spaces to obtain weights and per-cell program scores. Empirical p-values for each program-by-pair are computed by repeating the PMD on permuted matrices. Stores the following on ``adata``: - ``adata.obsm["X_dialogue_cca"]`` — per-cell CCA scores (``n_obs × n_programs``), residualized on the cell-quality confounder, NaN-padded for cells of cell types skipped during fitting. - ``adata.uns["dialogue"]["weights"][celltype]`` — PMD weights (``n_components × n_programs``). - ``adata.uns["dialogue"]["pseudobulk_features"][celltype]`` — the post-filter, post-center pseudobulk matrices (samples × retained components). - ``adata.uns["dialogue"]["empirical_pvalues"]`` — programs × cell-type pairs. - ``adata.uns["dialogue"]["cca_correlations_R"]`` / ``"_P"`` — per-pair pairwise correlation and p-value of the cell types' CCA scores. - ``adata.uns["dialogue"]["program_celltypes"]`` — mapping of each program to the cell types whose pair passed ``empirical_alpha``. - ``adata.uns["dialogue"]["program_signatures"][program][celltype]`` — initial signature genes from partial Spearman correlation of CCA scores against the cell type's expression matrix. - ``adata.uns["dialogue"]["params"]`` — recorded hyperparameters. - ``adata.uns["dialogue"]["shared_samples"]`` — samples present in all cell types. - ``adata.uns["dialogue"]["cell_type_order"]`` — cell types fit (in stable order). Examples: >>> import pertpy as pt >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) """ celltypes = self._cell_type_order(adata) pseudobulks_full, ct_views = self._per_celltype_pseudobulks(adata, celltypes) pseudobulks = self._anova_filter_per_celltype(pseudobulks_full, ct_views) for ct, pb in pseudobulks.items(): if pb.shape[1] < self.n_programs: raise ValueError( f"Cell type {ct!r} retained only {pb.shape[1]} components after the ANOVA " f"filter (need >= n_programs={self.n_programs}). Loosen anova_alpha or use more PCs." ) shared = self._shared_samples(pseudobulks) pseudobulks = {ct: pb.loc[shared] for ct, pb in pseudobulks.items()} centered = {ct: self._center(pb) for ct, pb in pseudobulks.items()} matrices = [centered[ct].to_numpy() for ct in celltypes] weights = self._fit_pmd(matrices) ws_dict = { ct: pd.DataFrame( weights[i], index=centered[ct].columns, columns=[f"MCP{j + 1}" for j in range(weights[i].shape[1])] ) for i, ct in enumerate(celltypes) } empirical_p = self._empirical_pmd_pvalues(matrices, celltypes) cca_correlations_R, cca_correlations_P = self._cca_correlations(matrices, weights, celltypes) cca_scores = { ct: ct_views[ct].obsm[self.feature_space_key][:, : self.n_components][ :, _retained_indices(pb_full, ws_dict[ct]) ] @ ws_dict[ct].to_numpy() for ct, pb_full in pseudobulks_full.items() } cca_scores = self._residualize_cca_scores(cca_scores, ct_views) adata.obsm["X_dialogue_cca"] = self._broadcast_per_celltype(adata, cca_scores, ct_views, n_cols=self.n_programs) program_celltypes = self._program_celltypes(empirical_p, celltypes) program_signatures = self._initial_program_signatures(ct_views, cca_scores, ws_dict) adata.uns["dialogue"] = { "weights": {ct: ws_dict[ct].to_numpy() for ct in celltypes}, "weights_index": {ct: list(ws_dict[ct].index) for ct in celltypes}, "pseudobulk_features": {ct: centered[ct] for ct in celltypes}, "empirical_pvalues": empirical_p, "cca_correlations_R": cca_correlations_R, "cca_correlations_P": cca_correlations_P, "program_celltypes": program_celltypes, "program_signatures": program_signatures, "cell_type_order": list(celltypes), "shared_samples": list(shared), "params": self._param_dict(), } return adata
def _cell_type_order(self, adata: AnnData) -> list[str]: col = adata.obs[self.celltype_key] if hasattr(col, "cat"): return [str(c) for c in col.cat.categories if (col == c).any()] return sorted(map(str, col.unique())) def _per_celltype_pseudobulks( self, adata: AnnData, celltypes: list[str] ) -> tuple[dict[str, pd.DataFrame], dict[str, AnnData]]: """Per cell type, return (sample-level median PCA pseudobulk, cell-level AnnData view). R uses ``colMedians`` as ``param$averaging.function``; we match. """ ct_views: dict[str, AnnData] = {} pseudobulks: dict[str, pd.DataFrame] = {} for ct in celltypes: mask = (adata.obs[self.celltype_key] == ct).to_numpy() sub = adata[mask].copy() ct_views[ct] = sub pcs = sub.obsm[self.feature_space_key][:, : self.n_components] pb_df = ( pd.DataFrame(pcs, columns=[f"PC{i + 1}" for i in range(pcs.shape[1])]) .assign(_sample=sub.obs[self.sample_key].astype(str).to_numpy()) .groupby("_sample") .median() .sort_index() ) pseudobulks[ct] = pb_df return pseudobulks, ct_views def _anova_filter_per_celltype( self, pseudobulks_full: dict[str, pd.DataFrame], ct_views: dict[str, AnnData] ) -> dict[str, pd.DataFrame]: """For each cell type, drop pseudobulk components whose per-cell ANOVA across samples is non-significant. Mirrors R's filter: the ANOVA is run on the per-cell PCA values (not the sample-level pseudobulks), restricting to samples with ``>= min_cells_per_sample`` cells, and BH-adjusted. """ out: dict[str, pd.DataFrame] = {} for ct, pb in pseudobulks_full.items(): view = ct_views[ct] pcs = view.obsm[self.feature_space_key][:, : self.n_components] samples = view.obs[self.sample_key].astype(str).to_numpy() counts = pd.Series(samples).value_counts() abundant = counts[counts >= self.min_cells_per_sample].index row_mask = np.isin(samples, abundant.to_numpy()) if row_mask.sum() == 0 or np.unique(samples[row_mask]).size < 2: out[ct] = pb continue mask = _anova_filter_features(pcs[row_mask], samples[row_mask], alpha=self.anova_alpha) if mask.sum() < self.n_programs: # R also keeps the original components when too few pass; we propagate that. out[ct] = pb continue out[ct] = pb.iloc[:, mask] return out def _shared_samples(self, pseudobulks: dict[str, pd.DataFrame]) -> list[str]: shared = set.intersection(*[set(pb.index) for pb in pseudobulks.values()]) if len(shared) < 5: raise ValueError(f"Only {len(shared)} samples are present in all cell types; DIALOGUE needs at least 5.") return sorted(shared) def _center(self, pseudobulk: pd.DataFrame) -> pd.DataFrame: scaled = _center_scale_winsorize(pseudobulk.to_numpy(), cap=self.winsorize_quantile) return pd.DataFrame(scaled, index=pseudobulk.index, columns=pseudobulk.columns) def _fit_pmd(self, matrices: list[np.ndarray]) -> list[np.ndarray]: n_samples = matrices[0].shape[0] penalties = multicca_permute( matrices, penalties=float(np.sqrt(n_samples) / 2.0), nperms=10, niter=50, standardize=True, )["bestpenalties"] weights, _ = multicca_pmd( matrices, penalties, K=self.n_programs, standardize=True, niter=100, mimic_R=True, ) return weights def _empirical_pmd_pvalues(self, matrices: list[np.ndarray], celltypes: list[str]) -> pd.DataFrame: rng = np.random.default_rng(self.random_state) baseline = self._pmd_pair_correlations(matrices, celltypes) pair_names = baseline.columns.tolist() better = np.zeros((self.n_programs, len(pair_names)), dtype=np.float64) for _ in range(self.n_permutations): permuted = [_column_shuffle(m, rng) for m in matrices] try: perm_cor = self._pmd_pair_correlations(permuted, celltypes) except Exception: # noqa: BLE001 - degenerate permutation; treat as exceeded better += 1.0 continue better += (np.abs(perm_cor.to_numpy()) >= np.abs(baseline.to_numpy())).astype(np.float64) empirical = (better + 1.0) / (self.n_permutations + 1.0) index = [f"MCP{i + 1}" for i in range(self.n_programs)] return pd.DataFrame(empirical, index=index, columns=pair_names) def _pmd_pair_correlations(self, matrices: list[np.ndarray], celltypes: list[str]) -> pd.DataFrame: weights = self._fit_pmd(matrices) scores = [matrices[i] @ weights[i] for i in range(len(matrices))] names = [f"{celltypes[i]}_{celltypes[j]}" for i, j in _pair_indices(len(matrices))] cor = np.zeros((self.n_programs, len(names))) for col, (i, j) in enumerate(_pair_indices(len(matrices))): for k in range(self.n_programs): a = scores[i][:, k] b = scores[j][:, k] denom = (a.std(ddof=1) * b.std(ddof=1)) or 1.0 cor[k, col] = float(np.cov(a, b, ddof=1)[0, 1] / denom) return pd.DataFrame(cor, columns=names, index=[f"MCP{i + 1}" for i in range(self.n_programs)]) def _cca_correlations( self, matrices: list[np.ndarray], weights: list[np.ndarray], celltypes: list[str] ) -> tuple[pd.DataFrame, pd.DataFrame]: scores = [matrices[i] @ weights[i] for i in range(len(matrices))] names = [f"{celltypes[i]}_{celltypes[j]}" for i, j in _pair_indices(len(matrices))] n = matrices[0].shape[0] df = max(n - 2, 1) R = np.zeros((self.n_programs, len(names))) P = np.zeros((self.n_programs, len(names))) for col, (i, j) in enumerate(_pair_indices(len(matrices))): for k in range(self.n_programs): a = scores[i][:, k] b = scores[j][:, k] denom = (a.std(ddof=1) * b.std(ddof=1)) or 1.0 r = float(np.cov(a, b, ddof=1)[0, 1] / denom) R[k, col] = r t_stat = r * np.sqrt(df / np.clip(1 - r**2, 1e-30, None)) P[k, col] = 2.0 * stats.t.sf(np.abs(t_stat), df=df) index = [f"MCP{i + 1}" for i in range(self.n_programs)] return ( pd.DataFrame(R, index=index, columns=names), pd.DataFrame(P, index=index, columns=names), ) def _residualize_cca_scores( self, cca_scores: dict[str, np.ndarray], ct_views: dict[str, AnnData] ) -> dict[str, np.ndarray]: out = {} for ct, scores in cca_scores.items(): conf = ct_views[ct].obs[self.cell_quality_key].to_numpy(dtype=np.float64) out[ct] = _residualize(scores, conf) return out def _broadcast_per_celltype( self, adata: AnnData, per_ct: dict[str, np.ndarray], ct_views: dict[str, AnnData], *, n_cols: int, ) -> np.ndarray: out = np.full((adata.n_obs, n_cols), np.nan, dtype=np.float64) cell_index = pd.Index(adata.obs_names) for ct, mat in per_ct.items(): view = ct_views[ct] positions = cell_index.get_indexer(view.obs_names) out[positions] = mat return out def _program_celltypes(self, empirical_p: pd.DataFrame, celltypes: list[str]) -> dict[str, list[str]]: pair_to_celltypes = { f"{celltypes[i]}_{celltypes[j]}": (celltypes[i], celltypes[j]) for i, j in _pair_indices(len(celltypes)) } out: dict[str, list[str]] = {} for program in empirical_p.index: members: set[str] = set() for col, value in empirical_p.loc[program].items(): if value < self.empirical_alpha: a, b = pair_to_celltypes[col] members.update({a, b}) out[program] = sorted(members) if members else [] return out def _initial_program_signatures( self, ct_views: dict[str, AnnData], cca_scores: dict[str, np.ndarray], ws_dict: dict[str, pd.DataFrame], ) -> dict[str, dict[str, dict[str, list[str]]]]: out: dict[str, dict[str, dict[str, list[str]]]] = {f"MCP{i + 1}": {} for i in range(self.n_programs)} for ct, scores in cca_scores.items(): view = ct_views[ct] X = view.X # may be sparse; _partial_spearman dispatches and streams cellQ = view.obs[self.cell_quality_key].to_numpy(dtype=np.float64) n_genes = view.n_vars R, P = _partial_spearman(X, scores, cellQ) for program_idx in range(scores.shape[1]): program_name = f"MCP{program_idx + 1}" col_R = R[:, program_idx] col_P = P[:, program_idx] bonferroni = 0.05 / max(n_genes, 1) ranked = np.argsort(-np.abs(col_R)) up: list[str] = [] down: list[str] = [] for gene_idx in ranked: if len(up) + len(down) >= self.n_genes_per_signature * 2: break if col_P[gene_idx] > bonferroni: continue name = view.var_names[gene_idx] if col_R[gene_idx] > 0 and len(up) < self.n_genes_per_signature: up.append(name) elif col_R[gene_idx] < 0 and len(down) < self.n_genes_per_signature: down.append(name) out[program_name][ct] = {"up": up, "down": down} return out
[docs] def test_celltype_pairs(self, adata: AnnData, *, show_progress: bool = False) -> AnnData: """For every ordered pair of cell types, fit a hierarchical linear model of one cell type's program score against the partner cell type's pseudobulk expression of candidate genes. Phase 2 of DIALOGUE. Builds per-pair, per-program tables of (estimate, pvalue, z-score) for each candidate gene from ``fit_programs``' signatures and prunes them to the top ``n_genes_per_signature`` per direction. Pair-level "shared abundant" samples are those with at least ``min_cells_per_sample`` cells in both cell types of the pair. Stores on ``adata.uns["dialogue"]["pair_results"]`` a nested dict. ``pair_results[pair_name][program][celltype]`` is a DataFrame with one row per gene tested (in that cell type's signature) and columns ``estimate, pvalue, zscore, up``. Refined per-pair signatures live at ``pair_results[pair_name][program]["refined_signatures"][celltype] = {"up": [...], "down": [...]}``. Args: adata: AnnData previously processed by :meth:`fit_programs`. show_progress: If True, print one line per pair while running. Examples: >>> import pertpy as pt >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) """ if "dialogue" not in adata.uns: raise RuntimeError("Run fit_programs(adata) before test_celltype_pairs(adata).") state = adata.uns["dialogue"] celltypes = state["cell_type_order"] ct_views = self._rebuild_celltype_views(adata, celltypes) cca_scores = self._extract_cca_scores(adata, ct_views) gene_pseudobulks = self._build_gene_pseudobulks(ct_views) per_sample_quality = self._build_per_sample_quality(ct_views) pair_results: dict[str, dict[str, dict[str, object]]] = {} for i, j in _pair_indices(len(celltypes)): ct1, ct2 = celltypes[i], celltypes[j] pair_name = f"{ct1}_{ct2}" shared = self._shared_abundant_samples(ct_views, ct1, ct2) if len(shared) < 5: if show_progress: print(f" skip {pair_name}: only {len(shared)} shared abundant samples") pair_results[pair_name] = {} continue ct1_cells = ct_views[ct1].obs[self.sample_key].astype(str).isin(shared).to_numpy() ct2_cells = ct_views[ct2].obs[self.sample_key].astype(str).isin(shared).to_numpy() ct1_scores = cca_scores[ct1][ct1_cells] ct2_scores = cca_scores[ct2][ct2_cells] ct1_samples = ct_views[ct1].obs[self.sample_key].astype(str).to_numpy()[ct1_cells] ct2_samples = ct_views[ct2].obs[self.sample_key].astype(str).to_numpy()[ct2_cells] ct1_quality = ct_views[ct1].obs[self.cell_quality_key].to_numpy(dtype=np.float64)[ct1_cells] ct2_quality = ct_views[ct2].obs[self.cell_quality_key].to_numpy(dtype=np.float64)[ct2_cells] ct1_tme_qc = per_sample_quality[ct2].reindex(ct1_samples).to_numpy() ct2_tme_qc = per_sample_quality[ct1].reindex(ct2_samples).to_numpy() shared_mcps = [ program for program, members in state["program_celltypes"].items() if ct1 in members and ct2 in members ] if show_progress: print(f" pair {pair_name}: {len(shared_mcps)} shared programs, {len(shared)} samples") pair_results[pair_name] = {} for program in shared_mcps: program_idx = int(program.replace("MCP", "")) - 1 sig1 = state["program_signatures"][program][ct1] sig2 = state["program_signatures"][program][ct2] sig1_up = self._intersect_genes(sig1["up"], gene_pseudobulks[ct1].columns) sig1_down = self._intersect_genes(sig1["down"], gene_pseudobulks[ct1].columns) sig2_up = self._intersect_genes(sig2["up"], gene_pseudobulks[ct2].columns) sig2_down = self._intersect_genes(sig2["down"], gene_pseudobulks[ct2].columns) ct1_genes_to_test = sig1_up + sig1_down ct2_genes_to_test = sig2_up + sig2_down # ct2's program score vs ct1's pseudobulk expression at ct2's cells (R's p1). ct2_tme_for_ct1_genes = gene_pseudobulks[ct1].loc[ct2_samples, ct1_genes_to_test].to_numpy() df_ct1 = self._hlm_block( ct2_scores[:, program_idx], ct2_tme_for_ct1_genes, ct1_genes_to_test, sig1_up, ct2_quality, ct2_tme_qc, ct2_samples, ) # ct1's program score vs ct2's pseudobulk expression at ct1's cells (R's p2). ct1_tme_for_ct2_genes = gene_pseudobulks[ct2].loc[ct1_samples, ct2_genes_to_test].to_numpy() df_ct2 = self._hlm_block( ct1_scores[:, program_idx], ct1_tme_for_ct2_genes, ct2_genes_to_test, sig2_up, ct1_quality, ct1_tme_qc, ct1_samples, ) refined_ct1 = self._top_by_zscore(df_ct1, n=self.n_genes_per_signature) refined_ct2 = self._top_by_zscore(df_ct2, n=self.n_genes_per_signature) pair_results[pair_name][program] = { ct1: df_ct1, ct2: df_ct2, "refined_signatures": {ct1: refined_ct1, ct2: refined_ct2}, } state["pair_results"] = pair_results state["gene_pseudobulks"] = {ct: gene_pseudobulks[ct] for ct in celltypes} state["per_sample_quality"] = {ct: per_sample_quality[ct] for ct in celltypes} return adata
def _rebuild_celltype_views(self, adata: AnnData, celltypes: list[str]) -> dict[str, AnnData]: return {ct: adata[(adata.obs[self.celltype_key] == ct).to_numpy()].copy() for ct in celltypes} def _extract_cca_scores(self, adata: AnnData, ct_views: dict[str, AnnData]) -> dict[str, np.ndarray]: full = adata.obsm["X_dialogue_cca"] idx = pd.Index(adata.obs_names) out = {} for ct, view in ct_views.items(): positions = idx.get_indexer(view.obs_names) out[ct] = np.asarray(full[positions], dtype=np.float64) return out def _build_gene_pseudobulks(self, ct_views: dict[str, AnnData]) -> dict[str, pd.DataFrame]: """Per cell type, return a sample × gene mean-pseudobulk DataFrame (matches R's ``tpmAv``).""" out: dict[str, pd.DataFrame] = {} for ct, view in ct_views.items(): out[ct] = _pseudobulk_per_sample(view, sample_key=self.sample_key, agg="mean") return out def _build_per_sample_quality(self, ct_views: dict[str, AnnData]) -> dict[str, pd.Series]: """Per cell type, per-sample mean of ``cell_quality_key`` (matches R's ``qcAv``).""" out: dict[str, pd.Series] = {} for ct, view in ct_views.items(): samples = view.obs[self.sample_key].astype(str).to_numpy() quality = view.obs[self.cell_quality_key].to_numpy(dtype=np.float64) out[ct] = pd.Series(quality).groupby(samples).mean().rename("qcAv") return out def _shared_abundant_samples(self, ct_views: dict[str, AnnData], ct1: str, ct2: str) -> list[str]: def _abundant(ct: str) -> set[str]: samples = ct_views[ct].obs[self.sample_key].astype(str) counts = samples.value_counts() return set(counts[counts >= self.min_cells_per_sample].index) return sorted(_abundant(ct1) & _abundant(ct2)) def _intersect_genes(self, candidate: list[str], present: pd.Index) -> list[str]: present_set = set(present) return [g for g in candidate if g in present_set] def _hlm_block( self, score: np.ndarray, expression: np.ndarray, gene_names: list[str], up_set: list[str], cell_quality: np.ndarray, tme_qc: np.ndarray, sample_groups: np.ndarray, ) -> pd.DataFrame: if len(gene_names) == 0: return pd.DataFrame(columns=["estimate", "pvalue", "zscore", "up"]) covariate_dict = {self.cell_quality_key: cell_quality} if self.use_tme_qc: covariate_dict["tme_qc"] = tme_qc for col in self.additional_covariates: covariate_dict[col] = np.zeros_like( cell_quality ) # placeholder; user-provided covariate handling reserved for run() covariates = pd.DataFrame(covariate_dict) # expression rows -> genes, columns -> cells. Transpose to genes-by-cells for our helper. expression_arr = pd.DataFrame(expression.T, index=gene_names).to_numpy() res = _hlm_pvalue_per_row( pd.DataFrame(expression_arr, index=gene_names), score, covariates, sample_groups, ) res["zscore"] = _zscores_from_signed_pvalues(res["estimate"].to_numpy(), res["pvalue"].to_numpy()) up_lookup = set(up_set) res["up"] = [g in up_lookup for g in res.index] return res def _top_by_zscore(self, df: pd.DataFrame, *, n: int) -> dict[str, list[str]]: if df.empty: return {"up": [], "down": []} finite = df.dropna(subset=["zscore"]) up_candidates = finite.loc[finite["zscore"] > 0].sort_values("zscore", ascending=False) down_candidates = finite.loc[finite["zscore"] < 0].sort_values("zscore", ascending=True) return { "up": up_candidates.head(n).index.tolist(), "down": down_candidates.head(n).index.tolist(), }
[docs] def refine_scores(self, adata: AnnData) -> AnnData: """Aggregate per-pair HLM evidence and fit final per-cell program scores via iterative non-negative least squares. Phase 3 of DIALOGUE. For every cell type, gather the per-gene z-scores produced by :meth:`test_celltype_pairs` across every pair the cell type appears in, BH-adjust within each program-by-direction, Fisher-combine across pairs, then run iterative NNLS to fit per-cell program scores against the resulting candidate gene set (sign-flipping down-regulated columns). The fitted scores are residualized on the cell-quality confounder and written back to ``adata.obsm["X_dialogue"]``. Stores on ``adata.uns["dialogue"]``: - ``gene_pvalues[celltype]`` — combined gene table with per-pair z-scores, Fisher-combined ``p_up``/``p_down``, support counts ``n_up``/``n_down``, fractions ``nf_up``/``nf_down``, program label, ``up`` direction, and the fitted ``coef`` from NNLS. - ``program_gene_signatures[program][celltype] = {"up": [...], "down": [...]}`` — refined gene signatures (R's ``sig1`` from ``DLG.find.scoring``). - ``program_gene_signatures_strict[program][celltype] = {"up": [...], "down": [...]}`` — stricter set (R's ``sig2``). - ``pair_refined_correlations[pair][program]`` — per-pair sample-average correlation R of the refined scores plus the HLM p-value for the same pair. Updates ``adata.obsm["X_dialogue"]`` with the refined per-cell program scores. Examples: >>> import pertpy as pt >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) >>> dl.refine_scores(adata) """ if "pair_results" not in adata.uns.get("dialogue", {}): raise RuntimeError("Run test_celltype_pairs(adata) before refine_scores(adata).") state = adata.uns["dialogue"] celltypes = state["cell_type_order"] ct_views = self._rebuild_celltype_views(adata, celltypes) gene_pvalues: dict[str, pd.DataFrame] = {} for ct in celltypes: gene_pvalues[ct] = self._aggregate_gene_pvalues_for_celltype(state["pair_results"], celltypes, ct) nnls_scores: dict[str, np.ndarray] = {} refined_signatures: dict[str, dict[str, dict[str, list[str]]]] = { f"MCP{p + 1}": {} for p in range(self.n_programs) } strict_signatures: dict[str, dict[str, dict[str, list[str]]]] = { f"MCP{p + 1}": {} for p in range(self.n_programs) } for ct in celltypes: view = ct_views[ct] cca0 = self._cca_scores_unresidualized(view, state, ct) program_columns = [f"MCP{p + 1}" for p in range(self.n_programs)] ct_scores = np.zeros((view.n_obs, self.n_programs)) gene_pval = gene_pvalues[ct] gene_pval["coef"] = 0.0 # Densify only the candidate gene columns; for a typical run that bounds the dense # working set at a few hundred columns rather than the per-celltype full gene matrix. candidate_genes_per_program: dict[str, np.ndarray] = {} all_candidates: list[str] = [] for program in program_columns: program_rows = gene_pval[gene_pval["program"] == program] if program_rows.empty: candidate_genes_per_program[program] = np.empty(0, dtype=object) continue names = program_rows["gene"].to_numpy() candidate_genes_per_program[program] = names all_candidates.extend(names.tolist()) unique_candidates = sorted(set(all_candidates)) if unique_candidates: slim = _select_dense_gene_columns(view.X, view.var_names, unique_candidates) zscored_slim = _zscore_columns(slim) slim_name_to_idx = {name: i for i, name in enumerate(unique_candidates)} else: zscored_slim = np.empty((view.n_obs, 0), dtype=np.float64) slim_name_to_idx = {} for program_idx, program in enumerate(program_columns): y_target = cca0[:, program_idx] program_rows = gene_pval[gene_pval["program"] == program] if program_rows.empty: ct_scores[:, program_idx] = y_target continue gene_names = candidate_genes_per_program[program] slim_indices = np.array([slim_name_to_idx[g] for g in gene_names], dtype=np.int64) X_program = zscored_slim[:, slim_indices].copy() down_mask = ~program_rows["up"].to_numpy(dtype=bool) X_program[:, down_mask] *= -1.0 ranks = program_rows["Nf"].to_numpy(dtype=np.float64) coefs = _iterative_nnls(X_program, y_target, ranks) ct_scores[:, program_idx] = X_program @ coefs gene_pval.loc[program_rows.index, "coef"] = coefs nnls_scores[ct] = ct_scores for program in program_columns: program_rows = gene_pval[gene_pval["program"] == program] if program_rows.empty: refined_signatures[program][ct] = {"up": [], "down": []} strict_signatures[program][ct] = {"up": [], "down": []} continue n_cells_in_program = len(state["program_celltypes"].get(program, [])) threshold_n = max(1, int(np.ceil(n_cells_in_program / 2))) strong_p = (program_rows["coef"].to_numpy() > 0) | ( ((program_rows["n_up"].to_numpy() >= threshold_n) & (program_rows["p_up"].to_numpy() < 1e-3)) | ((program_rows["n_down"].to_numpy() >= threshold_n) & (program_rows["p_down"].to_numpy() < 1e-3)) ) strict = (program_rows["Nf"].to_numpy() == 1.0) & ( (program_rows["p_up"].to_numpy() < 0.05) | (program_rows["p_down"].to_numpy() < 0.05) ) refined_signatures[program][ct] = self._split_up_down(program_rows.loc[strong_p]) strict_signatures[program][ct] = self._split_up_down(program_rows.loc[strict]) for ct in celltypes: view = ct_views[ct] cellQ = view.obs[self.cell_quality_key].to_numpy(dtype=np.float64) nnls_scores[ct] = _residualize(nnls_scores[ct], cellQ) adata.obsm["X_dialogue"] = self._broadcast_per_celltype(adata, nnls_scores, ct_views, n_cols=self.n_programs) for p in range(self.n_programs): adata.obs[f"mcp_{p}"] = adata.obsm["X_dialogue"][:, p] pair_refined = self._refined_pair_correlations(adata, nnls_scores, ct_views, celltypes) state["gene_pvalues"] = gene_pvalues state["program_gene_signatures"] = refined_signatures state["program_gene_signatures_strict"] = strict_signatures state["pair_refined_correlations"] = pair_refined return adata
def _aggregate_gene_pvalues_for_celltype( self, pair_results: dict[str, dict[str, dict[str, object]]], celltypes: list[str], ct: str, ) -> pd.DataFrame: """For a given cell type, build R's per-program-x-gene gene_pval DataFrame from pair results.""" gene_records: dict[tuple[str, str, bool], dict[str, float]] = {} partner_cols: list[str] = [] for pair_name, programs in pair_results.items(): ct1, ct2 = self._pair_split(pair_name, celltypes) if ct not in (ct1, ct2): continue partner = ct2 if ct == ct1 else ct1 colname = f"{partner}" if colname not in partner_cols: partner_cols.append(colname) for program, info in programs.items(): df = info.get(ct) if not isinstance(df, pd.DataFrame) or df.empty: continue for gene_name, row in df.iterrows(): if not np.isfinite(row["zscore"]): continue key = (program, gene_name, bool(row["up"])) gene_records.setdefault(key, {})[colname] = float(row["zscore"]) if not gene_records: return pd.DataFrame( columns=[ "gene", "program", "up", "programF", "p_up", "p_down", "n_up", "nf_up", "n_down", "nf_down", "N", "Nf", ] ) records = [] for (program, gene, up), partners in gene_records.items(): row = {"program": program, "gene": gene, "up": up, "programF": f"{program}.{'up' if up else 'down'}"} row.update({col: partners.get(col, np.nan) for col in partner_cols}) records.append(row) df = pd.DataFrame(records) df.index = [f"{r['programF']}_{r['gene']}" for _, r in df.iterrows()] z = df[partner_cols].to_numpy() # Two-sided p from z, then BH-adjust within (program, direction) and Fisher-combine. p_up_partner = self._adjust_per_label(self._pvals_from_zscores(z), df["programF"].to_numpy()) p_down_partner = self._adjust_per_label(self._pvals_from_zscores(-z), df["programF"].to_numpy()) df["p_up"] = self._fisher_per_row(p_up_partner) df["p_down"] = self._fisher_per_row(p_down_partner) df["n_up"] = (p_up_partner < self.empirical_alpha).sum(axis=1) df["nf_up"] = (p_up_partner < self.empirical_alpha).mean(axis=1) df["n_down"] = (p_down_partner < self.empirical_alpha).sum(axis=1) df["nf_down"] = (p_down_partner < self.empirical_alpha).mean(axis=1) df["N"] = np.where(df["up"], df["n_up"], df["n_down"]) df["Nf"] = np.where(df["up"], df["nf_up"], df["nf_down"]) # Override p_up/p_down: when the gene is "up" we keep p_up; otherwise the down side carries the signal. df.loc[~df["up"].astype(bool), "p_up"] = 1.0 df.loc[df["up"].astype(bool), "p_down"] = 1.0 return df.reset_index(drop=True) def _pair_split(self, pair_name: str, celltypes: list[str]) -> tuple[str, str]: for i, j in _pair_indices(len(celltypes)): if f"{celltypes[i]}_{celltypes[j]}" == pair_name: return celltypes[i], celltypes[j] raise KeyError(pair_name) def _adjust_per_label(self, pvalues: np.ndarray, labels: np.ndarray) -> np.ndarray: adjusted = np.full_like(pvalues, np.nan) for label in np.unique(labels): mask = labels == label block = pvalues[mask] for j in range(block.shape[1]): col = block[:, j] valid = np.isfinite(col) if valid.sum() < 1: continue adj = multipletests(col[valid], method="fdr_bh")[1] column_full = np.full_like(col, np.nan) column_full[valid] = adj block[:, j] = column_full adjusted[mask] = block return adjusted def _fisher_per_row(self, pvalues: np.ndarray) -> np.ndarray: out = np.full(pvalues.shape[0], 1.0) for i, row in enumerate(pvalues): finite = row[np.isfinite(row) & (row > 0)] if finite.size == 0: continue stat = -2.0 * np.log(finite).sum() out[i] = float(stats.chi2.sf(stat, df=2 * finite.size)) return out @staticmethod def _pvals_from_zscores(z: np.ndarray) -> np.ndarray: return np.where(np.isfinite(z), 2.0 * stats.norm.sf(np.abs(z)), np.nan) def _split_up_down(self, rows: pd.DataFrame) -> dict[str, list[str]]: if rows.empty: return {"up": [], "down": []} up = rows.loc[rows["up"].astype(bool), "gene"].tolist() down = rows.loc[~rows["up"].astype(bool), "gene"].tolist() return {"up": up, "down": down} def _cca_scores_unresidualized(self, view: AnnData, state: dict, ct: str) -> np.ndarray: W = state["weights"][ct] idx_names = state["weights_index"][ct] kept = [int(name[2:]) - 1 for name in idx_names] pcs = view.obsm[self.feature_space_key][:, : self.n_components][:, kept] return np.asarray(pcs, dtype=np.float64) @ W @staticmethod def _gene_indices(var_names: pd.Index, genes: np.ndarray) -> np.ndarray: lookup = {g: i for i, g in enumerate(var_names)} return np.array([lookup[g] for g in genes], dtype=np.int64) def _refined_pair_correlations( self, adata: AnnData, nnls_scores: dict[str, np.ndarray], ct_views: dict[str, AnnData], celltypes: list[str], ) -> dict[str, dict[str, dict[str, float]]]: out: dict[str, dict[str, dict[str, float]]] = {} sample_avg = { ct: pd.DataFrame(nnls_scores[ct], index=ct_views[ct].obs[self.sample_key].astype(str).to_numpy()) .groupby(level=0) .median() for ct in celltypes } for i, j in _pair_indices(len(celltypes)): ct1, ct2 = celltypes[i], celltypes[j] shared = sorted(set(sample_avg[ct1].index) & set(sample_avg[ct2].index)) if len(shared) < 3: continue a = sample_avg[ct1].loc[shared].to_numpy() b = sample_avg[ct2].loc[shared].to_numpy() pair_name = f"{ct1}_{ct2}" out[pair_name] = {} for p in range(self.n_programs): ap = a[:, p] bp = b[:, p] denom = (ap.std(ddof=1) * bp.std(ddof=1)) or 1.0 r = float(np.cov(ap, bp, ddof=1)[0, 1] / denom) out[pair_name][f"MCP{p + 1}"] = {"R": r} return out
[docs] def test_phenotype_association( self, adata: AnnData, condition_key: str, *, conditions: tuple[str, str] | None = None, ) -> pd.DataFrame: """Test each program's association with a binary phenotype using per-celltype hierarchical models. For every (program, cell type), fits ``score ~ phenotype + cell_quality + (1 | sample)`` on the cells of that cell type, where ``phenotype`` is a binary indicator coded from ``adata.obs[condition_key]``. Returns a DataFrame of signed z-scores (rows = cell types, columns = programs) plus a Fisher-combined p-value column across cell types per program. Args: adata: AnnData after :meth:`refine_scores`. condition_key: Column of ``adata.obs`` with the phenotype labels (categorical with exactly two levels, or pass ``conditions`` to pick which two to compare). conditions: Optional two-element tuple selecting which two values of ``adata.obs[condition_key]`` are compared. Returns: ``zscores`` DataFrame (rows: cell types, columns: programs). The combined p-values are stored on ``adata.uns["dialogue"]["phenotype_pvalues"]``. Examples: >>> import pertpy as pt >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) >>> dl.refine_scores(adata) >>> dl.test_phenotype_association(adata, condition_key="path_str") """ if "dialogue" not in adata.uns: raise RuntimeError("Run fit_programs/refine_scores before test_phenotype_association.") if "X_dialogue" not in adata.obsm: raise RuntimeError("Refined scores missing; run refine_scores(adata) first.") state = adata.uns["dialogue"] celltypes = state["cell_type_order"] if conditions is None: labels = pd.Series(adata.obs[condition_key]).astype("category").cat.categories.tolist() if len(labels) != 2: raise ValueError( f"adata.obs[{condition_key!r}] has {len(labels)} levels; pass `conditions` to pick two." ) conditions = (labels[0], labels[1]) scores = adata.obsm["X_dialogue"] obs = adata.obs program_cols = [f"MCP{p + 1}" for p in range(self.n_programs)] z_table = pd.DataFrame(np.nan, index=celltypes, columns=program_cols) p_table = pd.DataFrame(np.nan, index=celltypes, columns=program_cols) for ct in celltypes: mask = (obs[self.celltype_key] == ct).to_numpy() sub_scores = scores[mask] sub_obs = obs.loc[mask] condition = sub_obs[condition_key].astype(str).to_numpy() keep = np.isin(condition, list(conditions)) if keep.sum() < 5: continue x = (condition[keep] == conditions[1]).astype(float) covariates = pd.DataFrame({self.cell_quality_key: sub_obs[self.cell_quality_key].to_numpy()[keep]}) sample_groups = sub_obs[self.sample_key].astype(str).to_numpy()[keep] for program_idx, program in enumerate(program_cols): y = sub_scores[keep, program_idx] if not np.isfinite(y).any(): continue df_one = _hlm_pvalue_per_row( np.asarray(x[None, :]), y, covariates, sample_groups, ) est = float(df_one["estimate"].iloc[0]) pval = float(df_one["pvalue"].iloc[0]) z_table.loc[ct, program] = float(_zscores_from_signed_pvalues(np.array([est]), np.array([pval]))[0]) p_table.loc[ct, program] = pval # Combine across cell types for each program (the rows are cell types so we transpose). combined = self._fisher_per_row(p_table.to_numpy().T) state["phenotype_pvalues"] = pd.DataFrame({"combined_p": combined}, index=program_cols) state["phenotype_zscores"] = z_table return z_table
[docs] def get_program_genes( self, adata: AnnData, *, program: str, celltype: str | None = None, strict: bool = False, ) -> dict[str, list[str]]: """Return the refined gene signature ``{"up": [...], "down": [...]}`` for a program. Args: adata: AnnData after :meth:`refine_scores`. program: Program label (e.g. ``"MCP1"``). celltype: If given, return only that cell type's signature. Otherwise return the cross-celltype intersection of consistently up/down genes. strict: Use the strict variant from ``program_gene_signatures_strict`` (genes flagged in every pair). Examples: >>> import pertpy as pt >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) >>> dl.refine_scores(adata) >>> dl.get_program_genes(adata, program="MCP1", celltype="CD8+ IELs") """ if "dialogue" not in adata.uns: raise RuntimeError("Run refine_scores before get_program_genes.") key = "program_gene_signatures_strict" if strict else "program_gene_signatures" store = adata.uns["dialogue"][key] if program not in store: raise KeyError(f"Unknown program {program!r}; available: {sorted(store)}") per_ct = store[program] if celltype is not None: if celltype not in per_ct: raise KeyError(f"Cell type {celltype!r} not found in program {program}.") return {k: list(v) for k, v in per_ct[celltype].items()} if not per_ct: return {"up": [], "down": []} common_up = set.intersection(*(set(v["up"]) for v in per_ct.values())) common_down = set.intersection(*(set(v["down"]) for v in per_ct.values())) return {"up": sorted(common_up), "down": sorted(common_down)}
[docs] def find_extreme_score_genes( self, adata: AnnData, *, program: str = "MCP1", fraction: float = 0.1, ) -> dict[str, pd.DataFrame]: """Differential-expression scan between the highest- and lowest-scoring cells per cell type for one program. Args: adata: AnnData after :meth:`refine_scores`. program: Program to use (``"MCP1"`` by default). fraction: Fraction of cells at each tail to compare. Must lie in ``(0, 0.5)``. Examples: >>> import pertpy as pt >>> import scanpy as sc >>> adata = pt.dt.dialogue_example() >>> sc.pp.pca(adata) >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) >>> dl.refine_scores(adata) >>> dl.find_extreme_score_genes(adata, program="MCP1", fraction=0.1) """ if "X_dialogue" not in adata.obsm: raise RuntimeError("Run refine_scores(adata) first.") if not 0 < fraction < 0.5: raise ValueError("fraction must be in (0, 0.5)") idx = int(program.replace("MCP", "")) - 1 scores = adata.obsm["X_dialogue"][:, idx] out: dict[str, pd.DataFrame] = {} for ct in adata.uns["dialogue"]["cell_type_order"]: mask = (adata.obs[self.celltype_key] == ct).to_numpy() & np.isfinite(scores) if mask.sum() < int(2 / fraction): continue ct_scores = scores[mask] lo_cut = np.quantile(ct_scores, fraction) hi_cut = np.quantile(ct_scores, 1 - fraction) sub = adata[mask].copy() sub.obs["_extreme"] = pd.Categorical( np.where(ct_scores >= hi_cut, "high", np.where(ct_scores <= lo_cut, "low", "mid")), categories=["low", "mid", "high"], ) sc.tl.rank_genes_groups(sub, groupby="_extreme", groups=["high"], reference="low", use_raw=False) result = sc.get.rank_genes_groups_df(sub, group="high") out[ct] = result return out
[docs] @_doc_params(common_plot_args=doc_common_plot_args) def plot_split_violins( # pragma: no cover # noqa: D417 self, adata: AnnData, *, condition_key: str, program: str = "MCP1", conditions: tuple[str, str] | None = None, return_fig: bool = False, ) -> Figure | None: """Per-celltype split violin of program scores stratified by a binary phenotype. Args: adata: AnnData processed by :meth:`refine_scores`. condition_key: Column of ``adata.obs`` with the binary phenotype to split on. program: Program label (``"MCP1"``, ``"MCP2"``, ...). conditions: Pick which two values of ``adata.obs[condition_key]`` to plot. Required when the column has more than two levels. {common_plot_args} Returns: If ``return_fig`` is ``True``, returns the figure, otherwise ``None``. Examples: >>> import pertpy as pt >>> adata = pt.dt.dialogue_example() >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) >>> dl.refine_scores(adata) >>> dl.plot_split_violins(adata, condition_key="path_str", program="MCP1") Preview: .. image:: /_static/docstring_previews/dialogue_violin.png """ score_col = f"mcp_{int(program.replace('MCP', '')) - 1}" if score_col not in adata.obs.columns: raise RuntimeError(f"{score_col!r} not in adata.obs; run refine_scores(adata) first.") df = adata.obs[[self.celltype_key, score_col, condition_key]].copy() if conditions is not None: df = df[df[condition_key].isin(conditions)] else: unique = pd.Series(df[condition_key]).dropna().unique() if len(unique) != 2: raise ValueError(f"adata.obs[{condition_key!r}] has {len(unique)} levels; pass `conditions=(a, b)`.") df[condition_key] = pd.Categorical(df[condition_key]).remove_unused_categories() fig, ax = plt.subplots(figsize=(8, 4)) sns.violinplot(data=df, x=self.celltype_key, y=score_col, hue=condition_key, split=True, ax=ax) ax.set_ylabel(program) ax.tick_params(axis="x", rotation=90) plt.tight_layout() if return_fig: return fig plt.show() plt.close(fig) return None
[docs] @_doc_params(common_plot_args=doc_common_plot_args) def plot_pairplot( # pragma: no cover # noqa: D417 self, adata: AnnData, *, color: str, program: str = "MCP1", return_fig: bool = False, ) -> Figure | None: """Cross-celltype pairplot of sample-level program scores, colored by a phenotype. Aggregates each (sample, cell type) to the mean program score, pivots into a sample-by-celltype matrix, and runs ``seaborn.pairplot`` with the sample-level ``color`` annotation as the hue. Args: adata: AnnData processed by :meth:`refine_scores`. color: Column of ``adata.obs`` with the per-sample annotation used as the pairplot hue. program: Program label. {common_plot_args} Returns: If ``return_fig`` is ``True``, returns the figure, otherwise ``None``. Examples: >>> import pertpy as pt >>> adata = pt.dt.dialogue_example() >>> dl = pt.tl.Dialogue(celltype_key="cell.subtypes", sample_key="sample", n_programs=3) >>> dl.fit_programs(adata) >>> dl.test_celltype_pairs(adata) >>> dl.refine_scores(adata) >>> dl.plot_pairplot(adata, color="clinical.status", program="MCP1") Preview: .. image:: /_static/docstring_previews/dialogue_pairplot.png """ score_col = f"mcp_{int(program.replace('MCP', '')) - 1}" if score_col not in adata.obs.columns: raise RuntimeError(f"{score_col!r} not in adata.obs; run refine_scores(adata) first.") sample_means = ( adata.obs.groupby([self.sample_key, self.celltype_key], observed=True)[score_col].mean().unstack() ) sample_color = adata.obs.groupby(self.sample_key, observed=True)[color].first() df = sample_means.copy() df[color] = sample_color grid = sns.pairplot(df, hue=color, corner=True) if return_fig: return grid.fig plt.show() return None
def _param_dict(self) -> dict[str, object]: return { "celltype_key": self.celltype_key, "sample_key": self.sample_key, "cell_quality_key": self.cell_quality_key, "n_programs": self.n_programs, "feature_space_key": self.feature_space_key, "n_components": self.n_components, "n_genes_per_signature": self.n_genes_per_signature, "anova_alpha": self.anova_alpha, "winsorize_quantile": self.winsorize_quantile, "n_permutations": self.n_permutations, "empirical_alpha": self.empirical_alpha, "use_tme_qc": self.use_tme_qc, "additional_covariates": list(self.additional_covariates), "min_cells_per_sample": self.min_cells_per_sample, "random_state": self.random_state, }
def _retained_indices(pseudobulk_full: pd.DataFrame, weights: pd.DataFrame) -> np.ndarray: """Position indices in ``pseudobulk_full`` columns of the components retained in ``weights``.""" full_cols = list(pseudobulk_full.columns) return np.asarray([full_cols.index(c) for c in weights.index], dtype=np.int64) def _pair_indices(n: int) -> list[tuple[int, int]]: return [(i, j) for i in range(n) for j in range(i + 1, n)] def _column_shuffle(matrix: np.ndarray, rng: np.random.Generator) -> np.ndarray: out = matrix.copy() n = matrix.shape[0] for j in range(matrix.shape[1]): perm = rng.permutation(n) out[:, j] = matrix[perm, j] return out @singledispatch def _select_dense_gene_columns(X, var_names, gene_names: list[str]) -> np.ndarray: """Return the dense ``cells × len(gene_names)`` slice of ``X`` for the requested genes. Dispatches on dense ``np.ndarray`` and sparse ``scipy.sparse`` matrices so that a sparse adata never has to be densified beyond the candidate-gene subset (typically a few hundred columns). """ raise NotImplementedError(f"Unsupported X type: {type(X)!r}") @_select_dense_gene_columns.register(np.ndarray) def _select_dense_gene_columns_dense(X: np.ndarray, var_names, gene_names: list[str]) -> np.ndarray: lookup = {g: i for i, g in enumerate(var_names)} idx = np.array([lookup[g] for g in gene_names], dtype=np.int64) return np.asarray(X[:, idx], dtype=np.float64) @_select_dense_gene_columns.register(sp.spmatrix) def _select_dense_gene_columns_sparse(X: sp.spmatrix, var_names, gene_names: list[str]) -> np.ndarray: lookup = {g: i for i, g in enumerate(var_names)} idx = np.array([lookup[g] for g in gene_names], dtype=np.int64) return np.asarray(X.tocsc()[:, idx].toarray(), dtype=np.float64) def _zscore_columns(matrix: np.ndarray) -> np.ndarray: arr = np.asarray(matrix, dtype=np.float64) mean = arr.mean(axis=0, keepdims=True) std = arr.std(axis=0, ddof=1, keepdims=True) std = np.where(std > 0, std, 1.0) return (arr - mean) / std