Skip to content

Commit

Permalink
Discard cache on init_params computation
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 2, 2024
1 parent 7196ddb commit e19edf1
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ defmodule Axon do
end

@doc """
Implements an or else (e.g. an Elixir ||)
Implements an or else (e.g. an Elixir ||)
"""
@doc type: :special
def or_else(%Axon{} = a, %Axon{} = b, opts \\ []) do
Expand Down Expand Up @@ -3771,7 +3771,7 @@ defmodule Axon do
as input and returns a function that replaces or rewrites the given node.
For example, you can define a simple rewriter which replaces the `:relu`
layers with `:tanh` layers:
tanh_rewriter = fn [%Axon{} = x], _output ->
Axon.relu(x)
end
Expand Down Expand Up @@ -3926,13 +3926,16 @@ defmodule Axon do
end

@doc """
Compiles the given model to `{init_fn, predict_fn}`.
Compiles the given model to `{init_params, predict_fn}`.
This function will compile a model specialized to the given
input shapes and types. This is useful for avoiding the overhead
of long compilations at program runtime. You must provide template
inputs which match the expected shapes and types of inputs at
execution time.
execution time. Depending on the Nx compiler, such as EXLA v0.9.1+,
both `init_params` the `predict_fn` can be sent across nodes, as
long the node that owns them keeps a reference to the underlying
resources.
This function makes use of the built-in `Nx.Defn.compile/3`. Note
that passing inputs which differ in shape or type from the templates
Expand All @@ -3946,7 +3949,12 @@ defmodule Axon do
def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ [])
when is_list(opts) do
{init_fn, predict_fn} = build(model, opts)
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts)
model_state = Axon.ModelState.new(init_params)

# If there is a disk cache, we only want it to apply to the predict function
init_opts = if is_binary(opts[:cache]), do: Keyword.delete(opts, :cache), else: opts
init_params = Nx.Defn.jit_apply(init_fn, [template, model_state], init_opts)

predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts)
{init_params, predict_compiled_fn}
end
Expand Down

0 comments on commit e19edf1

Please sign in to comment.