diff --git a/minification.py b/minification.py new file mode 100644 index 00000000..cad24dda --- /dev/null +++ b/minification.py @@ -0,0 +1,204 @@ +from anndata import AnnData +from scipy.sparse import csr_matrix +import torch +from scipy import sparse +import numpy as np +import scanpy as sc +import scarches as sca + + + +#should be a method of scPoli +def get_latent(module, x, c=None, mean=False, mean_var=False): + """Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in + data. + Parameters + ---------- + x: torch.Tensor + Torch Tensor to be mapped to latent space. `x` has to be in shape [n_obs, input_dim]. + c: torch.Tensor + Torch Tensor of condition labels for each sample. + mean: boolean + Returns + ------- + Returns Torch Tensor containing latent space encoding of 'x'. + """ + #compute latent representation + x_ = torch.log(1 + x) + if module.recon_loss == "mse": + x_ = x + if "encoder" in module.inject_condition: + # c = c.type(torch.cuda.LongTensor) + c = c.long() + embed_c = torch.hstack([module.embeddings[i](c[:, i]) for i in range(c.shape[1])]) + z_mean, z_log_var = module.encoder(x_, embed_c) + else: + z_mean, z_log_var = module.encoder(x_) + latent = module.sampling(z_mean, z_log_var) + if mean: + return z_mean + elif mean_var: + return (z_mean, z_log_var) + return latent + +#should be a method of scPoli +def get_latent_representation( + model, + adata, + mean: bool = False, + mean_var: bool = False + ): + """Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in + data. + + Parameters + ---------- + x + Numpy nd-array to be mapped to latent space. `x` has to be in shape [n_obs, input_dim]. + c + `numpy nd-array` of original (unencoded) desired labels for each sample. + mean + return mean instead of random sample from the latent space + + Returns + ------- + Returns array containing latent space encoding of 'x'. + """ + device = next(model.model.parameters()).device + x = adata.X + c = {k: adata.obs[k].values for k in model.condition_keys_} + + if isinstance(c, dict): + label_tensor = [] + for cond in c.keys(): + query_conditions = c[cond] + if not set(query_conditions).issubset(model.conditions_[cond]): + raise ValueError("Incorrect conditions") + labels = np.zeros(query_conditions.shape[0]) + for condition, label in model.model.condition_encoders[cond].items(): + labels[query_conditions == condition] = label + label_tensor.append(labels) + c = torch.tensor(label_tensor, device=device).T + if sparse.issparse(x): + x = x.A + x = torch.tensor(x, dtype=torch.float32) + + latents = [] + # batch the latent transformation process + indices = torch.arange(x.size(0)) + subsampled_indices = indices.split(512) + for batch in subsampled_indices: + latent = get_latent(model.model, + x[batch, :].to(device), c[batch, :].to(device), mean, mean_var + ) + latent = (latent,) if not isinstance(latent, tuple) else latent + latents += [tuple(l.cpu().detach() for l in latent)] + + result = tuple(torch.cat(l) for l in zip(*latents)) + result = result[0] if len(result) == 1 else result + + return result + + +def get_minified_adata_scrna( + adata: AnnData, +) -> AnnData: + """Returns a minified adata that works for most scrna models (such as SCVI, SCANVI). + + Parameters + ---------- + adata + Original adata, of which we to create a minified version. + + """ + + all_zeros = csr_matrix(adata.X.shape) + layers = {layer: all_zeros.copy() for layer in adata.layers} + bdata = AnnData( + X=all_zeros, + layers=layers, + uns=adata.uns.copy(), + obs=adata.obs, + var=adata.var, + varm=adata.varm, + obsm=adata.obsm, + obsp=adata.obsp, + ) + + return bdata + + + +def minify_adata(model, adata): + + """ + This function is adapted from scvi-tools + https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI.minify_adata + + minify adata using latent posterior parameters: + + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + """ + + #get the latent representation and store it in the adata + qzm, qzv = model.get_latent_representation(adata, mean_var=True) + + adata.obsm["X_latent_qzm"] = qzm + adata.obsm["X_latent_qzv"] = qzv + + #we cannot minify data where we do not use observed library size for gene count generation. + #In SCVI model, the library size can be modelled as a latent variable. However in scPoli it is set + #to be observed (equal to the total UMI RNA count of a cell). + + + minified_adata = get_minified_adata_scrna(adata) + minified_adata.obsm["X_latent_qzm"] = adata.obsm["X_latent_qzm"] + minified_adata.obsm["X_latent_qzv"] = adata.obsm["X_latent_qzv"] + counts = adata.X + minified_adata.obs["observed_lib_size"] = np.squeeze( + np.asarray(counts.sum(axis=1)) + ) + + #TODO: set is_minified attribute to True + + + minified_adata.write("adata.h5ad") + +def main(): + + adata = sc.read("atlas_646ddf52fd46b85aafce28c2_data_not_minifiied.h5ad") + + model =sca.models.scPoli.load("model", adata) + + minify_adata(adata, model) + +if __name__ == "__main__": + main() + + +# import scarches as sca +# from scanpy.datasets import pbmc3k_processed, pbmc3k #replace with stored atlas trained on scPoli + +# @pytest.mark.parametrize('get_adata', ["path/to/atlas1", "path/to/atlas2"]) +# def test_minification(get_adata): +# adata = get_adata() +# model = scarches.models.scPoli.load(path = "path/to/model", adata=adata) + +# minify_adata(model, adata) + + + + + + + + + + + + + + + diff --git a/minify.py b/minify.py new file mode 100644 index 00000000..e69de29b diff --git a/run.py b/run.py new file mode 100644 index 00000000..8994e555 --- /dev/null +++ b/run.py @@ -0,0 +1,53 @@ +import numpy as np +import scanpy as sc +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import classification_report +from scarches.models.scpoli import scPoli + + +adata = sc.read('tmp/pancreas.h5ad') + +early_stopping_kwargs = { + "early_stopping_metric": "val_prototype_loss", + "mode": "min", + "threshold": 0, + "patience": 20, + "reduce_lr": True, + "lr_patience": 13, + "lr_factor": 0.1, +} + +condition_key = 'study' +cell_type_key = 'cell_type' +reference = [ + 'inDrop1', + 'inDrop2', + 'inDrop3', + 'inDrop4', + 'fluidigmc1', + 'smartseq2', + 'smarter' +] +query = ['celseq', 'celseq2'] + +adata.obs['query'] = adata.obs[condition_key].isin(query) +adata.obs['query'] = adata.obs['query'].astype('category') +source_adata = adata[adata.obs.study.isin(reference)].copy() +source_adata = source_adata[~source_adata.obs.cell_type.str.contains('alpha')].copy() +target_adata = adata[adata.obs.study.isin(query)].copy() + +scpoli_model = scPoli( + adata=source_adata, + condition_keys=condition_key, + cell_type_keys=cell_type_key, + embedding_dims=5, + recon_loss='nb', +) +scpoli_model.train( + n_epochs=50, + pretraining_epochs=40, + early_stopping_kwargs=early_stopping_kwargs, + eta=5, +) \ No newline at end of file diff --git a/scarches/models/base/_base.py b/scarches/models/base/_base.py index df5fdf7a..3e82de23 100644 --- a/scarches/models/base/_base.py +++ b/scarches/models/base/_base.py @@ -10,7 +10,7 @@ from anndata import AnnData, read from scipy.sparse import issparse -from ._utils import UnpicklerCpu, _validate_var_names +from ._utils import UnpicklerCpu, _validate_var_names, get_minified_adata_scrna class BaseMixin: @@ -193,6 +193,36 @@ def load( model.is_trained_ = attr_dict['is_trained_'] return model + + def minify_adata(self, adata=None, model_name=None): + """ + This function is adapted from scvi-tools + https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI.minify_adata + minify adata using latent posterior parameters: + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + """ + + if adata is None: + adata = self.adata + + #get the latent representation and store it in the adata + qzm, qzv = self.get_latent(adata, mean_var=True) + adata.obsm[f"X_latent_qzm_{model_name}"] = qzm + adata.obsm[f"X_latent_qzv_{model_name}"] = qzv + + minified_adata = get_minified_adata_scrna(adata) + minified_adata.obsm[f"X_latent_qzm_{model_name}"] = adata.obsm[f"X_latent_qzm_{model_name}"] + minified_adata.obsm[f"X_latent_qzv_{model_name}"] = adata.obsm[f"X_latent_qzv_{model_name}"] + counts = adata.X + minified_adata.obs["observed_lib_size"] = np.squeeze( + np.asarray(counts.sum(axis=1)) + ) + self.adata = minified_adata + + print(self.adata) + class SurgeryMixin: diff --git a/scarches/models/base/_utils.py b/scarches/models/base/_utils.py index d24c060e..3f52a669 100644 --- a/scarches/models/base/_utils.py +++ b/scarches/models/base/_utils.py @@ -62,6 +62,33 @@ def _validate_var_names(adata, source_var_names): return new_adata +def get_minified_adata_scrna( + adata: AnnData, +) -> AnnData: + + + """This function is adapted from scvi-tools + https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.utils.get_minified_adata_scrna.html + + Returns a minified adata. + Parameters + ---------- + adata + Original adata, of which we to create a minified version. + """ + all_zeros = csr_matrix(adata.X.shape) + layers = {layer: all_zeros.copy() for layer in adata.layers} + bdata = AnnData( + X=all_zeros, + layers=layers, + uns=adata.uns.copy(), + obs=adata.obs, + var=adata.var, + varm=adata.varm, + obsm=adata.obsm, + obsp=adata.obsp, + ) + return bdata class UnpicklerCpu(pickle.Unpickler): """Helps to pickle.load a model trained on GPU to CPU. @@ -72,4 +99,4 @@ def find_class(self, module, name): if module == 'torch.storage' and name == '_load_from_bytes': return lambda b: torch.load(io.BytesIO(b), map_location='cpu') else: - return super().find_class(module, name) \ No newline at end of file + return super().find_class(module, name) diff --git a/scarches/models/scpoli/scpoli.py b/scarches/models/scpoli/scpoli.py index 895c8012..a7239ccc 100644 --- a/scarches/models/scpoli/scpoli.py +++ b/scarches/models/scpoli/scpoli.py @@ -330,7 +330,7 @@ def sampling(self, mu, log_var): var = torch.exp(log_var) + 1e-4 return Normal(mu, var.sqrt()).rsample() - def get_latent(self, x, c=None, mean=False): + def get_latent(self, x, c=None, mean=False, mean_var=False): """Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in data. Parameters @@ -357,6 +357,8 @@ def get_latent(self, x, c=None, mean=False): latent = self.sampling(z_mean, z_log_var) if mean: return z_mean + elif mean_var: + return (z_mean, z_log_var) return latent diff --git a/scarches/models/scpoli/scpoli_model.py b/scarches/models/scpoli/scpoli_model.py index 98457adb..1c98a6d8 100644 --- a/scarches/models/scpoli/scpoli_model.py +++ b/scarches/models/scpoli/scpoli_model.py @@ -313,6 +313,7 @@ def get_latent( self, adata, mean: bool = False, + mean_var: bool = False ): """Map `x` in to the latent space. This function will feed data in encoder and return z for each sample in data. @@ -356,11 +357,13 @@ def get_latent( x_batch = x_batch.toarray() x_batch = torch.tensor(x_batch, device=device).float() latent = self.model.get_latent( - x_batch, c[batch, :], mean + x_batch, c[batch, :], mean, mean_var ) - latents += [latent.cpu().detach()] - latents = torch.cat(latents) - return np.array(latents) + latent = (latent,) if not isinstance(latent, tuple) else latent + latents += [tuple(l.cpu().detach() for l in latent)] + result = tuple(np.array(torch.cat(l)) for l in zip(*latents)) + result = result[0] if len(result) == 1 else result + return result def get_conditional_embeddings(self): """ @@ -969,3 +972,8 @@ def _load_expand_params_from_dict(self, state_dict): load_state_dict[key] = fixed_ten self.model.load_state_dict(load_state_dict) + + + def minify_adata(self, adata=None): + super().minify_adata(adata, model_name="scpoli") +