from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import arviz as az
import jax.numpy as jnp
import numpy as np
import numpyro as npy
import numpyro.distributions as npd
from anndata import AnnData
from jax import config, random
from lamin_utils import logger
from mudata import MuData
from numpyro.infer import Predictive
from rich import print
from pertpy.tools._coda._base_coda import CompositionalModel2, from_scanpy
if TYPE_CHECKING:
import pandas as pd
config.update("jax_enable_x64", True)
[docs]
class Sccoda(CompositionalModel2):
"""
Statistical model for single-cell differential composition analysis with specification of a reference cell type.
This is the standard scCODA model and recommended for all uses.
The hierarchical formulation of the model for one sample is:
.. math::
y|x &\\sim DirMult(\\phi, \\bar{y}) \\\\
\\log(\\phi) &= \\alpha + x \\beta \\\\
\\alpha_k &\\sim N(0, 5) \\quad &\\forall k \\in [K] \\\\
\\beta_{m, \\hat{k}} &= 0 &\\forall m \\in [M]\\\\
\\beta_{m, k} &= \\tau_{m, k} \\tilde{\\beta}_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
\\tau_{m, k} &= \\frac{\\exp(t_{m, k})}{1+ \\exp(t_{m, k})} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
\\frac{t_{m, k}}{50} &\\sim N(0, 1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
\\tilde{\\beta}_{m, k} &= \\sigma_m^2 \\cdot \\gamma_{m, k} \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
\\sigma_m^2 &\\sim HC(0, 1) \\quad &\\forall m \\in [M] \\\\
\\gamma_{m, k} &\\sim N(0,1) \\quad &\\forall m \\in [M], k \\in \\{[K] \\smallsetminus \\hat{k}\\} \\\\
with y being the cell counts and x the covariates.
For further information, see `scCODA is a Bayesian model for compositional single-cell data analysis`
(Büttner, Ostner et al., NatComms, 2021)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def load(
self,
adata: AnnData,
type: Literal["cell_level", "sample_level"],
generate_sample_level: bool = True,
cell_type_identifier: str = None,
sample_identifier: str = None,
covariate_uns: str | None = None,
covariate_obs: list[str] | None = None,
covariate_df: pd.DataFrame | None = None,
modality_key_1: str = "rna",
modality_key_2: str = "coda",
) -> MuData:
"""Prepare a MuData object for subsequent processing. If type is "cell_level", then create a compositional analysis dataset from the input adata.
When using ``type="cell_level"``, ``adata`` needs to have a column in ``adata.obs`` that contains the cell type assignment.
Further, it must contain one column or a set of columns (e.g. subject id, treatment, disease status) that uniquely identify each (statistical) sample.
Further covariates (e.g. subject age) can either be specified via addidional column names in ``adata.obs``, a key in ``adata.uns``, or as a separate DataFrame.
Args:
adata: AnnData object.
type : Specify the input adata type, which could be either a cell-level AnnData or an aggregated sample-level AnnData.
generate_sample_level: Whether to generate an AnnData object on the sample level or create an empty AnnData object.
cell_type_identifier: If type is "cell_level", specify column name in adata.obs that specifies the cell types. Defaults to None.
sample_identifier: If type is "cell_level", specify column name in adata.obs that specifies the sample. Defaults to None.
covariate_uns: If type is "cell_level", specify key for adata.uns, where covariate values are stored. Defaults to None.
covariate_obs: If type is "cell_level", specify list of keys for adata.obs, where covariate values are stored. Defaults to None.
covariate_df: If type is "cell_level", specify dataFrame with covariates. Defaults to None.
modality_key_1: Key to the cell-level AnnData in the MuData object. Defaults to "rna".
modality_key_2: Key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
Returns:
MuData: MuData object with cell-level AnnData (`mudata[modality_key_1]`) and aggregated sample-level AnnData (`mudata[modality_key_2]`).
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch", covariate_obs=["condition"])
"""
if type == "cell_level":
if generate_sample_level:
adata_coda = from_scanpy(
adata=adata,
cell_type_identifier=cell_type_identifier,
sample_identifier=sample_identifier,
covariate_uns=covariate_uns,
covariate_obs=covariate_obs,
covariate_df=covariate_df,
)
else:
adata_coda = AnnData()
mdata = MuData({modality_key_1: adata, modality_key_2: adata_coda})
else:
mdata = MuData({modality_key_1: AnnData(), modality_key_2: adata})
return mdata
[docs]
def prepare(
self,
data: AnnData | MuData,
formula: str,
reference_cell_type: str = "automatic",
automatic_reference_absence_threshold: float = 0.05,
modality_key: str = "coda",
) -> AnnData | MuData:
"""Handles data preprocessing, covariate matrix creation, reference selection, and zero count replacement for scCODA.
Args:
data: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
formula: R-style formula for building the covariate matrix.
Categorical covariates are handled automatically, with the covariate value of the first sample being used as the reference category.
To set a different level as the base category for a categorical covariate, use "C(<CovariateName>, Treatment('<ReferenceLevelName>'))"
reference_cell_type: Column name that sets the reference cell type.
Reference the name of a column. If "automatic", the cell type with the lowest dispersion in relative abundance that is present in at least 90% of samlpes will be chosen. Defaults to "automatic".
automatic_reference_absence_threshold: If using reference_cell_type = "automatic", determine the maximum fraction of zero entries for a cell type
to be considered as a possible reference cell type. Defaults to 0.05.
modality_key: If data is a MuData object, specify key to the aggregated sample-level AnnData object in the MuData object. Defaults to "coda".
Returns:
Return an AnnData (if input data is an AnnData object) or return a MuData (if input data is a MuData object)
Specifically, parameters have been set:
- `adata.uns["param_names"]` or `data[modality_key].uns["param_names"]`: List with the names of all tracked latent model parameters (through `npy.sample` or `npy.deterministic`)
- `adata.uns["scCODA_params"]["model_type"]` or `data[modality_key].uns["scCODA_params"]["model_type"]`: String indicating the model type ("classic")
- `adata.uns["scCODA_params"]["select_type"]` or `data[modality_key].uns["scCODA_params"]["select_type"]`: String indicating the type of spike_and_slab selection ("spikeslab")
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
"""
if isinstance(data, MuData):
adata = data[modality_key]
is_MuData = True
if isinstance(data, AnnData):
adata = data
is_MuData = False
adata = super().prepare(adata, formula, reference_cell_type, automatic_reference_absence_threshold)
# All parameters that are returned for analysis
adata.uns["scCODA_params"]["param_names"] = [
"sigma_d",
"b_offset",
"ind_raw",
"alpha",
"ind",
"b_raw",
"beta",
"concentration",
"prediction",
]
adata.uns["scCODA_params"]["model_type"] = "classic"
adata.uns["scCODA_params"]["select_type"] = "spikeslab"
if is_MuData:
data.mod[modality_key] = adata
return data
else:
return adata
[docs]
def set_init_mcmc_states(self, rng_key: None, ref_index: np.ndarray, sample_adata: AnnData) -> AnnData: # type: ignore
"""
Sets initial MCMC state values for scCODA model
Args:
rng_key: RNG value to be set
ref_index: Index of reference feature
sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
Returns:
Return AnnData object.
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> adata = sccoda.set_init_mcmc_states(rng_key=42, ref_index=0, sample_adata=mdata["coda"])
"""
# data dimensions
N, D = sample_adata.obsm["covariate_matrix"].shape
P = sample_adata.X.shape[1]
# Sizes of different parameter matrices
alpha_size = [P]
sigma_size = [D, 1]
beta_nobl_size = [D, P - 1]
# Initial MCMC states
rng = np.random.default_rng(seed=rng_key)
sample_adata.uns["scCODA_params"]["mcmc"]["init_params"] = {
"sigma_d": np.ones(dtype=np.float64, shape=sigma_size),
"b_offset": rng.normal(0.0, 1.0, beta_nobl_size),
"ind_raw": np.zeros(dtype=np.float64, shape=beta_nobl_size),
"alpha": rng.normal(0.0, 1.0, alpha_size),
}
return sample_adata
[docs]
def model( # type: ignore
self,
counts: np.ndarray,
covariates: np.ndarray,
n_total: np.ndarray,
ref_index,
sample_adata: AnnData,
):
"""
Implements scCODA model in numpyro
Args:
counts: Count data array
covariates: Covariate matrix
n_total: Number of counts per sample
ref_index: Index of reference feature
sample_adata: Anndata object with cell counts as sample_adata.X and covariates saved in sample_adata.obs.
Returns:
predictions (see numpyro documentation for details on models)
"""
# data dimensions
N, D = sample_adata.obsm["covariate_matrix"].shape
P = sample_adata.X.shape[1]
# numpyro plates for all dimensions
covariate_axis = npy.plate("covs", D, dim=-2)
cell_type_axis = npy.plate("ct", P, dim=-1)
cell_type_axis_nobl = npy.plate("ctnb", P - 1, dim=-1)
sample_axis = npy.plate("sample", N, dim=-2)
# Effect priors
with covariate_axis:
sigma_d = npy.sample("sigma_d", npd.HalfCauchy(1.0))
with covariate_axis, cell_type_axis_nobl:
b_offset = npy.sample("b_offset", npd.Normal(0.0, 1.0))
# spike-and-slab
ind_raw = npy.sample("ind_raw", npd.Normal(0.0, 1.0))
ind_scaled = ind_raw * 50
ind = npy.deterministic("ind", jnp.exp(ind_scaled) / (1 + jnp.exp(ind_scaled)))
b_raw = sigma_d * b_offset
beta_raw = npy.deterministic("b_raw", ind * b_raw)
with cell_type_axis:
# Intercepts
alpha = npy.sample("alpha", npd.Normal(0.0, 5.0))
# Add 0 effect reference feature
with covariate_axis:
beta_full = jnp.concatenate(
(beta_raw[:, :ref_index], jnp.zeros(shape=[D, 1]), beta_raw[:, ref_index:]), axis=-1
)
beta = npy.deterministic("beta", beta_full)
# Combine intercepts and effects
with sample_axis:
concentrations = npy.deterministic(
"concentrations", jnp.nan_to_num(jnp.exp(alpha + jnp.matmul(covariates, beta)), 0.0001)
)
# Calculate DM-distributed counts
predictions = npy.sample("counts", npd.DirichletMultinomial(concentrations, n_total), obs=counts)
return predictions
[docs]
def make_arviz( # type: ignore
self,
data: AnnData | MuData,
modality_key: str = "coda",
rng_key=None,
num_prior_samples: int = 500,
use_posterior_predictive: bool = True,
) -> az.InferenceData:
"""Creates arviz object from model results for MCMC diagnosis
Args:
data: AnnData object or MuData object.
modality_key: If data is a MuData object, specify which modality to use. Defaults to "coda".
rng_key: The rng state used for the prior simulation. If None, a random state will be selected. Defaults to None.
num_prior_samples: Number of prior samples calculated. Defaults to 500.
use_posterior_predictive: If True, the posterior predictive will be calculated. Defaults to True.
Returns:
az.InferenceData: arviz_data with all MCMC information
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
>>> arviz_data = sccoda.make_arviz(mdata, num_prior_samples=100)
"""
if isinstance(data, MuData):
try:
sample_adata = data[modality_key]
except IndexError:
logger.error("When data is a MuData object, modality_key must be specified!")
raise
if isinstance(data, AnnData):
sample_adata = data
if not self.mcmc:
raise ValueError("No MCMC sampling found. Please run a sampler first!")
# feature names
cell_types = sample_adata.var.index.to_list()
# arviz dimensions
dims = {
"alpha": ["cell_type"],
"sigma_d": ["covariate", "0"],
"b_offset": ["covariate", "cell_type_nb"],
"ind_raw": ["covariate", "cell_type_nb"],
"ind": ["covariate", "cell_type_nb"],
"b_raw": ["covariate", "cell_type_nb"],
"beta": ["covariate", "cell_type"],
"concentrations": ["sample", "cell_type"],
"predictions": ["sample", "cell_type"],
"counts": ["sample", "cell_type"],
}
# arviz coordinates
reference_index = sample_adata.uns["scCODA_params"]["reference_index"]
cell_types_nb = cell_types[:reference_index] + cell_types[reference_index + 1 :]
coords = {
"cell_type": cell_types,
"cell_type_nb": cell_types_nb,
"covariate": sample_adata.uns["scCODA_params"]["covariate_names"],
"sample": sample_adata.obs.index,
}
dtype = "float64"
# Prior and posterior predictive simulation
numpyro_covariates = jnp.array(sample_adata.obsm["covariate_matrix"], dtype=dtype)
numpyro_n_total = jnp.array(sample_adata.obsm["sample_counts"], dtype=dtype)
ref_index = jnp.array(sample_adata.uns["scCODA_params"]["reference_index"])
if rng_key is None:
rng = np.random.default_rng()
rng_key = random.key(rng.integers(0, 10000))
if use_posterior_predictive:
posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
rng_key,
counts=None,
covariates=numpyro_covariates,
n_total=numpyro_n_total,
ref_index=ref_index,
sample_adata=sample_adata,
)
else:
posterior_predictive = None
if num_prior_samples > 0:
prior = Predictive(self.model, num_samples=num_prior_samples)(
rng_key,
counts=None,
covariates=numpyro_covariates,
n_total=numpyro_n_total,
ref_index=ref_index,
sample_adata=sample_adata,
)
else:
prior = None
# Create arviz object
arviz_data = az.from_numpyro(
self.mcmc, prior=prior, posterior_predictive=posterior_predictive, dims=dims, coords=coords
)
return arviz_data
[docs]
def run_nuts(
self,
data: AnnData | MuData,
modality_key: str = "coda",
num_samples: int = 10000,
num_warmup: int = 1000,
rng_key: int = 0,
copy: bool = False,
*args,
**kwargs,
):
"""
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
"""
return super().run_nuts(data, modality_key, num_samples, num_warmup, rng_key, copy, *args, **kwargs)
run_nuts.__doc__ = CompositionalModel2.run_nuts.__doc__ + run_nuts.__doc__
[docs]
def credible_effects(self, data: AnnData | MuData, modality_key: str = "coda", est_fdr: float = None) -> pd.Series:
"""
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
>>> credible_effects = sccoda.credible_effects(mdata)
"""
return super().credible_effects(data, modality_key, est_fdr)
credible_effects.__doc__ = CompositionalModel2.credible_effects.__doc__ + credible_effects.__doc__
[docs]
def summary(self, data: AnnData | MuData, extended: bool = False, modality_key: str = "coda", *args, **kwargs):
"""
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
>>> sccoda.summary(mdata)
"""
return super().summary(data, extended, modality_key, *args, **kwargs)
summary.__doc__ = CompositionalModel2.summary.__doc__ + summary.__doc__
[docs]
def set_fdr(self, data: AnnData | MuData, est_fdr: float, modality_key: str = "coda", *args, **kwargs):
"""
Examples:
>>> import pertpy as pt
>>> haber_cells = pt.dt.haber_2017_regions()
>>> sccoda = pt.tl.Sccoda()
>>> mdata = sccoda.load(haber_cells,
>>> type="cell_level",
>>> generate_sample_level=True,
>>> cell_type_identifier="cell_label",
>>> sample_identifier="batch",
>>> covariate_obs=["condition"])
>>> mdata = sccoda.prepare(mdata, formula="condition", reference_cell_type="Endocrine")
>>> sccoda.run_nuts(mdata, num_warmup=100, num_samples=1000, rng_key=42)
>>> sccoda.set_fdr(mdata, est_fdr=0.4)
"""
return super().set_fdr(data, est_fdr, modality_key, *args, **kwargs)
set_fdr.__doc__ = CompositionalModel2.set_fdr.__doc__ + set_fdr.__doc__