Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add snapshots for model summaries #106

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Run Flake8
run: flake8
- name: Black code style
run: black . --check --target-version py36 --exclude 'build/|buck-out/|dzzist/|_build/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/|larq/snapshots/'
run: black . --check --target-version py36 --exclude 'build/|buck-out/|dzzist/|_build/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/|tests/snapshots/'
- name: Check import order with isort
run: isort --check-only --diff
- name: Type check with PyType
Expand Down
7 changes: 7 additions & 0 deletions larq_zoo/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class TrainLarqZooModel(Experiment):
# Use a per-batch progress bar (as opposed to per-epoch).
use_progress_bar: bool = Field(False)

# Whether this experiment is compilation-only (i.e. no training)
dry_run: bool = Field(False)

# How often to run validation.
validation_frequency: int = Field(1)

Expand Down Expand Up @@ -119,6 +122,10 @@ def run(self):
self.model.load_weights(str(self.model_path))
print(f"Loaded model from epoch {initial_epoch}.")

if self.dry_run:
click.secho("Dry run: Not starting training!", fg="green")
return

click.secho(str(self))

self.model.fit(
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def readme():
"pytest-xdist==1.31.0",
"Pillow==7.1.2",
"scipy==1.4.1",
"snapshottest>=0.5.1",
"tensorflow_datasets>=3.1.0",
],
},
Expand Down
126 changes: 81 additions & 45 deletions tests/models_test.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,47 @@
import functools
import os
from pathlib import Path

import larq as lq
import numpy as np
import pytest
from tensorflow import keras
import tensorflow as tf
from zookeeper import cli

import larq_zoo as lqz


