diff --git a/nbs/036_models.InceptionTimePlus.ipynb b/nbs/036_models.InceptionTimePlus.ipynb index 86556f84e..38ea50eac 100644 --- a/nbs/036_models.InceptionTimePlus.ipynb +++ b/nbs/036_models.InceptionTimePlus.ipynb @@ -155,7 +155,7 @@ " self.seq_len = seq_len\n", " if custom_head is not None: \n", " if isinstance(custom_head, nn.Module): head = custom_head\n", - " head = custom_head(self.head_nf, c_out, seq_len)\n", + " else: head = custom_head(self.head_nf, c_out, seq_len)\n", " else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool, \n", " fc_dropout=fc_dropout, bn=bn, y_range=y_range)\n", " \n", diff --git a/nbs/057_models.MINIROCKETPlus_Pytorch.ipynb b/nbs/057_models.MINIROCKETPlus_Pytorch.ipynb index 52c90aaa8..74c123ada 100644 --- a/nbs/057_models.MINIROCKETPlus_Pytorch.ipynb +++ b/nbs/057_models.MINIROCKETPlus_Pytorch.ipynb @@ -250,7 +250,7 @@ " self.head_nf = num_features\n", " if custom_head is not None: \n", " if isinstance(custom_head, nn.Module): head = custom_head\n", - " head = custom_head(self.head_nf, c_out, 1)\n", + " else: head = custom_head(self.head_nf, c_out, 1)\n", " else:\n", " layers = [Flatten()]\n", " if bn:\n", @@ -506,7 +506,7 @@ " self.head_nf = num_features\n", " if custom_head is not None: \n", " if isinstance(custom_head, nn.Module): head = custom_head\n", - " head = custom_head(self.head_nf, c_out, 1)\n", + " else: head = custom_head(self.head_nf, c_out, 1)\n", " else:\n", " layers = [Flatten()]\n", " if bn:\n", diff --git a/nbs/059_models.XResNet1dPlus.ipynb b/nbs/059_models.XResNet1dPlus.ipynb index e3c302dac..255332c50 100644 --- a/nbs/059_models.XResNet1dPlus.ipynb +++ b/nbs/059_models.XResNet1dPlus.ipynb @@ -45,7 +45,7 @@ "class XResNet1dPlus(nn.Sequential):\n", " @delegates(ResBlock1dPlus)\n", " def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),\n", - " widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, **kwargs):\n", + " widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):\n", "\n", " store_attr('block,expansion,act_cls,ks')\n", " n_out = c_out or n_out # added for compatibility\n", @@ -55,14 +55,14 @@ " act=act_cls)\n", " for i in range(3)]\n", "\n", - " block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]\n", + " block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]\n", " block_szs = [64//expansion] + block_szs\n", " blocks = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)\n", " backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)\n", " self.head_nf = block_szs[-1]*expansion\n", " if custom_head is not None: \n", " if isinstance(custom_head, nn.Module): head = custom_head\n", - " head = custom_head(self.head_nf, n_out, seq_len)\n", + " else: head = custom_head(self.head_nf, n_out, seq_len)\n", " else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))\n", " super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n", " self._init_cnn(self)\n", diff --git a/tsai/models/InceptionTimePlus.py b/tsai/models/InceptionTimePlus.py index 1f5feeffc..31722f3af 100644 --- a/tsai/models/InceptionTimePlus.py +++ b/tsai/models/InceptionTimePlus.py @@ -108,7 +108,7 @@ def __init__(self, c_in, c_out, seq_len=None, nf=32, nb_filters=None, self.seq_len = seq_len if custom_head is not None: if isinstance(custom_head, nn.Module): head = custom_head - head = custom_head(self.head_nf, c_out, seq_len) + else: head = custom_head(self.head_nf, c_out, seq_len) else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool, fc_dropout=fc_dropout, bn=bn, y_range=y_range) diff --git a/tsai/models/MINIROCKETPlus_Pytorch.py b/tsai/models/MINIROCKETPlus_Pytorch.py index dabd6f0d0..7cba68239 100644 --- a/tsai/models/MINIROCKETPlus_Pytorch.py +++ b/tsai/models/MINIROCKETPlus_Pytorch.py @@ -205,7 +205,7 @@ def __init__(self, c_in, c_out, seq_len, num_features=10_000, max_dilations_per_ self.head_nf = num_features if custom_head is not None: if isinstance(custom_head, nn.Module): head = custom_head - head = custom_head(self.head_nf, c_out, 1) + else: head = custom_head(self.head_nf, c_out, 1) else: layers = [Flatten()] if bn: @@ -319,7 +319,7 @@ def __init__(self, c_in, c_out, seq_len, num_features=10_000, max_dilations_per_ self.head_nf = num_features if custom_head is not None: if isinstance(custom_head, nn.Module): head = custom_head - head = custom_head(self.head_nf, c_out, 1) + else: head = custom_head(self.head_nf, c_out, 1) else: layers = [Flatten()] if bn: diff --git a/tsai/models/XResNet1dPlus.py b/tsai/models/XResNet1dPlus.py index 45784b922..647e0d14b 100644 --- a/tsai/models/XResNet1dPlus.py +++ b/tsai/models/XResNet1dPlus.py @@ -14,7 +14,7 @@ class XResNet1dPlus(nn.Sequential): @delegates(ResBlock1dPlus) def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64), - widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, **kwargs): + widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs): store_attr('block,expansion,act_cls,ks') n_out = c_out or n_out # added for compatibility @@ -24,14 +24,14 @@ def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropo act=act_cls) for i in range(3)] - block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)] + block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))] block_szs = [64//expansion] + block_szs blocks = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs) backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks) self.head_nf = block_szs[-1]*expansion if custom_head is not None: if isinstance(custom_head, nn.Module): head = custom_head - head = custom_head(self.head_nf, n_out, seq_len) + else: head = custom_head(self.head_nf, n_out, seq_len) else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out)) super().__init__(OrderedDict([('backbone', backbone), ('head', head)])) self._init_cnn(self)