Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to outdims #1305

Merged
merged 38 commits into from
Dec 30, 2020
Merged

Updates to outdims #1305

merged 38 commits into from
Dec 30, 2020

Conversation

darsnack
Copy link
Member

@darsnack darsnack commented Aug 5, 2020

Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one:

  • outdims for generic functions
  • Size checking for outdims(::Dense, isize)

I also added the following additional changes

  • Removed type signature restrictions on outdims for generic functions
  • Added outdims for normalization layers
    • This is helpful since BatchNorm etc. show up in a chain or array of layers frequently when model building
    • Right now there is a method error
    • Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible
  • Updated docs for outdims changes

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • Final review from @dhairyagandhi96 (for API changes).

@darsnack
Copy link
Member Author

darsnack commented Aug 5, 2020

Wanted to get the ball rolling on this. As you can see, the tests for the generic outdims fails. I think we should decide what API we want to support across all layers and functions as discussed in #1086.

I think the API should expect isize to be same number of dimensions as the input to the function (with the exception that the batch dimension can be left out). If the batch dimension is left out, I think we should assume it is 1 for generic function evaluation. The returned tuple should match the dimensions of isize (so if someone is ignoring batch, then we ignore it too).

@DhairyaLGandhi
Copy link
Member

Thank you for looking into this!

The implementation should ideally not care whether the batch size has been provided, but should return the correct batch size if one is provided. It's a bit tricky to ignore the last dimension, since for different future layers like for example 3D data, it might be that we have more dimensions than 4, so ignoring the last dimension needs to be threaded carefully through previous transforming layers.

@darsnack
Copy link
Member Author

darsnack commented Aug 5, 2020

It's a bit tricky to ignore the last dimension, since for different future layers like for example 3D data, it might be that we have more dimensions than 4, so ignoring the last dimension needs to be threaded carefully through previous transforming layers.

My thoughts on implementing this would be a _handle_batch(f, isize, ndims) function where f is the standard output dimension calculation as a function of isize and ndims is the expected number of dimensions. For example, Dense might look like

function outdims(l::Dense, isize)
  calc_dims = isize -> begin
    first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)), ...), got $isize"))
    return (size(l.W, 1), Base.tail(isize)...)
  end

  return _handle_batch(calc_dims, isize, 2)
end

_handle_batch would then know that Dense expects a 2D input (including batch) and accordingly pad isize with a 1 for the batch dimension and drop the batch (if necessary) after calc_dims returns.

@DhairyaLGandhi
Copy link
Member

Possibly something like this would work. This would just amount to a ntuple call for handle_batch

@darsnack
Copy link
Member Author

darsnack commented Aug 5, 2020

Yeah though I guess ndims is unknown for generic functions. In this case, I think we can enforce that isize passed to generic functions includes the batch dimension. I don't foresee that being a pain point, since people should only really be relying on the generic case when it's nested inside a chain. We can modify the chain implementation to pad the batch at the start of the chain and drop at the end but maintain through the middle.

@darsnack
Copy link
Member Author

darsnack commented Aug 7, 2020

What conv-style layers can work on >3D arrays? I'll need to adjust the expected dimensions based on the weight array. Currently some layers are hardcoded to expect 4 dimensions (including batch).

@darsnack darsnack changed the title Updates to outdims for normalisation and generic functions Updates to outdims Aug 9, 2020
@darsnack
Copy link
Member Author

darsnack commented Aug 9, 2020

@DhairyaLGandhi I think this is ready for review then merge. Summary of changes:

  • Added graceful handling of batch dimension
  • Added outdims for adaptive and global pooling
  • Added outdims for SkipConnection
  • Updated generic function outdims to allow functions with many arguments
  • Updated Chain outdims to also work on vectors/tuples of layers
  • Added more thorough testing
  • Updated docs with example

@darsnack
Copy link
Member Author

Bump

@DhairyaLGandhi
Copy link
Member

what do you think can be done to remove the dependency on having methods defined for every new layer?

@darsnack
Copy link
Member Author

Referencing #1086, I don't think that this is avoidable for a certain subset of primitives. For example, I don't think we can avoid manually specifying this for Conv. There is of course the fallback method of evaluating the function on a sample input which exists in the PR.

We could mix these two versions of outdims to automatically generate the output size for higher-order layers like SkipConnection. Right now, it is manually specified how to evaluate the sub-layers within SkipConnection. But this is essentially similar to AD, in that we evaluate the trace until we hit a primitive rule. So we could have "check and evaluate" mode which replaces all the function calls in the AST with calls to outdims. Then when we evaluate the transformed AST, we will hit the primitive outdims definitions for stuff like Conv and the fallback definition for higher-order methods like SkipConnection. Similar to how Zygote effectively replaces the entire tree with calls to pullback in place of the normal function call.

Personally, I don't think such an effort is worth it right now. outdims needs a usability update to be functional, so we should merge this PR and focus on an automated outdims for another PR.

src/outdims.jl Outdated Show resolved Hide resolved
@lorenzoh
Copy link
Member

See this Zulip thread. @darsnack said it could replace the default case for outdims.

