diff --git a/lib/axon.ex b/lib/axon.ex index 2ddfa204..09501709 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -3405,7 +3405,7 @@ defmodule Axon do """ @doc type: :graph def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do - {init_fn, forward_fn} = build(axon, opts) + {init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false]) out = Nx.Defn.jit( diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 07b9109a..45ec0710 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -48,6 +48,7 @@ defmodule Axon.Compiler do @doc false def build(%Axon{output: id, nodes: nodes}, opts) do debug? = Keyword.get(opts, :debug, false) + raise_on_none? = Keyword.get(opts, :raise_on_none, true) mode = Keyword.get(opts, :mode, :inference) seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end) global_layer_options = Keyword.get(opts, :global_layer_options, []) @@ -105,10 +106,12 @@ defmodule Axon.Compiler do end with %Axon.None{} <- result do - raise ArgumentError, - "the compiled model will always result in %Axon.None{}." <> - " This most likely means you specified optional output and " <> - " did not handle the case when it is missing" + if raise_on_none? do + raise ArgumentError, + "the compiled model will always result in %Axon.None{}." <> + " This most likely means you specified optional output and " <> + " did not handle the case when it is missing" + end end result diff --git a/test/axon_test.exs b/test/axon_test.exs index 835aadd3..eb4b8aef 100644 --- a/test/axon_test.exs +++ b/test/axon_test.exs @@ -1076,5 +1076,19 @@ defmodule AxonTest do assert shape = Axon.get_output_shape(model, Nx.template({1, 1}, :f32)) assert shape == {{1, 2}, {1, 2}} end + + test "doesn't raise on none output" do + values = Axon.input("values") + mask = Axon.input("mask", optional: true) + + model = + values + |> Axon.dense(10) + |> Axon.multiply(mask) + |> Axon.dense(1) + |> Axon.sigmoid() + + assert %Axon.None{} = Axon.get_output_shape(model, %{"values" => Nx.template({1, 1}, :f32)}) + end end end