Skip to content

Commit

Permalink
Discard disk cache for step function
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 2, 2024
1 parent f90e5ac commit 7196ddb
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ defmodule Axon.Loop do
"""
def train_step(model, loss, optimizer, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity])

loss_scale = opts[:loss_scale] || :identity

{init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts)
Expand All @@ -341,8 +340,7 @@ defmodule Axon.Loop do
optimizer_state = init_optimizer_fn.(trainable_parameters)
loss_scale_state = init_loss_scale.()

# TODO: is this expensive? Will it compute the entire
# forward?
# TODO: is this expensive? Will it compute the entire forward?
%{prediction: output} = forward_model_fn.(model_state, inp)

%{
Expand Down Expand Up @@ -507,6 +505,7 @@ defmodule Axon.Loop do
raise_bad_training_inputs!(data, state)
end

# Pass on_conflict: :reuse as we want someone to jit it on top
{
Nx.Defn.jit(init_fn, on_conflict: :reuse),
Nx.Defn.jit(step_fn, on_conflict: :reuse)
Expand Down Expand Up @@ -1563,9 +1562,9 @@ defmodule Axon.Loop do
* `:debug` - run loop in debug mode to trace loop progress. Defaults to
false.
Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
options are set, the default options set with `Nx.Defn.default_options` are
used.
Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT
options are set, the default options set with `Nx.Defn.default_options` are
used.
"""
def run(loop, data, init_state \\ %{}, opts \\ []) do
{max_epochs, opts} = Keyword.pop(opts, :epochs, 1)
Expand Down Expand Up @@ -2263,6 +2262,10 @@ defmodule Axon.Loop do
# otherwise just applies the function with the given arguments
defp maybe_jit(fun, args, jit_compile?, jit_opts) do
if jit_compile? do
# If there is a disk cache, we only want it to apply to the batch function
jit_opts =
if is_binary(jit_opts[:cache]), do: Keyword.delete(jit_opts, :cache), else: jit_opts

apply(Nx.Defn.jit(fun, jit_opts), args)
else
apply(fun, args)
Expand Down

0 comments on commit 7196ddb

Please sign in to comment.