@darsnack
Copy link
Member Author

@CarloLucibello I think this is ready to merge and should fix many issues related to outdims. I will submit a separate PR with @lorenzoh 's changes that will give us a more flexible implementation of outdims. The API will remain the same so I see no issue with multiple PRs.

@CarloLucibello
Copy link
Member

this need a rebase

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to write a clearer docstring, and added some doctests.

Not quite sure what you intended to do in which PR, no real comment on the changes.

src/outdims.jl Outdated Show resolved Hide resolved
src/outdims.jl Outdated Show resolved Hide resolved
src/outdims.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member Author

Thanks @mcabbott, I think the "spatial output" phrasing makes things more clear. There were some inconsistencies with the doctests expectation and what this PR actually does. So I added your changes w/ some modifications.

The intent of the PR is to refactor the outdims implementation to squash some long-standing bugs and make it more intuitive to use. Now the isize argument includes every dimension a user would pass if they were actually calling the model. The only exception is the batch dimension which can be optionally left off. The implementation will detect when the batch dimension is specified by the user and keep/drop it accordingly on the returned result.

@CarloLucibello If the rebase was the only change, then I think this is good to go once the CI passes (passed locally for me).

src/outdims.jl Outdated Show resolved Hide resolved
src/outdims.jl Outdated Show resolved Hide resolved
src/outdims.jl Outdated Show resolved Hide resolved
src/outdims.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member Author

Okay so under a separate branch (darsnack/better-outdims), I have implement what Lorenz and I discussed in that Zulip thread. The basic idea is to create a Nil number type that voids all scalar operations. The advantage to this method is that every custom layer is automatically supported, since there is only one outdims definition that does size(m(fill(nil, isize))). It is probably easier to maintain. The downside is that there are some issues with the warnings from NNlib for fallback convolution that requires a dependency to Logging.jl. It also includes the batch handling as a special case, and any custom layer can support batch handling by defining dimhint(layer) (this does not need to be defined for outdims to work, but you will need to include the batch w/ isize).

Here is a performance comparison:

julia> m = Chain(Conv((3, 3), 3 => 16), flatten, Dense(30*30*16, 10))
Chain(Conv((3, 3), 3=>16), flatten, Dense(14400, 10))

# outdims w/ current implementation
julia> @benchmark outdims($m, (32, 32, 3))
BenchmarkTools.Trial: 
  memory estimate:  59.88 KiB
  allocs estimate:  80
  --------------
  minimum time:     12.662 μs (0.00% GC)
  median time:      41.310 μs (0.00% GC)
  mean time:        53.773 μs (17.89% GC)
  maximum time:     8.441 ms (98.75% GC)
  --------------
  samples:          10000
  evals/sample:     1

# outdims w/ nil implementation
julia> @benchmark outdims($m, (32, 32, 3))
BenchmarkTools.Trial: 
  memory estimate:  3.41 KiB
  allocs estimate:  66
  --------------
  minimum time:     200.629 μs (0.00% GC)
  median time:      236.349 μs (0.00% GC)
  mean time:        258.500 μs (0.00% GC)
  maximum time:     7.229 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

Using the nil implementation is slower, but it is still fast and much faster than actually evaluating the model.

@mcabbott
Copy link
Member

That's a neat idea. Any chance missing would work in place of nil? @btime fill(missing, 1000,1000) * fill(missing, 1000); is pretty quick, and tanh(missing) === missing while tanh(nil) would have to be written... sadly fill(missing, 10,10) * fill(missing, 10,10) is an error, but perhaps could be fixed?

@darsnack
Copy link
Member Author

Yeah looks like Missing is an option:

julia> @benchmark outdims($m, (32, 32, 3))
BenchmarkTools.Trial: 
  memory estimate:  3.41 KiB
  allocs estimate:  66
  --------------
  minimum time:     194.253 μs (0.00% GC)
  median time:      201.545 μs (0.00% GC)
  mean time:        210.417 μs (0.00% GC)
  maximum time:     505.264 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

Still need to check the test cases, but it looks like an option! @lorenzoh just wanted to double check that you didn't try missing and eliminate it for some reason.

@lorenzoh
Copy link
Member

The reason it doesn't use missing is that it is not a subtype of Number and some functions may restrict the input arguments to some AbstractArray{<:Number,}.

It also avoids type piracy for those edge cases where you have to overwrite a function.

@mcabbott
Copy link
Member

I guess that's a concern, Flux's own layers don't restrict, but others might. Are there examples? Transformers.jl doesn't seem to restrict.

If you do go with nil, then perhaps it should be a special zero <:Integer perhaps better <:AbstractFloat, since Real would also be a plausible type restriction.

It probably wouldn't be hard to get a method for * of matrices of missings added to Julia, to avoid piracy there.

@darsnack
Copy link
Member Author

After trying Nil <: AbstractFloat, I think it might be better to do <: Number. There are a lot of primate number related functions (e.g. *(::Bool, ::AbstractFloat) and <(::AbstractFloat, AbstractFloat)) that are specialized for floats. In contrast, most primitive functions on Number fallback to a smaller set of Base functions. There are also layers in Flux that specialize on AbstractFloat inputs to hit BLAS.

