diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 4ff68e445d..9e4d0783dd 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -218,25 +218,4 @@ Flux.@functor Affine This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). -For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md). - -## Utility functions - -Flux provides some utility functions to help you generate models in an automated fashion. - -`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size. -Currently limited to the following layers: -- `Chain` -- `Dense` -- `Conv` -- `Diagonal` -- `Maxout` -- `ConvTranspose` -- `DepthwiseConv` -- `CrossCor` -- `MaxPool` -- `MeanPool` - -```@docs -Flux.outdims -``` +For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md). \ No newline at end of file diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 95ef098ea5..d4d2e5ae93 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -35,6 +35,54 @@ Flux.glorot_uniform Flux.glorot_normal ``` +## Model Building + +Flux provides some utility functions to help you generate models in an automated fashion. + +[`outputsize`](@ref) enables you to calculate the output sizes of layers like [`Conv`](@ref) +when applied to input samples of a given size. This is achieved by passing a "dummy" array into +the model that preserves size information without running any computation. +`outputsize(f, inputsize)` works for all layers (including custom layers) out of the box. +By default, `inputsize` expects the batch dimension, +but you can exclude the batch size with `outputsize(f, inputsize; padbatch=true)` (assuming it to be one). + +Using this utility function lets you automate model building for various inputs like so: +```julia +""" + make_model(width, height, inchannels, nclasses; + layer_config = [16, 16, 32, 32, 64, 64]) + +Create a CNN for a given set of configuration parameters. + +# Arguments +- `width`: the input image width +- `height`: the input image height +- `inchannels`: the number of channels in the input image +- `nclasses`: the number of output classes +- `layer_config`: a vector of the number of filters per each conv layer +""" +function make_model(width, height, inchannels, nclasses; + layer_config = [16, 16, 32, 32, 64, 64]) + # construct a vector of conv layers programmatically + conv_layers = [Conv((3, 3), inchannels => layer_config[1])] + for (infilters, outfilters) in zip(layer_config, layer_config[2:end]) + push!(conv_layers, Conv((3, 3), infilters => outfilters)) + end + + # compute the output dimensions for the conv layers + # use padbatch=true to set the batch dimension to 1 + conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true) + + # the input dimension to Dense is programatically calculated from + # width, height, and nchannels + return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses)) +end +``` + +```@docs +Flux.outputsize +``` + ## Model Abstraction ```@docs diff --git a/src/Flux.jl b/src/Flux.jl index c1646b5296..b7851138d3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -41,6 +41,8 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") +include("outputsize.jl") + include("data/Data.jl") include("losses/Losses.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index ea0073922a..018fd1e63c 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -3,3 +3,4 @@ @deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing) +@deprecate outdims(f, inputsize) outputsize(f, inputsize) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ec5127cb60..911d747233 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -45,22 +45,6 @@ function Base.show(io::IO, c::Chain) print(io, ")") end -""" - outdims(c::Chain, isize) - -Calculate the output dimensions given the input dimensions, `isize`. - -```jldoctest -julia> using Flux: outdims - -julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); - -julia> outdims(m, (10, 10)) == (6, 6) -true -``` -""" -outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize) - # This is a temporary and naive implementation # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 @@ -158,28 +142,6 @@ end (a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -""" - outdims(l::Dense, isize) - -Calculate the output dimensions given the input dimensions, `isize`. - -```jldoctest -julia> using Flux: outdims - -julia> m = Dense(10, 5); - -julia> outdims(m, (10, 100)) == (5,) -true - -julia> outdims(m, (10,)) == (5,) -true -``` -""" -function outdims(l::Dense, isize) - first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)),), got $isize")) - return (size(l.W, 1),) -end - """ Diagonal(in::Integer) @@ -209,8 +171,6 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end -outdims(l::Diagonal, isize) = (length(l.α),) - """ Maxout(over) @@ -254,8 +214,6 @@ function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end -outdims(l::Maxout, isize) = outdims(first(l.over), isize) - """ SkipConnection(layer, connection) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a3e76b2556..9248d68c92 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -3,8 +3,6 @@ using NNlib: conv, ∇conv_data, depthwiseconv, output_size # pad dims of x with dims of y until ndims(x) == ndims(y) _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...) -_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end]) - expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -188,21 +186,6 @@ end (a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -""" - outdims(l::Conv, isize::Tuple) - -Calculate the output dimensions given the input dimensions `isize`. -Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl). - -```julia -m = Conv((3, 3), 3 => 16) -outdims(m, (10, 10)) == (8, 8) -outdims(m, (10, 10, 1, 3)) == (8, 8) -``` -""" -outdims(l::Conv, isize) = - output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) - """ ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) @@ -311,8 +294,6 @@ end (a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) - function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T} calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride) end @@ -425,9 +406,6 @@ end (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -outdims(l::DepthwiseConv, isize) = - output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) - """ CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) @@ -521,9 +499,6 @@ end (a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -outdims(l::CrossCor, isize) = - output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) - """ AdaptiveMaxPool(out::NTuple) @@ -744,8 +719,6 @@ end _maybetuple_string(pad) = string(pad) _maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) -outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) - """ MeanPool(window::NTuple; pad=0, stride=window) @@ -798,5 +771,3 @@ function Base.show(io::IO, m::MeanPool) m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) print(io, ")") end - -outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index fbb3221e3e..5f9e116f29 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -420,4 +420,4 @@ function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") -end +end \ No newline at end of file diff --git a/src/outputsize.jl b/src/outputsize.jl new file mode 100644 index 0000000000..88172c3ec5 --- /dev/null +++ b/src/outputsize.jl @@ -0,0 +1,128 @@ +module NilNumber + +using NNlib + +""" + Nil <: Number + +Nil is a singleton type with a single instance `nil`. +Unlike `Nothing` and `Missing` it subtypes `Number`. +""" +struct Nil <: Number end + +const nil = Nil() + +Nil(::T) where T<:Number = nil +(::Type{T})(::Nil) where T<:Number = nil +Base.convert(::Type{Nil}, ::Number) = nil + +Base.float(::Type{Nil}) = Nil + +for f in [:copy, :zero, :one, :oneunit, + :+, :-, :abs, :abs2, :inv, + :exp, :log, :log1p, :log2, :log10, + :sqrt, :tanh, :conj] + @eval Base.$f(::Nil) = nil +end + +for f in [:+, :-, :*, :/, :^, :mod, :div, :rem] + @eval Base.$f(::Nil, ::Nil) = nil +end + +Base.isless(::Nil, ::Nil) = true +Base.isless(::Nil, ::Number) = true +Base.isless(::Number, ::Nil) = true + +Base.isnan(::Nil) = false + +Base.typemin(::Type{Nil}) = nil +Base.typemax(::Type{Nil}) = nil + +Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil + +end # module + +using .NilNumber: Nil, nil + +""" + outputsize(m, inputsize::Tuple; padbatch=false) + +Calculate the output size of model `m` given the input size. +Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`. +Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and +returns the final size including this extra batch dimension. + +This should be faster than calling `size(m(x))`. It uses a trivial number type, +and thus should work out of the box for custom layers. + +If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`. + +# Examples +```jldoctest +julia> using Flux: outputsize + +julia> outputsize(Dense(10, 4), (10,); padbatch=true) +(4, 1) + +julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); + +julia> m(randn(Float32, 10, 10, 3, 64)) |> size +(6, 6, 32, 64) + +julia> outputsize(m, (10, 10, 3); padbatch=true) +(6, 6, 32, 1) + +julia> outputsize(m, (10, 10, 3, 64)) +(6, 6, 32, 64) + +julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end +DimensionMismatch("Input channels must match! (7 vs. 3)") + +julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) +(2, 1) + +julia> using LinearAlgebra: norm + +julia> f(x) = x ./ norm.(eachcol(x)); + +julia> outputsize(f, (10, 1)) # manually specify batch size as 1 +(10, 1) + +julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size +(10, 1) +``` +""" +function outputsize(m, inputsize::Tuple; padbatch=false) + inputsize = padbatch ? (inputsize..., 1) : inputsize + + return size(m(fill(nil, inputsize))) +end + +## make tuples and vectors be like Chains + +outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) +outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) + +## bypass statistics in normalization layers + +for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm) + @eval (l::$layer)(x::AbstractArray{Nil}) = x +end + +## fixes for layers that don't work out of the box + +for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) + @eval begin + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) + fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end]) + end + + function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) + NNlib.$fn(fill(nil, size(a)), b, dims) + end + + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) + NNlib.$fn(a, fill(nil, size(b)), dims) + end + end +end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 40afee5668..e1660812f0 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -106,25 +106,4 @@ import Flux: activations @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end - - @testset "output dimensions" begin - m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test Flux.outdims(m, (10, 10)) == (6, 6) - - m = Dense(10, 5) - @test_throws DimensionMismatch Flux.outdims(m, (5, 2)) == (5,) - @test Flux.outdims(m, (10,)) == (5,) - - m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) - @test Flux.outdims(m, (10,)) == (2,) - - m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) - @test_throws DimensionMismatch Flux.outdims(m, (10,)) - - m = Flux.Diagonal(10) - @test Flux.outdims(m, (10,)) == (10,) - - m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test Flux.outdims(m, (10, 10)) == (8, 8) - end -end +end \ No newline at end of file diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 76dab3c68d..458ce69191 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -160,58 +160,6 @@ end end end -@testset "conv output dimensions" begin - m = Conv((3, 3), 3 => 16) - @test Flux.outdims(m, (10, 10)) == (8, 8) - m = Conv((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (5, 5)) == (2, 2) - m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) - m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5)) == (4, 4) - - m = ConvTranspose((3, 3), 3 => 16) - @test Flux.outdims(m, (8, 8)) == (10, 10) - m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (2, 2)) == (5, 5) - m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) - m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (4, 4)) == (5, 5) - - m = DepthwiseConv((3, 3), 3 => 6) - @test Flux.outdims(m, (10, 10)) == (8, 8) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test Flux.outdims(m, (5, 5)) == (2, 2) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5)) == (4, 4) - - m = CrossCor((3, 3), 3 => 16) - @test Flux.outdims(m, (10, 10)) == (8, 8) - m = CrossCor((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (5, 5)) == (2, 2) - m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) - m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5)) == (4, 4) - - m = MaxPool((2, 2)) - @test Flux.outdims(m, (10, 10)) == (5, 5) - m = MaxPool((2, 2); stride = 1) - @test Flux.outdims(m, (5, 5)) == (4, 4) - m = MaxPool((2, 2); stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) - - m = MeanPool((2, 2)) - @test Flux.outdims(m, (10, 10)) == (5, 5) - m = MeanPool((2, 2); stride = 1) - @test Flux.outdims(m, (5, 5)) == (4, 4) - m = MeanPool((2, 2); stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) -end - @testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8)) data = ones(Float32, (k .+ 3)..., 1,1) l = ltype(k, 1=>1, pad=SamePad()) diff --git a/test/outputsize.jl b/test/outputsize.jl new file mode 100644 index 0000000000..dc8ad3023b --- /dev/null +++ b/test/outputsize.jl @@ -0,0 +1,134 @@ +@testset "basic" begin + m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) + @test outputsize(m, (10, 10, 3, 1)) == (6, 6, 32, 1) + + m = Dense(10, 5) + @test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1) + @test outputsize(m, (10,); padbatch=true) == (5, 1) + + m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) + @test outputsize(m, (10,); padbatch=true) == (2, 1) + @test outputsize(m, (10, 30)) == (2, 30) + + m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) + @test_throws DimensionMismatch outputsize(m, (10,)) + + m = Flux.Diagonal(10) + @test outputsize(m, (10, 1)) == (10, 1) + + m = Maxout(() -> Conv((3, 3), 3 => 16), 2) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + + m = flatten + @test outputsize(m, (5, 5, 3, 10)) == (75, 10) + + m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) + @test outputsize(m, (10, 10, 3, 50)) == (10, 50) + @test outputsize(m, (10, 10, 3, 2)) == (10, 2) + + m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) + @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1) +end + +@testset "activations" begin + @testset for f in [celu, elu, gelu, hardsigmoid, hardtanh, + leakyrelu, lisht, logcosh, logσ, mish, + relu, relu6, rrelu, selu, σ, softplus, + softshrink, softsign, swish, tanhshrink, trelu] + @test outputsize(Dense(10, 5, f), (10, 1)) == (5, 1) + end +end + +@testset "conv" begin + m = Conv((3, 3), 3 => 16) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + m = Conv((3, 3), 3 => 16; stride = 2) + @test outputsize(m, (5, 5, 3, 1)) == (2, 2, 16, 1) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (5, 5, 3, 1)) == (4, 4, 16, 1) + @test_throws DimensionMismatch outputsize(m, (5, 5, 2)) + @test outputsize(m, (5, 5, 3, 100)) == (4, 4, 16, 100) + + m = ConvTranspose((3, 3), 3 => 16) + @test outputsize(m, (8, 8, 3, 1)) == (10, 10, 16, 1) + m = ConvTranspose((3, 3), 3 => 16; stride = 2) + @test outputsize(m, (2, 2, 3, 1)) == (5, 5, 16, 1) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (4, 4, 3, 1)) == (5, 5, 16, 1) + + m = DepthwiseConv((3, 3), 3 => 6) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 6, 1) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2) + @test outputsize(m, (5, 5, 3, 1)) == (2, 2, 6, 1) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 6, 1) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (5, 5, 3, 1)) == (4, 4, 6, 1) + + m = CrossCor((3, 3), 3 => 16) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + m = CrossCor((3, 3), 3 => 16; stride = 2) + @test outputsize(m, (5, 5, 3, 1)) == (2, 2, 16, 1) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (5, 5, 3, 1)) == (4, 4, 16, 1) + + m = AdaptiveMaxPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (2, 2, 3, 1) + + m = AdaptiveMeanPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (2, 2, 3, 1) + + m = GlobalMaxPool() + @test outputsize(m, (10, 10, 3, 1)) == (1, 1, 3, 1) + + m = GlobalMeanPool() + @test outputsize(m, (10, 10, 3, 1)) == (1, 1, 3, 1) + + m = MaxPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (5, 5, 3, 1) + m = MaxPool((2, 2); stride = 1) + @test outputsize(m, (5, 5, 4, 1)) == (4, 4, 4, 1) + m = MaxPool((2, 2); stride = 2, pad = 3) + @test outputsize(m, (5, 5, 2, 1)) == (5, 5, 2, 1) + + m = MeanPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (5, 5, 3, 1) + m = MeanPool((2, 2); stride = 1) + @test outputsize(m, (5, 5, 4, 1)) == (4, 4, 4, 1) + m = MeanPool((2, 2); stride = 2, pad = 3) + @test outputsize(m, (5, 5, 2, 1)) == (5, 5, 2, 1) +end + +@testset "normalisation" begin + m = Dropout(0.1) + @test outputsize(m, (10, 10)) == (10, 10) + @test outputsize(m, (10,); padbatch=true) == (10, 1) + + m = AlphaDropout(0.1) + @test outputsize(m, (10, 10)) == (10, 10) + @test outputsize(m, (10,); padbatch=true) == (10, 1) + + m = LayerNorm(32) + @test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) + + m = BatchNorm(3) + @test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) + + m = InstanceNorm(3) + @test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) + + if VERSION >= v"1.1" + m = GroupNorm(16, 4) + @test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16) + @test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 65bc635072..84ee994b45 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,11 @@ end include("layers/conv.jl") end +@testset "outputsize" begin + using Flux: outputsize + include("outputsize.jl") +end + @testset "CUDA" begin if Flux.use_cuda[] include("cuda/runtests.jl")