Skip to content

Commit

Permalink
Merge branch 'jhong/decipher' of github.com:scverse/scvi-tools into j…
Browse files Browse the repository at this point in the history
…hong/decipher
  • Loading branch information
justjhong committed Oct 21, 2024
2 parents e3562e2 + e2267ec commit 3bd9fab
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 31 deletions.
18 changes: 17 additions & 1 deletion .github/workflows/test_linux_cuda.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
name: test (cuda)

on:
push:
branches: [main, "[0-9]+.[0-9]+.x"] #this is new
pull_request:
branches: [main, "[0-9]+.[0-9]+.x"]
types: [labeled, synchronize, opened]
Expand Down Expand Up @@ -31,6 +33,7 @@ jobs:

container:
image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base
#image: ghcr.io/scverse/scvi-tools:py3.12-cu12-${{ env.BRANCH_NAME }}-base
options: --user root --gpus all

name: integration
Expand All @@ -40,11 +43,24 @@ jobs:
PYTHON: ${{ matrix.python }}

steps:
#- name: Get the current branch name
# id: vars
# run: echo "BRANCH_NAME=$(echo $GITHUB_REF | awk -F'/' '{print $3}')" >> $GITHUB_ENV

- uses: actions/checkout@v4

- run: |
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: "pip"
cache-dependency-path: "**/pyproject.toml"

- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."
python -m pip install jax[cuda]
python -m pip install nvidia-nccl-cu12
- name: Run pytest
env:
Expand Down
16 changes: 4 additions & 12 deletions src/scvi/external/decipher/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def __init__(

# The multiple outputs are computed as a single output layer, and then split
indices = np.concatenate(([0], np.cumsum(self.output_dims)))
self.output_slices = [
slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)
]
self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)]

# Create masked layers
deep_context_dim = self.context_dim if self.deep_context_injection else 0
Expand All @@ -63,21 +61,15 @@ def __init__(
batch_norms.append(nn.BatchNorm1d(hidden_dims[0]))
for i in range(1, len(hidden_dims)):
layers.append(
torch.nn.Linear(
hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]
)
torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i])
)
batch_norms.append(nn.BatchNorm1d(hidden_dims[i]))

layers.append(
torch.nn.Linear(
hidden_dims[-1] + deep_context_dim, self.output_total_dim
)
torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim)
)
else:
layers.append(
torch.nn.Linear(input_dim + context_dim, self.output_total_dim)
)
layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim))

self.layers = torch.nn.ModuleList(layers)

Expand Down
11 changes: 4 additions & 7 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def setup_anndata(
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

Expand Down Expand Up @@ -142,17 +140,16 @@ def get_latent_representation(
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent_locs = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
x = torch.log1p(x)
x = x.to(self.module.device)
z_loc, _ = self.module.encoder_x_to_z(x)
if give_z:
latent_locs.append(z_loc)
else:
v_loc, _ = self.module.encoder_zx_to_v(torch.cat([z_loc, x], dim=-1))
latent_locs.append(v_loc)
return torch.cat(latent_locs).detach().numpy()
return torch.cat(latent_locs).detach().cpu().numpy()
12 changes: 3 additions & 9 deletions src/scvi/external/decipher/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def device(self):
return self._dummy_param.device

@staticmethod
def _get_fn_args_from_batch(
tensor_dict: dict[str, torch.Tensor]
) -> Iterable | dict:
def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict:
x = tensor_dict[REGISTRY_KEYS.X_KEY]
return (x,), {}

Expand Down Expand Up @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor):
self.theta + self._epsilon
)
# noinspection PyUnresolvedReferences
x_dist = dist.NegativeBinomial(
total_count=self.theta + self._epsilon, logits=logit
)
x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit)
pyro.sample("x", x_dist.to_event(1), obs=x)

@auto_move_data
Expand Down Expand Up @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5):
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)
).get_trace(x)
log_weights.append(
model_trace.log_prob_sum() - guide_trace.log_prob_sum()
)
log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

finally:
self.beta = old_beta
Expand Down
7 changes: 5 additions & 2 deletions src/scvi/external/poissonvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,11 @@ def get_accessibility_estimates(

@torch.inference_mode()
def get_region_factors(self):
"""Return region-specific factors."""
region_factors = self.module.decoder.px_scale_decoder[-2].bias.numpy()
"""Return region-specific factors. CPU/GPU dependent"""
if self.device.type == "cpu":
region_factors = self.module.decoder.px_scale_decoder[-2].bias.numpy()
else:
region_factors = self.module.decoder.px_scale_decoder[-2].bias.cpu().numpy() # gpu
if region_factors is None:
raise RuntimeError("region factors were not included in this model")
return region_factors
Expand Down
3 changes: 3 additions & 0 deletions tests/model/test_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pyro
import pyro.distributions as dist
import pytest
import torch
from pyro import clear_param_store
from pyro.infer.autoguide import AutoNormal, init_to_mean
Expand Down Expand Up @@ -215,6 +216,7 @@ def test_pyro_bayesian_regression_low_level(
]


@pytest.mark.optional
def test_pyro_bayesian_regression(accelerator: str, devices: list | str | int, save_path: str):
adata = synthetic_iid()
adata_manager = _create_indices_adata_manager(adata)
Expand Down Expand Up @@ -277,6 +279,7 @@ def test_pyro_bayesian_regression(accelerator: str, devices: list | str | int, s
np.testing.assert_array_equal(linear_median_new, linear_median)


@pytest.mark.optional
def test_pyro_bayesian_regression_jit(
accelerator: str,
devices: list | str | int,
Expand Down

0 comments on commit 3bd9fab

Please sign in to comment.