Source code for pertpy.preprocessing._guide_rna

from __future__ import annotations

import uuid
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import scanpy as sc
import scipy

if TYPE_CHECKING:
    from anndata import AnnData
    from matplotlib.axes import Axes


[docs] class GuideAssignment: """Offers simple guide assigment based on count thresholds."""
[docs] def assign_by_threshold( self, adata: AnnData, assignment_threshold: float, layer: str | None = None, output_layer: str = "assigned_guides", only_return_results: bool = False, ) -> np.ndarray | None: """Simple threshold based gRNA assignment function. Each cell is assigned to gRNA with at least `assignment_threshold` counts. This function expects unnormalized data as input. Args: adata: Annotated data matrix containing gRNA values assignment_threshold: The count threshold that is required for an assignment to be viable. layer: Key to the layer containing raw count values of the gRNAs. adata.X is used if layer is None. Expects count data. output_layer: Assigned guide will be saved on adata.layers[output_key]. Defaults to `assigned_guides`. only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. Defaults to False. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell. >>> import pertpy as pt >>> mdata = pt.data.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) """ counts = adata.X if layer is None else adata.layers[layer] if scipy.sparse.issparse(counts): counts = counts.A assigned_grnas = np.where(counts >= assignment_threshold, 1, 0) assigned_grnas = scipy.sparse.csr_matrix(assigned_grnas) if only_return_results: return assigned_grnas adata.layers[output_layer] = assigned_grnas return None
[docs] def assign_to_max_guide( self, adata: AnnData, assignment_threshold: float, layer: str | None = None, output_key: str = "assigned_guide", no_grna_assigned_key: str = "NT", only_return_results: bool = False, ) -> np.ndarray | None: """Simple threshold based max gRNA assignment function. Each cell is assigned to the most expressed gRNA if it has at least `assignment_threshold` counts. This function expects unnormalized data as input. Args: adata: Annotated data matrix containing gRNA values assignment_threshold: The count threshold that is required for an assignment to be viable. layer: Key to the layer containing raw count values of the gRNAs. adata.X is used if layer is None. Expects count data. output_key: Assigned guide will be saved on adata.obs[output_key]. default value is `assigned_guide`. no_grna_assigned_key: The key to return if no gRNA is expressed enough. only_return_results: If True, input AnnData is not modified and the result is returned as an np.ndarray. Examples: Each cell is assigned to the most expressed gRNA if it has at least 5 counts. >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_to_max_guide(gdo, assignment_threshold=5) """ counts = adata.X if layer is None else adata.layers[layer] if scipy.sparse.issparse(counts): counts = counts.A assigned_grna = np.where( counts.max(axis=1).squeeze() >= assignment_threshold, adata.var.index[counts.argmax(axis=1).squeeze()], no_grna_assigned_key, ) if only_return_results: return assigned_grna adata.obs[output_key] = assigned_grna return None
[docs] def plot_heatmap( self, adata: AnnData, layer: str | None = None, order_by: np.ndarray | str | None = None, key_to_save_order: str = None, **kwargs, ) -> list[Axes]: """Heatmap plotting of guide RNA expression matrix. Assuming guides have sparse expression, this function reorders cells and plots guide RNA expression so that a nice sparse representation is achieved. The cell ordering can be stored and reused in future plots to obtain consistent plots before and after analysis of the guide RNA expression. Note: This function expects a log-normalized or binary data. Args: adata: Annotated data matrix containing gRNA values layer: Key to the layer containing log normalized count values of the gRNAs. adata.X is used if layer is None. order_by: The order of cells in y axis. Defaults to None. If None, cells will be reordered to have a nice sparse representation. If a string is provided, adata.obs[order_by] will be used as the order. If a numpy array is provided, the array will be used for ordering. key_to_save_order: The obs key to save cell orders in the current plot. Only saves if not None. kwargs: Are passed to sc.pl.heatmap. Returns: List of Axes. Alternatively you can pass save or show parameters as they will be passed to sc.pl.heatmap. Order of cells in the y-axis will be saved on adata.obs[key_to_save_order] if provided. Examples: Each cell is assigned to gRNA that occurs at least 5 times in the respective cell, which is then visualized using a heatmap. >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> gdo = mdata.mod["gdo"] >>> ga = pt.pp.GuideAssignment() >>> ga.assign_by_threshold(gdo, assignment_threshold=5) >>> ga.plot_heatmap(gdo) """ data = adata.X if layer is None else adata.layers[layer] if order_by is None: if scipy.sparse.issparse(data): max_values = data.max(axis=1).A.squeeze() data_argmax = data.argmax(axis=1).A.squeeze() max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1) else: max_guide_index = np.where( data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1 ) order = np.argsort(max_guide_index) elif isinstance(order_by, str): order = np.argsort(adata.obs[order_by]) else: order = order_by temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}" adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])]) if key_to_save_order is not None: adata.obs[key_to_save_order] = pd.Categorical(order) try: axis_group = sc.pl.heatmap( adata[order, :], var_names=adata.var.index.tolist(), groupby=temp_col_name, cmap="viridis", use_raw=False, dendrogram=False, layer=layer, **kwargs, ) finally: del adata.obs[temp_col_name] return axis_group