diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 53932dee1..e64c84fea 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,8 +1,9 @@ """ - transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.) + transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.) Transformer as used in the base ViT architecture. -([reference](https://arxiv.org/abs/2010.11929)). + +See the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). # Arguments - `planes`: number of input channels @@ -26,7 +27,8 @@ end emb_dropout = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. -([reference](https://arxiv.org/abs/2010.11929)). + +See the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). # Arguments - `imsize`: image size @@ -35,7 +37,7 @@ Creates a Vision Transformer (ViT) model. - `embedplanes`: the number of channels after the patch embedding - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer -- `mlpplanes`: number of hidden channels in the MLP block in the transformer +- `mlp_ratio`: ratio of MLP layers to the number of input channels - `dropout`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean @@ -45,8 +47,7 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, emb_dropout = 0.1, pool = :class, nclasses = 1000) - @assert pool in [:class, :mean] - "Pool type must be either :class (class token) or :mean (mean pooling)" + @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" npatches = prod(imsize .รท patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), @@ -69,8 +70,9 @@ vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) -Creates a Vision Transformer (ViT) model. -([reference](https://arxiv.org/abs/2010.11929)). +Creates a Vision Transformer (ViT) model with a standard configuration. + +See the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). # Arguments - `mode`: the model configuration, one of [:tiny, :small, :base, :large, :huge, :giant, :gigantic] @@ -80,7 +82,7 @@ Creates a Vision Transformer (ViT) model. - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output -See also [`Metalhead.vit`](#). +See also [`Metalhead.vit`](#) for a more flexible constructor. """ struct ViT layers