Skip to content

Commit

Permalink
增加方向参数
Browse files Browse the repository at this point in the history
  • Loading branch information
yizt committed Jan 7, 2020
1 parent c562c33 commit 07e461b
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@
"""

import argparse
import sys
import os
import sys

import numpy as np
import torch
from tensorboardX import SummaryWriter
from torch import optim
from torch.nn import CTCLoss
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter

# from torch.utils.tensorboard import SummaryWriter
import crnn
from generator import Generator
from config import cfg
import utils
from config import cfg
from generator import Generator


# import torchvision.transforms as transforms
Expand Down Expand Up @@ -89,7 +91,7 @@ def train(args):
'cuda:{}'.format(args.local_rank) if args.device == 'cuda' and torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True
# data loader
data_set = Generator(cfg.word.get_all_words())
data_set = Generator(cfg.word.get_all_words(), args.direction)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)
else:
Expand All @@ -114,7 +116,8 @@ def train(args):
model_without_ddp = model.module
# 加载预训练模型
if args.init_epoch > 0:
checkpoint = torch.load(os.path.join(args.output_dir, 'crnn.{:03d}.pth'.format(args.init_epoch)),
checkpoint = torch.load(os.path.join(args.output_dir,
'crnn.{}.{:03d}.pth'.format(args.direction, args.init_epoch)),
map_location='cpu')
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
Expand Down Expand Up @@ -147,14 +150,15 @@ def train(args):
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'crnn.{:03d}.pth'.format(epoch + 1)))
os.path.join(args.output_dir, 'crnn.{}.{:03d}.pth'.format(args.direction, epoch + 1)))
if utils.is_main_process():
writer.close()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default='cpu', help="cpu or cuda")
parser.add_argument("--direction", type=str, default='horizontal', help="horizontal or vertical")
parser.add_argument("--batch-size", type=int, default=64, help="batch size")
parser.add_argument("--epochs", type=int, default=90, help="epochs")
parser.add_argument("--init-epoch", type=int, default=0, help="init epoch")
Expand Down

0 comments on commit 07e461b

Please sign in to comment.