-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_cifar100_kaggle_model.py
127 lines (106 loc) · 3.96 KB
/
train_cifar100_kaggle_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# model from https://www.kaggle.com/pankajj/fashion-mnist-with-pytorch-93-accuracy
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, TensorDataset
import torch.nn.functional as F
from torch import nn, optim
import torchvision
import numpy as np
import sys
import torch
class FashionCNN(nn.Module):
def __init__(self):
super(FashionCNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc1 = nn.Linear(in_features=3136, out_features=600)
#self.drop = nn.Dropout2d(0.25)
self.fc2 = nn.Linear(in_features=600, out_features=120)
self.fc3 = nn.Linear(in_features=120, out_features=100)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.view(out.size(0), -1)
#print(out.size())
out = self.fc1(out)
#out = self.drop(out)
out = self.fc2(out)
out = self.fc3(out)
return out
assert(len(sys.argv) > 1)
save_path = sys.argv[1]
n_epochs = 50
log_interval = 100
use_cuda = torch.cuda.is_available()
if use_cuda:
torch.cuda.empty_cache()
#torch.cuda.set_device(0)
preprocess_train = transforms.Compose([
# transforms.RandomRotation(30),
# transforms.RandomResizedCrop(28),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
preprocess_test = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.CIFAR100(root="cifar100", train=True, download=True, transform=preprocess_train)
testset = torchvision.datasets.CIFAR100(root="cifar100", train=False, download=True, transform=preprocess_test)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=1280, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
net = FashionCNN()
if use_cuda:
net.cuda()
net = nn.DataParallel(net, device_ids=[0, 1])
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
factor=0.1, patience=15, threshold=0.0001, threshold_mode='abs')
criterion = nn.CrossEntropyLoss().cuda()
def train(epoch):
net.train()
for batch_idx, (data, target) in enumerate(train_loader):
if use_cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
scheduler.step(epoch)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test():
net.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
if use_cuda:
data, target = data.cuda(), target.cuda()
output = net(data)
test_loss += criterion(output, target).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def save_model(net, path):
torch.save(net.state_dict(), path)
#torch.save(optimizer.state_dict(), 'results/_optimizer_.pth')
# save_model(net, "results/init_mobilev2_2.pth")
for epoch in range(1, n_epochs + 1):
train(epoch)
test()
save_model(net, "models/" + save_path + ".pth")