From 04cfb243764fd79187efea255f2c64cb62ef828d Mon Sep 17 00:00:00 2001 From: Kye Date: Thu, 8 Feb 2024 14:01:28 -0800 Subject: [PATCH] [dynamic patching] --- pyproject.toml | 2 +- screenai/main.py | 107 +++++++++++++++++++++++++++++++---------------- 2 files changed, 72 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6ad938..e65fa9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "screenai" -version = "0.0.4" +version = "0.0.5" description = "Screen AI - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/screenai/main.py b/screenai/main.py index 893d5e5..2d001ac 100644 --- a/screenai/main.py +++ b/screenai/main.py @@ -34,6 +34,20 @@ def divisible_by(numer, denom): return (numer % denom) == 0 +def dynamic_patching(x, patch_size, image_size): + # Calculate the patch size based off the image + patch_size = pair(patch_size) + image_size = pair(image_size) + + # Get the height and width of the image + h, w = image_size + + # Use einops to rearrange the image + x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size[0], p2=patch_size[1]) + + return x + + # distributed @@ -122,46 +136,58 @@ def forward(self, x): class CrossAttention(nn.Module): def __init__( - self, - dim, - *, - context_dim=None, - dim_head=64, - heads=8, - parallel_ff=False, - ff_mult=4, - norm_context=False, - ): - super().__init__() - self.heads = heads - self.scale = dim_head**-0.5 - inner_dim = heads * dim_head - context_dim = default(context_dim, dim) - - self.norm = nn.LayerNorm(dim) - self.context_norm = ( - nn.LayerNorm(context_dim) - if norm_context - else nn.Identity() - ) + self, + dim, + *, + context_dim=None, + dim_head=64, + heads=8, + parallel_ff=False, + ff_mult=4, + norm_context=False, + ): + """ + Initializes the ScreenAI model. + + Args: + dim (int): The input dimension. + context_dim (int, optional): The dimension of the context. Defaults to None. + dim_head (int, optional): The dimension of each head. Defaults to 64. + heads (int, optional): The number of attention heads. Defaults to 8. + parallel_ff (bool, optional): Whether to use parallel feedforward. Defaults to False. + ff_mult (int, optional): The multiplier for the feedforward inner dimension. Defaults to 4. + norm_context (bool, optional): Whether to apply layer normalization to the context. Defaults to False. + """ + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + inner_dim = heads * dim_head + context_dim = default(context_dim, dim) + + self.norm = nn.LayerNorm(dim) + self.context_norm = ( + nn.LayerNorm(context_dim) + if norm_context + else nn.Identity() + ) - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) - # whether to have parallel feedforward + # whether to have parallel feedforward - ff_inner_dim = ff_mult * dim + ff_inner_dim = ff_mult * dim - self.ff = ( - nn.Sequential( - nn.Linear(dim, ff_inner_dim * 2, bias=False), - SwiGLU(), - nn.Linear(ff_inner_dim, dim, bias=False), + self.ff = ( + nn.Sequential( + nn.Linear(dim, ff_inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(ff_inner_dim, dim, bias=False), + ) + if parallel_ff + else None ) - if parallel_ff - else None - ) def forward(self, x, context): """ @@ -478,7 +504,16 @@ def forward(self, text: Tensor, img: Tensor) -> Tensor: # p2=self.patch_size, # ) # print(f"Image patch shape: {img.shape}") - + + + # Aspect ratio preserving grid with max e.g 25 patches, output needs to be 4 + x = rearrange( + img, + "b c (h p1) (w p2) -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + # vit img = self.vit(img, return_embeddings=True)