-
Notifications
You must be signed in to change notification settings - Fork 0
/
position_embeddings.py
46 lines (36 loc) · 1.72 KB
/
position_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import tensorflow as tf
class PositionEmbeddingSine(tf.keras.Model):
def __init__(self, num_pos_features=64, temperature=10000,
normalize=False, scale=None, eps=1e-6, **kwargs):
# These are the default parameters used in the original project
super().__init__(**kwargs)
self.num_pos_features = num_pos_features
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError('normalize should be "True" if scale is passed')
if scale is None:
scale = 2 * np.pi
self.scale = scale
self.eps = eps
def call(self, mask):
not_mask = tf.cast(~mask, tf.float32)
y_embed = tf.math.cumsum(not_mask, axis=1)
x_embed = tf.math.cumsum(not_mask, axis=2)
if self.normalize:
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = tf.range(self.num_pos_features, dtype=tf.float32)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_features)
pos_x = x_embed[..., tf.newaxis] / dim_t
pos_y = y_embed[..., tf.newaxis] / dim_t
pos_x = tf.stack([tf.math.sin(pos_x[..., 0::2]),
tf.math.cos(pos_x[..., 1::2])], axis=4)
pos_y = tf.stack([tf.math.sin(pos_y[..., 0::2]),
tf.math.cos(pos_y[..., 1::2])], axis=4)
shape = [tf.shape(pos_x)[i] for i in range(3)] + [-1]
pos_x = tf.reshape(pos_x, shape)
pos_y = tf.reshape(pos_y, shape)
pos_emb = tf.concat([pos_y, pos_x], axis=3)
return pos_emb