Skip to content

Commit

Permalink
Update version to 0.0.2.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqi3 committed Sep 8, 2021
1 parent e20f4ce commit 3ac8928
Show file tree
Hide file tree
Showing 13 changed files with 379 additions and 122 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
[MQBench](http://mqbench.tech/assets/docs/html/)

# Update V0.0.2

- Fix academic prepare setting.
- More deployable prepare process.
- Fix setup.py.
- Fix deploy on SNPE.
- Fix convert_deploy bug.
- Add Quantile observer.
- Other update.
1 change: 1 addition & 0 deletions mqbench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '0.0.2'
18 changes: 3 additions & 15 deletions mqbench/adaround.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import os
import numpy as np
from typing import Callable, Dict
Expand All @@ -8,7 +7,8 @@
import torch.nn.functional as F
from torch.fx import GraphModule, Node

from .observer import MinMaxObserver, ObserverBase
from mqbench.observer import MinMaxObserver, ObserverBase
from mqbench.utils import deepcopy_graphmodule

_ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear, )

Expand Down Expand Up @@ -95,7 +95,7 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,
cali_data = get_cali_samples(train_data, n_samples)

# apply rewritten deepcopy of GraphModule
quant_model = _deepcopy_graphmodule(model)
quant_model = deepcopy_graphmodule(model)
quant_model.eval()
model.eval()

