Skip to content

Commit

Permalink
[dynamic patching]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 8, 2024
1 parent 66320a2 commit 04cfb24
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "screenai"
version = "0.0.4"
version = "0.0.5"
description = "Screen AI - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand Down
107 changes: 71 additions & 36 deletions screenai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ def divisible_by(numer, denom):
return (numer % denom) == 0


def dynamic_patching(x, patch_size, image_size):
# Calculate the patch size based off the image
patch_size = pair(patch_size)
image_size = pair(image_size)

# Get the height and width of the image
h, w = image_size

# Use einops to rearrange the image
x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size[0], p2=patch_size[1])

return x


# distributed


Expand Down Expand Up @@ -122,46 +136,58 @@ def forward(self, x):

class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False,
):
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)

self.norm = nn.LayerNorm(dim)
self.context_norm = (
nn.LayerNorm(context_dim)
if norm_context
else nn.Identity()
)
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False,
):
"""
Initializes the ScreenAI model.
Args:
dim (int): The input dimension.
context_dim (int, optional): The dimension of the context. Defaults to None.
dim_head (int, optional): The dimension of each head. Defaults to 64.
heads (int, optional): The number of attention heads. Defaults to 8.
parallel_ff (bool, optional): Whether to use parallel feedforward. Defaults to False.
ff_mult (int, optional): The multiplier for the feedforward inner dimension. Defaults to 4.
norm_context (bool, optional): Whether to apply layer normalization to the context. Defaults to False.
"""
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)

self.norm = nn.LayerNorm(dim)
self.context_norm = (
nn.LayerNorm(context_dim)
if norm_context
else nn.Identity()
)

self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)

# whether to have parallel feedforward
# whether to have parallel feedforward

ff_inner_dim = ff_mult * dim
ff_inner_dim = ff_mult * dim

self.ff = (
nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False),
self.ff = (
nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False),
)
if parallel_ff
else None
)
if parallel_ff
else None
)

def forward(self, x, context):
"""
Expand Down Expand Up @@ -478,7 +504,16 @@ def forward(self, text: Tensor, img: Tensor) -> Tensor:
# p2=self.patch_size,
# )
# print(f"Image patch shape: {img.shape}")



# Aspect ratio preserving grid with max e.g 25 patches, output needs to be 4
x = rearrange(
img,
"b c (h p1) (w p2) -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
)

# vit
img = self.vit(img, return_embeddings=True)

Expand Down

0 comments on commit 04cfb24

Please sign in to comment.