Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fave fun #40

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/config_disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


#------------- disparity ---------------#
cfg.model = 'stereonet' # ['stereonet', 'activestereonet', 'hitnet', 'sos']
cfg.model = 'hitnet' # ['stereonet', 'activestereonet', 'hitnet', 'sos']
cfg.maxdisp = 192
cfg.mindisp = 0
cfg.loss_disp = True
Expand Down
211 changes: 192 additions & 19 deletions tools/train_net_disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_parser():
help='datapath')
parser.add_argument('--datapath12', default='./data/kitti2012/training/',
help='datapath')
parser.add_argument('--datapath_sf_clean', default='', help='datapath')
parser.add_argument('--epochs', type=int, default=4200, help='number of epochs to train')
parser.add_argument('--loadmodel', default=None, help='load model')
parser.add_argument('--savemodel', default=None, help='save model')
Expand Down Expand Up @@ -168,37 +169,60 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp):
model = model.cuda(gpu)

#------------------- Data Loader -----------------------
# all_left_img, all_right_img, all_left_disp, = ls.dataloader(args.data_path,
# args.split_file,
# depth_disp=True,
# cfg=cfg,
# is_train=True)

#kitti
train_left_img, train_right_img, train_left_disp,train_left_norm, test_left_img, test_right_img, test_left_disp, test_left_norm = ls_2015.dataloader(args.datapath)
train_left_img12, train_right_img12, train_left_disp12,train_left_norm12, test_left_img12, test_right_img12, test_left_disp12, test_left_norm12 = ls_2012.dataloader(args.datapath12)

ImageFloader = DA.myImageFloder(train_left_img+train_left_img12, train_right_img+train_right_img12, train_left_disp+train_left_disp12,train_left_norm+train_left_norm12, True)

ImageFloader_Test = DA.myImageFloder(test_left_img+test_left_img12, test_right_img+test_right_img12, test_left_disp+test_left_disp12,test_left_norm+test_left_norm12, True)

#sceneflow
# train_left_img, train_right_img, train_left_disp, test_left_img, test_right_img, test_left_disp = lt.dataloader(
# args.datapath)
# train_left_img.sort()
# train_right_img.sort()
# train_left_disp.sort()

# test_left_img.sort()
# test_right_img.sort()
# test_left_disp.sort()


# __normalize = {'mean': [0.0, 0.0, 0.0], 'std': [1.0, 1.0, 1.0]}
# TrainImgLoader = torch.utils.data.DataLoader(
# DA.myImageFloder(train_left_img, train_right_img, train_left_disp, True, normalize=__normalize),
# batch_size=args.train_bsize, shuffle=False, num_workers=1, drop_last=False)

# TestImgLoader = torch.utils.data.DataLoader(
# DA.myImageFloder(test_left_img, test_right_img, test_left_disp, False, normalize=__normalize),
# batch_size=args.test_bsize, shuffle=False, num_workers=4, drop_last=False)



# ImageFloader = DA.myImageFloder(all_left_img, all_right_img, all_left_disp, True, split=args.split_file, cfg=cfg)

# ImageFloader = torch.utils.data.DataLoader(
# DA.myImageFloder(train_left_img+train_left_img12, train_right_img+train_right_img12, train_left_disp+train_left_disp12,train_left_norm+train_left_norm12, True),
# batch_size=args.btrain, shuffle=True, num_workers=2, drop_last=False,pin_memory=True)

ImageFloader = DA.myImageFloder(train_left_img+train_left_img12, train_right_img+train_right_img12, train_left_disp+train_left_disp12,train_left_norm+train_left_norm12, True)


if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(ImageFloader)
test_sampler = torch.utils.data.distributed.DistributedSampler(ImageFloader_Test)
else:
train_sampler = None
test_sampler = None

TrainImgLoader = torch.utils.data.DataLoader(
ImageFloader,
batch_size=args.btrain, shuffle=(train_sampler is None), num_workers=args.workers, drop_last=True,
collate_fn=BatchCollator(cfg),
sampler=train_sampler)


TestImgLoader = torch.utils.data.DataLoader(
ImageFloader_Test,
batch_size=8, shuffle=(test_sampler is None), num_workers=args.workers, drop_last=True,
collate_fn=BatchCollator(cfg),
sampler=test_sampler)

args.max_warmup_step = min(len(TrainImgLoader), 500)

#------------------ Logger -------------------------------------
Expand Down Expand Up @@ -235,32 +259,61 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp):
for epoch in range(args.start_epoch, args.epochs + 1):
if args.distributed:
train_sampler.set_epoch(epoch)
test_sampler.set_epoch(epoch)

