Skip to content

Commit

Permalink
✅ test automodel (#40)
Browse files Browse the repository at this point in the history
* test automodel

* updates

* add CI autolabel

* min deps

* fix
  • Loading branch information
aniketmaurya authored Aug 29, 2021
1 parent 3e611a0 commit 8f20162
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
name: Bug report
name: 🐛 Bug report
about: Create a report to help us improve
title: ''
labels: ''
Expand Down
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
name: Feature request
name: 🚀 Feature request
about: Suggest an idea for this project
title: ''
labels: ''
Expand Down
5 changes: 5 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ example:

test:
- tests/**/*

CI:
- .github/**/*
- "*.yaml"
- "*.yml"
12 changes: 7 additions & 5 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ Fixes # (issue)

#### Type of change
<!-- Please delete options that are not relevant. -->
- [ ] Documentation Update
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] 📚 Documentation Update
- [ ] 🧪 Tests Cases
- [ ] 🐞 Bug fix (non-breaking change which fixes an issue)
- [ ] 🔬 New feature (non-breaking change which adds functionality)
- [ ] 🚨 Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] 📝 This change requires a documentation update


#### Checklist
Expand All @@ -22,3 +23,4 @@ Fixes # (issue)
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] Did you update CHANGELOG in case of a major change?
2 changes: 1 addition & 1 deletion gradsflow/core/autoclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, x):
return self.model(x)

# noinspection PyTypeChecker
def _create_hparam_config(self) -> Dict[str, str]:
def _create_search_space(self) -> Dict[str, str]:
"""Create hyperparameter config from `ray.tune`
Returns:
Expand Down
4 changes: 2 additions & 2 deletions gradsflow/core/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
)

@abstractmethod
def _create_hparam_config(self) -> Dict[str, str]:
def _create_search_space(self) -> Dict[str, str]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -144,7 +144,7 @@ def hp_tune(
trainer_config = trainer_config or {}
ray_config = ray_config or {}

search_space = self._create_hparam_config()
search_space = self._create_search_space()
trainable = self.objective

analysis = tune.run(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
requires-python = ">=3.7"
requires = [
"smart_open==5.1",
"lightning-flash[all]==0.4.0",
"lightning-flash[image,text]==0.4.0",
"pytorch-lightning==1.4.0",
"ray[tune]==1.6",
"loguru~=0.5"
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
import torch
from flash.image import ImageClassificationData

from gradsflow.core.automodel import AutoModel
Expand All @@ -34,3 +36,21 @@ def test_build_model():
model = AutoModel(datamodule)
with pytest.raises(NotImplementedError):
model.build_model({"lr": 1})


def test_create_search_space():
model = AutoModel(datamodule)
with pytest.raises(NotImplementedError):
model._create_search_space()


@patch("gradsflow.core.automodel.pl")
def test_objective(mock_pl):
optimization_metric = "val_accuracy"
model = AutoModel(datamodule, optimization_metric=optimization_metric)

model.build_model = MagicMock()
trainer = mock_pl.Trainer = MagicMock()
trainer.callback_metrics = {optimization_metric: torch.as_tensor([1])}

model.objective({}, {})

0 comments on commit 8f20162

Please sign in to comment.