Skip to content

Commit

Permalink
cleanup APIs (#137)
Browse files Browse the repository at this point in the history
* reformat

* refactor callback core

* update

* fix tests

* fixes

* update

* refactor callback_runner.clean

* refactor callbakc_runner.clean
  • Loading branch information
aniketmaurya authored Dec 8, 2021
1 parent 3fe670e commit ddc1c1b
Show file tree
Hide file tree
Showing 25 changed files with 145 additions and 88 deletions.
4 changes: 2 additions & 2 deletions docs/gradsflow/core.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Core Building blocks for AutoML Tasks
Core Building blocks

::: gradsflow.core.base
::: gradsflow.core.base.BaseAutoModel
5 changes: 5 additions & 0 deletions docs/gradsflow/models/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
::: gradsflow.models.utils.available_losses

---

::: gradsflow.models.utils.available_metrics
3 changes: 1 addition & 2 deletions gradsflow/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .callbacks import Callback
from .export import ModelCheckpoint
from .gpu import EmissionTrackerCallback
from .logger import CometCallback, CSVLogger
from .logger.logger import ModelCheckpoint
from .progress import ProgressCallback
from .raytune import report_checkpoint_callback
from .runner import CallbackRunner
Expand Down
39 changes: 0 additions & 39 deletions gradsflow/callbacks/export.py

This file was deleted.

2 changes: 1 addition & 1 deletion gradsflow/callbacks/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from loguru import logger

from gradsflow.callbacks import Callback
from gradsflow.core.callbacks import Callback
from gradsflow.utility.imports import requires


Expand Down
2 changes: 1 addition & 1 deletion gradsflow/callbacks/logger/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if TYPE_CHECKING:
from comet_ml import BaseExperiment

from gradsflow.callbacks import Callback
from gradsflow.core.callbacks import Callback
from gradsflow.utility.imports import requires

CURRENT_FILE = os.path.dirname(os.path.realpath(__file__))
Expand Down
2 changes: 1 addition & 1 deletion gradsflow/callbacks/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd
from loguru import logger

from gradsflow.callbacks.callbacks import Callback
from gradsflow.core.callbacks import Callback
from gradsflow.utility.common import to_item


Expand Down
2 changes: 1 addition & 1 deletion gradsflow/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from rich.progress import BarColumn, Progress, RenderableColumn, TimeRemainingColumn

from .callbacks import Callback
from gradsflow.core.callbacks import Callback


class ProgressCallback(Callback):
Expand Down
2 changes: 1 addition & 1 deletion gradsflow/callbacks/raytune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from ray import tune

from .callbacks import Callback
from gradsflow.core.callbacks import Callback

_METRICS = {
"val_accuracy": "val_accuracy",
Expand Down
15 changes: 10 additions & 5 deletions gradsflow/callbacks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
import typing
from collections import OrderedDict
from typing import Any, Dict, Union
from typing import Any, Dict, List, Optional, Union

from gradsflow.callbacks.callbacks import Callback
from gradsflow.callbacks.progress import ProgressCallback
from gradsflow.callbacks.raytune import TorchTuneCheckpointCallback, TorchTuneReport
from gradsflow.callbacks.training import TrainEvalCallback
from gradsflow.core.callbacks import Callback
from gradsflow.utility import listify

if typing.TYPE_CHECKING:
from gradsflow.models.model import Model
Expand All @@ -40,6 +41,7 @@ def __init__(self, model: "Model", *callbacks: Union[str, Callback]):
for callback in callbacks:
self.append(callback)

# skipcq: W0212
def append(self, callback: Union[str, Callback]):
try:
if isinstance(callback, str):
Expand Down Expand Up @@ -114,8 +116,11 @@ def on_forward_end(self):
for _, callback in self.callbacks.items():
callback.on_forward_end()

def clean(self):
"""Remove all the callbacks except `TrainEvalCallback` added during `model.fit`"""
def clean(self, keep: Optional[Union[List[str], str]] = None):
"""Remove all the callbacks except callback names provided in keep"""
for _, callback in self.callbacks.items():
callback.clean()
self.callbacks = OrderedDict(list(self.callbacks.items())[0:1])
not_keep = set(self.callbacks.keys()) - set(listify(keep))
for key in not_keep:
self.callbacks.pop(key)
# self.callbacks = OrderedDict(list(self.callbacks.items())[0:1])
2 changes: 1 addition & 1 deletion gradsflow/callbacks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .callbacks import Callback
from gradsflow.core.callbacks import Callback


class TrainEvalCallback(Callback):
Expand Down
1 change: 1 addition & 0 deletions gradsflow/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

"""Core Building blocks for Auto Tasks"""
from .callbacks import Callback
1 change: 1 addition & 0 deletions gradsflow/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Optional
Expand Down
16 changes: 14 additions & 2 deletions gradsflow/callbacks/callbacks.py → gradsflow/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from abc import ABC
from typing import Callable, Optional
Expand All @@ -19,7 +31,7 @@
from gradsflow.models.model import Model


def dummy(x=None, *_, **__):
def dummy(x=None, **__):
return x


Expand All @@ -29,7 +41,7 @@ class Callback(ABC):
_events = ("forward", "step", "train_epoch", "val_epoch", "epoch", "fit")
_name: str = "Callback"

def __init__(self, model: Optional["Model"]):
def __init__(self, model: Optional["Model"] = None):
self.model = model

def with_event(self, event_type: str, func: Callable, exception, final_fn: Callable = dummy):
Expand Down
34 changes: 25 additions & 9 deletions gradsflow/data/autodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from loguru import logger
from torch.utils.data import DataLoader, Dataset

from gradsflow.core.data import BaseAutoDataset
from gradsflow.data.base import BaseAutoDataset
from gradsflow.utility.imports import is_installed

from ..utility.common import default_device
Expand All @@ -39,11 +39,25 @@ def __init__(
val_dataset: Optional[Dataset] = None,
datamodule: Optional["pl.LightningDataModule"] = None,
num_classes: Optional[int] = None,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: Optional[bool] = False,
**kwargs
):
super().__init__()
self.device = default_device()
self.setup(train_dataloader, val_dataloader, train_dataset, val_dataset, datamodule, num_classes, **kwargs)
self.setup(
train_dataloader,
val_dataloader,
train_dataset,
val_dataset,
datamodule,
num_classes,
**kwargs,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory
)

def setup(
self,
Expand All @@ -53,7 +67,9 @@ def setup(
val_dataset: Optional[Dataset] = None,
datamodule: Optional["pl.LightningDataModule"] = None,
num_classes: Optional[int] = None,
**kwargs
batch_size: int = 1,
num_workers: int = 0,
pin_memory: Optional[bool] = False,
):

self.datamodule = datamodule
Expand All @@ -66,17 +82,17 @@ def setup(
if not train_dataloader and train_dataset:
self._train_dataloader = DataLoader(
train_dataset,
batch_size=kwargs.get("batch_size", 8),
shuffle=True,
num_workers=kwargs.get("num_workers", 0),
pin_memory=kwargs.get("pin_memory"),
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
)
if not val_dataloader and val_dataset:
self._val_dataloader = DataLoader(
val_dataset,
batch_size=kwargs.get("batch_size", 8),
num_workers=kwargs.get("num_workers", 0),
pin_memory=kwargs.get("pin_memory"),
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
)

if (datamodule or train_dataloader or train_dataset) is None:
Expand Down
3 changes: 0 additions & 3 deletions gradsflow/core/data.py → gradsflow/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import logging
from typing import Union

from torch.utils.data import DataLoader, Dataset

from gradsflow.data.ray_dataset import RayDataset

logger = logging.getLogger("core.data")


@dataclasses.dataclass(init=False)
class Data:
Expand Down
2 changes: 1 addition & 1 deletion gradsflow/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchvision import transforms as T
from torchvision.datasets import ImageFolder

from gradsflow.core.data import Data
from gradsflow.data.base import Data
from gradsflow.data.ray_dataset import RayImageFolder

logger = logging.getLogger("data.image")
Expand Down
10 changes: 3 additions & 7 deletions gradsflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@
from torch import nn
from torchmetrics import Metric

from gradsflow.callbacks import (
Callback,
CallbackRunner,
ProgressCallback,
TrainEvalCallback,
)
from gradsflow.callbacks import CallbackRunner, ProgressCallback, TrainEvalCallback
from gradsflow.core import Callback
from gradsflow.data import AutoDataset
from gradsflow.data.mixins import DataMixin
from gradsflow.models.base import BaseModel
Expand Down Expand Up @@ -258,6 +254,6 @@ def fit(
except KeyboardInterrupt:
logger.info("Keyboard interruption detected")
finally:
self.callback_runner.clean()
self.callback_runner.clean(keep="TrainEvalCallback")

return self.tracker
2 changes: 2 additions & 0 deletions gradsflow/models/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Dict, List

from loguru import logger
from rich import box
from rich.table import Table

Expand Down Expand Up @@ -97,6 +98,7 @@ def create_table(self) -> Table:
return table

def reset(self):
logger.info("Reset Tracker")
self.max_epochs = 0
self.current_epoch = 0
self.steps_per_epoch = None
Expand Down
38 changes: 30 additions & 8 deletions gradsflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import torch
import torchmetrics
from torch import nn
from torchmetrics import Metric

from gradsflow.utility.common import module_to_cls_index
from gradsflow.utility.common import filter_list, module_to_cls_index

SCALAR = Union[torch.Tensor, np.float, float, int]
_nn_classes = module_to_cls_index(nn)
Expand All @@ -31,9 +31,31 @@
metrics = {k.lower(): v for k, v in metrics.items()}


def available_losses() -> List[str]:
return list(losses.keys())


def available_metrics() -> List[str]:
return list(metrics.keys())
def available_losses(pattern: Optional[str] = None) -> List[str]:
"""Get available loss functions
```python
>> available_losses()
>> # crossentropy, binarycrossentropy, mae, ...
# Filter available losses with regex pattern
>> available_losses("m.e)
>> # ["mae", "mse"]
```
"""
loss_keys = list(losses.keys())
return filter_list(loss_keys, pattern)


def available_metrics(pattern: Optional[str] = None) -> List[str]:
"""Get available Metrics
```python
>> available_metrics()
>> # accuracy, F1, RMSE, ...
# Filter available metrics with regex pattern
>> available_metrics("acc.*")
>> # ["accuracy"]
```
"""
metric_keys = list(metrics.keys())
return filter_list(metric_keys, pattern)
Loading

0 comments on commit ddc1c1b

Please sign in to comment.