Skip to content

Commit

Permalink
fix path (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 authored Jun 15, 2018
1 parent 32e382b commit 2dd88e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
6 changes: 3 additions & 3 deletions encoding/models/encnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
>>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet50', pretrained, aux=False, **kwargs)
return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)

def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
Expand All @@ -185,7 +185,7 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
>>> model = get_encnet_resnet101_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet101', pretrained, aux=False, **kwargs)
return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=False, **kwargs)

def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
Expand All @@ -204,4 +204,4 @@ def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwarg
>>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_encnet('ade20k', 'resnet50', pretrained, aux=True, **kwargs)
return get_encnet('ade20k', 'resnet50', pretrained, root=root, aux=True, **kwargs)
7 changes: 0 additions & 7 deletions encoding/nn/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,6 @@ def __init__(self, in_channels, norm_layer, up_kwargs):
# bilinear upsample options
self._up_kwargs = up_kwargs

def _cat_each(self, x, feat1, feat2, feat3, feat4):
assert(len(x) == len(feat1))
z = []
for i in range(len(x)):
z.append(torch.cat((x[i], feat1[i], feat2[i], feat3[i], feat4[i]), 1))
return z

def forward(self, x):
_, _, h, w = x.size()
feat1 = F.upsample(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
Expand Down

0 comments on commit 2dd88e5

Please sign in to comment.