def keras_test(func):
"""Function wrapper to clean up after TensorFlow tests.
# Arguments
func: test function to clean up after.
# Returns
A function wrapping the input function.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
output = func(*args, **kwargs)
keras.backend.clear_session()
return output

return wrapper


def parametrize(func):
func = keras_test(func)
return pytest.mark.parametrize(
"app,last_feature_dim",
[
(lqz.literature.BinaryAlexNet, 256),
(lqz.literature.BiRealNet, 512),
(lqz.literature.BinaryResNetE18, 512),
(lqz.literature.BinaryDenseNet28, 576),
(lqz.literature.BinaryDenseNet37, 640),
(lqz.literature.BinaryDenseNet37Dilated, 640),
(lqz.literature.BinaryDenseNet45, 800),
(lqz.literature.MeliusNet22, 512),
(lqz.literature.XNORNet, 4096),
(lqz.literature.DoReFaNet, 256),
(lqz.literature.RealToBinaryNet, 512),
(lqz.sota.QuickNet, 512),
(lqz.sota.QuickNetLarge, 512),
(lqz.sota.QuickNetXL, 512),
],
)(func)
@pytest.fixture(autouse=True)
def run_around_tests():
tf.keras.backend.clear_session()
yield


parametrize = pytest.mark.parametrize(
"app,last_feature_dim",
[
(lqz.literature.BinaryAlexNet, 256),
(lqz.literature.BiRealNet, 512),
(lqz.literature.BinaryResNetE18, 512),
(lqz.literature.BinaryDenseNet28, 576),
(lqz.literature.BinaryDenseNet37, 640),
(lqz.literature.BinaryDenseNet37Dilated, 640),
(lqz.literature.BinaryDenseNet45, 800),
(lqz.literature.MeliusNet22, 512),
(lqz.literature.XNORNet, 4096),
(lqz.literature.DoReFaNet, 256),
(lqz.literature.RealToBinaryNet, 512),
(lqz.sota.QuickNet, 512),
(lqz.sota.QuickNetLarge, 512),
(lqz.sota.QuickNetXL, 512),
],
)


@parametrize
def test_prediction(app, last_feature_dim):
file = os.path.join(os.path.dirname(__file__), "fixtures", "elephant.jpg")
img = keras.preprocessing.image.load_img(file)
img = keras.preprocessing.image.img_to_array(img)
img = tf.keras.preprocessing.image.load_img(
Path() / "tests" / "fixtures" / "elephant.jpg"
)
img = tf.keras.preprocessing.image.img_to_array(img)
img = lqz.preprocess_input(img)
model = app(weights="imagenet")
preds = model.predict(np.expand_dims(img, axis=0))
Expand All @@ -74,7 +63,7 @@ def test_basic(app, last_feature_dim):

@parametrize
def test_keras_tensor_input(app, last_feature_dim):
input_tensor = keras.layers.Input(shape=(224, 224, 3))
input_tensor = tf.keras.layers.Input(shape=(224, 224, 3))
model = app(weights=None, input_tensor=input_tensor)
assert model.output_shape == (None, 1000)

Expand All @@ -97,3 +86,50 @@ def test_no_top_variable_shape_4(app, last_feature_dim):
input_shape = (None, None, 4)
model = app(weights=None, include_top=False, input_shape=input_shape)
assert model.output_shape == (None, None, None, last_feature_dim)


@parametrize
def test_model_summary(app, last_feature_dim):
input_tensor = tf.keras.layers.Input(shape=(224, 224, 3))
model = app(weights=None, input_tensor=input_tensor)

class PrintToVariable:
output = ""

def __call__(self, x):
self.output += f"{x}\n"

capture = PrintToVariable()
lq.models.summary(model, print_fn=capture)
leonoverweel marked this conversation as resolved.
Show resolved Hide resolved

summary_file = (
Path()
/ "tests"
/ "snapshots"
/ "model_summaries"
/ f"{app.__name__}_{last_feature_dim}.txt"
)

if summary_file.exists():
with open(summary_file, "r") as file:
content = file.read()
assert content.lower() == capture.output.lower()
else:
with open(summary_file, "w") as file:
file.write(capture.output)
leonoverweel marked this conversation as resolved.
Show resolved Hide resolved
raise FileNotFoundError(
f"Could not find snapshot {summary_file}, so generated a new summary. "
"If this was intentional, re-running the tests locally will make them pass."
)


@pytest.mark.parametrize("command_name", cli.commands.keys())
def test_experiments(command_name: str, snapshot, capsys):
try:
cli.commands[command_name](
["dataset=DummyOxfordFlowers", "batch_size=2", "dry_run=True"]
)
# Catch successful SystemExit to prevent exception
except SystemExit as e:
if e.code != 0:
raise
Empty file added tests/snapshots/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions tests/snapshots/model_summaries/BiRealNet_512.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
+birealnet18 stats--------------------------------------------------------------------------------------------------+
| Layer Input prec. Outputs # 1-bit # 32-bit Memory 1-bit MACs 32-bit MACs |
| (bit) x 1 x 1 (kB) |
+-------------------------------------------------------------------------------------------------------------------+
| input_1 - ((None, 224, 224, 3),) 0 0 0 ? ? |
| conv2d - (-1, 112, 112, 64) 0 9408 36.75 0 118013952 |
| batch_normalization - (-1, 112, 112, 64) 0 128 0.50 0 0 |
| max_pooling2d - (-1, 56, 56, 64) 0 0 0 0 0 |
| quant_conv2d 1 (-1, 56, 56, 64) 36864 0 4.50 115605504 0 |
| batch_normalization_1 - (-1, 56, 56, 64) 0 128 0.50 0 0 |
| add - (-1, 56, 56, 64) 0 0 0 ? ? |
| quant_conv2d_1 1 (-1, 56, 56, 64) 36864 0 4.50 115605504 0 |
| batch_normalization_2 - (-1, 56, 56, 64) 0 128 0.50 0 0 |
| add_1 - (-1, 56, 56, 64) 0 0 0 ? ? |
| quant_conv2d_2 1 (-1, 56, 56, 64) 36864 0 4.50 115605504 0 |
| batch_normalization_3 - (-1, 56, 56, 64) 0 128 0.50 0 0 |
| add_2 - (-1, 56, 56, 64) 0 0 0 ? ? |
| quant_conv2d_3 1 (-1, 56, 56, 64) 36864 0 4.50 115605504 0 |
| batch_normalization_4 - (-1, 56, 56, 64) 0 128 0.50 0 0 |
| add_3 - (-1, 56, 56, 64) 0 0 0 ? ? |
| average_pooling2d - (-1, 28, 28, 64) 0 0 0 0 0 |
| quant_conv2d_4 1 (-1, 28, 28, 128) 73728 0 9.00 57802752 0 |
| conv2d_1 - (-1, 28, 28, 128) 0 8192 32.00 0 6422528 |
| batch_normalization_6 - (-1, 28, 28, 128) 0 256 1.00 0 0 |
| batch_normalization_5 - (-1, 28, 28, 128) 0 256 1.00 0 0 |
| add_4 - (-1, 28, 28, 128) 0 0 0 ? ? |
| quant_conv2d_5 1 (-1, 28, 28, 128) 147456 0 18.00 115605504 0 |
| batch_normalization_7 - (-1, 28, 28, 128) 0 256 1.00 0 0 |
| add_5 - (-1, 28, 28, 128) 0 0 0 ? ? |
| quant_conv2d_6 1 (-1, 28, 28, 128) 147456 0 18.00 115605504 0 |
| batch_normalization_8 - (-1, 28, 28, 128) 0 256 1.00 0 0 |
| add_6 - (-1, 28, 28, 128) 0 0 0 ? ? |
| quant_conv2d_7 1 (-1, 28, 28, 128) 147456 0 18.00 115605504 0 |
| batch_normalization_9 - (-1, 28, 28, 128) 0 256 1.00 0 0 |
| add_7 - (-1, 28, 28, 128) 0 0 0 ? ? |
| average_pooling2d_1 - (-1, 14, 14, 128) 0 0 0 0 0 |
| quant_conv2d_8 1 (-1, 14, 14, 256) 294912 0 36.00 57802752 0 |
| conv2d_2 - (-1, 14, 14, 256) 0 32768 128.00 0 6422528 |
| batch_normalization_11 - (-1, 14, 14, 256) 0 512 2.00 0 0 |
| batch_normalization_10 - (-1, 14, 14, 256) 0 512 2.00 0 0 |
| add_8 - (-1, 14, 14, 256) 0 0 0 ? ? |
| quant_conv2d_9 1 (-1, 14, 14, 256) 589824 0 72.00 115605504 0 |
| batch_normalization_12 - (-1, 14, 14, 256) 0 512 2.00 0 0 |
| add_9 - (-1, 14, 14, 256) 0 0 0 ? ? |
| quant_conv2d_10 1 (-1, 14, 14, 256) 589824 0 72.00 115605504 0 |
| batch_normalization_13 - (-1, 14, 14, 256) 0 512 2.00 0 0 |
| add_10 - (-1, 14, 14, 256) 0 0 0 ? ? |
| quant_conv2d_11 1 (-1, 14, 14, 256) 589824 0 72.00 115605504 0 |
| batch_normalization_14 - (-1, 14, 14, 256) 0 512 2.00 0 0 |
| add_11 - (-1, 14, 14, 256) 0 0 0 ? ? |
| average_pooling2d_2 - (-1, 7, 7, 256) 0 0 0 0 0 |
| quant_conv2d_12 1 (-1, 7, 7, 512) 1179648 0 144.00 57802752 0 |
| conv2d_3 - (-1, 7, 7, 512) 0 131072 512.00 0 6422528 |
| batch_normalization_16 - (-1, 7, 7, 512) 0 1024 4.00 0 0 |
| batch_normalization_15 - (-1, 7, 7, 512) 0 1024 4.00 0 0 |
| add_12 - (-1, 7, 7, 512) 0 0 0 ? ? |
| quant_conv2d_13 1 (-1, 7, 7, 512) 2359296 0 288.00 115605504 0 |
| batch_normalization_17 - (-1, 7, 7, 512) 0 1024 4.00 0 0 |
| add_13 - (-1, 7, 7, 512) 0 0 0 ? ? |
| quant_conv2d_14 1 (-1, 7, 7, 512) 2359296 0 288.00 115605504 0 |
| batch_normalization_18 - (-1, 7, 7, 512) 0 1024 4.00 0 0 |
| add_14 - (-1, 7, 7, 512) 0 0 0 ? ? |
| quant_conv2d_15 1 (-1, 7, 7, 512) 2359296 0 288.00 115605504 0 |
| batch_normalization_19 - (-1, 7, 7, 512) 0 1024 4.00 0 0 |
| add_15 - (-1, 7, 7, 512) 0 0 0 ? ? |
| average_pooling2d_3 - (-1, 1, 1, 512) 0 0 0 0 0 |
| flatten - (-1, 512) 0 0 0 0 0 |
| dense - (-1, 1000) 0 513000 2003.91 0 512000 |
| activation - (-1, 1000) 0 0 0 ? ? |
+-------------------------------------------------------------------------------------------------------------------+
| Total 10985472 704040 4091.16 1676279808 137793536 |
+-------------------------------------------------------------------------------------------------------------------+
+birealnet18 summary--------------------------+
| Total params 11.7 M |
| Trainable params 11.7 M |
| Non-trainable params 9.6 k |
| Model size 4.00 MiB |
| Model size (8-bit FP weights) 1.98 MiB |
| Float-32 Equivalent 44.59 MiB |
| Compression Ratio of Memory 0.09 |
| Number of MACs 1.81 B |
| Ratio of MACs that are binarized 0.9240 |
+---------------------------------------------+
40 changes: 40 additions & 0 deletions tests/snapshots/model_summaries/BinaryAlexNet_256.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
+binary_alexnet stats----------------------------------------------------------------------------------------------+
| Layer Input prec. Outputs # 1-bit # 32-bit Memory 1-bit MACs 32-bit MACs |
| (bit) x 1 x 1 (kB) |
+------------------------------------------------------------------------------------------------------------------+
| input_1 - ((None, 224, 224, 3),) 0 0 0 ? ? |
| quant_conv2d - (-1, 56, 56, 64) 23232 0 2.84 0 72855552 |
| max_pooling2d - (-1, 27, 27, 64) 0 0 0 0 0 |
| batch_normalization - (-1, 27, 27, 64) 0 128 0.50 0 0 |
| quant_conv2d_1 1 (-1, 27, 27, 192) 307200 0 37.50 223948800 0 |
| max_pooling2d_1 - (-1, 13, 13, 192) 0 0 0 0 0 |
| batch_normalization_1 - (-1, 13, 13, 192) 0 384 1.50 0 0 |
| quant_conv2d_2 1 (-1, 13, 13, 384) 663552 0 81.00 112140288 0 |
| batch_normalization_2 - (-1, 13, 13, 384) 0 768 3.00 0 0 |
| quant_conv2d_3 1 (-1, 13, 13, 384) 1327104 0 162.00 224280576 0 |
| batch_normalization_3 - (-1, 13, 13, 384) 0 768 3.00 0 0 |
| quant_conv2d_4 1 (-1, 13, 13, 256) 884736 0 108.00 149520384 0 |
| max_pooling2d_2 - (-1, 6, 6, 256) 0 0 0 0 0 |
| batch_normalization_4 - (-1, 6, 6, 256) 0 512 2.00 0 0 |
| flatten - (-1, 9216) 0 0 0 0 0 |
| quant_dense 1 (-1, 4096) 37748736 0 4608.00 37748736 0 |
| batch_normalization_5 - (-1, 4096) 0 8192 32.00 0 0 |
| quant_dense_1 1 (-1, 4096) 16777216 0 2048.00 16777216 0 |
| batch_normalization_6 - (-1, 4096) 0 8192 32.00 0 0 |
| quant_dense_2 1 (-1, 1000) 4096000 0 500.00 4096000 0 |
| batch_normalization_7 - (-1, 1000) 0 2000 7.81 0 0 |
| activation - (-1, 1000) 0 0 0 ? ? |
+------------------------------------------------------------------------------------------------------------------+
| Total 61827776 20944 7629.15 768512000 72855552 |
+------------------------------------------------------------------------------------------------------------------+
+binary_alexnet summary------------------------+
| Total params 61.8 M |
| Trainable params 61.8 M |
| Non-trainable params 20.9 k |
| Model size 7.45 MiB |
| Model size (8-bit FP weights) 7.39 MiB |
| Float-32 Equivalent 235.93 MiB |
| Compression Ratio of Memory 0.03 |
| Number of MACs 841 M |
| Ratio of MACs that are binarized 0.9134 |
+----------------------------------------------+
Loading