It might be easier to override a method early specifically for Nil inputs when the default signature has AbstractFloat restrictions. Currently, there are lots of little cases that need to be overridden for AbstractFloat, and I don't think the tests based on Flux layers will capture all the little cases users will encounter.

@darsnack
Copy link
Member Author

Any last comments on the nil approach? If not, I am going to update darsnack#master to use it.

docs/src/utilities.md Outdated Show resolved Hide resolved
src/deprecations.jl Outdated Show resolved Hide resolved
src/deprecations.jl Outdated Show resolved Hide resolved
Co-authored-by: Carlo Lucibello <[email protected]>
@CarloLucibello
Copy link
Member

I'll wait for green CI then merge, thanks @darsnack for the patience of seeing this through

@gxyd
Copy link
Contributor

gxyd commented Dec 26, 2020

Since that'll be a deprecation of outdims, shouldn't depreciation raise warning (when deprecated function is used for the first time) for the user to move to the use of corresponding function?

@darsnack
Copy link
Member Author

Since that'll be a deprecation of outdims, shouldn't depreciation raise warning (when deprecated function is used for the first time) for the user to move to the use of corresponding function?

I think the issue is that we can't replicate the old batch handling behavior.

Sorry there appear to be some remaining test failures. Fixed them and testing locally before pushing so we get green CI 🤞🏾

@darsnack
Copy link
Member Author

Okay all tests passed locally, so hopefully we will be GTM

@CarloLucibello
Copy link
Member

bors r+

bors bot added a commit that referenced this pull request Dec 26, 2020
1305: Updates to outdims r=CarloLucibello a=darsnack

Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one:
- `outdims` for generic functions
- Size checking for `outdims(::Dense, isize)`

I also added the following additional changes
- Removed type signature restrictions on `outdims` for generic functions
- Added `outdims` for normalization layers
    - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building
    - Right now there is a method error
    - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible
- Updated docs for `outdims` changes

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [x] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: lorenzoh <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
@bors
Copy link
Contributor

bors bot commented Dec 26, 2020

Timed out.

@CarloLucibello
Copy link
Member

bors r+

bors bot added a commit that referenced this pull request Dec 26, 2020
1305: Updates to outdims r=CarloLucibello a=darsnack

Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one:
- `outdims` for generic functions
- Size checking for `outdims(::Dense, isize)`

I also added the following additional changes
- Removed type signature restrictions on `outdims` for generic functions
- Added `outdims` for normalization layers
    - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building
    - Right now there is a method error
    - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible
- Updated docs for `outdims` changes

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [x] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: lorenzoh <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
@bors
Copy link
Contributor

bors bot commented Dec 27, 2020

Timed out.

@CarloLucibello
Copy link
Member

@DhairyaLGandhi @maleadt is bors down or just busy?

bors r+

@CarloLucibello CarloLucibello added this to the v0.12 milestone Dec 27, 2020
@CarloLucibello CarloLucibello linked an issue Dec 27, 2020 that may be closed by this pull request
bors bot added a commit that referenced this pull request Dec 27, 2020
1305: Updates to outdims r=CarloLucibello a=darsnack

Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one:
- `outdims` for generic functions
- Size checking for `outdims(::Dense, isize)`

I also added the following additional changes
- Removed type signature restrictions on `outdims` for generic functions
- Added `outdims` for normalization layers
    - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building
    - Right now there is a method error
    - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible
- Updated docs for `outdims` changes

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [x] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: lorenzoh <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
@bors
Copy link
Contributor

bors bot commented Dec 27, 2020

Timed out.

@CarloLucibello CarloLucibello linked an issue Dec 28, 2020 that may be closed by this pull request
@CarloLucibello
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Dec 30, 2020

Build succeeded:

@bors bors bot merged commit 12281a2 into FluxML:master Dec 30, 2020
@DhairyaLGandhi
Copy link
Member

I haven't followed this closely in the past week, but seems that the new type was merged recently? I am uncomfortable about the extra burden this puts on code bloat still, so I'd definitely expect that to be addressed

for f in [:copy, :zero, :one, :oneunit,
:+, :-, :abs, :abs2, :inv,
:exp, :log, :log1p, :log2, :log10,
:sqrt, :tanh, :conj]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to new layers and new functions that layers would need. That would need a wider catchall than adding it to a list like here.

Copy link
Member Author

@darsnack darsnack Dec 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New layers don't need any modifications to work. New functions might, but most functions are built using primitives such as the ones in this list. There might be some Base primitive functions we need to add (e.g. mod), but most functions shouldn't need changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what do you mean by wider catch-all? If it's the ability to adapt to functions that aren't defined here, then I think we would need to resort to meta-programming for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I mean the functions and the correct answer would be to say operations on numbers would need to be forwarded with a macro call

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, that kind of dispatch sounds like Cassette.jl.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ConvTranspose same padding and outdims errors outdims function doesn't work properly for chained layers
6 participants