From 3ac8928ef6641e0ea78f9a5f0524b574a835463e Mon Sep 17 00:00:00 2001 From: zhangqi3 Date: Wed, 8 Sep 2021 16:39:13 +0800 Subject: [PATCH] Update version to 0.0.2. --- README.md | 10 ++ mqbench/__init__.py | 1 + mqbench/adaround.py | 18 +--- mqbench/convert_deploy.py | 21 ++-- mqbench/convert_onnx.py | 105 ++++++++++--------- mqbench/custom_quantizer.py | 180 ++++++++++++++++++++++++++++++--- mqbench/fake_quantize/nnie.py | 3 +- mqbench/observer.py | 42 +++++++- mqbench/prepare_by_platform.py | 60 +++++++---- mqbench/utils/logger.py | 6 +- mqbench/utils/utils.py | 24 ++++- requirements.txt | 1 + test/observer/test_observer.py | 30 ++++-- 13 files changed, 379 insertions(+), 122 deletions(-) create mode 100644 requirements.txt diff --git a/README.md b/README.md index f215531..2bb64ce 100644 --- a/README.md +++ b/README.md @@ -1 +1,11 @@ [MQBench](http://mqbench.tech/assets/docs/html/) + +# Update V0.0.2 + +- Fix academic prepare setting. +- More deployable prepare process. +- Fix setup.py. +- Fix deploy on SNPE. +- Fix convert_deploy bug. +- Add Quantile observer. +- Other update. diff --git a/mqbench/__init__.py b/mqbench/__init__.py index e69de29..d18f409 100644 --- a/mqbench/__init__.py +++ b/mqbench/__init__.py @@ -0,0 +1 @@ +__version__ = '0.0.2' diff --git a/mqbench/adaround.py b/mqbench/adaround.py index 50c4d96..fe38308 100644 --- a/mqbench/adaround.py +++ b/mqbench/adaround.py @@ -1,4 +1,3 @@ -import copy import os import numpy as np from typing import Callable, Dict @@ -8,7 +7,8 @@ import torch.nn.functional as F from torch.fx import GraphModule, Node -from .observer import MinMaxObserver, ObserverBase +from mqbench.observer import MinMaxObserver, ObserverBase +from mqbench.utils import deepcopy_graphmodule _ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear, ) @@ -95,7 +95,7 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128, cali_data = get_cali_samples(train_data, n_samples) # apply rewritten deepcopy of GraphModule - quant_model = _deepcopy_graphmodule(model) + quant_model = deepcopy_graphmodule(model) quant_model.eval() model.eval() @@ -175,18 +175,6 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128, return quant_model -def _deepcopy_graphmodule(gm: GraphModule): - """Rewrite the deepcopy of GraphModule. (Copy its 'graph'.) - - Args: - gm (GraphModule): - - Returns: - GraphModule: A deepcopied gm. - """ - copied_gm = copy.deepcopy(gm) - copied_gm.graph = copy.deepcopy(gm.graph) - return copied_gm def _insert_observer(gm: GraphModule, insert_type="input"): """Insert observers to record the input and output of target layers. diff --git a/mqbench/convert_deploy.py b/mqbench/convert_deploy.py index afde652..75043df 100644 --- a/mqbench/convert_deploy.py +++ b/mqbench/convert_deploy.py @@ -6,6 +6,7 @@ import mqbench.custom_symbolic_opset # noqa: F401 import mqbench.fusion_method # noqa: F401 from mqbench.prepare_by_platform import BackendType +from mqbench.utils import deepcopy_graphmodule from mqbench.utils.logger import logger from mqbench.utils.registry import ( BACKEND_DEPLOY_FUNCTION, @@ -37,12 +38,16 @@ def convert_merge_bn(model: GraphModule, **kwargs): @register_deploy_function(BackendType.PPLW8A16) @register_deploy_function(BackendType.Tensorrt) @register_deploy_function(BackendType.NNIE) -def convert_onnx(model: GraphModule, input_shape_dict, onnx_model_path='./test.onnx', **kwargs): +def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_path, **kwargs): logger.info("Export to onnx.") - device = next(model.parameters()).device - dummy_input = {name: torch.rand(shape).to(device) for name, shape in input_shape_dict.items()} - torch.onnx.export(model, tuple(dummy_input.values()), onnx_model_path, - input_names=list(dummy_input.keys()), + input_names = None + if dummy_input is None: + device = next(model.parameters()).device + dummy_input = {name: torch.rand(shape).to(device) for name, shape in input_shape_dict.items()} + input_names = list(dummy_input.keys()) + dummy_input = tuple(dummy_input.values()) + torch.onnx.export(model, dummy_input, onnx_model_path, + input_names=input_names, opset_version=11, enable_onnx_checker=False) @@ -72,7 +77,7 @@ def deploy_qparams_pplw8a16(model: GraphModule, onnx_model_path, **kwargs): def convert_deploy(model: GraphModule, backend_type: BackendType, - input_shape_dict, output_path='./', + input_shape_dict=None, dummy_input=None, output_path='./', model_name='mqbench_model_quantized.onnx'): r"""Convert model to onnx model and quantization params depends on backend. @@ -94,9 +99,11 @@ def forward(self, input_0, input_1): """ kwargs = { 'input_shape_dict': input_shape_dict, + 'dummy_input': dummy_input, 'output_path': output_path, 'model_name': model_name, 'onnx_model_path': osp.join(output_path, model_name) } + deploy_model = deepcopy_graphmodule(model) for convert_function in BACKEND_DEPLOY_FUNCTION[backend_type]: - convert_function(model, **kwargs) + convert_function(deploy_model, **kwargs) diff --git a/mqbench/convert_onnx.py b/mqbench/convert_onnx.py index 1718a60..193e074 100644 --- a/mqbench/convert_onnx.py +++ b/mqbench/convert_onnx.py @@ -3,6 +3,7 @@ import onnx from onnx import numpy_helper import numpy as np +from mqbench.utils.logger import logger perchannel_fakequantizer = ['FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine', 'FakeQuantizeDSQPerchannel'] pertensor_fakequantizer = ['LearnablePerTensorAffine', 'FixedPerTensorAffine', 'FakeQuantizeDSQPertensor'] @@ -70,18 +71,21 @@ def get_constant_inputs(node, out2node): class OnnxPreprocess(object): def replace_resize_op_with_upsample(self, graph, out2node): nodes_to_be_removed = [] - idx = 0 + idx = 0 while idx < len(graph.node): node = graph.node[idx] if node.op_type == 'Resize': - print(f"Replace resize op: <{node.name}> with upsample.") - attrs = parse_attrs(node.attribute) + logger.info(f"Replace resize op: <{node.name}> with upsample.") + mode = 'nearest' + for attr in node.attribute: + if attr.name == 'mode': + mode = attr.s upsample_node = onnx.helper.make_node('Upsample', name=node.name, inputs=[node.input[0], node.input[2]], outputs=node.output, - mode=attrs['mode']) - nodes_to_be_removed.append(node) + mode=mode) + nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) graph.node.insert(idx, upsample_node) idx += 1 @@ -97,11 +101,11 @@ def remove_fake_pad_op(self, graph, name2data, inp2node, out2node): if node.op_type == 'Pad': pads = name2data[node.input[1]] if all([x == 0 for x in pads]): - print(f"Remove pad op: <{node.name}>.") + logger.info(f"Remove pad op: <{node.name}>.") next_nodes = inp2node[node.output[0]] for next_node, idx in next_nodes: next_node.input[idx] = node.input[0] - nodes_to_be_removed.append(node) + nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) for node in nodes_to_be_removed: graph.node.remove(node) @@ -136,11 +140,9 @@ def gen_gfpq_param_file(self, graph, clip_val): interp_layer_name = node.name gfpq_param_dict[interp_layer_name + '_permute_' + str(interp_layer_cnt)] = gfpq_param_dict[interp_layer_name] interp_layer_cnt += 1 + return gfpq_param_dict - with open(os.path.join('./', 'nnie_gfpq_param_dict.json'), 'w') as f: - json.dump({"nnie": {"gfpq_param_dict": gfpq_param_dict}}, f, indent=4) - - def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nnie_deploy_model.onnx"): + def remove_fakequantize_and_collect_params(self, onnx_path): model = onnx.load(onnx_path) graph = model.graph out2node, inp2node = update_inp2node_out2node(graph) @@ -168,7 +170,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nni new_data = np.clip(data, -clip_range, clip_range) new_data = numpy_helper.from_array(new_data) named_initializer[tensor_name].raw_data = new_data.raw_data - print(f'clip weights {tensor_name} to range [{-clip_range}, {clip_range}].') + logger.info(f'Clip weights {tensor_name} to range [{-clip_range}, {clip_range}].') else: # fake quantize for activations clip_ranges[node.input[0]] = name2data[node.input[1]] @@ -181,8 +183,15 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nni for node in nodes_to_be_removed: graph.node.remove(node) - self.gen_gfpq_param_file(graph, clip_ranges) - onnx.save(model, model_save_path) + gfpq_param_dict = self.gen_gfpq_param_file(graph, clip_ranges) + + output_path = os.path.dirname(onnx_path) + filename = os.path.join(output_path, 'nnie_gfpq_param_dict.json') + with open(filename, 'w') as f: + json.dump({"nnie": {"gfpq_param_dict": gfpq_param_dict}}, f, indent=4) + filename = os.path.join(output_path, 'nnie_deploy_model.onnx') + onnx.save(model, filename) + logger.info("Finish deploy process.") remove_fakequantize_and_collect_params_nnie = NNIE_process().remove_fakequantize_and_collect_params @@ -230,11 +239,11 @@ def deal_with_weight_fakequant(self, node, out2node, inp2node, named_initializer next_node.input[idx] = node.input[0] return redundant_nodes - def deal_with_activation_fakequant(self, node, inp2node): + def deal_with_activation_fakequant(self, node, inp2node): next_nodes = inp2node[node.output[0]] for next_node, idx in next_nodes: next_node.input[idx] = node.input[0] - return + return def parse_qparams(self, node, name2data): tensor_name, scale, zero_point = node.input[:3] @@ -247,11 +256,11 @@ def parse_qparams(self, node, name2data): qmin = qparams['quant_min'] qmax = qparams['quant_max'] else: - print(f'qmin and qmax are not found for <{node.name}>!') + logger.info(f'qmin and qmax are not found for <{node.name}>!') return tensor_name, scale, zero_point, qmin, qmax def clip_weight(self, node, name2data, named_initializer): - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) data = name2data[tensor_name] clip_range_min = (qmin - zero_point) * scale clip_range_max = (qmax - zero_point) * scale @@ -260,10 +269,10 @@ def clip_weight(self, node, name2data, named_initializer): for c in range(data.shape[0]): new_data.append(np.clip(data[c], clip_range_min[c], clip_range_max[c])) new_data = np.array(new_data) - print(f'clip weights {tensor_name} to per-cahnnel clip range.') + logger.info(f'Clip weights <{tensor_name}> to per-channel ranges.') else: new_data = np.clip(data, clip_range_min, clip_range_max) - print(f'clip weights {tensor_name} to range [{clip_range_min}, {clip_range_max}].') + logger.info(f'Clip weights <{tensor_name}> to range [{clip_range_min}, {clip_range_max}].') new_data = numpy_helper.from_array(new_data) named_initializer[tensor_name].raw_data = new_data.raw_data @@ -271,7 +280,7 @@ def post_process_clip_ranges(self, clip_ranges, graph, inp2node): def find_the_closest_clip_range(node): if node.input[0] in clip_ranges: return node.input[0] - elif node.op_type in ['Flatten', 'Resize']: + elif node.op_type in ['Flatten', 'Resize'] and node.output[0] in inp2node: return find_the_closest_clip_range(inp2node[node.output[0]][0][0]) else: return None @@ -281,7 +290,7 @@ def find_the_closest_clip_range(node): tensor_name = find_the_closest_clip_range(node) if tensor_name: clip_ranges[node.input[0]] = clip_ranges[tensor_name] - print(f'Pass <{tensor_name}> clip range to <{node.name}> input <{node.input[0]}>.') + logger.info(f'Pass <{tensor_name}> clip range to <{node.name}> input <{node.input[0]}>.') return clip_ranges def remove_fakequantize_and_collect_params(self, onnx_path, backend): @@ -300,22 +309,22 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend): nodes_to_be_removed = [] for node in graph.node: if node.op_type in all_fakequantizer: - nodes_to_be_removed.append(node) + nodes_to_be_removed.append(node) nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) - if node.op_type in perchannel_fakequantizer: + if node.op_type in perchannel_fakequantizer: # fake quantize for weights, suppose per-channel quantize only for weight redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) nodes_to_be_removed.extend(redundant_nodes) self.clip_weight(node, name2data, named_initializer) if backend == 'ppl': - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) - clip_ranges[tensor_name] = {'step': [float(x) for x in scale], - 'zero_point': [int(x) for x in zero_point], - 'min': [float(x) for x in scale * (qmin - zero_point)], - 'max': [float(x) for x in scale * (qmax - zero_point)], - 'bit': int(np.log2(qmax - qmin + 1)), - 'type': "biased", + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + clip_ranges[tensor_name] = {'step': [float(x) for x in scale], + 'zero_point': [int(x) for x in zero_point], + 'min': [float(x) for x in scale * (qmin - zero_point)], + 'max': [float(x) for x in scale * (qmax - zero_point)], + 'bit': int(np.log2(qmax - qmin + 1)), + 'type': "biased", } elif node.op_type in pertensor_fakequantizer: @@ -326,13 +335,13 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend): if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: # fake quantize for weights redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) nodes_to_be_removed.extend(redundant_nodes) self.clip_weight(node, name2data, named_initializer) else: # fake quantize for activations self.deal_with_activation_fakequant(node, inp2node) - tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) for out in graph.output: if out.name == node.output[0]: out.name = tensor_name @@ -340,18 +349,19 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend): if backend == 'tensorrt': clip_ranges[tensor_name] = float(scale * min(-qmin, qmax)) elif backend == 'snpe': - clip_ranges[tensor_name] = {'bitwidth': int(np.log2(qmax - qmin + 1)), - 'min': float(scale * (qmin - zero_point)), - 'max': float(scale * (qmax - zero_point)) - } + clip_ranges[tensor_name] = [ + {'bitwidth': int(np.log2(qmax - qmin + 1)), + 'min': float(scale * (qmin - zero_point)), + 'max': float(scale * (qmax - zero_point))} + ] if backend == 'ppl': - clip_ranges[tensor_name] = {'step': float(scale), - 'zero_point': int(zero_point), - 'min': float(scale * (qmin - zero_point)), - 'max': float(scale * (qmax - zero_point)), - 'bit': int(np.log2(qmax - qmin + 1)), - 'type': "biased", - } + clip_ranges[tensor_name] = {'step': float(scale), + 'zero_point': int(zero_point), + 'min': float(scale * (qmin - zero_point)), + 'max': float(scale * (qmax - zero_point)), + 'bit': int(np.log2(qmax - qmin + 1)), + 'type': "biased", + } for node in nodes_to_be_removed: graph.node.remove(node) @@ -363,9 +373,12 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend): context = {'activation_encodings': clip_ranges, 'param_encodings': {}} elif backend == 'ppl': context = {"ppl": clip_ranges} - filename = os.path.join('./', '{}_clip_ranges.json'.format(backend)) + output_path = os.path.dirname(onnx_path) + filename = os.path.join(output_path, '{}_clip_ranges.json'.format(backend)) with open(filename, 'w') as f: json.dump(context, f, indent=4) - onnx.save(model, '{}_deploy_model.onnx'.format(backend)) + filename = os.path.join(output_path, '{}_deploy_model.onnx'.format(backend)) + onnx.save(model, filename) + logger.info("Finish deploy process.") remove_fakequantize_and_collect_params = LinearQuantizer_process().remove_fakequantize_and_collect_params diff --git a/mqbench/custom_quantizer.py b/mqbench/custom_quantizer.py index 3709794..fcac7af 100644 --- a/mqbench/custom_quantizer.py +++ b/mqbench/custom_quantizer.py @@ -1,3 +1,4 @@ +import copy import operator from typing import ( Dict, Any, Callable @@ -24,6 +25,9 @@ from torch.quantization.fx.qconfig_utils import ( get_flattened_qconfig_dict ) +from torch.quantization.quantize_fx import ( + _fuse_fx +) import mqbench.nn as qnn import mqbench.nn.intrinsic @@ -33,31 +37,32 @@ from mqbench.prepare_by_platform import BackendType -@register_model_quantizer(BackendType.SNPE) @register_model_quantizer(BackendType.NNIE) -@register_model_quantizer(BackendType.Academic) class ModelQuantizer(object): """General model quantizer class. First, replace common float module to nn.qat.modules to make weight fake quantized. Second, insert activation fake quantize node before specific layers. Layer type is defined in function_type_to_quant_input / module_type_to_quant_input. - We leave the output not quantized since it is next layer's input. + We only quantize the inputs of layers and leave the output not quantized + since it is next layer's input. """ - def __init__(self, extra_quantizer_dict): + def __init__(self, extra_quantizer_dict, extra_fuse_dict): self.additional_function_type = extra_quantizer_dict.get('additional_function_type', []) self.additional_module_type = extra_quantizer_dict.get('additional_module_type', ()) self.exclude_module_name = extra_quantizer_dict.get('exclude_module_name', []) + self.extra_fuse_dict = extra_fuse_dict - def prepare(self, model: GraphModule, qconfig_dict: Dict): - model = self._weight_quant(model, qconfig_dict) - model = self._insert_fake_quantize_for_act_quant(model, qconfig_dict) + def prepare(self, model: GraphModule, qconfig): + model = _fuse_fx(model, self.extra_fuse_dict) + model = self._weight_quant(model, qconfig) + model = self._insert_fake_quantize_for_act_quant(model, qconfig) return model def _insert_fake_quantize_for_act_quant( self, model: GraphModule, - qconfig_dict: Any): + qconfig: Any): graph = model.graph nodes = list(model.graph.nodes) @@ -65,7 +70,7 @@ def _insert_fake_quantize_for_act_quant( node_to_quantize_output = self._find_act_quants(model) for node in node_to_quantize_output: - fake_quantizer = qconfig_dict.activation() + fake_quantizer = qconfig.activation() quantizer_name = node.name + quantizer_prefix setattr(model, quantizer_name, fake_quantizer) logger.info("Insert act quant {}".format(quantizer_name)) @@ -90,13 +95,54 @@ def _fix_succ_recursivly(self, args_tuple, target_node, inserted_node): args_tuple = tuple(_tmp) return args_tuple - def _weight_quant(self, model: GraphModule, qconfig_dict: Dict): + def _weight_quant(self, model: GraphModule, qconfig): logger.info("Replace module to qat module.") - flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig_dict}) + flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig}) propagate_qconfig_(model, flattened_qconfig_dict) self._qat_swap_modules(model, self._additional_qat_module_mapping) return model + @property + def implicit_merge_patterns(self) -> list: + # Layers which do not need quantize among them. + # In reversed order! + return [ + (operator.add, operator.mul) + ] + + def _on_merge_chain(self, modules, pattern, pair, p_pos=0, v_pos=0): + if v_pos == len(pair): + return True + if p_pos == len(pattern): + return v_pos == len(pair) + node = pair[v_pos] + cur_pattern = pattern[p_pos] + # Means current node is matched. + if (node.op == "call_module" and type(modules[node.target]) == cur_pattern) or \ + ((node.op == 'call_function' or node.op == 'call_method') and + node.target == cur_pattern): + # Means compairing pair. + if len(pattern) > p_pos and len(pair) > v_pos: + return self._on_merge_chain(modules, pattern, pair, p_pos + 1, v_pos + 1) + # Means compairing extra node. + matched = False + flatten_args = self._flatten_args(node.args) + for _arg in flatten_args: + extra_pair = (*pair, _arg) + if isinstance(_arg, torch.fx.node.Node) and \ + self._on_merge_chain(modules, pattern, extra_pair, p_pos + 1, v_pos + 1): + matched = True + return matched + # Current node is not matched, skip to next. + else: + return self._on_merge_chain(modules, pattern, pair, p_pos + 1, v_pos) + + def _is_implicit_merge(self, modules, pair): + for pattern in self.implicit_merge_patterns: + if self._on_merge_chain(modules, pattern, pair): + return True + return False + @property def function_type_to_quant_input(self) -> list: return [ @@ -113,6 +159,8 @@ def module_type_to_quant_input(self) -> tuple: torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d, torch.nn.intrinsic.qat.modules.conv_fused.ConvBn2d, torch.nn.qat.modules.conv.Conv2d, + # ConvTranspose + torch.nn.ConvTranspose2d, # Linear torch.nn.qat.modules.linear.Linear, qnn.intrinsic.qat.LinearBn1d, @@ -153,6 +201,9 @@ def _find_act_quants(self, model: GraphModule) -> (set, set): input_node_list = self._flatten_args(node.args) for _node in input_node_list: if isinstance(_node, torch.fx.node.Node): + if self._is_implicit_merge(modules, (node, _node)): + logger.info("Implicit merge: {} + {}".format(_node.name, node.name)) + continue node_need_to_quantize_output.append(_node) return set(node_need_to_quantize_output) @@ -191,13 +242,107 @@ def _convert(self, module, mapping=None, inplace=False, scope=''): return module +@register_model_quantizer(BackendType.Academic) +class AcademicQuantizer(ModelQuantizer): + """Academic setting mostly do not merge BN and leave the first and last layer to higher bits. + """ + def __init__(self, extra_quantizer_dict, extra_fuse_dict): + super().__init__(extra_quantizer_dict, extra_fuse_dict) + self.io_module = {} + self.post_act_8bit_node_name = [] + + def prepare(self, model: GraphModule, qconfig): + self._get_io_module(model) + self._get_post_act_8bit_node_name(model) + model = self._weight_quant(model, qconfig) + model = self._insert_fake_quantize_for_act_quant(model, qconfig) + return model + + def _weight_quant(self, model: GraphModule, qconfig): + logger.info("Replace module to qat module.") + wqconfig_8bit = copy.deepcopy(qconfig) + wqconfig_8bit.weight.p.keywords['quant_min'] = -128 + wqconfig_8bit.weight.p.keywords['quant_max'] = 127 + for name, module in model.named_modules(): + if name in self.io_module.keys(): + logger.info("Set layer {} to 8 bit.".format(name)) + module.qconfig = wqconfig_8bit + flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig}) + propagate_qconfig_(model, flattened_qconfig_dict) + self._qat_swap_modules(model, self._additional_qat_module_mapping) + return model + + @property + def function_type_to_quant_input(self) -> list: + return self.additional_function_type + + @property + def module_type_to_quant_input(self) -> tuple: + return ( + # Conv + torch.nn.qat.modules.conv.Conv2d, + # Linear + torch.nn.qat.modules.linear.Linear, + ) + self.additional_module_type + + def _get_post_act_8bit_node_name(self, model): + for node in self.io_module.values(): + for _arg in node.args: + if isinstance(_arg, torch.fx.node.Node): + self.post_act_8bit_node_name.append(_arg.name) + + def _get_io_module(self, model): + total_args = [] + nodes = list(model.graph.nodes) + for node in nodes: + the_first_layer = False + for _arg in node.args: + if isinstance(_arg, torch.fx.node.Node): + if _arg.op == 'placeholder': + the_first_layer = True + total_args.append(_arg.name) + if the_first_layer: + self.io_module[node.target] = node + if node.op == 'output': + for _arg in node.args: + if isinstance(_arg, torch.fx.node.Node): + self.io_module[_arg.target] = _arg + + def _insert_fake_quantize_for_act_quant(self, model: GraphModule, qconfig): + graph = model.graph + nodes = list(model.graph.nodes) + + quantizer_prefix = "_post_act_fake_quantizer" + node_to_quantize_output = self._find_act_quants(model) + + aqconfig_8bit = copy.deepcopy(qconfig.activation) + aqconfig_8bit.p.keywords['quant_min'] = -128 + aqconfig_8bit.p.keywords['quant_max'] = 127 + for node in node_to_quantize_output: + if node.name in self.post_act_8bit_node_name: + logger.info("Set {} post act quantize to 8 bit.".format(node.name)) + fake_quantizer = aqconfig_8bit() + else: + fake_quantizer = qconfig.activation() + quantizer_name = node.name + quantizer_prefix + setattr(model, quantizer_name, fake_quantizer) + logger.info("Insert act quant {}".format(quantizer_name)) + with graph.inserting_after(node): + inserted_node = graph.create_node("call_module", quantizer_name, (node,), {}) + for _node in nodes: + _node.args = self._fix_succ_recursivly(_node.args, node, inserted_node) + + model.recompile() + model.graph.lint() + return model + @register_model_quantizer(BackendType.Tensorrt) class TRTModelQuantizer(ModelQuantizer): """The different points of TRT quantizer are how to deal with add op and the last layer. """ - def __init__(self, extra_quantizer_dict): - super().__init__(extra_quantizer_dict) + def __init__(self, extra_quantizer_dict, extra_fuse_dict): + super().__init__(extra_quantizer_dict, extra_fuse_dict) @property def _merge_add_type(self): @@ -222,6 +367,8 @@ def _find_act_quants(self, model: GraphModule) -> set: node_need_to_quantize_output.extend(input_node_list) else: for _node in input_node_list: + if self._is_implicit_merge(modules, (node, _node)): + continue if isinstance(_node, torch.fx.node.Node): node_need_to_quantize_output.append(_node) @@ -245,6 +392,7 @@ def _find_add_merge_node(self, model, input_node_list, node): return None +@register_model_quantizer(BackendType.SNPE) @register_model_quantizer(BackendType.PPLW8A16) class TotalINTQuantizer(ModelQuantizer): """There is only INT8 calculations in the model. @@ -252,8 +400,8 @@ class TotalINTQuantizer(ModelQuantizer): of the last layers. We quantize every activations tensors and weight tensors using this method. """ - def __init__(self, extra_quantizer_dict): - super().__init__(extra_quantizer_dict) + def __init__(self, extra_quantizer_dict, extra_fuse_dict): + super().__init__(extra_quantizer_dict, extra_fuse_dict) def _find_act_quants(self, model: GraphModule) -> (set, set): node_need_to_quantize_output = super(). _find_act_quants(model) @@ -263,4 +411,4 @@ def _find_act_quants(self, model: GraphModule) -> (set, set): for output_node in self._flatten_args(node.args): node_need_to_quantize_output.add(output_node) - return set(node_need_to_quantize_output) \ No newline at end of file + return set(node_need_to_quantize_output) diff --git a/mqbench/fake_quantize/nnie.py b/mqbench/fake_quantize/nnie.py index ebab029..d05c18d 100644 --- a/mqbench/fake_quantize/nnie.py +++ b/mqbench/fake_quantize/nnie.py @@ -15,7 +15,8 @@ def forward(self, X): self.activation_post_process(X.detach()) data_max = torch.max(-self.activation_post_process.min_val, self.activation_post_process.max_val) self.data_max = torch.max(data_max, self.data_max) - X = NNIEQuantizeFunc.apply(X, self.data_max) + if self.fake_quant_enabled[0] == 1: + X = NNIEQuantizeFunc.apply(X, self.data_max) return X diff --git a/mqbench/observer.py b/mqbench/observer.py index aa859b5..8866a5a 100644 --- a/mqbench/observer.py +++ b/mqbench/observer.py @@ -154,6 +154,42 @@ def forward(self, x_orig): return x +class EMAQuantileObserver(ObserverBase): + """Moving average quantile among batches. + """ + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, + quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, ema_ratio=0.9, threshold=0.99999, bins=2048): + super(EMAQuantileObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, pot_scale) + assert self.ch_axis == -1, "Quantile observer only support in per-tensor scheme." + self.ema_ratio = ema_ratio + self.threshold = threshold + self.bins = bins + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch._aminmax(x) + hist = torch.histc(torch.abs(x), bins=self.bins, min=0., max=torch.max(-min_val_cur, max_val_cur)) + cur_total = 0 + clip_value = torch.max(-min_val_cur, max_val_cur) + for i, cnt in enumerate(hist): + if cur_total + cnt >= self.threshold * x.numel(): + clip_value = (i + 0.5) * (max_val_cur / self.bins) + break + + if self.max_val.numel() <= 1 and self.max_val.isinf(): + self.min_val = max(min_val_cur, -clip_value) + self.max_val = min(max_val_cur, clip_value) + else: + self.min_val = self.min_val * self.ema_ratio + max(min_val_cur, -clip_value) * (1.0 - self.ema_ratio) + self.max_val = self.max_val * self.ema_ratio + min(max_val_cur, clip_value) * (1.0 - self.ema_ratio) + return x + + class ClipStdObserver(ObserverBase): """Clip std. """ @@ -241,8 +277,8 @@ class LSQPlusObserver(ObserverBase): LSQ+ observer. ''' - def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False, - quant_min=-128, quant_max=128, ch_axis=-1, pot_scale=False): + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, + quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False): super(LSQPlusObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, ch_axis, pot_scale) @@ -282,4 +318,4 @@ def calculate_qparams(self): zero_point = self.quant_min - torch.round(self.min_val / scale) sync_tensor(scale) sync_tensor(zero_point) - return scale, zero_point \ No newline at end of file + return scale, zero_point diff --git a/mqbench/prepare_by_platform.py b/mqbench/prepare_by_platform.py index 8756afe..880746e 100644 --- a/mqbench/prepare_by_platform.py +++ b/mqbench/prepare_by_platform.py @@ -3,7 +3,7 @@ import torch from torch.fx.symbolic_trace import symbolic_trace -from torch.quantization.quantize_fx import _swap_ff_with_fxff, _fuse_fx +from torch.quantization.quantize_fx import _swap_ff_with_fxff from torch.quantization import QConfig from mqbench.fake_quantize import ( @@ -18,7 +18,8 @@ ClipStdObserver, LSQObserver, MinMaxObserver, - EMAMinMaxObserver + EMAMinMaxObserver, + EMAQuantileObserver ) from mqbench.fuser_method_mappings import fuse_custom_config_dict from mqbench.utils.logger import logger @@ -79,28 +80,28 @@ def __str__(self): a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8), default_weight_quantize=LearnableFakeQuantize, default_act_quantize=LearnableFakeQuantize, - default_weight_observer=LSQObserver, - default_act_observer=LSQObserver), + default_weight_observer=MinMaxObserver, + default_act_observer=EMAMinMaxObserver), BackendType.SNPE: dict(qtype='affine', # noqa: E241 w_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=8), a_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=8), default_weight_quantize=LearnableFakeQuantize, default_act_quantize=LearnableFakeQuantize, - default_weight_observer=LSQObserver, - default_act_observer=LSQObserver), + default_weight_observer=MinMaxObserver, + default_act_observer=EMAMinMaxObserver), BackendType.PPLW8A16: dict(qtype='affine', # noqa: E241 w_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=8), a_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=16), default_weight_quantize=LearnableFakeQuantize, default_act_quantize=LearnableFakeQuantize, - default_weight_observer=LSQObserver, - default_act_observer=LSQObserver) + default_weight_observer=MinMaxObserver, + default_act_observer=EMAMinMaxObserver) } ObserverDict = { 'MinMaxObserver': MinMaxObserver, # noqa: E241 'EMAMinMaxObserver': EMAMinMaxObserver, # More general choice. # noqa: E241 - 'QuantileObserver': None, # TODO quantile. # noqa: E241 + 'EMAQuantileObserver': EMAQuantileObserver, # Quantile observer. # noqa: E241 'ClipStdObserver': ClipStdObserver, # Usually used for DSQ. # noqa: E241 'LSQObserver': LSQObserver # Usually used for LSQ. # noqa: E241 } @@ -114,7 +115,7 @@ def __str__(self): 'PACTFakeQuantize': PACTFakeQuantize # PACT # noqa: E241 } -def get_qconfig_by_platform(deploy_backend: BackendType, extra_qparams): +def get_qconfig_by_platform(deploy_backend: BackendType, extra_qparams: Dict): """ Args: @@ -139,18 +140,26 @@ def get_qconfig_by_platform(deploy_backend: BackendType, extra_qparams): same with w_qscheme. } } - """ + """ w_observer = extra_qparams.get('w_observer', None) if w_observer: + assert w_observer in ObserverDict, \ + 'Do not support observer name: {}'.format(w_observer) w_observer = ObserverDict[w_observer] a_observer = extra_qparams.get('a_observer', None) if a_observer: + assert a_observer in ObserverDict, \ + 'Do not support observer name: {}'.format(w_observer) a_observer = ObserverDict[a_observer] w_fakequantize = extra_qparams.get('w_fakequantize', None) if w_fakequantize: + assert w_fakequantize in FakeQuantizeDict, \ + 'Do not support fakequantize name: {}'.format(w_fakequantize) w_fakequantize = FakeQuantizeDict[w_fakequantize] a_fakequantize = extra_qparams.get('a_fakequantize', None) - if w_fakequantize: + if a_fakequantize: + assert a_fakequantize in FakeQuantizeDict, \ + 'Do not support fakequantize name: {}'.format(a_fakequantize) a_fakequantize = FakeQuantizeDict[a_fakequantize] backend_params = ParamsTable[deploy_backend] @@ -185,9 +194,9 @@ def get_qconfig_by_platform(deploy_backend: BackendType, extra_qparams): a_fakeq_params = extra_qparams.get('a_fakeq_params', {}) # Observer dot not need extra params for now. if not w_observer: - w_observer = MinMaxObserver + w_observer = backend_params['default_weight_observer'] if not a_observer: - a_observer = EMAMinMaxObserver + a_observer = backend_params['default_act_observer'] # Create qconfig. w_qconfig = w_fakequantize.with_args(observer=w_observer, **w_fakeq_params, **w_qscheme.to_observer_params()) @@ -205,11 +214,30 @@ def prepare_qat_fx_by_platform( model: torch.nn.Module, deploy_backend: BackendType, prepare_custom_config_dict: Dict[str, Any] = {}): + """ + Args: + model (torch.nn.Module): + deploy_backend (BackendType): + + >>> prepare_custom_config_dict : { + extra_qconfig_dict : Dict, Find explainations in get_qconfig_by_platform, + extra_quantizer_dict: Extra params for quantizer. + preserve_attr: Dict, Specify attribute of model which should be preserved + after prepare. Since symbolic_trace only store attributes which is + in forward. If model.func1 and model.backbone.func2 should be preserved, + {"": ["func1"], "backbone": ["func2"] } should work. + Attr below is inherited from Pytorch. + concrete_args: Specify input for model tracing. + extra_fuse_dict: Specify extra fusing patterns and functions. + } + + """ assert model.training, 'prepare_qat_fx_custom only works for models in ' + \ 'train mode' logger.info("Quantize model using {} scheme.".format(deploy_backend)) + _swap_ff_with_fxff(model) # Get Qconfig extra_qconfig_dict = prepare_custom_config_dict.get('extra_qconfig_dict', {}) qconfig = get_qconfig_by_platform(deploy_backend, extra_qconfig_dict) @@ -229,13 +257,11 @@ def prepare_qat_fx_by_platform( graph_module = symbolic_trace(model, concrete_args=concrete_args) # Model fusion. extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {}) - _swap_ff_with_fxff(graph_module) extra_fuse_dict.update(fuse_custom_config_dict) - graph_module = _fuse_fx(graph_module, extra_fuse_dict) # Prepare import mqbench.custom_quantizer # noqa: F401 extra_quantizer_dict = prepare_custom_config_dict.get('extra_quantizer_dict', {}) - quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict) + quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict) prepared = quantizer.prepare(graph_module, qconfig) # Restore attr. if 'preserve_attr' in prepare_custom_config_dict: diff --git a/mqbench/utils/logger.py b/mqbench/utils/logger.py index c3ee136..5c23b72 100644 --- a/mqbench/utils/logger.py +++ b/mqbench/utils/logger.py @@ -2,8 +2,8 @@ import sys -QBENCH_LOGGER_NAME = "QBENCH" -logger = logging.getLogger(QBENCH_LOGGER_NAME) +MQBENCH_LOGGER_NAME = "MQBENCH" +logger = logging.getLogger(MQBENCH_LOGGER_NAME) logger.propagate = False stdout_handler = logging.StreamHandler(sys.stdout) fmt = logging.Formatter("[%(name)s] %(levelname)s: %(message)s") @@ -21,4 +21,4 @@ def set_log_level(level): def disable_logging(): - logger.handlers = [] \ No newline at end of file + logger.handlers = [] diff --git a/mqbench/utils/utils.py b/mqbench/utils/utils.py index b128afd..0411fc2 100644 --- a/mqbench/utils/utils.py +++ b/mqbench/utils/utils.py @@ -1,13 +1,17 @@ +import copy + import torch +from torch.fx import GraphModule USE_LINK = False USE_DDP = False try: import spring.linklink as link + assert link.is_initialized() USE_LINK = True -except ModuleNotFoundError: - import torch.distributed as dist +except (ModuleNotFoundError, AssertionError): + import torch.distributed if torch.distributed.is_initialized(): USE_DDP = True @@ -46,4 +50,18 @@ def __exit__(self, *args): def is_tracing_state(): - return torch._C._get_tracing_state() \ No newline at end of file + return torch._C._get_tracing_state() + + +def deepcopy_graphmodule(gm: GraphModule): + """Rewrite the deepcopy of GraphModule. (Copy its 'graph'.) + + Args: + gm (GraphModule): + + Returns: + GraphModule: A deepcopied gm. + """ + copied_gm = copy.deepcopy(gm) + copied_gm.graph = copy.deepcopy(gm.graph) + return copied_gm diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..64eb193 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch==1.8.1 diff --git a/test/observer/test_observer.py b/test/observer/test_observer.py index 83e6f95..20add10 100644 --- a/test/observer/test_observer.py +++ b/test/observer/test_observer.py @@ -7,6 +7,24 @@ class TestObserver(unittest.TestCase): + def test_ema_observer(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAQuantileObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'FixedFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + def test_ema_observer(self): model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) dummy_input = torch.randn(2, 3, 224, 224, device='cpu') @@ -24,8 +42,6 @@ def test_ema_observer(self): enable_quantization(model_prepared) loss = model_prepared(dummy_input).sum() loss.backward() - model_prepared.eval() - convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_ema.onnx') def test_minmax_observer(self): model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) @@ -44,8 +60,6 @@ def test_minmax_observer(self): enable_quantization(model_prepared) loss = model_prepared(dummy_input).sum() loss.backward() - model_prepared.eval() - convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_minmax.onnx') def test_lsq_observer(self): model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) @@ -64,8 +78,6 @@ def test_lsq_observer(self): enable_quantization(model_prepared) loss = model_prepared(dummy_input).sum() loss.backward() - model_prepared.eval() - convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_lsq.onnx') def test_clip_std_observer(self): model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) @@ -83,8 +95,4 @@ def test_clip_std_observer(self): model_prepared(dummy_input) enable_quantization(model_prepared) loss = model_prepared(dummy_input).sum() - loss.backward() - model_prepared.eval() - convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_clip_std.onnx') - - \ No newline at end of file + loss.backward() \ No newline at end of file