Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of ReprojectedDistanceLoss #239

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion packnet_sfm/losses/supervised_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,102 @@
from packnet_sfm.utils.image import match_scales
from packnet_sfm.losses.loss_base import LossBase, ProgressiveScaling

from packnet_sfm.utils.depth import inv2depth
from packnet_sfm.utils.image import image_grid
########################################################################################################################

class ReprojectedDistanceLoss(nn.Module):
def __init__(self):
super().__init__()

def warp(self, depth, K, Kinv, T):
"""
Warps pixels from one image plane to another one.

Parameters
----------
depth : torch.Tensor [B,1,H,W]
Depth map for the camera

T : torch.Tensor
Relative pose

K : torch.Tensor
Intrinsics matrix

Kinv : torch.Tensor
Inverse of the intrinsics matrix

Returns
-------
points : torch.Tensor [B,H,W,2]
2D projected points that are within the image boundaries
"""

B, _, H, W = depth.shape

# Create flat index grid
grid = image_grid(B, H, W, depth.dtype, depth.device, normalized=False) # [B,3,H,W]
flat_grid = grid.view(B, 3, -1) # [B,3,HW]

# Estimate the outward rays in the camera frame
xnorm = (self.Kinv.bmm(flat_grid)).view(B, 3, H, W)
# Scale rays to metric depth
Xc = xnorm * depth

# Project the 3D point
Xs = T @ Xc
Xs = self.K.bmm(Xs.view(B, 3, -1))
# Normalize points
X = Xs[:, 0]
Y = Xs[:, 1]
Z = Xs[:, 2].clamp(min=1e-5)
Xnorm = 2 * (X / Z) / (W - 1) - 1.
Ynorm = 2 * (Y / Z) / (H - 1) - 1.

return torch.stack([Xnorm, Ynorm], dim=-1).view(B, H, W, 2)

def forward(self, inv_depth_gt, inv_depth_est, K, Kinv, T_ts):
"""
Calculates the reprojected distance loss.

Parameters
----------
inv_depth_est : torch.Tensor [B,1,H,W]
Predicted inverse depth map
inv_depth_gt : torch.Tensor [B,1,H,W]
Ground-truth inverse depth map
K : torch.Tensor
Intrinsic parameter

Kinv : torch.Tensor
Inverse of the intrinsics matrix

T_ts : torch.Tensor
Relative pose transformation from the source image camera frame to the target
image camera frame


Returns
-------
loss : torch.Tensor [1]
reprojected distance loss
"""

depth_gt = inv2depth(inv_depth_gt)
depth_est = inv2depth(inv_depth_est)

# Warp pixels from target frame to source frame
# Convert pose from source-to-target into target-to-source
T_st = torch.linalg.inv(T_ts)
p_gt = self.warp(depth_gt, K, Kinv, T_st)
p_est = self.warp(depth_est, K, Kinv, T_st)

loss_tmp = (p_est - p_gt).abs()
return loss_tmp.mean()



class BerHuLoss(nn.Module):
"""Class implementing the BerHu loss."""
def __init__(self, threshold=0.2):
Expand Down Expand Up @@ -179,4 +273,4 @@ def forward(self, inv_depths, gt_inv_depth,
return {
'loss': loss.unsqueeze(0),
'metrics': self.metrics,
}
}