diff --git a/vectorization/models/fully_conv_net.py b/vectorization/models/fully_conv_net.py index 41422ca..10d46c3 100644 --- a/vectorization/models/fully_conv_net.py +++ b/vectorization/models/fully_conv_net.py @@ -45,14 +45,16 @@ def __init__(self, hidden_dim=128, input_channels=1, pooling='max'): nn.LeakyReLU(), self.pooling((2, 2)), - nn.Conv2d(128, 128, kernel_size=(3, 3)), + nn.Conv2d(128, 128, kernel_size=(3, 3), bias=False), nn.Dropout(p=0.2), nn.BatchNorm2d(128), nn.LeakyReLU(), nn.Conv2d(128, 128, kernel_size=(3, 3)), + nn.LeakyReLU(), self.pooling((2, 2)), nn.Conv2d(128, 64, kernel_size=(3, 3)), - nn.Conv2d(64, 64, kernel_size=(3, 3)), + nn.LeakyReLU(), + nn.Conv2d(64, 64, kernel_size=(3, 3), bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(), self.adpooling((10, 6)), #FIXME (n_primitives, n_predicted_params) diff --git a/vectorization/modules/conv_modules.py b/vectorization/modules/conv_modules.py index 67cbfe2..e39e703 100644 --- a/vectorization/modules/conv_modules.py +++ b/vectorization/modules/conv_modules.py @@ -6,6 +6,7 @@ from vectorization.modules.maybe_module import MaybeModule from .base import ParameterizedModule +# from vectorization.models.common import resnet_model_creator, vgg_model_creator class ResnetBlock(nn.Module): def __init__(self, resample=None):