From 3ec0ed8643099994bb7f9b213e5125c36d3622cb Mon Sep 17 00:00:00 2001 From: iejMac Date: Wed, 15 Feb 2023 11:07:59 +0000 Subject: [PATCH 01/26] Add video support --- src/open_clip/vivit.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 src/open_clip/vivit.py diff --git a/src/open_clip/vivit.py b/src/open_clip/vivit.py new file mode 100644 index 000000000..102234b37 --- /dev/null +++ b/src/open_clip/vivit.py @@ -0,0 +1,3 @@ +""" +ViViT model (https://arxiv.org/abs/2103.15691) +""" From 1cd33e417fab5bd8d0513660d7d8a54003c944ba Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 18 Feb 2023 09:04:32 +0000 Subject: [PATCH 02/26] data loading: correct shapes in training loop (crappy code) --- src/training/main.py | 8 +- src/training/train.py | 6 +- src/training/video_data.py | 185 +++++++++++++++++++++++++++++++++++++ 3 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 src/training/video_data.py diff --git a/src/training/main.py b/src/training/main.py index f70c9f953..fddd8b263 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -29,6 +29,7 @@ from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data +from training.video_data import get_data # TODO: maybe we don't need separate files from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging from training.params import parse_args @@ -334,7 +335,12 @@ def main(args): logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets - data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + # TODO: come up with a way of getting alternative modality data based on model config + # if "ViViT" in model_name: + if True: + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + else: + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) assert len(data), 'At least one train or eval dataset must be specified.' # create scheduler if train diff --git a/src/training/train.py b/src/training/train.py index 830b0bf59..3a1ee2b1a 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -88,7 +88,11 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist if not args.skip_scheduler: scheduler(step) - images, texts = batch + # TODO: adapt dataloaders to fit open_clip format + # TODO: generalize train loop to modality1, modality2 instead of image,text maybe + images, texts = batch["video_tensor"], batch["text_tokens"] + print(images.shape, texts.shape) + continue images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) diff --git a/src/training/video_data.py b/src/training/video_data.py new file mode 100644 index 000000000..e546d2648 --- /dev/null +++ b/src/training/video_data.py @@ -0,0 +1,185 @@ +"""video dataset creation""" +import io +import torchvision +import tempfile +import webdataset as wds + +from dataclasses import dataclass +from multiprocessing import Value +from pathlib import Path +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + shared_epoch: SharedEpoch = None + sampler = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def create_webdataset( + urls, + video_transform, + enable_text=True, + enable_video=True, + video_key="mp4", + caption_key="txt", + enable_metadata=False, + cache_path=None, + input_sampler=lambda a: a, + tokenizer=None, +): + """Create a WebDataset reader, it can read a webdataset of video, text and json""" + + urls = input_sampler(urls) + + dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue) + + def filter_dataset(item): + if enable_text and caption_key not in item: + return False + if enable_video and video_key not in item: + return False + if enable_metadata and "json" not in item: + return False + return True + + filtered_dataset = dataset.select(filter_dataset) + + def preprocess_dataset(item): + output = {} + if enable_video: + video_data = item[video_key] + with tempfile.NamedTemporaryFile() as f: + f.write(video_data) + # video = torchvision.io.read_video(video_data) + video, audio, meta = torchvision.io.read_video(f.name) + video_tensor = video_transform(video) + output["video_filename"] = item["__key__"] + output["video_tensor"] = video_tensor + + if enable_text: + text = item[caption_key] + caption = text.decode("utf-8") + tokenized_text = tokenizer(caption) + output["text_tokens"] = tokenized_text + output["text"] = caption + + if enable_metadata: + metadata_file = item["json"] + metadata = metadata_file.decode("utf-8") + output["metadata"] = metadata + return output + + transformed_dataset = filtered_dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue) + return transformed_dataset + + +def dataset_to_dataloader(dataset, batch_size, num_prepro_workers, input_format): + """Create a pytorch dataloader from a dataset""" + + def collate_fn(batch): + batch = list(filter(lambda x: x is not None, batch)) + return default_collate(batch) + + data = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_prepro_workers, + pin_memory=True, + prefetch_factor=2, + collate_fn=collate_fn if input_format == "files" else None, + ) + return data + + +class VideoDatasetReader: + """WebdatasetReader is a reader that reads samples from a webdataset""" + + def __init__( + self, + sampler, + preprocess, + input_dataset, + batch_size, + num_prepro_workers, + enable_text=True, + enable_video=True, + enable_metadata=False, + wds_video_key="mp4", + wds_caption_key="txt", + cache_path=None, + tokenizer=None, + ): + self.batch_size = batch_size + dataset = create_webdataset( + input_dataset, + preprocess, + enable_text=enable_text, + enable_video=enable_video, + video_key=wds_video_key, + caption_key=wds_caption_key, + enable_metadata=enable_metadata, + cache_path=cache_path, + input_sampler=sampler, + tokenizer=tokenizer, + ) + self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "webdataset") + self.num_batches = 0 + self.num_samples = 0 + + + def __iter__(self): + for batch in self.dataloader: + yield batch + + +def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): + # TODO: get this from the model + preprocess_vid = lambda vid: vid[:32, :128, :128, :] # TODO: adjust this so it works for testing + num_samples = args.train_num_samples + batch_size = args.batch_size + + wds = VideoDatasetReader( + sampler=lambda a: a, + preprocess=preprocess_vid, + input_dataset=args.train_data, + batch_size=batch_size, + num_prepro_workers=96, + enable_metadata=True, + tokenizer=tokenizer, + ) + + shared_epoch = SharedEpoch(epoch=epoch) + + return DataInfo(dataloader=wds, shared_epoch=shared_epoch) + + +def get_data(args, preprocess_fns, epoch=0, tokenizer=None): + preprocess_train, preprocess_val = preprocess_fns + data = {} + + if args.train_data: + data["train"] = get_wds_dataset(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + + return data + From b99575064e46caaef5ce5f5a69dbd2a9c95ad2e4 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 18 Feb 2023 11:35:40 +0000 Subject: [PATCH 03/26] update model progress --- src/open_clip/factory.py | 6 +- src/open_clip/model_configs/ViViT-B-32.json | 29 +++++++++ src/open_clip/vivit.py | 69 +++++++++++++++++++++ src/training/train.py | 2 - 4 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 src/open_clip/model_configs/ViViT-B-32.json diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 14011f934..5aa975e5d 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -13,6 +13,7 @@ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype from .coca_model import CoCa +from .vivit import VideoCLIP # TODO: change once full model is implemented from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf @@ -191,7 +192,10 @@ def create_model( else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + if "ViViT" in model_name: # TODO better way of detecting video configs + model = VideoCLIP(**model_cfg) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_loaded = False if pretrained: diff --git a/src/open_clip/model_configs/ViViT-B-32.json b/src/open_clip/model_configs/ViViT-B-32.json new file mode 100644 index 000000000..b5ce5369b --- /dev/null +++ b/src/open_clip/model_configs/ViViT-B-32.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "temporal_cfg": { + "context_length": 32, + "width": 512, + "heads": 8, + "layers": 12, + "mlp_ratio": 4, + "pooler_type": "cls_pooler" + } +} diff --git a/src/open_clip/vivit.py b/src/open_clip/vivit.py index 102234b37..891aa6575 100644 --- a/src/open_clip/vivit.py +++ b/src/open_clip/vivit.py @@ -1,3 +1,72 @@ """ ViViT model (https://arxiv.org/abs/2103.15691) """ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +from dataclasses import dataclass + +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + + +@dataclass +class TemporalCfg: + context_length: int = 32 # number of input frames + width: int = 512 + heads: int = 8 + layers: int = 12 + mlp_ratio: int = 4 + pooler_type: str = "cls_pooler" + + +def _build_video_tower( + embed_dim, + vision_cfg, + temporal_cfg, + cast_dtype: Optional[torch.dtype] = None, + ): + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg + + +# TODO: implement +class ViViT(nn.Module): + def __init__(self): + pass + + +# TODO: turn into VideoCoCa +# TODO: implement +# TODO: do we need quickgelu? +class VideoCLIP(nn.Module): + def __init__( + self, + embed_dim, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + temporal_cfg: TemporalCfg, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg + + ''' + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + ''' + + diff --git a/src/training/train.py b/src/training/train.py index 3a1ee2b1a..e39e814f9 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -91,8 +91,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # TODO: adapt dataloaders to fit open_clip format # TODO: generalize train loop to modality1, modality2 instead of image,text maybe images, texts = batch["video_tensor"], batch["text_tokens"] - print(images.shape, texts.shape) - continue images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) From be04c066ff555703f1d1a651c33485562cce0e34 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 00:26:48 +0000 Subject: [PATCH 04/26] rename file + create_model loads something --- src/open_clip/factory.py | 2 +- src/open_clip/{vivit.py => video_model.py} | 67 +++++++++++++++++++--- 2 files changed, 60 insertions(+), 9 deletions(-) rename src/open_clip/{vivit.py => video_model.py} (55%) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 5aa975e5d..ccc0bc7ee 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -13,7 +13,7 @@ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype from .coca_model import CoCa -from .vivit import VideoCLIP # TODO: change once full model is implemented +from .video_model import VideoCLIP # TODO: change once full model is implemented from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf diff --git a/src/open_clip/vivit.py b/src/open_clip/video_model.py similarity index 55% rename from src/open_clip/vivit.py rename to src/open_clip/video_model.py index 891aa6575..24480abc0 100644 --- a/src/open_clip/vivit.py +++ b/src/open_clip/video_model.py @@ -8,6 +8,12 @@ from torch.nn import functional as F from dataclasses import dataclass +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + Transformer, +) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @@ -21,20 +27,60 @@ class TemporalCfg: pooler_type: str = "cls_pooler" +# TODO: ViViT class makes this function a bit pointless +# still thinking about how to organize this better def _build_video_tower( embed_dim, vision_cfg, temporal_cfg, + quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): - vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg - temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg + model = ViViT( + embed_dim, + vision_cfg, + temporal_cfg, + quick_gelu, + cast_dtype, + ) + + return model # TODO: implement class ViViT(nn.Module): - def __init__(self): - pass + def __init__( + self, + embed_dim, + vision_cfg, + temporal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + self.spatial = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + self.temporal = Transformer( + width=temporal_cfg.width, + layers=temporal_cfg.layers, + heads=temporal_cfg.heads, + mlp_ratio=temporal_cfg.mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + ) # TODO: turn into VideoCoCa @@ -47,6 +93,7 @@ def __init__( vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, temporal_cfg: TemporalCfg, + quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -54,7 +101,14 @@ def __init__( text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg - ''' + self.visual = _build_video_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + temporal_cfg=temporal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + self.text = _build_text_tower( embed_dim=embed_dim, text_cfg=text_cfg, @@ -67,6 +121,3 @@ def __init__( if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else text_cfg.vocab_size ) - ''' - - From 0ad716816f120b8207c758f4a8c94dacf4082d97 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 00:57:09 +0000 Subject: [PATCH 05/26] update --- src/open_clip/factory.py | 36 +++++++++++++++++++++--------------- src/open_clip/video_model.py | 16 +++++++++++++--- src/training/video_data.py | 2 -- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index ccc0bc7ee..61ef7a5e8 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -309,21 +309,27 @@ def create_model_and_transforms( output_dict=output_dict, ) - image_mean = image_mean or getattr(model.visual, 'image_mean', None) - image_std = image_std or getattr(model.visual, 'image_std', None) - preprocess_train = image_transform( - model.visual.image_size, - is_train=True, - mean=image_mean, - std=image_std, - aug_cfg=aug_cfg, - ) - preprocess_val = image_transform( - model.visual.image_size, - is_train=False, - mean=image_mean, - std=image_std, - ) + # TODO: better way of getting modality specific transforms + if "ViViT" in model_name: + # TODO: make better preprocessing functions + preprocess_train = lambda vid: vid[:32, :128, :128, :] + preprocess_val = preprocess_train + else: + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) return model, preprocess_train, preprocess_val diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 24480abc0..cb510838a 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -1,6 +1,3 @@ -""" -ViViT model (https://arxiv.org/abs/2103.15691) -""" from typing import Optional import torch @@ -49,6 +46,7 @@ def _build_video_tower( # TODO: implement class ViViT(nn.Module): + """ViViT model (https://arxiv.org/abs/2103.15691)""" def __init__( self, embed_dim, @@ -82,6 +80,9 @@ def __init__( norm_layer=norm_layer, ) + def forward(self, x): + return x + # TODO: turn into VideoCoCa # TODO: implement @@ -121,3 +122,12 @@ def __init__( if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else text_cfg.vocab_size ) + + def encode_video(self, video, normalize: bool = False): + features = self.visual(video) + return F.normalize(features, dim=-1) if normalize else features + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + diff --git a/src/training/video_data.py b/src/training/video_data.py index e546d2648..f220b439c 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -154,8 +154,6 @@ def __iter__(self): def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): - # TODO: get this from the model - preprocess_vid = lambda vid: vid[:32, :128, :128, :] # TODO: adjust this so it works for testing num_samples = args.train_num_samples batch_size = args.batch_size From f9dfd02438e30db109f3d935cd5aca7700b76114 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 01:39:37 +0000 Subject: [PATCH 06/26] update --- src/open_clip/video_model.py | 21 +++++++++++++++------ src/training/main.py | 3 +-- src/training/train.py | 8 ++++++-- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index cb510838a..d92133976 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -80,13 +80,12 @@ def __init__( norm_layer=norm_layer, ) - def forward(self, x): - return x + def forward(self, video): + return torch.randn((video.shape[0], 512)) # TODO: turn into VideoCoCa -# TODO: implement -# TODO: do we need quickgelu? +# TODO: set_grad_checkpointing class VideoCLIP(nn.Module): def __init__( self, @@ -128,6 +127,16 @@ def encode_video(self, video, normalize: bool = False): return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) + print(features.shape) return F.normalize(features, dim=-1) if normalize else features - - + def forward(self, video, text): + video_features = self.encode_video(video, normalize=True) + text_features = self.encode_text(text, normalize=True) + print(video_features.shape, text_features.shape) + # TODO: make loss funcitons generalize to all types of modality pairs + # i.e. make keys more general, for now keeping as image_features + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } diff --git a/src/training/main.py b/src/training/main.py index fddd8b263..cb58183cc 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -336,8 +336,7 @@ def main(args): # initialize datasets # TODO: come up with a way of getting alternative modality data based on model config - # if "ViViT" in model_name: - if True: + if "ViViT" in args.model: data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) else: data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) diff --git a/src/training/train.py b/src/training/train.py index e39e814f9..11c6ca911 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -90,16 +90,20 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # TODO: adapt dataloaders to fit open_clip format # TODO: generalize train loop to modality1, modality2 instead of image,text maybe - images, texts = batch["video_tensor"], batch["text_tokens"] + # images, texts = batch["video_tensor"], batch["text_tokens"] + images, texts = batch images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) data_time_m.update(time.time() - end) optimizer.zero_grad() + print(images.shape, texts.shape) + if args.accum_freq == 1: with autocast(): model_out = model(images, texts) + print(model_out.keys()) logit_scale = model_out["logit_scale"] if args.distill: with torch.no_grad(): @@ -108,8 +112,8 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist losses = loss(**model_out, output_dict=True) total_loss = sum(losses.values()) + print(total_loss) losses["loss"] = total_loss - backward(total_loss, scaler) else: # First, cache the features without any gradient tracking. From df1c6983ec467f5ca6a8a0d87d349e756b48680a Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 02:11:34 +0000 Subject: [PATCH 07/26] embeddings get to loss, time to implement video encoding --- src/open_clip/model_configs/ViViT-B-32.json | 11 +++-------- src/open_clip/video_model.py | 7 ++++--- src/training/main.py | 4 ++-- src/training/train.py | 8 ++------ src/training/video_data.py | 4 ++-- 5 files changed, 13 insertions(+), 21 deletions(-) diff --git a/src/open_clip/model_configs/ViViT-B-32.json b/src/open_clip/model_configs/ViViT-B-32.json index b5ce5369b..06dd83e1e 100644 --- a/src/open_clip/model_configs/ViViT-B-32.json +++ b/src/open_clip/model_configs/ViViT-B-32.json @@ -4,19 +4,14 @@ "image_size": 224, "layers": 12, "width": 768, - "patch_size": 32, - "attentional_pool": true, - "attn_pooler_heads": 8, - "output_tokens": true + "patch_size": 32 }, "text_cfg": { - "context_length": 76, + "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, - "layers": 12, - "embed_cls": true, - "output_tokens": true + "layers": 12 }, "temporal_cfg": { "context_length": 32, diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index d92133976..53aaccbc5 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -3,6 +3,7 @@ import torch from torch import nn from torch.nn import functional as F +import numpy as np from dataclasses import dataclass from .transformer import ( @@ -122,21 +123,21 @@ def __init__( else text_cfg.vocab_size ) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + def encode_video(self, video, normalize: bool = False): features = self.visual(video) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) - print(features.shape) return F.normalize(features, dim=-1) if normalize else features def forward(self, video, text): video_features = self.encode_video(video, normalize=True) text_features = self.encode_text(text, normalize=True) - print(video_features.shape, text_features.shape) # TODO: make loss funcitons generalize to all types of modality pairs # i.e. make keys more general, for now keeping as image_features return { - "image_features": image_features, + "image_features": video_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } diff --git a/src/training/main.py b/src/training/main.py index cb58183cc..b4f2c08b5 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -29,7 +29,7 @@ from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data -from training.video_data import get_data # TODO: maybe we don't need separate files +from training.video_data import get_video_data # TODO: maybe we don't need separate files from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging from training.params import parse_args @@ -337,7 +337,7 @@ def main(args): # initialize datasets # TODO: come up with a way of getting alternative modality data based on model config if "ViViT" in args.model: - data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + data = get_video_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) else: data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) assert len(data), 'At least one train or eval dataset must be specified.' diff --git a/src/training/train.py b/src/training/train.py index 11c6ca911..28a19510e 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -90,20 +90,17 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # TODO: adapt dataloaders to fit open_clip format # TODO: generalize train loop to modality1, modality2 instead of image,text maybe - # images, texts = batch["video_tensor"], batch["text_tokens"] - images, texts = batch + images, texts = batch["video_tensor"], batch["text_tokens"] + # images, texts = batch images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) data_time_m.update(time.time() - end) optimizer.zero_grad() - print(images.shape, texts.shape) - if args.accum_freq == 1: with autocast(): model_out = model(images, texts) - print(model_out.keys()) logit_scale = model_out["logit_scale"] if args.distill: with torch.no_grad(): @@ -112,7 +109,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist losses = loss(**model_out, output_dict=True) total_loss = sum(losses.values()) - print(total_loss) losses["loss"] = total_loss backward(total_loss, scaler) else: diff --git a/src/training/video_data.py b/src/training/video_data.py index f220b439c..77ef59ded 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -79,7 +79,7 @@ def preprocess_dataset(item): if enable_text: text = item[caption_key] caption = text.decode("utf-8") - tokenized_text = tokenizer(caption) + tokenized_text = tokenizer(caption)[0] output["text_tokens"] = tokenized_text output["text"] = caption @@ -172,7 +172,7 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni return DataInfo(dataloader=wds, shared_epoch=shared_epoch) -def get_data(args, preprocess_fns, epoch=0, tokenizer=None): +def get_video_data(args, preprocess_fns, epoch=0, tokenizer=None): preprocess_train, preprocess_val = preprocess_fns data = {} From 7643ce3d3f4c82d438276f09561189d7be750c53 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 03:15:29 +0000 Subject: [PATCH 08/26] update, set num_samples --- src/open_clip/video_model.py | 26 ++++++++++++++++++++++++-- src/training/video_data.py | 16 ++++++++++++++-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 53aaccbc5..392ad2fdc 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -46,8 +46,9 @@ def _build_video_tower( return model # TODO: implement +# TODO: maybe add option for mean pooling class ViViT(nn.Module): - """ViViT model (https://arxiv.org/abs/2103.15691)""" + """ViViT model (https://arxiv.org/abs/2103.15691), factorised encoder variant""" def __init__( self, embed_dim, @@ -66,6 +67,11 @@ def __init__( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) + # class embeddings and positional embeddings + scale = temporal_cfg.width ** -0.5 + self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) + self.video_positional_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.context_length, temporal_cfg.width)) + self.spatial = _build_vision_tower( embed_dim=embed_dim, vision_cfg=vision_cfg, @@ -81,8 +87,24 @@ def __init__( norm_layer=norm_layer, ) + # TODO: add patch dropout as suggested by lucidrains def forward(self, video): - return torch.randn((video.shape[0], 512)) + video = video[:, 1:] # make space for temporal CLS token + # TODO: make this happen + f_e = torch.randn((video.shape[0], video.shape[1], 512)).to(video.device) # shape = [b, cl-1, w] + + # class embeddings and positional embeddings + f_e = torch.cat( + [self.video_class_embedding.to(f_e.dtype) + torch.zeros(f_e.shape[0], 1, f_e.shape[-1], dtype=f_e.dtype, device=f_e.device), + f_e], dim=1) # shape = [b, cl, w] + f_e = f_e + self.video_positional_embedding.to(f_e.dtype) + + return f_e.mean(dim=1) + + + + + # TODO: turn into VideoCoCa diff --git a/src/training/video_data.py b/src/training/video_data.py index 77ef59ded..a3198b006 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -1,5 +1,6 @@ """video dataset creation""" import io +import math import torchvision import tempfile import webdataset as wds @@ -155,18 +156,29 @@ def __iter__(self): def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): num_samples = args.train_num_samples - batch_size = args.batch_size wds = VideoDatasetReader( sampler=lambda a: a, preprocess=preprocess_vid, input_dataset=args.train_data, - batch_size=batch_size, + batch_size=args.batch_size, num_prepro_workers=96, enable_metadata=True, tokenizer=tokenizer, ) + + round_fn = math.floor + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + wds.num_batches = num_batches + wds.num_samples = num_samples + shared_epoch = SharedEpoch(epoch=epoch) return DataInfo(dataloader=wds, shared_epoch=shared_epoch) From cb12acdccf3437fbf47ee81ab1107438f862d42e Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 08:04:56 +0000 Subject: [PATCH 09/26] more filling in --- src/open_clip/factory.py | 11 +++++++++-- src/open_clip/video_model.py | 29 ++++++++++++++++++++++++----- src/training/video_data.py | 3 +-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 61ef7a5e8..7024d21e3 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn.functional as F from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ @@ -312,8 +313,14 @@ def create_model_and_transforms( # TODO: better way of getting modality specific transforms if "ViViT" in model_name: # TODO: make better preprocessing functions - preprocess_train = lambda vid: vid[:32, :128, :128, :] - preprocess_val = preprocess_train + def preprocess_video(video): + video = video[:32, :, :224, :224] + h, w = video.shape[-2:] + video = F.pad(video, (0, 224-w, 0, 224-h)) + return video.float() + + preprocess_train = preprocess_video + preprocess_val = preprocess_video else: image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 392ad2fdc..85b2184ef 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional, Sequence, Tuple import torch from torch import nn @@ -54,6 +54,7 @@ def __init__( embed_dim, vision_cfg, temporal_cfg, + global_average_pool: bool = False, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): @@ -72,6 +73,8 @@ def __init__( self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) self.video_positional_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.context_length, temporal_cfg.width)) + self.ln_pre = norm_layer(temporal_cfg.width) + self.spatial = _build_vision_tower( embed_dim=embed_dim, vision_cfg=vision_cfg, @@ -87,24 +90,40 @@ def __init__( norm_layer=norm_layer, ) + self.global_average_pool = global_average_pool + + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + # TODO: add patch dropout as suggested by lucidrains def forward(self, video): video = video[:, 1:] # make space for temporal CLS token # TODO: make this happen - f_e = torch.randn((video.shape[0], video.shape[1], 512)).to(video.device) # shape = [b, cl-1, w] + batch_size = video.shape[0] + frames = video.flatten(start_dim=0, end_dim=1) + f_e = self.spatial(frames) + f_e = f_e.view(*video.shape[:2], -1) + # class embeddings and positional embeddings f_e = torch.cat( [self.video_class_embedding.to(f_e.dtype) + torch.zeros(f_e.shape[0], 1, f_e.shape[-1], dtype=f_e.dtype, device=f_e.device), f_e], dim=1) # shape = [b, cl, w] f_e = f_e + self.video_positional_embedding.to(f_e.dtype) - return f_e.mean(dim=1) - + # TODO: need to look at paper again, section 3, equations (4,5,6) + # do we need the residual connections? + f_e = self.ln_pre(f_e) + v_e = self.temporal(f_e) - + pooled, tokens = self._global_pool(v_e) + return pooled # TODO: turn into VideoCoCa diff --git a/src/training/video_data.py b/src/training/video_data.py index a3198b006..76ffca944 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -71,8 +71,7 @@ def preprocess_dataset(item): video_data = item[video_key] with tempfile.NamedTemporaryFile() as f: f.write(video_data) - # video = torchvision.io.read_video(video_data) - video, audio, meta = torchvision.io.read_video(f.name) + video, audio, meta = torchvision.io.read_video(f.name, output_format="TCHW") video_tensor = video_transform(video) output["video_filename"] = item["__key__"] output["video_tensor"] = video_tensor From 67a2d33643be81a8a4ae4934e2a7ed042412c7b7 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 10:59:23 +0000 Subject: [PATCH 10/26] slightly improved preprocessing --- src/open_clip/factory.py | 23 +++++++++++++---------- src/open_clip/transform.py | 32 ++++++++++++++++++++++++++++++++ src/training/train.py | 7 +++++-- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 7024d21e3..f78b43dc3 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -18,7 +18,7 @@ from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf -from .transform import image_transform, AugmentationCfg +from .transform import image_transform, video_transform, AugmentationCfg from .tokenizer import HFTokenizer, tokenize @@ -312,15 +312,18 @@ def create_model_and_transforms( # TODO: better way of getting modality specific transforms if "ViViT" in model_name: - # TODO: make better preprocessing functions - def preprocess_video(video): - video = video[:32, :, :224, :224] - h, w = video.shape[-2:] - video = F.pad(video, (0, 224-w, 0, 224-h)) - return video.float() - - preprocess_train = preprocess_video - preprocess_val = preprocess_video + preprocess_train = video_transform( + frame_size=model.visual.spatial.image_size, + n_frames=32, + take_every_nth=2, + is_train=False, # TODO: figre out if frame augmentations make sense + ) + preprocess_val = video_transform( + frame_size=model.visual.spatial.image_size, + n_frames=32, + take_every_nth=2, + is_train=False, + ) else: image_mean = image_mean or getattr(model.visual, 'image_mean', None) image_std = image_std or getattr(model.visual, 'image_std', None) diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 748884a3c..e593dc8a9 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -131,3 +131,35 @@ def image_transform( normalize, ]) return Compose(transforms) + + +# TODO: needs improvmenet +def video_transform( + frame_size: int, + n_frames: int, + take_every_nth: int, + is_train: bool, + ): + + if is_train: + transforms = [ + RandomResizedCrop( + frame_size, + scale=(0.9, 0.1), + interpolation=InterpolationMode.BICUBIC, + ), + ] + else: + transforms = [ + Resize(frame_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(frame_size), + ] + + frame_transform = Compose(transforms) + def apply_frame_transform(video): + video = video[::take_every_nth] + video = video[:n_frames] # TODO: maybe make this middle n frames + # TODO: this .float() is weird, look how this is done in other places + return torch.cat([frame_transform(frame)[None, ...].float() for frame in video]) + + return apply_frame_transform diff --git a/src/training/train.py b/src/training/train.py index 28a19510e..59ff723ff 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -90,8 +90,11 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # TODO: adapt dataloaders to fit open_clip format # TODO: generalize train loop to modality1, modality2 instead of image,text maybe - images, texts = batch["video_tensor"], batch["text_tokens"] - # images, texts = batch + if "ViViT" in args.model: + images, texts = batch["video_tensor"], batch["text_tokens"] + else: + images, texts = batch + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) From f0615f9ffa54b47a51c5e3bcb34244826407d706 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 12:12:20 +0000 Subject: [PATCH 11/26] update --- src/open_clip/factory.py | 4 ++-- src/open_clip/video_model.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index f78b43dc3..ee259c27f 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -315,13 +315,13 @@ def create_model_and_transforms( preprocess_train = video_transform( frame_size=model.visual.spatial.image_size, n_frames=32, - take_every_nth=2, + take_every_nth=1, is_train=False, # TODO: figre out if frame augmentations make sense ) preprocess_val = video_transform( frame_size=model.visual.spatial.image_size, n_frames=32, - take_every_nth=2, + take_every_nth=1, is_train=False, ) else: diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 85b2184ef..464d71a0b 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -92,6 +92,10 @@ def __init__( self.global_average_pool = global_average_pool + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.spatial.set_grad_checkpointing(enable) + self.temporal.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_average_pool: @@ -166,6 +170,11 @@ def __init__( self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + def encode_video(self, video, normalize: bool = False): features = self.visual(video) return F.normalize(features, dim=-1) if normalize else features From 7bd848e2c0ca3720130eca409d27fd0cc224dd75 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 19 Feb 2023 16:20:06 +0000 Subject: [PATCH 12/26] update weird lag --- src/open_clip/transform.py | 10 ++++++++++ src/open_clip/video_model.py | 2 ++ src/training/train.py | 10 ++++++++++ src/training/video_data.py | 9 ++++----- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index e593dc8a9..bbdf19ee5 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -159,6 +159,16 @@ def video_transform( def apply_frame_transform(video): video = video[::take_every_nth] video = video[:n_frames] # TODO: maybe make this middle n frames + + # TODO: maybe padding isn't the way to go + # TODO: also F.pad is acting up for some reason + # isn't letting me input a len 8 tuple for 4d tnesor??? + # video = F.pad(video, tuple([0, 0]*len(video.shape[-3:]) + [0, n_frames - video.shape[0]])) + if video.shape[0] < n_frames: + padded_video = torch.zeros(n_frames, *video.shape[1:]) + padded_video[:video.shape[0]] = video + video = padded_video + # TODO: this .float() is weird, look how this is done in other places return torch.cat([frame_transform(frame)[None, ...].float() for frame in video]) diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 464d71a0b..1b701b324 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -109,8 +109,10 @@ def forward(self, video): # TODO: make this happen batch_size = video.shape[0] + # Flatten all frames in batch across time and encode with ViT frames = video.flatten(start_dim=0, end_dim=1) f_e = self.spatial(frames) + # Put frame embeddings back into correct temporal sequences f_e = f_e.view(*video.shape[:2], -1) # class embeddings and positional embeddings diff --git a/src/training/train.py b/src/training/train.py index 59ff723ff..b0e0a8a79 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -82,6 +82,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist data_time_m = AverageMeter() end = time.time() for i, batch in enumerate(dataloader): + print(batch) i_accum = i // args.accum_freq step = num_batches_per_epoch * epoch + i_accum @@ -97,6 +98,15 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) + print("OUTOUTOUT") + print("OUTOUTOUT") + print("OUTOUTOUT") + print("OUTOUTOUT") + print("OUTOUTOUT") + print(images.shape) + print(images.shape) + print(images.shape) + print(images.shape) data_time_m.update(time.time() - end) optimizer.zero_grad() diff --git a/src/training/video_data.py b/src/training/video_data.py index 76ffca944..f84079295 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -73,6 +73,7 @@ def preprocess_dataset(item): f.write(video_data) video, audio, meta = torchvision.io.read_video(f.name, output_format="TCHW") video_tensor = video_transform(video) + print(video_tensor.shape) output["video_filename"] = item["__key__"] output["video_tensor"] = video_tensor @@ -100,13 +101,12 @@ def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) return default_collate(batch) + # pin_memory=True, + # prefetch_factor=2, data = DataLoader( dataset, batch_size=batch_size, - shuffle=False, num_workers=num_prepro_workers, - pin_memory=True, - prefetch_factor=2, collate_fn=collate_fn if input_format == "files" else None, ) return data @@ -161,12 +161,11 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni preprocess=preprocess_vid, input_dataset=args.train_data, batch_size=args.batch_size, - num_prepro_workers=96, + num_prepro_workers=args.workers, enable_metadata=True, tokenizer=tokenizer, ) - round_fn = math.floor global_batch_size = args.batch_size * args.world_size num_batches = round_fn(num_samples / global_batch_size) From 80f41a068cf5c2db8b30ebf682a6ec188aea516b Mon Sep 17 00:00:00 2001 From: iejMac Date: Mon, 20 Feb 2023 08:20:37 +0000 Subject: [PATCH 13/26] simpler dataloader same results --- src/open_clip/transform.py | 5 +- src/training/train.py | 16 +--- src/training/video_data.py | 192 ++++++++++++++----------------------- 3 files changed, 77 insertions(+), 136 deletions(-) diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index bbdf19ee5..7e6ecc41a 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -156,7 +156,10 @@ def video_transform( ] frame_transform = Compose(transforms) - def apply_frame_transform(video): + def apply_frame_transform(sample): + video, audio, video_meta = sample + video = video.permute(0, 3, 1, 2) + video = video[::take_every_nth] video = video[:n_frames] # TODO: maybe make this middle n frames diff --git a/src/training/train.py b/src/training/train.py index b0e0a8a79..3dbf2f643 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -82,31 +82,17 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist data_time_m = AverageMeter() end = time.time() for i, batch in enumerate(dataloader): - print(batch) i_accum = i // args.accum_freq step = num_batches_per_epoch * epoch + i_accum if not args.skip_scheduler: scheduler(step) - # TODO: adapt dataloaders to fit open_clip format # TODO: generalize train loop to modality1, modality2 instead of image,text maybe - if "ViViT" in args.model: - images, texts = batch["video_tensor"], batch["text_tokens"] - else: - images, texts = batch + images, texts = batch images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) - print("OUTOUTOUT") - print("OUTOUTOUT") - print("OUTOUTOUT") - print("OUTOUTOUT") - print("OUTOUTOUT") - print(images.shape) - print(images.shape) - print(images.shape) - print(images.shape) data_time_m.update(time.time() - end) optimizer.zero_grad() diff --git a/src/training/video_data.py b/src/training/video_data.py index f84079295..8770c5dc8 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -10,6 +10,7 @@ from pathlib import Path from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample class SharedEpoch: @@ -36,136 +37,88 @@ def set_epoch(self, epoch): self.sampler.set_epoch(epoch) +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + def create_webdataset( - urls, + args, video_transform, - enable_text=True, - enable_video=True, - video_key="mp4", - caption_key="txt", - enable_metadata=False, - cache_path=None, - input_sampler=lambda a: a, tokenizer=None, ): - """Create a WebDataset reader, it can read a webdataset of video, text and json""" - - urls = input_sampler(urls) - - dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue) - - def filter_dataset(item): - if enable_text and caption_key not in item: - return False - if enable_video and video_key not in item: - return False - if enable_metadata and "json" not in item: - return False - return True - - filtered_dataset = dataset.select(filter_dataset) - - def preprocess_dataset(item): - output = {} - if enable_video: - video_data = item[video_key] - with tempfile.NamedTemporaryFile() as f: - f.write(video_data) - video, audio, meta = torchvision.io.read_video(f.name, output_format="TCHW") - video_tensor = video_transform(video) - print(video_tensor.shape) - output["video_filename"] = item["__key__"] - output["video_tensor"] = video_tensor - - if enable_text: - text = item[caption_key] - caption = text.decode("utf-8") - tokenized_text = tokenizer(caption)[0] - output["text_tokens"] = tokenized_text - output["text"] = caption - - if enable_metadata: - metadata_file = item["json"] - metadata = metadata_file.decode("utf-8") - output["metadata"] = metadata - return output - - transformed_dataset = filtered_dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue) - return transformed_dataset - - -def dataset_to_dataloader(dataset, batch_size, num_prepro_workers, input_format): - """Create a pytorch dataloader from a dataset""" - - def collate_fn(batch): - batch = list(filter(lambda x: x is not None, batch)) - return default_collate(batch) - - # pin_memory=True, - # prefetch_factor=2, - data = DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_prepro_workers, - collate_fn=collate_fn if input_format == "files" else None, - ) - return data + pipeline = [wds.SimpleShardList(args.train_data)] + is_train = True + + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + wds.decode(wds.torch_video, handler=log_and_continue), + wds.rename(video="mp4", text="txt"), + wds.map_dict(video=video_transform, text=lambda text: tokenizer(text)[0]), + wds.to_tuple("video", "text"), + wds.batched(args.batch_size, partial=not is_train) + ]) -class VideoDatasetReader: - """WebdatasetReader is a reader that reads samples from a webdataset""" - - def __init__( - self, - sampler, - preprocess, - input_dataset, - batch_size, - num_prepro_workers, - enable_text=True, - enable_video=True, - enable_metadata=False, - wds_video_key="mp4", - wds_caption_key="txt", - cache_path=None, - tokenizer=None, - ): - self.batch_size = batch_size - dataset = create_webdataset( - input_dataset, - preprocess, - enable_text=enable_text, - enable_video=enable_video, - video_key=wds_video_key, - caption_key=wds_caption_key, - enable_metadata=enable_metadata, - cache_path=cache_path, - input_sampler=sampler, - tokenizer=tokenizer, - ) - self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "webdataset") - self.num_batches = 0 - self.num_samples = 0 - - - def __iter__(self): - for batch in self.dataloader: - yield batch + dataset = wds.DataPipeline(*pipeline) + return dataset def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): num_samples = args.train_num_samples - wds = VideoDatasetReader( - sampler=lambda a: a, - preprocess=preprocess_vid, - input_dataset=args.train_data, - batch_size=args.batch_size, - num_prepro_workers=args.workers, - enable_metadata=True, + dataset = create_webdataset( + args, + preprocess_vid, tokenizer=tokenizer, ) + dataloader = wds.WebLoader( + dataset, + batch_size=None, + num_workers=args.workers, + persistent_workers=True, + prefetch_factor=8, + pin_memory=True, + ) + round_fn = math.floor global_batch_size = args.batch_size * args.world_size num_batches = round_fn(num_samples / global_batch_size) @@ -174,12 +127,12 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni num_batches = num_worker_batches * num_workers num_samples = num_batches * global_batch_size - wds.num_batches = num_batches - wds.num_samples = num_samples + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples shared_epoch = SharedEpoch(epoch=epoch) - return DataInfo(dataloader=wds, shared_epoch=shared_epoch) + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_video_data(args, preprocess_fns, epoch=0, tokenizer=None): @@ -190,4 +143,3 @@ def get_video_data(args, preprocess_fns, epoch=0, tokenizer=None): data["train"] = get_wds_dataset(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) return data - From f5af600df04f56cb09fa567922eab8791a889697 Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 21 Feb 2023 01:32:36 +0000 Subject: [PATCH 14/26] properly normalize frames --- src/open_clip/factory.py | 4 ++ src/open_clip/transform.py | 21 ++++++++++- src/open_clip/video_model.py | 11 +++++- src/training/main.py | 1 + src/training/video_data.py | 72 +++++++++++++++++++++++++++--------- 5 files changed, 88 insertions(+), 21 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index ee259c27f..aab02a301 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -317,12 +317,16 @@ def create_model_and_transforms( n_frames=32, take_every_nth=1, is_train=False, # TODO: figre out if frame augmentations make sense + frame_mean=None, + frame_std=None, ) preprocess_val = video_transform( frame_size=model.visual.spatial.image_size, n_frames=32, take_every_nth=1, is_train=False, + frame_mean=None, + frame_std=None, ) else: image_mean = image_mean or getattr(model.visual, 'image_mean', None) diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 7e6ecc41a..1266a3554 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -7,7 +7,7 @@ import torchvision.transforms.functional as F from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ - CenterCrop + CenterCrop, ToPILImage from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD @@ -139,7 +139,19 @@ def video_transform( n_frames: int, take_every_nth: int, is_train: bool, + frame_mean: Optional[Tuple[float, ...]] = None, + frame_std: Optional[Tuple[float, ...]] = None, ): + + frame_mean = frame_mean or OPENAI_DATASET_MEAN + if not isinstance(frame_mean, (list, tuple)): + frame_mean = (frame_mean,) * 3 + + frame_std = frame_std or OPENAI_DATASET_STD + if not isinstance(frame_std, (list, tuple)): + frame_std = (frame_std,) * 3 + + normalize = Normalize(mean=frame_mean, std=frame_std) if is_train: transforms = [ @@ -148,11 +160,16 @@ def video_transform( scale=(0.9, 0.1), interpolation=InterpolationMode.BICUBIC, ), + normalize, ] else: transforms = [ + ToPILImage(), Resize(frame_size, interpolation=InterpolationMode.BICUBIC), CenterCrop(frame_size), + _convert_to_rgb, + ToTensor(), + normalize, ] frame_transform = Compose(transforms) @@ -173,6 +190,6 @@ def apply_frame_transform(sample): video = padded_video # TODO: this .float() is weird, look how this is done in other places - return torch.cat([frame_transform(frame)[None, ...].float() for frame in video]) + return torch.cat([frame_transform(frame.float())[None, ...] for frame in video]) return apply_frame_transform diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 1b701b324..2b8ff3227 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -90,7 +90,8 @@ def __init__( norm_layer=norm_layer, ) - self.global_average_pool = global_average_pool + # self.global_average_pool = global_average_pool + self.global_average_pool = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -105,6 +106,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: add patch dropout as suggested by lucidrains def forward(self, video): + # print(video[:, :10, 0, 0, 0]) video = video[:, 1:] # make space for temporal CLS token # TODO: make this happen batch_size = video.shape[0] @@ -114,6 +116,9 @@ def forward(self, video): f_e = self.spatial(frames) # Put frame embeddings back into correct temporal sequences f_e = f_e.view(*video.shape[:2], -1) + + # print("FRAME EMBS") + # print(f_e[:, :10, 0]) # class embeddings and positional embeddings f_e = torch.cat( @@ -129,6 +134,9 @@ def forward(self, video): pooled, tokens = self._global_pool(v_e) + # print("POOOOLED") + # print(pooled[:, :10]) + return pooled @@ -186,6 +194,7 @@ def encode_text(self, text, normalize: bool = False): def forward(self, video, text): video_features = self.encode_video(video, normalize=True) text_features = self.encode_text(text, normalize=True) + # print(text_features[:, :10]) # TODO: make loss funcitons generalize to all types of modality pairs # i.e. make keys more general, for now keeping as image_features return { diff --git a/src/training/main.py b/src/training/main.py index b4f2c08b5..e7a1e70c3 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -340,6 +340,7 @@ def main(args): data = get_video_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) else: data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + assert len(data), 'At least one train or eval dataset must be specified.' # create scheduler if train diff --git a/src/training/video_data.py b/src/training/video_data.py index 8770c5dc8..4267dfc60 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -1,6 +1,8 @@ """video dataset creation""" import io +import logging import math +import random import torchvision import tempfile import webdataset as wds @@ -10,6 +12,7 @@ from pathlib import Path from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate +from webdataset.filters import _shuffle from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample @@ -75,40 +78,74 @@ def tarfile_to_samples_nothrow(src, handler=log_and_continue): samples = group_by_keys_nothrow(files, handler=handler) return samples -def create_webdataset( - args, - video_transform, - tokenizer=None, -): + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers + seed = pytorch_worker_seed(epoch) + else: + # This seed to be deterministic AND the same across all nodes/workers in each epoch + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): + num_samples = args.train_num_samples + shared_epoch = SharedEpoch(epoch=epoch) + pipeline = [wds.SimpleShardList(args.train_data)] is_train = True pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), wds.split_by_node, wds.split_by_worker, tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), ]) + pipeline.extend([ wds.decode(wds.torch_video, handler=log_and_continue), wds.rename(video="mp4", text="txt"), - wds.map_dict(video=video_transform, text=lambda text: tokenizer(text)[0]), + wds.map_dict(video=preprocess_vid, text=lambda text: tokenizer(text)[0]), wds.to_tuple("video", "text"), wds.batched(args.batch_size, partial=not is_train) ]) dataset = wds.DataPipeline(*pipeline) - return dataset - - -def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): - num_samples = args.train_num_samples - - dataset = create_webdataset( - args, - preprocess_vid, - tokenizer=tokenizer, - ) dataloader = wds.WebLoader( dataset, @@ -130,7 +167,6 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni dataloader.num_batches = num_batches dataloader.num_samples = num_samples - shared_epoch = SharedEpoch(epoch=epoch) return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) From 982ee88769c6dc6e037db3befff6edf328fcec2d Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 21 Feb 2023 21:34:44 +0000 Subject: [PATCH 15/26] update no temporal --- src/open_clip/transform.py | 3 +++ src/open_clip/video_model.py | 21 +++++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 1266a3554..1045da5da 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -155,11 +155,14 @@ def video_transform( if is_train: transforms = [ + ToPILImage(), RandomResizedCrop( frame_size, scale=(0.9, 0.1), interpolation=InterpolationMode.BICUBIC, ), + _convert_to_rgb, + ToTensor(), normalize, ] else: diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 2b8ff3227..58c704e43 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -68,12 +68,14 @@ def __init__( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) + ''' # class embeddings and positional embeddings scale = temporal_cfg.width ** -0.5 self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) self.video_positional_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.context_length, temporal_cfg.width)) + ''' - self.ln_pre = norm_layer(temporal_cfg.width) + # self.ln_pre = norm_layer(temporal_cfg.width) self.spatial = _build_vision_tower( embed_dim=embed_dim, @@ -81,6 +83,7 @@ def __init__( quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) + ''' self.temporal = Transformer( width=temporal_cfg.width, layers=temporal_cfg.layers, @@ -89,14 +92,15 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, ) + ''' - # self.global_average_pool = global_average_pool - self.global_average_pool = True + + self.global_average_pool = global_average_pool @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.spatial.set_grad_checkpointing(enable) - self.temporal.grad_checkpointing = enable + # self.temporal.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_average_pool: @@ -117,6 +121,7 @@ def forward(self, video): # Put frame embeddings back into correct temporal sequences f_e = f_e.view(*video.shape[:2], -1) + ''' # print("FRAME EMBS") # print(f_e[:, :10, 0]) @@ -133,9 +138,13 @@ def forward(self, video): v_e = self.temporal(f_e) pooled, tokens = self._global_pool(v_e) + ''' + + pooled = torch.mean(f_e, dim=1) - # print("POOOOLED") - # print(pooled[:, :10]) + print("POOOOLED") + print(pooled[:10, :10]) + print(torch.mean(torch.var(pooled, dim=0))) return pooled From 3f4fde7c76ff7bc95654360c36ad41c17295ac73 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 26 Feb 2023 05:27:16 +0000 Subject: [PATCH 16/26] filter no mp4 samples --- src/training/video_data.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/training/video_data.py b/src/training/video_data.py index 4267dfc60..e6b1eda99 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -40,6 +40,12 @@ def set_epoch(self, epoch): self.sampler.set_epoch(epoch) +def filter_no_caption_or_no_video(sample): + has_caption = ('txt' in sample) + has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) + return has_caption and has_image + + def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) @@ -138,6 +144,7 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni pipeline.extend([ + wds.select(filter_no_caption_or_no_video), wds.decode(wds.torch_video, handler=log_and_continue), wds.rename(video="mp4", text="txt"), wds.map_dict(video=preprocess_vid, text=lambda text: tokenizer(text)[0]), From 7176d95af6c49ff95507ec5f0dda8cf73fe451e5 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 26 Feb 2023 05:36:55 +0000 Subject: [PATCH 17/26] update --- src/training/video_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/video_data.py b/src/training/video_data.py index e6b1eda99..eaf17788a 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -42,8 +42,8 @@ def set_epoch(self, epoch): def filter_no_caption_or_no_video(sample): has_caption = ('txt' in sample) - has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) - return has_caption and has_image + has_video = ('mp4' in sample) + return has_caption and has_video def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): From 3f64b62c949293357ae4a5b6c8e45e13428b0425 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 26 Feb 2023 08:44:04 +0000 Subject: [PATCH 18/26] adding projection removes weird const loss bug but training doesn't go very well --- src/open_clip/video_model.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 58c704e43..03fc6a523 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -68,14 +68,14 @@ def __init__( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) - ''' # class embeddings and positional embeddings scale = temporal_cfg.width ** -0.5 self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) self.video_positional_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.context_length, temporal_cfg.width)) - ''' - # self.ln_pre = norm_layer(temporal_cfg.width) + self.ln_pre = norm_layer(temporal_cfg.width) + self.ln_post = norm_layer(temporal_cfg.width) + self.proj = nn.Parameter(scale * torch.randn(temporal_cfg.width, temporal_cfg.width)) self.spatial = _build_vision_tower( embed_dim=embed_dim, @@ -83,7 +83,6 @@ def __init__( quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) - ''' self.temporal = Transformer( width=temporal_cfg.width, layers=temporal_cfg.layers, @@ -93,14 +92,20 @@ def __init__( norm_layer=norm_layer, ) ''' - + self.temporal = nn.Sequential( + nn.Linear(temporal_cfg.width, temporal_cfg.width*temporal_cfg.mlp_ratio), + act_layer(), + nn.Linear(temporal_cfg.width*temporal_cfg.mlp_ratio, temporal_cfg.width), + ) + ''' self.global_average_pool = global_average_pool + # self.global_average_pool = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.spatial.set_grad_checkpointing(enable) - # self.temporal.grad_checkpointing = enable + self.temporal.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_average_pool: @@ -110,9 +115,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: add patch dropout as suggested by lucidrains def forward(self, video): - # print(video[:, :10, 0, 0, 0]) video = video[:, 1:] # make space for temporal CLS token - # TODO: make this happen batch_size = video.shape[0] # Flatten all frames in batch across time and encode with ViT @@ -121,10 +124,6 @@ def forward(self, video): # Put frame embeddings back into correct temporal sequences f_e = f_e.view(*video.shape[:2], -1) - ''' - # print("FRAME EMBS") - # print(f_e[:, :10, 0]) - # class embeddings and positional embeddings f_e = torch.cat( [self.video_class_embedding.to(f_e.dtype) + torch.zeros(f_e.shape[0], 1, f_e.shape[-1], dtype=f_e.dtype, device=f_e.device), @@ -135,12 +134,13 @@ def forward(self, video): # do we need the residual connections? f_e = self.ln_pre(f_e) + f_e = f_e.permute(1, 0, 2) v_e = self.temporal(f_e) + v_e = v_e.permute(1, 0, 2) pooled, tokens = self._global_pool(v_e) - ''' - - pooled = torch.mean(f_e, dim=1) + pooled = self.ln_post(pooled) + pooled = pooled @ self.proj print("POOOOLED") print(pooled[:10, :10]) @@ -203,7 +203,6 @@ def encode_text(self, text, normalize: bool = False): def forward(self, video, text): video_features = self.encode_video(video, normalize=True) text_features = self.encode_text(text, normalize=True) - # print(text_features[:, :10]) # TODO: make loss funcitons generalize to all types of modality pairs # i.e. make keys more general, for now keeping as image_features return { From 8d3cc4895daa99eb677bf2606bf7f7c0e70f5d26 Mon Sep 17 00:00:00 2001 From: iejMac Date: Mon, 27 Feb 2023 01:08:00 +0000 Subject: [PATCH 19/26] some updates --- src/open_clip/transformer.py | 6 +++++- src/open_clip/video_model.py | 21 ++++++--------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 4e0151017..2205cb93a 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -495,9 +495,13 @@ def forward(self, x: torch.Tensor): if self.proj is not None: pooled = pooled @ self.proj + # print("POOOOLED") + # print(pooled[:10, :10]) + # print(torch.mean(torch.var(pooled, dim=0))) + if self.output_tokens: return pooled, tokens - + return pooled diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 03fc6a523..a947f01be 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -45,8 +45,7 @@ def _build_video_tower( return model -# TODO: implement -# TODO: maybe add option for mean pooling +# TODO: finish option for mean pooling (no cls token if global_average_pool == True) class ViViT(nn.Module): """ViViT model (https://arxiv.org/abs/2103.15691), factorised encoder variant""" def __init__( @@ -75,7 +74,7 @@ def __init__( self.ln_pre = norm_layer(temporal_cfg.width) self.ln_post = norm_layer(temporal_cfg.width) - self.proj = nn.Parameter(scale * torch.randn(temporal_cfg.width, temporal_cfg.width)) + self.proj = nn.Parameter(scale * torch.randn(temporal_cfg.width, embed_dim)) self.spatial = _build_vision_tower( embed_dim=embed_dim, @@ -91,16 +90,8 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, ) - ''' - self.temporal = nn.Sequential( - nn.Linear(temporal_cfg.width, temporal_cfg.width*temporal_cfg.mlp_ratio), - act_layer(), - nn.Linear(temporal_cfg.width*temporal_cfg.mlp_ratio, temporal_cfg.width), - ) - ''' self.global_average_pool = global_average_pool - # self.global_average_pool = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -142,9 +133,9 @@ def forward(self, video): pooled = self.ln_post(pooled) pooled = pooled @ self.proj - print("POOOOLED") - print(pooled[:10, :10]) - print(torch.mean(torch.var(pooled, dim=0))) + # print("POOLED") + # print(pooled[:10, :10]) + # print(torch.mean(torch.var(pooled, dim=0))) return pooled @@ -203,7 +194,7 @@ def encode_text(self, text, normalize: bool = False): def forward(self, video, text): video_features = self.encode_video(video, normalize=True) text_features = self.encode_text(text, normalize=True) - # TODO: make loss funcitons generalize to all types of modality pairs + # TODO: make loss functions generalize to all types of modality pairs # i.e. make keys more general, for now keeping as image_features return { "image_features": video_features, From 42304280e684676c8a0b2e0817c1a1c913258ea2 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 11 Mar 2023 23:32:56 +0000 Subject: [PATCH 20/26] save changes --- src/open_clip/factory.py | 8 ++++---- src/open_clip/transform.py | 2 ++ src/open_clip/video_model.py | 17 +++++++++++++++-- src/training/video_data.py | 10 ++++++---- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index aab02a301..3f902cc8e 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -314,16 +314,16 @@ def create_model_and_transforms( if "ViViT" in model_name: preprocess_train = video_transform( frame_size=model.visual.spatial.image_size, - n_frames=32, - take_every_nth=1, + n_frames=model.visual.context_length, + take_every_nth=5, is_train=False, # TODO: figre out if frame augmentations make sense frame_mean=None, frame_std=None, ) preprocess_val = video_transform( frame_size=model.visual.spatial.image_size, - n_frames=32, - take_every_nth=1, + n_frames=model.visual.context_length, + take_every_nth=5, is_train=False, frame_mean=None, frame_std=None, diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 1045da5da..51e72492a 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -187,6 +187,7 @@ def apply_frame_transform(sample): # TODO: also F.pad is acting up for some reason # isn't letting me input a len 8 tuple for 4d tnesor??? # video = F.pad(video, tuple([0, 0]*len(video.shape[-3:]) + [0, n_frames - video.shape[0]])) + if video.shape[0] < n_frames: padded_video = torch.zeros(n_frames, *video.shape[1:]) padded_video[:video.shape[0]] = video @@ -195,4 +196,5 @@ def apply_frame_transform(sample): # TODO: this .float() is weird, look how this is done in other places return torch.cat([frame_transform(frame.float())[None, ...] for frame in video]) + return apply_frame_transform diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index a947f01be..818cd295b 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -66,7 +66,9 @@ def __init__( norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) + self.context_length = temporal_cfg.context_length + ''' # class embeddings and positional embeddings scale = temporal_cfg.width ** -0.5 self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) @@ -75,6 +77,7 @@ def __init__( self.ln_pre = norm_layer(temporal_cfg.width) self.ln_post = norm_layer(temporal_cfg.width) self.proj = nn.Parameter(scale * torch.randn(temporal_cfg.width, embed_dim)) + ''' self.spatial = _build_vision_tower( embed_dim=embed_dim, @@ -82,6 +85,7 @@ def __init__( quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) + ''' self.temporal = Transformer( width=temporal_cfg.width, layers=temporal_cfg.layers, @@ -90,13 +94,15 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, ) + ''' self.global_average_pool = global_average_pool + self.global_average_pool = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.spatial.set_grad_checkpointing(enable) - self.temporal.grad_checkpointing = enable + # self.temporal.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_average_pool: @@ -106,7 +112,9 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: add patch dropout as suggested by lucidrains def forward(self, video): - video = video[:, 1:] # make space for temporal CLS token + print("VIDEO SHAPE") + print(video.shape) + # video = video[:, 1:] # make space for temporal CLS token batch_size = video.shape[0] # Flatten all frames in batch across time and encode with ViT @@ -114,7 +122,10 @@ def forward(self, video): f_e = self.spatial(frames) # Put frame embeddings back into correct temporal sequences f_e = f_e.view(*video.shape[:2], -1) + print("FRAME EMBEDDINGS SHAPE") + print(f_e.shape) + ''' # class embeddings and positional embeddings f_e = torch.cat( [self.video_class_embedding.to(f_e.dtype) + torch.zeros(f_e.shape[0], 1, f_e.shape[-1], dtype=f_e.dtype, device=f_e.device), @@ -136,6 +147,8 @@ def forward(self, video): # print("POOLED") # print(pooled[:10, :10]) # print(torch.mean(torch.var(pooled, dim=0))) + ''' + pooled, tokens = self._global_pool(f_e) return pooled diff --git a/src/training/video_data.py b/src/training/video_data.py index eaf17788a..b115edf11 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -140,9 +140,12 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni wds.split_by_node, wds.split_by_worker, tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), ]) - pipeline.extend([ wds.select(filter_no_caption_or_no_video), wds.decode(wds.torch_video, handler=log_and_continue), @@ -157,10 +160,9 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni dataloader = wds.WebLoader( dataset, batch_size=None, - num_workers=args.workers, + shuffle=False, + num_workers=1,# args.workers, persistent_workers=True, - prefetch_factor=8, - pin_memory=True, ) round_fn = math.floor From 9413c5e9737f7106425b931bc677a1187fbe4c13 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sun, 26 Mar 2023 09:16:59 +0000 Subject: [PATCH 21/26] update dataloader to use video2dataset --- src/training/train.py | 3 ++- src/training/video_data.py | 52 +++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 3dbf2f643..0d3834d78 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -89,7 +89,8 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist scheduler(step) # TODO: generalize train loop to modality1, modality2 instead of image,text maybe - images, texts = batch + # images, texts = batch + images, texts = batch["mp4"], batch["txt"] images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) diff --git a/src/training/video_data.py b/src/training/video_data.py index b115edf11..4da03ed8b 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -15,6 +15,8 @@ from webdataset.filters import _shuffle from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample +from video2dataset.dataloader import get_video_dataset + class SharedEpoch: def __init__(self, epoch: int = 0): @@ -180,11 +182,59 @@ def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokeni return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) +def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): + num_samples = args.train_num_samples + shared_epoch = SharedEpoch(epoch=epoch) + + decoder_kwargs = { # TODO: update with params + "n_frames": 32, + "fps": 10, + "num_threads": 12, + } + + custom_transforms = { + "mp4": lambda x: x.permute(0, 3, 1, 2), + "txt": lambda text: tokenizer(text)[0], + } + + dataset = get_video_dataset( + urls=args.train_data, + batch_size=args.batch_size, + decoder_kwargs=decoder_kwargs, + custom_transforms=custom_transforms, + resize_size=224, + crop_size=224, + ) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=1,# args.workers, + persistent_workers=True, + ) + + round_fn = math.floor + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + def get_video_data(args, preprocess_fns, epoch=0, tokenizer=None): preprocess_train, preprocess_val = preprocess_fns data = {} if args.train_data: - data["train"] = get_wds_dataset(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + # data["train"] = get_wds_dataset(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + data["train"] = get_wds_dataset2(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) return data From 5c65a52dad3537c86a5456c2521ae2f4fa233ab6 Mon Sep 17 00:00:00 2001 From: iejMac Date: Mon, 27 Mar 2023 07:27:44 +0000 Subject: [PATCH 22/26] update --- src/open_clip/transformer.py | 4 ---- src/open_clip/video_model.py | 25 ++++++++----------------- src/training/video_data.py | 6 ++++-- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 2205cb93a..1ca67b96c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -495,10 +495,6 @@ def forward(self, x: torch.Tensor): if self.proj is not None: pooled = pooled @ self.proj - # print("POOOOLED") - # print(pooled[:10, :10]) - # print(torch.mean(torch.var(pooled, dim=0))) - if self.output_tokens: return pooled, tokens diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py index 818cd295b..75f7de5e3 100644 --- a/src/open_clip/video_model.py +++ b/src/open_clip/video_model.py @@ -68,7 +68,6 @@ def __init__( ) self.context_length = temporal_cfg.context_length - ''' # class embeddings and positional embeddings scale = temporal_cfg.width ** -0.5 self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) @@ -77,7 +76,6 @@ def __init__( self.ln_pre = norm_layer(temporal_cfg.width) self.ln_post = norm_layer(temporal_cfg.width) self.proj = nn.Parameter(scale * torch.randn(temporal_cfg.width, embed_dim)) - ''' self.spatial = _build_vision_tower( embed_dim=embed_dim, @@ -85,7 +83,6 @@ def __init__( quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) - ''' self.temporal = Transformer( width=temporal_cfg.width, layers=temporal_cfg.layers, @@ -94,15 +91,13 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, ) - ''' self.global_average_pool = global_average_pool - self.global_average_pool = True @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.spatial.set_grad_checkpointing(enable) - # self.temporal.grad_checkpointing = enable + self.temporal.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.global_average_pool: @@ -112,9 +107,7 @@ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: add patch dropout as suggested by lucidrains def forward(self, video): - print("VIDEO SHAPE") - print(video.shape) - # video = video[:, 1:] # make space for temporal CLS token + video = video[:, 1:] # make space for temporal CLS token batch_size = video.shape[0] # Flatten all frames in batch across time and encode with ViT @@ -122,10 +115,7 @@ def forward(self, video): f_e = self.spatial(frames) # Put frame embeddings back into correct temporal sequences f_e = f_e.view(*video.shape[:2], -1) - print("FRAME EMBEDDINGS SHAPE") - print(f_e.shape) - ''' # class embeddings and positional embeddings f_e = torch.cat( [self.video_class_embedding.to(f_e.dtype) + torch.zeros(f_e.shape[0], 1, f_e.shape[-1], dtype=f_e.dtype, device=f_e.device), @@ -140,21 +130,22 @@ def forward(self, video): v_e = self.temporal(f_e) v_e = v_e.permute(1, 0, 2) + self.global_average_pool = True pooled, tokens = self._global_pool(v_e) pooled = self.ln_post(pooled) pooled = pooled @ self.proj - # print("POOLED") - # print(pooled[:10, :10]) - # print(torch.mean(torch.var(pooled, dim=0))) ''' - pooled, tokens = self._global_pool(f_e) + print("POOLED") + print(pooled[:10, :10]) + print(pooled.shape) + print(torch.mean(torch.var(pooled, dim=0))) + ''' return pooled # TODO: turn into VideoCoCa -# TODO: set_grad_checkpointing class VideoCLIP(nn.Module): def __init__( self, diff --git a/src/training/video_data.py b/src/training/video_data.py index 4da03ed8b..392196ca1 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -187,8 +187,8 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token shared_epoch = SharedEpoch(epoch=epoch) decoder_kwargs = { # TODO: update with params - "n_frames": 32, - "fps": 10, + "n_frames": 8, + "fps": 1, "num_threads": 12, } @@ -200,6 +200,8 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token dataset = get_video_dataset( urls=args.train_data, batch_size=args.batch_size, + shuffle=True, + repeat=True, decoder_kwargs=decoder_kwargs, custom_transforms=custom_transforms, resize_size=224, From f2fa5bd32ed77d242afa96e656301cb65ecd31c2 Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 28 Mar 2023 00:46:31 +0000 Subject: [PATCH 23/26] repeat is bad --- src/training/video_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/training/video_data.py b/src/training/video_data.py index 392196ca1..90f071f13 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -201,7 +201,6 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token urls=args.train_data, batch_size=args.batch_size, shuffle=True, - repeat=True, decoder_kwargs=decoder_kwargs, custom_transforms=custom_transforms, resize_size=224, From 31251717bf45fb44ae1c32aab3e85b385b4ee6ac Mon Sep 17 00:00:00 2001 From: iejMac Date: Tue, 4 Apr 2023 01:23:34 +0000 Subject: [PATCH 24/26] enable loading CLIP weights to spatial and text encoders --- src/open_clip/factory.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 3f902cc8e..6491cf8e7 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -102,7 +102,18 @@ def load_checkpoint(model, checkpoint_path, strict=True): if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) resize_pos_embed(state_dict, model) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) + + incompatible_keys = [] + # TODO: find better way of doing this + if isinstance(model, VideoCLIP): + text_state_dict = dict([(k[len("text."):], v) for (k, v) in state_dict.items() if k.startswith("text")]) + visual_state_dict = dict([(k[len("visual."):], v) for (k, v) in state_dict.items() if k.startswith("visual")]) + + incompatible_keys += model.text.load_state_dict(text_state_dict, strict=strict) + incompatible_keys += model.visual.spatial.load_state_dict(visual_state_dict, strict=strict) + else: + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys @@ -201,6 +212,12 @@ def create_model( pretrained_loaded = False if pretrained: checkpoint_path = '' + + # TODO: not sure how to initialize components nicely + # idea for now: model_name:pretrained + if ":" in pretrained: + model_name, pretrained = pretrained.split(":") + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) From b093a4c4083a1e23508371c907a8ca390f2d38d2 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 15 Apr 2023 13:53:40 +0000 Subject: [PATCH 25/26] update --- src/training/video_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/training/video_data.py b/src/training/video_data.py index 90f071f13..3527f1bd3 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -200,18 +200,19 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token dataset = get_video_dataset( urls=args.train_data, batch_size=args.batch_size, - shuffle=True, + shuffle=1, decoder_kwargs=decoder_kwargs, custom_transforms=custom_transforms, resize_size=224, crop_size=224, + keys_to_remove=["m4a"], ) dataloader = wds.WebLoader( dataset, batch_size=None, shuffle=False, - num_workers=1,# args.workers, + num_workers=args.workers, persistent_workers=True, ) From 9db7425a447dd650fe1fa623c1a0507562eb1544 Mon Sep 17 00:00:00 2001 From: iejmac Date: Fri, 9 Jun 2023 01:28:20 +0000 Subject: [PATCH 26/26] update --- .../model_configs/ViViT-B-32_short.json | 24 +++++++++++++ .../model_configs/ViViT-L-14_short.json | 24 +++++++++++++ src/test.sh | 24 +++++++++++++ src/training/train.py | 34 ++++++++++++++++++- src/training/video_data.py | 16 ++++++++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 src/open_clip/model_configs/ViViT-B-32_short.json create mode 100644 src/open_clip/model_configs/ViViT-L-14_short.json create mode 100755 src/test.sh diff --git a/src/open_clip/model_configs/ViViT-B-32_short.json b/src/open_clip/model_configs/ViViT-B-32_short.json new file mode 100644 index 000000000..b84041b35 --- /dev/null +++ b/src/open_clip/model_configs/ViViT-B-32_short.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "temporal_cfg": { + "context_length": 8, + "width": 512, + "heads": 8, + "layers": 12, + "mlp_ratio": 4, + "pooler_type": "cls_pooler" + } +} diff --git a/src/open_clip/model_configs/ViViT-L-14_short.json b/src/open_clip/model_configs/ViViT-L-14_short.json new file mode 100644 index 000000000..d7a8ccccb --- /dev/null +++ b/src/open_clip/model_configs/ViViT-L-14_short.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "temporal_cfg": { + "context_length": 8, + "width": 768, + "heads": 12, + "layers": 12, + "mlp_ratio": 4, + "pooler_type": "cls_pooler" + } +} diff --git a/src/test.sh b/src/test.sh new file mode 100755 index 000000000..7fc6225f1 --- /dev/null +++ b/src/test.sh @@ -0,0 +1,24 @@ +#!/bin/bash + + +# --pretrained "ViT-L-14:laion2b_s32b_b82k" \ + +python3 -m training.main \ + --train-data "s3://stability-west/webvid-10M/{00000..10727}.tar" \ + --train-num-samples 10727000 \ + --dataset-type webdataset \ + --batch-size=128 \ + --precision amp_bfloat16 \ + --epochs=9 \ + --warmup=100 \ + --lr-scheduler "const" \ + --lr 3e-4 \ + --workers=12 \ + --model "ViViT-B-32_short" \ + --pretrained "ViT-B-32:laion2b_s34b_b79k" \ + --ddp-static-graph \ + --local-loss \ + --log-every-n-steps 1 \ + --gather-with-grad \ + --grad-checkpointing \ + --report-to wandb diff --git a/src/training/train.py b/src/training/train.py index 0d3834d78..e413a0659 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -20,6 +20,10 @@ from .precision import get_autocast +OPENAI_DATASET_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) +OPENAI_DATASET_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) + + class AverageMeter(object): """Computes and stores the average and current value""" @@ -90,7 +94,19 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # TODO: generalize train loop to modality1, modality2 instead of image,text maybe # images, texts = batch - images, texts = batch["mp4"], batch["txt"] + # images, texts = batch["mp4"], batch["txt"] + images, texts = batch["video"], batch["txt"] + # texts = batch['txt'] + # images = torch.zeros((32, 8, 3, 224, 224)) + print(images.shape) + + original_means = images.mean(dim=(0, 1, 3, 4), keepdim=True) + original_stds = images.std(dim=(0, 1, 3, 4), keepdim=True) + images = (images - original_means) / original_stds + + images = images * OPENAI_DATASET_STD[..., None, None] + OPENAI_DATASET_MEAN[..., None, None] + + texts = data['tokenizer'](texts) images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) @@ -229,6 +245,22 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() + + # Saving checkpoints every 1M steps + if is_master(args) and (i_accum % 1000000 == 0 or batch_count == num_batches_per_epoch): + checkpoint_dict = { + "epoch": epoch, + "name": args.name, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) # end for diff --git a/src/training/video_data.py b/src/training/video_data.py index 3527f1bd3..c4eb29221 100644 --- a/src/training/video_data.py +++ b/src/training/video_data.py @@ -16,6 +16,8 @@ from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample from video2dataset.dataloader import get_video_dataset +from omegaconf import OmegaConf +from sdata import create_dataset, create_loader class SharedEpoch: @@ -186,6 +188,7 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token num_samples = args.train_num_samples shared_epoch = SharedEpoch(epoch=epoch) + ''' decoder_kwargs = { # TODO: update with params "n_frames": 8, "fps": 1, @@ -206,6 +209,7 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token resize_size=224, crop_size=224, keys_to_remove=["m4a"], + handler=wds.warn_and_continue, ) dataloader = wds.WebLoader( @@ -215,6 +219,16 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token num_workers=args.workers, persistent_workers=True, ) + ''' + + # config = OmegaConf.load("/admin/home-iejmac/stable-datasets/examples/configs/debug.yaml") + config = OmegaConf.load("/admin/home-iejmac/stable-datasets/examples/configs/video_test.yaml") + + # build config + datapipeline = create_dataset(**config.dataset) + + # build loader + dataloader = create_loader(datapipeline, **config.loader) round_fn = math.floor global_batch_size = args.batch_size * args.world_size @@ -227,7 +241,6 @@ def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, token dataloader.num_batches = num_batches dataloader.num_samples = num_samples - return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) @@ -238,5 +251,6 @@ def get_video_data(args, preprocess_fns, epoch=0, tokenizer=None): if args.train_data: # data["train"] = get_wds_dataset(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) data["train"] = get_wds_dataset2(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + data["tokenizer"] = tokenizer return data