From d10f24ff935f2ee9419931d5f72b4660bdd9d4f1 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 12 Jul 2024 14:24:32 -0400 Subject: [PATCH 1/3] Quantization draft --- lib/axon.ex | 27 +++ lib/axon/quantization.ex | 285 ++++++++++++++++++++++++++++++++ lib/axon/quantization/layers.ex | 18 ++ 3 files changed, 330 insertions(+) create mode 100644 lib/axon/quantization.ex create mode 100644 lib/axon/quantization/layers.ex diff --git a/lib/axon.ex b/lib/axon.ex index b2e03523..68253563 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -4096,6 +4096,33 @@ defmodule Axon do end end + @doc """ + Returns a mapping of layer names to layer properties. + """ + def properties(%Axon{output: id, nodes: nodes}) do + {_, _, properties} = node_properties(id, nodes, {%{}, %{}, %{}}) + properties + end + + defp node_properties(id, nodes, {cache, op_counts, properties} = acc) do + case cache do + %{^id => _} -> + {cache, op_counts, properties} + + %{} -> + %Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id] + + {cache, op_counts, properties} = + Enum.reduce(parents, acc, &node_properties(&1, nodes, &2)) + + name = name_fn.(op_name, op_counts) + op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end) + properties = Map.put(properties, name, op_name) + + {Map.put(cache, id, name), op_counts, properties} + end + end + ## Helpers @valid_initializers [:zeros, :ones, :uniform, :normal, :identity] ++ diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex new file mode 100644 index 00000000..7abb4206 --- /dev/null +++ b/lib/axon/quantization.ex @@ -0,0 +1,285 @@ +defmodule Axon.Quantization do + alias Axon.Quantization.Layers + + ## Transformation + + def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do + quantized_model = rewrite_dense(model) + quantized_model_state = quantize_dense_layers(model, model_state) + {quantized_model, quantized_model_state} + end + + defp rewrite_dense(%Axon{} = model) do + # TODO: Make this easier + Axon.map_nodes(model, fn + %{op_name: :dense, args: args, parameters: parameters} = axon_node -> + scales = Axon.param("scales", &quantized_dense_scale/1, initializer: :zeros, kind: :state) + + %{ + axon_node + | op_name: :weight_only_quantized_dense, + op: &Layers.weight_only_quantized_dense/5, + args: args ++ [:parameter], + parameters: parameters ++ [scales] + } + + axon_node -> + axon_node + end) + end + + defp quantize_dense_layers(model, model_state) do + # TODO: Make these updates easier + dense_layer_names = + model + |> Axon.properties() + |> Enum.filter(fn {_, v} -> v == :dense end) + |> Enum.map(fn {k, _} -> k end) + + Enum.reduce(dense_layer_names, model_state, fn layer_name, state -> + state + |> update_in([Access.key!(:data), layer_name], fn params -> + quantize_dense_params(params) + end) + |> update_in([Access.key!(:state), layer_name], fn _ -> + ["scales"] + end) + end) + end + + defp quantize_dense_params(%{"kernel" => dense_kernel, "bias" => dense_bias}) do + transposed_kernel = Nx.transpose(dense_kernel) + + {quant_kernel, scales, _zero} = + dynamically_quantize_per_channel(transposed_kernel, -128, 127, {:s, 8}) + + %{ + "kernel" => Nx.transpose(quant_kernel), + "bias" => dense_bias, + "scales" => scales + } + end + + ## Quantizers + + def dynamically_quantize_per_channel(%Nx.Tensor{} = x, quant_min, quant_max, target_dtype) do + unless Nx.rank(x) == 2, do: raise("expected 2d tensor") + + eps = Nx.Constants.epsilon(:f32) + block_size = {1, Nx.axis_size(x, 1)} + zero_point_dtype = {:s, 64} + + {scale, zero_point} = + choose_quantization_params_affine(x, :symmetric, block_size, target_dtype, + quant_min: quant_min, + quant_max: quant_max, + eps: eps, + zero_point_dtype: zero_point_dtype + ) + + quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) + + {quant, scale, zero_point} + end + + def quantize_affine( + input, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + opts \\ [] + ) do + opts = Keyword.validate!(opts, zero_point_domain: :int) + zero_point_domain = opts[:zero_point_domain] + + {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) + + original_shape = Nx.shape(input) + input = Nx.reshape(input, shape_for_reduction) + + scale_shape = + Enum.reduce(reduction_dims, shape_for_reduction, fn i, shape -> + put_elem(shape, i, 1) + end) + + scale = Nx.reshape(scale, scale_shape) + zero_point = Nx.reshape(zero_point, scale_shape) + + quant = + case zero_point_domain do + :int -> + Nx.clip( + Nx.add(Nx.round(Nx.multiply(input, Nx.divide(1, scale))), zero_point), + quant_min, + quant_max + ) + + other -> + raise "unsupported zero point domain #{other}" + end + + Nx.as_type(Nx.reshape(quant, original_shape), target_dtype) + end + + def choose_quantization_params_affine( + input, + mapping_type, + block_size, + target_dtype, + opts \\ [] + ) do + opts = + Keyword.validate!(opts, [ + :quant_min, + :quant_max, + :eps, + :scale_dtype, + :zero_point_dtype, + :zero_point_domain, + preserve_zero: true + ]) + + preserve_zero = opts[:preserve_zero] + + {quant_min, quant_max} = + get_and_check_qmin_qmax(target_dtype, opts[:quant_min], opts[:quant_max]) + + scale_dtype = opts[:scale_dtype] || Nx.type(input) + zero_point_dtype = opts[:zero_point_dtype] || Nx.type(input) + eps = opts[:eps] || Nx.Constants.epsilon(Nx.type(input)) + + {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) + input = Nx.reshape(input, shape_for_reduction) + + min_val = Nx.reduce_min(input, axes: reduction_dims, keep_axes: false) + max_val = Nx.reduce_max(input, axes: reduction_dims, keep_axes: false) + + {min_val_neg, max_val_pos} = + if preserve_zero do + {Nx.min(min_val, Nx.broadcast(0, min_val)), Nx.max(max_val, Nx.broadcast(0, max_val))} + else + {min_val, max_val} + end + + {scale, zero_point} = + case mapping_type do + :symmetric -> + max_val_pos = Nx.max(Nx.negate(min_val_neg), max_val_pos) + scale = Nx.divide(max_val_pos, Nx.divide(Nx.subtract(quant_max, quant_min), 2)) + zero_point = Nx.broadcast(trunc((quant_max + quant_min + 1) / 2), scale) + {scale, zero_point} + + other -> + raise "unsupported mapping #{other}" + end + + scale = Nx.clip(scale, eps, Nx.reduce_max(scale)) + + {Nx.as_type(scale, scale_dtype), Nx.as_type(zero_point, zero_point_dtype)} + end + + def get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) do + {lower_bound, upper_bound} = + case target_dtype do + {:u, 8} -> {0, 255} + {:s, 8} -> {-128, 127} + {:s, 16} -> {-(2 ** 15), 2 ** 15 - 1} + {:s, 32} -> {-(2 ** 31), 2 ** 31 - 1} + end + + quant_min = + cond do + quant_min == nil -> + lower_bound + + quant_min < lower_bound -> + raise "quant_min out of bounds for target_dtype" + + true -> + quant_min + end + + quant_max = + cond do + quant_max == nil -> + upper_bound + + quant_max > upper_bound -> + raise "quant_max out of bounds for target_dtype" + + true -> + quant_max + end + + {quant_min, quant_max} + end + + def get_reduction_params(block_size, input_size) do + if tuple_size(block_size) != tuple_size(input_size) do + raise "block_size and input_size must have the same length" + end + + {shape_for_reduction, reduction_dims, _} = + block_size + |> Tuple.to_list() + |> Enum.zip(Tuple.to_list(input_size)) + |> Enum.with_index() + |> Enum.reduce({[], [], 0}, fn {{block, input}, i}, {shape, dims, cur_dim} -> + if block != input and block > 1 do + unless rem(input, block) == 0 do + raise "Expecting input size at #{i} dimension: #{input} to be divisible by block_size at #{i} dimension: #{block}" + end + + shape = [block, div(input, block) | shape] + dims = [cur_dim + 1 | dims] + cur_dim = cur_dim + 2 + + {shape, dims, cur_dim} + else + shape = [input | shape] + dims = if block != 1, do: [cur_dim | dims], else: dims + cur_dim = cur_dim + 1 + + {shape, dims, cur_dim} + end + end) + + {List.to_tuple(Enum.reverse(shape_for_reduction)), Enum.reverse(reduction_dims)} + end + + ## Layers + + def weight_only_quantized_dense(input, units, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :name, + :meta, + kernel_initializer: :glorot_uniform, + bias_initializer: :zeros + ]) + + kernel_shape = &Axon.Shape.dense_kernel(&1, units) + bias_shape = &Axon.Shape.dense_bias(&1, units) + scales_shape = &quantized_dense_scale/1 + + kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) + bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) + # TODO: This requires dependent initializers + scales = Axon.param("scales", scales_shape, initializer: :zeros) + + Axon.layer(&Layers.weight_only_quantized_dense/5, [input, kernel, bias, scales], + meta: opts[:meta], + name: opts[:name], + op_name: :weight_only_quantized_dense + ) + end + + defp quantized_dense_scale(input_shape) do + Nx.axis_size(input_shape, -1) + end + + ## Quantizers +end diff --git a/lib/axon/quantization/layers.ex b/lib/axon/quantization/layers.ex new file mode 100644 index 00000000..93c943c0 --- /dev/null +++ b/lib/axon/quantization/layers.ex @@ -0,0 +1,18 @@ +defmodule Axon.Quantization.Layers do + @moduledoc """ + Quantized Layer Implementations. + """ + + import Nx.Defn + + # TODO: Make this more general + + defn weight_only_quantized_dense(x, kernel, bias, scales, _opts \\ []) do + # TODO: Flatten x if necessary + + x + |> Nx.dot(Nx.as_type(kernel, Nx.type(x))) + |> Nx.multiply(scales) + |> Nx.add(bias) + end +end From e784c1d9efc333db014d443b3873e3dec217c7c3 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Wed, 24 Jul 2024 11:26:53 -0400 Subject: [PATCH 2/3] Finish initial quantization API --- lib/axon.ex | 10 +- lib/axon/model_state.ex | 11 + lib/axon/quantization.ex | 341 +++++++++--------------------- lib/axon/quantization/layers.ex | 39 +++- lib/axon/quantization/q_tensor.ex | 233 ++++++++++++++++++++ lib/axon/shared.ex | 6 - test/axon/quantization_test.exs | 45 ++++ 7 files changed, 431 insertions(+), 254 deletions(-) create mode 100644 lib/axon/quantization/q_tensor.ex create mode 100644 test/axon/quantization_test.exs diff --git a/lib/axon.ex b/lib/axon.ex index 68253563..28b57621 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -855,6 +855,12 @@ defmodule Axon do use_bias: true ]) + meta = + opts[:meta] || + %{} + |> Map.put(:units, units) + |> Map.put(:use_bias, opts[:use_bias]) + kernel_shape = &Axon.Shape.dense_kernel(&1, units) bias_shape = &Axon.Shape.dense_bias(&1, units) @@ -868,7 +874,7 @@ defmodule Axon do {[x, kernel], :dense} end - node = layer(op, inputs, name: opts[:name], meta: opts[:meta], op_name: :dense) + node = layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense) if activation = opts[:activation] do activation(node, activation) @@ -3666,7 +3672,7 @@ defmodule Axon do """ @doc type: :graph def get_op_counts(%Axon{} = axon) do - reduce_nodes(axon, %{}, fn %Axon.Node{op: op}, op_counts -> + reduce_nodes(axon, %{}, fn %Axon.Node{op_name: op}, op_counts -> Map.update(op_counts, op, 1, fn x -> x + 1 end) end) end diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index 8eede9a6..44a94291 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -206,6 +206,8 @@ defmodule Axon.ModelState do defp traverse(%Nx.Tensor{}, acc), do: [Enum.reverse(acc)] + defp traverse(%Axon.Quantization.QTensor{}, acc), do: [Enum.reverse(acc)] + defp traverse(map, acc) do Enum.flat_map(map, fn {k, value} -> traverse(value, [k | acc]) @@ -273,6 +275,10 @@ defmodule Axon.ModelState do new_val = fun.(key, val_lhs, val_rhs) Map.put(acc, key, new_val) + %Axon.Quantization.QTensor{} = val_rhs -> + new_val = fun.(key, val_lhs, val_rhs) + Map.put(acc, key, new_val) + val_rhs when is_map(val_lhs) and is_map(val_rhs) -> updated_val = tree_merge(val_lhs, val_rhs, fun) Map.put(acc, key, updated_val) @@ -321,6 +327,11 @@ defmodule Axon.ModelState do {_, %Nx.Tensor{} = tensor}, {count, size} -> {count + Nx.size(tensor), size + Nx.byte_size(tensor)} + {_, %Axon.Quantization.QTensor{value: value, scale: scale, zero_point: zero}}, + {count, size} -> + {count + Nx.size(value) + Nx.size(scale) + Nx.size(zero), + size + Nx.byte_size(value) + Nx.byte_size(scale) + Nx.byte_size(zero)} + {_, map}, {count, size} -> {inner_count, inner_size} = get_param_info(map) {count + inner_count, size + inner_size} diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex index 7abb4206..679ca605 100644 --- a/lib/axon/quantization.ex +++ b/lib/axon/quantization.ex @@ -1,285 +1,148 @@ defmodule Axon.Quantization do alias Axon.Quantization.Layers + alias Axon.Quantization.QTensor - ## Transformation + @doc """ + Quantizes a model and a model state. + Given a model and model state, this method will rewrite all + of the dense layers in the model to perform weight-only 8-bit + integer versions of the same operation. It will also replace values + for all dense kernels in the given model state with quantized + tensors. + """ def quantize(%Axon{} = model, %Axon.ModelState{} = model_state) do - quantized_model = rewrite_dense(model) - quantized_model_state = quantize_dense_layers(model, model_state) + quantized_model = quantize_model(model) + quantized_model_state = quantize_model_state(model, model_state) {quantized_model, quantized_model_state} end - defp rewrite_dense(%Axon{} = model) do - # TODO: Make this easier - Axon.map_nodes(model, fn - %{op_name: :dense, args: args, parameters: parameters} = axon_node -> - scales = Axon.param("scales", &quantized_dense_scale/1, initializer: :zeros, kind: :state) + @doc """ + Replaces standard operations with quantized variants. - %{ - axon_node - | op_name: :weight_only_quantized_dense, - op: &Layers.weight_only_quantized_dense/5, - args: args ++ [:parameter], - parameters: parameters ++ [scales] - } + The only supported conversion is to convert regular dense layers + to a weight-only 8-bit integer variant. Note that this only replaces + the properties of the model. If you have a pre-trained model state + that you wish to quantize, refer to `Axon.Quantization.quantize_model_state/1`. - axon_node -> - axon_node + All `:dense` layers in the model are replaced with `Axon.Quantization.weight_only_quantized_dense/3`. + """ + def quantize_model(%Axon{} = model) do + quantized_dense_rewriter = fn [%Axon{} = x], _output, units, use_bias -> + weight_only_quantized_dense(x, units, use_bias: use_bias) + end + + Axon.rewrite_nodes(model, fn + %Axon.Node{op: :dense, meta: meta} -> + &quantized_dense_rewriter.(&1, &2, meta[:units], meta[:use_bias]) + + _ -> + :skip end) end - defp quantize_dense_layers(model, model_state) do - # TODO: Make these updates easier + @doc """ + Returns a quantized model state. + + Given a model and a model state, this function will replace + all dense layer kernels with a quantized version of the weight. + + Training is not currently supported, so all quantized layers are + automatically frozen. + """ + def quantize_model_state(model, model_state) do dense_layer_names = model |> Axon.properties() |> Enum.filter(fn {_, v} -> v == :dense end) |> Enum.map(fn {k, _} -> k end) + |> MapSet.new() - Enum.reduce(dense_layer_names, model_state, fn layer_name, state -> - state - |> update_in([Access.key!(:data), layer_name], fn params -> - quantize_dense_params(params) - end) - |> update_in([Access.key!(:state), layer_name], fn _ -> - ["scales"] - end) - end) - end - - defp quantize_dense_params(%{"kernel" => dense_kernel, "bias" => dense_bias}) do - transposed_kernel = Nx.transpose(dense_kernel) - - {quant_kernel, scales, _zero} = - dynamically_quantize_per_channel(transposed_kernel, -128, 127, {:s, 8}) - - %{ - "kernel" => Nx.transpose(quant_kernel), - "bias" => dense_bias, - "scales" => scales - } - end - - ## Quantizers - - def dynamically_quantize_per_channel(%Nx.Tensor{} = x, quant_min, quant_max, target_dtype) do - unless Nx.rank(x) == 2, do: raise("expected 2d tensor") - - eps = Nx.Constants.epsilon(:f32) - block_size = {1, Nx.axis_size(x, 1)} - zero_point_dtype = {:s, 64} - - {scale, zero_point} = - choose_quantization_params_affine(x, :symmetric, block_size, target_dtype, - quant_min: quant_min, - quant_max: quant_max, - eps: eps, - zero_point_dtype: zero_point_dtype - ) - - quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) - - {quant, scale, zero_point} - end - - def quantize_affine( - input, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - opts \\ [] - ) do - opts = Keyword.validate!(opts, zero_point_domain: :int) - zero_point_domain = opts[:zero_point_domain] - - {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) - - original_shape = Nx.shape(input) - input = Nx.reshape(input, shape_for_reduction) - - scale_shape = - Enum.reduce(reduction_dims, shape_for_reduction, fn i, shape -> - put_elem(shape, i, 1) + state = + Enum.reduce(dense_layer_names, model_state, fn layer_name, state -> + update_in(state, [Access.key!(:data), layer_name, "kernel"], &QTensor.from_tensor/1) end) - scale = Nx.reshape(scale, scale_shape) - zero_point = Nx.reshape(zero_point, scale_shape) - - quant = - case zero_point_domain do - :int -> - Nx.clip( - Nx.add(Nx.round(Nx.multiply(input, Nx.divide(1, scale))), zero_point), - quant_min, - quant_max - ) - - other -> - raise "unsupported zero point domain #{other}" - end - - Nx.as_type(Nx.reshape(quant, original_shape), target_dtype) - end - - def choose_quantization_params_affine( - input, - mapping_type, - block_size, - target_dtype, - opts \\ [] - ) do - opts = - Keyword.validate!(opts, [ - :quant_min, - :quant_max, - :eps, - :scale_dtype, - :zero_point_dtype, - :zero_point_domain, - preserve_zero: true - ]) - - preserve_zero = opts[:preserve_zero] - - {quant_min, quant_max} = - get_and_check_qmin_qmax(target_dtype, opts[:quant_min], opts[:quant_max]) - - scale_dtype = opts[:scale_dtype] || Nx.type(input) - zero_point_dtype = opts[:zero_point_dtype] || Nx.type(input) - eps = opts[:eps] || Nx.Constants.epsilon(Nx.type(input)) - - {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) - input = Nx.reshape(input, shape_for_reduction) - - min_val = Nx.reduce_min(input, axes: reduction_dims, keep_axes: false) - max_val = Nx.reduce_max(input, axes: reduction_dims, keep_axes: false) - - {min_val_neg, max_val_pos} = - if preserve_zero do - {Nx.min(min_val, Nx.broadcast(0, min_val)), Nx.max(max_val, Nx.broadcast(0, max_val))} - else - {min_val, max_val} - end - - {scale, zero_point} = - case mapping_type do - :symmetric -> - max_val_pos = Nx.max(Nx.negate(min_val_neg), max_val_pos) - scale = Nx.divide(max_val_pos, Nx.divide(Nx.subtract(quant_max, quant_min), 2)) - zero_point = Nx.broadcast(trunc((quant_max + quant_min + 1) / 2), scale) - {scale, zero_point} - - other -> - raise "unsupported mapping #{other}" - end - - scale = Nx.clip(scale, eps, Nx.reduce_max(scale)) - - {Nx.as_type(scale, scale_dtype), Nx.as_type(zero_point, zero_point_dtype)} + Axon.ModelState.freeze(state, fn [name | _] -> + MapSet.member?(dense_layer_names, name) + end) end - def get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) do - {lower_bound, upper_bound} = - case target_dtype do - {:u, 8} -> {0, 255} - {:s, 8} -> {-128, 127} - {:s, 16} -> {-(2 ** 15), 2 ** 15 - 1} - {:s, 32} -> {-(2 ** 31), 2 ** 31 - 1} - end - - quant_min = - cond do - quant_min == nil -> - lower_bound - - quant_min < lower_bound -> - raise "quant_min out of bounds for target_dtype" - - true -> - quant_min - end - - quant_max = - cond do - quant_max == nil -> - upper_bound - - quant_max > upper_bound -> - raise "quant_max out of bounds for target_dtype" - - true -> - quant_max - end - - {quant_min, quant_max} - end + ## Layers - def get_reduction_params(block_size, input_size) do - if tuple_size(block_size) != tuple_size(input_size) do - raise "block_size and input_size must have the same length" - end + @doc """ + Adds a weight-only quantized dense layer to the network. - {shape_for_reduction, reduction_dims, _} = - block_size - |> Tuple.to_list() - |> Enum.zip(Tuple.to_list(input_size)) - |> Enum.with_index() - |> Enum.reduce({[], [], 0}, fn {{block, input}, i}, {shape, dims, cur_dim} -> - if block != input and block > 1 do - unless rem(input, block) == 0 do - raise "Expecting input size at #{i} dimension: #{input} to be divisible by block_size at #{i} dimension: #{block}" - end + This is equivalent to a dense layer, but works on quantized + weights for reducing model memory footprint. - shape = [block, div(input, block) | shape] - dims = [cur_dim + 1 | dims] - cur_dim = cur_dim + 2 + Compiles to `Axon.Quantization.Layers.weight_only_quantized_dense/4`. - {shape, dims, cur_dim} - else - shape = [input | shape] - dims = if block != 1, do: [cur_dim | dims], else: dims - cur_dim = cur_dim + 1 + ## Options - {shape, dims, cur_dim} - end - end) + * `:name` - layer name. - {List.to_tuple(Enum.reverse(shape_for_reduction)), Enum.reverse(reduction_dims)} - end + * `:kernel_initializer` - initializer for `kernel` weights. + Defaults to `:glorot_uniform`. - ## Layers + * `:bias_initializer` - initializer for `bias` weights. Defaults + to `:zeros`. - def weight_only_quantized_dense(input, units, opts \\ []) do + * `:use_bias` - whether the layer should add bias to the output. + Defaults to `true`. + """ + def weight_only_quantized_dense(x, units, opts \\ []) do opts = Keyword.validate!(opts, [ :name, :meta, + use_bias: true, kernel_initializer: :glorot_uniform, bias_initializer: :zeros ]) + meta = + opts[:meta] || + %{} + |> Map.put(:units, units) + |> Map.put(:use_bias, opts[:use_bias]) + kernel_shape = &Axon.Shape.dense_kernel(&1, units) bias_shape = &Axon.Shape.dense_bias(&1, units) - scales_shape = &quantized_dense_scale/1 - kernel = Axon.param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) - bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) - # TODO: This requires dependent initializers - scales = Axon.param("scales", scales_shape, initializer: :zeros) + kernel = + Axon.param("kernel", kernel_shape, + initializer: fn shape, type, key -> + fun = + case opts[:kernel_initializer] do + init when is_atom(init) -> + apply(Axon.Initializers, []) - Axon.layer(&Layers.weight_only_quantized_dense/5, [input, kernel, bias, scales], - meta: opts[:meta], - name: opts[:name], - op_name: :weight_only_quantized_dense - ) - end + fun when is_function(fun) -> + fun + end - defp quantized_dense_scale(input_shape) do - Nx.axis_size(input_shape, -1) - end + tensor = + case fun do + fun when is_function(fun, 2) -> + fun.(shape, type) + + fun when is_function(fun, 3) -> + fun.(shape, type, key) + end + + QTensor.from_tensor(tensor) + end + ) - ## Quantizers + {inputs, op} = + if opts[:use_bias] do + bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) + {[x, kernel, bias], &Layers.weight_only_quantized_dense/4} + else + {[x, kernel], &Layers.weight_only_quantized_dense/3} + end + + Axon.layer(op, inputs, name: opts[:name], meta: meta, op_name: :dense) + end end diff --git a/lib/axon/quantization/layers.ex b/lib/axon/quantization/layers.ex index 93c943c0..d7613939 100644 --- a/lib/axon/quantization/layers.ex +++ b/lib/axon/quantization/layers.ex @@ -2,17 +2,42 @@ defmodule Axon.Quantization.Layers do @moduledoc """ Quantized Layer Implementations. """ - + alias Axon.Quantization.QTensor import Nx.Defn - # TODO: Make this more general + @doc """ + Weight-only quantized version of a dense layer. + + It expects the input kernel to be an `Axon.Quantization.QTensor`. + """ + deftransform weight_only_quantized_dense(input, kernel, bias \\ 0, opts \\ []) do + {bias, opts} = + case bias do + %Nx.Tensor{} = bias -> + {bias, opts} + + bias when is_number(bias) -> + {bias, opts} - defn weight_only_quantized_dense(x, kernel, bias, scales, _opts \\ []) do - # TODO: Flatten x if necessary + opts when is_list(opts) -> + {Nx.tensor(0), opts} + + other -> + raise ArgumentError, "invalid bias, expected a tensor, got #{inspect(other)}" + end + + weight_only_quantized_dense_impl(input, kernel, bias, opts) + end - x - |> Nx.dot(Nx.as_type(kernel, Nx.type(x))) - |> Nx.multiply(scales) + defnp weight_only_quantized_dense_impl( + input, + %QTensor{value: kernel, scale: scale}, + bias, + _opts + ) do + input + |> Nx.dot([Nx.rank(input) - 1], Nx.as_type(kernel, Nx.type(input)), [0]) + |> Nx.multiply(scale) |> Nx.add(bias) end end diff --git a/lib/axon/quantization/q_tensor.ex b/lib/axon/quantization/q_tensor.ex new file mode 100644 index 00000000..a7f7749d --- /dev/null +++ b/lib/axon/quantization/q_tensor.ex @@ -0,0 +1,233 @@ +defmodule Axon.Quantization.QTensor do + @moduledoc """ + Representation of a quantized tensor. + + A quantized tensor stores information about the quantized + value, scale, and zero-point. This module contains lower-level + functions for converting to and from quantized tensors. + + In most cases, you should prefer to use the public APIs in + `Axon.Quantization`. + """ + import Nx.Defn + + @derive {Nx.Container, containers: [:value, :scale, :zero_point]} + defstruct [:value, :scale, :zero_point] + + @doc """ + Converts a regular float tensor into a quantized tensor. + """ + deftransform from_tensor(x, opts \\ []) do + opts = Keyword.validate!(opts, type: {:s, 8}) + + case opts[:type] do + {:s, 8} -> + dynamically_quantize_per_channel(x, min: -128, max: 127, type: {:s, 8}) + + other -> + raise "unsupported quantization type #{inspect(other)}" + end + end + + deftransformp dynamically_quantize_per_channel(input, opts \\ []) do + opts = Keyword.validate!(opts, [:min, :max, :type]) + + unless Nx.type(input) == {:f, 32}, do: raise(ArgumentError, "expected a float tensor") + unless Nx.rank(input) == 2, do: raise(ArgumentError, "expected a 2d tensor") + + target_dtype = opts[:type] + eps = Nx.Constants.epsilon(:f32) + block_size = {1, Nx.axis_size(input, 1)} + zero_point_type = {:s, 64} + + {scale, zero_point} = + choose_quantization_params_affine(input, + mapping_type: :symmetric, + block_size: block_size, + type: opts[:type], + min: opts[:min], + max: opts[:max], + eps: eps, + zero_point_type: zero_point_type + ) + + quantized_value = + quantize_affine(input, scale, zero_point, + block_size: block_size, + type: target_dtype, + min: opts[:min], + max: opts[:max] + ) + + struct(__MODULE__, value: quantized_value, scale: scale, zero_point: zero_point) + end + + deftransformp quantize_affine( + input, + scale, + zero_point, + opts \\ [] + ) do + opts = Keyword.validate!(opts, [:block_size, :type, :min, :max, zero_point_domain: :int]) + + target_dtype = opts[:type] + quant_min = opts[:min] + quant_max = opts[:max] + block_size = opts[:block_size] + zero_point_domain = opts[:zero_point_domain] + + {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) + + original_shape = Nx.shape(input) + input = Nx.reshape(input, shape_for_reduction) + + scale_shape = + Enum.reduce(reduction_dims, shape_for_reduction, fn i, shape -> + put_elem(shape, i, 1) + end) + + scale = Nx.reshape(scale, scale_shape) + zero_point = Nx.reshape(zero_point, scale_shape) + + quant = + case zero_point_domain do + :int -> + Nx.clip( + Nx.add(Nx.round(Nx.multiply(input, Nx.divide(1, scale))), zero_point), + quant_min, + quant_max + ) + + other -> + raise "unsupported zero point domain #{other}" + end + + Nx.as_type(Nx.reshape(quant, original_shape), target_dtype) + end + + deftransformp choose_quantization_params_affine(input, opts \\ []) do + opts = + Keyword.validate!(opts, [ + :mapping_type, + :block_size, + :type, + :min, + :max, + :eps, + :scale_type, + :zero_point_type, + :zero_point_domain, + preserve_zero: true + ]) + + mapping_type = opts[:mapping_type] + block_size = opts[:block_size] + target_dtype = opts[:type] + preserve_zero = opts[:preserve_zero] + + {quant_min, quant_max} = + get_and_check_qmin_qmax(target_dtype, opts[:min], opts[:max]) + + scale_dtype = opts[:scale_type] || Nx.type(input) + zero_point_dtype = opts[:zero_point_type] || Nx.type(input) + eps = opts[:eps] || Nx.Constants.epsilon(Nx.type(input)) + + {shape_for_reduction, reduction_dims} = get_reduction_params(block_size, Nx.shape(input)) + input = Nx.reshape(input, shape_for_reduction) + + min_val = Nx.reduce_min(input, axes: reduction_dims, keep_axes: false) + max_val = Nx.reduce_max(input, axes: reduction_dims, keep_axes: false) + + {min_val_neg, max_val_pos} = + if preserve_zero do + {Nx.min(min_val, Nx.broadcast(0, min_val)), Nx.max(max_val, Nx.broadcast(0, max_val))} + else + {min_val, max_val} + end + + {scale, zero_point} = + case mapping_type do + :symmetric -> + max_val_pos = Nx.max(Nx.negate(min_val_neg), max_val_pos) + scale = Nx.divide(max_val_pos, Nx.divide(Nx.subtract(quant_max, quant_min), 2)) + zero_point = Nx.broadcast(trunc((quant_max + quant_min + 1) / 2), scale) + {scale, zero_point} + + other -> + raise "unsupported mapping #{other}" + end + + scale = Nx.clip(scale, eps, Nx.reduce_max(scale)) + + {Nx.as_type(scale, scale_dtype), Nx.as_type(zero_point, zero_point_dtype)} + end + + deftransformp get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) do + {lower_bound, upper_bound} = + case target_dtype do + {:u, 8} -> {0, 255} + {:s, 8} -> {-128, 127} + {:s, 16} -> {-(2 ** 15), 2 ** 15 - 1} + {:s, 32} -> {-(2 ** 31), 2 ** 31 - 1} + end + + quant_min = + cond do + quant_min == nil -> + lower_bound + + quant_min < lower_bound -> + raise "quant_min out of bounds for target_dtype" + + true -> + quant_min + end + + quant_max = + cond do + quant_max == nil -> + upper_bound + + quant_max > upper_bound -> + raise "quant_max out of bounds for target_dtype" + + true -> + quant_max + end + + {quant_min, quant_max} + end + + deftransformp get_reduction_params(block_size, input_size) do + if tuple_size(block_size) != tuple_size(input_size) do + raise "block_size and input_size must have the same length" + end + + {shape_for_reduction, reduction_dims, _} = + block_size + |> Tuple.to_list() + |> Enum.zip(Tuple.to_list(input_size)) + |> Enum.with_index() + |> Enum.reduce({[], [], 0}, fn {{block, input}, i}, {shape, dims, cur_dim} -> + if block != input and block > 1 do + unless rem(input, block) == 0 do + raise "Expecting input size at #{i} dimension: #{input} to be divisible by block_size at #{i} dimension: #{block}" + end + + shape = [block, div(input, block) | shape] + dims = [cur_dim + 1 | dims] + cur_dim = cur_dim + 2 + + {shape, dims, cur_dim} + else + shape = [input | shape] + dims = if block != 1, do: [cur_dim | dims], else: dims + cur_dim = cur_dim + 1 + + {shape, dims, cur_dim} + end + end) + + {List.to_tuple(Enum.reverse(shape_for_reduction)), Enum.reverse(reduction_dims)} + end +end diff --git a/lib/axon/shared.ex b/lib/axon/shared.ex index 6279488a..87eff5ae 100644 --- a/lib/axon/shared.ex +++ b/lib/axon/shared.ex @@ -192,9 +192,6 @@ defmodule Axon.Shared do defp recur_deep_reduce(value, acc, fun) do case value do - %Axon{} = val -> - fun.(val, acc) - %Nx.Tensor{} = val -> fun.(val, acc) @@ -217,9 +214,6 @@ defmodule Axon.Shared do defp recur_deep_map_reduce(leaf, acc, fun) do case leaf do - %Axon{} = leaf -> - fun.(leaf, acc) - %Nx.Tensor{} = leaf -> fun.(leaf, acc) diff --git a/test/axon/quantization_test.exs b/test/axon/quantization_test.exs new file mode 100644 index 00000000..4a289ce0 --- /dev/null +++ b/test/axon/quantization_test.exs @@ -0,0 +1,45 @@ +defmodule Axon.QuantizationTest do + use Axon.Case, async: true + + alias Axon.ModelState + alias Axon.Quantization.QTensor + + describe "quantize_model_state" do + test "replaces dense kernels with quantized versions" do + model = + Axon.input("input") + |> Axon.dense(10, activation: :relu) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert %{data: %{"dense_0" => %{"kernel" => %QTensor{}}}} = + Axon.Quantization.quantize_model_state(model, model_state) + end + end + + describe "quantize" do + test "returns model and state that execute properly" do + model = + Axon.input("input") + |> Axon.dense(10, activation: :relu) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert {quantized_model, quantized_model_state} = + Axon.Quantization.quantize(model, model_state) + + assert {_, predict_fn} = Axon.build(quantized_model) + + real_fn = fn %{data: %{"dense_0" => %{"kernel" => k, "bias" => b}}}, input -> + input + |> Axon.Quantization.Layers.weight_only_quantized_dense(k, b) + |> Axon.Activations.relu() + end + + inp = Nx.broadcast(1.0, {1, 1}) + assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp)) + end + end +end From 60775c0f0e2e563c3845325dfb563e97867c85d1 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Wed, 24 Jul 2024 11:30:10 -0400 Subject: [PATCH 3/3] Docs --- lib/axon/quantization.ex | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex index 679ca605..ef11dee3 100644 --- a/lib/axon/quantization.ex +++ b/lib/axon/quantization.ex @@ -1,4 +1,17 @@ defmodule Axon.Quantization do + @moduledoc """ + Model quantization. + + Model quantization is a technique for reducing the memory footprint of + a model by converting portions of a model to use quantized representations. + Typically, these quantized representations are low-precision integers. + + This is an **experimental** API which implements weight-only quantization. + The implementation in this module will convert dense layers in a large + model to quantized-variants. The only supported quantization type is + `{:s, 8}`. Axon quantization is inference-only. Training is not currently + supported. + """ alias Axon.Quantization.Layers alias Axon.Quantization.QTensor