Skip to content

Commit

Permalink
[Patch embedding[
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 8, 2024
1 parent 04cfb24 commit e5e11b7
Showing 1 changed file with 73 additions and 62 deletions.
135 changes: 73 additions & 62 deletions screenai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,18 @@ def dynamic_patching(x, patch_size, image_size):
# Calculate the patch size based off the image
patch_size = pair(patch_size)
image_size = pair(image_size)

# Get the height and width of the image
h, w = image_size

# Use einops to rearrange the image
x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size[0], p2=patch_size[1])

x = rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=patch_size[0],
p2=patch_size[1],
)

return x


Expand Down Expand Up @@ -136,58 +141,58 @@ def forward(self, x):

class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False,
):
"""
Initializes the ScreenAI model.
Args:
dim (int): The input dimension.
context_dim (int, optional): The dimension of the context. Defaults to None.
dim_head (int, optional): The dimension of each head. Defaults to 64.
heads (int, optional): The number of attention heads. Defaults to 8.
parallel_ff (bool, optional): Whether to use parallel feedforward. Defaults to False.
ff_mult (int, optional): The multiplier for the feedforward inner dimension. Defaults to 4.
norm_context (bool, optional): Whether to apply layer normalization to the context. Defaults to False.
"""
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)

self.norm = nn.LayerNorm(dim)
self.context_norm = (
nn.LayerNorm(context_dim)
if norm_context
else nn.Identity()
)
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
parallel_ff=False,
ff_mult=4,
norm_context=False,
):
"""
Initializes the ScreenAI model.
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
Args:
dim (int): The input dimension.
context_dim (int, optional): The dimension of the context. Defaults to None.
dim_head (int, optional): The dimension of each head. Defaults to 64.
heads (int, optional): The number of attention heads. Defaults to 8.
parallel_ff (bool, optional): Whether to use parallel feedforward. Defaults to False.
ff_mult (int, optional): The multiplier for the feedforward inner dimension. Defaults to 4.
norm_context (bool, optional): Whether to apply layer normalization to the context. Defaults to False.
"""
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
inner_dim = heads * dim_head
context_dim = default(context_dim, dim)

self.norm = nn.LayerNorm(dim)
self.context_norm = (
nn.LayerNorm(context_dim)
if norm_context
else nn.Identity()
)

# whether to have parallel feedforward
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)

ff_inner_dim = ff_mult * dim
# whether to have parallel feedforward

self.ff = (
nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False),
)
if parallel_ff
else None
ff_inner_dim = ff_mult * dim

self.ff = (
nn.Sequential(
nn.Linear(dim, ff_inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(ff_inner_dim, dim, bias=False),
)
if parallel_ff
else None
)

def forward(self, x, context):
"""
Expand Down Expand Up @@ -474,14 +479,20 @@ def __init__(
)

# Patch embedding
# self.to_patch_embedding = nn.Sequential(
# Rearrange(
# "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
# p1=patch_height,
# p2=patch_width,
# ),
# nn.LayerNorm(patch_dim),
# nn.Linear(patch_dim, dim),
# nn.LayerNorm(dim),
# )
# Patch embedding for 3d image tensor,
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=patch_height,
p2=patch_width,
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
)

Expand All @@ -504,22 +515,22 @@ def forward(self, text: Tensor, img: Tensor) -> Tensor:
# p2=self.patch_size,
# )
# print(f"Image patch shape: {img.shape}")



# Aspect ratio preserving grid with max e.g 25 patches, output needs to be 4
x = rearrange(
img,
"b c (h p1) (w p2) -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
)

# vit
img = self.vit(img, return_embeddings=True)
print(f"Image shape: {img.shape}")

# Embed image
# img = self.image_embedding(img)
# img = self.to_patch_embedding(img)
img = self.to_patch_embedding(img)

# Concatenate image and text
x = torch.cat((img, text), dim=1)
Expand Down

0 comments on commit e5e11b7

Please sign in to comment.