diff --git a/lib/axon.ex b/lib/axon.ex index aeb20a76..4b2c1935 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -3974,7 +3974,7 @@ defmodule Axon do """ @doc type: :debug - def trace_init(model, template, params \\ %{}, opts \\ []) do + def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do {init_fn, _} = build(model, opts) Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, params) end diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index 44a94291..b9088a1b 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -171,12 +171,7 @@ defmodule Axon.ModelState do Returns an empty model state. """ def empty() do - %Axon.ModelState{ - data: %{}, - parameters: %{}, - state: %{}, - frozen_parameters: %{} - } + new(%{}) end @doc """ @@ -190,12 +185,40 @@ defmodule Axon.ModelState do def new(data) when is_map(data) do %Axon.ModelState{ data: data, - parameters: get_paths(data), + parameters: transform_to_parameters(data), state: %{}, frozen_parameters: %{} } end + defp transform_to_parameters(%Nx.Tensor{}), do: nil + + defp transform_to_parameters(map) when is_map(map) do + map + |> Enum.map(fn {k, v} -> {k, transform_to_parameters(v)} end) + |> Enum.into(%{}) + end + + defp transform_to_parameters(list) when is_list(list) do + Enum.map(list, &transform_to_parameters/1) + end + + defp transform_to_parameters(value) do + case value do + map when is_map(map) -> + keys = Map.keys(map) + + if Enum.all?(keys, &(is_map(map[&1]) or match?(%Nx.Tensor{}, map[&1]))) do + keys + else + transform_to_parameters(map) + end + + _ -> + value + end + end + # Helpers defp get_paths(map) do diff --git a/lib/axon/quantization/layers.ex b/lib/axon/quantization/layers.ex index e07cfe2c..80900a17 100644 --- a/lib/axon/quantization/layers.ex +++ b/lib/axon/quantization/layers.ex @@ -35,18 +35,20 @@ defmodule Axon.Quantization.Layers do bias, _opts ) do - x_shape = Nx.shape(x) - last_dim = Nx.axis_size(x, -1) + x_view = Nx.reshape(x, {:auto, Nx.axis_size(x, -1)}) - x_view = Nx.reshape(x, {:auto, last_dim}) - - y = Nx.dot(x_view, Nx.as_type(Nx.transpose(w_int8), Nx.type(x))) - y = Nx.multiply(y, scales) - y = reshape_output(y, x_shape) + y = Nx.dot(x_view, Nx.as_type(w_int8, Nx.type(x))) + y = Nx.multiply(y, reshape_scales(scales, y)) + y = reshape_output(y, Nx.shape(x)) Nx.add(y, bias) end + deftransformp reshape_scales(scales, y) do + ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1)) + Nx.reshape(scales, Tuple.append(ones, :auto)) + end + deftransformp reshape_output(output, x_shape) do all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1) new_shape = Tuple.append(all_but_last, :auto) diff --git a/lib/axon/quantization/q_tensor.ex b/lib/axon/quantization/q_tensor.ex index 4ef978f9..12fb6d99 100644 --- a/lib/axon/quantization/q_tensor.ex +++ b/lib/axon/quantization/q_tensor.ex @@ -59,7 +59,7 @@ defmodule Axon.Quantization.QTensor do max: opts[:max] ) - struct(__MODULE__, value: quantized_value, scale: scale, zero_point: zero_point) + struct(__MODULE__, value: Nx.transpose(quantized_value), scale: scale, zero_point: zero_point) end deftransformp quantize_affine(