Skip to content

Commit

Permalink
📊 Wandb Implementation (#168)
Browse files Browse the repository at this point in the history
* wandb implementation

* wandb implementation

* update

* style

* update

* refactor

* define metrics

* fixes

* add wandb
  • Loading branch information
aniketmaurya authored Jan 14, 2022
1 parent 43762fa commit dd81035
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 61 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ mkdocs-meta-descriptions-plugin
jupyter_contrib_nbextensions
comet_ml
lightning-flash[image,text]>=0.5.1
wandb
5 changes: 3 additions & 2 deletions gradsflow/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .comet import CometCallback
from .gpu import EmissionTrackerCallback
from .logger import CometCallback, CSVLogger
from .logger.logger import ModelCheckpoint
from .logger import CSVLogger, ModelCheckpoint
from .progress import ProgressCallback
from .raytune import report_checkpoint_callback
from .runner import CallbackRunner
from .training import TrainEvalCallback
from .wandb import WandbCallback
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
# Copyright (c) 2022 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# Copyright (c) 2021 GradsFlow. All rights reserved.
# Copyright (c) 2022 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
106 changes: 106 additions & 0 deletions gradsflow/callbacks/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2022 GradsFlow. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict, List, Optional

import wandb

from gradsflow.callbacks.base import Callback
from gradsflow.utility.imports import requires

CURRENT_FILE = os.path.dirname(os.path.realpath(__file__))


def define_metrics():
min_max_def: Dict[str, List[str]] = {
"min": ["train/step_loss", "train/epoch_loss", "val/epoch_loss"],
"max": ["train/acc*", "val/acc*"],
}
for summary, metric_list in min_max_def.items():
for metric in metric_list:
if "epoch" in metric or "val" in metric:
wandb.define_metric(metric, summary=summary, step_metric="epoch")
wandb.define_metric("*", step_metric="global_step")


class WandbCallback(Callback):
"""
[Weights & Biases](https://www.wandb.com/) Logging callback. To use this callback `pip install wandb`.
Args:
log_model: Whether to upload model artifact to Wandb
code_file: path of the code you want to upload as artifact to Wandb
"""

@requires("wandb", "WandbCallback requires wandb to be installed!")
def __init__(
self,
log_model: bool = False,
code_file: Optional[str] = None,
):
super().__init__()
if wandb.run is None:
raise ValueError("You must call wandb.init() before WandbCallback()")
self._code_file = code_file
self._train_prefix = "train"
self._val_prefix = "val"
self._log_model = log_model
self._setup()

def _setup(self):
define_metrics()

def on_fit_start(self):
if self._log_model:
wandb.log_artifact(self.model.learner)
if self._code_file:
wandb.log_artifact(self._code_file)

def _apply_prefix(self, data: dict, prefix: str):
data = {f"{prefix}/{k}": v for k, v in data.items()}
return data

def on_train_step_end(self, outputs: dict = None, **_):
# self._step(prefix=self._train_prefix, outputs=outputs)
prefix = "train"
global_step = self.model.tracker.global_step
loss = outputs["loss"].item()
# log train step loss
wandb.log({f"{prefix}/step_loss": loss, "train_step": global_step}, commit=False)

# log train step metrics
metrics = outputs.get("metrics", {})
metrics = self._apply_prefix(metrics, prefix)
wandb.log(metrics, commit=False)

# https://docs.wandb.ai/guides/track/log#how-do-i-use-custom-x-axes
wandb.log({"global_step": global_step})

def on_epoch_end(self):
epoch = self.model.tracker.current_epoch
train_loss = self.model.tracker.train_loss
train_metrics = self.model.tracker.train_metrics.to_dict()
val_loss = self.model.tracker.val_loss
val_metrics = self.model.tracker.val_metrics.to_dict()

train_metrics = self._apply_prefix(train_metrics, prefix=self._train_prefix)
val_metrics = self._apply_prefix(val_metrics, prefix=self._val_prefix)
train_metrics.update({"epoch": epoch})
val_metrics.update({"epoch": epoch})

wandb.log({"train/epoch_loss": train_loss, "epoch": epoch}, commit=False)
wandb.log({"val/epoch_loss": val_loss, "epoch": epoch}, commit=False)
wandb.log(train_metrics, commit=False)
wandb.log(val_metrics, commit=False)
wandb.log({})
20 changes: 7 additions & 13 deletions gradsflow/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.
#
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Dict, Optional

import numpy as np
import torch

from gradsflow.utility.common import AverageMeter, module_to_cls_index
from gradsflow.utility.common import AverageMeter, GDict, module_to_cls_index


class BaseAutoModel(ABC):
Expand All @@ -44,12 +44,12 @@ def build_model(self, search_space: dict):
@dataclass(init=False)
class TrackingValues:
loss: Optional[AverageMeter] = None # Average loss in a single Epoch
steps: Optional[int] = None
steps: Optional[int] = None # Step per epoch
step_loss: Optional[float] = None
metrics: Optional[Dict[str, AverageMeter]] = None # Average value in a single Epoch

def __init__(self):
self.metrics = {}
self.metrics = GDict()
self.loss = AverageMeter(name="loss")

def update_loss(self, loss: float):
Expand All @@ -67,17 +67,11 @@ def update_metrics(self, metrics: Dict[str, float]):
self.metrics[key] = AverageMeter(name=key)
self.metrics[key].update(value)

def to_dict(self) -> dict:
return asdict(self)

def reset(self):
"""Values are Reset on start of each `on_*_epoch_start`"""
self.loss.reset()
for _, metric in self.metrics.items():
metric.reset()


@dataclass(init=False)
class BaseTracker:
max_epochs: int = 0
current_epoch: int = 0 # current train current_epoch
steps_per_epoch: Optional[int] = None
train: TrackingValues = TrackingValues()
val: TrackingValues = TrackingValues()
1 change: 1 addition & 0 deletions gradsflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def train_one_epoch(self, train_dataloader):
steps_per_epoch = tracker.steps_per_epoch

for step, batch in enumerate(train_dataloader):
tracker.global_step += 1
tracker.train.steps = step
# ----- TRAIN STEP -----
self.callback_runner.on_train_step_start()
Expand Down
25 changes: 18 additions & 7 deletions gradsflow/models/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from dataclasses import dataclass
from typing import Dict, List, Optional

from loguru import logger
from rich import box
from rich.table import Table

from gradsflow.core.base import BaseTracker, TrackingValues
from gradsflow.utility.common import to_item
from gradsflow.core.base import TrackingValues
from gradsflow.utility.common import GDict, to_item


@dataclass(init=False)
class BaseTracker:
global_step: int = 0 # Global training steps
max_epochs: int = 0
current_epoch: int = 0 # current train current_epoch
steps_per_epoch: Optional[int] = None
train: TrackingValues = TrackingValues()
val: TrackingValues = TrackingValues()


class Tracker(BaseTracker):
Expand All @@ -27,8 +38,8 @@ class Tracker(BaseTracker):
"""

def __init__(self):
self.train.metrics = {}
self.val.metrics = {}
self.train.metrics = GDict()
self.val.metrics = GDict()
self.logs: List[Dict] = []

def __getitem__(self, key: str): # skipcq: PYL-R1705
Expand Down Expand Up @@ -60,11 +71,11 @@ def val_loss(self):
return self.val.loss.avg

@property
def train_metrics(self):
def train_metrics(self) -> GDict:
return self.train.metrics

@property
def val_metrics(self):
def val_metrics(self) -> GDict:
return self.val.metrics

def mode(self, mode) -> TrackingValues:
Expand Down
13 changes: 12 additions & 1 deletion gradsflow/utility/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class AverageMeter:
`val` is the running value, `avg` is the average value over an epoch.
"""

name: Optional[str]
avg: Optional[float] = 0

def __init__(self, name=None):
Expand Down Expand Up @@ -138,3 +137,15 @@ def filter_list(arr: List[str], pattern: Optional[str] = None) -> List[str]:

p = re.compile(pattern)
return [s for s in arr if p.match(s)]


class GDict(dict):
def to_dict(self):
clone = self.copy()
for k in clone.keys():
value = clone[k]
try:
clone[k] = dataclasses.asdict(value)
except TypeError:
continue
return clone
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install_requires =
torchmetrics >=0.5.0

[options.extras_require]
dev = lightning-flash[image,text] ==0.5.1; codecarbon >=1.2.0; comet_ml
dev = lightning-flash[image,text] ==0.5.1; codecarbon >=1.2.0; comet_ml; wandb
test = pytest; coverage; pytest-sugar

[options.packages.find] #optional
Expand Down
13 changes: 0 additions & 13 deletions tests/autotasks/__init__.py

This file was deleted.

13 changes: 0 additions & 13 deletions tests/callbacks/__init__.py

This file was deleted.

8 changes: 6 additions & 2 deletions tests/callbacks/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import pytest

from gradsflow import AutoDataset
from gradsflow.callbacks import EmissionTrackerCallback, ModelCheckpoint
from gradsflow.callbacks.logger import CometCallback, CSVLogger
from gradsflow.callbacks import (
CometCallback,
CSVLogger,
EmissionTrackerCallback,
ModelCheckpoint,
)
from gradsflow.data.image import image_dataset_from_directory
from gradsflow.utility.imports import is_installed
from tests.dummies import DummyModel
Expand Down
3 changes: 0 additions & 3 deletions tests/callbacks/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@


def test_init(dummy_model):

assert isinstance(CallbackRunner(dummy_model, "training").callbacks["TrainEvalCallback"], TrainEvalCallback)
with pytest.raises(NotImplementedError):
CallbackRunner(dummy_model, "random")


def test_append(dummy_model):

cb = CallbackRunner(dummy_model)
with pytest.raises(NotImplementedError):
cb.append("random")
Expand All @@ -39,7 +37,6 @@ def test_append(dummy_model):


def test_clean(dummy_model):

cb = CallbackRunner(dummy_model, TrainEvalCallback())
cb.clean(keep="TrainEvalCallback")
assert cb.callbacks.get("TrainEvalCallback") is not None
Loading

0 comments on commit dd81035

Please sign in to comment.