diff --git a/packnet_sfm/losses/supervised_loss.py b/packnet_sfm/losses/supervised_loss.py index 4ff67918..11cca802 100644 --- a/packnet_sfm/losses/supervised_loss.py +++ b/packnet_sfm/losses/supervised_loss.py @@ -6,8 +6,102 @@ from packnet_sfm.utils.image import match_scales from packnet_sfm.losses.loss_base import LossBase, ProgressiveScaling +from packnet_sfm.utils.depth import inv2depth +from packnet_sfm.utils.image import image_grid ######################################################################################################################## +class ReprojectedDistanceLoss(nn.Module): + def __init__(self): + super().__init__() + + def warp(self, depth, K, Kinv, T): + """ + Warps pixels from one image plane to another one. + + Parameters + ---------- + depth : torch.Tensor [B,1,H,W] + Depth map for the camera + + T : torch.Tensor + Relative pose + + K : torch.Tensor + Intrinsics matrix + + Kinv : torch.Tensor + Inverse of the intrinsics matrix + + Returns + ------- + points : torch.Tensor [B,H,W,2] + 2D projected points that are within the image boundaries + """ + + B, _, H, W = depth.shape + + # Create flat index grid + grid = image_grid(B, H, W, depth.dtype, depth.device, normalized=False) # [B,3,H,W] + flat_grid = grid.view(B, 3, -1) # [B,3,HW] + + # Estimate the outward rays in the camera frame + xnorm = (self.Kinv.bmm(flat_grid)).view(B, 3, H, W) + # Scale rays to metric depth + Xc = xnorm * depth + + # Project the 3D point + Xs = T @ Xc + Xs = self.K.bmm(Xs.view(B, 3, -1)) + # Normalize points + X = Xs[:, 0] + Y = Xs[:, 1] + Z = Xs[:, 2].clamp(min=1e-5) + Xnorm = 2 * (X / Z) / (W - 1) - 1. + Ynorm = 2 * (Y / Z) / (H - 1) - 1. + + return torch.stack([Xnorm, Ynorm], dim=-1).view(B, H, W, 2) + + def forward(self, inv_depth_gt, inv_depth_est, K, Kinv, T_ts): + """ + Calculates the reprojected distance loss. + + Parameters + ---------- + inv_depth_est : torch.Tensor [B,1,H,W] + Predicted inverse depth map + inv_depth_gt : torch.Tensor [B,1,H,W] + Ground-truth inverse depth map + K : torch.Tensor + Intrinsic parameter + + Kinv : torch.Tensor + Inverse of the intrinsics matrix + + T_ts : torch.Tensor + Relative pose transformation from the source image camera frame to the target + image camera frame + + + Returns + ------- + loss : torch.Tensor [1] + reprojected distance loss + """ + + depth_gt = inv2depth(inv_depth_gt) + depth_est = inv2depth(inv_depth_est) + + # Warp pixels from target frame to source frame + # Convert pose from source-to-target into target-to-source + T_st = torch.linalg.inv(T_ts) + p_gt = self.warp(depth_gt, K, Kinv, T_st) + p_est = self.warp(depth_est, K, Kinv, T_st) + + loss_tmp = (p_est - p_gt).abs() + return loss_tmp.mean() + + + class BerHuLoss(nn.Module): """Class implementing the BerHu loss.""" def __init__(self, threshold=0.2): @@ -179,4 +273,4 @@ def forward(self, inv_depths, gt_inv_depth, return { 'loss': loss.unsqueeze(0), 'metrics': self.metrics, - } \ No newline at end of file + }