Source code for pertpy.tools._scgen._scgen

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
from adjustText import adjust_text
from anndata import AnnData
from jax import Array
from lamin_utils import logger
from scipy import stats
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField
from scvi.model.base import BaseModelClass, JaxTrainingMixin
from scvi.utils import setup_anndata_dsp

from ._scgenvae import JaxSCGENVAE
from ._utils import balancer, extractor

if TYPE_CHECKING:
    from collections.abc import Sequence

font = {"family": "Arial", "size": 14}


[docs] class SCGEN(JaxTrainingMixin, BaseModelClass): """Jax Implementation of scGen model for batch removal and perturbation prediction.""" def __init__( self, adata: AnnData, n_hidden: int = 800, n_latent: int = 100, n_layers: int = 2, dropout_rate: float = 0.2, **model_kwargs, ): super().__init__(adata) self.module = JaxSCGENVAE( n_input=self.summary_stats.n_vars, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, **model_kwargs, ) self._model_summary_string = ( f"SCGEN Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: " f"{dropout_rate}" ) self.init_params_ = self._get_init_params(locals())
[docs] def predict( self, ctrl_key=None, stim_key=None, adata_to_predict=None, celltype_to_predict=None, restrict_arithmetic_to="all", ) -> tuple[AnnData, Any]: """Predicts the cell type provided by the user in stimulated condition. Args: ctrl_key: Key for `control` part of the `data` found in `condition_key`. stim_key: Key for `stimulated` part of the `data` found in `condition_key`. adata_to_predict: Adata for unperturbed cells you want to be predicted. celltype_to_predict: The cell type you want to be predicted. restrict_arithmetic_to: Dictionary of celltypes you want to be observed for prediction. Returns: `np nd-array` of predicted cells in primary space. delta: float Difference between stimulated and control cells in latent space Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") >>> model = pt.tl.SCGEN(data) >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) >>> pred, delta = model.predict(ctrl_key="ctrl", stim_key="stim", celltype_to_predict="CD4 T cells") """ # use keys registered from `setup_anndata()` cell_type_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key condition_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key if restrict_arithmetic_to == "all": ctrl_x = self.adata[self.adata.obs[condition_key] == ctrl_key, :] stim_x = self.adata[self.adata.obs[condition_key] == stim_key, :] ctrl_x = balancer(ctrl_x, cell_type_key) stim_x = balancer(stim_x, cell_type_key) else: key = list(restrict_arithmetic_to.keys())[0] values = restrict_arithmetic_to[key] subset = self.adata[self.adata.obs[key].isin(values)] ctrl_x = subset[subset.obs[condition_key] == ctrl_key, :] stim_x = subset[subset.obs[condition_key] == stim_key, :] if len(values) > 1: ctrl_x = balancer(ctrl_x, cell_type_key) stim_x = balancer(stim_x, cell_type_key) if celltype_to_predict is not None and adata_to_predict is not None: raise Exception("Please provide either a cell type or adata not both!") if celltype_to_predict is None and adata_to_predict is None: raise Exception("Please provide a cell type name or adata for your unperturbed cells") if celltype_to_predict is not None: ctrl_pred = extractor( self.adata, celltype_to_predict, condition_key, cell_type_key, ctrl_key, stim_key, )[1] else: ctrl_pred = adata_to_predict eq = min(ctrl_x.X.shape[0], stim_x.X.shape[0]) rng = np.random.default_rng() cd_ind = rng.choice(range(ctrl_x.shape[0]), size=eq, replace=False) stim_ind = rng.choice(range(stim_x.shape[0]), size=eq, replace=False) ctrl_adata = ctrl_x[cd_ind, :] stim_adata = stim_x[stim_ind, :] latent_ctrl = self._avg_vector(ctrl_adata) latent_stim = self._avg_vector(stim_adata) delta = latent_stim - latent_ctrl latent_cd = self.get_latent_representation(ctrl_pred) stim_pred = delta + latent_cd predicted_cells = self.module.as_bound().generative(stim_pred)["px"] predicted_adata = AnnData( X=np.array(predicted_cells), obs=ctrl_pred.obs.copy(), var=ctrl_pred.var.copy(), obsm=ctrl_pred.obsm.copy(), ) return predicted_adata, delta
def _avg_vector(self, adata): return np.mean(self.get_latent_representation(adata), axis=0)
[docs] def get_decoded_expression( self, adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, ) -> Array: """Get decoded expression. Args: adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. indices: Indices of cells in adata to use. If `None`, all cells are used. batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns: Decoded expression for each cell Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") >>> model = pt.tl.SCGEN(data) >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) >>> decoded_X = model.get_decoded_expression() """ if self.is_trained_ is False: raise RuntimeError("Please train the model first.") adata = self._validate_anndata(adata) scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) decoded = [] for tensors in scdl: _, generative_outputs = self.module.as_bound()(tensors, compute_loss=False) px = generative_outputs["px"] decoded.append(px) return jnp.concatenate(decoded)
[docs] def batch_removal(self, adata: AnnData | None = None) -> AnnData: """Removes batch effects. Args: adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`, corresponding to batch and cell type metadata, respectively. Returns: corrected: `~anndata.AnnData` AnnData of corrected gene expression in adata.X and corrected latent space in adata.obsm["latent"]. A reference to the original AnnData is in `corrected.raw` if the input adata had no `raw` attribute. Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") >>> model = pt.tl.SCGEN(data) >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) >>> corrected_adata = model.batch_removal() """ adata = self._validate_anndata(adata) latent_all = self.get_latent_representation(adata) # use keys registered from `setup_anndata()` cell_label_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY).original_key batch_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key adata_latent = AnnData(latent_all) adata_latent.obs = adata.obs.copy(deep=True) unique_cell_types = np.unique(adata_latent.obs[cell_label_key]) shared_ct = [] not_shared_ct = [] for cell_type in unique_cell_types: temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy() if len(np.unique(temp_cell.obs[batch_key])) < 2: cell_type_ann = adata_latent[adata_latent.obs[cell_label_key] == cell_type] not_shared_ct.append(cell_type_ann) continue temp_cell = adata_latent[adata_latent.obs[cell_label_key] == cell_type].copy() batch_list = {} batch_ind = {} max_batch = 0 max_batch_ind = "" batches = np.unique(temp_cell.obs[batch_key]) for i in batches: temp = temp_cell[temp_cell.obs[batch_key] == i] temp_ind = temp_cell.obs[batch_key] == i if max_batch < len(temp): max_batch = len(temp) max_batch_ind = i batch_list[i] = temp batch_ind[i] = temp_ind max_batch_ann = batch_list[max_batch_ind] for study in batch_list: delta = np.average(max_batch_ann.X, axis=0) - np.average(batch_list[study].X, axis=0) batch_list[study].X = delta + batch_list[study].X temp_cell[batch_ind[study]].X = batch_list[study].X shared_ct.append(temp_cell) all_shared_ann = AnnData.concatenate(*shared_ct, batch_key="concat_batch", index_unique=None) if "concat_batch" in all_shared_ann.obs.columns: del all_shared_ann.obs["concat_batch"] if len(not_shared_ct) < 1: corrected = AnnData( np.array(self.module.as_bound().generative(all_shared_ann.X)["px"]), obs=all_shared_ann.obs, ) corrected.var_names = adata.var_names.tolist() corrected = corrected[adata.obs_names] if adata.raw is not None: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw corrected.obsm["latent"] = all_shared_ann.X corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected) return corrected else: all_not_shared_ann = AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None) all_corrected_data = AnnData.concatenate( all_shared_ann, all_not_shared_ann, batch_key="concat_batch", index_unique=None, ) if "concat_batch" in all_shared_ann.obs.columns: del all_corrected_data.obs["concat_batch"] corrected = AnnData( np.array(self.module.as_bound().generative(all_corrected_data.X)["px"]), obs=all_corrected_data.obs, ) corrected.var_names = adata.var_names.tolist() corrected = corrected[adata.obs_names] if adata.raw is not None: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw corrected.obsm["latent"] = all_corrected_data.X corrected.obsm["corrected_latent"] = self.get_latent_representation(corrected) return corrected
[docs] @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, batch_key: str | None = None, labels_key: str | None = None, **kwargs, ): """%(summary)s. scGen expects the expression data to come from `adata.X` %(param_batch_key)s %(param_labels_key)s Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=False), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
[docs] def to_device(self, device): pass
@property def device(self): return self.module.device
[docs] def get_latent_representation( self, adata: AnnData | None = None, indices: Sequence[int] | None = None, give_mean: bool = True, n_samples: int = 1, batch_size: int | None = None, ) -> np.ndarray: """Return the latent representation for each cell. Args: adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. indices: Indices of cells in adata to use. If `None`, all cells are used. batch_size: Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns: Low-dimensional representation for each cell Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") >>> model = pt.tl.SCGEN(data) >>> model.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) >>> latent_X = model.get_latent_representation() """ self._check_if_trained(warn=False) adata = self._validate_anndata(adata) scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True) jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"n_samples": n_samples}) latent = [] for array_dict in scdl: out = jit_inference_fn(self.module.rngs, array_dict) if give_mean: z = out["qz"].mean else: z = out["z"] latent.append(z) concat_axis = 0 if ((n_samples == 1) or give_mean) else 1 latent = jnp.concatenate(latent, axis=concat_axis) # type: ignore return self.module.as_numpy_array(latent)
[docs] def plot_reg_mean_plot( self, adata, condition_key: str, axis_keys: dict[str, str], labels: dict[str, str], save: str | bool | None = None, gene_list: list[str] = None, show: bool = False, top_100_genes: list[str] = None, verbose: bool = False, legend: bool = True, title: str = None, x_coeff: float = 0.30, y_coeff: float = 0.8, fontsize: float = 14, **kwargs, ) -> tuple[float, float] | float: """Plots mean matching for a set of specified genes. Args: adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`, corresponding to batch and cell type metadata, respectively. condition_key: The key for the condition axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form: `{"x": "Key for x-axis", "y": "Key for y-axis"}`. labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`. path_to_save: path to save the plot. save: Specify if the plot should be saved or not. gene_list: list of gene names to be plotted. show: if `True`: will show to the plot after saving it. top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra. verbose: Specify if you want information to be printed while creating the plot, defaults to `False`. legend: if `True`: plots a legend, defaults to `True`. title: Set if you want the plot to display a title. x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3. y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8. fontsize: Fontsize used for text in the plot, defaults to 14. **kwargs: Examples: >>> import pertpy as pt >>> data = pt.dt.kang_2018() >>> pt.tl.SCGEN.setup_anndata(data, batch_key="label", labels_key="cell_type") >>> scg = pt.tl.SCGEN(data) >>> scg.train(max_epochs=10, batch_size=64, early_stopping=True, early_stopping_patience=5) >>> pred, delta = scg.predict(ctrl_key='ctrl', stim_key='stim', celltype_to_predict='CD4 T cells') >>> pred.obs['label'] = 'pred' >>> eval_adata = data[data.obs['cell_type'] == 'CD4 T cells'].copy().concatenate(pred) >>> r2_value = scg.plot_reg_mean_plot(eval_adata, condition_key='label', axis_keys={"x": "pred", "y": "stim"}, \ labels={"x": "predicted", "y": "ground truth"}, save=False, show=True) Preview: .. image:: /_static/docstring_previews/scgen_reg_mean.png """ import seaborn as sns sns.set_theme() sns.set_theme(color_codes=True) diff_genes = top_100_genes stim = adata[adata.obs[condition_key] == axis_keys["y"]] ctrl = adata[adata.obs[condition_key] == axis_keys["x"]] if diff_genes is not None: if hasattr(diff_genes, "tolist"): diff_genes = diff_genes.tolist() adata_diff = adata[:, diff_genes] stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]] ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]] x_diff = np.asarray(np.mean(ctrl_diff.X, axis=0)).ravel() y_diff = np.asarray(np.mean(stim_diff.X, axis=0)).ravel() m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff) if verbose: logger.info("top_100 DEGs mean: ", r_value_diff**2) x = np.asarray(np.mean(ctrl.X, axis=0)).ravel() y = np.asarray(np.mean(stim.X, axis=0)).ravel() m, b, r_value, p_value, std_err = stats.linregress(x, y) if verbose: logger.info("All genes mean: ", r_value**2) df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y}) ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df) ax.tick_params(labelsize=fontsize) if "range" in kwargs: start, stop, step = kwargs.get("range") ax.set_xticks(np.arange(start, stop, step)) ax.set_yticks(np.arange(start, stop, step)) ax.set_xlabel(labels["x"], fontsize=fontsize) ax.set_ylabel(labels["y"], fontsize=fontsize) if gene_list is not None: texts = [] for i in gene_list: j = adata.var_names.tolist().index(i) x_bar = x[j] y_bar = y[j] texts.append(plt.text(x_bar, y_bar, i, fontsize=11, color="black")) plt.plot(x_bar, y_bar, "o", color="red", markersize=5) # if "y1" in axis_keys.keys(): # y1_bar = y1[j] # plt.text(x_bar, y1_bar, i, fontsize=11, color="black") if gene_list is not None: adjust_text( texts, x=x, y=y, arrowprops={"arrowstyle": "->", "color": "grey", "lw": 0.5}, force_static=(0.0, 0.0), ) if legend: plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) if title is None: plt.title("", fontsize=fontsize) else: plt.title(title, fontsize=fontsize) ax.text( max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y), r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}", fontsize=kwargs.get("textsize", fontsize), ) if diff_genes is not None: ax.text( max(x) - max(x) * x_coeff, max(y) - (y_coeff + 0.15) * max(y), r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}", fontsize=kwargs.get("textsize", fontsize), ) if save: plt.savefig(save, bbox_inches="tight") if show: plt.show() plt.close() if diff_genes is not None: return r_value**2, r_value_diff**2 else: return r_value**2
[docs] def plot_reg_var_plot( self, adata, condition_key: str, axis_keys: dict[str, str], labels: dict[str, str], save: str | bool | None = None, gene_list: list[str] = None, top_100_genes: list[str] = None, show: bool = False, legend: bool = True, title: str = None, verbose: bool = False, x_coeff: float = 0.3, y_coeff: float = 0.8, fontsize: float = 14, **kwargs, ) -> tuple[float, float] | float: """Plots variance matching for a set of specified genes. Args: adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`, corresponding to batch and cell type metadata, respectively. condition_key: Key of the condition. axis_keys: Dictionary of `adata.obs` keys that are used by the axes of the plot. Has to be in the following form: `{"x": "Key for x-axis", "y": "Key for y-axis"}`. labels: Dictionary of axes labels of the form `{"x": "x-axis-name", "y": "y-axis name"}`. path_to_save: path to save the plot. save: Specify if the plot should be saved or not. gene_list: list of gene names to be plotted. show: if `True`: will show to the plot after saving it. top_100_genes: List of the top 100 differentially expressed genes. Specify if you want the top 100 DEGs to be assessed extra. legend: if `True`: plots a legend, defaults to `True`. title: Set if you want the plot to display a title. verbose: Specify if you want information to be printed while creating the plot, defaults to `False`. x_coeff: Offset to print the R^2 value in x-direction, defaults to 0.3. y_coeff: Offset to print the R^2 value in y-direction, defaults to 0.8. fontsize: Fontsize used for text in the plot, defaults to 14. """ import seaborn as sns sns.set_theme() sns.set_theme(color_codes=True) sc.tl.rank_genes_groups(adata, groupby=condition_key, n_genes=100, method="wilcoxon") diff_genes = top_100_genes stim = adata[adata.obs[condition_key] == axis_keys["y"]] ctrl = adata[adata.obs[condition_key] == axis_keys["x"]] if diff_genes is not None: if hasattr(diff_genes, "tolist"): diff_genes = diff_genes.tolist() adata_diff = adata[:, diff_genes] stim_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["y"]] ctrl_diff = adata_diff[adata_diff.obs[condition_key] == axis_keys["x"]] x_diff = np.asarray(np.var(ctrl_diff.X, axis=0)).ravel() y_diff = np.asarray(np.var(stim_diff.X, axis=0)).ravel() m, b, r_value_diff, p_value_diff, std_err_diff = stats.linregress(x_diff, y_diff) if verbose: logger.info("Top 100 DEGs var: ", r_value_diff**2) if "y1" in axis_keys.keys(): real_stim = adata[adata.obs[condition_key] == axis_keys["y1"]] x = np.asarray(np.var(ctrl.X, axis=0)).ravel() y = np.asarray(np.var(stim.X, axis=0)).ravel() m, b, r_value, p_value, std_err = stats.linregress(x, y) if verbose: logger.info("All genes var: ", r_value**2) df = pd.DataFrame({axis_keys["x"]: x, axis_keys["y"]: y}) ax = sns.regplot(x=axis_keys["x"], y=axis_keys["y"], data=df) ax.tick_params(labelsize=fontsize) if "range" in kwargs: start, stop, step = kwargs.get("range") ax.set_xticks(np.arange(start, stop, step)) ax.set_yticks(np.arange(start, stop, step)) # _p1 = plt.scatter(x, y, marker=".", label=f"{axis_keys['x']}-{axis_keys['y']}") # plt.plot(x, m * x + b, "-", color="green") ax.set_xlabel(labels["x"], fontsize=fontsize) ax.set_ylabel(labels["y"], fontsize=fontsize) if "y1" in axis_keys.keys(): y1 = np.asarray(np.var(real_stim.X, axis=0)).ravel() _ = plt.scatter( x, y1, marker="*", c="grey", alpha=0.5, label=f"{axis_keys['x']}-{axis_keys['y1']}", ) if gene_list is not None: for i in gene_list: j = adata.var_names.tolist().index(i) x_bar = x[j] y_bar = y[j] plt.text(x_bar, y_bar, i, fontsize=11, color="black") plt.plot(x_bar, y_bar, "o", color="red", markersize=5) if "y1" in axis_keys.keys(): y1_bar = y1[j] plt.text(x_bar, y1_bar, "*", color="black", alpha=0.5) if legend: plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) if title is None: plt.title("", fontsize=12) else: plt.title(title, fontsize=12) ax.text( max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y), r"$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= " + f"{r_value ** 2:.2f}", fontsize=kwargs.get("textsize", fontsize), ) if diff_genes is not None: ax.text( max(x) - max(x) * x_coeff, max(y) - (y_coeff + 0.15) * max(y), r"$\mathrm{R^2_{\mathrm{\mathsf{top\ 100\ DEGs}}}}$= " + f"{r_value_diff ** 2:.2f}", fontsize=kwargs.get("textsize", fontsize), ) if save: plt.savefig(save, bbox_inches="tight") if show: plt.show() plt.close() if diff_genes is not None: return r_value**2, r_value_diff**2 else: return r_value**2
[docs] def plot_binary_classifier( self, scgen: SCGEN, adata: AnnData | None, delta: np.ndarray, ctrl_key: str, stim_key: str, show: bool = False, save: str | bool | None = None, fontsize: float = 14, ) -> plt.Axes | None: """Plots the dot product between delta and latent representation of a linear classifier. Builds a linear classifier based on the dot product between the difference vector and the latent representation of each cell and plots the dot product results between delta and latent representation. Args: scgen: ScGen object that was trained. adata: AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the AnnData object used to initialize the model. Must have been setup with `batch_key` and `labels_key`, corresponding to batch and cell type metadata, respectively. delta: Difference between stimulated and control cells in latent space ctrl_key: Key for `control` part of the `data` found in `condition_key`. stim_key: Key for `stimulated` part of the `data` found in `condition_key`. path_to_save: Path to save the plot. save: Specify if the plot should be saved or not. fontsize: Set the font size of the plot. """ plt.close("all") adata = scgen._validate_anndata(adata) condition_key = scgen.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key cd = adata[adata.obs[condition_key] == ctrl_key, :] stim = adata[adata.obs[condition_key] == stim_key, :] all_latent_cd = scgen.get_latent_representation(cd.X) all_latent_stim = scgen.get_latent_representation(stim.X) dot_cd = np.zeros(len(all_latent_cd)) dot_sal = np.zeros(len(all_latent_stim)) for ind, vec in enumerate(all_latent_cd): dot_cd[ind] = np.dot(delta, vec) for ind, vec in enumerate(all_latent_stim): dot_sal[ind] = np.dot(delta, vec) plt.hist( dot_cd, label=ctrl_key, bins=50, ) plt.hist(dot_sal, label=stim_key, bins=50) plt.axvline(0, color="k", linestyle="dashed", linewidth=1) plt.title(" ", fontsize=fontsize) plt.xlabel(" ", fontsize=fontsize) plt.ylabel(" ", fontsize=fontsize) plt.xticks(fontsize=fontsize) plt.yticks(fontsize=fontsize) ax = plt.gca() ax.grid(False) if save: plt.savefig(save, bbox_inches="tight") if show: plt.show() if not (show or save): return ax return None