Skip to content

Commit

Permalink
Extract depth range automagic setter
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Labatut committed Sep 27, 2023
1 parent 5b0e833 commit 53eeb55
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions dinov2/hub/depthers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from enum import Enum
from functools import partial
from typing import Optional, Union
from typing import Optional, Tuple, Union

import torch

Expand All @@ -19,10 +19,26 @@ class Weights(Enum):
KITTI = "KITTI"


def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
if not pretrained: # Default
return (0.001, 10.0)

# Pretrained, set according to the training dataset for the provided weights
if weights == Weights.KITTI:
return (0.001, 80.0)

if weights == Weights.NYU:
return (0.001, 10.0)

return (0.001, 10.0)


def _make_dinov2_linear_depth_head(
*,
embed_dim: int = 1024,
layers: int = 4,
embed_dim: int,
layers: int,
min_depth: float,
max_depth: float,
**kwargs,
):
if layers not in (1, 4):
Expand All @@ -46,7 +62,7 @@ def _make_dinov2_linear_depth_head(
channels=embed_dim * len(in_index) * 2,
align_corners=False,
min_depth=0.001,
max_depth=10,
max_depth=80,
loss_decode=(),
)

Expand All @@ -57,6 +73,7 @@ def _make_dinov2_linear_depther(
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.NYU,
depth_range: Optional[Tuple[float, float]] = None,
**kwargs,
):
if layers not in (1, 4):
Expand All @@ -67,15 +84,20 @@ def _make_dinov2_linear_depther(
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")

if depth_range is None:
depth_range = _get_depth_range(pretrained, weights)
min_depth, max_depth = depth_range

backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)

embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
model_name = _make_dinov2_model_name(arch_name, patch_size)
linear_depth_head = _make_dinov2_linear_depth_head(
arch_name=arch_name,
embed_dim=embed_dim,
layers=layers,
min_depth=min_depth,
max_depth=max_depth,
)

layer_count = {
Expand Down Expand Up @@ -160,8 +182,7 @@ def _make_dinov2_dpt_depther(
arch_name: str = "vit_large",
pretrained: bool = True,
weights: Union[Weights, str] = Weights.NYU,
min_depth: Optional[float] = None,
max_depth: Optional[float] = None,
depth_range: Optional[Tuple[float, float]] = None,
**kwargs,
):
if isinstance(weights, str):
Expand All @@ -170,20 +191,9 @@ def _make_dinov2_dpt_depther(
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")

if max_depth is None:
if pretrained: # Set according to training dataset
if weights == Weights.KITTI:
min_depth = 0.001
max_depth = 80.0
elif weights == Weights.NYU:
min_depth = 0.001
max_depth = 10.0
else:
min_depth = 0.001
max_depth = 10.0
else:
min_depth = 0.001
max_depth = 10.0
if depth_range is None:
depth_range = _get_depth_range(pretrained, weights)
min_depth, max_depth = depth_range

backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)

Expand Down

0 comments on commit 53eeb55

Please sign in to comment.