diff --git a/lib/axon.ex b/lib/axon.ex index fa73645a..52f4ba6d 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -587,7 +587,7 @@ defmodule Axon do end @doc """ - Implements an or else (e.g. an Elixir ||) + Implements an or else (e.g. an Elixir ||) """ @doc type: :special def or_else(%Axon{} = a, %Axon{} = b, opts \\ []) do @@ -3771,7 +3771,7 @@ defmodule Axon do as input and returns a function that replaces or rewrites the given node. For example, you can define a simple rewriter which replaces the `:relu` layers with `:tanh` layers: - + tanh_rewriter = fn [%Axon{} = x], _output -> Axon.relu(x) end @@ -3926,13 +3926,16 @@ defmodule Axon do end @doc """ - Compiles the given model to `{init_fn, predict_fn}`. + Compiles the given model to `{init_params, predict_fn}`. This function will compile a model specialized to the given input shapes and types. This is useful for avoiding the overhead of long compilations at program runtime. You must provide template inputs which match the expected shapes and types of inputs at - execution time. + execution time. Depending on the Nx compiler, such as EXLA v0.9.1+, + both `init_params` the `predict_fn` can be sent across nodes, as + long the node that owns them keeps a reference to the underlying + resources. This function makes use of the built-in `Nx.Defn.compile/3`. Note that passing inputs which differ in shape or type from the templates @@ -3946,7 +3949,12 @@ defmodule Axon do def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ []) when is_list(opts) do {init_fn, predict_fn} = build(model, opts) - init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts) + model_state = Axon.ModelState.new(init_params) + + # If there is a disk cache, we only want it to apply to the predict function + init_opts = if is_binary(opts[:cache]), do: Keyword.delete(opts, :cache), else: opts + init_params = Nx.Defn.jit_apply(init_fn, [template, model_state], init_opts) + predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts) {init_params, predict_compiled_fn} end