Source code for pertpy.preprocessing._guide_rna

from __future__ import annotations

import uuid
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Literal
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from numba import njit, prange
from rich.progress import track
from scanpy.get import _get_obs_rep, _set_obs_rep
from scipy.sparse import csr_matrix, issparse

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy._types import CSRBase
from pertpy.preprocessing._guide_rna_mixture import compute_count_thresholds, fit_poisson_gauss_mixture

if TYPE_CHECKING:
    from matplotlib.pyplot import Figure


[docs] class GuideAssignment: """Assign cells to guide RNAs."""
[docs] @singledispatchmethod def assign_by_threshold( self, data: AnnData | np.ndarray | CSRBase, /, *, assignment_threshold: float, layer: str | None = None, output_layer: str = "assigned_guides", ): """Simple threshold based gRNA assignment function. Each cell is assigned to gRNA with at least `assignment_threshold` counts. This function expects unnormalized data as input. Args: data: The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. assignment_threshold: The count threshold that is required for an assignment to be viable. layer: Key to the layer containing raw count values of the gRNAs. adata.X is used if layer is None. Expects count data. output_layer: Assigned guide will be saved on adata.layers[output_key]. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell. >>> import pertpy as pt >>> mdata = pt.data.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) """ raise NotImplementedError( f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object." )
@assign_by_threshold.register(AnnData) def _assign_by_threshold_anndata( self, adata: AnnData, /, *, assignment_threshold: float, layer: str | None = None, output_layer: str = "assigned_guides", ) -> None: X = _get_obs_rep(adata, layer=layer) guide_assignments = self.assign_by_threshold(X, assignment_threshold=assignment_threshold) _set_obs_rep(adata, guide_assignments, layer=output_layer) @assign_by_threshold.register(np.ndarray) def _assign_by_threshold_numpy(self, X: np.ndarray, /, *, assignment_threshold: float) -> np.ndarray: return np.where(assignment_threshold <= X, 1, 0) @staticmethod @njit(parallel=True) def _threshold_sparse_numba(data: np.ndarray, threshold: float) -> np.ndarray: out = np.zeros_like(data, dtype=np.int8) for i in prange(data.shape[0]): if data[i] >= threshold: out[i] = 1 return out @assign_by_threshold.register(CSRBase) def _assign_by_threshold_sparse(self, X: CSRBase, /, *, assignment_threshold: float) -> CSRBase: new_data = self._threshold_sparse_numba(X.data, assignment_threshold) return csr_matrix((new_data, X.indices, X.indptr), shape=X.shape)
[docs] @singledispatchmethod def assign_to_max_guide( self, data: AnnData | np.ndarray | CSRBase, /, *, assignment_threshold: float, layer: str | None = None, obs_key: str = "assigned_guide", no_grna_assigned_key: str = "Negative", ) -> np.ndarray | None: """Simple threshold based max gRNA assignment function. Each cell is assigned to the most expressed gRNA if it has at least `assignment_threshold` counts. This function expects unnormalized data as input. Args: data: The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. assignment_threshold: The count threshold that is required for an assignment to be viable. layer: Key to the layer containing raw count values of the gRNAs. adata.X is used if layer is None. Expects count data. obs_key: Assigned guide will be saved on adata.obs[output_key]. no_grna_assigned_key: The key to return if no gRNA is expressed enough. Examples: Each cell is assigned to the most expressed gRNA if it has at least 5 counts. >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_to_max_guide(gdo, assignment_threshold=5) """ raise NotImplementedError( f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object." )
[docs] @assign_to_max_guide.register(AnnData) def assign_to_max_guide_anndata( self, adata: AnnData, /, *, assignment_threshold: float, layer: str | None = None, obs_key: str = "assigned_guide", no_grna_assigned_key: str = "Negative", ) -> None: X = _get_obs_rep(adata, layer=layer) guide_assignments = self.assign_to_max_guide( X, var=adata.var, assignment_threshold=assignment_threshold, no_grna_assigned_key=no_grna_assigned_key ) adata.obs[obs_key] = guide_assignments
[docs] @assign_to_max_guide.register(np.ndarray) def assign_to_max_guide_numpy( self, X: np.ndarray, /, *, var: pd.DataFrame, assignment_threshold: float, no_grna_assigned_key: str = "Negative", ) -> np.ndarray: assigned_grna = np.where( X.max(axis=1).squeeze() >= assignment_threshold, var.index[X.argmax(axis=1).squeeze()], no_grna_assigned_key, ) return assigned_grna
@staticmethod @njit(parallel=True) def _assign_max_guide_sparse(indptr, data, indices, assignment_threshold, assigned_grna): n_rows = len(indptr) - 1 for i in range(n_rows): row_start = indptr[i] row_end = indptr[i + 1] if row_end > row_start: data_row = data[row_start:row_end] indices_row = indices[row_start:row_end] max_pos = np.argmax(data_row) if data_row[max_pos] >= assignment_threshold: assigned_grna[i] = indices_row[max_pos] return assigned_grna
[docs] @assign_to_max_guide.register(CSRBase) def assign_to_max_guide_sparse( self, X: CSRBase, /, *, var: pd.DataFrame, assignment_threshold: float, no_grna_assigned_key: str = "Negative" ) -> np.ndarray: n_rows = X.shape[0] assigned_positions = np.zeros(n_rows, dtype=np.int32) - 1 # -1 means not assigned assigned_positions = self._assign_max_guide_sparse( X.indptr, X.data, X.indices, assignment_threshold, assigned_positions ) assigned_grna = np.full(n_rows, no_grna_assigned_key, dtype=object) mask = assigned_positions >= 0 var_index_array = np.array(var.index) if np.any(mask): assigned_grna[mask] = var_index_array[assigned_positions[mask]] return assigned_grna
[docs] @singledispatchmethod def assign_mixture_model( self, data: AnnData | np.ndarray | CSRBase, /, *, model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture", layer: str | None = None, assigned_guides_key: str = "assigned_guide", no_grna_assigned_key: str = "negative", max_assignments_per_cell: int = 5, multiple_grna_assigned_key: str = "multiple", multiple_grna_assignment_string: str = "+", only_return_results: bool = False, show_progress: bool = False, n_iter: int = 500, learning_rate: float = 0.01, n_init_seeds: int = 10, seed: int = 2024, ) -> np.ndarray | None: """Assigns gRNAs to cells using a Poisson-Gaussian mixture model. The model, priors, and per-guide thresholding rule reproduce ``crispat.ga_poisson_gauss`` (Velten group, https://github.com/velten-group/crispat). MAP estimation runs for all guides in parallel on JAX, replacing crispat's per-guide Pyro SVI loop. For each guide, the model is fit only to cells with non-zero counts. Log2 counts are modelled as a mixture of a continuous Poisson background and a Gaussian on-target component with priors ``weights ~ Dirichlet([0.9, 0.1])``, ``mu ~ Normal(3, 2)``, ``scale ~ LogNormal(2, 1)``, ``lam ~ LogNormal(0, 1)``. A cell is assigned to a guide if its UMI count is at least the smallest integer ``t`` for which ``P(Normal | log2(t)) > 0.5``. Args: data: AnnData with gRNA counts, or a dense or sparse cell-by-guide count matrix. model: The mixture model to use; currently only ``"poisson_gauss_mixture"`` is supported. layer: Layer name to use when ``data`` is an AnnData (defaults to ``X``). assigned_guides_key: Per-cell assignment is saved on ``adata.obs[assigned_guides_key]``. no_grna_assigned_key: Key to use when a cell is negative for all gRNAs. max_assignments_per_cell: Maximum number of gRNAs that can be assigned to a cell. multiple_grna_assigned_key: Key to use when more than ``max_assignments_per_cell`` gRNAs are assigned. multiple_grna_assignment_string: Separator used to join multiple gRNAs assigned to one cell. only_return_results: If ``True``, do not modify ``adata`` and return the assignment array. show_progress: Whether to print a progress line. n_iter: Optimization steps for the SVI loop (crispat default: 500). learning_rate: Adam learning rate (crispat default: 0.01). n_init_seeds: Number of prior-sampled inits per guide; best is kept (crispat default: 10). seed: Random seed used for initialization. Examples: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_mixture_model(gdo) """ raise NotImplementedError( f"No implementation found for {type(data)}. Must be numpy array, sparse matrix, or AnnData object." )
@assign_mixture_model.register(AnnData) def _assign_mixture_model_anndata( self, adata: AnnData, /, *, model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture", layer: str | None = None, assigned_guides_key: str = "assigned_guide", no_grna_assigned_key: str = "negative", max_assignments_per_cell: int = 5, multiple_grna_assigned_key: str = "multiple", multiple_grna_assignment_string: str = "+", only_return_results: bool = False, show_progress: bool = False, n_iter: int = 500, learning_rate: float = 0.01, n_init_seeds: int = 10, seed: int = 2024, ) -> np.ndarray | None: if model != "poisson_gauss_mixture": raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.") X = _get_obs_rep(adata, layer=layer) result = self._fit_mixture_pg( X, guide_names=list(adata.var_names), n_iter=n_iter, learning_rate=learning_rate, n_init_seeds=n_init_seeds, seed=seed, show_progress=show_progress, ) adata.var["poisson_rate"] = result["poisson_rate"] adata.var["gaussian_mean"] = result["gaussian_mean"] adata.var["gaussian_std"] = result["gaussian_std"] adata.var["mix_probs_0"] = result["mix_probs_pois"] adata.var["mix_probs_1"] = result["mix_probs_norm"] adata.var["threshold"] = result["thresholds"] adata.var["final_loss"] = result["final_loss"] assignments = self._binary_to_per_cell_strings( result["binary"], np.asarray(adata.var_names, dtype=object), no_grna_assigned_key=no_grna_assigned_key, max_assignments_per_cell=max_assignments_per_cell, multiple_grna_assigned_key=multiple_grna_assigned_key, multiple_grna_assignment_string=multiple_grna_assignment_string, ) if only_return_results: return assignments adata.obs[assigned_guides_key] = pd.Categorical(assignments) return None @assign_mixture_model.register(np.ndarray) def _assign_mixture_model_numpy( self, X: np.ndarray, /, *, var: pd.DataFrame, model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture", no_grna_assigned_key: str = "negative", max_assignments_per_cell: int = 5, multiple_grna_assigned_key: str = "multiple", multiple_grna_assignment_string: str = "+", show_progress: bool = False, n_iter: int = 500, learning_rate: float = 0.01, n_init_seeds: int = 10, seed: int = 2024, ) -> np.ndarray: if model != "poisson_gauss_mixture": raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.") result = self._fit_mixture_pg( X, guide_names=list(var.index), n_iter=n_iter, learning_rate=learning_rate, n_init_seeds=n_init_seeds, seed=seed, show_progress=show_progress, ) return self._binary_to_per_cell_strings( result["binary"], np.asarray(var.index, dtype=object), no_grna_assigned_key=no_grna_assigned_key, max_assignments_per_cell=max_assignments_per_cell, multiple_grna_assigned_key=multiple_grna_assigned_key, multiple_grna_assignment_string=multiple_grna_assignment_string, ) @assign_mixture_model.register(CSRBase) def _assign_mixture_model_sparse( self, X: CSRBase, /, *, var: pd.DataFrame, model: Literal["poisson_gauss_mixture"] = "poisson_gauss_mixture", no_grna_assigned_key: str = "negative", max_assignments_per_cell: int = 5, multiple_grna_assigned_key: str = "multiple", multiple_grna_assignment_string: str = "+", show_progress: bool = False, n_iter: int = 500, learning_rate: float = 0.01, n_init_seeds: int = 10, seed: int = 2024, ) -> np.ndarray: if model != "poisson_gauss_mixture": raise ValueError("Model not implemented. Please use 'poisson_gauss_mixture'.") result = self._fit_mixture_pg( X, guide_names=list(var.index), n_iter=n_iter, learning_rate=learning_rate, n_init_seeds=n_init_seeds, seed=seed, show_progress=show_progress, ) return self._binary_to_per_cell_strings( result["binary"], np.asarray(var.index, dtype=object), no_grna_assigned_key=no_grna_assigned_key, max_assignments_per_cell=max_assignments_per_cell, multiple_grna_assigned_key=multiple_grna_assigned_key, multiple_grna_assignment_string=multiple_grna_assignment_string, ) def _fit_mixture_pg( self, X: np.ndarray | CSRBase, *, guide_names: list[str], n_iter: int, learning_rate: float, n_init_seeds: int, seed: int, show_progress: bool, ) -> dict[str, np.ndarray]: """Fit the Poisson-Gaussian mixture for every guide and produce a binary assignment matrix. Dispatches the per-guide nonzero extraction and per-guide thresholding on dense vs sparse so that a sparse input never gets densified at full ``[cells, guides]`` size. """ if issparse(X): X_csc = X.tocsc() n_cells, n_guides = X_csc.shape if X_csc.data.size and X_csc.data.min() < 0: raise ValueError( "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model." ) nonzero_counts = np.diff(X_csc.indptr).astype(np.int64) max_counts = np.zeros(n_guides, dtype=np.float64) if X_csc.data.size: max_counts[:] = X_csc.max(axis=0).toarray().ravel() def _nonzero_values(g: int) -> np.ndarray: start, end = X_csc.indptr[g], X_csc.indptr[g + 1] return np.asarray(X_csc.data[start:end]) else: X_dense = np.ascontiguousarray(np.asarray(X), dtype=np.float32) n_cells, n_guides = X_dense.shape if np.any(X_dense < 0): raise ValueError( "Data contains negative values. Please use non-negative data for guide assignment with the Mixture Model." ) nonzero_counts = (X_dense != 0).sum(axis=0).astype(np.int64) max_counts = X_dense.max(axis=0).astype(np.float64) def _nonzero_values(g: int) -> np.ndarray: col = X_dense[:, g] return col[col != 0] # crispat fits only on cells with non-zero counts and requires max count >= 2. fittable_mask = (nonzero_counts >= 2) & (max_counts >= 2) for gene_idx in np.where(~fittable_mask)[0]: gene = guide_names[gene_idx] if nonzero_counts[gene_idx] < 2: warn(f"Skipping {gene} as there are less than 2 cells expressing the guide at all.", stacklevel=3) else: warn(f"Skipping {gene} as its maximum count is below 2.", stacklevel=3) fittable_idx = np.where(fittable_mask)[0] thresholds = np.full(n_guides, np.nan, dtype=np.float64) poisson_rate = np.full(n_guides, np.nan, dtype=np.float64) gaussian_mean = np.full(n_guides, np.nan, dtype=np.float64) gaussian_std = np.full(n_guides, np.nan, dtype=np.float64) mix_probs_pois = np.full(n_guides, np.nan, dtype=np.float64) mix_probs_norm = np.full(n_guides, np.nan, dtype=np.float64) final_loss = np.full(n_guides, np.nan, dtype=np.float64) if fittable_idx.size > 0: n_max = int(nonzero_counts[fittable_idx].max()) data_batch = np.zeros((fittable_idx.size, n_max), dtype=np.float32) mask_batch = np.zeros((fittable_idx.size, n_max), dtype=bool) for i, g in enumerate(fittable_idx): nz = _nonzero_values(int(g)) log_values = np.log2(nz).astype(np.float32) data_batch[i, : log_values.size] = log_values mask_batch[i, : log_values.size] = True if show_progress: print(f"Fitting Poisson-Gaussian mixture for {fittable_idx.size} guides in parallel.") fit = fit_poisson_gauss_mixture( data_batch, mask_batch, n_iter=n_iter, learning_rate=learning_rate, n_init_seeds=n_init_seeds, seed=seed, ) thresholds[fittable_idx] = compute_count_thresholds(fit, max_counts[fittable_idx].astype(np.int64)) poisson_rate[fittable_idx] = fit.poisson_rate gaussian_mean[fittable_idx] = fit.gaussian_mean gaussian_std[fittable_idx] = fit.gaussian_std mix_probs_pois[fittable_idx] = fit.mix_probs[:, 0] mix_probs_norm[fittable_idx] = fit.mix_probs[:, 1] final_loss[fittable_idx] = fit.final_loss # Per-cell binary assignment using the per-guide thresholds; thresholds are >=1 so only # the non-zero entries of each column can pass. binary = np.zeros((n_cells, n_guides), dtype=np.int8) valid_idx = np.where(~np.isnan(thresholds))[0] if issparse(X): for g in valid_idx: thr = thresholds[g] start, end = X_csc.indptr[g], X_csc.indptr[g + 1] data_col = X_csc.data[start:end] if data_col.size == 0: continue sel = data_col >= thr if sel.any(): binary[X_csc.indices[start:end][sel], g] = 1 elif valid_idx.size: valid_thr = thresholds[valid_idx] binary[:, valid_idx] = (X_dense[:, valid_idx] >= valid_thr[None, :]).astype(np.int8) return { "binary": binary, "thresholds": thresholds, "poisson_rate": poisson_rate, "gaussian_mean": gaussian_mean, "gaussian_std": gaussian_std, "mix_probs_pois": mix_probs_pois, "mix_probs_norm": mix_probs_norm, "final_loss": final_loss, } @staticmethod def _binary_to_per_cell_strings( binary: np.ndarray, guide_names: np.ndarray, *, no_grna_assigned_key: str, max_assignments_per_cell: int, multiple_grna_assigned_key: str, multiple_grna_assignment_string: str, ) -> np.ndarray: n_cells = binary.shape[0] num_guides_assigned = binary.sum(axis=1) assignments = np.full(n_cells, no_grna_assigned_key, dtype=object) multi_mask = (num_guides_assigned > 0) & (num_guides_assigned <= max_assignments_per_cell) for cell_idx in np.where(multi_mask)[0]: assigned_guides = guide_names[binary[cell_idx] == 1] assignments[cell_idx] = multiple_grna_assignment_string.join(assigned_guides.tolist()) assignments[num_guides_assigned > max_assignments_per_cell] = multiple_grna_assigned_key return assignments
[docs] @_doc_params(common_plot_args=doc_common_plot_args) def plot_heatmap( # pragma: no cover # noqa: D417 self, adata: AnnData, *, layer: str | None = None, order_by: np.ndarray | str | None = None, key_to_save_order: str = None, return_fig: bool = False, **kwargs, ) -> Figure | None: """Heatmap plotting of guide RNA expression matrix. Assuming guides have sparse expression, this function reorders cells and plots guide RNA expression so that a nice sparse representation is achieved. The cell ordering can be stored and reused in future plots to obtain consistent plots before and after analysis of the guide RNA expression. Note: This function expects a log-normalized or binary data. Args: adata: Annotated data matrix containing gRNA values layer: Key to the layer containing log normalized count values of the gRNAs. adata.X is used if layer is None. order_by: The order of cells in y axis. If None, cells will be reordered to have a nice sparse representation. If a string is provided, adata.obs[order_by] will be used as the order. If a numpy array is provided, the array will be used for ordering. key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None. {common_plot_args} kwargs: Are passed to sc.pl.heatmap. Returns: If `return_fig` is `True`, returns the figure, otherwise `None`. Order of cells in the y-axis will be saved on `adata.obs[key_to_save_order]` if provided. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then visualized using a heatmap. >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) >>> ga.plot_heatmap(gdo) """ data = adata.X if layer is None else adata.layers[layer] if order_by is None: if issparse(data): max_values = data.max(axis=1).toarray().squeeze() data_argmax = data.argmax(axis=1).A.squeeze() max_guide_index = np.where(max_values != data.min(axis=1).toarray().squeeze(), data_argmax, -1) else: max_guide_index = np.where( data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1 ) order = np.argsort(max_guide_index) elif isinstance(order_by, str): order = np.argsort(adata.obs[order_by]) else: order = order_by temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}" adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])]) if key_to_save_order is not None: adata.obs[key_to_save_order] = pd.Categorical(order) try: fig = sc.pl.heatmap( adata[order, :], var_names=adata.var.index.tolist(), groupby=temp_col_name, use_raw=False, dendrogram=False, layer=layer, show=False, **kwargs, ) finally: del adata.obs[temp_col_name] if return_fig: return fig plt.show() return None