Skip to content

Commit

Permalink
Merge #1305
Browse files Browse the repository at this point in the history
1305: Updates to outdims r=CarloLucibello a=darsnack

Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one:
- `outdims` for generic functions
- Size checking for `outdims(::Dense, isize)`

I also added the following additional changes
- Removed type signature restrictions on `outdims` for generic functions
- Added `outdims` for normalization layers
    - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building
    - Right now there is a method error
    - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible
- Updated docs for `outdims` changes

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [x] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: lorenzoh <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
4 people authored Dec 30, 2020
2 parents e032be1 + 438db24 commit 12281a2
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 168 deletions.
23 changes: 1 addition & 22 deletions docs/src/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,25 +218,4 @@ Flux.@functor Affine

This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).

## Utility functions

Flux provides some utility functions to help you generate models in an automated fashion.

`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size.
Currently limited to the following layers:
- `Chain`
- `Dense`
- `Conv`
- `Diagonal`
- `Maxout`
- `ConvTranspose`
- `DepthwiseConv`
- `CrossCor`
- `MaxPool`
- `MeanPool`

```@docs
Flux.outdims
```
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
48 changes: 48 additions & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,54 @@ Flux.glorot_uniform
Flux.glorot_normal
```

## Model Building

Flux provides some utility functions to help you generate models in an automated fashion.

[`outputsize`](@ref) enables you to calculate the output sizes of layers like [`Conv`](@ref)
when applied to input samples of a given size. This is achieved by passing a "dummy" array into
the model that preserves size information without running any computation.
`outputsize(f, inputsize)` works for all layers (including custom layers) out of the box.
By default, `inputsize` expects the batch dimension,
but you can exclude the batch size with `outputsize(f, inputsize; padbatch=true)` (assuming it to be one).

Using this utility function lets you automate model building for various inputs like so:
```julia
"""
make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
Create a CNN for a given set of configuration parameters.
# Arguments
- `width`: the input image width
- `height`: the input image height
- `inchannels`: the number of channels in the input image
- `nclasses`: the number of output classes
- `layer_config`: a vector of the number of filters per each conv layer
"""
function make_model(width, height, inchannels, nclasses;
layer_config = [16, 16, 32, 32, 64, 64])
# construct a vector of conv layers programmatically
conv_layers = [Conv((3, 3), inchannels => layer_config[1])]
for (infilters, outfilters) in zip(layer_config, layer_config[2:end])
push!(conv_layers, Conv((3, 3), infilters => outfilters))
end

# compute the output dimensions for the conv layers
# use padbatch=true to set the batch dimension to 1
conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true)

# the input dimension to Dense is programatically calculated from
# width, height, and nchannels
return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses))
end
```

```@docs
Flux.outputsize
```

## Model Abstraction

