Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch_norm states "mean" and "var" never updated #546

Closed
robinmonjo opened this issue Nov 14, 2023 · 1 comment
Closed

batch_norm states "mean" and "var" never updated #546

robinmonjo opened this issue Nov 14, 2023 · 1 comment

Comments

@robinmonjo
Copy link
Contributor

I have noticed a really strange behaviour that appears to be a bug. Here is a Livebook to demonstrate the bug:

debugging-batch-norm

Mix.install([
  {:nx, "~> 0.6.0"},
  {:axon, "~> 0.6.0"},
  {:kino, "~> 0.11.2"}
])

Batch norm layers not updated

So I have noticed this weird bug.

The state of the batch_norm layers (mean and var) are never updated ! They stay all 0 and all 1, but only when they are on the "right path" of the network.

Example:

model =
  Axon.input("input")
  |> Axon.dense(10)
  |> Axon.batch_norm()
#Axon<
  inputs: %{"input" => nil}
  outputs: "batch_norm_0"
  nodes: 3
>
{init, pred} = Axon.build(model, mode: :train)
params = init.(Nx.template({10, 10}, :f32), %{})
%{state: state} = pred.(params, Nx.iota({10, 10}))
state["batch_norm_0"]
%{
  "mean" => #Nx.Tensor<
    f32[10]
    [-47.73035430908203, 50.59987258911133, -40.22572326660156, 57.21146774291992, -6.546464443206787, 5.084614276885986, -30.406084060668945, 34.05552291870117, -27.80076789855957, 16.147790908813477]
  >,
  "var" => #Nx.Tensor<
    f32[10]
    [963.2650756835938, 967.86474609375, 611.614013671875, 1129.3345947265625, 8.083409309387207, 20.26643943786621, 438.4486083984375, 393.705810546875, 288.500244140625, 81.58602142333984]
  >
}

So here, in this configuration, batch_norm state is updated

Now let's build an example where it's not:

input = Axon.input("input")

l1 =
  Axon.dense(input, 10)

l2 =
  Axon.dense(input, 10)
  |> Axon.batch_norm()

model = Axon.add(l2, l1)
Axon.Display.as_graph(model, Nx.template({10, 10}, :f32))
Screenshot 2023-11-14 at 21 43 37
{init, pred} = Axon.build(model, mode: :train)
params = init.(Nx.template({10, 10}, :f32), %{})
%{state: state} = pred.(params, Nx.iota({10, 10}))
state["batch_norm_0"]
nil

So here batch_norm_0 state is never returned nor updated.

But when the batch_norm is "on the left of the network", it works:

model = Axon.add(l1, l2)
Axon.Display.as_graph(model, Nx.template({10, 10}, :f32))
Screenshot 2023-11-14 at 21 43 47
{init, pred} = Axon.build(model, mode: :train)
params = init.(Nx.template({10, 10}, :f32), %{})
%{state: state} = pred.(params, Nx.iota({10, 10}))
state["batch_norm_0"]
%{
  "mean" => #Nx.Tensor<
    f32[10]
    [38.54197311401367, -13.553568840026855, -53.8503532409668, -13.158158302307129, -38.54884338378906, 7.524081707000732, 29.492563247680664, 33.01778030395508, 19.012977600097656, 17.45140266418457]
  >,
  "var" => #Nx.Tensor<
    f32[10]
    [596.3408203125, 56.97801208496094, 1097.6343994140625, 151.06002807617188, 651.69677734375, 16.675113677978516, 378.416748046875, 327.9981689453125, 123.95314025878906, 161.44265747070312]
  >
}

I would be happy to help fixing this but I'm not yet very familiar with all the internals of Axon and have limited times 😊

@seanmor5
Copy link
Contributor

This has been fixed with the addition of model state in #553

Thanks for pointing out! I was having issues with RNNs in 553 and realized this was the same issue there!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants