Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fold and unfold #444

Merged
merged 6 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("fold.jl")
export unfold, unfold!, fold, fold!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried these names may be too common to export. scatter collided with every plotting library...

It's not working for me right now but https://juliahub.com may be able to tell us.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible name confusion with Base too. Given these functions are somewhat domain-specific, I agree it would be better to keep them unexported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem, makes sense. That juliahub tool is very useful, thanks for showing.


include("ctc.jl")
export ctc_loss

Expand Down
137 changes: 137 additions & 0 deletions src/fold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

"""
unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = false)

Places sliding windows of x into a container tensor of size (num_windows, window_size, batchsize).
The window size is determined by the prod(spatial dims of kernel)*input_channels.
The number of sliding windows will match those of convolution (conv) with the same kernel_size and arguments.
Uses NNlib.im2col! as backend.
"""
function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, K, N}
stride = expand(Val(N - 2), stride)
padding = expand(Val(N - 2), pad)
dilation = expand(Val(N - 2), dilation)
cdims = DenseConvDims(size(x), kernel_size; stride, padding, dilation, flipkernel=flipped)
return unfold(x, cdims)
end

"""
fold(y, output_size, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = false)

Accumulates sliding windows from the output of unfold into a container tensor of size `output_size`.
An inverse to `unfold` can be obtained by using `fold` and accounting for scaling issues.
For example,

```jldoctest
julia> kernel_size, pad = (3, 3, 1, 1), 1;

julia> x = reshape(1:64, 8, 8, 1, 1) |> collect;

julia> y = unfold(x, kernel_size; pad=pad);

julia> size(y)
(64, 9, 1)

julia> z = fold(y, size(x), kernel_size; pad=pad);

julia> d = fold(unfold(ones(eltype(x), size(x)...), kernel_size; pad=pad), size(x), kernel_size; pad=pad)
8×8×1×1 Array{Int64, 4}:
[:, :, 1, 1] =
4 6 6 6 6 6 6 4
6 9 9 9 9 9 9 6
6 9 9 9 9 9 9 6
6 9 9 9 9 9 9 6
6 9 9 9 9 9 9 6
6 9 9 9 9 9 9 6
6 9 9 9 9 9 9 6
4 6 6 6 6 6 6 4

julia> x == z./d
true

```
Uses NNlib.col2im! as backend.
"""
function fold(x::AbstractArray{T, 3}, output_size::NTuple{N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = false) where {T, K, N}
stride = expand(Val(N - 2), stride)
padding = expand(Val(N - 2), pad)
dilation = expand(Val(N - 2), dilation)
cdims = DenseConvDims(output_size, kernel_size; stride, padding, dilation, flipkernel=flipped)
return fold(x, output_size, cdims)
end

# im2col_dims returns (numblocks, blocksize, threadnum) where thread dim is used as thread-local
# workspace for multithreaded conv. Ultimately, we want to threadnum with batchsize.
unfold_dims(cdims::DenseConvDims) = im2col_dims(cdims)[1:2]

# auto-allocating versions
function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}
y = similar(x, unfold_dims(cdims)..., size(x, N)) # (numblocks, blocksize, batchsize)
return unfold!(y, x, cdims)
end

function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}
x = similar(y, output_size)
return fold!(x, y, cdims)
end

# N < 5 -dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}
unfold!(
y,
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
)
return y
end

function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}
fold!(
insert_singleton_spatial_dimension(x, 5-N),
y,
insert_singleton_spatial_dimension(cdims, 5-N),
)
return x
end

# 5-dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}
@threads for batch_idx in 1:size(x, 5)
y_slice = view(y, :, :, batch_idx)
im2col!(y_slice, view(x, :, :, :, :, batch_idx), cdims)
end
return y
end

function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {xT, yT}
@threads for batch_idx in 1:size(x, 5)
y_slice = view(y, :, :, batch_idx)
col2im!(view(x, :, :, :, :, batch_idx), y_slice, cdims)
end
return x
end

# reverse diff rules
function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)
function unfold_pullback(Δ)
return (
NoTangent(),
fold(unthunk(Δ), size(x), cdims; kw...),
NoTangent(),
)
end
return unfold(x, cdims; kw...), unfold_pullback
end

function rrule(::typeof(fold), x, output_size, cdims::DenseConvDims; kw...)
function fold_pullback(Δ)
return (
NoTangent(),
unfold(unthunk(Δ), cdims; kw...),
NoTangent(),
NoTangent(),
)
end
return fold(x, output_size, cdims; kw...), fold_pullback
end

38 changes: 38 additions & 0 deletions test/fold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using NNlib, Test

@testset "unfold wrapper" begin
x = rand(rng, 16, 16, 3, 10)
w = rand(rng, 5, 5, 3, 2)
@test size(unfold(x, size(w))) == (144, 75, 10)
@test size(unfold(x, size(w); pad=2)) == (256, 75, 10)
@test size(unfold(x, size(w); stride=2)) == (36, 75, 10)
@test size(unfold(x, size(w); dilation=2)) == (64, 75, 10)
end

@testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([8], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w; padding=1)
y = unfold(x, cdims)
z = fold(y, size(x), cdims)
divisor = fold(unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims)
@test isapprox(z ./ divisor, x, rtol=1.0e-7)

# introduce stride
cdims = DenseConvDims(x, w; padding=1, stride=2)
y = unfold(x, cdims)
z = fold(y, size(x), cdims)
divisor = fold(unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims)
@test isapprox(z ./ divisor, x, rtol=1.0e-7)
end

@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
gradtest(x -> sum(unfold(x, cdims)), x)

y = unfold(x, cdims)
gradtest(y -> sum(fold(y, size(x), cdims)), y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gradtest(x -> sum(unfold(x, cdims)), x)
y = unfold(x, cdims)
gradtest(y -> sum(fold(y, size(x), cdims)), y)
gradtest(unfold, x, cdims; check_rrule=true)
y = unfold(x, cdims)
gradtest(fold, y, size(x), cdims; check_rrule=true)

Should save a lambda and test a little more at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think FiniteDifferences is causing an error in gradtest by trying to perturb the arguments of cdims.

AutoDiff: spatial_rank=1: Error During Test at /home/nikopj/.julia/dev/NNlib/test/fold.jl:29
  Got exception outside of a @test
  TypeError: in new, expected Tuple{Int64}, got a value of type Tuple{Float64}
...

I can pass the finite differences test by passing the function only as an argument of the input array, and I can pass the CRC rrule test by calling it separately. Looking at test/conv.jl, it seems to be doing a similar game with gradtest. The below change passes finite differences and rrule test.

gradtest(x -> unfold(x, cdims), x)
test_rrule(unfold, x, cdims)

end

4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ include("test_utils.jl")
include("ctc.jl")
end

@testset "Fold/Unfold" begin
include("fold.jl")
end

@testset "Inference" begin
include("inference.jl")
end
Expand Down