Skip to content

Commit

Permalink
Remove pydantic v2 dep
Browse files Browse the repository at this point in the history
Necessary to have the napari plugin install via GUI installer.
  • Loading branch information
bentaculum committed Jun 3, 2024
1 parent dc6645e commit b87710e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 26 deletions.
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ install_requires =
wandb
edt
joblib
pydantic >= 2.0
pydantic_numpy
python_requires = >=3.10
include_package_data = True

Expand Down
2 changes: 0 additions & 2 deletions trackastra/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

# from .data import CTCData
import tifffile
from pydantic import validate_call
from tqdm import tqdm

logger = logging.getLogger(__name__)


@validate_call
def load_tiff_timeseries(
dir: Path,
dtype: str | type | None = None,
Expand Down
37 changes: 22 additions & 15 deletions trackastra/data/wrfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import numpy as np
import pandas as pd
from edt import edt
from pydantic import validate_call
from pydantic_numpy import NpNDArray
from skimage.measure import regionprops, regionprops_table
from tqdm import tqdm

Expand Down Expand Up @@ -89,7 +87,10 @@ def __init__(
self.timepoints = timepoints

def __repr__(self):
s = f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}, ntimepoints={len(np.unique(self.timepoints))})\n\n"
s = (
f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)},"
f" ntimepoints={len(np.unique(self.timepoints))})\n\n"
)
for k, v in self.features.items():
s += f"{k:>20} -> {v.shape}\n"
return s
Expand Down Expand Up @@ -445,10 +446,9 @@ def __call__(self, feats: WRFeatures):
return feats


@validate_call
def get_features(
detections: NpNDArray,
imgs: NpNDArray | None = None,
detections: np.ndarray,
imgs: np.ndarray | None = None,
features: Literal["none", "wrfeat"] = "wrfeat",
ndim: int = 2,
n_workers=0,
Expand All @@ -458,25 +458,34 @@ def get_features(
imgs = _check_dimensions(imgs, ndim)
logger.info(f"Extracting features from {len(detections)} detections")
if n_workers > 0:
features = joblib.Parallel(n_jobs=n_workers, backend='multiprocessing')(
features = joblib.Parallel(n_jobs=n_workers, backend="multiprocessing")(
joblib.delayed(WRFeatures.from_mask_img)(
# New axis for time component
mask=mask[np.newaxis, ...],
img=img[np.newaxis, ...],
t_start=t,
)
for t, (mask, img) in progbar_class(enumerate(zip(detections, imgs)), total=len(imgs), desc="Extracting features")
for t, (mask, img) in progbar_class(
enumerate(zip(detections, imgs)),
total=len(imgs),
desc="Extracting features",
)
)
else:
logger.info("Using single process for feature extraction")
features = tuple(WRFeatures.from_mask_img(
features = tuple(
WRFeatures.from_mask_img(
mask=mask[np.newaxis, ...],
img=img[np.newaxis, ...],
t_start=t,
)
for t, (mask, img) in progbar_class(enumerate(zip(detections, imgs)), total=len(imgs), desc="Extracting features")
)

for t, (mask, img) in progbar_class(
enumerate(zip(detections, imgs)),
total=len(imgs),
desc="Extracting features",
)
)

return features


Expand All @@ -495,9 +504,7 @@ def _check_dimensions(x: np.ndarray, ndim: int):


def build_windows(
features: list[WRFeatures],
window_size: int,
progbar_class=tqdm
features: list[WRFeatures], window_size: int, progbar_class=tqdm
) -> list[dict]:
windows = []
for t1, t2 in progbar_class(
Expand Down
3 changes: 0 additions & 3 deletions trackastra/model/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import yaml
from pydantic import validate_call
from tqdm import tqdm

from ..data import build_windows, get_features
Expand All @@ -27,15 +26,13 @@ def __init__(self, transformer, train_args, device="cpu"):
self.device = device

@classmethod
@validate_call
def from_folder(cls, dir: Path, device: str = "cpu"):
transformer = TrackingTransformer.from_folder(dir, map_location=device)
train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
return cls(transformer=transformer, train_args=train_args, device=device)

# TODO make safer
@classmethod
@validate_call
def from_pretrained(
cls, name: str, device: str = "cpu", download_dir: Path | None = None
):
Expand Down
9 changes: 5 additions & 4 deletions trackastra/model/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from pathlib import Path

import requests
from pydantic import validate_call
from tqdm import tqdm

logger = logging.getLogger(__name__)

_MODELS = {
"ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip",
"ctc": (
"https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip"
),
"general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip",
}

Expand Down Expand Up @@ -53,7 +54,6 @@ def download(url: str, fname: Path):
bar.update(size)


@validate_call
def download_pretrained(name: str, download_dir: Path | None = None):
# TODO make safe, introduce versioning
if download_dir is None:
Expand All @@ -66,7 +66,8 @@ def download_pretrained(name: str, download_dir: Path | None = None):
url = _MODELS[name]
except KeyError:
raise ValueError(
f"Pretrained model `name` is not available. Choose from {list(_MODELS.keys())}"
"Pretrained model `name` is not available. Choose from"
f" {list(_MODELS.keys())}"
)
folder = download_dir / name
download_and_unzip(url=url, dst=folder)
Expand Down

0 comments on commit b87710e

Please sign in to comment.