Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] auto add fake module to call_function node. #131

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 155 additions & 72 deletions mqbench/advanced_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def qnode2fpnode(quant_modules, fp32_modules):
qnode2fpnode_dict = {quant_named_nodes[key]: fp32_named_nodes[key] for key in quant_named_nodes}
return qnode2fpnode_dict


def insert_fake_modules(model, node_list, direction):
graph = model.graph
for node in node_list:
inserted_node_target = node.target + '_fake_module_' + direction
setattr(model, inserted_node_target, torch.nn.Identity())
if direction == 'input':
with graph.inserting_before(node):
inserted_node = graph.create_node(op='call_module',
name=inserted_node_target.replace('.', '_'),
target=inserted_node_target,
args=node.args,
kwargs=node.kwargs)
elif direction == 'output':
with graph.inserting_after(node):
inserted_node = graph.create_node(op='call_module',
name=inserted_node_target.replace('.', '_'),
target=inserted_node_target,
args=(node,),
kwargs={})
model.recompile()
model.graph.lint()
return model


def layer_has_weights(nodes, modules):
has_weights = False
for node in nodes:
Expand Down Expand Up @@ -233,14 +258,27 @@ def _flatten_args(node):
flattned_args.extend([node])
return flattned_args

def get_io_of_block(nodes):
used_list = []
input_list = []
for node in nodes:
if all([arg not in nodes for arg in _flatten_args(node.kwargs)]) and all([arg not in nodes for arg in _flatten_args(node.args)]):
input_list.append(node)
for arg in _flatten_args(node.kwargs):
if arg in nodes and arg not in used_list:
used_list.append(arg)
for arg in _flatten_args(node.args):
if arg in nodes and arg not in used_list:
used_list.append(arg)
output_list = [node for node in nodes if node not in used_list]
return input_list, output_list

def find_used_times(nodes, target):
used = len([_node for _node in target.users if _node in nodes])
return used




def find_cur_node(layer_node_list):
node_list = []
used_later = []
Expand Down Expand Up @@ -497,44 +535,7 @@ def extract_block(input_nodes, fp32_modules, depth=0):
return layer_node_list + exp_nodes + extract_block(
[exp_nodes[-1]], fp32_modules, depth + 1)


def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_module_list: list = None):
r"""
Reconsturction for AdaRound, BRECQ, QDrop.
Basic optimization objective:

.. math::

\mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}),

\tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right)

where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase.

Args:
model (torch.nn.Module): a prepared GraphModule to do PTQ
cali_data (List): a list of calibration tensor
config (dict): a config for PTQ reconstruction
graph_module_list (list): a list of model's children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized.

>>> sample config : {
pattern: block (str, Available options are [layer, block].)
scale_lr: 4.0e-5 (learning rate for learning step size of activation)
warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil)
weight: 0.01 (loss weight for regularization item)
max_count: 20000 (optimization iteration)
b_range: [20,2] (beta decaying range )
keep_gpu: True (calibration data restore in gpu or cpu)
round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid)
prob: 0.5 (dropping probability of QDROP)
}

"""
# assert model is on cuda
if not config.keep_gpu:
cali_data = [to_device(inp, 'cpu') for inp in cali_data]
'''set state first'''

def prepare_fp_and_quant_model_for_ptq(model, graph_module_list):
fp32_model = model
fp32_model.eval()
if graph_module_list is None:
Expand Down Expand Up @@ -577,6 +578,50 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_
enable_quantization(quant_model)
torch.cuda.empty_cache()
checked_nodes = dict()
return fp32_model, quant_model, nodes, g2node, fp32_modules, quant_modules, topology_order_by_node, qnode2fpnode_dict, checked_nodes

def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_module_list: list = None):
r"""
Reconsturction for AdaRound, BRECQ, QDrop.
Basic optimization objective:

.. math::

\mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}),

\tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right)

where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase.

Args:
model (torch.nn.Module): a prepared GraphModule to do PTQ
cali_data (List): a list of calibration tensor
config (dict): a config for PTQ reconstruction
graph_module_list (list): a list of model's children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized.

>>> sample config : {
pattern: block (str, Available options are [layer, block].)
scale_lr: 4.0e-5 (learning rate for learning step size of activation)
warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil)
weight: 0.01 (loss weight for regularization item)
max_count: 20000 (optimization iteration)
b_range: [20,2] (beta decaying range )
keep_gpu: True (calibration data restore in gpu or cpu)
round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid)
prob: 0.5 (dropping probability of QDROP)
}

"""
# assert model is on cuda
if not config.keep_gpu:
cali_data = [to_device(inp, 'cpu') for inp in cali_data]
'''set state first'''

