Skip to content

Commit

Permalink
Fix depth range for KITTI
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Labatut committed Sep 27, 2023
1 parent 7e40f8b commit 5b0e833
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions dinov2/hub/depthers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from enum import Enum
from functools import partial
from typing import Union
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -142,15 +142,15 @@ def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union
)


def _make_dinov2_dpt_depth_head(*, embed_dim: int = 1024):
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
return DPTHead(
in_channels=[embed_dim] * 4,
channels=256,
embed_dims=embed_dim,
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
readout_type="project",
min_depth=0.001,
max_depth=10,
min_depth=min_depth,
max_depth=max_depth,
loss_decode=(),
)

Expand All @@ -160,6 +160,8 @@ def _make_dinov2_dpt_depther(
arch_name: str = "vit_large",
pretrained: bool = True,
weights: Union[Weights, str] = Weights.NYU,
min_depth: Optional[float] = None,
max_depth: Optional[float] = None,
**kwargs,
):
if isinstance(weights, str):
Expand All @@ -168,10 +170,25 @@ def _make_dinov2_dpt_depther(
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")

if max_depth is None:
if pretrained: # Set according to training dataset
if weights == Weights.KITTI:
min_depth = 0.001
max_depth = 80.0
elif weights == Weights.NYU:
min_depth = 0.001
max_depth = 10.0
else:
min_depth = 0.001
max_depth = 10.0
else:
min_depth = 0.001
max_depth = 10.0

backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)

model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim)
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)

out_index = {
"vit_small": [2, 5, 8, 11],
Expand Down

0 comments on commit 5b0e833

Please sign in to comment.