Expand Down Expand Up @@ -175,18 +175,6 @@ def adaround(model: GraphModule, train_data, n_samples: int = 128,

return quant_model

def _deepcopy_graphmodule(gm: GraphModule):
"""Rewrite the deepcopy of GraphModule. (Copy its 'graph'.)
Args:
gm (GraphModule):
Returns:
GraphModule: A deepcopied gm.
"""
copied_gm = copy.deepcopy(gm)
copied_gm.graph = copy.deepcopy(gm.graph)
return copied_gm

def _insert_observer(gm: GraphModule, insert_type="input"):
"""Insert observers to record the input and output of target layers.
Expand Down
21 changes: 14 additions & 7 deletions mqbench/convert_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import mqbench.custom_symbolic_opset # noqa: F401
import mqbench.fusion_method # noqa: F401
from mqbench.prepare_by_platform import BackendType
from mqbench.utils import deepcopy_graphmodule
from mqbench.utils.logger import logger
from mqbench.utils.registry import (
BACKEND_DEPLOY_FUNCTION,
Expand Down Expand Up @@ -37,12 +38,16 @@ def convert_merge_bn(model: GraphModule, **kwargs):
@register_deploy_function(BackendType.PPLW8A16)
@register_deploy_function(BackendType.Tensorrt)
@register_deploy_function(BackendType.NNIE)
def convert_onnx(model: GraphModule, input_shape_dict, onnx_model_path='./test.onnx', **kwargs):
def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_path, **kwargs):
logger.info("Export to onnx.")
device = next(model.parameters()).device
dummy_input = {name: torch.rand(shape).to(device) for name, shape in input_shape_dict.items()}
torch.onnx.export(model, tuple(dummy_input.values()), onnx_model_path,
input_names=list(dummy_input.keys()),
input_names = None
if dummy_input is None:
device = next(model.parameters()).device
dummy_input = {name: torch.rand(shape).to(device) for name, shape in input_shape_dict.items()}
input_names = list(dummy_input.keys())
dummy_input = tuple(dummy_input.values())
torch.onnx.export(model, dummy_input, onnx_model_path,
input_names=input_names,
opset_version=11,
enable_onnx_checker=False)

Expand Down Expand Up @@ -72,7 +77,7 @@ def deploy_qparams_pplw8a16(model: GraphModule, onnx_model_path, **kwargs):


def convert_deploy(model: GraphModule, backend_type: BackendType,
input_shape_dict, output_path='./',
input_shape_dict=None, dummy_input=None, output_path='./',
model_name='mqbench_model_quantized.onnx'):
r"""Convert model to onnx model and quantization params depends on backend.
Expand All @@ -94,9 +99,11 @@ def forward(self, input_0, input_1):
"""
kwargs = {
'input_shape_dict': input_shape_dict,
'dummy_input': dummy_input,
'output_path': output_path,
'model_name': model_name,
'onnx_model_path': osp.join(output_path, model_name)
}
deploy_model = deepcopy_graphmodule(model)
for convert_function in BACKEND_DEPLOY_FUNCTION[backend_type]:
convert_function(model, **kwargs)
convert_function(deploy_model, **kwargs)
105 changes: 59 additions & 46 deletions mqbench/convert_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import onnx
from onnx import numpy_helper
import numpy as np
from mqbench.utils.logger import logger

perchannel_fakequantizer = ['FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine', 'FakeQuantizeDSQPerchannel']
pertensor_fakequantizer = ['LearnablePerTensorAffine', 'FixedPerTensorAffine', 'FakeQuantizeDSQPertensor']
Expand Down Expand Up @@ -70,18 +71,21 @@ def get_constant_inputs(node, out2node):
class OnnxPreprocess(object):
def replace_resize_op_with_upsample(self, graph, out2node):
nodes_to_be_removed = []
idx = 0
idx = 0
while idx < len(graph.node):
node = graph.node[idx]
if node.op_type == 'Resize':
print(f"Replace resize op: <{node.name}> with upsample.")
attrs = parse_attrs(node.attribute)
logger.info(f"Replace resize op: <{node.name}> with upsample.")
mode = 'nearest'
for attr in node.attribute:
if attr.name == 'mode':
mode = attr.s
upsample_node = onnx.helper.make_node('Upsample',
name=node.name,
inputs=[node.input[0], node.input[2]],
outputs=node.output,
mode=attrs['mode'])
nodes_to_be_removed.append(node)
mode=mode)
nodes_to_be_removed.append(node)
nodes_to_be_removed.extend(get_constant_inputs(node, out2node))
graph.node.insert(idx, upsample_node)
idx += 1
Expand All @@ -97,11 +101,11 @@ def remove_fake_pad_op(self, graph, name2data, inp2node, out2node):
if node.op_type == 'Pad':
pads = name2data[node.input[1]]
if all([x == 0 for x in pads]):
print(f"Remove pad op: <{node.name}>.")
logger.info(f"Remove pad op: <{node.name}>.")
next_nodes = inp2node[node.output[0]]
for next_node, idx in next_nodes:
next_node.input[idx] = node.input[0]
nodes_to_be_removed.append(node)
nodes_to_be_removed.append(node)
nodes_to_be_removed.extend(get_constant_inputs(node, out2node))
for node in nodes_to_be_removed:
graph.node.remove(node)
Expand Down Expand Up @@ -136,11 +140,9 @@ def gen_gfpq_param_file(self, graph, clip_val):
interp_layer_name = node.name
gfpq_param_dict[interp_layer_name + '_permute_' + str(interp_layer_cnt)] = gfpq_param_dict[interp_layer_name]
interp_layer_cnt += 1
return gfpq_param_dict

with open(os.path.join('./', 'nnie_gfpq_param_dict.json'), 'w') as f:
json.dump({"nnie": {"gfpq_param_dict": gfpq_param_dict}}, f, indent=4)

def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nnie_deploy_model.onnx"):
def remove_fakequantize_and_collect_params(self, onnx_path):
model = onnx.load(onnx_path)
graph = model.graph
out2node, inp2node = update_inp2node_out2node(graph)
Expand Down Expand Up @@ -168,7 +170,7 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nni
new_data = np.clip(data, -clip_range, clip_range)
new_data = numpy_helper.from_array(new_data)
named_initializer[tensor_name].raw_data = new_data.raw_data
print(f'clip weights {tensor_name} to range [{-clip_range}, {clip_range}].')
logger.info(f'Clip weights {tensor_name} to range [{-clip_range}, {clip_range}].')
else:
# fake quantize for activations
clip_ranges[node.input[0]] = name2data[node.input[1]]
Expand All @@ -181,8 +183,15 @@ def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nni
for node in nodes_to_be_removed:
graph.node.remove(node)

self.gen_gfpq_param_file(graph, clip_ranges)
onnx.save(model, model_save_path)
gfpq_param_dict = self.gen_gfpq_param_file(graph, clip_ranges)

output_path = os.path.dirname(onnx_path)
filename = os.path.join(output_path, 'nnie_gfpq_param_dict.json')
with open(filename, 'w') as f:
json.dump({"nnie": {"gfpq_param_dict": gfpq_param_dict}}, f, indent=4)
filename = os.path.join(output_path, 'nnie_deploy_model.onnx')
onnx.save(model, filename)
logger.info("Finish deploy process.")

remove_fakequantize_and_collect_params_nnie = NNIE_process().remove_fakequantize_and_collect_params

Expand Down Expand Up @@ -230,11 +239,11 @@ def deal_with_weight_fakequant(self, node, out2node, inp2node, named_initializer
next_node.input[idx] = node.input[0]
return redundant_nodes

def deal_with_activation_fakequant(self, node, inp2node):
def deal_with_activation_fakequant(self, node, inp2node):
next_nodes = inp2node[node.output[0]]
for next_node, idx in next_nodes:
next_node.input[idx] = node.input[0]
return
return

def parse_qparams(self, node, name2data):
tensor_name, scale, zero_point = node.input[:3]
Expand All @@ -247,11 +256,11 @@ def parse_qparams(self, node, name2data):
qmin = qparams['quant_min']
qmax = qparams['quant_max']
else:
print(f'qmin and qmax are not found for <{node.name}>!')
logger.info(f'qmin and qmax are not found for <{node.name}>!')
return tensor_name, scale, zero_point, qmin, qmax

def clip_weight(self, node, name2data, named_initializer):
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
data = name2data[tensor_name]
clip_range_min = (qmin - zero_point) * scale
clip_range_max = (qmax - zero_point) * scale
Expand All @@ -260,18 +269,18 @@ def clip_weight(self, node, name2data, named_initializer):
for c in range(data.shape[0]):
new_data.append(np.clip(data[c], clip_range_min[c], clip_range_max[c]))
new_data = np.array(new_data)
print(f'clip weights {tensor_name} to per-cahnnel clip range.')
logger.info(f'Clip weights <{tensor_name}> to per-channel ranges.')
else:
new_data = np.clip(data, clip_range_min, clip_range_max)
print(f'clip weights {tensor_name} to range [{clip_range_min}, {clip_range_max}].')
logger.info(f'Clip weights <{tensor_name}> to range [{clip_range_min}, {clip_range_max}].')
new_data = numpy_helper.from_array(new_data)
named_initializer[tensor_name].raw_data = new_data.raw_data

def post_process_clip_ranges(self, clip_ranges, graph, inp2node):
def find_the_closest_clip_range(node):
if node.input[0] in clip_ranges:
return node.input[0]
elif node.op_type in ['Flatten', 'Resize']:
elif node.op_type in ['Flatten', 'Resize'] and node.output[0] in inp2node:
return find_the_closest_clip_range(inp2node[node.output[0]][0][0])
else:
return None
Expand All @@ -281,7 +290,7 @@ def find_the_closest_clip_range(node):
tensor_name = find_the_closest_clip_range(node)
if tensor_name:
clip_ranges[node.input[0]] = clip_ranges[tensor_name]
print(f'Pass <{tensor_name}> clip range to <{node.name}> input <{node.input[0]}>.')
logger.info(f'Pass <{tensor_name}> clip range to <{node.name}> input <{node.input[0]}>.')
return clip_ranges

def remove_fakequantize_and_collect_params(self, onnx_path, backend):
Expand All @@ -300,22 +309,22 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend):
nodes_to_be_removed = []
for node in graph.node:
if node.op_type in all_fakequantizer:
nodes_to_be_removed.append(node)
nodes_to_be_removed.append(node)
nodes_to_be_removed.extend(get_constant_inputs(node, out2node))

if node.op_type in perchannel_fakequantizer:
if node.op_type in perchannel_fakequantizer:
# fake quantize for weights, suppose per-channel quantize only for weight
redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer)
nodes_to_be_removed.extend(redundant_nodes)
self.clip_weight(node, name2data, named_initializer)
if backend == 'ppl':
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
clip_ranges[tensor_name] = {'step': [float(x) for x in scale],
'zero_point': [int(x) for x in zero_point],
'min': [float(x) for x in scale * (qmin - zero_point)],
'max': [float(x) for x in scale * (qmax - zero_point)],
'bit': int(np.log2(qmax - qmin + 1)),
'type': "biased",
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
clip_ranges[tensor_name] = {'step': [float(x) for x in scale],
'zero_point': [int(x) for x in zero_point],
'min': [float(x) for x in scale * (qmin - zero_point)],
'max': [float(x) for x in scale * (qmax - zero_point)],
'bit': int(np.log2(qmax - qmin + 1)),
'type': "biased",
}

elif node.op_type in pertensor_fakequantizer:
Expand All @@ -326,32 +335,33 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend):
if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']:
# fake quantize for weights
redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer)
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
nodes_to_be_removed.extend(redundant_nodes)
self.clip_weight(node, name2data, named_initializer)
else:
# fake quantize for activations
self.deal_with_activation_fakequant(node, inp2node)
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data)
for out in graph.output:
if out.name == node.output[0]:
out.name = tensor_name

if backend == 'tensorrt':
clip_ranges[tensor_name] = float(scale * min(-qmin, qmax))
elif backend == 'snpe':
clip_ranges[tensor_name] = {'bitwidth': int(np.log2(qmax - qmin + 1)),
'min': float(scale * (qmin - zero_point)),
'max': float(scale * (qmax - zero_point))
}
clip_ranges[tensor_name] = [
{'bitwidth': int(np.log2(qmax - qmin + 1)),
'min': float(scale * (qmin - zero_point)),
'max': float(scale * (qmax - zero_point))}
]
if backend == 'ppl':
clip_ranges[tensor_name] = {'step': float(scale),
'zero_point': int(zero_point),
'min': float(scale * (qmin - zero_point)),
'max': float(scale * (qmax - zero_point)),
'bit': int(np.log2(qmax - qmin + 1)),
'type': "biased",
}
clip_ranges[tensor_name] = {'step': float(scale),
'zero_point': int(zero_point),
'min': float(scale * (qmin - zero_point)),
'max': float(scale * (qmax - zero_point)),
'bit': int(np.log2(qmax - qmin + 1)),
'type': "biased",
}

for node in nodes_to_be_removed:
graph.node.remove(node)
Expand All @@ -363,9 +373,12 @@ def remove_fakequantize_and_collect_params(self, onnx_path, backend):
context = {'activation_encodings': clip_ranges, 'param_encodings': {}}
elif backend == 'ppl':
context = {"ppl": clip_ranges}
filename = os.path.join('./', '{}_clip_ranges.json'.format(backend))
output_path = os.path.dirname(onnx_path)
filename = os.path.join(output_path, '{}_clip_ranges.json'.format(backend))
with open(filename, 'w') as f:
json.dump(context, f, indent=4)
onnx.save(model, '{}_deploy_model.onnx'.format(backend))
filename = os.path.join(output_path, '{}_deploy_model.onnx'.format(backend))
onnx.save(model, filename)
logger.info("Finish deploy process.")

remove_fakequantize_and_collect_params = LinearQuantizer_process().remove_fakequantize_and_collect_params
Loading

0 comments on commit 3ac8928

Please sign in to comment.