diff --git a/docs/gradsflow/core.md b/docs/gradsflow/core.md index 38b84d0f..d7886b3c 100644 --- a/docs/gradsflow/core.md +++ b/docs/gradsflow/core.md @@ -1,3 +1,3 @@ -Core Building blocks for AutoML Tasks +Core Building blocks -::: gradsflow.core.base +::: gradsflow.core.base.BaseAutoModel diff --git a/docs/gradsflow/models/utils.md b/docs/gradsflow/models/utils.md new file mode 100644 index 00000000..1806e22b --- /dev/null +++ b/docs/gradsflow/models/utils.md @@ -0,0 +1,5 @@ +::: gradsflow.models.utils.available_losses + +--- + +::: gradsflow.models.utils.available_metrics diff --git a/gradsflow/callbacks/__init__.py b/gradsflow/callbacks/__init__.py index 1df45f08..bd82249e 100644 --- a/gradsflow/callbacks/__init__.py +++ b/gradsflow/callbacks/__init__.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .callbacks import Callback -from .export import ModelCheckpoint from .gpu import EmissionTrackerCallback from .logger import CometCallback, CSVLogger +from .logger.logger import ModelCheckpoint from .progress import ProgressCallback from .raytune import report_checkpoint_callback from .runner import CallbackRunner diff --git a/gradsflow/callbacks/export.py b/gradsflow/callbacks/export.py deleted file mode 100644 index 3ab1c43c..00000000 --- a/gradsflow/callbacks/export.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2021 GradsFlow. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from pathlib import Path -from typing import Optional - -from gradsflow.callbacks.callbacks import Callback - - -class ModelCheckpoint(Callback): - def __init__(self, filename: Optional[str] = None, path: str = os.getcwd(), save_extra: bool = False): - """ - Saves Model checkpoint - Args: - filename: name of checkpoint - path: folder path location of the model checkpoint - save_extra: whether to save extra details like tracker - """ - super().__init__(model=None) - filename = filename or "model" - self.path = path - self._dst = Path(path) / Path(filename) - self.save_extra = save_extra - - def on_epoch_end(self): - epoch = self.model.tracker.current_epoch - path = f"{self._dst}_epoch={epoch}_.pt" - self.model.save(path, save_extra=self.save_extra) diff --git a/gradsflow/callbacks/gpu.py b/gradsflow/callbacks/gpu.py index 8902931b..b3551088 100644 --- a/gradsflow/callbacks/gpu.py +++ b/gradsflow/callbacks/gpu.py @@ -14,7 +14,7 @@ from loguru import logger -from gradsflow.callbacks import Callback +from gradsflow.core.callbacks import Callback from gradsflow.utility.imports import requires diff --git a/gradsflow/callbacks/logger/comet.py b/gradsflow/callbacks/logger/comet.py index 0b527d8a..9beea673 100644 --- a/gradsflow/callbacks/logger/comet.py +++ b/gradsflow/callbacks/logger/comet.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from comet_ml import BaseExperiment -from gradsflow.callbacks import Callback +from gradsflow.core.callbacks import Callback from gradsflow.utility.imports import requires CURRENT_FILE = os.path.dirname(os.path.realpath(__file__)) diff --git a/gradsflow/callbacks/logger/logger.py b/gradsflow/callbacks/logger/logger.py index 2349e6f1..6fd4544c 100644 --- a/gradsflow/callbacks/logger/logger.py +++ b/gradsflow/callbacks/logger/logger.py @@ -18,7 +18,7 @@ import pandas as pd from loguru import logger -from gradsflow.callbacks.callbacks import Callback +from gradsflow.core.callbacks import Callback from gradsflow.utility.common import to_item diff --git a/gradsflow/callbacks/progress.py b/gradsflow/callbacks/progress.py index 2bb4354e..ce202de2 100644 --- a/gradsflow/callbacks/progress.py +++ b/gradsflow/callbacks/progress.py @@ -15,7 +15,7 @@ from rich.progress import BarColumn, Progress, RenderableColumn, TimeRemainingColumn -from .callbacks import Callback +from gradsflow.core.callbacks import Callback class ProgressCallback(Callback): diff --git a/gradsflow/callbacks/raytune.py b/gradsflow/callbacks/raytune.py index 7a2ceeec..a8818a4f 100644 --- a/gradsflow/callbacks/raytune.py +++ b/gradsflow/callbacks/raytune.py @@ -17,7 +17,7 @@ import torch from ray import tune -from .callbacks import Callback +from gradsflow.core.callbacks import Callback _METRICS = { "val_accuracy": "val_accuracy", diff --git a/gradsflow/callbacks/runner.py b/gradsflow/callbacks/runner.py index 441bab15..9e6b89e1 100644 --- a/gradsflow/callbacks/runner.py +++ b/gradsflow/callbacks/runner.py @@ -13,12 +13,13 @@ # limitations under the License. import typing from collections import OrderedDict -from typing import Any, Dict, Union +from typing import Any, Dict, List, Optional, Union -from gradsflow.callbacks.callbacks import Callback from gradsflow.callbacks.progress import ProgressCallback from gradsflow.callbacks.raytune import TorchTuneCheckpointCallback, TorchTuneReport from gradsflow.callbacks.training import TrainEvalCallback +from gradsflow.core.callbacks import Callback +from gradsflow.utility import listify if typing.TYPE_CHECKING: from gradsflow.models.model import Model @@ -40,6 +41,7 @@ def __init__(self, model: "Model", *callbacks: Union[str, Callback]): for callback in callbacks: self.append(callback) + # skipcq: W0212 def append(self, callback: Union[str, Callback]): try: if isinstance(callback, str): @@ -114,8 +116,11 @@ def on_forward_end(self): for _, callback in self.callbacks.items(): callback.on_forward_end() - def clean(self): - """Remove all the callbacks except `TrainEvalCallback` added during `model.fit`""" + def clean(self, keep: Optional[Union[List[str], str]] = None): + """Remove all the callbacks except callback names provided in keep""" for _, callback in self.callbacks.items(): callback.clean() - self.callbacks = OrderedDict(list(self.callbacks.items())[0:1]) + not_keep = set(self.callbacks.keys()) - set(listify(keep)) + for key in not_keep: + self.callbacks.pop(key) + # self.callbacks = OrderedDict(list(self.callbacks.items())[0:1]) diff --git a/gradsflow/callbacks/training.py b/gradsflow/callbacks/training.py index 9567b1e8..3c86256c 100644 --- a/gradsflow/callbacks/training.py +++ b/gradsflow/callbacks/training.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .callbacks import Callback +from gradsflow.core.callbacks import Callback class TrainEvalCallback(Callback): diff --git a/gradsflow/core/__init__.py b/gradsflow/core/__init__.py index b10996df..a5e33f00 100644 --- a/gradsflow/core/__init__.py +++ b/gradsflow/core/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. """Core Building blocks for Auto Tasks""" +from .callbacks import Callback diff --git a/gradsflow/core/base.py b/gradsflow/core/base.py index 68118941..85efcbe9 100644 --- a/gradsflow/core/base.py +++ b/gradsflow/core/base.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, Optional diff --git a/gradsflow/callbacks/callbacks.py b/gradsflow/core/callbacks.py similarity index 85% rename from gradsflow/callbacks/callbacks.py rename to gradsflow/core/callbacks.py index 6c813650..58e5471f 100644 --- a/gradsflow/callbacks/callbacks.py +++ b/gradsflow/core/callbacks.py @@ -11,6 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import typing from abc import ABC from typing import Callable, Optional @@ -19,7 +31,7 @@ from gradsflow.models.model import Model -def dummy(x=None, *_, **__): +def dummy(x=None, **__): return x @@ -29,7 +41,7 @@ class Callback(ABC): _events = ("forward", "step", "train_epoch", "val_epoch", "epoch", "fit") _name: str = "Callback" - def __init__(self, model: Optional["Model"]): + def __init__(self, model: Optional["Model"] = None): self.model = model def with_event(self, event_type: str, func: Callable, exception, final_fn: Callable = dummy): diff --git a/gradsflow/data/autodata.py b/gradsflow/data/autodata.py index 6e289f4e..c45ea19c 100644 --- a/gradsflow/data/autodata.py +++ b/gradsflow/data/autodata.py @@ -17,7 +17,7 @@ from loguru import logger from torch.utils.data import DataLoader, Dataset -from gradsflow.core.data import BaseAutoDataset +from gradsflow.data.base import BaseAutoDataset from gradsflow.utility.imports import is_installed from ..utility.common import default_device @@ -39,11 +39,25 @@ def __init__( val_dataset: Optional[Dataset] = None, datamodule: Optional["pl.LightningDataModule"] = None, num_classes: Optional[int] = None, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: Optional[bool] = False, **kwargs ): super().__init__() self.device = default_device() - self.setup(train_dataloader, val_dataloader, train_dataset, val_dataset, datamodule, num_classes, **kwargs) + self.setup( + train_dataloader, + val_dataloader, + train_dataset, + val_dataset, + datamodule, + num_classes, + **kwargs, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory + ) def setup( self, @@ -53,7 +67,9 @@ def setup( val_dataset: Optional[Dataset] = None, datamodule: Optional["pl.LightningDataModule"] = None, num_classes: Optional[int] = None, - **kwargs + batch_size: int = 1, + num_workers: int = 0, + pin_memory: Optional[bool] = False, ): self.datamodule = datamodule @@ -66,17 +82,17 @@ def setup( if not train_dataloader and train_dataset: self._train_dataloader = DataLoader( train_dataset, - batch_size=kwargs.get("batch_size", 8), shuffle=True, - num_workers=kwargs.get("num_workers", 0), - pin_memory=kwargs.get("pin_memory"), + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, ) if not val_dataloader and val_dataset: self._val_dataloader = DataLoader( val_dataset, - batch_size=kwargs.get("batch_size", 8), - num_workers=kwargs.get("num_workers", 0), - pin_memory=kwargs.get("pin_memory"), + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, ) if (datamodule or train_dataloader or train_dataset) is None: diff --git a/gradsflow/core/data.py b/gradsflow/data/base.py similarity index 95% rename from gradsflow/core/data.py rename to gradsflow/data/base.py index 91736f4c..31abd76f 100644 --- a/gradsflow/core/data.py +++ b/gradsflow/data/base.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import dataclasses -import logging from typing import Union from torch.utils.data import DataLoader, Dataset from gradsflow.data.ray_dataset import RayDataset -logger = logging.getLogger("core.data") - @dataclasses.dataclass(init=False) class Data: diff --git a/gradsflow/data/image.py b/gradsflow/data/image.py index a2f6a4cc..d19ea694 100644 --- a/gradsflow/data/image.py +++ b/gradsflow/data/image.py @@ -21,7 +21,7 @@ from torchvision import transforms as T from torchvision.datasets import ImageFolder -from gradsflow.core.data import Data +from gradsflow.data.base import Data from gradsflow.data.ray_dataset import RayImageFolder logger = logging.getLogger("data.image") diff --git a/gradsflow/models/model.py b/gradsflow/models/model.py index 802b5d50..ff10fc5d 100644 --- a/gradsflow/models/model.py +++ b/gradsflow/models/model.py @@ -19,12 +19,8 @@ from torch import nn from torchmetrics import Metric -from gradsflow.callbacks import ( - Callback, - CallbackRunner, - ProgressCallback, - TrainEvalCallback, -) +from gradsflow.callbacks import CallbackRunner, ProgressCallback, TrainEvalCallback +from gradsflow.core import Callback from gradsflow.data import AutoDataset from gradsflow.data.mixins import DataMixin from gradsflow.models.base import BaseModel @@ -258,6 +254,6 @@ def fit( except KeyboardInterrupt: logger.info("Keyboard interruption detected") finally: - self.callback_runner.clean() + self.callback_runner.clean(keep="TrainEvalCallback") return self.tracker diff --git a/gradsflow/models/tracker.py b/gradsflow/models/tracker.py index ab867e81..dceab55e 100644 --- a/gradsflow/models/tracker.py +++ b/gradsflow/models/tracker.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict, List +from loguru import logger from rich import box from rich.table import Table @@ -97,6 +98,7 @@ def create_table(self) -> Table: return table def reset(self): + logger.info("Reset Tracker") self.max_epochs = 0 self.current_epoch = 0 self.steps_per_epoch = None diff --git a/gradsflow/models/utils.py b/gradsflow/models/utils.py index 1480a8a3..f36a53dd 100644 --- a/gradsflow/models/utils.py +++ b/gradsflow/models/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Union +from typing import Callable, Dict, List, Optional, Union import numpy as np import torch @@ -19,7 +19,7 @@ from torch import nn from torchmetrics import Metric -from gradsflow.utility.common import module_to_cls_index +from gradsflow.utility.common import filter_list, module_to_cls_index SCALAR = Union[torch.Tensor, np.float, float, int] _nn_classes = module_to_cls_index(nn) @@ -31,9 +31,31 @@ metrics = {k.lower(): v for k, v in metrics.items()} -def available_losses() -> List[str]: - return list(losses.keys()) - - -def available_metrics() -> List[str]: - return list(metrics.keys()) +def available_losses(pattern: Optional[str] = None) -> List[str]: + """Get available loss functions + ```python + >> available_losses() + >> # crossentropy, binarycrossentropy, mae, ... + + # Filter available losses with regex pattern + >> available_losses("m.e) + >> # ["mae", "mse"] + ``` + """ + loss_keys = list(losses.keys()) + return filter_list(loss_keys, pattern) + + +def available_metrics(pattern: Optional[str] = None) -> List[str]: + """Get available Metrics + ```python + >> available_metrics() + >> # accuracy, F1, RMSE, ... + + # Filter available metrics with regex pattern + >> available_metrics("acc.*") + >> # ["accuracy"] + ``` + """ + metric_keys = list(metrics.keys()) + return filter_list(metric_keys, pattern) diff --git a/gradsflow/utility/common.py b/gradsflow/utility/common.py index f9db87c0..b9067b4b 100644 --- a/gradsflow/utility/common.py +++ b/gradsflow/utility/common.py @@ -14,6 +14,7 @@ import dataclasses import inspect import os +import re import sys import warnings from glob import glob @@ -121,3 +122,18 @@ def to_item(data: Union[torch.Tensor, Iterable, Dict]) -> Union[int, float, str, warnings.warn("to_item didn't convert any value.") return data + + +def filter_list(arr: List[str], pattern: Optional[str] = None) -> List[str]: + """Filter a list of strings with given pattern + ```python + >> arr = ['crossentropy', 'binarycrossentropy', 'softmax', 'mae',] + >> filter_list(arr, ".*entropy*") + >> # ["crossentropy", "binarycrossentropy"] + ``` + """ + if pattern is None: + return arr + + p = re.compile(pattern) + return [s for s in arr if p.match(s)] diff --git a/mkdocs.yml b/mkdocs.yml index 53bf5d4b..8f14fa43 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,16 +75,17 @@ nav: - Pix2Pix GAN Code Explanation: 'examples/nbs/Pix2Pix_explained_with_code.ipynb' - 🤗 HuggingFace Training Example: 'examples/nbs/2021-10-3-huggingface-training.ipynb' - API References: - - Core: gradsflow/core - Model: - gradsflow/models/base.md - gradsflow/models/model.md - gradsflow/models/tracker.md + - gradsflow/models/utils.md - Tuner: gradsflow/tuner - AutoTasks: - gradsflow/autotasks/autotasks.md - gradsflow/autotasks/engine.md - Data: gradsflow/data - Callbacks: gradsflow/callbacks - - utility: gradsflow/utility.md + - Core: gradsflow/core + - utility: gradsflow/utils.md - Release Notes: 'CHANGELOG.md' diff --git a/tests/callbacks/test_runner.py b/tests/callbacks/test_runner.py index 2f7aa345..b42d7194 100644 --- a/tests/callbacks/test_runner.py +++ b/tests/callbacks/test_runner.py @@ -13,7 +13,8 @@ # limitations under the License. import pytest -from gradsflow.callbacks import Callback, CallbackRunner, TrainEvalCallback +from gradsflow.callbacks import CallbackRunner, TrainEvalCallback +from gradsflow.core.callbacks import Callback def test_init(): @@ -41,3 +42,13 @@ def forward(self): for cb_name, cb in cb.callbacks.items(): assert isinstance(cb_name, str) assert isinstance(cb, Callback) + + +def test_clean(): + class DummyModel: + def forward(self): + return 1 + + cb = CallbackRunner(DummyModel(), TrainEvalCallback()) + cb.clean(keep="TrainEvalCallback") + assert cb.callbacks.get("TrainEvalCallback") is not None diff --git a/tests/data/test_image_data.py b/tests/data/test_image_data.py index 42efd8fc..48cf9181 100644 --- a/tests/data/test_image_data.py +++ b/tests/data/test_image_data.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path -from gradsflow.core.data import Data +from gradsflow.data.base import Data from gradsflow.data.image import image_dataset_from_directory data_dir = Path.cwd() diff --git a/tests/utility/test_common.py b/tests/utility/test_common.py index 65c0bfca..42ca07fc 100644 --- a/tests/utility/test_common.py +++ b/tests/utility/test_common.py @@ -16,6 +16,7 @@ from gradsflow.utility.common import ( default_device, + filter_list, get_file_extension, get_files, listify, @@ -61,3 +62,14 @@ def test_to_item(): x = {"input": torch.rand(10)} assert isinstance(to_item(x), dict) assert isinstance(to_item(x)["input"], np.ndarray) + + +def test_filter_list(): + arr = [ + "crossentropy", + "binarycrossentropy", + "softmax", + "mae", + ] + assert filter_list(arr, ".*entropy") == arr[:2] + assert filter_list(arr) == arr