total_train_loss = 0
adjust_learning_rate(optimizer, epoch, args=args)

# test have some speed bug
if epoch%1 == 0:
EPEs = AverageMeter()
Bad3s = AverageMeter()
for batch_idx, data_batch in enumerate(TestImgLoader):
output6, disp_L = test(model, **data_batch)

output6 = output6.unsqueeze(1)
# print('********')
# three_pixel = cal_3pixel_error(output6, disp_L)
# Bad3s.update(three_pixel)
mask = (disp_L<196) & (disp_L>0)
if mask.sum()>0:
EPEs.update((output6[mask] - disp_L[mask]).abs().mean().item())
# print()
# if batch_idx % 2 ==0:
# info_str = 'EPE {} = {:.2f} Bads= {:.4f}'.format(batch_idx, EPEs.avg, Bad3s.avg)
# print(info_str)
info_str = 'EPE {} = {:.2f} '.format(batch_idx, EPEs.avg)
if main_process(args):
logger.info(info_str)





for batch_idx, data_batch in enumerate(TrainImgLoader):
start_time = time.time()
if epoch == 1 and cfg.warmup and batch_idx < args.max_warmup_step:
adjust_learning_rate(optimizer, epoch, batch_idx, args=args)

losses = train(model, cfg, args, optimizer, **data_batch)
losses, info_S = train(model, epoch, cfg, args, optimizer, **data_batch)
loss = losses.pop('loss')

if main_process(args):
logger.info(info_S)
logger.info('%s: %s' % (args.savemodel.strip('/').split('/')[-1], args.devices))
logger.info('Epoch %d Iter %d/%d training loss = %.3f , time = %.2f; Epoch time: %.3fs, Left time: %.3fs, lr: %.6f' % (
epoch,
batch_idx, len(TrainImgLoader), loss, time.time() - start_time, (time.time() - start_time) * len(TrainImgLoader),
(time.time() - start_time) * (len(TrainImgLoader) * (args.epochs - epoch) - batch_idx), optimizer.param_groups[0]["lr"]) )
logger.info('losses: {}'.format(list(losses.items())))
# logger.info('losses: {}'.format(list(losses.items())))
for lk, lv in losses.items():
writer.add_scalar(lk, lv, epoch * len(TrainImgLoader) + batch_idx)
total_train_loss += loss

if batch_idx == 100 and cfg.debug:
break


if main_process(args):
logger.info('epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader)))
savefilename = args.savemodel + '/finetune_' + str(epoch) + '.tar'
Expand All @@ -271,9 +324,29 @@ def main_worker(gpu, ngpus_per_node, args, cfg, exp):
'optimizer': optimizer.state_dict()
}, savefilename)
logger.info('Snapshot {} epoch in {}'.format(epoch, args.savemodel))


def train(model, cfg, args, optimizer, imgL, imgR, disp_L, norm_L,

# test have some speed bug
# if epoch%1 == 0:
# EPEs = AverageMeter()
# Bad3s = AverageMeter()
# for batch_idx, data_batch in enumerate(TestImgLoader):
# output6, disp_L = test(model, **data_batch)

# output6 = output6.unsqueeze(1)
# three_pixel = cal_3pixel_error(output6, disp_L)
# Bad3s.update(three_pixel)
# mask = (disp_L<196) & (disp_L>0)
# if mask.sum()>0:
# EPEs.update((output6[mask] - disp_L[mask]).abs().mean().item())
# if batch_idx % 2 ==0:
# info_str = 'EPE {} = {:.2f} Bads= {:.4f}'.format(batch_idx, EPEs.avg, Bad3s.avg)
# print(info_str)

# if main_process(args):
# logger.info(info_str)


def train(model, epoch, cfg, args, optimizer, imgL, imgR, disp_L, norm_L,
calib=None, calib_R=None, image_indexes=None, targets=None, ious=None, labels_map=None):
get_loss= My_loss(10, 5, 2, 3)
model.train()
Expand All @@ -295,7 +368,8 @@ def train(model, cfg, args, optimizer, imgL, imgR, disp_L, norm_L,
# outputs = model(imgL, imgR, disp_L)
if cfg.model == 'hitnet':
out, h_new, w = model(imgL, imgR, disp_true)
loss = get_loss(out, h_new, w, imgL, disp_true.squeeze(1), norm_true)
loss, info_S = get_loss(out, h_new, w, imgL, disp_true.squeeze(1), norm_true)


if cfg.model == 'stereonet':
outputs = model(imgL, imgR)
Expand Down Expand Up @@ -336,8 +410,70 @@ def train(model, cfg, args, optimizer, imgL, imgR, disp_L, norm_L,
reduced_losses = {k: v.item() for k, v in zip(loss_names, all_losses)}
else:
reduced_losses = {k: v.item() for k, v in losses.items()}
if not info_S:
info_S = 'None'

return reduced_losses, info_S



def cal_3pixel_error(pred_disp, disp_true):
pred_disp = pred_disp.data.cpu()
disp_true = disp_true.cpu()
#computing 3-px error#
true_disp = disp_true.clone()
index = np.argwhere(true_disp>0)
disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]])
correct1 = (disp_true[index[0][:], index[1][:], index[2][:]] < 1)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)
correct2 = (disp_true[index[0][:], index[1][:], index[2][:]] < 2)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)
# correct3 = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)
correct3 = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)
correct4 = (disp_true[index[0][:], index[1][:], index[2][:]] < 4)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)
correct5 = (disp_true[index[0][:], index[1][:], index[2][:]] < 5)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)
# return 1-(float(torch.sum(correct1))/float(len(index[0]))),1-(float(torch.sum(correct2))/float(len(index[0]))),\
# 1-(float(torch.sum(correct3))/float(len(index[0]))),1-(float(torch.sum(correct4))/float(len(index[0]))),\
# 1-(float(torch.sum(correct5))/float(len(index[0])))
return 1-(float(torch.sum(correct3))/float(len(index[0])))


