From 5c6d01a3874cfdfd1f1d160b038d8a79b7c545c5 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Thu, 1 Jun 2023 20:01:01 +0300 Subject: [PATCH 01/37] very dirty draft --- lightly/models/modules/jepa.py | 203 +++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 lightly/models/modules/jepa.py diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py new file mode 100644 index 000000000..0e1373d4d --- /dev/null +++ b/lightly/models/modules/jepa.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +from einops import rearrange +import copy +from typing import Callable, List, Optional, Tuple, Union +from .format import Format, nchw_to +from .helpers import to_2tuple +from .trace_utils import _assert + + +def to_2tuple(x): + + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + output_fmt: Format + + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + strict_img_size: bool = True, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.img_size = to_2tuple(img_size) + self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + + self.flatten = flatten + self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + if self.img_size is not None: + if self.strict_img_size: + _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") + else: + _assert( + H % self.patch_size[0] == 0, + f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + ) + _assert( + W % self.patch_size[1] == 0, + f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + ) + + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) + x = self.norm(x) + return x + + +class Predictor(nn.Module): + def __init__(self, embed_dim, num_heads, depth): + super().__init__() + + self.predictor = Decoder(dim = embed_dim, depth = depth, heads = num_heads) + def forward(self, context_encoding, target_masks): + x = torch.cat((context_encoding, target_masks), dim = 1) + x = self.predictor(x) + #return last len(target_masks) tokens + l = x.shape[1] + return x[:, l - target_masks.shape[1]:, :] + + +class IJEPA_base(nn.Module): + def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_depth, num_heads, post_emb_norm=False, M = 4, mode="train", layer_dropout=0.): + super().__init__() + self.M = M + self.mode = mode + self.layer_dropout = layer_dropout + + #define the patch embedding and positional embedding + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.patch_dim = (self.patch_embed.patch_shape[0], self.patch_embed.patch_shape[1]) + self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1] + self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) + + #define the cls and mask tokens + self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim)) + nn.init.trunc_normal_(self.mask_token, 0.02) + + #define the encoder and decoder, as well as the layer normalization and dropout + self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity() + self.norm = nn.LayerNorm(embed_dim) + self.teacher_encoder = Encoder( + dim=embed_dim, + heads=num_heads, + depth=enc_depth, + layer_dropout=self.layer_dropout, + ) + self.student_encoder = copy.deepcopy(self.teacher_encoder).cuda() + self.predictor = Predictor(embed_dim, num_heads, pred_depth) + + @torch.no_grad() + def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M): + #get the target block + target_encoder = target_encoder.eval() + x = target_encoder(x) + x = self.norm(x) + #get the patch dimensions + patch_h, patch_w = patch_dim + #get the number of patches + num_patches = patch_h * patch_w + #get the number of patches in the target block + num_patches_block = int(patch_h * patch_w * scale) + #get the height and width of the target block with aspect ratio + block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) + block_w = int(aspect_ratio * block_h) + #get the patches in the target block + target_block = torch.zeros((M, x.shape[0], block_h*block_w, x.shape[2])) + target_patches = [] + all_patches = [] + for z in range(M): + #get the starting patch + start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item() + start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item() + start_patch = start_patch_h * patch_w + start_patch_w + + patches = [] + #get the patches in the target block + for i in range(block_h): + for j in range(block_w): + patches.append(start_patch + i * patch_w + j) + if start_patch + i * patch_w + j not in all_patches: + all_patches.append(start_patch + i * patch_w + j) + + #get the target block + target_patches.append(patches) + target_block[z] = x[:, patches, :] + return target_block.cuda(), target_patches, all_patches + + def get_context_block(self, x, patch_dim, aspect_ratio, scale, target_patches): + patch_h, patch_w = patch_dim + #get the number of patches in the target block + num_patches_block = int(patch_h * patch_w * scale) + #get the height and width of the target block with aspect ratio + block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) + block_w = int(aspect_ratio * block_h) + #get the starting patch + start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item() + start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item() + start_patch = start_patch_h * patch_w + start_patch_w + #get the patches in the context_block + patches = [] + for i in range(block_h): + for j in range(block_w): + if start_patch + i * patch_w + j not in target_patches: #remove the target patches + patches.append(start_patch + i * patch_w + j) + return x[:, patches, :] + + + def forward(self, x, target_aspect_ratio=1, target_scale=1, context_aspect_ratio=1, context_scale=1): + #get the patch embeddings + x = self.patch_embed(x) + b, n, e = x.shape + x = x + self.pos_embedding[:, :n] + #add the positional embeddings + x = x + self.pos_embedding + #normalize the embeddings + x = self.post_emb_norm(x) + #if mode is test, we get return full embedding: + if self.mode == 'test': + return self.student_encoder(x) + # #get target embeddings + target_blocks, target_patches, all_patches = self.get_target_block(self.teacher_encoder, x, self.patch_dim, target_aspect_ratio, target_scale, self.M) + m, b, n, e = target_blocks.shape + #get context embedding + + context_block = self.get_context_block(x, self.patch_dim, context_aspect_ratio, context_scale, all_patches) + context_encoding = self.student_encoder(context_block) + context_encoding = self.norm(context_encoding) + + + prediction_blocks = torch.zeros((m, b, n, e)).cuda() + #get the prediction blocks, predict each target block separately + for i in range(m): + target_masks = self.mask_token.repeat(b, n, 1) + target_pos_embedding = self.pos_embedding[:, target_patches[i], :] + target_masks = target_masks + target_pos_embedding + prediction_blocks[i] = self.predictor(context_encoding, target_masks) + + return prediction_blocks, target_blocks From a1c0c158d58661049ed79327f2ba5dbb687a8dc8 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Thu, 1 Jun 2023 20:06:32 +0300 Subject: [PATCH 02/37] little refactoring --- lightly/models/modules/jepa.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py index 0e1373d4d..c56b5107f 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/jepa.py @@ -5,20 +5,21 @@ from einops import rearrange import copy from typing import Callable, List, Optional, Tuple, Union -from .format import Format, nchw_to -from .helpers import to_2tuple -from .trace_utils import _assert +import collections.abc +from itertools import repeat + def to_2tuple(x): - + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - output_fmt: Format - def __init__( self, img_size: Optional[int] = 224, @@ -27,7 +28,6 @@ def __init__( embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten: bool = True, - output_fmt: Optional[str] = None, bias: bool = True, strict_img_size: bool = True, ): @@ -39,33 +39,15 @@ def __init__( self.flatten = flatten - self.output_fmt = Format.NCHW self.strict_img_size = strict_img_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): - B, C, H, W = x.shape - if self.img_size is not None: - if self.strict_img_size: - _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") - _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") - else: - _assert( - H % self.patch_size[0] == 0, - f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." - ) - _assert( - W % self.patch_size[1] == 0, - f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." - ) - x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC - elif self.output_fmt != Format.NCHW: - x = nchw_to(x, self.output_fmt) x = self.norm(x) return x From c85ee973ab294d5faa9082308d8767b6dcee9864 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Thu, 1 Jun 2023 20:18:05 +0300 Subject: [PATCH 03/37] refactoring --- lightly/models/modules/jepa.py | 38 ++++++++-------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py index c56b5107f..785e924e5 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/jepa.py @@ -1,22 +1,12 @@ import torch import torch.nn as nn -import math import torch.nn.functional as F -from einops import rearrange import copy -from typing import Callable, List, Optional, Tuple, Union +from typing import Optional import collections.abc from itertools import repeat - -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - return tuple(x) - return tuple(repeat(x, n)) - - - class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ @@ -26,29 +16,19 @@ def __init__( patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten: bool = True, - bias: bool = True, - strict_img_size: bool = True, + ): super().__init__() - self.patch_size = to_2tuple(patch_size) - self.img_size = to_2tuple(img_size) - self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) - self.num_patches = self.grid_size[0] * self.grid_size[1] - - - self.flatten = flatten - self.strict_img_size = strict_img_size - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + if isinstance(img_size, int): + img_size = img_size, img_size + if isinstance(patch_size, int): + patch_size = patch_size, patch_size + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) - if self.flatten: - x = x.flatten(2).transpose(1, 2) # NCHW -> NLC - x = self.norm(x) + x = x.flatten(2).transpose(1, 2) return x From 0ece1d3038d30c548c11a602c12f7f5c3a5b694d Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Thu, 1 Jun 2023 20:23:02 +0300 Subject: [PATCH 04/37] refactoring --- lightly/models/modules/jepa.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py index 785e924e5..f70844cd0 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/jepa.py @@ -3,8 +3,6 @@ import torch.nn.functional as F import copy from typing import Optional -import collections.abc -from itertools import repeat class PatchEmbed(nn.Module): From 6a3f7fe7d9c5fc48e8b2aecff0a2709e4cafcc64 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Fri, 2 Jun 2023 19:00:57 +0300 Subject: [PATCH 05/37] + encoder. TODO: decoder based on causal attention --- lightly/models/modules/jepa.py | 83 +++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py index f70844cd0..23385ac4a 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/jepa.py @@ -2,7 +2,10 @@ import torch.nn as nn import torch.nn.functional as F import copy -from typing import Optional +from typing import Optional, Callable, List +from torchvision.models import vision_transformer +from lightly.models import utils +import math class PatchEmbed(nn.Module): @@ -43,6 +46,82 @@ def forward(self, context_encoding, target_masks): return x[:, l - target_masks.shape[1]:, :] +class IJEPA_Encoder(vision_transformer.Encoder): + def __init__( + self, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + ): + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + def forward( + self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Encode input tokens. + + Args: + input: + Batch of token sequences. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be encoded. + + Returns: + Batch of encoded output tokens. + """ + input = input + self.interpolate_pos_encoding(input) + if idx_keep is not None: + input = utils.get_at_index(input, idx_keep) + return self.ln(self.layers(self.dropout(input))) + + def interpolate_pos_encoding(self, input: torch.Tensor): + """Returns the interpolated positional embedding for the given input. + + This function interpolates self.pos_embedding for all tokens in the input, + ignoring the class token. This allows encoding variable sized images. + + Args: + input: + Input tensor with shape (batch_size, num_sequences). + + """ + # code copied from: + npatch = input.shape[1] - 1 + N = self.pos_embedding.shape[1] - 1 + if npatch == N: + return self.pos_embedding + class_emb = self.pos_embedding[:, 0] + pos_embedding = self.pos_embedding[:, 1:] + dim = input.shape[-1] + pos_embedding = nn.functional.interpolate( + pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(npatch / N), + mode="bicubic", + ) + pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + + + + class IJEPA_base(nn.Module): def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_depth, num_heads, post_emb_norm=False, M = 4, mode="train", layer_dropout=0.): super().__init__() @@ -63,7 +142,7 @@ def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_de #define the encoder and decoder, as well as the layer normalization and dropout self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity() self.norm = nn.LayerNorm(embed_dim) - self.teacher_encoder = Encoder( + self.teacher_encoder = IJEPA_Encoder( dim=embed_dim, heads=num_heads, depth=enc_depth, From 7102910542795b21cb8f5aa56e4d58925bc3368c Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Fri, 2 Jun 2023 19:32:55 +0300 Subject: [PATCH 06/37] fix imports --- lightly/models/modules/jepa.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py index 23385ac4a..d5bc96a65 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/jepa.py @@ -2,6 +2,7 @@ import torch.nn as nn import torch.nn.functional as F import copy +from functools import partial from typing import Optional, Callable, List from torchvision.models import vision_transformer from lightly.models import utils @@ -120,8 +121,6 @@ def interpolate_pos_encoding(self, input: torch.Tensor): return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) - - class IJEPA_base(nn.Module): def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_depth, num_heads, post_emb_norm=False, M = 4, mode="train", layer_dropout=0.): super().__init__() From 0c5bac6ff94a8ad7785b3bcf207193f78fad3c18 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Fri, 2 Jun 2023 21:56:28 +0300 Subject: [PATCH 07/37] add Decoder class to consistency between code --- lightly/models/modules/jepa.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/jepa.py index d5bc96a65..9a276c6f8 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/jepa.py @@ -13,11 +13,11 @@ class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__( - self, - img_size: Optional[int] = 224, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, ): super().__init__() @@ -34,11 +34,16 @@ def forward(self, x): return x -class Predictor(nn.Module): +class IJEPA_Decoder(nn.Module): + def __init__(self, embed_dim, depth, num_heads): + super().__init__() + pass + +class IJEPA_Predictor(nn.Module): def __init__(self, embed_dim, num_heads, depth): super().__init__() - self.predictor = Decoder(dim = embed_dim, depth = depth, heads = num_heads) + self.predictor = IJEPA_Decoder(dim = embed_dim, depth = depth, heads = num_heads) def forward(self, context_encoding, target_masks): x = torch.cat((context_encoding, target_masks), dim = 1) x = self.predictor(x) @@ -146,9 +151,9 @@ def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_de heads=num_heads, depth=enc_depth, layer_dropout=self.layer_dropout, - ) + ) self.student_encoder = copy.deepcopy(self.teacher_encoder).cuda() - self.predictor = Predictor(embed_dim, num_heads, pred_depth) + self.predictor = IJEPA_Predictor(embed_dim, num_heads, pred_depth) @torch.no_grad() def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M): From 7c41f599c79db62f1ad8008d26368fd6c14ae752 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Sat, 3 Jun 2023 13:16:35 +0300 Subject: [PATCH 08/37] change naming; change class structure --- lightly/models/modules/{jepa.py => i_jepa.py} | 63 +------------------ 1 file changed, 2 insertions(+), 61 deletions(-) rename lightly/models/modules/{jepa.py => i_jepa.py} (77%) diff --git a/lightly/models/modules/jepa.py b/lightly/models/modules/i_jepa.py similarity index 77% rename from lightly/models/modules/jepa.py rename to lightly/models/modules/i_jepa.py index 9a276c6f8..cb60ca7b5 100644 --- a/lightly/models/modules/jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -52,7 +52,7 @@ def forward(self, context_encoding, target_masks): return x[:, l - target_masks.shape[1]:, :] -class IJEPA_Encoder(vision_transformer.Encoder): +class IJEPA_Encoder(nn.Module): def __init__( self, seq_length: int, @@ -64,66 +64,7 @@ def __init__( attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): - super().__init__( - seq_length=seq_length, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - dropout=dropout, - attention_dropout=attention_dropout, - norm_layer=norm_layer, - ) - - def forward( - self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """Encode input tokens. - - Args: - input: - Batch of token sequences. - idx_keep: - Tensor with shape (batch_size, num_tokens_to_keep) where each - entry is an index of the token to keep in the respective batch. - If specified, only the indexed tokens will be encoded. - - Returns: - Batch of encoded output tokens. - """ - input = input + self.interpolate_pos_encoding(input) - if idx_keep is not None: - input = utils.get_at_index(input, idx_keep) - return self.ln(self.layers(self.dropout(input))) - - def interpolate_pos_encoding(self, input: torch.Tensor): - """Returns the interpolated positional embedding for the given input. - - This function interpolates self.pos_embedding for all tokens in the input, - ignoring the class token. This allows encoding variable sized images. - - Args: - input: - Input tensor with shape (batch_size, num_sequences). - - """ - # code copied from: - npatch = input.shape[1] - 1 - N = self.pos_embedding.shape[1] - 1 - if npatch == N: - return self.pos_embedding - class_emb = self.pos_embedding[:, 0] - pos_embedding = self.pos_embedding[:, 1:] - dim = input.shape[-1] - pos_embedding = nn.functional.interpolate( - pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( - 0, 3, 1, 2 - ), - scale_factor=math.sqrt(npatch / N), - mode="bicubic", - ) - pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + pass class IJEPA_base(nn.Module): From d99994db63b50d9ffc4d5889a46c9837c230efe2 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Tue, 6 Jun 2023 22:09:57 +0300 Subject: [PATCH 09/37] change module; add example --- examples/pytorch/i_jepa.py | 0 lightly/models/modules/i_jepa.py | 15 +++++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) create mode 100644 examples/pytorch/i_jepa.py diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/i_jepa.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index cb60ca7b5..06bc7c8b4 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -34,16 +34,11 @@ def forward(self, x): return x -class IJEPA_Decoder(nn.Module): - def __init__(self, embed_dim, depth, num_heads): - super().__init__() - pass - class IJEPA_Predictor(nn.Module): def __init__(self, embed_dim, num_heads, depth): super().__init__() - - self.predictor = IJEPA_Decoder(dim = embed_dim, depth = depth, heads = num_heads) + decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads) + self.predictor = nn.TransformerDecoder(decoder_layer, num_layers=depth) def forward(self, context_encoding, target_masks): x = torch.cat((context_encoding, target_masks), dim = 1) x = self.predictor(x) @@ -64,7 +59,11 @@ def __init__( attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): - pass + encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + def forward(self, src): + return self.transformer_encoder(src) class IJEPA_base(nn.Module): From 4e585e7e76e061563d973a2106c6fd0b877b135e Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Wed, 7 Jun 2023 14:58:12 +0300 Subject: [PATCH 10/37] few refactoring --- lightly/models/modules/i_jepa.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 06bc7c8b4..3fa6cc582 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -4,8 +4,6 @@ import copy from functools import partial from typing import Optional, Callable, List -from torchvision.models import vision_transformer -from lightly.models import utils import math @@ -50,17 +48,13 @@ def forward(self, context_encoding, target_masks): class IJEPA_Encoder(nn.Module): def __init__( self, - seq_length: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - dropout: float, - attention_dropout: float, - norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + dim, + heads, + depth, ): - encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) - self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + super().__init__() + encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth) def forward(self, src): return self.transformer_encoder(src) @@ -90,10 +84,9 @@ def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_de dim=embed_dim, heads=num_heads, depth=enc_depth, - layer_dropout=self.layer_dropout, - ) + ) self.student_encoder = copy.deepcopy(self.teacher_encoder).cuda() - self.predictor = IJEPA_Predictor(embed_dim, num_heads, pred_depth) + self.predictor = IJEPA_Encoder(embed_dim, num_heads, pred_depth) @torch.no_grad() def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M): From 9034061732090aadc159c32fd65ead4a191d1e97 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Wed, 7 Jun 2023 15:09:57 +0300 Subject: [PATCH 11/37] del line --- lightly/models/modules/i_jepa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 3fa6cc582..55c597dcb 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -176,4 +176,4 @@ def forward(self, x, target_aspect_ratio=1, target_scale=1, context_aspect_ratio target_masks = target_masks + target_pos_embedding prediction_blocks[i] = self.predictor(context_encoding, target_masks) - return prediction_blocks, target_blocks + return prediction_blocks, target_blocks \ No newline at end of file From 035ed5a0d44cc72dfba13821cc262abc29e98142 Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Wed, 7 Jun 2023 15:23:33 +0300 Subject: [PATCH 12/37] del comments --- lightly/models/modules/i_jepa.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 55c597dcb..cdf6ca385 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -158,10 +158,8 @@ def forward(self, x, target_aspect_ratio=1, target_scale=1, context_aspect_ratio #if mode is test, we get return full embedding: if self.mode == 'test': return self.student_encoder(x) - # #get target embeddings target_blocks, target_patches, all_patches = self.get_target_block(self.teacher_encoder, x, self.patch_dim, target_aspect_ratio, target_scale, self.M) m, b, n, e = target_blocks.shape - #get context embedding context_block = self.get_context_block(x, self.patch_dim, context_aspect_ratio, context_scale, all_patches) context_encoding = self.student_encoder(context_block) From 3b413429ac6fc360a542f626e0845f29a1b1cc6e Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Wed, 7 Jun 2023 15:27:14 +0300 Subject: [PATCH 13/37] del comment --- lightly/models/modules/i_jepa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index cdf6ca385..81a7cc7d5 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -167,7 +167,6 @@ def forward(self, x, target_aspect_ratio=1, target_scale=1, context_aspect_ratio prediction_blocks = torch.zeros((m, b, n, e)).cuda() - #get the prediction blocks, predict each target block separately for i in range(m): target_masks = self.mask_token.repeat(b, n, 1) target_pos_embedding = self.pos_embedding[:, target_patches[i], :] From 464678649eae7d692834eee0f89d99958396802f Mon Sep 17 00:00:00 2001 From: georgebredis <9454-georgebredis@users.noreply.gitlab.aicrowd.com> Date: Wed, 7 Jun 2023 15:30:25 +0300 Subject: [PATCH 14/37] del comment --- lightly/models/modules/i_jepa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 81a7cc7d5..f4c244e17 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -121,7 +121,6 @@ def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M) if start_patch + i * patch_w + j not in all_patches: all_patches.append(start_patch + i * patch_w + j) - #get the target block target_patches.append(patches) target_block[z] = x[:, patches, :] return target_block.cuda(), target_patches, all_patches From 28fd7db36c6e8f943d6dd966813cfda6d9042cf0 Mon Sep 17 00:00:00 2001 From: georgebredis Date: Wed, 7 Jun 2023 15:34:27 +0300 Subject: [PATCH 15/37] del line --- lightly/models/modules/i_jepa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index f4c244e17..2936d931c 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -127,7 +127,6 @@ def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M) def get_context_block(self, x, patch_dim, aspect_ratio, scale, target_patches): patch_h, patch_w = patch_dim - #get the number of patches in the target block num_patches_block = int(patch_h * patch_w * scale) #get the height and width of the target block with aspect ratio block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) From de1106c37e08f65b9f91731ab6c4a03fd877f4d2 Mon Sep 17 00:00:00 2001 From: Natyren Date: Wed, 7 Jun 2023 21:07:38 +0300 Subject: [PATCH 16/37] pass --- lightly/models/modules/i_jepa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 2936d931c..186186fcf 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -114,7 +114,6 @@ def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M) start_patch = start_patch_h * patch_w + start_patch_w patches = [] - #get the patches in the target block for i in range(block_h): for j in range(block_w): patches.append(start_patch + i * patch_w + j) From 9995df68fe77e236fbfa81e97d96d3a0879b5a1f Mon Sep 17 00:00:00 2001 From: Natyren Date: Wed, 7 Jun 2023 21:15:00 +0300 Subject: [PATCH 17/37] pass --- lightly/models/modules/i_jepa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 186186fcf..fbed2037b 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -103,7 +103,6 @@ def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M) #get the height and width of the target block with aspect ratio block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) block_w = int(aspect_ratio * block_h) - #get the patches in the target block target_block = torch.zeros((M, x.shape[0], block_h*block_w, x.shape[2])) target_patches = [] all_patches = [] From c23de3915a83cb0796b50b46b1973f29db7bb99a Mon Sep 17 00:00:00 2001 From: Natyren Date: Sun, 9 Jul 2023 18:13:16 +0300 Subject: [PATCH 18/37] add model itself, todo: train loop and debug --- lightly/models/modules/i_jepa.py | 382 ++++++++++++++++++------------- lightly/models/utils.py | 24 ++ 2 files changed, 251 insertions(+), 155 deletions(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index fbed2037b..a1aee357e 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -1,172 +1,244 @@ import torch import torch.nn as nn import torch.nn.functional as F -import copy -from functools import partial -from typing import Optional, Callable, List +import numpy as np + +from torchvision.models import vision_transformer + +from lightly.models import utils +from typing import Optional, partial, Callable import math -class PatchEmbed(nn.Module): - """ 2D Image to Patch Embedding +class IJEPA_predictor(vision_transformer.Encoder): + """ + Predictor for the I-JEPA model [0]. + + Predict patch embeddings. Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + seq_length: + Token sequence length, including the class token. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + predictor_embed_dim: + Dimension of inner predicted tokens + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + """ def __init__( self, - img_size: Optional[int] = 224, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + predictor_embed_dim :int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + **kwargs ): - super().__init__() - if isinstance(img_size, int): - img_size = img_size, img_size - if isinstance(patch_size, int): - patch_size = patch_size, patch_size - self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x): - x = self.proj(x) - x = x.flatten(2).transpose(1, 2) - return x + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + self.predictor_proj = nn.Linear(predictor_embed_dim, mlp_dim, bias=True) + # self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), + # requires_grad=False) + # predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], + # int(num_patches**.5), + # cls_token=False) + # self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + + @classmethod + def from_vit_encoder(cls, vit_encoder): + """Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + encoder = cls( + seq_length=1, + num_layers=1, + num_heads=1, + hidden_dim=1, + mlp_dim=1, + dropout=0, + attention_dropout=0, + ) + encoder.layers = vit_encoder.layers + encoder.ln = vit_encoder.ln + return encoder + + def forward( + self, x, masks_x, masks + ): + assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices' + + if not isinstance(masks_x, list): + masks_x = [masks_x] + + if not isinstance(masks, list): + masks = [masks] + B = len(x) // len(masks_x) -class IJEPA_Predictor(nn.Module): - def __init__(self, embed_dim, num_heads, depth): - super().__init__() - decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads) - self.predictor = nn.TransformerDecoder(decoder_layer, num_layers=depth) - def forward(self, context_encoding, target_masks): - x = torch.cat((context_encoding, target_masks), dim = 1) - x = self.predictor(x) - #return last len(target_masks) tokens - l = x.shape[1] - return x[:, l - target_masks.shape[1]:, :] + x = self.predictor_embed(x) + x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += utils.apply_masks(x_pos_embed, masks_x) + _, N_ctxt, _ = x.shape + + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = utils.apply_masks(pos_embs, masks) + pos_embs = utils.repeat_interleave_batch(pos_embs, B, repeat=len(masks_x)) + pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) + + pred_tokens += pos_embs + x = x.repeat(len(masks), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + x = self.ln(self.layers(x)) + + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +class IJEPA_encoder(vision_transformer.Encoder): + """Encoder for the I-JEPA model [0]. + + Encodes patch embeddings. Code inspired by [1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + seq_length: + Token sequence length, including the class token. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + + """ -class IJEPA_Encoder(nn.Module): def __init__( self, - dim, - heads, - depth, + seq_length: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float, + attention_dropout: float, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): - super().__init__() - encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads) - self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth) - - def forward(self, src): - return self.transformer_encoder(src) - - -class IJEPA_base(nn.Module): - def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_depth, num_heads, post_emb_norm=False, M = 4, mode="train", layer_dropout=0.): - super().__init__() - self.M = M - self.mode = mode - self.layer_dropout = layer_dropout - - #define the patch embedding and positional embedding - self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - self.patch_dim = (self.patch_embed.patch_shape[0], self.patch_embed.patch_shape[1]) - self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1] - self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) - - #define the cls and mask tokens - self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim)) - nn.init.trunc_normal_(self.mask_token, 0.02) - - #define the encoder and decoder, as well as the layer normalization and dropout - self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity() - self.norm = nn.LayerNorm(embed_dim) - self.teacher_encoder = IJEPA_Encoder( - dim=embed_dim, - heads=num_heads, - depth=enc_depth, - ) - self.student_encoder = copy.deepcopy(self.teacher_encoder).cuda() - self.predictor = IJEPA_Encoder(embed_dim, num_heads, pred_depth) - - @torch.no_grad() - def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M): - #get the target block - target_encoder = target_encoder.eval() - x = target_encoder(x) - x = self.norm(x) - #get the patch dimensions - patch_h, patch_w = patch_dim - #get the number of patches - num_patches = patch_h * patch_w - #get the number of patches in the target block - num_patches_block = int(patch_h * patch_w * scale) - #get the height and width of the target block with aspect ratio - block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) - block_w = int(aspect_ratio * block_h) - target_block = torch.zeros((M, x.shape[0], block_h*block_w, x.shape[2])) - target_patches = [] - all_patches = [] - for z in range(M): - #get the starting patch - start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item() - start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item() - start_patch = start_patch_h * patch_w + start_patch_w - - patches = [] - for i in range(block_h): - for j in range(block_w): - patches.append(start_patch + i * patch_w + j) - if start_patch + i * patch_w + j not in all_patches: - all_patches.append(start_patch + i * patch_w + j) - - target_patches.append(patches) - target_block[z] = x[:, patches, :] - return target_block.cuda(), target_patches, all_patches - - def get_context_block(self, x, patch_dim, aspect_ratio, scale, target_patches): - patch_h, patch_w = patch_dim - num_patches_block = int(patch_h * patch_w * scale) - #get the height and width of the target block with aspect ratio - block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) - block_w = int(aspect_ratio * block_h) - #get the starting patch - start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item() - start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item() - start_patch = start_patch_h * patch_w + start_patch_w - #get the patches in the context_block - patches = [] - for i in range(block_h): - for j in range(block_w): - if start_patch + i * patch_w + j not in target_patches: #remove the target patches - patches.append(start_patch + i * patch_w + j) - return x[:, patches, :] - - - def forward(self, x, target_aspect_ratio=1, target_scale=1, context_aspect_ratio=1, context_scale=1): - #get the patch embeddings - x = self.patch_embed(x) - b, n, e = x.shape - x = x + self.pos_embedding[:, :n] - #add the positional embeddings - x = x + self.pos_embedding - #normalize the embeddings - x = self.post_emb_norm(x) - #if mode is test, we get return full embedding: - if self.mode == 'test': - return self.student_encoder(x) - target_blocks, target_patches, all_patches = self.get_target_block(self.teacher_encoder, x, self.patch_dim, target_aspect_ratio, target_scale, self.M) - m, b, n, e = target_blocks.shape - - context_block = self.get_context_block(x, self.patch_dim, context_aspect_ratio, context_scale, all_patches) - context_encoding = self.student_encoder(context_block) - context_encoding = self.norm(context_encoding) - - - prediction_blocks = torch.zeros((m, b, n, e)).cuda() - for i in range(m): - target_masks = self.mask_token.repeat(b, n, 1) - target_pos_embedding = self.pos_embedding[:, target_patches[i], :] - target_masks = target_masks + target_pos_embedding - prediction_blocks[i] = self.predictor(context_encoding, target_masks) - - return prediction_blocks, target_blocks \ No newline at end of file + super().__init__( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + @classmethod + def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder): + """Creates a MAEEncoder from a torchvision ViT encoder.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + encoder = cls( + seq_length=1, + num_layers=1, + num_heads=1, + hidden_dim=1, + mlp_dim=1, + dropout=0, + attention_dropout=0, + ) + encoder.dropout = vit_encoder.dropout + encoder.layers = vit_encoder.layers + encoder.ln = vit_encoder.ln + return encoder + + def forward( + self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Encode input tokens. + + Args: + input: + Batch of token sequences. + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be encoded. + + Returns: + Batch of encoded output tokens. + """ + input = input + self.interpolate_pos_encoding(input) + if idx_keep is not None: + input = utils.get_at_index(input, idx_keep) + return self.ln(self.layers(self.dropout(input))) + + def interpolate_pos_encoding(self, input: torch.Tensor): + """Returns the interpolated positional embedding for the given input. + + This function interpolates self.pos_embedding for all tokens in the input, + ignoring the class token. This allows encoding variable sized images. + + Args: + input: + Input tensor with shape (batch_size, num_sequences). + + """ + # code copied from: + # https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291 + npatch = input.shape[1] - 1 + N = self.pos_embedding.shape[1] - 1 + if npatch == N: + return self.pos_embedding + class_emb = self.pos_embedding[:, 0] + pos_embedding = self.pos_embedding[:, 1:] + dim = input.shape[-1] + pos_embedding = nn.functional.interpolate( + pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=math.sqrt(npatch / N), + mode="bicubic", + ) + pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) \ No newline at end of file diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 9a1e229a1..44be05968 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -567,3 +567,27 @@ def get_weight_decay_parameters( else: params.append(param) return params, params_no_weight_decay + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat([ + torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) + for i in range(N) + ], dim=0) + return x \ No newline at end of file From ca171bfcb1512c96300989e3e99d1b4689a2d98c Mon Sep 17 00:00:00 2001 From: Natyren Date: Sun, 9 Jul 2023 18:33:05 +0300 Subject: [PATCH 19/37] add collator; todo: add trainloop, debug --- examples/pytorch_lightning/i_jepa.py | 0 .../pytorch_lightning_distributed/i_jepa.py | 0 lightly/data/collate.py | 150 ++++++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 examples/pytorch_lightning/i_jepa.py create mode 100644 examples/pytorch_lightning_distributed/i_jepa.py diff --git a/examples/pytorch_lightning/i_jepa.py b/examples/pytorch_lightning/i_jepa.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/pytorch_lightning_distributed/i_jepa.py b/examples/pytorch_lightning_distributed/i_jepa.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightly/data/collate.py b/lightly/data/collate.py index 25720171d..d7136e326 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -1345,6 +1345,156 @@ def forward( return (views_global, views_local, grids_global, grids_local), labels, fnames +class IJEPAMaskCollator(MultiViewCollateFunction): + + def __init__( + self, + input_size=(224, 224), + patch_size=16, + enc_mask_scale=(0.2, 0.8), + pred_mask_scale=(0.2, 0.8), + aspect_ratio=(0.3, 3.0), + nenc=1, + npred=2, + min_keep=4, + allow_overlap=False + ): + super(IJEPAMaskCollator, self).__init__() + if not isinstance(input_size, tuple): + input_size = (input_size, ) * 2 + self.patch_size = patch_size + self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size + self.enc_mask_scale = enc_mask_scale + self.pred_mask_scale = pred_mask_scale + self.aspect_ratio = aspect_ratio + self.nenc = nenc + self.npred = npred + self.min_keep = min_keep # minimum number of patches to keep + self.allow_overlap = allow_overlap # whether to allow overlap b/w enc and pred masks + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size(self, generator, scale, aspect_ratio_scale): + _rand = torch.rand(1, generator=generator).item() + # -- Sample block scale + min_s, max_s = scale + mask_scale = min_s + _rand * (max_s - min_s) + max_keep = int(self.height * self.width * mask_scale) + # -- Sample block aspect-ratio + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(torch.sqrt(max_keep * aspect_ratio))) + w = int(round(torch.sqrt(max_keep / aspect_ratio))) + while h >= self.height: + h -= 1 + while w >= self.width: + w -= 1 + + return (h, w) + + def _sample_block_mask(self, b_size, acceptable_regions=None): + h, w = b_size + + def constrain_mask(mask, tries=0): + """ Helper to restrict given mask to a set of acceptable regions """ + N = max(int(len(acceptable_regions)-tries), 0) + for k in range(N): + mask *= acceptable_regions[k] + # -- + # -- Loop to sample masks until we find a valid one + tries = 0 + timeout = og_timeout = 20 + valid_mask = False + while not valid_mask: + # -- Sample block top-left corner + top = torch.randint(0, self.height - h, (1,)) + left = torch.randint(0, self.width - w, (1,)) + mask = torch.zeros((self.height, self.width), dtype=torch.int32) + mask[top:top+h, left:left+w] = 1 + # -- Constrain mask to a set of acceptable regions + if acceptable_regions is not None: + constrain_mask(mask, tries) + mask = torch.nonzero(mask.flatten()) + # -- If mask too small try again + valid_mask = len(mask) > self.min_keep + if not valid_mask: + timeout -= 1 + if timeout == 0: + tries += 1 + timeout = og_timeout + mask = mask.squeeze() + # -- + mask_complement = torch.ones((self.height, self.width), dtype=torch.int32) + mask_complement[top:top+h, left:left+w] = 0 + # -- + return mask, mask_complement + + def __call__(self, batch): + ''' + Create encoder and predictor masks when collating imgs into a batch + # 1. sample enc block (size + location) using seed + # 2. sample pred block (size) using seed + # 3. sample several enc block locations for each image (w/o seed) + # 4. sample several pred block locations for each image (w/o seed) + # 5. return enc mask and pred mask + ''' + B = len(batch) + + collated_batch = torch.utils.data.default_collate(batch) + + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + scale=self.pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio) + e_size = self._sample_block_size( + generator=g, + scale=self.enc_mask_scale, + aspect_ratio_scale=(1., 1.)) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_pred = self.height * self.width + min_keep_enc = self.height * self.width + for _ in range(B): + + masks_p, masks_C = [], [] + for _ in range(self.npred): + mask, mask_C = self._sample_block_mask(p_size) + masks_p.append(mask) + masks_C.append(mask_C) + min_keep_pred = min(min_keep_pred, len(mask)) + collated_masks_pred.append(masks_p) + + acceptable_regions = masks_C + + if self.allow_overlap: + acceptable_regions= None + + + masks_e = [] + for _ in range(self.nenc): + mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions) + masks_e.append(mask) + min_keep_enc = min(min_keep_enc, len(mask)) + collated_masks_enc.append(masks_e) + + collated_masks_pred = [[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_batch, collated_masks_enc, collated_masks_pred + + def _deprecation_warning_collate_functions() -> None: warn( "Collate functions are deprecated and will be removed in favor of transforms in v1.4.0.\n" From e1b97ec79ada82d3493f33707a94de9bf6947b39 Mon Sep 17 00:00:00 2001 From: Natyren Date: Sun, 9 Jul 2023 18:47:40 +0300 Subject: [PATCH 20/37] added template to train code and transforms todo: train and debug --- examples/pytorch/i_jepa.py | 30 ++++++++++++++ lightly/models/modules/i_jepa.py | 3 +- lightly/transforms/ijepa_transform.py | 57 +++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 lightly/transforms/ijepa_transform.py diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/i_jepa.py index e69de29bb..202019f33 100644 --- a/examples/pytorch/i_jepa.py +++ b/examples/pytorch/i_jepa.py @@ -0,0 +1,30 @@ +import torch +import torchvision +from torch import nn + +from lightly.models import utils +from lightly.models.modules import i_jepa +from lightly.transforms.mae_transform import IJEPATransform + + +class I_JEPA(nn.Module): + pass + + +vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) +vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) +model = I_JEPA(vit_for_predictor, vit_for_embedder) + +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + +transform = IJEPATransform() +# we ignore object detection annotations by setting target_transform to return 0 +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=lambda t: 0, +) +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index a1aee357e..ebd8814e2 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -126,7 +126,8 @@ def forward( class IJEPA_encoder(vision_transformer.Encoder): - """Encoder for the I-JEPA model [0]. + """ + Encoder for the I-JEPA model [0]. Encodes patch embeddings. Code inspired by [1]. diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py new file mode 100644 index 000000000..a212e694b --- /dev/null +++ b/lightly/transforms/ijepa_transform.py @@ -0,0 +1,57 @@ +from typing import List, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image +from torch import Tensor + +from lightly.transforms.multi_view_transform import MultiViewTransform +from lightly.transforms.utils import IMAGENET_NORMALIZE + + +class IJEPATransform: + """Implements the augmentations for I-JEPA (IMAGENET data transforms accorgind to original code) [0, 1]. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + input_size: + Size of the input image in pixels. + min_scale: + Minimum size of the randomized crop relative to the input_size. + normalize: + Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize. + + """ + + def __init__( + self, + input_size: Union[int, Tuple[int, int]] = 224, + min_scale: float = 0.2, + normalize: dict = IMAGENET_NORMALIZE, + ): + transforms = [ + T.RandomResizedCrop( + input_size, scale=(min_scale, 1.0), interpolation=3 + ), # 3 is bicubic + T.RandomHorizontalFlip(), + T.ToTensor(), + ] + if normalize: + transforms.append(T.Normalize(mean=normalize["mean"], std=normalize["std"])) + + self.transform = T.Compose(transforms) + + def __call__(self, image: Union[Tensor, Image]) -> List[Tensor]: + """ + Applies the transforms to the input image. + + Args: + image: + The input image to apply the transforms to. + + Returns: + The transformed image. + + """ + return [self.transform(image)] From 612f7cc95f0c8755cfffccdf8778321836a23af1 Mon Sep 17 00:00:00 2001 From: Natyren Date: Wed, 12 Jul 2023 21:17:45 +0300 Subject: [PATCH 21/37] add train in pure pytorch in debug process now --- examples/pytorch/i_jepa.py | 72 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/i_jepa.py index 202019f33..488b0f73c 100644 --- a/examples/pytorch/i_jepa.py +++ b/examples/pytorch/i_jepa.py @@ -1,14 +1,48 @@ import torch import torchvision from torch import nn +from torch.nn import functional as F +import copy from lightly.models import utils from lightly.models.modules import i_jepa from lightly.transforms.mae_transform import IJEPATransform +from lightly.data.collate import IJEPAMaskCollator class I_JEPA(nn.Module): - pass + def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): + super().__init__() + self.encoder = i_jepa.IJEPA_encoder.from_vit_encoder(vit_encoder) + self.predictor = i_jepa.IJEPA_predictor.from_vit_encoder(vit_predictor) + self.target_encoder = copy.deepcopy(self.encoder) + self.momentum_scheduler = momentum_scheduler + + def forward_target(self, imgs, masks_enc, masks_pred): + with torch.no_grad(): + h = self.target_encoder(imgs) + h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim + B = len(h) + # -- create targets (masked regions of h) + h = utils.apply_masks(h, masks_pred) + h = utils.repeat_interleave_batch(h, B, repeat=len(masks_enc)) + return h + + def forward_context(self, imgs, masks_enc, masks_pred): + z = self.encoder(imgs, masks_enc) + z = self.predictor(z, masks_enc, masks_pred) + return z + + def forward(self, imgs, masks_enc, masks_pred): + z = self.forward_context(self, imgs, masks_enc, masks_pred) + h = self.forward_target(self, imgs, masks_enc, masks_pred) + return z, h + + def update_target_encoder(self,): + with torch.no_grad(): + m = next(self.momentum_scheduler) + for param_q, param_k in zip(self.encoder.parameters(), self.target_encoder.parameters()): + param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) @@ -18,13 +52,47 @@ class I_JEPA(nn.Module): device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) +collator = IJEPAMaskCollator( + input_size=(224,224), + patch_size=32, +) + transform = IJEPATransform() -# we ignore object detection annotations by setting target_transform to return 0 dataset = torchvision.datasets.VOCDetection( "datasets/pascal_voc", download=True, transform=transform, target_transform=lambda t: 0, ) +data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + batch_size=1, + persistent_workers=False +) + +# we ignore object detection annotations by setting target_transform to return 0 +criterion = nn.SmoothL1Loss() +optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) # or create a dataset from a folder containing images or videos: # dataset = LightlyDataset("path/to/folder") +print("Starting Training") +for epoch in range(10): + total_loss = 0 + for itr, (udata, masks_enc, masks_pred) in enumerate(data_loader): + + def load_imgs(): + # -- unsupervised imgs + imgs = udata[0].to(device, non_blocking=True) + masks_1 = [u.to(device, non_blocking=True) for u in masks_enc] + masks_2 = [u.to(device, non_blocking=True) for u in masks_pred] + return (imgs, masks_1, masks_2) + imgs, masks_enc, masks_pred = load_imgs() + z, h = model(imgs, masks_enc, masks_pred) + loss = criterion(z, h) + total_loss += loss.detach() + loss.backward() + optimizer.step() + optimizer.zero_grad() + avg_loss = total_loss / len(data_loader) + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") \ No newline at end of file From d9a9fd83bc31d0c80b45e57901045778b6c16a6a Mon Sep 17 00:00:00 2001 From: Natyren Date: Wed, 12 Jul 2023 21:50:49 +0300 Subject: [PATCH 22/37] little fix --- examples/pytorch/i_jepa.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/i_jepa.py index 488b0f73c..69f7345f9 100644 --- a/examples/pytorch/i_jepa.py +++ b/examples/pytorch/i_jepa.py @@ -45,19 +45,16 @@ def update_target_encoder(self,): param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) -vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) -vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) -model = I_JEPA(vit_for_predictor, vit_for_embedder) - -device = "cuda" if torch.cuda.is_available() else "cpu" -model.to(device) - collator = IJEPAMaskCollator( input_size=(224,224), patch_size=32, ) transform = IJEPATransform() + +# we ignore object detection annotations by setting target_transform to return 0 +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") dataset = torchvision.datasets.VOCDetection( "datasets/pascal_voc", download=True, @@ -71,13 +68,23 @@ def update_target_encoder(self,): persistent_workers=False ) -# we ignore object detection annotations by setting target_transform to return 0 +ema = (0.996, 1.0) +ipe_scale = 1.0 +ipe = len(data_loader) +num_epochs = 10 +momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) + for i in range(int(ipe*num_epochs*ipe_scale)+1)) +vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) +vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) +model = I_JEPA(vit_for_predictor, vit_for_embedder, momentum_scheduler) + criterion = nn.SmoothL1Loss() optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) -# or create a dataset from a folder containing images or videos: -# dataset = LightlyDataset("path/to/folder") +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + print("Starting Training") -for epoch in range(10): +for epoch in range(num_epochs): total_loss = 0 for itr, (udata, masks_enc, masks_pred) in enumerate(data_loader): From adac38b19c36f95e056da8bccc26ffbbed9085bf Mon Sep 17 00:00:00 2001 From: Natyren Date: Wed, 12 Jul 2023 21:59:58 +0300 Subject: [PATCH 23/37] little fix --- examples/pytorch/i_jepa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/i_jepa.py index 69f7345f9..e4b6a7284 100644 --- a/examples/pytorch/i_jepa.py +++ b/examples/pytorch/i_jepa.py @@ -6,7 +6,7 @@ from lightly.models import utils from lightly.models.modules import i_jepa -from lightly.transforms.mae_transform import IJEPATransform +from lightly.transforms.ijepa_transform import IJEPATransform from lightly.data.collate import IJEPAMaskCollator From 800aed66e114d1797aab31b02a28cbb4994df23b Mon Sep 17 00:00:00 2001 From: Natyren Date: Wed, 12 Jul 2023 22:03:57 +0300 Subject: [PATCH 24/37] little fix --- lightly/models/modules/i_jepa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index ebd8814e2..3af266efe 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -6,7 +6,8 @@ from torchvision.models import vision_transformer from lightly.models import utils -from typing import Optional, partial, Callable +from typing import Optional, Callable +from functools import partial import math From 3a92cda6601c0f3be494b197a3194c4ca1438dff Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 10:14:12 +0300 Subject: [PATCH 25/37] fix classmethod --- lightly/models/modules/i_jepa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 3af266efe..64624f036 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -82,6 +82,7 @@ def from_vit_encoder(cls, vit_encoder): num_layers=1, num_heads=1, hidden_dim=1, + predictor_embed_dim=512, mlp_dim=1, dropout=0, attention_dropout=0, From 485f9fcc9c0a1e4e918fe3b400ade7fa5f488a70 Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 10:18:13 +0300 Subject: [PATCH 26/37] fix collator --- lightly/data/collate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lightly/data/collate.py b/lightly/data/collate.py index d7136e326..1ce786aca 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -1345,7 +1345,7 @@ def forward( return (views_global, views_local, grids_global, grids_local), labels, fnames -class IJEPAMaskCollator(MultiViewCollateFunction): +class IJEPAMaskCollator: def __init__( self, @@ -1359,7 +1359,6 @@ def __init__( min_keep=4, allow_overlap=False ): - super(IJEPAMaskCollator, self).__init__() if not isinstance(input_size, tuple): input_size = (input_size, ) * 2 self.patch_size = patch_size From 80cc07730ce99c762da4b3b1d979424a68c8280d Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 10:27:23 +0300 Subject: [PATCH 27/37] fixes --- lightly/data/collate.py | 2 ++ lightly/models/modules/i_jepa.py | 1 + 2 files changed, 3 insertions(+) diff --git a/lightly/data/collate.py b/lightly/data/collate.py index 1ce786aca..82ad63009 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -16,6 +16,7 @@ from lightly.transforms.random_crop_and_flip_with_grid import RandomResizedCropAndFlip from lightly.transforms.rotation import random_rotation_transform from lightly.transforms.utils import IMAGENET_NORMALIZE +from multiprocessing import Value imagenet_normalize = IMAGENET_NORMALIZE # Kept for backwards compatibility @@ -1370,6 +1371,7 @@ def __init__( self.npred = npred self.min_keep = min_keep # minimum number of patches to keep self.allow_overlap = allow_overlap # whether to allow overlap b/w enc and pred masks + self._itr_counter = Value('i', -1) # collator is shared across worker processes def step(self): i = self._itr_counter diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 64624f036..954fe38b2 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -88,6 +88,7 @@ def from_vit_encoder(cls, vit_encoder): attention_dropout=0, ) encoder.layers = vit_encoder.layers + encoder.predictor_pos_embed = vit_encoder.pos_embedding encoder.ln = vit_encoder.ln return encoder From 45510e2bf97a94c6002bfa9894317452f9ed92a9 Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 10:34:17 +0300 Subject: [PATCH 28/37] fix in collator --- lightly/data/collate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightly/data/collate.py b/lightly/data/collate.py index 82ad63009..da7f90e1f 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -17,6 +17,7 @@ from lightly.transforms.rotation import random_rotation_transform from lightly.transforms.utils import IMAGENET_NORMALIZE from multiprocessing import Value +import math imagenet_normalize = IMAGENET_NORMALIZE # Kept for backwards compatibility @@ -1390,8 +1391,8 @@ def _sample_block_size(self, generator, scale, aspect_ratio_scale): min_ar, max_ar = aspect_ratio_scale aspect_ratio = min_ar + _rand * (max_ar - min_ar) # -- Compute block height and width (given scale and aspect-ratio) - h = int(round(torch.sqrt(max_keep * aspect_ratio))) - w = int(round(torch.sqrt(max_keep / aspect_ratio))) + h = int(round(math.sqrt(max_keep * aspect_ratio))) + w = int(round(math.sqrt(max_keep / aspect_ratio))) while h >= self.height: h -= 1 while w >= self.width: From 263c221a4b6af6de7462ca8d77e1dcca52eaea50 Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 11:22:37 +0300 Subject: [PATCH 29/37] fix collators, added imports, fix models --- lightly/models/modules/i_jepa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 954fe38b2..ee38612f1 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -191,6 +191,7 @@ def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder): dropout=0, attention_dropout=0, ) + encoder.pos_embedding = vit_encoder.pos_embedding encoder.dropout = vit_encoder.dropout encoder.layers = vit_encoder.layers encoder.ln = vit_encoder.ln From 6939ff91e52ed6bee357feb4ea54a4b31b69a705 Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 12:08:09 +0300 Subject: [PATCH 30/37] add ijepa backbone --- lightly/models/modules/i_jepa.py | 173 ++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 2 deletions(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index ee38612f1..c8f3ef0b0 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -4,9 +4,10 @@ import numpy as np from torchvision.models import vision_transformer +from torchvision.models.vision_transformer import ConvStemConfig from lightly.models import utils -from typing import Optional, Callable +from typing import Optional, List, Callable from functools import partial import math @@ -246,4 +247,172 @@ def interpolate_pos_encoding(self, input: torch.Tensor): mode="bicubic", ) pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) \ No newline at end of file + return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + +class MAEBackbone(vision_transformer.VisionTransformer): + """ + Encoder for the I-JEPA model [0]. + Converts images into patches and encodes them. Code inspired by [1]. + Note that this implementation uses a learned positional embedding while [0] + uses a fixed positional embedding. + + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + + Attributes: + image_size: + Input image size. + patch_size: + Width and height of the image patches. image_size must be a multiple + of patch_size. + num_layers: + Number of transformer blocks. + num_heads: + Number of attention heads. + hidden_dim: + Dimension of the input and output tokens. + mlp_dim: + Dimension of the MLP in the transformer block. + dropout: + Percentage of elements set to zero after the MLP in the transformer. + attention_dropout: + Percentage of elements set to zero after the attention head. + num_classes: + Number of classes for the classification head. Currently not used. + representation_size: + If specified, an additional linear layer is added before the + classification head to change the token dimension from hidden_dim + to representation_size. Currently not used. + norm_layer: + Callable that creates a normalization layer. + + """ + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0, + attention_dropout: float = 0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + super().__init__( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_classes=num_classes, + representation_size=representation_size, + norm_layer=norm_layer, + conv_stem_configs=conv_stem_configs, + ) + self.encoder = IJEPA_encoder( + seq_length=self.seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + + @classmethod + def from_vit(cls, vit: vision_transformer.VisionTransformer): + """Creates a IJEPAbackbone from a torchvision ViT model.""" + # Create a new instance with dummy values as they will be overwritten + # by the copied vit_encoder attributes + backbone = cls( + image_size=vit.image_size, + patch_size=vit.patch_size, + num_layers=1, + num_heads=1, + hidden_dim=vit.hidden_dim, + mlp_dim=vit.mlp_dim, + dropout=vit.dropout, + attention_dropout=vit.attention_dropout, + num_classes=vit.num_classes, + representation_size=vit.representation_size, + norm_layer=vit.norm_layer, + ) + backbone.conv_proj = vit.conv_proj + backbone.class_token = vit.class_token + backbone.seq_length = vit.seq_length + backbone.heads = vit.heads + backbone.encoder = IJEPA_encoder.from_vit_encoder(vit.encoder) + return backbone + + def forward( + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns encoded class tokens from a batch of images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, hidden_dim) containing the + encoded class token for every image. + + """ + out = self.encode(images, idx_keep) + class_token = out[:, 0] + return class_token + + def encode( + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns encoded class and patch tokens from images. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + idx_keep: + Tensor with shape (batch_size, num_tokens_to_keep) where each + entry is an index of the token to keep in the respective batch. + If specified, only the indexed tokens will be passed to the + encoder. + + Returns: + Tensor with shape (batch_size, sequence_length, hidden_dim) + containing the encoded class and patch tokens for every image. + + """ + out = self.images_to_tokens(images, prepend_class_token=True) + return self.encoder(out, idx_keep) + + def images_to_tokens( + self, images: torch.Tensor, prepend_class_token: bool + ) -> torch.Tensor: + """Converts images into patch tokens. + + Args: + images: + Tensor with shape (batch_size, channels, image_size, image_size). + + Returns: + Tensor with shape (batch_size, sequence_length - 1, hidden_dim) + containing the patch tokens. + """ + x = self.conv_proj(images) + tokens = x.flatten(2).transpose(1, 2) + if prepend_class_token: + tokens = utils.prepend_class_token(tokens, self.class_token) + return tokens From 0fd4639e956d2284cf125d28e2b93f0b1a35d3e9 Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 15:48:23 +0300 Subject: [PATCH 31/37] finish pure torch impelementation --- examples/pytorch/i_jepa.py | 19 +++-- lightly/models/modules/i_jepa.py | 102 +++++++++++++++++++++----- lightly/transforms/ijepa_transform.py | 2 +- 3 files changed, 97 insertions(+), 26 deletions(-) diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/i_jepa.py index e4b6a7284..a302a031d 100644 --- a/examples/pytorch/i_jepa.py +++ b/examples/pytorch/i_jepa.py @@ -9,12 +9,14 @@ from lightly.transforms.ijepa_transform import IJEPATransform from lightly.data.collate import IJEPAMaskCollator +from tqdm import tqdm + class I_JEPA(nn.Module): def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): super().__init__() - self.encoder = i_jepa.IJEPA_encoder.from_vit_encoder(vit_encoder) - self.predictor = i_jepa.IJEPA_predictor.from_vit_encoder(vit_predictor) + self.encoder = i_jepa.IJEPA_Backbone.from_vit(vit_encoder) + self.predictor = i_jepa.IJEPA_predictor.from_vit_encoder(vit_predictor.encoder, (vit_predictor.image_size//vit_predictor.patch_size)**2) self.target_encoder = copy.deepcopy(self.encoder) self.momentum_scheduler = momentum_scheduler @@ -34,8 +36,8 @@ def forward_context(self, imgs, masks_enc, masks_pred): return z def forward(self, imgs, masks_enc, masks_pred): - z = self.forward_context(self, imgs, masks_enc, masks_pred) - h = self.forward_target(self, imgs, masks_enc, masks_pred) + z = self.forward_context(imgs, masks_enc, masks_pred) + h = self.forward_target(imgs, masks_enc, masks_pred) return z, h def update_target_encoder(self,): @@ -64,7 +66,7 @@ def update_target_encoder(self,): data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, - batch_size=1, + batch_size=10, persistent_workers=False ) @@ -74,9 +76,10 @@ def update_target_encoder(self,): num_epochs = 10 momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) for i in range(int(ipe*num_epochs*ipe_scale)+1)) + vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) -model = I_JEPA(vit_for_predictor, vit_for_embedder, momentum_scheduler) +model = I_JEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler) criterion = nn.SmoothL1Loss() optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) @@ -86,7 +89,7 @@ def update_target_encoder(self,): print("Starting Training") for epoch in range(num_epochs): total_loss = 0 - for itr, (udata, masks_enc, masks_pred) in enumerate(data_loader): + for udata, masks_enc, masks_pred in tqdm(data_loader): def load_imgs(): # -- unsupervised imgs @@ -101,5 +104,7 @@ def load_imgs(): loss.backward() optimizer.step() optimizer.zero_grad() + model.update_target_encoder() + avg_loss = total_loss / len(data_loader) print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") \ No newline at end of file diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index c8f3ef0b0..971af45cb 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -11,6 +11,68 @@ from functools import partial import math +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid length + return: + pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + class IJEPA_predictor(vision_transformer.Encoder): """ @@ -46,7 +108,8 @@ def __init__( num_layers: int, num_heads: int, hidden_dim: int, - predictor_embed_dim :int, + predictor_embed_dim :int, + num_patches : int, mlp_dim: int, dropout: float, attention_dropout: float, @@ -66,15 +129,15 @@ def __init__( self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) self.predictor_proj = nn.Linear(predictor_embed_dim, mlp_dim, bias=True) - # self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), - # requires_grad=False) - # predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], - # int(num_patches**.5), - # cls_token=False) - # self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], + int(num_patches**.5), + cls_token=False) + self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) @classmethod - def from_vit_encoder(cls, vit_encoder): + def from_vit_encoder(cls, vit_encoder, num_patches): """Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder.""" # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes @@ -83,13 +146,13 @@ def from_vit_encoder(cls, vit_encoder): num_layers=1, num_heads=1, hidden_dim=1, - predictor_embed_dim=512, - mlp_dim=1, + predictor_embed_dim=768, + mlp_dim=768, + num_patches=num_patches, dropout=0, attention_dropout=0, ) encoder.layers = vit_encoder.layers - encoder.predictor_pos_embed = vit_encoder.pos_embedding encoder.ln = vit_encoder.ln return encoder @@ -105,10 +168,9 @@ def forward( masks = [masks] B = len(x) // len(masks_x) - x = self.predictor_embed(x) - x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += utils.apply_masks(x_pos_embed, masks_x) _, N_ctxt, _ = x.shape @@ -216,7 +278,7 @@ def forward( """ input = input + self.interpolate_pos_encoding(input) if idx_keep is not None: - input = utils.get_at_index(input, idx_keep) + input = utils.apply_masks(input, idx_keep) return self.ln(self.layers(self.dropout(input))) def interpolate_pos_encoding(self, input: torch.Tensor): @@ -249,7 +311,7 @@ def interpolate_pos_encoding(self, input: torch.Tensor): pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) -class MAEBackbone(vision_transformer.VisionTransformer): +class IJEPA_Backbone(vision_transformer.VisionTransformer): """ Encoder for the I-JEPA model [0]. Converts images into patches and encodes them. Code inspired by [1]. @@ -356,7 +418,8 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer): def forward( self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None ) -> torch.Tensor: - """Returns encoded class tokens from a batch of images. + """ + Returns encoded class tokens from a batch of images. Args: images: @@ -372,9 +435,12 @@ def forward( encoded class token for every image. """ + if idx_keep is not None: + if not isinstance(idx_keep, list): + idx_keep = [idx_keep] + out = self.encode(images, idx_keep) - class_token = out[:, 0] - return class_token + return out def encode( self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py index a212e694b..4d40cf327 100644 --- a/lightly/transforms/ijepa_transform.py +++ b/lightly/transforms/ijepa_transform.py @@ -54,4 +54,4 @@ def __call__(self, image: Union[Tensor, Image]) -> List[Tensor]: The transformed image. """ - return [self.transform(image)] + return self.transform(image) From 72ca5cb15ef3f84784c61c37091b3231242ffbf4 Mon Sep 17 00:00:00 2001 From: Natyren Date: Thu, 13 Jul 2023 20:03:02 +0300 Subject: [PATCH 32/37] docstring fix --- lightly/models/modules/i_jepa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/i_jepa.py index 971af45cb..9a6e0cbfa 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/i_jepa.py @@ -242,7 +242,7 @@ def __init__( @classmethod def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder): - """Creates a MAEEncoder from a torchvision ViT encoder.""" + """Creates a IJEPA encoder from a torchvision ViT encoder.""" # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes encoder = cls( From 605cdc27cc3d5d8e13bb40e34691d1eb705fdc81 Mon Sep 17 00:00:00 2001 From: Natyren Date: Fri, 14 Jul 2023 13:49:13 +0300 Subject: [PATCH 33/37] fixes of name and references to original paper --- examples/pytorch/{i_jepa.py => ijepa.py} | 10 +++++----- examples/pytorch_lightning/{i_jepa.py => ijepa.py} | 0 .../{i_jepa.py => ijepa.py} | 0 lightly/data/collate.py | 7 +++++++ lightly/models/modules/{i_jepa.py => ijepa.py} | 12 ++++++------ 5 files changed, 18 insertions(+), 11 deletions(-) rename examples/pytorch/{i_jepa.py => ijepa.py} (90%) rename examples/pytorch_lightning/{i_jepa.py => ijepa.py} (100%) rename examples/pytorch_lightning_distributed/{i_jepa.py => ijepa.py} (100%) rename lightly/models/modules/{i_jepa.py => ijepa.py} (98%) diff --git a/examples/pytorch/i_jepa.py b/examples/pytorch/ijepa.py similarity index 90% rename from examples/pytorch/i_jepa.py rename to examples/pytorch/ijepa.py index a302a031d..f9351b02d 100644 --- a/examples/pytorch/i_jepa.py +++ b/examples/pytorch/ijepa.py @@ -5,18 +5,18 @@ import copy from lightly.models import utils -from lightly.models.modules import i_jepa +from lightly.models.modules import ijepa from lightly.transforms.ijepa_transform import IJEPATransform from lightly.data.collate import IJEPAMaskCollator from tqdm import tqdm -class I_JEPA(nn.Module): +class IJEPA(nn.Module): def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): super().__init__() - self.encoder = i_jepa.IJEPA_Backbone.from_vit(vit_encoder) - self.predictor = i_jepa.IJEPA_predictor.from_vit_encoder(vit_predictor.encoder, (vit_predictor.image_size//vit_predictor.patch_size)**2) + self.encoder = ijepa.IJEPABackbone.from_vit(vit_encoder) + self.predictor = ijepa.IJEPAPredictor.from_vit_encoder(vit_predictor.encoder, (vit_predictor.image_size//vit_predictor.patch_size)**2) self.target_encoder = copy.deepcopy(self.encoder) self.momentum_scheduler = momentum_scheduler @@ -79,7 +79,7 @@ def update_target_encoder(self,): vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) -model = I_JEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler) +model = IJEPA(vit_for_embedder, vit_for_predictor, momentum_scheduler) criterion = nn.SmoothL1Loss() optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) diff --git a/examples/pytorch_lightning/i_jepa.py b/examples/pytorch_lightning/ijepa.py similarity index 100% rename from examples/pytorch_lightning/i_jepa.py rename to examples/pytorch_lightning/ijepa.py diff --git a/examples/pytorch_lightning_distributed/i_jepa.py b/examples/pytorch_lightning_distributed/ijepa.py similarity index 100% rename from examples/pytorch_lightning_distributed/i_jepa.py rename to examples/pytorch_lightning_distributed/ijepa.py diff --git a/lightly/data/collate.py b/lightly/data/collate.py index da7f90e1f..8ec637ccc 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -1348,7 +1348,14 @@ def forward( class IJEPAMaskCollator: + """ + Collator for IJEPA model [0]. + + Include collate function. Code inspired by [1]. + - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 + - [1]: https://github.com/facebookresearch/ijepa + """ def __init__( self, input_size=(224, 224), diff --git a/lightly/models/modules/i_jepa.py b/lightly/models/modules/ijepa.py similarity index 98% rename from lightly/models/modules/i_jepa.py rename to lightly/models/modules/ijepa.py index 9a6e0cbfa..4b5c2b499 100644 --- a/lightly/models/modules/i_jepa.py +++ b/lightly/models/modules/ijepa.py @@ -74,7 +74,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): return emb -class IJEPA_predictor(vision_transformer.Encoder): +class IJEPAPredictor(vision_transformer.Encoder): """ Predictor for the I-JEPA model [0]. @@ -191,7 +191,7 @@ def forward( return x -class IJEPA_encoder(vision_transformer.Encoder): +class IJEPAEncoder(vision_transformer.Encoder): """ Encoder for the I-JEPA model [0]. @@ -311,7 +311,7 @@ def interpolate_pos_encoding(self, input: torch.Tensor): pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) -class IJEPA_Backbone(vision_transformer.VisionTransformer): +class IJEPABackbone(vision_transformer.VisionTransformer): """ Encoder for the I-JEPA model [0]. Converts images into patches and encodes them. Code inspired by [1]. @@ -379,7 +379,7 @@ def __init__( norm_layer=norm_layer, conv_stem_configs=conv_stem_configs, ) - self.encoder = IJEPA_encoder( + self.encoder = IJEPAEncoder( seq_length=self.seq_length, num_layers=num_layers, num_heads=num_heads, @@ -392,7 +392,7 @@ def __init__( @classmethod def from_vit(cls, vit: vision_transformer.VisionTransformer): - """Creates a IJEPAbackbone from a torchvision ViT model.""" + """Creates a IJEPABackbone from a torchvision ViT model.""" # Create a new instance with dummy values as they will be overwritten # by the copied vit_encoder attributes backbone = cls( @@ -412,7 +412,7 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer): backbone.class_token = vit.class_token backbone.seq_length = vit.seq_length backbone.heads = vit.heads - backbone.encoder = IJEPA_encoder.from_vit_encoder(vit.encoder) + backbone.encoder = IJEPAEncoder.from_vit_encoder(vit.encoder) return backbone def forward( From 1be645384cfab8a3c075aba02209e3980e6f4450 Mon Sep 17 00:00:00 2001 From: guarin Date: Fri, 14 Jul 2023 13:36:37 +0000 Subject: [PATCH 34/37] Format --- examples/pytorch/ijepa.py | 43 ++++++++++++++----------- lightly/data/collate.py | 57 ++++++++++++++++++++------------- lightly/models/modules/ijepa.py | 54 +++++++++++++++++-------------- lightly/models/utils.py | 16 +++++---- 4 files changed, 99 insertions(+), 71 deletions(-) diff --git a/examples/pytorch/ijepa.py b/examples/pytorch/ijepa.py index f9351b02d..b36b463ad 100644 --- a/examples/pytorch/ijepa.py +++ b/examples/pytorch/ijepa.py @@ -1,22 +1,25 @@ +import copy + import torch import torchvision from torch import nn from torch.nn import functional as F -import copy +from tqdm import tqdm +from lightly.data.collate import IJEPAMaskCollator from lightly.models import utils from lightly.models.modules import ijepa from lightly.transforms.ijepa_transform import IJEPATransform -from lightly.data.collate import IJEPAMaskCollator - -from tqdm import tqdm class IJEPA(nn.Module): def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): super().__init__() self.encoder = ijepa.IJEPABackbone.from_vit(vit_encoder) - self.predictor = ijepa.IJEPAPredictor.from_vit_encoder(vit_predictor.encoder, (vit_predictor.image_size//vit_predictor.patch_size)**2) + self.predictor = ijepa.IJEPAPredictor.from_vit_encoder( + vit_predictor.encoder, + (vit_predictor.image_size // vit_predictor.patch_size) ** 2, + ) self.target_encoder = copy.deepcopy(self.encoder) self.momentum_scheduler = momentum_scheduler @@ -39,16 +42,20 @@ def forward(self, imgs, masks_enc, masks_pred): z = self.forward_context(imgs, masks_enc, masks_pred) h = self.forward_target(imgs, masks_enc, masks_pred) return z, h - - def update_target_encoder(self,): + + def update_target_encoder( + self, + ): with torch.no_grad(): m = next(self.momentum_scheduler) - for param_q, param_k in zip(self.encoder.parameters(), self.target_encoder.parameters()): - param_k.data.mul_(m).add_((1.-m) * param_q.detach().data) + for param_q, param_k in zip( + self.encoder.parameters(), self.target_encoder.parameters() + ): + param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) collator = IJEPAMaskCollator( - input_size=(224,224), + input_size=(224, 224), patch_size=32, ) @@ -64,18 +71,17 @@ def update_target_encoder(self,): target_transform=lambda t: 0, ) data_loader = torch.utils.data.DataLoader( - dataset, - collate_fn=collator, - batch_size=10, - persistent_workers=False + dataset, collate_fn=collator, batch_size=10, persistent_workers=False ) ema = (0.996, 1.0) ipe_scale = 1.0 ipe = len(data_loader) num_epochs = 10 -momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale) - for i in range(int(ipe*num_epochs*ipe_scale)+1)) +momentum_scheduler = ( + ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) + for i in range(int(ipe * num_epochs * ipe_scale) + 1) +) vit_for_predictor = torchvision.models.vit_b_32(pretrained=False) vit_for_embedder = torchvision.models.vit_b_32(pretrained=False) @@ -97,6 +103,7 @@ def load_imgs(): masks_1 = [u.to(device, non_blocking=True) for u in masks_enc] masks_2 = [u.to(device, non_blocking=True) for u in masks_pred] return (imgs, masks_1, masks_2) + imgs, masks_enc, masks_pred = load_imgs() z, h = model(imgs, masks_enc, masks_pred) loss = criterion(z, h) @@ -105,6 +112,6 @@ def load_imgs(): optimizer.step() optimizer.zero_grad() model.update_target_encoder() - + avg_loss = total_loss / len(data_loader) - print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") \ No newline at end of file + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") diff --git a/lightly/data/collate.py b/lightly/data/collate.py index 8ec637ccc..bc61a9442 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -3,6 +3,8 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +import math +from multiprocessing import Value from typing import List, Optional, Tuple, Union from warnings import warn @@ -16,8 +18,6 @@ from lightly.transforms.random_crop_and_flip_with_grid import RandomResizedCropAndFlip from lightly.transforms.rotation import random_rotation_transform from lightly.transforms.utils import IMAGENET_NORMALIZE -from multiprocessing import Value -import math imagenet_normalize = IMAGENET_NORMALIZE # Kept for backwards compatibility @@ -1356,6 +1356,7 @@ class IJEPAMaskCollator: - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 - [1]: https://github.com/facebookresearch/ijepa """ + def __init__( self, input_size=(224, 224), @@ -1366,20 +1367,25 @@ def __init__( nenc=1, npred=2, min_keep=4, - allow_overlap=False + allow_overlap=False, ): if not isinstance(input_size, tuple): - input_size = (input_size, ) * 2 + input_size = (input_size,) * 2 self.patch_size = patch_size - self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size + self.height, self.width = ( + input_size[0] // patch_size, + input_size[1] // patch_size, + ) self.enc_mask_scale = enc_mask_scale self.pred_mask_scale = pred_mask_scale self.aspect_ratio = aspect_ratio self.nenc = nenc self.npred = npred self.min_keep = min_keep # minimum number of patches to keep - self.allow_overlap = allow_overlap # whether to allow overlap b/w enc and pred masks - self._itr_counter = Value('i', -1) # collator is shared across worker processes + self.allow_overlap = ( + allow_overlap # whether to allow overlap b/w enc and pred masks + ) + self._itr_counter = Value("i", -1) # collator is shared across worker processes def step(self): i = self._itr_counter @@ -1411,10 +1417,11 @@ def _sample_block_mask(self, b_size, acceptable_regions=None): h, w = b_size def constrain_mask(mask, tries=0): - """ Helper to restrict given mask to a set of acceptable regions """ - N = max(int(len(acceptable_regions)-tries), 0) + """Helper to restrict given mask to a set of acceptable regions""" + N = max(int(len(acceptable_regions) - tries), 0) for k in range(N): mask *= acceptable_regions[k] + # -- # -- Loop to sample masks until we find a valid one tries = 0 @@ -1425,7 +1432,7 @@ def constrain_mask(mask, tries=0): top = torch.randint(0, self.height - h, (1,)) left = torch.randint(0, self.width - w, (1,)) mask = torch.zeros((self.height, self.width), dtype=torch.int32) - mask[top:top+h, left:left+w] = 1 + mask[top : top + h, left : left + w] = 1 # -- Constrain mask to a set of acceptable regions if acceptable_regions is not None: constrain_mask(mask, tries) @@ -1440,19 +1447,19 @@ def constrain_mask(mask, tries=0): mask = mask.squeeze() # -- mask_complement = torch.ones((self.height, self.width), dtype=torch.int32) - mask_complement[top:top+h, left:left+w] = 0 + mask_complement[top : top + h, left : left + w] = 0 # -- return mask, mask_complement def __call__(self, batch): - ''' + """ Create encoder and predictor masks when collating imgs into a batch # 1. sample enc block (size + location) using seed # 2. sample pred block (size) using seed # 3. sample several enc block locations for each image (w/o seed) # 4. sample several pred block locations for each image (w/o seed) # 5. return enc mask and pred mask - ''' + """ B = len(batch) collated_batch = torch.utils.data.default_collate(batch) @@ -1463,17 +1470,16 @@ def __call__(self, batch): p_size = self._sample_block_size( generator=g, scale=self.pred_mask_scale, - aspect_ratio_scale=self.aspect_ratio) + aspect_ratio_scale=self.aspect_ratio, + ) e_size = self._sample_block_size( - generator=g, - scale=self.enc_mask_scale, - aspect_ratio_scale=(1., 1.)) + generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0) + ) collated_masks_pred, collated_masks_enc = [], [] min_keep_pred = self.height * self.width min_keep_enc = self.height * self.width for _ in range(B): - masks_p, masks_C = [], [] for _ in range(self.npred): mask, mask_C = self._sample_block_mask(p_size) @@ -1485,20 +1491,25 @@ def __call__(self, batch): acceptable_regions = masks_C if self.allow_overlap: - acceptable_regions= None - + acceptable_regions = None masks_e = [] for _ in range(self.nenc): - mask, _ = self._sample_block_mask(e_size, acceptable_regions=acceptable_regions) + mask, _ = self._sample_block_mask( + e_size, acceptable_regions=acceptable_regions + ) masks_e.append(mask) min_keep_enc = min(min_keep_enc, len(mask)) collated_masks_enc.append(masks_e) - collated_masks_pred = [[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred] + collated_masks_pred = [ + [cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred + ] collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) # -- - collated_masks_enc = [[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc] + collated_masks_enc = [ + [cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc + ] collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) return collated_batch, collated_masks_enc, collated_masks_pred diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py index 4b5c2b499..0a8803cdd 100644 --- a/lightly/models/modules/ijepa.py +++ b/lightly/models/modules/ijepa.py @@ -1,15 +1,16 @@ +import math +from functools import partial +from typing import Callable, List, Optional + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np - from torchvision.models import vision_transformer from torchvision.models.vision_transformer import ConvStemConfig from lightly.models import utils -from typing import Optional, List, Callable -from functools import partial -import math + def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ @@ -61,11 +62,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) @@ -102,14 +103,15 @@ class IJEPAPredictor(vision_transformer.Encoder): Percentage of elements set to zero after the attention head. """ + def __init__( self, seq_length: int, num_layers: int, num_heads: int, hidden_dim: int, - predictor_embed_dim :int, - num_patches : int, + predictor_embed_dim: int, + num_patches: int, mlp_dim: int, dropout: float, attention_dropout: float, @@ -129,12 +131,15 @@ def __init__( self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) self.predictor_proj = nn.Linear(predictor_embed_dim, mlp_dim, bias=True) - self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim), - requires_grad=False) - predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1], - int(num_patches**.5), - cls_token=False) - self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0)) + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False + ) + predictor_pos_embed = get_2d_sincos_pos_embed( + self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False + ) + self.predictor_pos_embed.data.copy_( + torch.from_numpy(predictor_pos_embed).float().unsqueeze(0) + ) @classmethod def from_vit_encoder(cls, vit_encoder, num_patches): @@ -148,7 +153,7 @@ def from_vit_encoder(cls, vit_encoder, num_patches): hidden_dim=1, predictor_embed_dim=768, mlp_dim=768, - num_patches=num_patches, + num_patches=num_patches, dropout=0, attention_dropout=0, ) @@ -156,10 +161,10 @@ def from_vit_encoder(cls, vit_encoder, num_patches): encoder.ln = vit_encoder.ln return encoder - def forward( - self, x, masks_x, masks - ): - assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices' + def forward(self, x, masks_x, masks): + assert (masks is not None) and ( + masks_x is not None + ), "Cannot run predictor without mask indices" if not isinstance(masks_x, list): masks_x = [masks_x] @@ -190,7 +195,7 @@ def forward( return x - + class IJEPAEncoder(vision_transformer.Encoder): """ Encoder for the I-JEPA model [0]. @@ -311,6 +316,7 @@ def interpolate_pos_encoding(self, input: torch.Tensor): pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) + class IJEPABackbone(vision_transformer.VisionTransformer): """ Encoder for the I-JEPA model [0]. @@ -438,7 +444,7 @@ def forward( if idx_keep is not None: if not isinstance(idx_keep, list): idx_keep = [idx_keep] - + out = self.encode(images, idx_keep) return out diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 9004cdb5b..7215395a3 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -570,7 +570,7 @@ def get_weight_decay_parameters( return params, params_no_weight_decay -def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): return _no_grad_trunc_normal(tensor, mean, std, a, b) @@ -585,10 +585,14 @@ def apply_masks(x, masks): all_x += [torch.gather(x, dim=1, index=mask_keep)] return torch.cat(all_x, dim=0) + def repeat_interleave_batch(x, B, repeat): N = len(x) // B - x = torch.cat([ - torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) - for i in range(N) - ], dim=0) - return x \ No newline at end of file + x = torch.cat( + [ + torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) + for i in range(N) + ], + dim=0, + ) + return x From d35b7a9ff1949efae7ff7c996fc610fc56fd49d4 Mon Sep 17 00:00:00 2001 From: guarin Date: Fri, 14 Jul 2023 13:42:00 +0000 Subject: [PATCH 35/37] Add note about experimental support --- lightly/data/collate.py | 8 +- lightly/models/modules/ijepa.py | 148 ++++++++++++++------------ lightly/transforms/ijepa_transform.py | 13 +-- 3 files changed, 89 insertions(+), 80 deletions(-) diff --git a/lightly/data/collate.py b/lightly/data/collate.py index bc61a9442..3a935951f 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -1348,10 +1348,12 @@ def forward( class IJEPAMaskCollator: - """ - Collator for IJEPA model [0]. + """Collator for IJEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. - Include collate function. Code inspired by [1]. + Code inspired by [1]. - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 - [1]: https://github.com/facebookresearch/ijepa diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py index 0a8803cdd..3eb14a247 100644 --- a/lightly/models/modules/ijepa.py +++ b/lightly/models/modules/ijepa.py @@ -5,79 +5,17 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from torchvision.models import vision_transformer from torchvision.models.vision_transformer import ConvStemConfig from lightly.models import utils -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=float) - grid_w = np.arange(grid_size, dtype=float) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid length - return: - pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid = np.arange(grid_size, dtype=float) - pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - class IJEPAPredictor(vision_transformer.Encoder): - """ - Predictor for the I-JEPA model [0]. + """Predictor for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. Predict patch embeddings. Code inspired by [1]. @@ -134,7 +72,7 @@ def __init__( self.predictor_pos_embed = nn.Parameter( torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False ) - predictor_pos_embed = get_2d_sincos_pos_embed( + predictor_pos_embed = _get_2d_sincos_pos_embed( self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False ) self.predictor_pos_embed.data.copy_( @@ -197,8 +135,10 @@ def forward(self, x, masks_x, masks): class IJEPAEncoder(vision_transformer.Encoder): - """ - Encoder for the I-JEPA model [0]. + """Encoder for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. Encodes patch embeddings. Code inspired by [1]. @@ -318,8 +258,11 @@ def interpolate_pos_encoding(self, input: torch.Tensor): class IJEPABackbone(vision_transformer.VisionTransformer): - """ - Encoder for the I-JEPA model [0]. + """Encoder for the I-JEPA model [0]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. + Converts images into patches and encodes them. Code inspired by [1]. Note that this implementation uses a learned positional embedding while [0] uses a fixed positional embedding. @@ -488,3 +431,66 @@ def images_to_tokens( if prepend_class_token: tokens = utils.prepend_class_token(tokens, self.class_token) return tokens + + +def _get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = _get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def _get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def _get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid length + return: + pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = _get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def _get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py index 4d40cf327..321dba66a 100644 --- a/lightly/transforms/ijepa_transform.py +++ b/lightly/transforms/ijepa_transform.py @@ -1,15 +1,17 @@ -from typing import List, Tuple, Union +from typing import Tuple, Union import torchvision.transforms as T from PIL.Image import Image from torch import Tensor -from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.utils import IMAGENET_NORMALIZE class IJEPATransform: - """Implements the augmentations for I-JEPA (IMAGENET data transforms accorgind to original code) [0, 1]. + """Implements the augmentations for I-JEPA [0, 1]. + + Experimental: Support for I-JEPA is experimental, there might be breaking changes + in the future. - [0]: Joint-Embedding Predictive Architecture, 2023, https://arxiv.org/abs/2301.08243 - [1]: https://github.com/facebookresearch/ijepa @@ -42,9 +44,8 @@ def __init__( self.transform = T.Compose(transforms) - def __call__(self, image: Union[Tensor, Image]) -> List[Tensor]: - """ - Applies the transforms to the input image. + def __call__(self, image: Union[Tensor, Image]) -> Tensor: + """Applies the transforms to the input image. Args: image: From d262f12153a49c4cc7423f0bfca59bb177684737 Mon Sep 17 00:00:00 2001 From: guarin Date: Fri, 14 Jul 2023 13:51:17 +0000 Subject: [PATCH 36/37] Add datasets to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 2fdc54c33..031ae8795 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ lightning_logs/ **lightning_logs/ **/__MACOSX +datasets/ docs/source/tutorials/package/* docs/source/tutorials/platform/* docs/source/tutorials_source/platform/data From 3dafc052abd294012fce5813965486341e218b36 Mon Sep 17 00:00:00 2001 From: guarin Date: Fri, 14 Jul 2023 13:51:28 +0000 Subject: [PATCH 37/37] Cleanup imports --- examples/pytorch/ijepa.py | 6 +++--- examples/pytorch_lightning/ijepa.py | 1 + examples/pytorch_lightning_distributed/ijepa.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/ijepa.py b/examples/pytorch/ijepa.py index b36b463ad..eb4730e04 100644 --- a/examples/pytorch/ijepa.py +++ b/examples/pytorch/ijepa.py @@ -8,15 +8,15 @@ from lightly.data.collate import IJEPAMaskCollator from lightly.models import utils -from lightly.models.modules import ijepa +from lightly.models.modules.ijepa import IJEPABackbone, IJEPAPredictor from lightly.transforms.ijepa_transform import IJEPATransform class IJEPA(nn.Module): def __init__(self, vit_encoder, vit_predictor, momentum_scheduler): super().__init__() - self.encoder = ijepa.IJEPABackbone.from_vit(vit_encoder) - self.predictor = ijepa.IJEPAPredictor.from_vit_encoder( + self.encoder = IJEPABackbone.from_vit(vit_encoder) + self.predictor = IJEPAPredictor.from_vit_encoder( vit_predictor.encoder, (vit_predictor.image_size // vit_predictor.patch_size) ** 2, ) diff --git a/examples/pytorch_lightning/ijepa.py b/examples/pytorch_lightning/ijepa.py index e69de29bb..464090415 100644 --- a/examples/pytorch_lightning/ijepa.py +++ b/examples/pytorch_lightning/ijepa.py @@ -0,0 +1 @@ +# TODO diff --git a/examples/pytorch_lightning_distributed/ijepa.py b/examples/pytorch_lightning_distributed/ijepa.py index e69de29bb..464090415 100644 --- a/examples/pytorch_lightning_distributed/ijepa.py +++ b/examples/pytorch_lightning_distributed/ijepa.py @@ -0,0 +1 @@ +# TODO