Skip to content

Commit

Permalink
Final pass to fix lint / remove mmcv config objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Labatut committed Sep 27, 2023
1 parent 53eeb55 commit 736919b
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions dinov2/hub/depth/decode_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DepthBaseDecodeHead(nn.Module):
Default: 1e-3.
max_depth (int): Max depth in dataset setting.
Default: None.
norm_cfg (dict|None): Config of norm layers.
norm_layer (dict|None): Norm layers.
Default: None.
classify (bool): Whether predict depth in a cls.-reg. manner.
Default: False.
Expand All @@ -69,7 +69,7 @@ def __init__(
align_corners=False,
min_depth=1e-3,
max_depth=None,
norm_cfg=None,
norm_layer=None,
classify=False,
n_bins=256,
bins_strategy="UD",
Expand All @@ -86,7 +86,7 @@ def __init__(
self.align_corners = align_corners
self.min_depth = min_depth
self.max_depth = max_depth
self.norm_cfg = norm_cfg
self.norm_layer = norm_layer
self.classify = classify
self.n_bins = n_bins
self.scale_up = scale_up
Expand Down Expand Up @@ -326,11 +326,11 @@ class ConvModule(nn.Module):
groups (int): Number of blocked connections from input channels to
output channels. Same as that in ``nn._ConvNd``.
bias (bool | str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
False. Default: "auto".
conv_cfg (dict): Config dict for convolution layer. Default: None,
conv_layer (nn.Module): Convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
norm_layer (nn.Module): Normalization layer. Default: None.
act_layer (nn.Module): Activation layer. Default: nn.ReLU.
inplace (bool): Whether to use inplace mode for activation.
Default: True.
Expand Down Expand Up @@ -359,20 +359,18 @@ def __init__(
dilation=1,
groups=1,
bias="auto",
conv_cfg=None,
norm_cfg=None,
conv_layer=nn.Conv2d,
norm_layer=None,
act_layer=nn.ReLU,
inplace=True,
with_spectral_norm=False,
padding_mode="zeros",
order=("conv", "norm", "act"),
):
super(ConvModule, self).__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
official_padding_mode = ["zeros", "circular"]
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.conv_layer = conv_layer
self.norm_layer = norm_layer
self.act_layer = act_layer
self.inplace = inplace
self.with_spectral_norm = with_spectral_norm
Expand All @@ -381,21 +379,24 @@ def __init__(
assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(["conv", "norm", "act"])

self.with_norm = norm_cfg is not None
self.with_norm = norm_layer is not None
self.with_activation = act_layer is not None
# if the conv layer is before a norm layer, bias is unnecessary.
if bias == "auto":
bias = not self.with_norm
self.with_bias = bias

if self.with_explicit_padding:
pad_cfg = dict(type=padding_mode)
self.padding_layer = build_padding_layer(pad_cfg, padding)
if padding_mode == "zeros":
padding_layer = nn.ZeroPad2d
else:
raise AssertionError(f"Unsupported padding mode: {padding_mode}")
self.pad = padding_layer(padding)

# reset padding to 0 for conv module
conv_padding = 0 if self.with_explicit_padding else padding
# build convolution layer
self.conv = nn.Conv2d(
self.conv = self.conv_layer(
in_channels,
out_channels,
kernel_size,
Expand Down Expand Up @@ -426,9 +427,12 @@ def __init__(
norm_channels = out_channels
else:
norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
self.add_module(self.norm_name, norm)
norm = partial(norm_layer, num_features=norm_channels)
self.add_module("norm", norm)
if self.with_bias:
from torch.nnModules.batchnorm import _BatchNorm
from torch.nnModules.instancenorm import _InstanceNorm

if isinstance(norm, (_BatchNorm, _InstanceNorm)):
warnings.warn("Unnecessary conv bias before batch/instance norm")
else:
Expand Down Expand Up @@ -483,7 +487,7 @@ def forward(self, x, activate=True, norm=True):
for layer in self.order:
if layer == "conv":
if self.with_explicit_padding:
x = self.padding_layer(x)
x = self.pad(x)
x = self.conv(x)
elif layer == "norm" and norm and self.with_norm:
x = self.norm(x)
Expand Down Expand Up @@ -598,12 +602,12 @@ class PreActResidualConvUnit(nn.Module):
Args:
in_channels (int): number of channels in the input feature map.
act_layer (nn.Module): activation layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_layer (nn.Module): norm layer.
stride (int): stride of the first block. Default: 1
dilation (int): dilation rate for convs layers. Default: 1.
"""

def __init__(self, in_channels, act_layer, norm_cfg, stride=1, dilation=1):
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
super(PreActResidualConvUnit, self).__init__()

self.conv1 = ConvModule(
Expand All @@ -613,7 +617,7 @@ def __init__(self, in_channels, act_layer, norm_cfg, stride=1, dilation=1):
stride=stride,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg,
norm_layer=norm_layer,
act_layer=act_layer,
bias=False,
order=("act", "conv", "norm"),
Expand All @@ -624,7 +628,7 @@ def __init__(self, in_channels, act_layer, norm_cfg, stride=1, dilation=1):
in_channels,
3,
padding=1,
norm_cfg=norm_cfg,
norm_layer=norm_layer,
act_layer=act_layer,
bias=False,
order=("act", "conv", "norm"),
Expand All @@ -642,14 +646,14 @@ class FeatureFusionBlock(nn.Module):
Args:
in_channels (int): Input channels.
act_layer (nn.Module): activation layer for ResidualConvUnit.
norm_cfg (dict): Config dict for normalization layer.
norm_layer (nn.Module): normalization layer.
expand (bool): Whether expand the channels in post process block.
Default: False.
align_corners (bool): align_corner setting for bilinear upsample.
Default: True.
"""

def __init__(self, in_channels, act_layer, norm_cfg, expand=False, align_corners=True):
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
super(FeatureFusionBlock, self).__init__()

self.in_channels = in_channels
Expand All @@ -663,10 +667,10 @@ def __init__(self, in_channels, act_layer, norm_cfg, expand=False, align_corners
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)

self.res_conv_unit1 = PreActResidualConvUnit(
in_channels=self.in_channels, act_layer=act_layer, norm_cfg=norm_cfg
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
)
self.res_conv_unit2 = PreActResidualConvUnit(
in_channels=self.in_channels, act_layer=act_layer, norm_cfg=norm_cfg
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
)

def forward(self, *inputs):
Expand Down Expand Up @@ -704,7 +708,7 @@ def __init__(
readout_type="ignore",
patch_size=16,
expand_channels=False,
**kwargs
**kwargs,
):
super(DPTHead, self).__init__(**kwargs)

Expand All @@ -720,9 +724,9 @@ def __init__(
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
self.fusion_blocks = nn.ModuleList()
for _ in range(len(self.convs)):
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_cfg))
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
self.fusion_blocks[0].res_conv_unit1 = None
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
self.num_fusion_blocks = len(self.fusion_blocks)
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
self.num_post_process_channels = len(self.post_process_channels)
Expand Down

0 comments on commit 736919b

Please sign in to comment.