diff --git a/docs/requirements.txt b/docs/requirements.txt index 9d46ce14..f0d4d53e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -12,3 +12,4 @@ mkdocs-meta-descriptions-plugin jupyter_contrib_nbextensions comet_ml lightning-flash[image,text]>=0.5.1 +wandb diff --git a/gradsflow/callbacks/__init__.py b/gradsflow/callbacks/__init__.py index bd82249e..41d93a05 100644 --- a/gradsflow/callbacks/__init__.py +++ b/gradsflow/callbacks/__init__.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .comet import CometCallback from .gpu import EmissionTrackerCallback -from .logger import CometCallback, CSVLogger -from .logger.logger import ModelCheckpoint +from .logger import CSVLogger, ModelCheckpoint from .progress import ProgressCallback from .raytune import report_checkpoint_callback from .runner import CallbackRunner from .training import TrainEvalCallback +from .wandb import WandbCallback diff --git a/gradsflow/callbacks/logger/comet.py b/gradsflow/callbacks/comet.py similarity index 88% rename from gradsflow/callbacks/logger/comet.py rename to gradsflow/callbacks/comet.py index 81705417..5448028c 100644 --- a/gradsflow/callbacks/logger/comet.py +++ b/gradsflow/callbacks/comet.py @@ -1,4 +1,16 @@ -# Copyright (c) 2021 GradsFlow. All rights reserved. +# Copyright (c) 2022 GradsFlow. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/gradsflow/callbacks/logger/logger.py b/gradsflow/callbacks/logger.py similarity index 83% rename from gradsflow/callbacks/logger/logger.py rename to gradsflow/callbacks/logger.py index 6cbded09..8f4c8909 100644 --- a/gradsflow/callbacks/logger/logger.py +++ b/gradsflow/callbacks/logger.py @@ -1,4 +1,16 @@ -# Copyright (c) 2021 GradsFlow. All rights reserved. +# Copyright (c) 2022 GradsFlow. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/gradsflow/callbacks/wandb.py b/gradsflow/callbacks/wandb.py new file mode 100644 index 00000000..56be47cc --- /dev/null +++ b/gradsflow/callbacks/wandb.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022 GradsFlow. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict, List, Optional + +import wandb + +from gradsflow.callbacks.base import Callback +from gradsflow.utility.imports import requires + +CURRENT_FILE = os.path.dirname(os.path.realpath(__file__)) + + +def define_metrics(): + min_max_def: Dict[str, List[str]] = { + "min": ["train/step_loss", "train/epoch_loss", "val/epoch_loss"], + "max": ["train/acc*", "val/acc*"], + } + for summary, metric_list in min_max_def.items(): + for metric in metric_list: + if "epoch" in metric or "val" in metric: + wandb.define_metric(metric, summary=summary, step_metric="epoch") + wandb.define_metric("*", step_metric="global_step") + + +class WandbCallback(Callback): + """ + [Weights & Biases](https://www.wandb.com/) Logging callback. To use this callback `pip install wandb`. + Args: + log_model: Whether to upload model artifact to Wandb + code_file: path of the code you want to upload as artifact to Wandb + """ + + @requires("wandb", "WandbCallback requires wandb to be installed!") + def __init__( + self, + log_model: bool = False, + code_file: Optional[str] = None, + ): + super().__init__() + if wandb.run is None: + raise ValueError("You must call wandb.init() before WandbCallback()") + self._code_file = code_file + self._train_prefix = "train" + self._val_prefix = "val" + self._log_model = log_model + self._setup() + + def _setup(self): + define_metrics() + + def on_fit_start(self): + if self._log_model: + wandb.log_artifact(self.model.learner) + if self._code_file: + wandb.log_artifact(self._code_file) + + def _apply_prefix(self, data: dict, prefix: str): + data = {f"{prefix}/{k}": v for k, v in data.items()} + return data + + def on_train_step_end(self, outputs: dict = None, **_): + # self._step(prefix=self._train_prefix, outputs=outputs) + prefix = "train" + global_step = self.model.tracker.global_step + loss = outputs["loss"].item() + # log train step loss + wandb.log({f"{prefix}/step_loss": loss, "train_step": global_step}, commit=False) + + # log train step metrics + metrics = outputs.get("metrics", {}) + metrics = self._apply_prefix(metrics, prefix) + wandb.log(metrics, commit=False) + + # https://docs.wandb.ai/guides/track/log#how-do-i-use-custom-x-axes + wandb.log({"global_step": global_step}) + + def on_epoch_end(self): + epoch = self.model.tracker.current_epoch + train_loss = self.model.tracker.train_loss + train_metrics = self.model.tracker.train_metrics.to_dict() + val_loss = self.model.tracker.val_loss + val_metrics = self.model.tracker.val_metrics.to_dict() + + train_metrics = self._apply_prefix(train_metrics, prefix=self._train_prefix) + val_metrics = self._apply_prefix(val_metrics, prefix=self._val_prefix) + train_metrics.update({"epoch": epoch}) + val_metrics.update({"epoch": epoch}) + + wandb.log({"train/epoch_loss": train_loss, "epoch": epoch}, commit=False) + wandb.log({"val/epoch_loss": val_loss, "epoch": epoch}, commit=False) + wandb.log(train_metrics, commit=False) + wandb.log(val_metrics, commit=False) + wandb.log({}) diff --git a/gradsflow/core/base.py b/gradsflow/core/base.py index 4fe6b8fc..2ae85dd8 100644 --- a/gradsflow/core/base.py +++ b/gradsflow/core/base.py @@ -13,13 +13,13 @@ # limitations under the License. # from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Dict, Optional import numpy as np import torch -from gradsflow.utility.common import AverageMeter, module_to_cls_index +from gradsflow.utility.common import AverageMeter, GDict, module_to_cls_index class BaseAutoModel(ABC): @@ -44,12 +44,12 @@ def build_model(self, search_space: dict): @dataclass(init=False) class TrackingValues: loss: Optional[AverageMeter] = None # Average loss in a single Epoch - steps: Optional[int] = None + steps: Optional[int] = None # Step per epoch step_loss: Optional[float] = None metrics: Optional[Dict[str, AverageMeter]] = None # Average value in a single Epoch def __init__(self): - self.metrics = {} + self.metrics = GDict() self.loss = AverageMeter(name="loss") def update_loss(self, loss: float): @@ -67,17 +67,11 @@ def update_metrics(self, metrics: Dict[str, float]): self.metrics[key] = AverageMeter(name=key) self.metrics[key].update(value) + def to_dict(self) -> dict: + return asdict(self) + def reset(self): """Values are Reset on start of each `on_*_epoch_start`""" self.loss.reset() for _, metric in self.metrics.items(): metric.reset() - - -@dataclass(init=False) -class BaseTracker: - max_epochs: int = 0 - current_epoch: int = 0 # current train current_epoch - steps_per_epoch: Optional[int] = None - train: TrackingValues = TrackingValues() - val: TrackingValues = TrackingValues() diff --git a/gradsflow/models/model.py b/gradsflow/models/model.py index 6c021ea8..cb985e12 100644 --- a/gradsflow/models/model.py +++ b/gradsflow/models/model.py @@ -146,6 +146,7 @@ def train_one_epoch(self, train_dataloader): steps_per_epoch = tracker.steps_per_epoch for step, batch in enumerate(train_dataloader): + tracker.global_step += 1 tracker.train.steps = step # ----- TRAIN STEP ----- self.callback_runner.on_train_step_start() diff --git a/gradsflow/models/tracker.py b/gradsflow/models/tracker.py index faf03c35..eebf13f4 100644 --- a/gradsflow/models/tracker.py +++ b/gradsflow/models/tracker.py @@ -11,14 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from dataclasses import dataclass +from typing import Dict, List, Optional from loguru import logger from rich import box from rich.table import Table -from gradsflow.core.base import BaseTracker, TrackingValues -from gradsflow.utility.common import to_item +from gradsflow.core.base import TrackingValues +from gradsflow.utility.common import GDict, to_item + + +@dataclass(init=False) +class BaseTracker: + global_step: int = 0 # Global training steps + max_epochs: int = 0 + current_epoch: int = 0 # current train current_epoch + steps_per_epoch: Optional[int] = None + train: TrackingValues = TrackingValues() + val: TrackingValues = TrackingValues() class Tracker(BaseTracker): @@ -27,8 +38,8 @@ class Tracker(BaseTracker): """ def __init__(self): - self.train.metrics = {} - self.val.metrics = {} + self.train.metrics = GDict() + self.val.metrics = GDict() self.logs: List[Dict] = [] def __getitem__(self, key: str): # skipcq: PYL-R1705 @@ -60,11 +71,11 @@ def val_loss(self): return self.val.loss.avg @property - def train_metrics(self): + def train_metrics(self) -> GDict: return self.train.metrics @property - def val_metrics(self): + def val_metrics(self) -> GDict: return self.val.metrics def mode(self, mode) -> TrackingValues: diff --git a/gradsflow/utility/common.py b/gradsflow/utility/common.py index 9ae86017..22a7d224 100644 --- a/gradsflow/utility/common.py +++ b/gradsflow/utility/common.py @@ -75,7 +75,6 @@ class AverageMeter: `val` is the running value, `avg` is the average value over an epoch. """ - name: Optional[str] avg: Optional[float] = 0 def __init__(self, name=None): @@ -138,3 +137,15 @@ def filter_list(arr: List[str], pattern: Optional[str] = None) -> List[str]: p = re.compile(pattern) return [s for s in arr if p.match(s)] + + +class GDict(dict): + def to_dict(self): + clone = self.copy() + for k in clone.keys(): + value = clone[k] + try: + clone[k] = dataclasses.asdict(value) + except TypeError: + continue + return clone diff --git a/setup.cfg b/setup.cfg index 164735bc..166ac04b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ install_requires = torchmetrics >=0.5.0 [options.extras_require] -dev = lightning-flash[image,text] ==0.5.1; codecarbon >=1.2.0; comet_ml +dev = lightning-flash[image,text] ==0.5.1; codecarbon >=1.2.0; comet_ml; wandb test = pytest; coverage; pytest-sugar [options.packages.find] #optional diff --git a/tests/autotasks/__init__.py b/tests/autotasks/__init__.py deleted file mode 100644 index f775d5ed..00000000 --- a/tests/autotasks/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 GradsFlow. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py deleted file mode 100644 index f775d5ed..00000000 --- a/tests/callbacks/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 GradsFlow. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/callbacks/test_logger.py b/tests/callbacks/test_logger.py index bc2755a1..b7655ab1 100644 --- a/tests/callbacks/test_logger.py +++ b/tests/callbacks/test_logger.py @@ -17,8 +17,12 @@ import pytest from gradsflow import AutoDataset -from gradsflow.callbacks import EmissionTrackerCallback, ModelCheckpoint -from gradsflow.callbacks.logger import CometCallback, CSVLogger +from gradsflow.callbacks import ( + CometCallback, + CSVLogger, + EmissionTrackerCallback, + ModelCheckpoint, +) from gradsflow.data.image import image_dataset_from_directory from gradsflow.utility.imports import is_installed from tests.dummies import DummyModel diff --git a/tests/callbacks/test_runner.py b/tests/callbacks/test_runner.py index 6d3d47a6..e5c82434 100644 --- a/tests/callbacks/test_runner.py +++ b/tests/callbacks/test_runner.py @@ -18,14 +18,12 @@ def test_init(dummy_model): - assert isinstance(CallbackRunner(dummy_model, "training").callbacks["TrainEvalCallback"], TrainEvalCallback) with pytest.raises(NotImplementedError): CallbackRunner(dummy_model, "random") def test_append(dummy_model): - cb = CallbackRunner(dummy_model) with pytest.raises(NotImplementedError): cb.append("random") @@ -39,7 +37,6 @@ def test_append(dummy_model): def test_clean(dummy_model): - cb = CallbackRunner(dummy_model, TrainEvalCallback()) cb.clean(keep="TrainEvalCallback") assert cb.callbacks.get("TrainEvalCallback") is not None diff --git a/gradsflow/callbacks/logger/__init__.py b/tests/callbacks/test_wandb.py similarity index 50% rename from gradsflow/callbacks/logger/__init__.py rename to tests/callbacks/test_wandb.py index c25c5f2e..72dd3026 100644 --- a/gradsflow/callbacks/logger/__init__.py +++ b/tests/callbacks/test_wandb.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 GradsFlow. All rights reserved. +# Copyright (c) 2022 GradsFlow. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,5 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .comet import CometCallback -from .logger import CSVLogger +from unittest.mock import Mock, patch + +import pytest + +from gradsflow.callbacks.wandb import WandbCallback +from gradsflow.utility.imports import is_installed + + +@pytest.mark.skipif(not is_installed("wandb"), reason="requires `wandb` installed") +@patch("gradsflow.callbacks.wandb.wandb") +def test_wandbcallback(mock_wandb: Mock, cnn_model, auto_dataset): + model = cnn_model + cb = WandbCallback() + model.compile() + model.fit(auto_dataset, callbacks=cb) + mock_wandb.log.assert_called() diff --git a/tests/conftest.py b/tests/conftest.py index 9dc65105..6899f945 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,13 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # Arrange +from pathlib import Path + import pytest import timm from torch import nn -from gradsflow import Model +from gradsflow import AutoDataset, Model +from gradsflow.data import image_dataset_from_directory from gradsflow.models.tracker import Tracker +data_dir = Path.cwd() +folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" +data = image_dataset_from_directory(folder, transform=True, ray_data=False) + + +@pytest.fixture +def auto_dataset(): + return AutoDataset(train_dataloader=data.dataloader, val_dataloader=data.dataloader) + @pytest.fixture def resnet18(): diff --git a/tests/utility/test_common.py b/tests/utility/test_common.py index 42ca07fc..b06fbe54 100644 --- a/tests/utility/test_common.py +++ b/tests/utility/test_common.py @@ -15,6 +15,7 @@ import torch from gradsflow.utility.common import ( + GDict, default_device, filter_list, get_file_extension, @@ -73,3 +74,12 @@ def test_filter_list(): ] assert filter_list(arr, ".*entropy") == arr[:2] assert filter_list(arr) == arr + + +def test_gdict(): + gdict: GDict[str, str] = GDict() + gdict["hi"] = "hello" + assert gdict["hi"] == "hello" + assert list(gdict.items())[0][0] == "hi" + assert list(gdict.items())[0][1] == "hello" + assert gdict.to_dict() == {"hi": "hello"}