def test(model, imgL, imgR, disp_L, norm_L):

# EPEs = AverageMeter()
# Bad3s = AverageMeter()
model.eval()

imgL = imgL.float().cuda()
imgR = imgR.float().cuda()
disp_L = disp_L.float().cuda()
norm_L = norm_L.float().cuda()

with torch.no_grad():
out,h_new,w= model(imgL, imgR,disp_L)
dx = h_new[7][:, 1, :, :]
dy = h_new[7][:, 2, :, :]
output6 = to_plane(h_new[7][:, 0, :, :], dx=dx, dy=dy, size=1).float()
output6 = torch.unsqueeze(output6,dim=1)
output6 = F.relu(torch.squeeze(output6,dim=1))
return output6, disp_L

# three_pixel = cal_3pixel_error(output6, disp_L)
# Bad3s.update(three_pixel)

# mask = (disp_L<196) & (disp_L>0)
# if mask.sum()>0:
# EPEs.update((output6[mask] - disp_L[mask]).abs().mean().item())
# if batch_idx % 2 ==0:
# # info_str = 'EPE {} = {:.2f} Bad3s {} = {:.2f}'.format(batch_idx, EPEs.avg, Bad3s.avg)
# info_str = 'EPE {} = {:.2f} Bads= {:.4f}'.format(batch_idx, EPEs.avg, Bad3s.avg)
# # log.info(info_str)
# return info_str


# if batch_idx == 0:
# savename_0 = '/media/elonli/365C51B15C516C9D/HIT_new_5_15/result/' + str(epoch)+ str(batch_idx) + '.png'
# saveimage_color(output6,savename_0,disp_L[mask].max())


return reduced_losses

class BatchCollator(object):
def __init__(self, cfg):
Expand Down Expand Up @@ -376,6 +512,43 @@ def adjust_learning_rate(optimizer, epoch, step=None, args=None):
for param_group in optimizer.param_groups:
param_group['lr'] = lr

class AverageMeter(object):
"""Computes and stores the average and current value"""

def __init__(self):
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

def to_plane(d,dx,dy,size=4):
#equation (5) in original paper
c = np.linspace(-(size - 1) / 2, (size - 1) / 2, size, endpoint=True)
c = torch.tensor(c).cuda()

a = c.view([1, 1, size])
a = torch.unsqueeze(a.repeat(1, d.shape[1] * size, d.shape[2]),dim=1).float()

b = c.view([1,size,1])
b = torch.unsqueeze(b.repeat(1, d.shape[1], d.shape[2]*size), dim=1).float()

d_float = d.float()
d_4 = F.interpolate(torch.unsqueeze(d_float,dim=1),scale_factor=size,mode='nearest')
dx_4 = F.interpolate(torch.unsqueeze(dx,dim=1),scale_factor=size,mode='nearest')
dy_4 = F.interpolate(torch.unsqueeze(dy,dim=1),scale_factor=size,mode='nearest')
d_plane = d_4 + a*dx_4 + b*dy_4
d_plane = torch.squeeze(d_plane,dim=1)
return d_plane

if __name__ == '__main__':
main()