Skip to content

Commit

Permalink
Merge pull request #730 from talesa/feature/more_parameterizable_xres…
Browse files Browse the repository at this point in the history
…net1dplus

Allow for parameterizing block_szs in XResNet1dPlus
  • Loading branch information
oguiza authored Apr 2, 2023
2 parents db55074 + 20ccf5a commit 881aa5a
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion nbs/036_models.InceptionTimePlus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
" self.seq_len = seq_len\n",
" if custom_head is not None: \n",
" if isinstance(custom_head, nn.Module): head = custom_head\n",
" head = custom_head(self.head_nf, c_out, seq_len)\n",
" else: head = custom_head(self.head_nf, c_out, seq_len)\n",
" else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool, \n",
" fc_dropout=fc_dropout, bn=bn, y_range=y_range)\n",
" \n",
Expand Down
4 changes: 2 additions & 2 deletions nbs/057_models.MINIROCKETPlus_Pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
" self.head_nf = num_features\n",
" if custom_head is not None: \n",
" if isinstance(custom_head, nn.Module): head = custom_head\n",
" head = custom_head(self.head_nf, c_out, 1)\n",
" else: head = custom_head(self.head_nf, c_out, 1)\n",
" else:\n",
" layers = [Flatten()]\n",
" if bn:\n",
Expand Down Expand Up @@ -506,7 +506,7 @@
" self.head_nf = num_features\n",
" if custom_head is not None: \n",
" if isinstance(custom_head, nn.Module): head = custom_head\n",
" head = custom_head(self.head_nf, c_out, 1)\n",
" else: head = custom_head(self.head_nf, c_out, 1)\n",
" else:\n",
" layers = [Flatten()]\n",
" if bn:\n",
Expand Down
6 changes: 3 additions & 3 deletions nbs/059_models.XResNet1dPlus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"class XResNet1dPlus(nn.Sequential):\n",
" @delegates(ResBlock1dPlus)\n",
" def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),\n",
" widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, **kwargs):\n",
" widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):\n",
"\n",
" store_attr('block,expansion,act_cls,ks')\n",
" n_out = c_out or n_out # added for compatibility\n",
Expand All @@ -55,14 +55,14 @@
" act=act_cls)\n",
" for i in range(3)]\n",
"\n",
" block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]\n",
" block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]\n",
" block_szs = [64//expansion] + block_szs\n",
" blocks = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)\n",
" backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)\n",
" self.head_nf = block_szs[-1]*expansion\n",
" if custom_head is not None: \n",
" if isinstance(custom_head, nn.Module): head = custom_head\n",
" head = custom_head(self.head_nf, n_out, seq_len)\n",
" else: head = custom_head(self.head_nf, n_out, seq_len)\n",
" else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))\n",
" super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
" self._init_cnn(self)\n",
Expand Down
2 changes: 1 addition & 1 deletion tsai/models/InceptionTimePlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, c_in, c_out, seq_len=None, nf=32, nb_filters=None,
self.seq_len = seq_len
if custom_head is not None:
if isinstance(custom_head, nn.Module): head = custom_head
head = custom_head(self.head_nf, c_out, seq_len)
else: head = custom_head(self.head_nf, c_out, seq_len)
else: head = self.create_head(self.head_nf, c_out, seq_len, flatten=flatten, concat_pool=concat_pool,
fc_dropout=fc_dropout, bn=bn, y_range=y_range)

Expand Down
4 changes: 2 additions & 2 deletions tsai/models/MINIROCKETPlus_Pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(self, c_in, c_out, seq_len, num_features=10_000, max_dilations_per_
self.head_nf = num_features
if custom_head is not None:
if isinstance(custom_head, nn.Module): head = custom_head
head = custom_head(self.head_nf, c_out, 1)
else: head = custom_head(self.head_nf, c_out, 1)
else:
layers = [Flatten()]
if bn:
Expand Down Expand Up @@ -319,7 +319,7 @@ def __init__(self, c_in, c_out, seq_len, num_features=10_000, max_dilations_per_
self.head_nf = num_features
if custom_head is not None:
if isinstance(custom_head, nn.Module): head = custom_head
head = custom_head(self.head_nf, c_out, 1)
else: head = custom_head(self.head_nf, c_out, 1)
else:
layers = [Flatten()]
if bn:
Expand Down
6 changes: 3 additions & 3 deletions tsai/models/XResNet1dPlus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class XResNet1dPlus(nn.Sequential):
@delegates(ResBlock1dPlus)
def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),
widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, **kwargs):
widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):

store_attr('block,expansion,act_cls,ks')
n_out = c_out or n_out # added for compatibility
Expand All @@ -24,14 +24,14 @@ def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropo
act=act_cls)
for i in range(3)]

block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]
block_szs = [64//expansion] + block_szs
blocks = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)
backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)
self.head_nf = block_szs[-1]*expansion
if custom_head is not None:
if isinstance(custom_head, nn.Module): head = custom_head
head = custom_head(self.head_nf, n_out, seq_len)
else: head = custom_head(self.head_nf, n_out, seq_len)
else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))
super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))
self._init_cnn(self)
Expand Down

0 comments on commit 881aa5a

Please sign in to comment.