fp32_model, quant_model, nodes, g2node, fp32_modules, quant_modules, topology_order_by_node, qnode2fpnode_dict, checked_nodes = \
prepare_fp_and_quant_model_for_ptq(model, graph_module_list)

# setup for the reconstruction block node list
block_list = []
for node in nodes:
if 'exclude_node_prefix' in config:
cont = False
Expand Down Expand Up @@ -633,41 +678,79 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_
continue
logger.info('the node list is below!')
logger.info(layer_node_list)
fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]]
fp32_all_inps = []
quant_all_inps = []
fp32_final_oups = None
out_is_cached = False
for _node in layer_node_list:
if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]):
continue
else:
fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
quant_module = quant_modules[_node]
# fp32 inps: [out_b1, out_b2, ...]
_, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data,
store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu)
_, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu)
_, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
store_inp=False, store_oup=True, keep_gpu=config.keep_gpu)
fp32_all_inps.append(fp32_inps)
quant_all_inps.append(quant_inps)
if not out_is_cached:
fp32_final_oups = fp32_oups
out_is_cached = True
cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps
cached_oups = fp32_final_oups
quant_modules_by_name = dict()
for node in layer_node_list:
if node.op == 'call_module':
quant_modules_by_name[node.target] = quant_modules[node]
subgraph = extract_subgraph(quant_modules_by_name, layer_node_list,
layer_node_list[-1], g2node)
logger.info(subgraph.code)
subgraph_reconstruction(subgraph, cached_inps, cached_oups, config)
block_list.append(layer_node_list)
for x in layer_node_list:
checked_nodes[x] = True

# insert fake module of input/output
fake_module_dict = {'input': [], 'output': []}
for idx, layer_node_list in enumerate(block_list):
input_nodes, output_nodes = get_io_of_block(layer_node_list)
print(idx, input_nodes, output_nodes)
block_list[idx] = [node.name for node in layer_node_list]
for onode in output_nodes:
if onode not in fake_module_dict['output']:
fake_module_dict['output'].append(onode)
for inode in input_nodes:
if inode not in fake_module_dict['input']:
fake_module_dict['input'].append(inode)
# re-build the model
if len(fake_module_dict['input']) == 0 and len(fake_module_dict['output']) == 0:
pass
else:
model_with_fake_module = deepcopy_graphmodule(model) if graph_module_list is None else deepcopy_mixedmodule(model, graph_module_list)
fake_module_dict['input'] = [node for node in model_with_fake_module.graph.nodes if node.name in fake_module_dict['input']]
fake_module_dict['output'] = [node for node in model_with_fake_module.graph.nodes if node.name in fake_module_dict['output']]
model_with_fake_module = insert_fake_modules(model_with_fake_module, fake_module_dict['input'], 'input')
model_with_fake_module = insert_fake_modules(model_with_fake_module, fake_module_dict['output'], 'output')
fp32_model, quant_model, nodes, g2node, fp32_modules, quant_modules, topology_order_by_node, qnode2fpnode_dict, checked_nodes = \
prepare_fp_and_quant_model_for_ptq(model_with_fake_module, graph_module_list)
for direction in fake_module_dict:
for node in fake_module_dict[direction]:
for idx, layer_node_list in enumerate(block_list):
if node.name in layer_node_list:
block_list[idx].append(node.name + '_fake_module_' + direction)
qname2node = {node.name: node for node in nodes}
for idx, layer_node_name_list in enumerate(block_list):
layer_node_list = [qname2node[node_name] for node_name in layer_node_name_list]
block_list[idx] = layer_node_list
block_list[idx] = sorted(layer_node_list, key=lambda x: topology_order_by_node[x])

for layer_node_list in block_list:
fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]]
fp32_all_inps = []
quant_all_inps = []
fp32_final_oups = None
out_is_cached = False
for _node in layer_node_list:
if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]):
continue
else:
fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]]
quant_module = quant_modules[_node]
# fp32 inps: [out_b1, out_b2, ...]
_, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data,
store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu)
_, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data,
store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu)
_, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data,
store_inp=False, store_oup=True, keep_gpu=config.keep_gpu)
fp32_all_inps.append(fp32_inps)
quant_all_inps.append(quant_inps)
if not out_is_cached:
fp32_final_oups = fp32_oups
out_is_cached = True
cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps
cached_oups = fp32_final_oups
quant_modules_by_name = dict()
for node in layer_node_list:
if node.op == 'call_module':
quant_modules_by_name[node.target] = quant_modules[node]
subgraph = extract_subgraph(quant_modules_by_name, layer_node_list,
layer_node_list[-1], g2node)
logger.info(subgraph.code)
subgraph_reconstruction(subgraph, cached_inps, cached_oups, config)

disable_all(quant_model)
for node in checked_nodes:
if node.op == 'call_module':
Expand Down