-
Notifications
You must be signed in to change notification settings - Fork 5
/
transformer_utils.py
220 lines (187 loc) · 9.42 KB
/
transformer_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# This file is largely from timm
# The functions from timm (https://github.com/huggingface/pytorch-image-models/tree/main) adheres to the original license
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.layers import DropPath
from timm.layers.helpers import to_2tuple
torch_version = torch.__version__
is_torch2 = torch_version.startswith('2.')
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
if is_torch2:
self.attn_drop = attn_drop
else:
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
if is_torch2:
attn = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.attn_drop,
)
x = attn.transpose(1, 2).reshape(B, N, C)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = decoder_dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias)
self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias)
if is_torch2:
self.attn_drop = attn_drop
else:
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(decoder_dim, decoder_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, y):
"""
query from decoder (x), key and value from encoder (y)
"""
B, N, C = x.shape
Ny = y.shape[1]
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(y).reshape(B, Ny, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
if is_torch2:
attn = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.attn_drop,
)
x = attn.transpose(1, 2).reshape(B, N, C)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class CrossAttentionBlock(nn.Module):
def __init__(self, encoder_dim, decoder_dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, self_attn=False):
super().__init__()
self.self_attn = self_attn
if self.self_attn:
self.norm0 = norm_layer(decoder_dim)
self.self_attn = Attention(
decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.norm1 = norm_layer(decoder_dim)
self.cross_attn = CrossAttention(
encoder_dim, decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(decoder_dim)
mlp_hidden_dim = int(decoder_dim * mlp_ratio)
self.mlp = Mlp(in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, y):
"""
x: decoder feature; y: encoder feature (after layernorm)
"""
if self.self_attn:
x = x + self.drop_path(self.self_attn(self.norm0(x)))
x = x + self.drop_path(self.cross_attn(self.norm1(x), y))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x, random_sample=False):
B, C, H, W = x.shape
assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
def handle_flash_attn(args):
sm = torch.cuda.get_device_capability(0)
# https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/
enable_flashattn = sm[0] >= 8 or (sm[0] == 7 and sm[1] >= 5)
print(f"enable_flashattn: {enable_flashattn}")
if args.enable_flash_attention2:
print("Flash attention 2 enabled")
# This requies installing https://github.com/Dao-AILab/flash-attention/tree/v2.2.3
assert enable_flashattn, "Flash attn requires compute capabilities"
from flash_attn import flash_attn_func
torch_scaled_dot_product_attention = F.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
# torch convention: B, num heads, seq len, C
# print(f"Using flash attention, query: {query.shape}, key: {key.shape}, value: {value.shape}")
assert attn_mask is None, attn_mask
if query.shape[-1] > 256:
return torch_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return torch.permute(flash_attn_func(torch.permute(query, [0, 2, 1, 3]), torch.permute(key, [0, 2, 1, 3]), torch.permute(value, [0, 2, 1, 3]), dropout_p=dropout_p, causal=is_causal), [0, 2, 1, 3])
F.scaled_dot_product_attention = scaled_dot_product_attention
# Use memory efficient attention as a fallback
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
else:
print("Flash attention 2 is not enabled. Using built-in attention implementation.")
torch.backends.cuda.enable_flash_sdp(enable_flashattn)
torch.backends.cuda.enable_mem_efficient_sdp(not enable_flashattn)
torch.backends.cuda.enable_math_sdp(False)