generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 28
/
example.py
29 lines (24 loc) · 658 Bytes
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from screenai.main import ScreenAI
# Create a tensor for the image
image = torch.rand(1, 3, 224, 224)
# Create a tensor for the text
text = torch.randint(0, 20000, (1, 1028))
# Create an instance of the ScreenAI model with specified parameters
model = ScreenAI(
num_tokens = 20000,
max_seq_len = 1028,
patch_size=16,
image_size=224,
dim=512,
depth=6,
heads=8,
vit_depth=4,
multi_modal_encoder_depth=4,
llm_decoder_depth=4,
mm_encoder_ff_mult=4,
)
# Perform forward pass of the model with the given text and image tensors
out = model(text, image)
# Print the shape of the output tensor
print(out)