From 7196ddb5629d7173b26e45a2d7a63f716925d580 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 2 Sep 2024 13:26:24 +0200 Subject: [PATCH] Discard disk cache for step function --- lib/axon/loop.ex | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 71d9d5e8..6c148e07 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -325,7 +325,6 @@ defmodule Axon.Loop do """ def train_step(model, loss, optimizer, opts \\ []) do opts = Keyword.validate!(opts, [:seed, loss_scale: :identity]) - loss_scale = opts[:loss_scale] || :identity {init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts) @@ -341,8 +340,7 @@ defmodule Axon.Loop do optimizer_state = init_optimizer_fn.(trainable_parameters) loss_scale_state = init_loss_scale.() - # TODO: is this expensive? Will it compute the entire - # forward? + # TODO: is this expensive? Will it compute the entire forward? %{prediction: output} = forward_model_fn.(model_state, inp) %{ @@ -507,6 +505,7 @@ defmodule Axon.Loop do raise_bad_training_inputs!(data, state) end + # Pass on_conflict: :reuse as we want someone to jit it on top { Nx.Defn.jit(init_fn, on_conflict: :reuse), Nx.Defn.jit(step_fn, on_conflict: :reuse) @@ -1563,9 +1562,9 @@ defmodule Axon.Loop do * `:debug` - run loop in debug mode to trace loop progress. Defaults to false. - Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT - options are set, the default options set with `Nx.Defn.default_options` are - used. + Additional options are forwarded to `Nx.Defn.jit` as JIT-options. If no JIT + options are set, the default options set with `Nx.Defn.default_options` are + used. """ def run(loop, data, init_state \\ %{}, opts \\ []) do {max_epochs, opts} = Keyword.pop(opts, :epochs, 1) @@ -2263,6 +2262,10 @@ defmodule Axon.Loop do # otherwise just applies the function with the given arguments defp maybe_jit(fun, args, jit_compile?, jit_opts) do if jit_compile? do + # If there is a disk cache, we only want it to apply to the batch function + jit_opts = + if is_binary(jit_opts[:cache]), do: Keyword.delete(jit_opts, :cache), else: jit_opts + apply(Nx.Defn.jit(fun, jit_opts), args) else apply(fun, args)