diff --git a/configs/config_disp.py b/configs/config_disp.py index 980b8ff..b2c8040 100644 --- a/configs/config_disp.py +++ b/configs/config_disp.py @@ -10,7 +10,7 @@ #------------- disparity ---------------# -cfg.model = 'stereonet' # ['stereonet', 'activestereonet', 'hitnet', 'sos'] +cfg.model = 'hitnet' # ['stereonet', 'activestereonet', 'hitnet', 'sos'] cfg.maxdisp = 192 cfg.mindisp = 0 cfg.loss_disp = True diff --git a/tools/train_net_disp.py b/tools/train_net_disp.py index 9d1c4f9..6039e5c 100644 --- a/tools/train_net_disp.py +++ b/tools/train_net_disp.py @@ -41,6 +41,7 @@ def get_parser(): help='datapath') parser.add_argument('--datapath12', default='./data/kitti2012/training/', help='datapath') + parser.add_argument('--datapath_sf_clean', default='', help='datapath') parser.add_argument('--epochs', type=int, default=4200, help='number of epochs to train') parser.add_argument('--loadmodel', default=None, help='load model') parser.add_argument('--savemodel', default=None, help='save model') @@ -168,30 +169,46 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp): model = model.cuda(gpu) #------------------- Data Loader ----------------------- - # all_left_img, all_right_img, all_left_disp, = ls.dataloader(args.data_path, - # args.split_file, - # depth_disp=True, - # cfg=cfg, - # is_train=True) - + #kitti train_left_img, train_right_img, train_left_disp,train_left_norm, test_left_img, test_right_img, test_left_disp, test_left_norm = ls_2015.dataloader(args.datapath) train_left_img12, train_right_img12, train_left_disp12,train_left_norm12, test_left_img12, test_right_img12, test_left_disp12, test_left_norm12 = ls_2012.dataloader(args.datapath12) + ImageFloader = DA.myImageFloder(train_left_img+train_left_img12, train_right_img+train_right_img12, train_left_disp+train_left_disp12,train_left_norm+train_left_norm12, True) + + ImageFloader_Test = DA.myImageFloder(test_left_img+test_left_img12, test_right_img+test_right_img12, test_left_disp+test_left_disp12,test_left_norm+test_left_norm12, True) + + #sceneflow + # train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = lt.dataloader( + # args.datapath) + # train_left_img.sort() + # train_right_img.sort() + # train_left_disp.sort() + + # test_left_img.sort() + # test_right_img.sort() + # test_left_disp.sort() + + + # __normalize = {'mean': [0.0, 0.0, 0.0], 'std': [1.0, 1.0, 1.0]} + # TrainImgLoader = torch.utils.data.DataLoader( + # DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True, normalize=__normalize), + # batch_size=args.train_bsize, shuffle=False, num_workers=1, drop_last=False) + + # TestImgLoader = torch.utils.data.DataLoader( + # DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False, normalize=__normalize), + # batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False) + - # ImageFloader = DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True, split=args.split_file, cfg=cfg) - # ImageFloader = torch.utils.data.DataLoader( - # DA.myImageFloder(train_left_img+train_left_img12, train_right_img+train_right_img12, train_left_disp+train_left_disp12,train_left_norm+train_left_norm12, True), - # batch_size=args.btrain, shuffle=True, num_workers=2, drop_last=False,pin_memory=True) - ImageFloader = DA.myImageFloder(train_left_img+train_left_img12, train_right_img+train_right_img12, train_left_disp+train_left_disp12,train_left_norm+train_left_norm12, True) - if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(ImageFloader) + test_sampler = torch.utils.data.distributed.DistributedSampler(ImageFloader_Test) else: train_sampler = None + test_sampler = None TrainImgLoader = torch.utils.data.DataLoader( ImageFloader, @@ -199,6 +216,13 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp): collate_fn=BatchCollator(cfg), sampler=train_sampler) + + TestImgLoader = torch.utils.data.DataLoader( + ImageFloader_Test, + batch_size=8, shuffle=(test_sampler is None), num_workers=args.workers, drop_last=True, + collate_fn=BatchCollator(cfg), + sampler=test_sampler) + args.max_warmup_step = min(len(TrainImgLoader), 500) #------------------ Logger ------------------------------------- @@ -235,25 +259,53 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp): for epoch in range(args.start_epoch, args.epochs + 1): if args.distributed: train_sampler.set_epoch(epoch) + test_sampler.set_epoch(epoch) total_train_loss = 0 adjust_learning_rate(optimizer, epoch, args=args) + # test have some speed bug + if epoch%1 == 0: + EPEs = AverageMeter() + Bad3s = AverageMeter() + for batch_idx, data_batch in enumerate(TestImgLoader): + output6, disp_L = test(model, **data_batch) + + output6 = output6.unsqueeze(1) + # print('********') + # three_pixel = cal_3pixel_error(output6, disp_L) + # Bad3s.update(three_pixel) + mask = (disp_L<196) & (disp_L>0) + if mask.sum()>0: + EPEs.update((output6[mask] - disp_L[mask]).abs().mean().item()) + # print() + # if batch_idx % 2 ==0: + # info_str = 'EPE {} = {:.2f} Bads= {:.4f}'.format(batch_idx, EPEs.avg, Bad3s.avg) + # print(info_str) + info_str = 'EPE {} = {:.2f} '.format(batch_idx, EPEs.avg) + if main_process(args): + logger.info(info_str) + + + + + for batch_idx, data_batch in enumerate(TrainImgLoader): start_time = time.time() if epoch == 1 and cfg.warmup and batch_idx < args.max_warmup_step: adjust_learning_rate(optimizer, epoch, batch_idx, args=args) - losses = train(model, cfg, args, optimizer, **data_batch) + losses, info_S = train(model, epoch, cfg, args, optimizer, **data_batch) loss = losses.pop('loss') if main_process(args): + logger.info(info_S) logger.info('%s: %s' % (args.savemodel.strip('/').split('/')[-1], args.devices)) logger.info('Epoch %d Iter %d/%d training loss = %.3f , time = %.2f; Epoch time: %.3fs, Left time: %.3fs, lr: %.6f' % ( epoch, batch_idx, len(TrainImgLoader), loss, time.time() - start_time, (time.time() - start_time) * len(TrainImgLoader), (time.time() - start_time) * (len(TrainImgLoader) * (args.epochs - epoch) - batch_idx), optimizer.param_groups[0]["lr"]) ) - logger.info('losses: {}'.format(list(losses.items()))) + # logger.info('losses: {}'.format(list(losses.items()))) for lk, lv in losses.items(): writer.add_scalar(lk, lv, epoch * len(TrainImgLoader) + batch_idx) total_train_loss += loss @@ -261,6 +313,7 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp): if batch_idx == 100 and cfg.debug: break + if main_process(args): logger.info('epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader))) savefilename = args.savemodel + '/finetune_' + str(epoch) + '.tar' @@ -271,9 +324,29 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp): 'optimizer': optimizer.state_dict() }, savefilename) logger.info('Snapshot {} epoch in {}'.format(epoch, args.savemodel)) - - -def train(model, cfg, args, optimizer, imgL, imgR, disp_L, norm_L, + + # test have some speed bug + # if epoch%1 == 0: + # EPEs = AverageMeter() + # Bad3s = AverageMeter() + # for batch_idx, data_batch in enumerate(TestImgLoader): + # output6, disp_L = test(model, **data_batch) + + # output6 = output6.unsqueeze(1) + # three_pixel = cal_3pixel_error(output6, disp_L) + # Bad3s.update(three_pixel) + # mask = (disp_L<196) & (disp_L>0) + # if mask.sum()>0: + # EPEs.update((output6[mask] - disp_L[mask]).abs().mean().item()) + # if batch_idx % 2 ==0: + # info_str = 'EPE {} = {:.2f} Bads= {:.4f}'.format(batch_idx, EPEs.avg, Bad3s.avg) + # print(info_str) + + # if main_process(args): + # logger.info(info_str) + + +def train(model, epoch, cfg, args, optimizer, imgL, imgR, disp_L, norm_L, calib=None, calib_R=None, image_indexes=None, targets=None, ious=None, labels_map=None): get_loss= My_loss(10, 5, 2, 3) model.train() @@ -295,7 +368,8 @@ def train(model, cfg, args, optimizer, imgL, imgR, disp_L, norm_L, # outputs = model(imgL, imgR, disp_L) if cfg.model == 'hitnet': out, h_new, w = model(imgL, imgR, disp_true) - loss = get_loss(out, h_new, w, imgL, disp_true.squeeze(1), norm_true) + loss, info_S = get_loss(out, h_new, w, imgL, disp_true.squeeze(1), norm_true) + if cfg.model == 'stereonet': outputs = model(imgL, imgR) @@ -336,8 +410,70 @@ def train(model, cfg, args, optimizer, imgL, imgR, disp_L, norm_L, reduced_losses = {k: v.item() for k, v in zip(loss_names, all_losses)} else: reduced_losses = {k: v.item() for k, v in losses.items()} + if not info_S: + info_S = 'None' + + return reduced_losses, info_S + + + +def cal_3pixel_error(pred_disp, disp_true): + pred_disp = pred_disp.data.cpu() + disp_true = disp_true.cpu() + #computing 3-px error# + true_disp = disp_true.clone() + index = np.argwhere(true_disp>0) + disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]]) + correct1 = (disp_true[index[0][:], index[1][:], index[2][:]] < 1)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) + correct2 = (disp_true[index[0][:], index[1][:], index[2][:]] < 2)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) + # correct3 = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) + correct3 = (disp_true[index[0][:], index[1][:], index[2][:]] < 3) + correct4 = (disp_true[index[0][:], index[1][:], index[2][:]] < 4)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) + correct5 = (disp_true[index[0][:], index[1][:], index[2][:]] < 5)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) + # return 1-(float(torch.sum(correct1))/float(len(index[0]))),1-(float(torch.sum(correct2))/float(len(index[0]))),\ + # 1-(float(torch.sum(correct3))/float(len(index[0]))),1-(float(torch.sum(correct4))/float(len(index[0]))),\ + # 1-(float(torch.sum(correct5))/float(len(index[0]))) + return 1-(float(torch.sum(correct3))/float(len(index[0]))) + + +def test(model, imgL, imgR, disp_L, norm_L): + + # EPEs = AverageMeter() + # Bad3s = AverageMeter() + model.eval() + + imgL = imgL.float().cuda() + imgR = imgR.float().cuda() + disp_L = disp_L.float().cuda() + norm_L = norm_L.float().cuda() + + with torch.no_grad(): + out,h_new,w= model(imgL, imgR,disp_L) + dx = h_new[7][:, 1, :, :] + dy = h_new[7][:, 2, :, :] + output6 = to_plane(h_new[7][:, 0, :, :], dx=dx, dy=dy, size=1).float() + output6 = torch.unsqueeze(output6,dim=1) + output6 = F.relu(torch.squeeze(output6,dim=1)) + return output6, disp_L + + # three_pixel = cal_3pixel_error(output6, disp_L) + # Bad3s.update(three_pixel) + + # mask = (disp_L<196) & (disp_L>0) + # if mask.sum()>0: + # EPEs.update((output6[mask] - disp_L[mask]).abs().mean().item()) + # if batch_idx % 2 ==0: + # # info_str = 'EPE {} = {:.2f} Bad3s {} = {:.2f}'.format(batch_idx, EPEs.avg, Bad3s.avg) + # info_str = 'EPE {} = {:.2f} Bads= {:.4f}'.format(batch_idx, EPEs.avg, Bad3s.avg) + # # log.info(info_str) + # return info_str + + + # if batch_idx == 0: + # savename_0 = '/media/elonli/365C51B15C516C9D/HIT_new_5_15/result/' + str(epoch)+ str(batch_idx) + '.png' + # saveimage_color(output6,savename_0,disp_L[mask].max()) + - return reduced_losses class BatchCollator(object): def __init__(self, cfg): @@ -376,6 +512,43 @@ def adjust_learning_rate(optimizer, epoch, step=None, args=None): for param_group in optimizer.param_groups: param_group['lr'] = lr +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def to_plane(d,dx,dy,size=4): + #equation (5) in original paper + c = np.linspace(-(size - 1) / 2, (size - 1) / 2, size, endpoint=True) + c = torch.tensor(c).cuda() + + a = c.view([1, 1, size]) + a = torch.unsqueeze(a.repeat(1, d.shape[1] * size, d.shape[2]),dim=1).float() + + b = c.view([1,size,1]) + b = torch.unsqueeze(b.repeat(1, d.shape[1], d.shape[2]*size), dim=1).float() + + d_float = d.float() + d_4 = F.interpolate(torch.unsqueeze(d_float,dim=1),scale_factor=size,mode='nearest') + dx_4 = F.interpolate(torch.unsqueeze(dx,dim=1),scale_factor=size,mode='nearest') + dy_4 = F.interpolate(torch.unsqueeze(dy,dim=1),scale_factor=size,mode='nearest') + d_plane = d_4 + a*dx_4 + b*dy_4 + d_plane = torch.squeeze(d_plane,dim=1) + return d_plane + if __name__ == '__main__': main()