diff --git a/lib/axon.ex b/lib/axon.ex index 258f2b19..6489f1dd 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -3761,6 +3761,10 @@ defmodule Axon do metrics. Also forwarded to JIT if debug mode is available for your chosen compiler or backend. Defaults to `false` + * `:print_values` - if `true`, will print intermediate layer + values to the screen for inspection. This is useful if you need + to debug intermediate values of a model + * `:mode` - one of `:inference` or `:train`. Forwarded to layers to control differences in compilation at training or inference time. Defaults to `:inference` diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 6c25e35c..ebd36543 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -51,8 +51,15 @@ defmodule Axon.Compiler do raise_on_none? = Keyword.get(opts, :raise_on_none, true) mode = Keyword.get(opts, :mode, :inference) seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) + print_values = Keyword.get(opts, :print_values, false) global_layer_options = Keyword.get(opts, :global_layer_options, []) - config = %{mode: mode, debug?: debug?, global_layer_options: global_layer_options} + + config = %{ + mode: mode, + debug?: debug?, + global_layer_options: global_layer_options, + print_values: print_values + } {time, {root_id, {cache, _op_counts, _block_cache, model_state_meta}}} = :timer.tc(fn -> @@ -446,16 +453,21 @@ defmodule Axon.Compiler do end defp recur_model_funs( - %Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: policy}, + %Axon.Node{id: id, name: name_fn, op: :constant, opts: [value: tensor], policy: policy}, _nodes, {cache, op_counts, block_cache, model_state_meta}, - _ + %{print_values: print_values} ) do + name = name_fn.(:constant, op_counts) op_counts = Map.update(op_counts, :constant, 1, fn x -> x + 1 end) tensor = Nx.backend_copy(tensor, Nx.BinaryBackend) predict_fun = fn _params, _inputs, state, _cache, result_cache, _fn_stacktrace -> - out = safe_policy_cast(tensor, policy, :output) + out = + tensor + |> safe_policy_cast(policy, :output) + |> maybe_print_values(name, print_values) + {out, {state, result_cache}} end @@ -477,7 +489,7 @@ defmodule Axon.Compiler do }, _nodes, {cache, op_counts, block_cache, model_state_meta}, - %{mode: mode} + %{mode: mode, print_values: print_values} ) do name = name_fn.(:input, op_counts) op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end) @@ -492,6 +504,7 @@ defmodule Axon.Compiler do value |> apply_hooks(:forward, mode, hooks) |> apply_hooks(:backward, mode, hooks) + |> maybe_print_values(name, print_values) {res, {state, result_cache}} end @@ -687,6 +700,8 @@ defmodule Axon.Compiler do Map.put(state, block_name, out_state) end + out_result = maybe_print_values(out_result, block_name, config.print_values) + {out_result, {state, result_cache}} end end @@ -847,7 +862,12 @@ defmodule Axon.Compiler do }, nodes, cache_and_counts, - %{mode: mode, debug?: debug?, global_layer_options: global_layer_options} = config + %{ + mode: mode, + debug?: debug?, + global_layer_options: global_layer_options, + print_values: print_values + } = config ) when (is_function(op) or is_atom(op)) and is_list(inputs) do # Traverse to accumulate cache and get parent_ids for @@ -912,6 +932,7 @@ defmodule Axon.Compiler do hooks, mode, global_layer_options, + print_values, stacktrace ) @@ -994,6 +1015,7 @@ defmodule Axon.Compiler do hooks, mode, global_layer_options, + print_values, layer_stacktrace ) do # Recurse graph inputs and invoke cache to get parent results, @@ -1113,6 +1135,8 @@ defmodule Axon.Compiler do {new_out, state} end + out = maybe_print_values(out, name, print_values) + {out, {state, result_cache}} end end @@ -1270,6 +1294,12 @@ defmodule Axon.Compiler do defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param) defp maybe_freeze(param, false), do: param + defp maybe_print_values(value, layer, true) do + Nx.Defn.Kernel.print_value(value, label: layer) + end + + defp maybe_print_values(value, _, _), do: value + defp apply_hooks(res, event, mode, hooks) do hooks |> Enum.reverse() diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 11fc9108..4b6026e8 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5716,4 +5716,27 @@ defmodule CompilerTest do assert Nx.shape(out) == {1, 20, 32} end end + + describe "inspect values" do + test "prints intermediate layer values to the screen" do + model = + Axon.input("x") + |> Axon.dense(10, name: "foo") + |> Axon.dense(4, name: "bar") + + {init_fn, predict_fn} = Axon.build(model, print_values: true) + input = Nx.broadcast(1, {1, 10}) + + model_state = init_fn.(input, ModelState.empty()) + + out = + ExUnit.CaptureIO.capture_io(fn -> + predict_fn.(model_state, input) + end) + + assert out =~ "x:" + assert out =~ "foo:" + assert out =~ "bar:" + end + end end