Skip to content

Commit

Permalink
Generalize quantization layers
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jul 30, 2024
1 parent a54ee13 commit 8e0a6d9
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 16 deletions.
2 changes: 1 addition & 1 deletion lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3974,7 +3974,7 @@ defmodule Axon do
"""
@doc type: :debug
def trace_init(model, template, params \\ %{}, opts \\ []) do
def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do
{init_fn, _} = build(model, opts)
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, params)
end
Expand Down
37 changes: 30 additions & 7 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,7 @@ defmodule Axon.ModelState do
Returns an empty model state.
"""
def empty() do
%Axon.ModelState{
data: %{},
parameters: %{},
state: %{},
frozen_parameters: %{}
}
new(%{})
end

@doc """
Expand All @@ -190,12 +185,40 @@ defmodule Axon.ModelState do
def new(data) when is_map(data) do
%Axon.ModelState{
data: data,
parameters: get_paths(data),
parameters: transform_to_parameters(data),
state: %{},
frozen_parameters: %{}
}
end

defp transform_to_parameters(%Nx.Tensor{}), do: nil

defp transform_to_parameters(map) when is_map(map) do
map
|> Enum.map(fn {k, v} -> {k, transform_to_parameters(v)} end)
|> Enum.into(%{})
end

defp transform_to_parameters(list) when is_list(list) do
Enum.map(list, &transform_to_parameters/1)
end

defp transform_to_parameters(value) do
case value do
map when is_map(map) ->
keys = Map.keys(map)

if Enum.all?(keys, &(is_map(map[&1]) or match?(%Nx.Tensor{}, map[&1]))) do
keys
else
transform_to_parameters(map)
end

_ ->
value
end
end

# Helpers

defp get_paths(map) do
Expand Down
16 changes: 9 additions & 7 deletions lib/axon/quantization/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ defmodule Axon.Quantization.Layers do
bias,
_opts
) do
x_shape = Nx.shape(x)
last_dim = Nx.axis_size(x, -1)
x_view = Nx.reshape(x, {:auto, Nx.axis_size(x, -1)})

x_view = Nx.reshape(x, {:auto, last_dim})

y = Nx.dot(x_view, Nx.as_type(Nx.transpose(w_int8), Nx.type(x)))
y = Nx.multiply(y, scales)
y = reshape_output(y, x_shape)
y = Nx.dot(x_view, Nx.as_type(w_int8, Nx.type(x)))
y = Nx.multiply(y, reshape_scales(scales, y))
y = reshape_output(y, Nx.shape(x))

Nx.add(y, bias)
end

deftransformp reshape_scales(scales, y) do
ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1))
Nx.reshape(scales, Tuple.append(ones, :auto))
end

deftransformp reshape_output(output, x_shape) do
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
new_shape = Tuple.append(all_but_last, :auto)
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/quantization/q_tensor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ defmodule Axon.Quantization.QTensor do
max: opts[:max]
)

struct(__MODULE__, value: quantized_value, scale: scale, zero_point: zero_point)
struct(__MODULE__, value: Nx.transpose(quantized_value), scale: scale, zero_point: zero_point)
end

deftransformp quantize_affine(
Expand Down

0 comments on commit 8e0a6d9

Please sign in to comment.