From 07e461bdde2c77db74a27084513f0190fd1cf09c Mon Sep 17 00:00:00 2001 From: yizt Date: Tue, 7 Jan 2020 17:26:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=96=B9=E5=90=91=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 40b4335..7447424 100644 --- a/train.py +++ b/train.py @@ -7,20 +7,22 @@ """ import argparse -import sys import os +import sys + import numpy as np import torch +from tensorboardX import SummaryWriter from torch import optim from torch.nn import CTCLoss from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from tensorboardX import SummaryWriter + # from torch.utils.tensorboard import SummaryWriter import crnn -from generator import Generator -from config import cfg import utils +from config import cfg +from generator import Generator # import torchvision.transforms as transforms @@ -89,7 +91,7 @@ def train(args): 'cuda:{}'.format(args.local_rank) if args.device == 'cuda' and torch.cuda.is_available() else 'cpu') torch.backends.cudnn.benchmark = True # data loader - data_set = Generator(cfg.word.get_all_words()) + data_set = Generator(cfg.word.get_all_words(), args.direction) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(data_set) else: @@ -114,7 +116,8 @@ def train(args): model_without_ddp = model.module # 加载预训练模型 if args.init_epoch > 0: - checkpoint = torch.load(os.path.join(args.output_dir, 'crnn.{:03d}.pth'.format(args.init_epoch)), + checkpoint = torch.load(os.path.join(args.output_dir, + 'crnn.{}.{:03d}.pth'.format(args.direction, args.init_epoch)), map_location='cpu') optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) @@ -147,7 +150,7 @@ def train(args): 'args': args} utils.save_on_master( checkpoint, - os.path.join(args.output_dir, 'crnn.{:03d}.pth'.format(epoch + 1))) + os.path.join(args.output_dir, 'crnn.{}.{:03d}.pth'.format(args.direction, epoch + 1))) if utils.is_main_process(): writer.close() @@ -155,6 +158,7 @@ def train(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default='cpu', help="cpu or cuda") + parser.add_argument("--direction", type=str, default='horizontal', help="horizontal or vertical") parser.add_argument("--batch-size", type=int, default=64, help="batch size") parser.add_argument("--epochs", type=int, default=90, help="epochs") parser.add_argument("--init-epoch", type=int, default=0, help="init epoch")