diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 14011f934..6491cf8e7 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -8,15 +8,17 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn.functional as F from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype from .coca_model import CoCa +from .video_model import VideoCLIP # TODO: change once full model is implemented from .loss import ClipLoss, DistillClipLoss, CoCaLoss from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf -from .transform import image_transform, AugmentationCfg +from .transform import image_transform, video_transform, AugmentationCfg from .tokenizer import HFTokenizer, tokenize @@ -100,7 +102,18 @@ def load_checkpoint(model, checkpoint_path, strict=True): if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) resize_pos_embed(state_dict, model) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) + + incompatible_keys = [] + # TODO: find better way of doing this + if isinstance(model, VideoCLIP): + text_state_dict = dict([(k[len("text."):], v) for (k, v) in state_dict.items() if k.startswith("text")]) + visual_state_dict = dict([(k[len("visual."):], v) for (k, v) in state_dict.items() if k.startswith("visual")]) + + incompatible_keys += model.text.load_state_dict(text_state_dict, strict=strict) + incompatible_keys += model.visual.spatial.load_state_dict(visual_state_dict, strict=strict) + else: + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys @@ -191,11 +204,20 @@ def create_model( else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) + if "ViViT" in model_name: # TODO better way of detecting video configs + model = VideoCLIP(**model_cfg) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) pretrained_loaded = False if pretrained: checkpoint_path = '' + + # TODO: not sure how to initialize components nicely + # idea for now: model_name:pretrained + if ":" in pretrained: + model_name, pretrained = pretrained.split(":") + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) @@ -305,21 +327,40 @@ def create_model_and_transforms( output_dict=output_dict, ) - image_mean = image_mean or getattr(model.visual, 'image_mean', None) - image_std = image_std or getattr(model.visual, 'image_std', None) - preprocess_train = image_transform( - model.visual.image_size, - is_train=True, - mean=image_mean, - std=image_std, - aug_cfg=aug_cfg, - ) - preprocess_val = image_transform( - model.visual.image_size, - is_train=False, - mean=image_mean, - std=image_std, - ) + # TODO: better way of getting modality specific transforms + if "ViViT" in model_name: + preprocess_train = video_transform( + frame_size=model.visual.spatial.image_size, + n_frames=model.visual.context_length, + take_every_nth=5, + is_train=False, # TODO: figre out if frame augmentations make sense + frame_mean=None, + frame_std=None, + ) + preprocess_val = video_transform( + frame_size=model.visual.spatial.image_size, + n_frames=model.visual.context_length, + take_every_nth=5, + is_train=False, + frame_mean=None, + frame_std=None, + ) + else: + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) return model, preprocess_train, preprocess_val diff --git a/src/open_clip/model_configs/ViViT-B-32.json b/src/open_clip/model_configs/ViViT-B-32.json new file mode 100644 index 000000000..06dd83e1e --- /dev/null +++ b/src/open_clip/model_configs/ViViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "temporal_cfg": { + "context_length": 32, + "width": 512, + "heads": 8, + "layers": 12, + "mlp_ratio": 4, + "pooler_type": "cls_pooler" + } +} diff --git a/src/open_clip/model_configs/ViViT-B-32_short.json b/src/open_clip/model_configs/ViViT-B-32_short.json new file mode 100644 index 000000000..b84041b35 --- /dev/null +++ b/src/open_clip/model_configs/ViViT-B-32_short.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "temporal_cfg": { + "context_length": 8, + "width": 512, + "heads": 8, + "layers": 12, + "mlp_ratio": 4, + "pooler_type": "cls_pooler" + } +} diff --git a/src/open_clip/model_configs/ViViT-L-14_short.json b/src/open_clip/model_configs/ViViT-L-14_short.json new file mode 100644 index 000000000..d7a8ccccb --- /dev/null +++ b/src/open_clip/model_configs/ViViT-L-14_short.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "temporal_cfg": { + "context_length": 8, + "width": 768, + "heads": 12, + "layers": 12, + "mlp_ratio": 4, + "pooler_type": "cls_pooler" + } +} diff --git a/src/open_clip/transform.py b/src/open_clip/transform.py index 748884a3c..51e72492a 100644 --- a/src/open_clip/transform.py +++ b/src/open_clip/transform.py @@ -7,7 +7,7 @@ import torchvision.transforms.functional as F from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ - CenterCrop + CenterCrop, ToPILImage from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD @@ -131,3 +131,70 @@ def image_transform( normalize, ]) return Compose(transforms) + + +# TODO: needs improvmenet +def video_transform( + frame_size: int, + n_frames: int, + take_every_nth: int, + is_train: bool, + frame_mean: Optional[Tuple[float, ...]] = None, + frame_std: Optional[Tuple[float, ...]] = None, + ): + + frame_mean = frame_mean or OPENAI_DATASET_MEAN + if not isinstance(frame_mean, (list, tuple)): + frame_mean = (frame_mean,) * 3 + + frame_std = frame_std or OPENAI_DATASET_STD + if not isinstance(frame_std, (list, tuple)): + frame_std = (frame_std,) * 3 + + normalize = Normalize(mean=frame_mean, std=frame_std) + + if is_train: + transforms = [ + ToPILImage(), + RandomResizedCrop( + frame_size, + scale=(0.9, 0.1), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + else: + transforms = [ + ToPILImage(), + Resize(frame_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(frame_size), + _convert_to_rgb, + ToTensor(), + normalize, + ] + + frame_transform = Compose(transforms) + def apply_frame_transform(sample): + video, audio, video_meta = sample + video = video.permute(0, 3, 1, 2) + + video = video[::take_every_nth] + video = video[:n_frames] # TODO: maybe make this middle n frames + + # TODO: maybe padding isn't the way to go + # TODO: also F.pad is acting up for some reason + # isn't letting me input a len 8 tuple for 4d tnesor??? + # video = F.pad(video, tuple([0, 0]*len(video.shape[-3:]) + [0, n_frames - video.shape[0]])) + + if video.shape[0] < n_frames: + padded_video = torch.zeros(n_frames, *video.shape[1:]) + padded_video[:video.shape[0]] = video + video = padded_video + + # TODO: this .float() is weird, look how this is done in other places + return torch.cat([frame_transform(frame.float())[None, ...] for frame in video]) + + + return apply_frame_transform diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 4e0151017..1ca67b96c 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -497,7 +497,7 @@ def forward(self, x: torch.Tensor): if self.output_tokens: return pooled, tokens - + return pooled diff --git a/src/open_clip/video_model.py b/src/open_clip/video_model.py new file mode 100644 index 000000000..75f7de5e3 --- /dev/null +++ b/src/open_clip/video_model.py @@ -0,0 +1,207 @@ +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + Transformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + + +@dataclass +class TemporalCfg: + context_length: int = 32 # number of input frames + width: int = 512 + heads: int = 8 + layers: int = 12 + mlp_ratio: int = 4 + pooler_type: str = "cls_pooler" + + +# TODO: ViViT class makes this function a bit pointless +# still thinking about how to organize this better +def _build_video_tower( + embed_dim, + vision_cfg, + temporal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + + model = ViViT( + embed_dim, + vision_cfg, + temporal_cfg, + quick_gelu, + cast_dtype, + ) + + return model + +# TODO: finish option for mean pooling (no cls token if global_average_pool == True) +class ViViT(nn.Module): + """ViViT model (https://arxiv.org/abs/2103.15691), factorised encoder variant""" + def __init__( + self, + embed_dim, + vision_cfg, + temporal_cfg, + global_average_pool: bool = False, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + self.context_length = temporal_cfg.context_length + + # class embeddings and positional embeddings + scale = temporal_cfg.width ** -0.5 + self.video_class_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.width)) + self.video_positional_embedding = nn.Parameter(scale * torch.randn(temporal_cfg.context_length, temporal_cfg.width)) + + self.ln_pre = norm_layer(temporal_cfg.width) + self.ln_post = norm_layer(temporal_cfg.width) + self.proj = nn.Parameter(scale * torch.randn(temporal_cfg.width, embed_dim)) + + self.spatial = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + self.temporal = Transformer( + width=temporal_cfg.width, + layers=temporal_cfg.layers, + heads=temporal_cfg.heads, + mlp_ratio=temporal_cfg.mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.spatial.set_grad_checkpointing(enable) + self.temporal.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + # TODO: add patch dropout as suggested by lucidrains + def forward(self, video): + video = video[:, 1:] # make space for temporal CLS token + batch_size = video.shape[0] + + # Flatten all frames in batch across time and encode with ViT + frames = video.flatten(start_dim=0, end_dim=1) + f_e = self.spatial(frames) + # Put frame embeddings back into correct temporal sequences + f_e = f_e.view(*video.shape[:2], -1) + + # class embeddings and positional embeddings + f_e = torch.cat( + [self.video_class_embedding.to(f_e.dtype) + torch.zeros(f_e.shape[0], 1, f_e.shape[-1], dtype=f_e.dtype, device=f_e.device), + f_e], dim=1) # shape = [b, cl, w] + f_e = f_e + self.video_positional_embedding.to(f_e.dtype) + + # TODO: need to look at paper again, section 3, equations (4,5,6) + # do we need the residual connections? + f_e = self.ln_pre(f_e) + + f_e = f_e.permute(1, 0, 2) + v_e = self.temporal(f_e) + v_e = v_e.permute(1, 0, 2) + + self.global_average_pool = True + pooled, tokens = self._global_pool(v_e) + pooled = self.ln_post(pooled) + pooled = pooled @ self.proj + + ''' + print("POOLED") + print(pooled[:10, :10]) + print(pooled.shape) + print(torch.mean(torch.var(pooled, dim=0))) + ''' + + return pooled + + +# TODO: turn into VideoCoCa +class VideoCLIP(nn.Module): + def __init__( + self, + embed_dim, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + temporal_cfg: TemporalCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + temporal_cfg = TemporalCfg(**temporal_cfg) if isinstance(temporal_cfg, dict) else temporal_cfg + + self.visual = _build_video_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + temporal_cfg=temporal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_video(self, video, normalize: bool = False): + features = self.visual(video) + return F.normalize(features, dim=-1) if normalize else features + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + def forward(self, video, text): + video_features = self.encode_video(video, normalize=True) + text_features = self.encode_text(text, normalize=True) + # TODO: make loss functions generalize to all types of modality pairs + # i.e. make keys more general, for now keeping as image_features + return { + "image_features": video_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } diff --git a/src/test.sh b/src/test.sh new file mode 100755 index 000000000..7fc6225f1 --- /dev/null +++ b/src/test.sh @@ -0,0 +1,24 @@ +#!/bin/bash + + +# --pretrained "ViT-L-14:laion2b_s32b_b82k" \ + +python3 -m training.main \ + --train-data "s3://stability-west/webvid-10M/{00000..10727}.tar" \ + --train-num-samples 10727000 \ + --dataset-type webdataset \ + --batch-size=128 \ + --precision amp_bfloat16 \ + --epochs=9 \ + --warmup=100 \ + --lr-scheduler "const" \ + --lr 3e-4 \ + --workers=12 \ + --model "ViViT-B-32_short" \ + --pretrained "ViT-B-32:laion2b_s34b_b79k" \ + --ddp-static-graph \ + --local-loss \ + --log-every-n-steps 1 \ + --gather-with-grad \ + --grad-checkpointing \ + --report-to wandb diff --git a/src/training/main.py b/src/training/main.py index f70c9f953..e7a1e70c3 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -29,6 +29,7 @@ from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from training.data import get_data +from training.video_data import get_video_data # TODO: maybe we don't need separate files from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging from training.params import parse_args @@ -334,7 +335,12 @@ def main(args): logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets - data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + # TODO: come up with a way of getting alternative modality data based on model config + if "ViViT" in args.model: + data = get_video_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + else: + data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model)) + assert len(data), 'At least one train or eval dataset must be specified.' # create scheduler if train diff --git a/src/training/train.py b/src/training/train.py index 830b0bf59..e413a0659 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -20,6 +20,10 @@ from .precision import get_autocast +OPENAI_DATASET_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) +OPENAI_DATASET_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) + + class AverageMeter(object): """Computes and stores the average and current value""" @@ -88,7 +92,22 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist if not args.skip_scheduler: scheduler(step) - images, texts = batch + # TODO: generalize train loop to modality1, modality2 instead of image,text maybe + # images, texts = batch + # images, texts = batch["mp4"], batch["txt"] + images, texts = batch["video"], batch["txt"] + # texts = batch['txt'] + # images = torch.zeros((32, 8, 3, 224, 224)) + print(images.shape) + + original_means = images.mean(dim=(0, 1, 3, 4), keepdim=True) + original_stds = images.std(dim=(0, 1, 3, 4), keepdim=True) + images = (images - original_means) / original_stds + + images = images * OPENAI_DATASET_STD[..., None, None] + OPENAI_DATASET_MEAN[..., None, None] + + texts = data['tokenizer'](texts) + images = images.to(device=device, dtype=cast_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) @@ -107,7 +126,6 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist total_loss = sum(losses.values()) losses["loss"] = total_loss - backward(total_loss, scaler) else: # First, cache the features without any gradient tracking. @@ -227,6 +245,22 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() + + # Saving checkpoints every 1M steps + if is_master(args) and (i_accum % 1000000 == 0 or batch_count == num_batches_per_epoch): + checkpoint_dict = { + "epoch": epoch, + "name": args.name, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) # end for diff --git a/src/training/video_data.py b/src/training/video_data.py new file mode 100644 index 000000000..c4eb29221 --- /dev/null +++ b/src/training/video_data.py @@ -0,0 +1,256 @@ +"""video dataset creation""" +import io +import logging +import math +import random +import torchvision +import tempfile +import webdataset as wds + +from dataclasses import dataclass +from multiprocessing import Value +from pathlib import Path +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate +from webdataset.filters import _shuffle +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample + +from video2dataset.dataloader import get_video_dataset +from omegaconf import OmegaConf +from sdata import create_dataset, create_loader + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + shared_epoch: SharedEpoch = None + sampler = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def filter_no_caption_or_no_video(sample): + has_caption = ('txt' in sample) + has_video = ('mp4' in sample) + return has_caption and has_video + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, issue a warning, and continue.""" + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + # If seed is negative, we use the worker's seed, this will be different across all nodes/workers + seed = pytorch_worker_seed(epoch) + else: + # This seed to be deterministic AND the same across all nodes/workers in each epoch + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def get_wds_dataset(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): + num_samples = args.train_num_samples + shared_epoch = SharedEpoch(epoch=epoch) + + pipeline = [wds.SimpleShardList(args.train_data)] + is_train = True + + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + + pipeline.extend([ + wds.select(filter_no_caption_or_no_video), + wds.decode(wds.torch_video, handler=log_and_continue), + wds.rename(video="mp4", text="txt"), + wds.map_dict(video=preprocess_vid, text=lambda text: tokenizer(text)[0]), + wds.to_tuple("video", "text"), + wds.batched(args.batch_size, partial=not is_train) + ]) + + dataset = wds.DataPipeline(*pipeline) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=1,# args.workers, + persistent_workers=True, + ) + + round_fn = math.floor + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_wds_dataset2(args, preprocess_vid, is_train, epoch=0, floor=False, tokenizer=None): + num_samples = args.train_num_samples + shared_epoch = SharedEpoch(epoch=epoch) + + ''' + decoder_kwargs = { # TODO: update with params + "n_frames": 8, + "fps": 1, + "num_threads": 12, + } + + custom_transforms = { + "mp4": lambda x: x.permute(0, 3, 1, 2), + "txt": lambda text: tokenizer(text)[0], + } + + dataset = get_video_dataset( + urls=args.train_data, + batch_size=args.batch_size, + shuffle=1, + decoder_kwargs=decoder_kwargs, + custom_transforms=custom_transforms, + resize_size=224, + crop_size=224, + keys_to_remove=["m4a"], + handler=wds.warn_and_continue, + ) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + ''' + + # config = OmegaConf.load("/admin/home-iejmac/stable-datasets/examples/configs/debug.yaml") + config = OmegaConf.load("/admin/home-iejmac/stable-datasets/examples/configs/video_test.yaml") + + # build config + datapipeline = create_dataset(**config.dataset) + + # build loader + dataloader = create_loader(datapipeline, **config.loader) + + round_fn = math.floor + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_video_data(args, preprocess_fns, epoch=0, tokenizer=None): + preprocess_train, preprocess_val = preprocess_fns + data = {} + + if args.train_data: + # data["train"] = get_wds_dataset(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + data["train"] = get_wds_dataset2(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) + data["tokenizer"] = tokenizer + + return data