Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple quantization API #586

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,12 @@ defmodule Axon do
use_bias: true
])

meta =
opts[:meta] ||
%{}
|> Map.put(:units, units)
|> Map.put(:use_bias, opts[:use_bias])

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

Expand All @@ -868,7 +874,7 @@ defmodule Axon do
{[x, kernel], :dense}
end

node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :dense)
node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)

if activation = opts[:activation] do
activation(node, activation)
Expand Down Expand Up @@ -3666,7 +3672,7 @@ defmodule Axon do
"""
@doc type: :graph
def get_op_counts(%Axon{} = axon) do
reduce_nodes(axon, %{}, fn %Axon.Node{op: op}, op_counts ->
reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts ->
Map.update(op_counts, op, 1, fn x -> x + 1 end)
end)
end
Expand Down Expand Up @@ -4096,6 +4102,33 @@ defmodule Axon do
end
end

@doc """
Returns a mapping of layer names to layer properties.
"""
def properties(%Axon{output: id, nodes: nodes}) do
{_, _, properties} = node_properties(id, nodes, {%{}, %{}, %{}})
properties
end

defp node_properties(id, nodes, {cache, op_counts, properties} = acc) do
case cache do
%{^id => _} ->
{cache, op_counts, properties}

%{} ->
%Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id]

{cache, op_counts, properties} =
Enum.reduce(parents, acc, &node_properties(&1, nodes, &2))

name = name_fn.(op_name, op_counts)
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)
properties = Map.put(properties, name, op_name)

{Map.put(cache, id, name), op_counts, properties}
end
end

## Helpers

@valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++
Expand Down
11 changes: 11 additions & 0 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ defmodule Axon.ModelState do

defp traverse(%Nx.Tensor{}, acc), do: [Enum.reverse(acc)]

defp traverse(%Axon.Quantization.QTensor{}, acc), do: [Enum.reverse(acc)]

defp traverse(map, acc) do
Enum.flat_map(map, fn {k, value} ->
traverse(value, [k | acc])
Expand Down Expand Up @@ -273,6 +275,10 @@ defmodule Axon.ModelState do
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

%Axon.Quantization.QTensor{} = val_rhs ->
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

val_rhs when is_map(val_lhs) and is_map(val_rhs) ->
updated_val = tree_merge(val_lhs, val_rhs, fun)
Map.put(acc, key, updated_val)
Expand Down Expand Up @@ -321,6 +327,11 @@ defmodule Axon.ModelState do
{_, %Nx.Tensor{} = tensor}, {count, size} ->
{count + Nx.size(tensor), size + Nx.byte_size(tensor)}

{_, %Axon.Quantization.QTensor{value: value, scale: scale, zero_point: zero}},
{count, size} ->
{count + Nx.size(value) + Nx.size(scale) + Nx.size(zero),
size + Nx.byte_size(value) + Nx.byte_size(scale) + Nx.byte_size(zero)}

{_, map}, {count, size} ->
{inner_count, inner_size} = get_param_info(map)
{count + inner_count, size + inner_size}
Expand Down
161 changes: 161 additions & 0 deletions lib/axon/quantization.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
defmodule Axon.Quantization do
@moduledoc """
Model quantization.

Model quantization is a technique for reducing the memory footprint of
a model by converting portions of a model to use quantized representations.
Typically, these quantized representations are low-precision integers.

This is an **experimental** API which implements weight-only quantization.
The implementation in this module will convert dense layers in a large
model to quantized-variants. The only supported quantization type is
`{:s, 8}`. Axon quantization is inference-only. Training is not currently
supported.
"""
alias Axon.Quantization.Layers
alias Axon.Quantization.QTensor

@doc """
Quantizes a model and a model state.

Given a model and model state, this method will rewrite all
of the dense layers in the model to perform weight-only 8-bit
integer versions of the same operation. It will also replace values
for all dense kernels in the given model state with quantized
tensors.
"""
def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do
quantized_model = quantize_model(model)
quantized_model_state = quantize_model_state(model, model_state)
{quantized_model, quantized_model_state}
end

@doc """
Replaces standard operations with quantized variants.

