-
Notifications
You must be signed in to change notification settings - Fork 36
/
train_timesformer_deduped.py
628 lines (512 loc) · 23.1 KB
/
train_timesformer_deduped.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
import os.path as osp
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from timesformer_pytorch import TimeSformer
import gc
import random
import yaml
import numpy as np
import pandas as pd
import wandb
from torch.utils.data import DataLoader
import pandas as pd
import os
import random
from contextlib import contextmanager
import cv2
import scipy as sp
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import argparse
import torch
import torch.nn as nn
from torch.optim import AdamW
import datetime
import segmentation_models_pytorch as smp
import numpy as np
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from i3dallnl import InceptionI3d
import torch.nn as nn
import torch
from warmup_scheduler import GradualWarmupScheduler
from scipy import ndimage
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 933120000
class CFG:
# ============== comp exp name =============
comp_name = 'vesuvius'
# comp_dir_path = './'
comp_dir_path = './'
comp_folder_name = './'
# comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
comp_dataset_path = f'./'
exp_name = 'pretraining_all'
# ============== pred target =============
target_size = 1
# ============== model cfg =============
model_name = 'Unet'
# backbone = 'efficientnet-b0'
# backbone = 'se_resnext50_32x4d'
backbone='resnet3d'
in_chans = 26 # 65
encoder_depth=5
# ============== training cfg =============
size = 64
tile_size = 256
stride = tile_size // 8
train_batch_size = 256 # 32
valid_batch_size = train_batch_size
use_amp = True
scheduler = 'GradualWarmupSchedulerV2'
# scheduler = 'CosineAnnealingLR'
epochs = 10 # 30
# adamW warmupあり
warmup_factor = 10
# lr = 1e-4 / warmup_factor
# lr = 1e-4 / warmup_factor
lr = 6e-5
# ============== fold =============
valid_id = '20230820203112'
# objective_cv = 'binary' # 'binary', 'multiclass', 'regression'
metric_direction = 'maximize' # maximize, 'minimize'
# metrics = 'dice_coef'
# ============== fixed =============
pretrained = True
inf_weight = 'best' # 'best'
min_lr = 1e-6
weight_decay = 1e-6
max_grad_norm = 100
print_freq = 50
num_workers = 16
seed = 0
# ============== set dataset path =============
print('set dataset path')
outputs_path = f'./outputs/{comp_name}/{exp_name}/'
submission_dir = outputs_path + 'submissions/'
submission_path = submission_dir + f'submission_{exp_name}.csv'
model_dir = outputs_path + \
f'{comp_name}-models/'
figures_dir = outputs_path + 'figures/'
log_dir = outputs_path + 'logs/'
log_path = log_dir + f'{exp_name}.txt'
# ============== augmentation =============
train_aug_list = [
# A.RandomResizedCrop(
# size, size, scale=(0.85, 1.0)),
A.Resize(size, size),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
# A.RandomRotate90(p=0.6),
A.RandomBrightnessContrast(p=0.75),
A.ShiftScaleRotate(rotate_limit=360,shift_limit=0.15,scale_limit=0.15,p=0.75),
A.OneOf([
A.GaussNoise(var_limit=[10, 50]),
A.GaussianBlur(),
A.MotionBlur(),
], p=0.4),
# A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
A.CoarseDropout(max_holes=2, max_width=int(size * 0.2), max_height=int(size * 0.2),
mask_fill_value=0, p=0.5),
# A.Cutout(max_h_size=int(size * 0.6),
# max_w_size=int(size * 0.6), num_holes=1, p=1.0),
A.Normalize(
mean= [0] * in_chans,
std= [1] * in_chans
),
ToTensorV2(transpose_mask=True),
]
valid_aug_list = [
A.Resize(size, size),
A.Normalize(
mean= [0] * in_chans,
std= [1] * in_chans
),
ToTensorV2(transpose_mask=True),
]
rotate = A.Compose([A.Rotate(8,p=1)])
def init_logger(log_file):
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=log_file)
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
def set_seed(seed=None, cudnn_deterministic=True):
if seed is None:
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = cudnn_deterministic
torch.backends.cudnn.benchmark = False
def make_dirs(cfg):
for dir in [cfg.model_dir, cfg.figures_dir, cfg.submission_dir, cfg.log_dir]:
os.makedirs(dir, exist_ok=True)
def cfg_init(cfg, mode='train'):
set_seed(cfg.seed)
# set_env_name()
# set_dataset_path(cfg)
if mode == 'train':
make_dirs(cfg)
cfg_init(CFG)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def read_image_mask(fragment_id,start_idx=17,end_idx=43):
fragment_id_ = fragment_id.split("_")[0]
images = []
# idxs = range(65)
mid = 65 // 2
start = mid - CFG.in_chans // 2
end = mid + CFG.in_chans // 2
idxs = range(start_idx, end_idx)
for i in idxs:
if os.path.exists(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif"):
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)
else:
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.jpg", 0)
pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size)
pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size)
image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)
# image = ndimage.median_filter(image, size=5)
# image = cv2.resize(image, (image.shape[1]//2,image.shape[0]//2), interpolation = cv2.INTER_AREA)
if 'frag' in fragment_id:
image = cv2.resize(image, (image.shape[1]//2,image.shape[0]//2), interpolation = cv2.INTER_AREA)
image=np.clip(image,0,200)
if fragment_id_=='20230827161846':
image=cv2.flip(image,0)
images.append(image)
images = np.stack(images, axis=2)
if fragment_id_ in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:
images=images[:,:,::-1]
if fragment_id_ in ['20231022170901','20231022170900']:
mask = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.tiff", 0)
else:
mask = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.png", 0)
# mask = np.pad(mask, [(0, pad0), (0, pad1)], constant_values=0)
fragment_mask=cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_mask.png", 0)
if fragment_id_=='20230827161846':
fragment_mask=cv2.flip(fragment_mask,0)
fragment_mask = np.pad(fragment_mask, [(0, pad0), (0, pad1)], constant_values=0)
kernel = np.ones((16,16),np.uint8)
if 'frag' in fragment_id:
fragment_mask = cv2.resize(fragment_mask, (fragment_mask.shape[1]//2,fragment_mask.shape[0]//2), interpolation = cv2.INTER_AREA)
mask = cv2.resize(mask , (mask.shape[1]//2,mask.shape[0]//2), interpolation = cv2.INTER_AREA)
mask = mask.astype('float32')
mask/=255
return images, mask,fragment_mask
def get_train_valid_dataset():
train_images = []
train_masks = []
valid_images = []
valid_masks = []
valid_xyxys = []
# for fragment_id in ['20231005123333','20230820203112','20230620230619','20230530164535','20230826170124','20231012085431','20230702185753','20230620230617','20230522215721','20230701020044','20230901184804','20230531193658','20230520175435','20230903193206','20230902141231','20231007101615','20230929220924','recto','verso','20231022170900','20231012173610','20231016151000','20231012184420']:
#BIG 6:'20231005123333','20231022170900','20231012173610','20230702185753','20230929220924','20231007101615'
for fragment_id in ['20231210121321','20231106155350','20231005123336','20230820203112','20230620230619','20230826170124','20230702185753','20230522215721','20230531193658','20230520175435','20230903193206','20230902141231','20231007101615','20230929220924','recto','verso','20231016151000','20231012184423','20231031143850']:
#,
if not os.path.exists(f"train_scrolls/{fragment_id}"):
fragment_id = fragment_id + "_superseded"
# for fragment_id in ['20231210121321','20231106155350']:
print('reading ',fragment_id)
try:
image, mask,fragment_mask = read_image_mask(fragment_id)
except:
print(f"couldnt load {fragment_id}!")
x1_list = list(range(0, image.shape[1]-CFG.tile_size+1, CFG.stride))
y1_list = list(range(0, image.shape[0]-CFG.tile_size+1, CFG.stride))
windows_dict={}
for a in y1_list:
for b in x1_list:
for yi in range(0,CFG.tile_size,CFG.size):
for xi in range(0,CFG.tile_size,CFG.size):
y1=a+yi
x1=b+xi
y2=y1+CFG.size
x2=x1+CFG.size
# for y2 in range(y1,y1 + CFG.tile_size,CFG.size):
# for x2 in range(x1, x1 + CFG.tile_size,CFG.size):
if fragment_id!=CFG.valid_id:
if not np.all(mask[a:a + CFG.tile_size, b:b + CFG.tile_size]<0.01):
if not np.any(fragment_mask[a:a+ CFG.tile_size, b:b + CFG.tile_size]==0):
if (y1,y2,x1,x2) not in windows_dict:
train_images.append(image[y1:y2, x1:x2])
train_masks.append(mask[y1:y2, x1:x2, None])
assert image[y1:y2, x1:x2].shape==(CFG.size,CFG.size,CFG.in_chans)
windows_dict[(y1,y2,x1,x2)]='1'
if fragment_id==CFG.valid_id:
if not np.any(fragment_mask[a:a + CFG.tile_size, b:b + CFG.tile_size]==0):
valid_images.append(image[y1:y2, x1:x2])
valid_masks.append(mask[y1:y2, x1:x2, None])
valid_xyxys.append([x1, y1, x2, y2])
assert image[y1:y2, x1:x2].shape==(CFG.size,CFG.size,CFG.in_chans)
return train_images, train_masks, valid_images, valid_masks, valid_xyxys
def get_transforms(data, cfg):
if data == 'train':
aug = A.Compose(cfg.train_aug_list)
elif data == 'valid':
aug = A.Compose(cfg.valid_aug_list)
return aug
class CustomDataset(Dataset):
def __init__(self, images ,cfg,xyxys=None, labels=None, transform=None):
self.images = images
self.cfg = cfg
self.labels = labels
self.transform = transform
self.xyxys=xyxys
self.rotate=CFG.rotate
def __len__(self):
return len(self.images)
def cubeTranslate(self,y):
x=np.random.uniform(0,1,4).reshape(2,2)
x[x<.4]=0
x[x>.633]=2
x[(x>.4)&(x<.633)]=1
mask=cv2.resize(x, (x.shape[1]*64,x.shape[0]*64), interpolation = cv2.INTER_AREA)
x=np.zeros((self.cfg.size,self.cfg.size,self.cfg.in_chans)).astype(np.uint8)
for i in range(3):
x=np.where(np.repeat((mask==0).reshape(self.cfg.size,self.cfg.size,1), self.cfg.in_chans, axis=2),y[:,:,i:self.cfg.in_chans+i],x)
return x
def fourth_augment(self,image):
image_tmp = np.zeros_like(image)
cropping_num = random.randint(18, 26)
start_idx = random.randint(0, self.cfg.in_chans - cropping_num)
crop_indices = np.arange(start_idx, start_idx + cropping_num)
start_paste_idx = random.randint(0, self.cfg.in_chans - cropping_num)
tmp = np.arange(start_paste_idx, cropping_num)
np.random.shuffle(tmp)
cutout_idx = random.randint(0, 2)
temporal_random_cutout_idx = tmp[:cutout_idx]
image_tmp[..., start_paste_idx : start_paste_idx + cropping_num] = image[..., crop_indices]
if random.random() > 0.4:
image_tmp[..., temporal_random_cutout_idx] = 0
image = image_tmp
return image
def __getitem__(self, idx):
if self.xyxys is not None:
image = self.images[idx]
label = self.labels[idx]
xy=self.xyxys[idx]
if self.transform:
data = self.transform(image=image, mask=label)
image = data['image'].unsqueeze(0)
label = data['mask']
label=F.interpolate(label.unsqueeze(0),(self.cfg.size//16,self.cfg.size//16)).squeeze(0)
return image, label,xy
else:
image = self.images[idx]
label = self.labels[idx]
#3d rotate
image=image.transpose(2,1,0)#(c,w,h)
image=self.rotate(image=image)['image']
image=image.transpose(0,2,1)#(c,h,w)
image=self.rotate(image=image)['image']
image=image.transpose(0,2,1)#(c,w,h)
image=image.transpose(2,1,0)#(h,w,c)
image=self.fourth_augment(image)
if self.transform:
data = self.transform(image=image, mask=label)
image = data['image'].unsqueeze(0)
label = data['mask']
label=F.interpolate(label.unsqueeze(0),(self.cfg.size//16,self.cfg.size//16)).squeeze(0)
return image, label
class CustomDatasetTest(Dataset):
def __init__(self, images,xyxys, cfg, transform=None):
self.images = images
self.xyxys=xyxys
self.cfg = cfg
self.transform = transform
def __len__(self):
# return len(self.df)
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
xy=self.xyxys[idx]
if self.transform:
data = self.transform(image=image)
image = data['image'].unsqueeze(0)
return image,xy
# from resnetall import generate_model
def init_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m, mode='fan_out', nonlinearity='relu')
class Decoder(nn.Module):
def __init__(self, encoder_dims, upscale):
super().__init__()
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
nn.BatchNorm2d(encoder_dims[i-1]),
nn.ReLU(inplace=True)
) for i in range(1, len(encoder_dims))])
self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")
def forward(self, feature_maps):
for i in range(len(feature_maps)-1, 0, -1):
f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
f = torch.cat([feature_maps[i-1], f_up], dim=1)
f_down = self.convs[i-1](f)
feature_maps[i-1] = f_down
x = self.logit(feature_maps[0])
mask = self.up(x)
return mask
class RegressionPLModel(pl.LightningModule):
def __init__(self,pred_shape,size=256,enc='',with_norm=False,total_steps=500):
super(RegressionPLModel, self).__init__()
self.save_hyperparameters()
self.mask_pred = np.zeros(self.hparams.pred_shape)
self.mask_count = np.zeros(self.hparams.pred_shape)
self.loss_func1 = smp.losses.DiceLoss(mode='binary')
self.loss_func2= smp.losses.SoftBCEWithLogitsLoss(smooth_factor=0.25)
self.loss_func= lambda x,y:0.5 * self.loss_func1(x,y)+0.5*self.loss_func2(x,y)
self.backbone=TimeSformer(
dim = 512,
image_size = 64,
patch_size = 16,
num_frames = 26,
num_classes = 16,
channels=1,
depth = 8,
heads = 6,
dim_head = 64,
attn_dropout = 0.1,
ff_dropout = 0.1
)
if self.hparams.with_norm:
self.normalization=nn.BatchNorm3d(num_features=1)
def forward(self, x):
if x.ndim==4:
x=x[:,None]
if self.hparams.with_norm:
x=self.normalization(x)
x = self.backbone(torch.permute(x, (0, 2, 1,3,4)))
x=x.view(-1,1,4,4)
# feat_maps_pooled = [torch.max(f, dim=2)[0] for f in feat_maps]
# pred_mask = self.decoder(feat_maps_pooled)
return x
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self(x)
loss1 = self.loss_func(outputs, y)
if torch.isnan(loss1):
print("Loss nan encountered")
self.log("train/total_loss", loss1.item(),on_step=True, on_epoch=True, prog_bar=True)
return {"loss": loss1}
def validation_step(self, batch, batch_idx):
x,y,xyxys= batch
batch_size = x.size(0)
outputs = self(x)
loss1 = self.loss_func(outputs, y)
y_preds = torch.sigmoid(outputs).to('cpu')
for i, (x1, y1, x2, y2) in enumerate(xyxys):
self.mask_pred[y1:y2, x1:x2] += F.interpolate(y_preds[i].unsqueeze(0).float(),scale_factor=16,mode='bilinear').squeeze(0).squeeze(0).numpy()
self.mask_count[y1:y2, x1:x2] += np.ones((self.hparams.size, self.hparams.size))
self.log("val/total_loss", loss1.item(),on_step=True, on_epoch=True, prog_bar=True)
return {"loss": loss1}
def on_validation_epoch_end(self):
self.mask_pred = np.divide(self.mask_pred, self.mask_count, out=np.zeros_like(self.mask_pred), where=self.mask_count!=0)
wandb_logger.log_image(key="masks", images=[np.clip(self.mask_pred,0,1)], caption=["probs"])
#reset mask
self.mask_pred = np.zeros(self.hparams.pred_shape)
self.mask_count = np.zeros(self.hparams.pred_shape)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=CFG.lr)
scheduler =torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-4,pct_start=0.15, steps_per_epoch=self.hparams.total_steps, epochs=150,final_div_factor=1e2)
# scheduler = get_scheduler(CFG, optimizer)
return [optimizer],[scheduler]
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
"""
https://www.kaggle.com/code/underwearfitting/single-fold-training-of-resnet200d-lb0-965
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
super(GradualWarmupSchedulerV2, self).__init__(
optimizer, multiplier, total_epoch, after_scheduler)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [
base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def get_scheduler(cfg, optimizer):
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, 30, eta_min=1e-6)
scheduler = GradualWarmupSchedulerV2(
optimizer, multiplier=1.0, total_epoch=1, after_scheduler=scheduler_cosine)
return scheduler
def scheduler_step(scheduler, avg_val_loss, epoch):
scheduler.step(epoch)
fragment_id = CFG.valid_id
valid_mask_gt = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.png", 0)
# valid_mask_gt=cv2.resize(valid_mask_gt,(valid_mask_gt.shape[1]//2,valid_mask_gt.shape[0]//2),cv2.INTER_AREA)
pred_shape=valid_mask_gt.shape
torch.set_float32_matmul_precision('medium')
fragments=['20231005123336']
enc_i,enc,fold=0,'i3d',0
for fid in fragments:
CFG.valid_id=fid
fragment_id = CFG.valid_id
run_slug=f'training_scrolls_valid={fragment_id}_{CFG.size}x{CFG.size}_submissionlabels'
valid_mask_gt = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.png", 0)
pred_shape=valid_mask_gt.shape
train_images, train_masks, valid_images, valid_masks, valid_xyxys = get_train_valid_dataset()
valid_xyxys = np.stack(valid_xyxys)
train_dataset = CustomDataset(
train_images, CFG, labels=train_masks, transform=get_transforms(data='train', cfg=CFG))
valid_dataset = CustomDataset(
valid_images, CFG,xyxys=valid_xyxys, labels=valid_masks, transform=get_transforms(data='valid', cfg=CFG))
train_loader = DataLoader(train_dataset,
batch_size=CFG.train_batch_size,
shuffle=True,
num_workers=CFG.num_workers, pin_memory=True, drop_last=True,
)
valid_loader = DataLoader(valid_dataset,
batch_size=CFG.valid_batch_size,
shuffle=False,
num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
wandb_logger = WandbLogger(project="vesivus",name=run_slug+f'timesformer_big6_finetune')
norm=fold==1
model=RegressionPLModel(enc='i3d',pred_shape=pred_shape,size=CFG.size,total_steps=len(train_loader))
print('FOLD : ',fold)
wandb_logger.watch(model, log="all", log_freq=100)
multiplicative = lambda epoch: 0.9
trainer = pl.Trainer(
# check_val_every_n_epoch=10,
max_epochs=150,
accelerator="gpu",
devices=8,
logger=wandb_logger,
default_root_dir="./models",
accumulate_grad_batches=1,
precision='16-mixed',
gradient_clip_val=1.0,
gradient_clip_algorithm="norm",
strategy='ddp_find_unused_parameters_true',
callbacks=[ModelCheckpoint(filename=f'timesformer_wild15_{fid}_{fold}_fr_{enc}'+'{epoch}',dirpath=CFG.model_dir,monitor='train/total_loss',mode='min',save_top_k=12,every_n_epochs=20),
],
)
trainer.fit(model=model, train_dataloaders=train_loader)
del train_images,train_loader,train_masks,valid_loader,model
gc.collect()
torch.cuda.empty_cache()
wandb.finish()