Skip to content

Commit

Permalink
Finish initial quantization API
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jul 24, 2024
1 parent d10f24f commit e784c1d
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 254 deletions.
10 changes: 8 additions & 2 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,12 @@ defmodule Axon do
use_bias: true
])

meta =
opts[:meta] ||
%{}
|> Map.put(:units, units)
|> Map.put(:use_bias, opts[:use_bias])

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

Expand All @@ -868,7 +874,7 @@ defmodule Axon do
{[x, kernel], :dense}
end

node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :dense)
node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)

if activation = opts[:activation] do
activation(node, activation)
Expand Down Expand Up @@ -3666,7 +3672,7 @@ defmodule Axon do
"""
@doc type: :graph
def get_op_counts(%Axon{} = axon) do
reduce_nodes(axon, %{}, fn %Axon.Node{op: op}, op_counts ->
reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts ->
Map.update(op_counts, op, 1, fn x -> x + 1 end)
end)
end
Expand Down
11 changes: 11 additions & 0 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ defmodule Axon.ModelState do

defp traverse(%Nx.Tensor{}, acc), do: [Enum.reverse(acc)]

defp traverse(%Axon.Quantization.QTensor{}, acc), do: [Enum.reverse(acc)]

defp traverse(map, acc) do
Enum.flat_map(map, fn {k, value} ->
traverse(value, [k | acc])
Expand Down Expand Up @@ -273,6 +275,10 @@ defmodule Axon.ModelState do
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

%Axon.Quantization.QTensor{} = val_rhs ->
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

val_rhs when is_map(val_lhs) and is_map(val_rhs) ->
updated_val = tree_merge(val_lhs, val_rhs, fun)
Map.put(acc, key, updated_val)
Expand Down Expand Up @@ -321,6 +327,11 @@ defmodule Axon.ModelState do
{_, %Nx.Tensor{} = tensor}, {count, size} ->
{count + Nx.size(tensor), size + Nx.byte_size(tensor)}

{_, %Axon.Quantization.QTensor{value: value, scale: scale, zero_point: zero}},
{count, size} ->
{count + Nx.size(value) + Nx.size(scale) + Nx.size(zero),
size + Nx.byte_size(value) + Nx.byte_size(scale) + Nx.byte_size(zero)}

{_, map}, {count, size} ->
{inner_count, inner_size} = get_param_info(map)
{count + inner_count, size + inner_size}
Expand Down
Loading

0 comments on commit e784c1d

Please sign in to comment.