```@docs
Expand Down
2 changes: 2 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")

include("outputsize.jl")

include("data/Data.jl")

include("losses/Losses.jl")
Expand Down
1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate outdims(f, inputsize) outputsize(f, inputsize)
42 changes: 0 additions & 42 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,6 @@ function Base.show(io::IO, c::Chain)
print(io, ")")
end

"""
outdims(c::Chain, isize)
Calculate the output dimensions given the input dimensions, `isize`.
```jldoctest
julia> using Flux: outdims
julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32));
julia> outdims(m, (10, 10)) == (6, 6)
true
```
"""
outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize)

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
Expand Down Expand Up @@ -158,28 +142,6 @@ end
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
outdims(l::Dense, isize)
Calculate the output dimensions given the input dimensions, `isize`.
```jldoctest
julia> using Flux: outdims
julia> m = Dense(10, 5);
julia> outdims(m, (10, 100)) == (5,)
true
julia> outdims(m, (10,)) == (5,)
true
```
"""
function outdims(l::Dense, isize)
first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)),), got $isize"))
return (size(l.W, 1),)
end

"""
Diagonal(in::Integer)
Expand Down Expand Up @@ -209,8 +171,6 @@ function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end

outdims(l::Diagonal, isize) = (length(l.α),)

"""
Maxout(over)
Expand Down Expand Up @@ -254,8 +214,6 @@ function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

outdims(l::Maxout, isize) = outdims(first(l.over), isize)

"""
SkipConnection(layer, connection)
Expand Down
29 changes: 0 additions & 29 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ using NNlib: conv, ∇conv_data, depthwiseconv, output_size
# pad dims of x with dims of y until ndims(x) == ndims(y)
_paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...)

_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end])

expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)

Expand Down Expand Up @@ -188,21 +186,6 @@ end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
outdims(l::Conv, isize::Tuple)
Calculate the output dimensions given the input dimensions `isize`.
Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl).
```julia
m = Conv((3, 3), 3 => 16)
outdims(m, (10, 10)) == (8, 8)
outdims(m, (10, 10, 1, 3)) == (8, 8)
```
"""
outdims(l::Conv, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
Expand Down Expand Up @@ -311,8 +294,6 @@ end
(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad)

function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T}
calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride)
end
Expand Down Expand Up @@ -425,9 +406,6 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

outdims(l::DepthwiseConv, isize) =
output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
Expand Down Expand Up @@ -521,9 +499,6 @@ end
(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

outdims(l::CrossCor, isize) =
output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation))

"""
AdaptiveMaxPool(out::NTuple)
Expand Down Expand Up @@ -744,8 +719,6 @@ end
_maybetuple_string(pad) = string(pad)
_maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad)

outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))

"""
MeanPool(window::NTuple; pad=0, stride=window)
Expand Down Expand Up @@ -798,5 +771,3 @@ function Base.show(io::IO, m::MeanPool)
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
print(io, ")")
end

outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad))
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,4 +420,4 @@ function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(join(size(l.β), ", "))")
(l.λ == identity) || print(io, ", λ = $(l.λ)")
print(io, ")")
end
end
128 changes: 128 additions & 0 deletions src/outputsize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
module NilNumber

using NNlib

"""
Nil <: Number
Nil is a singleton type with a single instance `nil`.
Unlike `Nothing` and `Missing` it subtypes `Number`.
"""
struct Nil <: Number end

const nil = Nil()

Nil(::T) where T<:Number = nil
(::Type{T})(::Nil) where T<:Number = nil
Base.convert(::Type{Nil}, ::Number) = nil

Base.float(::Type{Nil}) = Nil

for f in [:copy, :zero, :one, :oneunit,
:+, :-, :abs, :abs2, :inv,
:exp, :log, :log1p, :log2, :log10,
:sqrt, :tanh, :conj]
@eval Base.$f(::Nil) = nil
end

for f in [:+, :-, :*, :/, :^, :mod, :div, :rem]
@eval Base.$f(::Nil, ::Nil) = nil
end

Base.isless(::Nil, ::Nil) = true
Base.isless(::Nil, ::Number) = true
Base.isless(::Number, ::Nil) = true

Base.isnan(::Nil) = false

Base.typemin(::Type{Nil}) = nil
Base.typemax(::Type{Nil}) = nil

Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil

end # module

using .NilNumber: Nil, nil

"""
outputsize(m, inputsize::Tuple; padbatch=false)
Calculate the output size of model `m` given the input size.
Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`.
Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and
returns the final size including this extra batch dimension.
This should be faster than calling `size(m(x))`. It uses a trivial number type,
and thus should work out of the box for custom layers.
If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`.
# Examples
```jldoctest
julia> using Flux: outputsize
julia> outputsize(Dense(10, 4), (10,); padbatch=true)
(4, 1)
julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32));
julia> m(randn(Float32, 10, 10, 3, 64)) |> size
(6, 6, 32, 64)
julia> outputsize(m, (10, 10, 3); padbatch=true)
(6, 6, 32, 1)
julia> outputsize(m, (10, 10, 3, 64))
(6, 6, 32, 64)
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end
DimensionMismatch("Input channels must match! (7 vs. 3)")
julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1))
(2, 1)
julia> using LinearAlgebra: norm
julia> f(x) = x ./ norm.(eachcol(x));
julia> outputsize(f, (10, 1)) # manually specify batch size as 1
(10, 1)
julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size
(10, 1)
```
"""
function outputsize(m, inputsize::Tuple; padbatch=false)
inputsize = padbatch ? (inputsize..., 1) : inputsize

return size(m(fill(nil, inputsize)))
end

## make tuples and vectors be like Chains

outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch)
outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch)

## bypass statistics in normalization layers

for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm)
@eval (l::$layer)(x::AbstractArray{Nil}) = x
end

## fixes for layers that don't work out of the box

for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
@eval begin
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)
fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end])
end

function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims)
NNlib.$fn(fill(nil, size(a)), b, dims)
end

function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims)
NNlib.$fn(a, fill(nil, size(b)), dims)
end
end
end
Loading

0 comments on commit 12281a2

Please sign in to comment.