Skip to content

Commit

Permalink
[PROGRESS REPORT]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 8, 2024
1 parent a85cc40 commit 530be8a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 77 deletions.
27 changes: 0 additions & 27 deletions Dockerfile

This file was deleted.

114 changes: 64 additions & 50 deletions screenai/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F
Expand All @@ -14,6 +13,7 @@
Encoder,
ViTransformerWrapper,
)
from einops.layers.torch import Rearrange

# helper functions

Expand All @@ -26,6 +26,10 @@ def default(val, d):
return val if exists(val) else d


def pair(val):
return (val, val) if not isinstance(val, tuple) else val


def divisible_by(numer, denom):
return (numer % denom) == 0

Expand Down Expand Up @@ -241,25 +245,25 @@ def __init__(
*args,
**kwargs,
):
# super(self, MultiModalEncoder).__init__(*args, **kwargs)
super().__init__()
self.dim = dim
self.depth = depth
self.heads = heads
self.dim_head = dim_head
# self.layers = nn.ModuleList([])

# for _ in range(depth):
# attention_layer = Attention(dim, dim_head, heads, causal=True, qk_norm=True, flash="cuda")
# feedforward_layer = FeedForward(dim, dim, 4, *args, **kwargs)

# self.layers.append(attention_layer)
# self.layers.append(feedforward_layer)
self.flash = "cuda" if torch.cuda.is_available() else "cpu"

self.layers = nn.ModuleList(
[Attention(dim, dim_head, heads, causal=True, qk_norm=True, flash="cuda"),
FeedForward(dim, dim, 4, *args, **kwargs)]
for _ in range(depth)
self.attn = Attention(
dim,
dim_head,
heads,
causal=True,
qk_norm=True,
flash=self.flash,
)
self.ffn = FeedForward(dim, dim, 4, *args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the MultiModalEncoder.
Expand All @@ -271,11 +275,12 @@ def forward(self, x: Tensor) -> Tensor:
Tensor: The encoded tensor.
"""
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
skip = x
x, _ = self.attn(x)
x = x + skip
x = self.ffn(x) + x

return x
return x + skip


class MultiModalDecoder(nn.Module):
Expand Down Expand Up @@ -314,36 +319,29 @@ def __init__(
self.depth = depth
self.heads = heads
self.dim_head = dim_head
self.flash = "cuda" if torch.cuda.is_available() else "cpu"
self.cross_attn = CrossAttention(
dim,
dim_head=dim_head,
heads=heads,
parallel_ff=True,
)

self.layers = nn.ModuleList(
[
(
CrossAttention(
dim,
context_dim=dim,
heads=heads,
*args,
**kwargs,
),
Attention(
dim,
dim_head,
heads,
causal=True,
qk_norm=True,
flash="cuda",
),
)
for _ in range(depth)
]
self.attn = Attention(
dim,
dim_head,
heads,
causal=True,
qk_norm=True,
flash=self.flash,
)

def forward(self, x: Tensor) -> Tensor:
for cross_attn, attn in self.layers:
x = cross_attn(x, x) + x
x = attn(x) + x
skip = x
x = self.cross_attn(x, x) + x
x, _ = self.attn(x)

return x
return x + skip


class ScreenAI(nn.Module):
Expand Down Expand Up @@ -394,6 +392,7 @@ def __init__(
multi_modal_encoder_depth: int = 4,
llm_decoder_depth: int = 4,
mm_encoder_ff_mult: int = 4,
channels: int = 3,
*args,
**kwargs,
):
Expand All @@ -407,7 +406,9 @@ def __init__(
self.multi_modal_encoder_depth = multi_modal_encoder_depth
self.llm_decoder_depth = llm_decoder_depth


image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
patch_dim = channels * patch_height * patch_width

# ViTransformerWrapper
self.vit = ViTransformerWrapper(
Expand Down Expand Up @@ -446,6 +447,18 @@ def __init__(
heads,
)

# Patch embedding
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=patch_height,
p2=patch_width,
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

def forward(self, text: Tensor, img: Tensor) -> Tensor:
"""
Forward pass of the ScreenAI module.
Expand All @@ -458,19 +471,20 @@ def forward(self, text: Tensor, img: Tensor) -> Tensor:
Tensor: Output tensor.
"""
# Image patch
img = rearrange(
img,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_size,
p2=self.patch_size,
)
print(f"Image patch shape: {img.shape}")
# img = rearrange(
# img,
# "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
# p1=self.patch_size,
# p2=self.patch_size,
# )
# print(f"Image patch shape: {img.shape}")

# vit
img = self.vit(img, return_embeddings=True)

# Embed image
img = self.image_embedding(img)
# img = self.image_embedding(img)
# img = self.to_patch_embedding(img)

# Concatenate image and text
x = torch.cat((img, text), dim=1)
Expand Down

0 comments on commit 530be8a

Please sign in to comment.