Skip to content

Commit

Permalink
Add inspect_values option (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Jun 10, 2024
1 parent 57cd12f commit 8cee5a9
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
4 changes: 4 additions & 0 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3761,6 +3761,10 @@ defmodule Axon do
metrics. Also forwarded to JIT if debug mode is available
for your chosen compiler or backend. Defaults to `false`
* `:print_values` - if `true`, will print intermediate layer
values to the screen for inspection. This is useful if you need
to debug intermediate values of a model
* `:mode` - one of `:inference` or `:train`. Forwarded to layers
to control differences in compilation at training or inference time.
Defaults to `:inference`
Expand Down
42 changes: 36 additions & 6 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ defmodule Axon.Compiler do
raise_on_none? = Keyword.get(opts, :raise_on_none, true)
mode = Keyword.get(opts, :mode, :inference)
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
print_values = Keyword.get(opts, :print_values, false)
global_layer_options = Keyword.get(opts, :global_layer_options, [])
config = %{mode: mode, debug?: debug?, global_layer_options: global_layer_options}

config = %{
mode: mode,
debug?: debug?,
global_layer_options: global_layer_options,
print_values: print_values
}

{time, {root_id, {cache, _op_counts, _block_cache, model_state_meta}}} =
:timer.tc(fn ->
Expand Down Expand Up @@ -446,16 +453,21 @@ defmodule Axon.Compiler do
end

defp recur_model_funs(
%Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: policy},
%Axon.Node{id: id, name: name_fn, op: :constant, opts: [value: tensor], policy: policy},
_nodes,
{cache, op_counts, block_cache, model_state_meta},
_
%{print_values: print_values}
) do
name = name_fn.(:constant, op_counts)
op_counts = Map.update(op_counts, :constant, 1, fn x -> x + 1 end)
tensor = Nx.backend_copy(tensor, Nx.BinaryBackend)

predict_fun = fn _params, _inputs, state, _cache, result_cache, _fn_stacktrace ->
out = safe_policy_cast(tensor, policy, :output)
out =
tensor
|> safe_policy_cast(policy, :output)
|> maybe_print_values(name, print_values)

{out, {state, result_cache}}
end

Expand All @@ -477,7 +489,7 @@ defmodule Axon.Compiler do
},
_nodes,
{cache, op_counts, block_cache, model_state_meta},
%{mode: mode}
%{mode: mode, print_values: print_values}
) do
name = name_fn.(:input, op_counts)
op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end)
Expand All @@ -492,6 +504,7 @@ defmodule Axon.Compiler do
value
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> maybe_print_values(name, print_values)

{res, {state, result_cache}}
end
Expand Down Expand Up @@ -687,6 +700,8 @@ defmodule Axon.Compiler do
Map.put(state, block_name, out_state)
end

out_result = maybe_print_values(out_result, block_name, config.print_values)

{out_result, {state, result_cache}}
end
end
Expand Down Expand Up @@ -847,7 +862,12 @@ defmodule Axon.Compiler do
},
nodes,
cache_and_counts,
%{mode: mode, debug?: debug?, global_layer_options: global_layer_options} = config
%{
mode: mode,
debug?: debug?,
global_layer_options: global_layer_options,
print_values: print_values
} = config
)
when (is_function(op) or is_atom(op)) and is_list(inputs) do
# Traverse to accumulate cache and get parent_ids for
Expand Down Expand Up @@ -912,6 +932,7 @@ defmodule Axon.Compiler do
hooks,
mode,
global_layer_options,
print_values,
stacktrace
)

Expand Down Expand Up @@ -994,6 +1015,7 @@ defmodule Axon.Compiler do
hooks,
mode,
global_layer_options,
print_values,
layer_stacktrace
) do
# Recurse graph inputs and invoke cache to get parent results,
Expand Down Expand Up @@ -1113,6 +1135,8 @@ defmodule Axon.Compiler do
{new_out, state}
end

out = maybe_print_values(out, name, print_values)

{out, {state, result_cache}}
end
end
Expand Down Expand Up @@ -1270,6 +1294,12 @@ defmodule Axon.Compiler do
defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param)
defp maybe_freeze(param, false), do: param

defp maybe_print_values(value, layer, true) do
Nx.Defn.Kernel.print_value(value, label: layer)
end

defp maybe_print_values(value, _, _), do: value

defp apply_hooks(res, event, mode, hooks) do
hooks
|> Enum.reverse()
Expand Down
23 changes: 23 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5716,4 +5716,27 @@ defmodule CompilerTest do
assert Nx.shape(out) == {1, 20, 32}
end
end

describe "inspect values" do
test "prints intermediate layer values to the screen" do
model =
Axon.input("x")
|> Axon.dense(10, name: "foo")
|> Axon.dense(4, name: "bar")

{init_fn, predict_fn} = Axon.build(model, print_values: true)
input = Nx.broadcast(1, {1, 10})

model_state = init_fn.(input, ModelState.empty())

out =
ExUnit.CaptureIO.capture_io(fn ->
predict_fn.(model_state, input)
end)

assert out =~ "x:"
assert out =~ "foo:"
assert out =~ "bar:"
end
end
end

0 comments on commit 8cee5a9

Please sign in to comment.