-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* WIP: Tensorflow MNIST use-case * UPDATE: Tensorflow MNIST version * ADD: Backend * ADD: Use-case init * FIX: Paths and downloading of the data * FIX: Paths and downloading of the data * ADD: Setup, Config update * ADD: Setup, Config update * UPDATE: File movement into itwinai * FIX: Move utils from tensorflow to global folder * FIX: Add setup into torch Executable * ADD: MNIST Torch Use-case * FIX: Formatting
- Loading branch information
Showing
64 changed files
with
2,185 additions
and
1,107 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
|
||
class Executable(metaclass=ABCMeta): | ||
@abstractmethod | ||
def execute(self, args): | ||
pass | ||
|
||
@abstractmethod | ||
def setup(self, args): | ||
pass | ||
|
||
|
||
class Trainer(Executable): | ||
@abstractmethod | ||
def train(self, data): | ||
pass | ||
|
||
|
||
class DataGetter(Executable): | ||
@abstractmethod | ||
def load(self, args): | ||
pass | ||
|
||
|
||
class DataPreproc(Executable): | ||
@abstractmethod | ||
def preproc(self, args): | ||
pass | ||
|
||
|
||
class StatGetter(Executable): | ||
@abstractmethod | ||
def stats(self, args): | ||
pass | ||
|
||
|
||
class Evaluator(Executable): | ||
@abstractmethod | ||
def evaluate(self, args): | ||
pass | ||
|
||
|
||
class Saver(Executable): | ||
@abstractmethod | ||
def save(self, args): | ||
pass | ||
|
||
|
||
class Executor(Executable): | ||
@abstractmethod | ||
def execute(self, pipeline): | ||
pass | ||
|
||
@abstractmethod | ||
def setup(self, pipeline): | ||
pass | ||
|
||
|
||
class Logger(metaclass=ABCMeta): | ||
@abstractmethod | ||
def log(self): | ||
pass |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from ..components import Executor | ||
|
||
|
||
class TensorflowExecutor(Executor): | ||
def __init__(self, args): | ||
self.args = args | ||
|
||
def execute(self, pipeline): | ||
args = None | ||
for executable in pipeline: | ||
args = executable.execute(args) | ||
|
||
def setup(self, pipeline): | ||
for executable in pipeline: | ||
executable.setup(self.args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import wandb | ||
import mlflow | ||
import mlflow.keras | ||
|
||
from ..components import Logger | ||
|
||
|
||
class WanDBLogger(Logger): | ||
def __init__(self): | ||
pass | ||
|
||
def log(self): | ||
wandb.init(config={"bs": 12}) | ||
|
||
|
||
class MLFlowLogger(Logger): | ||
def __init__(self): | ||
mlflow.set_tracking_uri("http://127.0.0.1:5000") | ||
mlflow.set_experiment("test-experiment") | ||
|
||
def log(self): | ||
mlflow.keras.autolog() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import keras | ||
import json | ||
|
||
|
||
def model_to_json(model: keras.Model, filepath: str): | ||
with open(filepath, "w") as f: | ||
json.dump(model.to_json(), f) | ||
|
||
|
||
def model_from_json(filepath: str) -> keras.Model: | ||
with open(filepath, "r") as f: | ||
config = json.load(f) | ||
return keras.models.model_from_json(config) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from ..components import Executor | ||
|
||
|
||
class TorchExecutor(Executor): | ||
def __init__(self): | ||
pass | ||
|
||
def execute(self, pipeline): | ||
args = None | ||
for executable in pipeline: | ||
args = executable.execute(args) | ||
|
||
def setup(self, pipeline): | ||
args = None | ||
for executable in pipeline: | ||
args = executable.setup(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import wandb | ||
import mlflow | ||
import mlflow.keras | ||
|
||
from ..components import Logger | ||
|
||
|
||
class WanDBLogger(Logger): | ||
def __init__(self): | ||
pass | ||
|
||
def log(self): | ||
wandb.init(config={"bs": 12}) | ||
|
||
|
||
class MLFlowLogger(Logger): | ||
def __init__(self): | ||
mlflow.set_tracking_uri("http://127.0.0.1:5000") | ||
mlflow.set_experiment("test-experiment") | ||
|
||
def log(self): | ||
mlflow.pytorch.autolog() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import yaml | ||
|
||
|
||
# Parse (part of) YAML loaded in memory | ||
def parse_pipe_config(yaml_file, parser): | ||
with open(yaml_file, "r", encoding="utf-8") as f: | ||
try: | ||
config = yaml.safe_load(f) | ||
except yaml.YAMLError as exc: | ||
print(exc) | ||
raise exc | ||
|
||
return parser.parse_object(config) |
Oops, something went wrong.