Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide layer name as hook name #536

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,8 @@ defmodule Axon.Compiler do

res =
value
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> maybe_print_values(name, print_values)

{res, {state, result_cache}}
Expand Down Expand Up @@ -975,7 +975,7 @@ defmodule Axon.Compiler do
layer_input =
layer_input
|> safe_policy_cast(policy, :compute)
|> apply_hooks(:pre_forward, mode, hooks)
|> apply_hooks(name, :pre_forward, mode, hooks)

{layer_input, {state, result_cache, none?}}
end
Expand Down Expand Up @@ -1051,8 +1051,8 @@ defmodule Axon.Compiler do
%StatefulOutput{output: out, state: out_state} ->
new_out =
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> safe_policy_cast(policy, :output)

new_state = Map.put(state, name, out_state)
Expand All @@ -1061,8 +1061,8 @@ defmodule Axon.Compiler do
out ->
new_out =
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> apply_hooks(name, :forward, mode, hooks)
|> apply_hooks(name, :backward, mode, hooks)
|> safe_policy_cast(policy, :output)

{new_out, state}
Expand Down Expand Up @@ -1169,7 +1169,7 @@ defmodule Axon.Compiler do
init_param(layer_id, param, layer_params, parent_templates, dtype, keys)
end)

layer_params = apply_hooks(layer_params, :initialize, nil, hooks)
layer_params = apply_hooks(layer_params, name, :initialize, nil, hooks)

params =
if layer_params == %{} do
Expand Down Expand Up @@ -1228,7 +1228,7 @@ defmodule Axon.Compiler do

defp maybe_print_values(value, _, _), do: value

defp apply_hooks(res, event, mode, hooks) do
defp apply_hooks(res, layer_name, event, mode, hooks) do
hooks
|> Enum.reverse()
|> Enum.reduce(res, fn {on_event, on_mode, hook_fn}, expr ->
Expand All @@ -1238,11 +1238,11 @@ defmodule Axon.Compiler do
if event? and mode? do
if on_event == :backward do
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
hooked_g = Nx.Defn.Kernel.hook(g, String.to_atom(layer_name), hook_fn)
[hooked_g]
end)
else
Nx.Defn.Kernel.hook(expr, hook_fn)
Nx.Defn.Kernel.hook(expr, String.to_atom(layer_name), hook_fn)
end
else
expr
Expand Down
20 changes: 20 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4726,6 +4726,26 @@
assert_receive {%Nx.Tensor{}, :from_relu}
assert_receive {%Nx.Tensor{}, :from_sigmoid}
end

test "can be overriden at jit-time with layer name", config do
model =
Axon.input("input_0", shape: {nil, 1})
|> Axon.attach_hook(fn x -> send(config.test, {x, :from_input}) end, on: :forward)
|> Axon.relu()

inp = Nx.tensor([[1.0]])
{_, predict_fn} = Axon.build(model)

hook = fn val -> send(config.test, {val, :overridden}) end

fun = Nx.Defn.jit(predict_fn, hooks: %{input_0: hook})
apply(fun, [ModelState.empty(), inp])

assert_receive {from_inp, :overridden}
refute_receive {_, :from_input}

assert_equal(from_inp, inp)
end
end

describe "integrated models" do
Expand Down Expand Up @@ -5654,7 +5674,7 @@
end

describe "inspect values" do
test "prints intermediate layer values to the screen" do

Check failure on line 5677 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5, USE_EXLA=true)

test inspect values prints intermediate layer values to the screen (CompilerTest)
model =
Axon.input("x")
|> Axon.dense(10, name: "foo")
Expand Down
Loading