Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can hub paper #2379

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scvi/data/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

class _ADATA_MINIFY_TYPE_NT(NamedTuple):
LATENT_POSTERIOR: str = "latent_posterior_parameters"
ADD_POSTERIOR_PARAMETERS: str = "add_posterior_parameters"


ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT()
Expand Down
8 changes: 6 additions & 2 deletions scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from scvi import REGISTRY_KEYS, settings
from scvi._types import AnnOrMuData, MinifiedDataType
from scvi.data._constants import ADATA_MINIFY_TYPE

from . import _constants

Expand Down Expand Up @@ -325,10 +326,13 @@ def _get_adata_minify_type(adata: AnnData) -> Union[MinifiedDataType, None]:
def _is_minified(adata: Union[AnnData, str]) -> bool:
uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY
if isinstance(adata, AnnData):
return adata.uns.get(uns_key, None) is not None
return adata.uns.get(uns_key, None) == ADATA_MINIFY_TYPE.LATENT_POSTERIOR
elif isinstance(adata, str):
with h5py.File(adata) as fp:
return uns_key in read_elem(fp["uns"]).keys()
return (
read_elem(fp["uns"]).get(uns_key, None)
== ADATA_MINIFY_TYPE.LATENT_POSTERIOR
)
else:
raise TypeError(f"Unsupported type: {type(adata)}")

Expand Down
23 changes: 8 additions & 15 deletions scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField
from scvi.model.base import (
BaseModelClass,
BaseMinifiedModeModelClass,
RNASeqMixin,
UnsupervisedTrainingMixin,
VAEMixin,
Expand All @@ -24,7 +24,9 @@
logger = logging.getLogger(__name__)


class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
class CondSCVI(
RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseMinifiedModeModelClass
):
"""Conditional version of single-cell Variational Inference, used for multi-resolution deconvolution of spatial transcriptomics data :cite:p:`Lopez22`.

Parameters
Expand Down Expand Up @@ -60,6 +62,9 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass)
"""

_module_cls = VAEC
_LATENT_QZM = "_condscvi_latent_qzm"
_LATENT_QZV = "_condscvi_latent_qzv"
_OBSERVED_LIB_SIZE = "_condscvi_observed_lib_size"

def __init__(
self,
Expand Down Expand Up @@ -140,19 +145,7 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra
key = labels_state_registry.original_key
mapping = labels_state_registry.categorical_mapping

scdl = self._make_data_loader(adata=adata, batch_size=p)

mean = []
var = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
out = self.module.inference(x, y)
mean_, var_ = out["qz"].loc, (out["qz"].scale ** 2)
mean += [mean_.cpu()]
var += [var_.cpu()]

mean_cat, var_cat = torch.cat(mean).numpy(), torch.cat(var).numpy()
mean_cat, var_cat = self.get_latent_representation(adata, return_dist=True)

for ct in range(self.summary_stats["n_labels"]):
local_indices = np.where(adata.obs[key] == mapping[ct])[0]
Expand Down
102 changes: 8 additions & 94 deletions scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,21 @@
from anndata import AnnData

from scvi import REGISTRY_KEYS, settings
from scvi._types import MinifiedDataType
from scvi.data import AnnDataManager
from scvi.data._constants import (
_ADATA_MINIFY_TYPE_UNS_KEY,
_SETUP_ARGS_KEY,
ADATA_MINIFY_TYPE,
)
from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute
from scvi.data.fields import (
BaseAnnDataField,
CategoricalJointObsField,
CategoricalObsField,
LabelsWithUnlabeledObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.dataloaders import SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
from scvi.train._callbacks import SubSampleLabels
Expand All @@ -43,10 +36,6 @@
from ._scvi import SCVI
from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin

_SCANVI_LATENT_QZM = "_scanvi_latent_qzm"
_SCANVI_LATENT_QZV = "_scanvi_latent_qzv"
_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size"

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -106,6 +95,9 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):

_module_cls = SCANVAE
_training_plan_cls = SemiSupervisedTrainingPlan
_LATENT_QZM = "_scanvi_latent_qzm"
_LATENT_QZV = "_scanvi_latent_qzv"
_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size"

def __init__(
self,
Expand Down Expand Up @@ -465,7 +457,8 @@ def setup_anndata(
continuous_covariate_keys: list[str] | None = None,
**kwargs,
):
"""%(summary)s.
"""
%(summary)s.

Parameters
----------
Expand Down Expand Up @@ -498,90 +491,11 @@ def setup_anndata(
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
anndata_fields += cls._get_fields_for_adata_minification(
cls, adata_minify_type
)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the anndata fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCANVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCANVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCANVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
):
"""Minifies the model's adata.

Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is minified.

Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError(
"Cannot minify the data if `use_observed_lib_size` is False"
)

minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze(
np.asarray(counts.sum(axis=1))
)
self._update_adata_and_manager_post_minification(
minified_adata, minified_data_type
)
self.module.minified_data_type = minified_data_type
101 changes: 6 additions & 95 deletions scvi/model/_scvi.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,25 @@
import logging
from typing import Literal, Optional

import numpy as np
from anndata import AnnData

from scvi import REGISTRY_KEYS
from scvi._types import MinifiedDataType
from scvi.data import AnnDataManager
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
BaseAnnDataField,
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp

from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin

_SCVI_LATENT_QZM = "_scvi_latent_qzm"
_SCVI_LATENT_QZV = "_scvi_latent_qzv"
_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size"

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -96,6 +85,9 @@ class SCVI(
"""

_module_cls = VAE
_LATENT_QZM = "_scvi_latent_qzm"
_LATENT_QZV = "_scvi_latent_qzv"
_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size"

def __init__(
self,
Expand Down Expand Up @@ -204,92 +196,11 @@ def setup_anndata(
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
anndata_fields += cls._get_fields_for_adata_minification(
cls, adata_minify_type
)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the anndata fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's adata.

Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is minified.

Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:

* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored

Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError(
"Cannot minify the data if `use_observed_lib_size` is False"
)

minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(
np.asarray(counts.sum(axis=1))
)
self._update_adata_and_manager_post_minification(
minified_adata, minified_data_type
)
self.module.minified_data_type = minified_data_type
Loading