The only supported conversion is to convert regular dense layers
to a weight-only 8-bit integer variant. Note that this only replaces
the properties of the model. If you have a pre-trained model state
that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`.

All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`.
"""
def quantize_model(%Axon{} = model) do
quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias ->
weight_only_quantized_dense(x, units, use_bias: use_bias)
end

Axon.rewrite_nodes(model, fn
%Axon.Node{op: :dense, meta: meta} ->
&quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias])

_ ->
:skip
end)
end

@doc """
Returns a quantized model state.

Given a model and a model state, this function will replace
all dense layer kernels with a quantized version of the weight.

Training is not currently supported, so all quantized layers are
automatically frozen.
"""
def quantize_model_state(model, model_state) do
dense_layer_names =
model
|> Axon.properties()
|> Enum.filter(fn {_, v} -> v == :dense end)
|> Enum.map(fn {k, _} -> k end)
|> MapSet.new()

state =
Enum.reduce(dense_layer_names, model_state, fn layer_name, state ->
update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1)
end)

Axon.ModelState.freeze(state, fn [name | _] ->
MapSet.member?(dense_layer_names, name)
end)
end

## Layers

@doc """
Adds a weight-only quantized dense layer to the network.

This is equivalent to a dense layer, but works on quantized
weights for reducing model memory footprint.

Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`.

## Options

* `:name` - layer name.

* `:kernel_initializer` - initializer for `kernel` weights.
Defaults to `:glorot_uniform`.

* `:bias_initializer` - initializer for `bias` weights. Defaults
to `:zeros`.

* `:use_bias` - whether the layer should add bias to the output.
Defaults to `true`.
"""
def weight_only_quantized_dense(x, units, opts \\ []) do
opts =
Keyword.validate!(opts, [
:name,
:meta,
use_bias: true,
kernel_initializer: :glorot_uniform,
bias_initializer: :zeros
])

meta =
opts[:meta] ||
%{}
|> Map.put(:units, units)
|> Map.put(:use_bias, opts[:use_bias])

kernel_shape = &Axon.Shape.dense_kernel(&1, units)
bias_shape = &Axon.Shape.dense_bias(&1, units)

kernel =
Axon.param("kernel", kernel_shape,
initializer: fn shape, type, key ->
fun =
case opts[:kernel_initializer] do
init when is_atom(init) ->
apply(Axon.Initializers, [])

fun when is_function(fun) ->
fun
end

tensor =
case fun do
fun when is_function(fun, 2) ->
fun.(shape, type)

fun when is_function(fun, 3) ->
fun.(shape, type, key)
end

QTensor.from_tensor(tensor)
end
)

{inputs, op} =
if opts[:use_bias] do
bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer])
{[x, kernel, bias], &Layers.weight_only_quantized_dense/4}
else
{[x, kernel], &Layers.weight_only_quantized_dense/3}
end

Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense)
end
end
43 changes: 43 additions & 0 deletions lib/axon/quantization/layers.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
defmodule Axon.Quantization.Layers do
@moduledoc """
Quantized Layer Implementations.
"""
alias Axon.Quantization.QTensor
import Nx.Defn

@doc """
Weight-only quantized version of a dense layer.

It expects the input kernel to be an `Axon.Quantization.QTensor`.
"""
deftransform weight_only_quantized_dense(input, kernel, bias \\ 0, opts \\ []) do
{bias, opts} =
case bias do
%Nx.Tensor{} = bias ->
{bias, opts}

bias when is_number(bias) ->
{bias, opts}

opts when is_list(opts) ->
{Nx.tensor(0), opts}

other ->
raise ArgumentError, "invalid bias, expected a tensor, got #{inspect(other)}"
end

weight_only_quantized_dense_impl(input, kernel, bias, opts)
end

defnp weight_only_quantized_dense_impl(
input,
%QTensor{value: kernel, scale: scale},
bias,
_opts
) do
input
|> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0])
|> Nx.multiply(scale)
|> Nx.add(bias)
end
end
Loading
Loading