diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f215531 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +[MQBench](http://mqbench.tech/assets/docs/html/) diff --git a/application/imagenet_example/calibrator.py b/application/imagenet_example/calibrator.py new file mode 100644 index 0000000..4373cc2 --- /dev/null +++ b/application/imagenet_example/calibrator.py @@ -0,0 +1,57 @@ +import tensorrt as trt +import os +import pycuda.driver as cuda +import pycuda.autoinit +from PIL import Image +import numpy as np + +def load_imagenet_data(cali_data_loader): + dataset = [] + for i, (data, label) in enumerate(cali_data_loader): + data = data.numpy().astype(np.float32) + dataset.append(data) + return dataset + +class ImagenetCalibrator(trt.IInt8EntropyCalibrator2): + def __init__(self, cali_data_loader, cache_file): + # Whenever you specify a custom constructor for a TensorRT class, + # you MUST call the constructor of the parent explicitly. + trt.IInt8EntropyCalibrator2.__init__(self) + + self.cache_file = cache_file + + # Every time get_batch is called, the next batch of size batch_size will be copied to the device and returned. + self.data = load_imagenet_data(cali_data_loader) + self.batch_size = self.data[0].shape[0] + self.current_index = 0 + + # Allocate enough memory for a whole batch. + self.device_input = cuda.mem_alloc(self.data[0].nbytes) + + def get_batch_size(self): + return self.batch_size + + # TensorRT passes along the names of the engine bindings to the get_batch function. + # You don't necessarily have to use them, but they can be useful to understand the order of + # the inputs. The bindings list is expected to have the same ordering as 'names'. + def get_batch(self, names): + if self.current_index == len(self.data): + return None + + batch = self.data[self.current_index].ravel() + cuda.memcpy_htod(self.device_input, batch) + self.current_index += 1 + print('Calibrate batch = {} / {}'.format(self.current_index, len(self.data))) + return [self.device_input] + + + def read_calibration_cache(self): + # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. + # if os.path.exists(self.cache_file): + # with open(self.cache_file, "rb") as f: + # return f.read() + return None + + def write_calibration_cache(self, cache): + with open(self.cache_file, "wb") as f: + f.write(cache) diff --git a/application/imagenet_example/main.py b/application/imagenet_example/main.py new file mode 100644 index 0000000..efec9ca --- /dev/null +++ b/application/imagenet_example/main.py @@ -0,0 +1,491 @@ +import argparse +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +# import torchvision.models as models +import models +from mqbench.convert_deploy import convert_deploy +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.utils.state import enable_calibration, enable_quantization, disable_all + +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('--train_data', metavar='DIR', + help='path to dataset', required=True) +parser.add_argument('--val_data', metavar='DIR', + help='path to dataset', required=True) +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=100, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +parser.add_argument('--model_path', type=str, default=None) +parser.add_argument('--optim', type=str, default='sgd') +parser.add_argument('--not-quant', action='store_true') +parser.add_argument('--deploy', action='store_true') + + +best_acc1 = 0 + +def main(): + args = parser.parse_args() + args.quant = not args.not_quant + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + # create model + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True, model_path=args.model_path) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + # quantize model + if args.quant: + model = prepare_qat_fx_by_platform(model, BackendType.Tensorrt) + + if not torch.cuda.is_available(): + print('using CPU, this will be slow') + elif args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + else: + # DataParallel will divide and allocate batch_size to all available GPUs + if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): + model.features = torch.nn.DataParallel(model.features) + model.cuda() + else: + model = torch.nn.DataParallel(model).cuda() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), args.lr, + betas=(0.9, 0.999), eps=1e-08, + weight_decay=args.weight_decay, + amsgrad=False) + + # prepare dataset + train_loader, train_sampler, val_loader, cali_loader = prepare_dataloader(args) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + + state_dict = checkpoint['state_dict'] + model_dict = model.state_dict() + if 'module.' in list(state_dict.keys())[0] and 'module.' not in list(model_dict.keys())[0]: + for k in list(state_dict.keys()): + state_dict[k[7:]] = state_dict.pop(k) + + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {}), acc = {}" + .format(args.resume, checkpoint['epoch'], best_acc1)) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + elif args.quant: + enable_calibration(model) + calibrate(cali_loader, model, args) + + cudnn.benchmark = True + + if args.quant: + enable_quantization(model) + + if args.quant and args.deploy: + convert_deploy(model.eval(), BackendType.Tensorrt, input_shape_dict={'data': [10, 3, 224, 224]}) + return + + if args.evaluate: + from mqbench.convert_deploy import convert_merge_bn + convert_merge_bn(model.eval()) + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer' : optimizer.state_dict(), + }, is_best) + +def prepare_dataloader(args): + traindir = os.path.join(args.train_data, 'train') + valdir = os.path.join(args.val_data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + cali_batch_size = 10 + cali_batch = 10 + cali_dataset = torch.utils.data.Subset(train_dataset, indices=torch.arange(cali_batch_size * cali_batch)) + cali_loader = torch.utils.data.DataLoader(cali_dataset, batch_size=cali_batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + return train_loader, train_sampler, val_loader, cali_loader + +def calibrate(cali_loader, model, args): + model.eval() + print("Start calibration ...") + print("Calibrate images number = ", len(cali_loader.dataset)) + with torch.no_grad(): + for i, (images, target) in enumerate(cali_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + output = model(images) + print("Calibration ==> ", i+1) + print("End calibration.") + return + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if torch.cuda.is_available(): + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + if torch.cuda.is_available(): + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/application/imagenet_example/models/__init__.py b/application/imagenet_example/models/__init__.py new file mode 100644 index 0000000..00e8c36 --- /dev/null +++ b/application/imagenet_example/models/__init__.py @@ -0,0 +1,2 @@ +from .resnet import * +from .mobilenetv2 import * \ No newline at end of file diff --git a/application/imagenet_example/models/mobilenetv2.py b/application/imagenet_example/models/mobilenetv2.py new file mode 100644 index 0000000..f2466bb --- /dev/null +++ b/application/imagenet_example/models/mobilenetv2.py @@ -0,0 +1,222 @@ +import torch +from torch import nn +from torch import Tensor +# from .._internally_replaced_utils import load_state_dict_from_url +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Callable, Any, Optional, List + + +__all__ = ['MobileNetV2', 'mobilenet_v2'] + + +model_urls = { + 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', +} + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNActivation(nn.Sequential): + def __init__( + self, + in_planes: int, + out_planes: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Optional[Callable[..., nn.Module]] = None, + dilation: int = 1, + ) -> None: + padding = (kernel_size - 1) // 2 * dilation + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if activation_layer is None: + activation_layer = nn.ReLU6 + super().__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, + bias=False), + norm_layer(out_planes), + activation_layer(inplace=True) + ) + self.out_channels = out_planes + + +# necessary for backwards compatibility +ConvBNReLU = ConvBNActivation + + +class InvertedResidual(nn.Module): + def __init__( + self, + inp: int, + oup: int, + stride: int, + expand_ratio: int, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers: List[nn.Module] = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ]) + self.conv = nn.Sequential(*layers) + self.out_channels = oup + self._is_cn = stride > 1 + + def forward(self, x: Tensor) -> Tensor: + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__( + self, + num_classes: int = 1000, + width_mult: float = 1.0, + inverted_residual_setting: Optional[List[List[int]]] = None, + round_nearest: int = 8, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + norm_layer: Module specifying the normalization layer to use + """ + super(MobileNetV2, self).__init__() + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + x = self.features(x) + # Cannot use "squeeze" as batch-size can be 1 + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def mobilenet_v2(pretrained: bool = False, progress: bool = True, model_path: str = None, **kwargs: Any) -> MobileNetV2: + """ + Constructs a MobileNetV2 architecture from + `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + model = MobileNetV2(**kwargs) + if pretrained: + if model_path: + state_dict = torch.load(model_path) + print(f'load pretrained checkpoint from: {model_path}') + else: + state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], + progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/application/imagenet_example/models/resnet.py b/application/imagenet_example/models/resnet.py new file mode 100644 index 0000000..fddd171 --- /dev/null +++ b/application/imagenet_example/models/resnet.py @@ -0,0 +1,388 @@ +import torch +from torch import Tensor +import torch.nn as nn +# from .._internally_replaced_utils import load_state_dict_from_url +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + model_path: str, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + if model_path: + state_dict = torch.load(model_path) + print(f'load pretrained checkpoint from: {model_path}') + else: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, model_path: str = None, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, model_path, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, model_path: str = None, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, model_path, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + \ No newline at end of file diff --git a/application/imagenet_example/onnx2trt.py b/application/imagenet_example/onnx2trt.py new file mode 100644 index 0000000..69747d5 --- /dev/null +++ b/application/imagenet_example/onnx2trt.py @@ -0,0 +1,176 @@ +import onnx +import pycuda.autoinit # noqa F401 +import tensorrt as trt +import torch +import json +import pycuda.driver as cuda +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import os +import numpy as np +import argparse + +from main import AverageMeter, accuracy + +def onnx2trt(onnx_model, + trt_path, + dataset_path, + batch_size=1, + cali_batch=10, + log_level=trt.Logger.ERROR, + max_workspace_size=1 << 30, + device_id=0, + mode='fp32', + dynamic_range_file=None): + if os.path.exists(trt_path): + print(f'The "{trt_path}" exists. Remove it and continue.') + os.remove(trt_path) + + device = torch.device('cuda:{}'.format(device_id)) + + # create builder and network + logger = trt.Logger(log_level) + builder = trt.Builder(logger) + EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + + # parse onnx + parser = trt.OnnxParser(network, logger) + + if isinstance(onnx_model, str): + onnx_model = onnx.load(onnx_model) + + if not parser.parse(onnx_model.SerializeToString()): + error_msgs = '' + for error in range(parser.num_errors): + error_msgs += f'{parser.get_error(error)}\n' + raise RuntimeError(f'parse onnx failed:\n{error_msgs}') + + config = builder.create_builder_config() + config.max_workspace_size = max_workspace_size + + if mode == 'int8': + config.set_flag(trt.BuilderFlag.INT8) + if dynamic_range_file: + with open(dynamic_range_file, 'r') as f: + dynamic_range = json.load(f)['tensorrt']['blob_range'] + + for input_index in range(network.num_inputs): + input_tensor = network.get_input(input_index) + if input_tensor.name in dynamic_range: + amax = dynamic_range[input_tensor.name] + input_tensor.dynamic_range = (-amax, amax) + print(f'Set dynamic range of {input_tensor.name} as [{-amax}, {amax}]') + + for layer_index in range(network.num_layers): + layer = network[layer_index] + output_tensor = layer.get_output(0) + if output_tensor.name in dynamic_range: + amax = dynamic_range[output_tensor.name] + output_tensor.dynamic_range = (-amax, amax) + print(f'Set dynamic range of {output_tensor.name} as [{-amax}, {amax}]') + else: + from calibrator import ImagenetCalibrator + calidir = os.path.join(dataset_path, 'cali') + dataset = datasets.ImageFolder(calidir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])])) + cali_num = min(len(dataset), batch_size * cali_batch) + cali_dataset = torch.utils.data.Subset(dataset, indices=torch.arange(cali_num)) + cali_loader = torch.utils.data.DataLoader(cali_dataset, batch_size=batch_size, shuffle=False, + num_workers=1, pin_memory=False) + calibrator = ImagenetCalibrator(cali_loader, cache_file='imagenet.cache') + config.int8_calibrator = calibrator + print(f'Calibration Set!') + + # create engine + with torch.cuda.device(device): + engine = builder.build_engine(network, config) + + with open(trt_path, mode='wb') as f: + f.write(bytearray(engine.serialize())) + return engine + +def infer(engine, img, batch_size, context): + h_input = img + h_output = cuda.pagelocked_empty(batch_size * trt.volume(engine.get_binding_shape(1)[1:]), dtype=np.float32) + # Allocate device memory for inputs and outputs. + d_input = cuda.mem_alloc(4 * trt.volume(engine.get_binding_shape(0))) + d_output = cuda.mem_alloc(4 * trt.volume(engine.get_binding_shape(1))) + # Transfer input data to the GPU. + cuda.memcpy_htod(d_input, h_input) + # Run inference. + context.execute_v2(bindings=[int(d_input), int(d_output)]) + # Return the host output. + cuda.memcpy_dtoh(h_output, d_output) + d_input.free() + d_output.free() + + return h_output.reshape(batch_size, *engine.get_binding_shape(1)[1:]) + +def validate(trt_file, batch_size=64, dataset_path=None): + # deserialize engine + trt_logger = trt.Logger(trt.Logger.INFO) + + with trt.Runtime(trt_logger) as runtime: + with open(trt_file, 'rb') as f: + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + # prepare dateset + valdir = os.path.join(dataset_path, 'val') + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ])), + batch_size=batch_size, shuffle=False, + num_workers=4, pin_memory=False) + + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + # evaluate + for index, (images, target) in enumerate(val_loader): + images = images.detach().numpy() + output = infer(engine, images, len(images), context) + output = torch.from_numpy(output) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], images.shape[0]) + top5.update(acc5[0], images.shape[0]) + + if index % 100 == 0: + print(f' {index} ==> * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) + print(f' Final ==> * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Onnx to tensorrt') + parser.add_argument('--onnx-path', type=str, default=None) + parser.add_argument('--trt-path', type=str, default=None) + parser.add_argument('--mode', choices=['fp32', 'int8'], default='int8') + parser.add_argument('--clip-range-file', type=str, default=None) + parser.add_argument('--batch-size', type=int, default=10) + parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--verbose', action='store_true') + parser.add_argument('--data-path', type=str, required=True) + args = parser.parse_args() + + if args.onnx_path: + onnx2trt(args.onnx_path, + trt_path=args.trt_path, + mode=args.mode, + dataset_path=args.data_path, + batch_size=10, + cali_batch=10, + log_level=trt.Logger.VERBOSE if args.verbose else trt.Logger.ERROR, + dynamic_range_file=args.clip_range_file) + if args.evaluate: + validate(args.trt_path, batch_size=args.batch_size, dataset_path=args.data_path) + diff --git a/application/imagenet_example/readme.md b/application/imagenet_example/readme.md new file mode 100644 index 0000000..fd335d1 --- /dev/null +++ b/application/imagenet_example/readme.md @@ -0,0 +1,58 @@ +# MQBench Example with ImageNet + +We follow the PyTorch [official example][https://github.com/pytorch/examples/tree/master/imagenet] to build the example of Model Quantization Benchmark for ImageNet classification task. + +## Requirements + +- Install PyTorch ([pytorch.org](http://pytorch.org)) +- `pip install -r requirements.txt` +- Download the ImageNet dataset from http://www.image-net.org/. + - Then, and move validation images to labeled subfolders, using [the following shell script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh). + +- Install TensorRT==7.2.1.6 https://developer.nvidia.com/tensorrt. + +## Usage + +- **Quantization-Aware Training:** + + - Training hyper-parameters: + - batch size = 128 + - epochs = 1 + - lr = 1e-4 + - others like weight decay, momentum are kept as default. + + - [model_name] = resnet18 / resnet50 / mobilenet_v2 / ... + + ``` + python main.py -a [model_name] --epochs 1 --lr 1e-4 --b 128 --seed 99 --pretrained + ``` + +- **Deployment** + We provide the example to deploy the quantized model to TensorRT. + + 1. First export the quantized model to ONNX [tensorrt_deploy_model.onnx] and dump the clip ranges [tensorrt_clip_ranges.json] for activations. + + ``` + python main.py -a [model_name] --resume [model_save_path] + ``` + + 2. Second build the TensorRT INT8 engine and evaluate, please make sure [dataset_path] contains subfolder [val]. + + ``` + python onnx2trt.py --onnx [tensorrt_deploy_model.onnx] --trt [model_name.trt] --clip [tensorrt_clip_ranges.json] --data [dataset_path] --evaluate + ``` + + If you don’t pass in external clip ranges [tensorrt_clip_ranges.json], TenosrRT will do calibration using default algorithm IInt8EntropyCalibrator2 with 100 images. So, please make sure [dataset_path] contains subfolder [cali]. + + ``` + python onnx2trt.py --onnx [tensorrt_deploy_model.onnx] --trt [model_name.trt] --data [dataset_path] --evaluate + ``` + +## Results + +| Model | accuracy@fp32 | accuracy@int8
TensoRT Calibration | accuracy@int8
MQBench QAT | accuracy@int8
TensorRT SetRange | +| :--------------- | :------------------------- | :----------------------------------- | :---------------------------- | :---------------------------------- | +| **ResNet18** | Acc@1 69.758 Acc@5 89.078 | Acc@1 69.612 Acc@5 88.980 | Acc@1 69.912 Acc@5 89.150 | Acc@1 69.904 Acc@5 89.182 | +| **ResNet50** | Acc@1 76.130 Acc@5 92.862 | Acc@1 76.074 Acc@5 92.892 | Acc@1 76.114 Acc@5 92.946 | Acc@1 76.320 Acc@5 93.006 | +| **MobileNet_v2** | Acc@1 71.878 Acc@5 90.286 | Acc@1 70.700 Acc@5 89.708 | Acc@1 70.826 Acc@5 89.874 | Acc@1 70.724 Acc@5 89.870 | + diff --git a/application/imagenet_example/requirements.txt b/application/imagenet_example/requirements.txt new file mode 100644 index 0000000..30efe1c --- /dev/null +++ b/application/imagenet_example/requirements.txt @@ -0,0 +1,3 @@ +torch>=1.8.1 +torchvision +pycuda \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html new file mode 100644 index 0000000..9bbd135 --- /dev/null +++ b/docs/_templates/layout.html @@ -0,0 +1,4 @@ +{% extends "!layout.html" %} +{% block extrahead %} + +{% endblock %} diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..6247f7e --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/_static/css/theme_overrides.css b/docs/source/_static/css/theme_overrides.css new file mode 100644 index 0000000..52bb2e3 --- /dev/null +++ b/docs/source/_static/css/theme_overrides.css @@ -0,0 +1,33 @@ +@import url("theme.css"); + +/* override table width restrictions */ +@media screen and (min-width: 767px) { + + /* for header */ + .wy-table-responsive table p { + text-align: center !important; + vertical-align: center !important; + } + + .wy-table-responsive table td { + /* !important prevents the common CSS stylesheets from overriding + this as on RTD they are loaded after this stylesheet */ + white-space: normal !important; + word-wrap: break-word !important; + /* word-break: break-all !important; */ + + text-align: center !important; + vertical-align: center !important; + } + + /* for list */ + .wy-table-responsive table td ol { + display: table-cell; !important; /* to enable the style setting */ + text-align: center !important; + vertical-align: center !important; + } + + .wy-table-responsive { + overflow: visible !important; + } +} diff --git a/docs/source/_static/images/pipeline.png b/docs/source/_static/images/pipeline.png new file mode 100644 index 0000000..eb1c401 Binary files /dev/null and b/docs/source/_static/images/pipeline.png differ diff --git a/docs/source/algorithm/index.rst b/docs/source/algorithm/index.rst new file mode 100644 index 0000000..0e138d2 --- /dev/null +++ b/docs/source/algorithm/index.rst @@ -0,0 +1,107 @@ +Quantization Algorithm +=========================== + +.. _LSQ: https://arxiv.org/abs/1902.08153 +.. _LSQ plus: https://arxiv.org/abs/2004.09576 +.. _DSQ: https://arxiv.org/abs/1908.05033 +.. _PACT: https://arxiv.org/abs/1805.06085 +.. _APoT: https://arxiv.org/abs/1909.13144 +.. _opensource codes: https://github.com/yhhhli/APoT_Quantization +.. _weight standardization: https://github.com/joe-siyuan-qiao/WeightStandardization +.. _QIL: https://arxiv.org/abs/1808.05779 +.. _AdaRound: https://arxiv.org/abs/2004.10568 + + +Post-training Quantization v.s. Quantization-aware Training +----------------------------------------------------------------------- + +1. Post Training Quantization (PTQ): + + Quantize a pre-trained network with limited data and computation resources, including activation range estimation, bn statistics update and other tuning techniques. + +2. Quantization Aware Training (QAT): + + End-to-end Finetuning a pre-trained full-precision model, this requires all training data and huge computation resource. + +QAT Algorithms +--------------------------------- + +**Learned Step Size Quantization**: + +`LSQ`_ leverages the Straight-Through Estimator (i.e. directly pass the gradient in round operation) to learn the quantization scale for each layer. +Please refer to the original paper for detailed derivation of the scale gradient. +For initialization, we use the method proposed in original paper: the scale is determined by :math:`s= \frac{2||\mathbf{w}||_1}{\sqrt{N_{max}}}`. For symmetric quantization, the zero point is initialized to 0, and kept fixed. For asymmetric quantization, zero point is initialized to :math:`N_{min}` if the activation is non-negative. Inspired by `LSQ plus`_, the zero point can also be updated through backpropagation with the help of STE. Therefore we make it learnable in asymmetric quantization. +LSQ uses gradient scale to stabilize the scale learning. The gradient scale is determined by :math:`\frac{1}{\sqrt{MN_{max}}}` where :math:`M` is the number of elements in that tensor. We extend this gradient scale to per-channel weight learning, where the :math:`M` is the number of weights in each filter. + + +**Differentiable Soft Quantization**: + +`DSQ`_ uses the hyperbolic tangent function to approximate the conventionally adopted STE. In our implementation, we use :math:`\alpha=0.4` (for definition please refer to the original paper) which controls the shape and smoothness of the :math:`\mathrm{tanh}` function. For weight quantization, we use the min-max range as + +.. raw:: latex html + + \[Clip_{min} = \mu(\mathbf{w}) - 2.6\sigma(\mathbf{w}) \] + \[Clip_{max} = \mu(\mathbf{w}) + 2.6\sigma(\mathbf{w}) \] + + +where :math:`\mu(\cdot)` and :math:`\sigma(\cdot)` compute the mean and standard deviation of the tensor. Then, the scale is determined by :math:`s=\frac{\max(-Clip_{min}, Clip_{max})}{N_{max}-N_{min}}` for symmetric quantization, and :math:`\frac{Clip_{max}-Clip_{min}}{N_{max}-N_{min}}` for asymmetric quantization. The zero point is set to 0 for symmetric and :math:`N_{min}-\lfloor \frac{Clip_{min}}{s}\rceil` for asymmetric quantization. For activation, we use the BatchMinMax as the clipping range, i.e. the averaged min-max range across the batch dimension. This is further updated with exponential moving average across different batches with momentum 0.9, similar to Batch Normalization. + +**Parameterized Clipping Activation**: + +`PACT`_ is introduced to quantized activation by learning the clipping threshold through STE. Its activation is clipped by a parameter :math:`\alpha` first. Then, the clipped activation is quantized and re-quantized. Although PACT and LSQ both learns the scale, they have three differences. First, the clipping range in PACT is handcrafted initialized to 6 while LSQ initialization is based on the tensor :math:`L1` norm. Second, PACT has no gradient in the range of clipping. While LSQ can compute the gradient. Third, PACT does not scale the gradient of :math:`\alpha`, while LSQ does. +Note that PACT only has non-negative, unsigned quantization in the first. To extend it to our hardware settings, we clip the activation to :math:`(-\alpha, \alpha)` in symmetric case and :math:`(\beta, \alpha)` for asymmetric case, (where :math:`\beta` is initialized to :math:`-6`). +For weight quantization of PACT, it is the same with DoReFa-Net. + +**DoReFa-Net**: + +DoReFa-Net simply clips the activation to :math:`[0, 1]` and then quantizes it. This is based on the intuition that most activation will fall into this range in old network architectures, e.g. AlexNet and ResNet. In hardware settings, we modify the activation range to :math:`[-1, 1]` for both symmetric and asymmetric quantization. As for weight quantization, it can be described as: + +.. raw:: latex html + + \[\tilde{\mathbf{w}} = \mathrm{tanh}(\mathbf{w}) \frac{1}{\max(|\mathrm{tanh}(\mathbf{w})|)} \] + \[\hat{\mathbf{w}} = \mathrm{dequantize}(\mathrm{quantize(\tilde{\mathbf{w}})}) \] + +where the first step is a non-linear transformation and the second step is the quantization and the de-quantization. The scale is simply calculated by :math:`\frac{2}{N_{max}-N_{min}}` for symmetric quantization and :math:`\frac{\max(\tilde{\mathbf{w}}) - \min(\tilde{\mathbf{w}})}{N_{max}-N_{min}}` for asymmetric quantization. + + +**Additive Powers-of-Two Quantization**: + +`APoT`_ quantization uses multiple PoT's (Powers-of-Two) combination to composes a set of non-uniform quantization levels. Since the quantization are non-uniform in most cases (except the case of 2-bit the APoT becomes uniform quantization), we do not benchmark it on real hardware. Additionally, APoT introduces weight normalization (similar to `weight standardization`_ technique) to smooth the learning process of clipping range in weight. However, it is unclear how to incoporate this technique with BN folding. +Therefore, we only reproduce it in our academic setting. The implementation are based on the `opensource codes`_. + + + +**Quantization Interval Learning**: + +`QIL`_ composes of two unit to quantization: (1) the first one is called transformer, which transform the weights or activation to :math:`[-1, 1]` (:math:`[0, 1]` as for non-negative activation). +This transformer also has two functionalities: pruning and non-linearity. +(2) The second one is called quantizer, given by + +.. raw:: latex html + + \[ \tilde{\mathbf{w}} = \mathrm{clip}\left((\alpha |\mathbf{w}| + \beta)^{\gamma}, 0, 1\right) * \mathrm{sign}(\mathbf{w})\] + \[ \hat{\mathbf{w}} = \mathrm{dequantize}(\mathrm{quantize(\tilde{\mathbf{w}})}), \] + +where :math:`\alpha = \frac{1}{2*D}` and :math:`\beta=-\frac{C}{2D}+\frac{1}{2}`. This transformation maps the weight from :math:`[C-D, C+D]` to :math:`[0, 1]` and :math:`[-C-D, -C+D]` to :math:`[-1, 0]`. As a result, the weights between :math:`[-C+D, C-D]` are pruned. The non-linearity of the transformation function is introduced by $\gamma$. This parameter can control the linearity and thus control the quantization interval. However, we find this technique is extremely unstable. In our experimental reproduction, learning $\gamma$ will not converge. In the original paper, the gradient scale of :math:`C` and :math:`D` is set to 0.01. We find this gradient scale also leads to frequent crashes. Thus we use the gradient scale introduced in LSQ, i.e. :math:`\frac{1}{\sqrt{MN_{max}}}`. + + +PTQ Algorithms +------------------------------ + +**AdaRound**: + +`AdaRound`_ aims to find the global optimal strategy of rounding the quantized values. In common sense, rounding-to-nearest is optimal for each individual value, but through threoretical analysis on the quantization loss, it's not the case for the entire network or the whole layer. The second order term in the difference contains cross term of the round error, illustrated in a layer of two weights: + +.. raw:: latex html + + \[ E[ L(x,y,\mathbf{w}) - L(x,y,\mathbf{w}+\Delta \mathbf{w}) ] \approx \Delta \mathbf{w}^T g^{(\mathbf{w})} + \frac12 \Delta \mathbf{w}^T H^{(\mathbf{w})} \Delta \mathbf{w} \approx \Delta \mathbf{w}_1^2 + \Delta \mathbf{w}_2^2 + \Delta \mathbf{w}_1 \Delta \mathbf{w}_2 \] + +Hence, it's benifitial to learn a rounding mask for each layer. One well-designed object function is given by the authors: + +.. raw:: latex html + + \[ \mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}), \] + \[ \tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right) \] + +where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase. + diff --git a/docs/source/api/modules.rst b/docs/source/api/modules.rst new file mode 100644 index 0000000..ee854aa --- /dev/null +++ b/docs/source/api/modules.rst @@ -0,0 +1,7 @@ +API Reference +============== + +.. toctree:: + :maxdepth: 4 + + mqbench diff --git a/docs/source/api/mqbench.fake_quantize.rst b/docs/source/api/mqbench.fake_quantize.rst new file mode 100644 index 0000000..c47214c --- /dev/null +++ b/docs/source/api/mqbench.fake_quantize.rst @@ -0,0 +1,69 @@ +mqbench.fake\_quantize package +============================== + +Submodules +---------- + +mqbench.fake\_quantize.dorefa +------------------------------------ + +.. automodule:: mqbench.fake_quantize.dorefa + :members: + :undoc-members: + :show-inheritance: + +mqbench.fake\_quantize.dsq +--------------------------------- + +.. automodule:: mqbench.fake_quantize.dsq + :members: + :undoc-members: + :show-inheritance: + +mqbench.fake\_quantize.fixed +----------------------------------- + +.. automodule:: mqbench.fake_quantize.fixed + :members: + :undoc-members: + :show-inheritance: + +mqbench.fake\_quantize.lsq +--------------------------------- + +.. automodule:: mqbench.fake_quantize.lsq + :members: + :undoc-members: + :show-inheritance: + +mqbench.fake\_quantize.nnie +---------------------------------- + +.. automodule:: mqbench.fake_quantize.nnie + :members: + :undoc-members: + :show-inheritance: + +mqbench.fake\_quantize.pact +---------------------------------- + +.. automodule:: mqbench.fake_quantize.pact + :members: + :undoc-members: + :show-inheritance: + +mqbench.fake\_quantize.quantize\_base +-------------------------------------------- + +.. automodule:: mqbench.fake_quantize.quantize_base + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mqbench.fake_quantize + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.nn.intrinsic.modules.rst b/docs/source/api/mqbench.nn.intrinsic.modules.rst new file mode 100644 index 0000000..31ad8ea --- /dev/null +++ b/docs/source/api/mqbench.nn.intrinsic.modules.rst @@ -0,0 +1,21 @@ +mqbench.nn.intrinsic.modules package +==================================== + +Submodules +---------- + +mqbench.nn.intrinsic.modules.fused +----------------------------------------- + +.. automodule:: mqbench.nn.intrinsic.modules.fused + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mqbench.nn.intrinsic.modules + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.nn.intrinsic.qat.modules.rst b/docs/source/api/mqbench.nn.intrinsic.qat.modules.rst new file mode 100644 index 0000000..65ded79 --- /dev/null +++ b/docs/source/api/mqbench.nn.intrinsic.qat.modules.rst @@ -0,0 +1,21 @@ +mqbench.nn.intrinsic.qat.modules package +======================================== + +Submodules +---------- + +mqbench.nn.intrinsic.qat.modules.linear\_fused +----------------------------------------------------- + +.. automodule:: mqbench.nn.intrinsic.qat.modules.linear_fused + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mqbench.nn.intrinsic.qat.modules + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.nn.intrinsic.qat.rst b/docs/source/api/mqbench.nn.intrinsic.qat.rst new file mode 100644 index 0000000..5e13f06 --- /dev/null +++ b/docs/source/api/mqbench.nn.intrinsic.qat.rst @@ -0,0 +1,18 @@ +mqbench.nn.intrinsic.qat package +================================ + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mqbench.nn.intrinsic.qat.modules + +Module contents +--------------- + +.. automodule:: mqbench.nn.intrinsic.qat + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.nn.intrinsic.rst b/docs/source/api/mqbench.nn.intrinsic.rst new file mode 100644 index 0000000..b68bedf --- /dev/null +++ b/docs/source/api/mqbench.nn.intrinsic.rst @@ -0,0 +1,19 @@ +mqbench.nn.intrinsic package +============================ + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mqbench.nn.intrinsic.modules + mqbench.nn.intrinsic.qat + +Module contents +--------------- + +.. automodule:: mqbench.nn.intrinsic + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.nn.rst b/docs/source/api/mqbench.nn.rst new file mode 100644 index 0000000..3e580c9 --- /dev/null +++ b/docs/source/api/mqbench.nn.rst @@ -0,0 +1,18 @@ +mqbench.nn package +================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mqbench.nn.intrinsic + +Module contents +--------------- + +.. automodule:: mqbench.nn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.rst b/docs/source/api/mqbench.rst new file mode 100644 index 0000000..da91af1 --- /dev/null +++ b/docs/source/api/mqbench.rst @@ -0,0 +1,95 @@ +mqbench package +=============== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mqbench.fake_quantize + mqbench.nn + mqbench.utils + +Submodules +---------- + +mqbench.adaround +----------------------- + +.. automodule:: mqbench.adaround + :members: + :undoc-members: + :show-inheritance: + +mqbench.convert\_deploy +------------------------------ + +.. automodule:: mqbench.convert_deploy + :members: + :undoc-members: + :show-inheritance: + +mqbench.convert\_onnx +---------------------------- + +.. automodule:: mqbench.convert_onnx + :members: + :undoc-members: + :show-inheritance: + +mqbench.custom\_quantizer +-------------------------------- + +.. automodule:: mqbench.custom_quantizer + :members: + :undoc-members: + :show-inheritance: + +mqbench.custom\_symbolic\_opset +-------------------------------------- + +.. automodule:: mqbench.custom_symbolic_opset + :members: + :undoc-members: + :show-inheritance: + +mqbench.fuser\_method\_mappings +-------------------------------------- + +.. automodule:: mqbench.fuser_method_mappings + :members: + :undoc-members: + :show-inheritance: + +mqbench.fusion\_method +----------------------------- + +.. automodule:: mqbench.fusion_method + :members: + :undoc-members: + :show-inheritance: + +mqbench.observer +----------------------- + +.. automodule:: mqbench.observer + :members: + :undoc-members: + :show-inheritance: + +mqbench.prepare\_by\_platform +------------------------------------ + +.. automodule:: mqbench.prepare_by_platform + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mqbench + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/mqbench.utils.rst b/docs/source/api/mqbench.utils.rst new file mode 100644 index 0000000..c7e393f --- /dev/null +++ b/docs/source/api/mqbench.utils.rst @@ -0,0 +1,45 @@ +mqbench.utils package +===================== + +Submodules +---------- + +mqbench.utils.logger +--------------------------- + +.. automodule:: mqbench.utils.logger + :members: + :undoc-members: + :show-inheritance: + +mqbench.utils.registry +----------------------------- + +.. automodule:: mqbench.utils.registry + :members: + :undoc-members: + :show-inheritance: + +mqbench.utils.state +-------------------------- + +.. automodule:: mqbench.utils.state + :members: + :undoc-members: + :show-inheritance: + +mqbench.utils.utils +-------------------------- + +.. automodule:: mqbench.utils.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mqbench.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..616c7f5 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,97 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'MQBench' +copyright = 'MQBench' +author = 'The Greate Cold' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'sphinx.ext.napoleon', + 'sphinxcontrib.contentui', + 'sphinx.ext.doctest', + 'sphinx.ext.mathjax', + 'sphinx.ext.ifconfig', +] + +# Add path of source folder for codes. +import os +import sys + +sys.path.insert(0, os.path.abspath("../../")) + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +# html_style = 'css/theme_overrides.css' +# +# html_context = { +# 'css_files': [ +# # overrides for wide tables in RTD theme +# '_static/css/theme_overrides.css', +# ], +# } +# + + +def setup(app): + app.add_css_file("css/theme_overrides.css") diff --git a/docs/source/example/index.rst b/docs/source/example/index.rst new file mode 100644 index 0000000..5d7200f --- /dev/null +++ b/docs/source/example/index.rst @@ -0,0 +1,65 @@ +Get Started +========================== +We follow the `PyTorch official example `_ to build the example of Model Quantization Benchmark for ImageNet classification task. + +Requirements +------------- + +- Install PyTorch following `pytorch.org `_ +- Install dependencies:: + + pip install -r requirements.txt + +- Download the ImageNet dataset from `the official website `_ + + - Then, and move validation images to labeled subfolders, using `the following shell script `_ + +- Install TensorRT=7.2.1.6 from `NVIDIA `_ + +Usage +--------- + +- **Quantization-Aware Training:** + + - Training hyper-parameters: + + - batch size = 128 + - epochs = 1 + - lr = 1e-4 + - others like weight decay, momentum are kept as default. + + - ResNet18 / ResNet50 / MobileNet_v2:: + + python main.py -a [model_name] --epochs 1 --lr 1e-4 --b 128 --seed 99 --pretrained + + +- **Deployment** + We provide the example to deploy the quantized model to TensorRT. + + 1. First export the quantized model to ONNX [tensorrt_deploy_model.onnx] and dump the clip ranges [tensorrt_clip_ranges.json] for activations.:: + + python main.py -a [model_name] --resume [model_save_path] + + + 2. Second build the TensorRT INT8 engine and evaluate, please make sure [dataset_path] contains subfolder [val]:: + + python onnx2trt.py --onnx [tensorrt_deploy_model.onnx] --trt [model_name.trt] --clip [tensorrt_clip_ranges.json] --data [dataset_path] --evaluate + + 3. If you don’t pass in external clip ranges [tensorrt_clip_ranges.json], TenosrRT will do calibration using default algorithm IInt8EntropyCalibrator2 with 100 images. So, please make sure [dataset_path] contains subfolder [cali]:: + + python onnx2trt.py --onnx [tensorrt_deploy_model.onnx] --trt [model_name.trt] --data [dataset_path] --evaluate + +Results +----------- + ++-------------------+--------------------------------+------------------------------------------------------------------------------------------------------------------+ +| Model | accuracy\@fp32 | accuracy\@int8 | +| | +----------------------------------------+---------------------------------+---------------------------------------+ +| | | TensoRT Calibration | MQBench QAT | TensorRT SetRange | ++===================+================================+========================================+=================================+=======================================+ +| **ResNet18** | Acc\@1 69.758 Acc\@5 89.078 | Acc\@1 69.612 Acc\@5 88.980 | Acc\@1 69.912 Acc\@5 89.150 | Acc\@1 69.904 Acc\@5 89.182 | ++-------------------+--------------------------------+----------------------------------------+---------------------------------+---------------------------------------+ +| **ResNet50** | Acc\@1 76.130 Acc\@5 92.862 | Acc\@1 76.074 Acc\@5 92.892 | Acc\@1 76.114 Acc\@5 92.946 | Acc\@1 76.320 Acc\@5 93.006 | ++-------------------+--------------------------------+----------------------------------------+---------------------------------+---------------------------------------+ +| **MobileNet_v2** | Acc\@1 71.878 Acc\@5 90.286 | Acc\@1 70.700 Acc\@5 89.708 | Acc\@1 70.826 Acc\@5 89.874 | Acc\@1 70.724 Acc\@5 89.870 | ++-------------------+--------------------------------+----------------------------------------+---------------------------------+---------------------------------------+ diff --git a/docs/source/hardware/index.rst b/docs/source/hardware/index.rst new file mode 100644 index 0000000..e621c10 --- /dev/null +++ b/docs/source/hardware/index.rst @@ -0,0 +1,8 @@ +Quantization Hardware +======================================== + +.. toctree:: + :maxdepth: 1 + + nnie + tensorrt diff --git a/docs/source/hardware/nnie.rst b/docs/source/hardware/nnie.rst new file mode 100644 index 0000000..c70a131 --- /dev/null +++ b/docs/source/hardware/nnie.rst @@ -0,0 +1,88 @@ +NNIE +==== +NNIE is a Neural Network Inference Engine of Hisilicon. It support INT8/INT16 quantization. + +.. _NNIE Quantization Scheme: + +Quantization Scheme +--------------------- +8/16 bit per-layer logarithmic quantization. + +The specific quantization formulation is: + +.. math:: + + \begin{equation} + \begin{aligned} + &z = \lfloor 16 * \log_2(c) \rceil - 127 \\ + &\mathtt{fakequant(x)} = \begin{cases} + - 2 ^ {\dfrac{\mathtt{clamp}(\lfloor 16 * \log_2(-x) \rceil - z, 1, 127) + z}{16}}, & x \lt - 2 ^ {\dfrac{z + 1}{16} - 1} \\ + % 0, & - 2 ^ {\dfrac{z + 1}{16} - 1} \le x \lt 2 ^ {\dfrac{z}{16} - 1} \\ + 2 ^ {\dfrac{\mathtt{clamp}(\lfloor 16 * \log_2(x) \rceil - z, 0, 127) + z}{16}}, & x \ge 2 ^ {\dfrac{z}{16} - 1} \\ + zero, & otherwise + \end{cases} + \end{aligned} + \end{equation} + +where :math:`c` is clipping range. :math:`2 ^ {\dfrac{z}{16}}` is the smallest positive value that can be represented after quantization. + +It represents the integer number in *True Form* format. +The highest bit represents the sign and the rest represents the absolute value of the number. + +.. list-table:: + :header-rows: 1 + :align: center + + * - Floating Numer + - Integer Number + - Hexadecimal + - Dequantized Floating Number + * - :math:`\bigg(- \infty, - 2 ^ {\dfrac{z + 126.5}{16}}\bigg]` + - -127 + - 0xFF + - :math:`- 2 ^ {\dfrac{z+127}{16}}` + * - ... + - ... + - ... + - ... + * - :math:`\bigg(- 2 ^ {\dfrac{z + 2.5}{16}}, - 2 ^ {\dfrac{z + 1.5}{16}}\bigg]` + - -2 + - 0x82 + - :math:`- 2 ^ {\dfrac{z+2}{16}}` + * - :math:`\bigg(- 2 ^ {\dfrac{z + 1.5}{16}}, - 2 ^ {\dfrac{z + 1}{16} - 1}\bigg)` + - -1 + - 0x81 + - :math:`- 2 ^ {\dfrac{z+1}{16}}` + * - :math:`\bigg[- 2 ^ {\dfrac{z + 1}{16} - 1}, 2 ^ {\dfrac{z}{16} - 1}\bigg)` + - -0 + - 0x80 + - 0 + * - :math:`\bigg[2 ^ {\dfrac{z}{16} - 1}, 2 ^ {\dfrac{z + 0.5}{16}}\bigg)` + - 0 + - 0x00 + - :math:`2 ^ {\dfrac{z}{16}}` + * - :math:`\bigg[2 ^ {\dfrac{z + 0.5}{16}}, 2 ^ {\dfrac{z + 1.5}{16}}\bigg)` + - 1 + - 0x01 + - :math:`2 ^ {\dfrac{z+1}{16}}` + * - ... + - ... + - ... + - ... + * - :math:`\bigg[2 ^ {\dfrac{z + 126.5}{16}}, + \infty\bigg)` + - 127 + - 0x7F + - :math:`2 ^ {\dfrac{z+127}{16}}` + +NNIE performs a per-layer quantization, which means the inputs of the same layer share the same :math:`z_a` and the weights of the same layer share the same :math:`z_w`. + +In fact, when building engine using the official tool of NNIE, it requires the clipping value :math:`c` rather than :math:`z`. :math:`c` needs to be a number in the :download:`gfpq_param_table_8bit.txt` which ensures that :math:`16 * \log_2{c}` is an integer. + +.. attention:: + Pooling: ceil_mode = True + + Avoid using depthwise convolution. + + Only support 2x nearest neighbor upsample. + + For Detection task, you'd better choose RetinaNet structure. diff --git a/docs/source/hardware/tensorrt.rst b/docs/source/hardware/tensorrt.rst new file mode 100644 index 0000000..34a75de --- /dev/null +++ b/docs/source/hardware/tensorrt.rst @@ -0,0 +1,23 @@ +TensorRT +========= + +`NVIDIA TensorRT `_ is a platform for high-performance deep learning inference on GPU device. + +.. _TensorRT Quantization Scheme: + +Quantization Scheme +-------------------- +8bit per-channel symmetric linear quantization. + +.. math:: + + \begin{equation} + q = \mathtt{clamp}(\lfloor x * s \rceil, lb, ub) + \end{equation} + +where :math:`s` is scaling factor to quantize a number from floating range to integer range, :math:`lb` and :math:`ub` are bounds of integer range. +For weights, [lb, ub] = [-127, 127]. For activations, [lb, ub] = [-128, 127]. + +For weights, each filter needs an independent scale :math:`s`. + +In fact, when building the TensorRT engine, the official tool requires the clipping value as quantization parameters, which can be calculated by :math:`c = s * 127`. diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..5242a68 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,24 @@ +.. MQBench documentation master file, created by + sphinx-quickstart on Mon Aug 9 15:27:41 2021. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to MQBench's documentation! +=================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + example/index + algorithm/index + hardware/index + api/modules + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/mqbench/__init__.py b/mqbench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mqbench/adaround.py b/mqbench/adaround.py new file mode 100644 index 0000000..50c4d96 --- /dev/null +++ b/mqbench/adaround.py @@ -0,0 +1,604 @@ +import copy +import os +import numpy as np +from typing import Callable, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx import GraphModule, Node + +from .observer import MinMaxObserver, ObserverBase + +_ADAROUND_SUPPORT_TYPE = (nn.Conv2d, nn.Linear, ) + + +def lp_norm(prediction, target, p=2.0): + """Function to calculate LP-Norm loss term. + + Args: + prediction (torch.Tensor): + target (torch.Tensor): + p (float, optional): Order of norm. Defaults to 2.0. + + Returns: + torch.Tensor: + """ + return (prediction - target).abs().pow(p).sum(1).mean() + +def _rectified_sigmoid(x, zeta, gamma): + """Function to generate rounding mask. + + Args: + x (torch.Tensor): + zeta (torch.Tensor): + gamma (torch.Tensor): + + Returns: + torch.Tensor: + """ + return ((zeta - gamma) * torch.sigmoid(x) + gamma).clamp(0, 1) + +def get_cali_samples(train_data_loader, num_samples, no_label=True): + """Generate sub-dataset for calibration. + + Args: + train_data_loader (torch.utils.data.DataLoader): + num_samples (int): + no_label (bool, optional): If the dataloader has no labels. Defaults to True. + + Returns: + torch.Tensor: Concatenated data matrix. + """ + cali_data_list = [] + if no_label: + for batch_data in train_data_loader: + cali_data_list.append(batch_data["image"]) + if len(cali_data_list) >= num_samples: + break + else: + for batch_data, _ in train_data_loader: + cali_data_list.append(batch_data) + if len(cali_data_list) >= num_samples: + break + return torch.cat(cali_data_list, dim=0)[:num_samples].cpu() + +def adaround(model: GraphModule, train_data, n_samples: int = 128, + lr: float = 4e-3, batch_size: int = 128, max_iter: int = 8000, + weight: float = 0.01, beta: float = 20, gamma: float = -0.1, zeta: float = 1.1, + quant_min: int = -128, quant_max: int = 127, per_channel: bool = False): + """Main function to run AdaRound on a given model. + + Args: + model (GraphModule): + train_data (torch.utils.data.DataLoader): + n_samples (int, optional): Defaults to 128. + lr (float, optional): Defaults to 4e-3. + batch_size (int, optional): Defaults to 128. + max_iter (int, optional): Defaults to 8000. + weight (float, optional): Defaults to 0.01. + beta (float, optional): Defaults to 20. + gamma (float, optional): Defaults to -0.1. + zeta (float, optional): Defaults to 1.1. + quant_min (int, optional): Defaults to -128. + quant_max (int, optional): Defaults to 127. + per_channel (bool, optional): Defaults to False. + + Returns: + GraphModule: Modified copy of the given model. + """ + model.cpu() + print("AdaRound: Quant-Range=" + "[{}, {}], Per-Channel={}".format(quant_min, quant_max, per_channel)) + + # sample data from training data + cali_data = get_cali_samples(train_data, n_samples) + + # apply rewritten deepcopy of GraphModule + quant_model = _deepcopy_graphmodule(model) + quant_model.eval() + model.eval() + + # insert observer to record input/output + fp_observer_binding_dict = _insert_observer(model, "output") + quant_observer_binding_dict = _insert_observer(quant_model, "input") + + print("Record Outputs (by CPU) ...") + # apply data to record output + saver = FpOutputSaver(model, observer_binding_dict=fp_observer_binding_dict, + input_data=cali_data) + + # get layers for reconstruction + modules = dict(quant_model.named_modules()) + quant_module_name_list = _get_quant_modules_by_topology(quant_model) + + # TODO: more observer types / affine mode + if per_channel: + qscheme = torch.per_channel_symmetric + ch_axis = 0 + else: + qscheme = torch.per_tensor_symmetric + ch_axis = -1 + + observer_type = MinMaxObserver.with_args(dtype=torch.qint8, quant_min=quant_min, quant_max=quant_max, + reduce_range=False, qscheme=qscheme, ch_axis=ch_axis) + + scale_dict = _init_weight_scale(quant_model, quant_observer_binding_dict.keys(), observer_type) + + # disable gradient for all parameters + for n, m in quant_model.named_modules(): + if hasattr(m, "weight"): + m.weight.requires_grad = False + if hasattr(m, "bias") and getattr(m, "bias") is not None: + m.bias.requires_grad = False + + quant_model.cuda() + cali_data = cali_data.cuda() + + # learn the rounding mask for each layer + for node_name in quant_module_name_list: + print("===> Train for Layer: {}".format(node_name)) + # get input and output tensors + output_tensor = saver.get_result_by_name(node_name).cuda() + input_observer = modules[quant_observer_binding_dict[node_name].name] + cur_node = _get_node_by_name(quant_model, node_name) + if cur_node is not None: + module = modules[cur_node.target] + else: + raise RuntimeError("Node not found in graph.") + module.eval() + + with _Recorder(input_observer): + with torch.no_grad(): + quant_model(cali_data) + input_tensor = input_observer.cache.detach() + + # optimize the 'alpha' + temp_anneal = TempDecay(t_max=max_iter, start_b=beta) + ada_reg_loss = AdaRoundReg(zeta=zeta, gamma=gamma, weight=weight, + temp_anneal=temp_anneal, h_func=_rectified_sigmoid) + + scale, zero_point = scale_dict[node_name] + ada_quantizer = AdaRoundQuantizer(reg=ada_reg_loss, ch_axis=ch_axis, + scale=scale, zero_point=zero_point, + quant_min=quant_min, quant_max=quant_max) + + ada_layer = AdaRoundLayer(module, ada_reg_loss, ada_quantizer).cuda() + + alpha = learning_alpha(input_tensor, output_tensor, + ada_layer, ada_reg_loss, lr, + batch_size, max_iter) + + ada_quantizer.soft_quantize = False + module.weight.data = ada_quantizer(module.weight, alpha) + module.weight.requires_grad = False + + return quant_model + +def _deepcopy_graphmodule(gm: GraphModule): + """Rewrite the deepcopy of GraphModule. (Copy its 'graph'.) + + Args: + gm (GraphModule): + + Returns: + GraphModule: A deepcopied gm. + """ + copied_gm = copy.deepcopy(gm) + copied_gm.graph = copy.deepcopy(gm.graph) + return copied_gm + +def _insert_observer(gm: GraphModule, insert_type="input"): + """Insert observers to record the input and output of target layers. + + Args: + gm (GraphModule): + insert_type (str, optional): Defaults to "input". + + Returns: + Dict: Dict to lookup the observers. + """ + assert insert_type in ["input", "output"], "insert_type should be 'input' or 'output'." + + modules = dict(gm.named_modules()) + nodes = list(gm.graph.nodes) + + observer_prefix = "_{}_observer".format(insert_type) + insert_idx = 0 + observer_binding_dict = {} + + for node in nodes: + if node.op == "call_module" and isinstance(modules[node.target], _ADAROUND_SUPPORT_TYPE): + observer_name = observer_prefix + str(insert_idx) + + observer = TensorObserver(recording=(insert_type == "output")) + setattr(gm, observer_name, observer) + if insert_type == "input": + with gm.graph.inserting_before(node): + insert_node = gm.graph.create_node("call_module", observer_name, node.args, {}) + node.args = (insert_node, ) + else: + with gm.graph.inserting_after(node): + insert_node = gm.graph.create_node("call_module", observer_name, (node, ), {}) + node.replace_all_uses_with(insert_node) + insert_node.args = (node, ) + observer_binding_dict[node.name] = insert_node + insert_idx += 1 + + gm.recompile() + gm.graph.lint() + return observer_binding_dict + + +class TensorObserver(ObserverBase): + recording_enabled: torch.Tensor + + def __init__(self, dtype=torch.float32, recording=True): + """Observer to record tensors. + + Args: + dtype (type, optional): Defaults to torch.float32. + recording (bool, optional): Defaults to True. + """ + # overwrite 'dtype' attr instead of passing as args directly + # to avoid built-in assertion in torch.quantization.observer._ObserverBase + super(TensorObserver, self).__init__(dtype=torch.quint8) + self.dtype = dtype + self._cache = None + self.register_buffer("recording_enabled", torch.tensor([1], dtype=torch.uint8)) + self.enable_recording(recording) + + def enable_recording(self, enabled=True): + assert isinstance(enabled, bool) + self.recording_enabled[0] = 1 if enabled else 0 + + def forward(self, x): + if self.recording_enabled[0] == 1: + self._cache = x + return x + + def calculate_qparams(self, **kwargs): + pass + + @property + def cache(self): + return self._cache + + def clear(self): + self._cache = None + + +class FpOutputSaver: + @torch.no_grad() + def __init__(self, fp_gm: GraphModule, + observer_binding_dict: Dict[str, Node], + save_loc="disk", root="./calibration", + input_data=None): + """ + Currently, there are two options provided to save floating point model + outputs, including saving to disk and caching on GPU memory. + + 1) Dump to disk. If so, the output feature maps are dumped to disk + and loaded to memory when necessary. + + 2) Cache on GPU side. If so, the output feature maps are kept in memory. + This option will be risky with a large network and limited GPU memory. + """ + + assert isinstance(fp_gm, GraphModule), "Input module must be a GraphModule." + assert save_loc in ["disk", "gpu"], "Saving location should be 'disk' or 'gpu'." + + self.module = fp_gm + self.observer_binding = observer_binding_dict + self.save_loc = save_loc + self.data_root = root + self._data = dict() + + if self.save_loc == "disk" and not os.path.exists(self.data_root): + raise NotADirectoryError("The given path is not a folder." + "Ensure you give the correct path.") + saving_operation = self._disk_saving_operation \ + if self.save_loc == "disk" else self._gpu_saving_operation + + self._save(input_data=input_data, saving_operation=saving_operation) + + @torch.no_grad() + def _save(self, input_data, saving_operation: Callable): + modules = dict(self.module.named_modules()) + self._turn_on_all_observers(True) + self.module(input_data) + + for node_name, observer_node in self.observer_binding.items(): + observer = modules[observer_node.target] + observer.enable_recording(False) + # Do saving operation and clear cache. + saving_operation(observer, node_name) + observer.clear() + + self._turn_on_all_observers(False) + + def _disk_saving_operation(self, observer: TensorObserver, node_name: str): + output_numpy = observer.cache.cpu().numpy() + np.save(os.path.join(self.data_root, "{}.npy".format(node_name)), + output_numpy) + + def _gpu_saving_operation(self, observer: TensorObserver, node_name: str): + self._data[node_name] = observer.cache + + def _turn_on_all_observers(self, recording=True): + modules = dict(self.module.named_modules()) + for node in self.module.graph.nodes: + if node.op == "call_module": + if isinstance(modules[node.target], TensorObserver): + modules[node.target].enable_recording(recording) + + def get_result_by_name(self, node_name): + assert node_name in self.observer_binding + if self.save_loc == "disk": + # Load from file according to node_name. + output_numpy = np.load(os.path.join(self.data_root, + "{}.npy".format(node_name))) + saved_data = torch.from_numpy(output_numpy) + saved_data = saved_data + return saved_data + + elif self.save_loc == "gpu": + # Load from GPU memory. + return self._data[node_name] + + +def _get_quant_modules_by_topology(gm: GraphModule): + """Get modules in the model. + + Args: + gm (GraphModule): + + Returns: + list: + """ + module_name_list = [] + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + if node.op == "call_module": + if isinstance(modules[node.target], _ADAROUND_SUPPORT_TYPE): + module_name_list.append(node.name) + return module_name_list + +def _init_weight_scale(gm: GraphModule, observed_module_list, observer_type: Callable): + """Simulate the fake quant modules to calculate scales and zero-points. + + Args: + gm (GraphModule): + observed_module_list (list): + observer_type (Callable): + + Returns: + dict: + """ + scale_dict = dict() + modules = dict(gm.named_modules()) + + for name in observed_module_list: + node = _get_node_by_name(gm, name) + if node.op == "call_module": + observer = observer_type() + module = modules[node.target] + weight = module.weight + observer(weight) + scale, zero_point = observer.calculate_qparams() + scale_dict[name] = (scale.cuda().detach(), zero_point.cuda().detach()) + return scale_dict + +def _get_node_by_name(gm: GraphModule, node_name: str): + """ + + Args: + gm (GraphModule): + node_name (str): + + Returns: + torch.fx.Node: + """ + for node in gm.graph.nodes: + if node.name == node_name: + return node + return None + + +class _Recorder: + def __init__(self, observer: TensorObserver): + self.observer = observer + self.observer.enable_recording(True) + self.observer.clear() + + def __enter__(self): + pass + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.observer.enable_recording(False) + self.observer.clear() + + +class AdaRoundReg(): + def __init__(self, zeta=1.1, gamma=-0.1, weight=0.01, + temp_anneal: Callable = None, h_func: Callable = _rectified_sigmoid): + super(AdaRoundReg, self).__init__() + self.zeta = zeta + self.gamma = gamma + self.weight = weight + self.temp_anneal = temp_anneal + self.h_func = h_func + self.beta = None + + def round_mask(self, alpha): + return self.h_func(alpha, self.zeta, self.gamma) + + def loss(self, alpha, iter_num): + self.beta = self.temp_anneal(iter_num) + return self.weight * (1 - torch.pow((self.round_mask(alpha) + - 0.5).abs() * 2, self.beta)).sum() + + +class TempDecay: + def __init__(self, t_max=10000, rel_start_decay=0.2, start_b=20, end_b=2): + self.t_max = t_max + self.start_decay = rel_start_decay * t_max + self.start_b = start_b + self.end_b = end_b + + def __call__(self, t): + if t < self.start_decay: + return self.start_b + elif t > self.t_max: + return self.end_b + else: + rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) + return self.end_b + 0.5 * (self.start_b + - self.end_b) * (1 + np.cos(rel_t * np.pi)) + + +class AdaRoundQuantizer: + def __init__(self, reg: AdaRoundReg, ch_axis: int, + scale, zero_point, quant_min=-128, quant_max=127, + soft=True): + self.quant_min = quant_min + self.quant_max = quant_max + self.ch_axis = ch_axis + + self.zeta = reg.zeta + self.gamma = reg.gamma + + self.soft_quantize = soft + self.scale = scale + self.zero_point = zero_point + + self.h_func = reg.h_func + + def __call__(self, w, alpha): + scale = self.scale + zero_point = self.zero_point + if self.ch_axis != -1: + new_shape = [1] * len(w.shape) + new_shape[self.ch_axis] = w.shape[self.ch_axis] + scale = self.scale.reshape(new_shape) + zero_point = self.zero_point.reshape(new_shape) + + if self.soft_quantize: + w = (w / scale).floor() + self.h_func(alpha, self.zeta, self.gamma) + else: + w = (w / scale).floor() + (alpha > 0).float() + + w += zero_point + w = w.clamp(self.quant_min, self.quant_max) + w -= zero_point + + w = w * scale + return w + + def __repr__(self): + scale = self.scale.item() + if self.ch_axis != -1: + scale = "per-channel scale of " + str(tuple(self.scale.shape)) + repr_str = "AdaRoundQuantizer(quant_min={}, quant_max={}, scale={}, " \ + "gamma={}, zeta={}, soft_quantize={})".format(self.quant_min, self.quant_max, scale, + self.gamma, self.zeta, self.soft_quantize) + return repr_str + + +class AdaRoundLayer(nn.Module): + def __init__(self, module: nn.Module, + reg: AdaRoundReg, quantizer: AdaRoundQuantizer): + super(AdaRoundLayer, self).__init__() + assert isinstance(module, _ADAROUND_SUPPORT_TYPE), \ + "Cannot apply AdaRound on this module." + + self.module = module + self.quantizer = quantizer + + if self.module.bias is not None: + self.module.bias.requires_grad = False + + scale = self.quantizer.scale + if self.quantizer.ch_axis != -1: + new_shape = [1] * len(self.module.weight.shape) + new_shape[self.quantizer.ch_axis] = self.module.weight.shape[self.quantizer.ch_axis] + scale = self.quantizer.scale.reshape(new_shape) + + rest = self.module.weight / scale - (self.module.weight / scale).floor() + rest = -torch.log((reg.zeta - reg.gamma) / (rest - reg.gamma) - 1) + + self.alpha = torch.nn.Parameter(rest.cuda(), True) + + def forward(self, x): + weight = self.quantizer(self.module.weight, self.alpha) + + if isinstance(self.module, nn.Conv2d): + x = F.conv2d(x, weight, self.module.bias, stride=self.module.stride, + padding=self.module.padding, dilation=self.module.dilation, + groups=self.module.groups) + elif isinstance(self.module, nn.Linear): + x = F.linear(x, weight, self.module.bias) + else: + raise RuntimeError("Unsupported module type.") + + return x + + +# TODO: support different loss functions / p in lp_norm +def learning_alpha(in_tensor: torch.Tensor, + fp_out_tensor: torch.Tensor, + ada_layer: AdaRoundLayer, + ada_reg: AdaRoundReg, + learning_rate: float, + batch_size: int, + max_iter: int) -> torch.Tensor: + + optimizer = torch.optim.Adam([ada_layer.alpha], lr=learning_rate) + + for epoch in range(max_iter): + for idx in range(np.ceil(len(in_tensor) / batch_size).astype(int)): + st = idx * batch_size + ed = st + batch_size + + input_ = in_tensor[st:ed].squeeze(1).detach() + fp_output = fp_out_tensor[st:ed].squeeze(1).detach() + output = ada_layer(input_) + + loss_p = lp_norm(output, fp_output) + loss_reg = ada_reg.loss(ada_layer.alpha, epoch) + loss = loss_p + loss_reg + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if epoch % 200 == 0: + print("Epoch: {:<4} L2 Loss: {:>10.3f} Loss P: " + "{:>8.6f} Loss Reg: {:>5.3f} Beta: {:>3.3f}".format(epoch, loss, loss_p, + loss_reg, ada_reg.beta)) + res = ada_reg.round_mask(ada_layer.alpha) + print("Loss: {:>5.3f} Ceil: {:>5} Floor: {:>5} Total: {:>5} Ratio: {:>.3f}".format( + loss, + res[res + 1e-4 >= 1.0].numel(), res[res <= 1e-4].numel(), torch.numel(res), + (res[res + 1e-4 >= 1.0].numel() + res[res <= 1e-4].numel()) / torch.numel(res))) + return ada_layer.alpha + +@torch.no_grad() +def round_to_nearset_quant(m: nn.Module, scale, zero_point, quant_min, quant_max, ch_axis): + w = m.weight + if ch_axis != -1: + new_shape = [1] * len(w.shape) + new_shape[ch_axis] = w.shape[ch_axis] + scale = scale.reshape(new_shape) + zero_point = zero_point.reshape(new_shape) + + w = (w / scale).round() + w += zero_point + w = w.clamp(quant_min, quant_max) + w -= zero_point + w = w * scale + + return w + +if __name__ == "__main__": + pass diff --git a/mqbench/convert_deploy.py b/mqbench/convert_deploy.py new file mode 100644 index 0000000..afde652 --- /dev/null +++ b/mqbench/convert_deploy.py @@ -0,0 +1,102 @@ +import os.path as osp + +import torch +from torch.fx import GraphModule + +import mqbench.custom_symbolic_opset # noqa: F401 +import mqbench.fusion_method # noqa: F401 +from mqbench.prepare_by_platform import BackendType +from mqbench.utils.logger import logger +from mqbench.utils.registry import ( + BACKEND_DEPLOY_FUNCTION, + register_deploy_function, + FUSED_MODULE_CONVERT_FUNCTION +) +from mqbench.convert_onnx import ( + remove_fakequantize_and_collect_params_nnie, + remove_fakequantize_and_collect_params +) + + +@register_deploy_function(BackendType.SNPE) +@register_deploy_function(BackendType.PPLW8A16) +@register_deploy_function(BackendType.Tensorrt) +@register_deploy_function(BackendType.NNIE) +def convert_merge_bn(model: GraphModule, **kwargs): + logger.info("Merge BN for deploy.") + nodes = list(model.graph.nodes) + modules = dict(model.named_modules()) + for node in nodes: + if node.op == 'call_module': + if type(modules[node.target]) in FUSED_MODULE_CONVERT_FUNCTION: + FUSED_MODULE_CONVERT_FUNCTION[type(modules[node.target])](model, node) + + +@register_deploy_function(BackendType.Academic) +@register_deploy_function(BackendType.SNPE) +@register_deploy_function(BackendType.PPLW8A16) +@register_deploy_function(BackendType.Tensorrt) +@register_deploy_function(BackendType.NNIE) +def convert_onnx(model: GraphModule, input_shape_dict, onnx_model_path='./test.onnx', **kwargs): + logger.info("Export to onnx.") + device = next(model.parameters()).device + dummy_input = {name: torch.rand(shape).to(device) for name, shape in input_shape_dict.items()} + torch.onnx.export(model, tuple(dummy_input.values()), onnx_model_path, + input_names=list(dummy_input.keys()), + opset_version=11, + enable_onnx_checker=False) + + +@register_deploy_function(BackendType.NNIE) +def deploy_qparams_nnie(model: GraphModule, onnx_model_path, **kwargs): + logger.info("Extract qparams for NNIE.") + remove_fakequantize_and_collect_params_nnie(onnx_model_path) + + +@register_deploy_function(BackendType.Tensorrt) +def deploy_qparams_tensorrt(model: GraphModule, onnx_model_path, **kwargs): + logger.info("Extract qparams for TensorRT.") + remove_fakequantize_and_collect_params(onnx_model_path, backend='tensorrt') + + +@register_deploy_function(BackendType.SNPE) +def deploy_qparams_snpe(model: GraphModule, onnx_model_path, **kwargs): + logger.info("Extract qparams for SNPE.") + remove_fakequantize_and_collect_params(onnx_model_path, backend='snpe') + + +@register_deploy_function(BackendType.PPLW8A16) +def deploy_qparams_pplw8a16(model: GraphModule, onnx_model_path, **kwargs): + logger.info("Extract qparams for PPLW8A16.") + remove_fakequantize_and_collect_params(onnx_model_path, backend='ppl') + + +def convert_deploy(model: GraphModule, backend_type: BackendType, + input_shape_dict, output_path='./', + model_name='mqbench_model_quantized.onnx'): + r"""Convert model to onnx model and quantization params depends on backend. + + Args: + model (GraphModule): GraphModule prepared qat module. + backend_type (BackendType): specific which backend should be converted to. + input_shape_dict (dict): keys are model input name(should be forward function + params name, values are list of tensor dims) + output_path (str, optional): path to save convert results. Defaults to './'. + model_name (str, optional): name of converted onnx model. Defaults to 'mqbench_model_quantized.onnx'. + + >>> note on input_shape_dict: + example: {'input_0': [1, 3, 224, 224] + 'input_1': [1, 3, 112, 112] + } + while forward function signature is like: + def forward(self, input_0, input_1): + pass + """ + kwargs = { + 'input_shape_dict': input_shape_dict, + 'output_path': output_path, + 'model_name': model_name, + 'onnx_model_path': osp.join(output_path, model_name) + } + for convert_function in BACKEND_DEPLOY_FUNCTION[backend_type]: + convert_function(model, **kwargs) diff --git a/mqbench/convert_onnx.py b/mqbench/convert_onnx.py new file mode 100644 index 0000000..1718a60 --- /dev/null +++ b/mqbench/convert_onnx.py @@ -0,0 +1,371 @@ +import json +import os +import onnx +from onnx import numpy_helper +import numpy as np + +perchannel_fakequantizer = ['FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine', 'FakeQuantizeDSQPerchannel'] +pertensor_fakequantizer = ['LearnablePerTensorAffine', 'FixedPerTensorAffine', 'FakeQuantizeDSQPertensor'] +all_fakequantizer = perchannel_fakequantizer + pertensor_fakequantizer + +def update_inp2node_out2node(graph): + out2node = {} + inp2node = {} + for node in graph.node: + for out in node.output: + # suppose each node only has one output + out2node[out] = node + for idx, inp in enumerate(node.input): + # one node may have multiple inputs + if inp not in inp2node: + inp2node[inp] = [] + inp2node[inp].append([node, idx]) + return out2node, inp2node + +def prepare_data(graph): + params = {} + for init in graph.initializer: + params[init.name] = numpy_helper.to_array(init) + for node in graph.node: + if node.op_type == "Constant": + for attr in node.attribute: + if attr.name == "value": + params[node.output[0]] = numpy_helper.to_array(attr.t) + return params + +def prepare_initializer(graph): + named_initializer = {} + for init in graph.initializer: + named_initializer[init.name] = init + return named_initializer + +def parse_attrs(node_attrs): + attrs = {} + for attr in node_attrs: + if attr.type == onnx.AttributeProto.AttributeType.INTS: + attrs[attr.name] = tuple(attr.ints) + elif attr.type == onnx.AttributeProto.AttributeType.INT: + attrs[attr.name] = attr.i + elif attr.type == onnx.AttributeProto.AttributeType.FLOATS: + attrs[attr.name] = tuple(attr.floats) + elif attr.type == onnx.AttributeProto.AttributeType.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == onnx.AttributeProto.AttributeType.TENSOR: + attrs[attr.name] = numpy_helper.to_array(attr.t) + elif attr.type == onnx.AttributeProto.AttributeType.STRING: + attrs[attr.name] = str(attr.s) + elif attr.type == onnx.AttributeProto.AttributeType.STRINGS: + attrs[attr.name] = tuple([str(x) for x in attr.strings]) + else: + raise Exception("ATTR Type [{}] Not Supported!".format(attr.type)) + return attrs + +def get_constant_inputs(node, out2node): + node_list = [] + for inp in node.input: + if inp in out2node and out2node[inp].op_type == 'Constant': + node_list.append(out2node[inp]) + return node_list + +class OnnxPreprocess(object): + def replace_resize_op_with_upsample(self, graph, out2node): + nodes_to_be_removed = [] + idx = 0 + while idx < len(graph.node): + node = graph.node[idx] + if node.op_type == 'Resize': + print(f"Replace resize op: <{node.name}> with upsample.") + attrs = parse_attrs(node.attribute) + upsample_node = onnx.helper.make_node('Upsample', + name=node.name, + inputs=[node.input[0], node.input[2]], + outputs=node.output, + mode=attrs['mode']) + nodes_to_be_removed.append(node) + nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) + graph.node.insert(idx, upsample_node) + idx += 1 + idx += 1 + for node in nodes_to_be_removed: + graph.node.remove(node) + return + + def remove_fake_pad_op(self, graph, name2data, inp2node, out2node): + nodes_to_be_removed = [] + for idx, node in enumerate(graph.node): + node = graph.node[idx] + if node.op_type == 'Pad': + pads = name2data[node.input[1]] + if all([x == 0 for x in pads]): + print(f"Remove pad op: <{node.name}>.") + next_nodes = inp2node[node.output[0]] + for next_node, idx in next_nodes: + next_node.input[idx] = node.input[0] + nodes_to_be_removed.append(node) + nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) + for node in nodes_to_be_removed: + graph.node.remove(node) + return + +class NNIE_process(object): + def gen_gfpq_param_file(self, graph, clip_val): + nnie_exclude_layer_type = ['Flatten', 'Relu', 'PRelu', 'Sigmoid', 'Reshape', + 'Softmax', 'CaffeSoftmax', 'Clip', 'GlobalAveragePool', 'Mul'] + interp_layer_cnt = 0 + gfpq_param_dict = {} + for idx, node in enumerate(graph.node): + # We can not support NNIE group conv. + # Group conv need group-size input params. + if node.op_type == 'Conv' and node.attribute[1].i != 1: + continue + + layer_input_tensor = [] + for in_tensor in node.input: + if in_tensor in clip_val: + clip_value = clip_val[in_tensor] + layer_input_tensor.append(float(clip_value)) + # Upsample layer only reserve one input. + if node.op_type in ['Upsample', 'DynamicUpsample']: + break + + if node.op_type not in nnie_exclude_layer_type and len(layer_input_tensor) > 0: + gfpq_param_dict[node.name] = layer_input_tensor + + # Upsample ---> Upsample + Permute in NNIE. + if node.op_type in ['Upsample', 'DynamicUpsample']: + interp_layer_name = node.name + gfpq_param_dict[interp_layer_name + '_permute_' + str(interp_layer_cnt)] = gfpq_param_dict[interp_layer_name] + interp_layer_cnt += 1 + + with open(os.path.join('./', 'nnie_gfpq_param_dict.json'), 'w') as f: + json.dump({"nnie": {"gfpq_param_dict": gfpq_param_dict}}, f, indent=4) + + def remove_fakequantize_and_collect_params(self, onnx_path, model_save_path="nnie_deploy_model.onnx"): + model = onnx.load(onnx_path) + graph = model.graph + out2node, inp2node = update_inp2node_out2node(graph) + name2data = prepare_data(graph) + named_initializer = prepare_initializer(graph) + + preprocess = OnnxPreprocess() + preprocess.replace_resize_op_with_upsample(graph, out2node) + preprocess.remove_fake_pad_op(graph, name2data, inp2node, out2node) + out2node, inp2node = update_inp2node_out2node(graph) + + nodes_to_be_removed = [] + clip_ranges = {} + for node in graph.node: + if node.op_type == 'NNIEQuantize': + next_nodes = inp2node[node.output[0]] + if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: + # fake quantize for weights + next_node, idx = next_nodes[0] + next_node.input[idx] = node.input[0] + # clip weights + tensor_name = node.input[0] + data = name2data[tensor_name] + clip_range = name2data[node.input[1]] + new_data = np.clip(data, -clip_range, clip_range) + new_data = numpy_helper.from_array(new_data) + named_initializer[tensor_name].raw_data = new_data.raw_data + print(f'clip weights {tensor_name} to range [{-clip_range}, {clip_range}].') + else: + # fake quantize for activations + clip_ranges[node.input[0]] = name2data[node.input[1]] + for next_node, idx in next_nodes: + next_node.input[idx] = node.input[0] + + nodes_to_be_removed.append(node) + nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) + + for node in nodes_to_be_removed: + graph.node.remove(node) + + self.gen_gfpq_param_file(graph, clip_ranges) + onnx.save(model, model_save_path) + +remove_fakequantize_and_collect_params_nnie = NNIE_process().remove_fakequantize_and_collect_params + +class LinearQuantizer_process(object): + # some method like dorefa need pre-compute weights + def weight_preprocess(self, target_tensor, out2node, inp2node, named_initializer): + def find_weight(tensor): + if tensor not in named_initializer: + _node = out2node[tensor] + for inp in _node.input: + return find_weight(inp) + return tensor + weight = find_weight(target_tensor) + + # TODO need more general method, like onnxruntime infer + data = numpy_helper.to_array(named_initializer[weight]) + data = np.tanh(data) + data = data / (np.max(np.abs(data)) + 1e-5) + data = numpy_helper.from_array(data) + named_initializer[weight].raw_data = data.raw_data + + redundant_nodes = [] + + def find_redundant_nodes(tensor): + if tensor == target_tensor: + return + nodes = inp2node[tensor] + for node, idx in nodes: + if node not in redundant_nodes: + redundant_nodes.append(node) + redundant_nodes.extend(get_constant_inputs(node, out2node)) + find_redundant_nodes(node.output[0]) + find_redundant_nodes(weight) + return weight, redundant_nodes + + def deal_with_weight_fakequant(self, node, out2node, inp2node, named_initializer): + next_nodes = inp2node[node.output[0]] + assert len(next_nodes) == 1 + next_node, idx = next_nodes[0] + assert next_node.op_type in ['Conv', 'Gemm'] + redundant_nodes = [] + if node.input[0] not in named_initializer: + node.input[0], redundant_nodes = \ + self.weight_preprocess(node.input[0], out2node, inp2node, named_initializer) + next_node.input[idx] = node.input[0] + return redundant_nodes + + def deal_with_activation_fakequant(self, node, inp2node): + next_nodes = inp2node[node.output[0]] + for next_node, idx in next_nodes: + next_node.input[idx] = node.input[0] + return + + def parse_qparams(self, node, name2data): + tensor_name, scale, zero_point = node.input[:3] + scale, zero_point = name2data[scale], name2data[zero_point] + if len(node.input) > 3: + qmin, qmax = node.input[-2:] + qmin, qmax = name2data[qmin], name2data[qmax] + elif len(node.attribute) > 0: + qparams = parse_attrs(node.attribute) + qmin = qparams['quant_min'] + qmax = qparams['quant_max'] + else: + print(f'qmin and qmax are not found for <{node.name}>!') + return tensor_name, scale, zero_point, qmin, qmax + + def clip_weight(self, node, name2data, named_initializer): + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + data = name2data[tensor_name] + clip_range_min = (qmin - zero_point) * scale + clip_range_max = (qmax - zero_point) * scale + if scale.shape[0] > 1: + new_data = [] + for c in range(data.shape[0]): + new_data.append(np.clip(data[c], clip_range_min[c], clip_range_max[c])) + new_data = np.array(new_data) + print(f'clip weights {tensor_name} to per-cahnnel clip range.') + else: + new_data = np.clip(data, clip_range_min, clip_range_max) + print(f'clip weights {tensor_name} to range [{clip_range_min}, {clip_range_max}].') + new_data = numpy_helper.from_array(new_data) + named_initializer[tensor_name].raw_data = new_data.raw_data + + def post_process_clip_ranges(self, clip_ranges, graph, inp2node): + def find_the_closest_clip_range(node): + if node.input[0] in clip_ranges: + return node.input[0] + elif node.op_type in ['Flatten', 'Resize']: + return find_the_closest_clip_range(inp2node[node.output[0]][0][0]) + else: + return None + + for node in graph.node: + if node.op_type in ['Flatten', 'Resize']: + tensor_name = find_the_closest_clip_range(node) + if tensor_name: + clip_ranges[node.input[0]] = clip_ranges[tensor_name] + print(f'Pass <{tensor_name}> clip range to <{node.name}> input <{node.input[0]}>.') + return clip_ranges + + def remove_fakequantize_and_collect_params(self, onnx_path, backend): + model = onnx.load(onnx_path) + graph = model.graph + out2node, inp2node = update_inp2node_out2node(graph) + name2data = prepare_data(graph) + named_initializer = prepare_initializer(graph) + + preprocess = OnnxPreprocess() + preprocess.replace_resize_op_with_upsample(graph, out2node) + preprocess.remove_fake_pad_op(graph, name2data, inp2node, out2node) + out2node, inp2node = update_inp2node_out2node(graph) + + clip_ranges = {} + nodes_to_be_removed = [] + for node in graph.node: + if node.op_type in all_fakequantizer: + nodes_to_be_removed.append(node) + nodes_to_be_removed.extend(get_constant_inputs(node, out2node)) + + if node.op_type in perchannel_fakequantizer: + # fake quantize for weights, suppose per-channel quantize only for weight + redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) + nodes_to_be_removed.extend(redundant_nodes) + self.clip_weight(node, name2data, named_initializer) + if backend == 'ppl': + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + clip_ranges[tensor_name] = {'step': [float(x) for x in scale], + 'zero_point': [int(x) for x in zero_point], + 'min': [float(x) for x in scale * (qmin - zero_point)], + 'max': [float(x) for x in scale * (qmax - zero_point)], + 'bit': int(np.log2(qmax - qmin + 1)), + 'type': "biased", + } + + elif node.op_type in pertensor_fakequantizer: + if node.output[0] in [x.name for x in graph.output]: + inp2node[node.output[0]] = [] + + next_nodes = inp2node[node.output[0]] + if len(next_nodes) == 1 and next_nodes[0][1] == 1 and next_nodes[0][0].op_type in ['Gemm', 'Conv']: + # fake quantize for weights + redundant_nodes = self.deal_with_weight_fakequant(node, out2node, inp2node, named_initializer) + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + nodes_to_be_removed.extend(redundant_nodes) + self.clip_weight(node, name2data, named_initializer) + else: + # fake quantize for activations + self.deal_with_activation_fakequant(node, inp2node) + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams(node, name2data) + for out in graph.output: + if out.name == node.output[0]: + out.name = tensor_name + + if backend == 'tensorrt': + clip_ranges[tensor_name] = float(scale * min(-qmin, qmax)) + elif backend == 'snpe': + clip_ranges[tensor_name] = {'bitwidth': int(np.log2(qmax - qmin + 1)), + 'min': float(scale * (qmin - zero_point)), + 'max': float(scale * (qmax - zero_point)) + } + if backend == 'ppl': + clip_ranges[tensor_name] = {'step': float(scale), + 'zero_point': int(zero_point), + 'min': float(scale * (qmin - zero_point)), + 'max': float(scale * (qmax - zero_point)), + 'bit': int(np.log2(qmax - qmin + 1)), + 'type': "biased", + } + + for node in nodes_to_be_removed: + graph.node.remove(node) + + clip_ranges = self.post_process_clip_ranges(clip_ranges, graph, inp2node) + if backend == 'tensorrt': + context = {"tensorrt": {"blob_range": clip_ranges}} + elif backend == 'snpe': + context = {'activation_encodings': clip_ranges, 'param_encodings': {}} + elif backend == 'ppl': + context = {"ppl": clip_ranges} + filename = os.path.join('./', '{}_clip_ranges.json'.format(backend)) + with open(filename, 'w') as f: + json.dump(context, f, indent=4) + onnx.save(model, '{}_deploy_model.onnx'.format(backend)) + +remove_fakequantize_and_collect_params = LinearQuantizer_process().remove_fakequantize_and_collect_params diff --git a/mqbench/custom_quantizer.py b/mqbench/custom_quantizer.py new file mode 100644 index 0000000..3709794 --- /dev/null +++ b/mqbench/custom_quantizer.py @@ -0,0 +1,266 @@ +import operator +from typing import ( + Dict, Any, Callable +) + +import torch +from torch.fx import ( + GraphModule +) +from torch.quantization import ( + propagate_qconfig_, + swap_module +) +from torch.nn.intrinsic import ( + _FusedModule +) +from torch.quantization.quantization_mappings import ( + get_default_qat_module_mappings, + get_default_static_quant_module_mappings +) +from torch.quantization.utils import ( + get_combined_dict +) +from torch.quantization.fx.qconfig_utils import ( + get_flattened_qconfig_dict +) + +import mqbench.nn as qnn +import mqbench.nn.intrinsic +import mqbench.nn.intrinsic.qat # noqa F401 +from mqbench.utils.logger import logger +from mqbench.utils.registry import register_model_quantizer +from mqbench.prepare_by_platform import BackendType + + +@register_model_quantizer(BackendType.SNPE) +@register_model_quantizer(BackendType.NNIE) +@register_model_quantizer(BackendType.Academic) +class ModelQuantizer(object): + """General model quantizer class. + First, replace common float module to nn.qat.modules to make weight fake + quantized. + Second, insert activation fake quantize node before specific layers. Layer + type is defined in function_type_to_quant_input / module_type_to_quant_input. + We leave the output not quantized since it is next layer's input. + """ + def __init__(self, extra_quantizer_dict): + self.additional_function_type = extra_quantizer_dict.get('additional_function_type', []) + self.additional_module_type = extra_quantizer_dict.get('additional_module_type', ()) + self.exclude_module_name = extra_quantizer_dict.get('exclude_module_name', []) + + def prepare(self, model: GraphModule, qconfig_dict: Dict): + model = self._weight_quant(model, qconfig_dict) + model = self._insert_fake_quantize_for_act_quant(model, qconfig_dict) + return model + + def _insert_fake_quantize_for_act_quant( + self, + model: GraphModule, + qconfig_dict: Any): + graph = model.graph + nodes = list(model.graph.nodes) + + quantizer_prefix = "_post_act_fake_quantizer" + node_to_quantize_output = self._find_act_quants(model) + + for node in node_to_quantize_output: + fake_quantizer = qconfig_dict.activation() + quantizer_name = node.name + quantizer_prefix + setattr(model, quantizer_name, fake_quantizer) + logger.info("Insert act quant {}".format(quantizer_name)) + with graph.inserting_after(node): + inserted_node = graph.create_node("call_module", quantizer_name, (node,), {}) + for _node in nodes: + _node.args = self._fix_succ_recursivly(_node.args, node, inserted_node) + + model.recompile() + model.graph.lint() + return model + + def _fix_succ_recursivly(self, args_tuple, target_node, inserted_node): + _tmp = list(args_tuple) + for _i, _arg in enumerate(args_tuple): + if _arg == target_node: + _tmp[_i] = inserted_node + elif isinstance(_arg, tuple): + _tmp[_i] = self._fix_succ_recursivly(_arg, target_node, inserted_node) + elif isinstance(_arg, list): + _tmp[_i] = list(self._fix_succ_recursivly(_arg, target_node, inserted_node)) + args_tuple = tuple(_tmp) + return args_tuple + + def _weight_quant(self, model: GraphModule, qconfig_dict: Dict): + logger.info("Replace module to qat module.") + flattened_qconfig_dict = get_flattened_qconfig_dict({'': qconfig_dict}) + propagate_qconfig_(model, flattened_qconfig_dict) + self._qat_swap_modules(model, self._additional_qat_module_mapping) + return model + + @property + def function_type_to_quant_input(self) -> list: + return [ + operator.add, + operator.mul, + torch.cat, + torch.nn.functional.adaptive_avg_pool2d + ] + self.additional_function_type + + @property + def module_type_to_quant_input(self) -> tuple: + return ( + # Conv + torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d, + torch.nn.intrinsic.qat.modules.conv_fused.ConvBn2d, + torch.nn.qat.modules.conv.Conv2d, + # Linear + torch.nn.qat.modules.linear.Linear, + qnn.intrinsic.qat.LinearBn1d, + # Pooling + torch.nn.modules.pooling.MaxPool2d, + torch.nn.modules.pooling.AvgPool2d, + torch.nn.modules.pooling.AdaptiveAvgPool2d, + # BN + torch.nn.BatchNorm2d, + # Prelu mostly do not merge. + torch.nn.PReLU, + # Upsample + torch.nn.Upsample + ) + self.additional_module_type + + def _flatten_args(self, node): + flattned_args = [] + if isinstance(node, dict): + for v in node.values(): + flattned_args.extend(self._flatten_args(v)) + elif isinstance(node, tuple) or isinstance(node, list): + for n in node: + flattned_args.extend(self._flatten_args(n)) + else: + flattned_args.extend([node]) + return flattned_args + + def _find_act_quants(self, model: GraphModule) -> (set, set): + nodes = list(model.graph.nodes) + modules = dict(model.named_modules()) + node_need_to_quantize_output = [] + for node in nodes: + if node.op == "call_module" and node.target in self.exclude_module_name: + continue + if (node.op == "call_module" and isinstance(modules[node.target], self.module_type_to_quant_input)) or \ + ((node.op == 'call_function' or node.op == 'call_method') and + node.target in self.function_type_to_quant_input): + input_node_list = self._flatten_args(node.args) + for _node in input_node_list: + if isinstance(_node, torch.fx.node.Node): + node_need_to_quantize_output.append(_node) + return set(node_need_to_quantize_output) + + @property + def _additional_qat_module_mapping(self): + return { + qnn.intrinsic.LinearBn1d: qnn.intrinsic.qat.LinearBn1d + } + + def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Dict[Callable, Callable]): + all_mappings = get_combined_dict( + get_default_qat_module_mappings(), additional_qat_module_mapping) + root = self._convert(root, all_mappings, inplace=True) + return root + + def _convert(self, module, mapping=None, inplace=False, scope=''): + if mapping is None: + mapping = get_default_static_quant_module_mappings() + + if not inplace: + module = copy.deepcopy(module) + reassign = {} + for name, mod in module.named_children(): + # fused modules are swapped as one unit + new_scope = "{}.{}".format(scope, name) if scope != '' else name + if new_scope in self.exclude_module_name: + logger.info("Skip quant layer: " + new_scope) + continue + if not isinstance(mod, _FusedModule): + self._convert(mod, mapping, True, new_scope) + reassign[name] = swap_module(mod, mapping, {}) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + +@register_model_quantizer(BackendType.Tensorrt) +class TRTModelQuantizer(ModelQuantizer): + """The different points of TRT quantizer are how to deal with add op + and the last layer. + """ + def __init__(self, extra_quantizer_dict): + super().__init__(extra_quantizer_dict) + + @property + def _merge_add_type(self): + return (torch.nn.Conv2d, torch.nn.Linear) + + def _find_act_quants(self, model: GraphModule) -> set: + nodes = list(model.graph.nodes) + modules = dict(model.named_modules()) + node_need_to_quantize_output = [] + for node in nodes: + if node.op == "call_module" and node.target in self.exclude_module_name: + continue + if (node.op == "call_module" and isinstance(modules[node.target], self.module_type_to_quant_input)) or \ + ((node.op == 'call_function' or node.op == 'call_method') and + node.target in self.function_type_to_quant_input): + # Add will be merged with previous conv. + input_node_list = self._flatten_args(node.args) + if node.target is operator.add: + merge_node = self._find_add_merge_node(model, input_node_list, node) + if merge_node: + input_node_list.remove(merge_node) + node_need_to_quantize_output.extend(input_node_list) + else: + for _node in input_node_list: + if isinstance(_node, torch.fx.node.Node): + node_need_to_quantize_output.append(_node) + + return set(node_need_to_quantize_output) + + def _find_add_merge_node(self, model, input_node_list, node): + """Find the first input node which has only one successor from the last. + This kind of node can be merge with add. + """ + input_node_list.reverse() + modules = dict(model.named_modules()) + for input_node in input_node_list: + if input_node.op == 'call_module' and type(modules[input_node.target]) in self._merge_add_type: + succ = 0 + for _node in list(model.graph.nodes): + _node_input_list = self._flatten_args(_node.args) + if input_node in _node_input_list: + succ += 1 + if succ == 1: + return input_node + return None + + +@register_model_quantizer(BackendType.PPLW8A16) +class TotalINTQuantizer(ModelQuantizer): + """There is only INT8 calculations in the model. + We quantize the input tensors of all layers and the output tensors + of the last layers. We quantize every activations tensors and weight + tensors using this method. + """ + def __init__(self, extra_quantizer_dict): + super().__init__(extra_quantizer_dict) + + def _find_act_quants(self, model: GraphModule) -> (set, set): + node_need_to_quantize_output = super(). _find_act_quants(model) + nodes = list(model.graph.nodes) + for node in nodes: + if node.op == 'output': + for output_node in self._flatten_args(node.args): + node_need_to_quantize_output.add(output_node) + + return set(node_need_to_quantize_output) \ No newline at end of file diff --git a/mqbench/custom_symbolic_opset.py b/mqbench/custom_symbolic_opset.py new file mode 100644 index 0000000..6fcb1f2 --- /dev/null +++ b/mqbench/custom_symbolic_opset.py @@ -0,0 +1,23 @@ +from torch.onnx import register_custom_op_symbolic + +# Register symbolic op for torch.quantize_function op. + +def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor): + return g.op("::LearnablePerTensorAffine", x, scale, zero_point, quant_min, quant_max) + + +register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11) + + +def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max): + return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max) + + +register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11) + + +def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max): + return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max) + + +register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11) \ No newline at end of file diff --git a/mqbench/fake_quantize/__init__.py b/mqbench/fake_quantize/__init__.py new file mode 100644 index 0000000..315ddc9 --- /dev/null +++ b/mqbench/fake_quantize/__init__.py @@ -0,0 +1,6 @@ +from .dorefa import DoReFaFakeQuantize +from .dsq import DSQFakeQuantize +from .fixed import FixedFakeQuantize +from .lsq import LearnableFakeQuantize +from .nnie import NNIEFakeQuantize +from .pact import PACTFakeQuantize \ No newline at end of file diff --git a/mqbench/fake_quantize/dorefa.py b/mqbench/fake_quantize/dorefa.py new file mode 100644 index 0000000..daabdd1 --- /dev/null +++ b/mqbench/fake_quantize/dorefa.py @@ -0,0 +1,33 @@ +import torch + +from mqbench.fake_quantize.quantize_base import QuantizeBase + + +class DoReFaFakeQuantize(QuantizeBase): + def __init__(self, observer, **observer_kwargs): + super(DoReFaFakeQuantize, self).__init__(observer, **observer_kwargs) + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int)) + + def forward(self, X): + X = torch.tanh(X) + X = X.div(X.abs().max() + 1e-5) + + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point.long(), self.ch_axis, self.quant_min, self.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale.item(), self.zero_point.item(), self.quant_min, self.quant_max) + return X \ No newline at end of file diff --git a/mqbench/fake_quantize/dsq.py b/mqbench/fake_quantize/dsq.py new file mode 100644 index 0000000..cc0ed09 --- /dev/null +++ b/mqbench/fake_quantize/dsq.py @@ -0,0 +1,95 @@ +import math + +import torch + +from mqbench.fake_quantize.quantize_base import QuantizeBase +from mqbench.utils import is_tracing_state + + +def dsq_function_per_tensor(x, scale, zero_point, quant_min, quant_max, alpha): + tanh_scale = 1 / (1 - alpha) + tanh_k = math.log((tanh_scale + 1) / (tanh_scale - 1)) + + x = x / scale + zero_point + x = torch.clamp(x, quant_min, quant_max) + x = x.floor() + (tanh_scale * torch.tanh(tanh_k * (x - x.floor() - 0.5))) * 0.5 + 0.5 + x = (x.round() - x).detach() + x + x = (x - zero_point) * scale + + return x + + +def dsq_function_per_channel(x, scale, zero_point, quant_min, quant_max, ch_axis, alpha): + + new_shape = [1] * len(x.shape) + new_shape[ch_axis] = x.shape[ch_axis] + scale = scale.reshape(new_shape) + zero_point = zero_point.reshape(new_shape) + + tanh_scale = 1 / (1 - alpha) + tanh_k = math.log((tanh_scale + 1) / (tanh_scale - 1)) + + x = x / scale + zero_point + x = torch.clamp(x, quant_min, quant_max) + x = x.floor() + (tanh_scale * torch.tanh(tanh_k * (x - x.floor() - 0.5))) * 0.5 + 0.5 + x = (x.round() - x).detach() + x + x = (x - zero_point) * scale + + return x + + +class DSQFakeQuantize(QuantizeBase): + def __init__(self, observer, alpha=0.4, **observer_kwargs): + super(DSQFakeQuantize, self).__init__(observer, **observer_kwargs) + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int)) + self.alpha = alpha + + def forward(self, X): + if self.training: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point.float()) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + if is_tracing_state(): + X = FakeQuantizeDSQPerchannel.apply( + X, self.scale, self.zero_point, self.quant_min, self.quant_max, self.ch_axis, self.alpha) + else: + X = dsq_function_per_channel( + X, self.scale, self.zero_point, self.quant_min, self.quant_max, self.ch_axis, self.alpha) + else: + if is_tracing_state(): + X = FakeQuantizeDSQPertensor.apply( + X, self.scale, self.zero_point, self.quant_min, self.quant_max, self.alpha) + else: + X = dsq_function_per_tensor( + X, self.scale, self.zero_point, self.quant_min, self.quant_max, self.alpha) + + return X + + +class FakeQuantizeDSQPerchannel(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale, zero_point, quant_min, quant_max, ch_axis, alpha): + return dsq_function_per_channel(x, scale, zero_point, quant_min, quant_max, ch_axis, alpha) + + @staticmethod + def symbolic(g, x, scale, zero_point, quant_min, quant_max, ch_axis, alpha): + return g.op("::FakeQuantizeDSQPerchannel", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max, alpha_f=alpha) + + +class FakeQuantizeDSQPertensor(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale, zero_point, quant_min, quant_max, alpha): + return dsq_function_per_tensor(x, scale, zero_point, quant_min, quant_max, alpha) + + @staticmethod + def symbolic(g, x, scale, zero_point, quant_min, quant_max, alpha): + return g.op("::FakeQuantizeDSQPertensor", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max, alpha_f=alpha) diff --git a/mqbench/fake_quantize/fixed.py b/mqbench/fake_quantize/fixed.py new file mode 100644 index 0000000..a2a82c4 --- /dev/null +++ b/mqbench/fake_quantize/fixed.py @@ -0,0 +1,81 @@ +import torch + +from mqbench.fake_quantize.quantize_base import QuantizeBase + + +class FixedFakeQuantize(QuantizeBase): + """This is actually torch.quantization.FakeQuantize. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int)) + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale.data, self.zero_point.data.long(), + self.ch_axis, self.quant_min, self.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale.item(), self.zero_point.item(), + self.quant_min, self.quant_max) + return X + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to manually + # specify serialization here. + super(FixedFakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + 'scale'] = self.scale + destination[prefix + 'zero_point'] = self.zero_point + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + # Removing this function throws an error that the the size of the loaded tensor does not match the original size + # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. + local_state = ['scale', 'zero_point'] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super(FixedFakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) \ No newline at end of file diff --git a/mqbench/fake_quantize/lsq.py b/mqbench/fake_quantize/lsq.py new file mode 100644 index 0000000..118f595 --- /dev/null +++ b/mqbench/fake_quantize/lsq.py @@ -0,0 +1,131 @@ +from functools import partial + +import torch +from torch.nn.parameter import Parameter + +from mqbench.fake_quantize.quantize_base import QuantizeBase +from mqbench.utils import is_symmetric_quant, is_tracing_state + + +class LearnableFakeQuantize(QuantizeBase): + r""" This is an extension of the FakeQuantize module in fake_quantize.py, which + supports more generalized lower-bit quantization and support learning of the scale + and zero point parameters through backpropagation. For literature references, + please see the class _LearnableFakeQuantizePerTensorOp. + In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + """ + def __init__(self, observer, scale=1., zero_point=0., use_grad_scaling=True, **observer_kwargs): + super(LearnableFakeQuantize, self).__init__(observer, **observer_kwargs) + self.use_grad_scaling = use_grad_scaling + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps])) + # Check whether the module will load a state dict; + # Initialize the shape of per-channel 'scale' and 'zero-point' before copying values + + class PerChannelLoadHook: + def __init__(self, module): + self.hook = module._register_load_state_dict_pre_hook(partial(self.hook_fn, module=module)) + + def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, + module): + if module.ch_axis == -1: + # no per-channel parameters + return + for module_key, param in module._parameters.items(): + if module_key not in ["scale", "zero_point"]: + continue + candidate = prefix + module_key + if candidate in state_dict: + input_param = state_dict[candidate] + if param.shape != input_param.shape: + param.data = torch.ones_like(input_param, dtype=param.dtype, device=param.device) + + def close(self): + self.hook.remove() + + self.load_state_dict_hook = PerChannelLoadHook(self) + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) + + def forward(self, X): + # Learnable fake quantize have to zero_point.float() to make it learnable. + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + + if self.ch_axis != -1: + self.scale.data = torch.ones_like(_scale) + self.zero_point.data = torch.zeros_like(_zero_point.float()) + + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point.float()) + else: + self.scale.data.abs_() + self.scale.data.clamp_(min=self.eps.item()) + + if self.fake_quant_enabled[0] == 1: + if is_symmetric_quant(self.qscheme): + self.zero_point.data.zero_() + else: + self.zero_point.data.clamp_(self.quant_min, self.quant_max).float() + + if self.is_per_channel: + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + if is_tracing_state(): + X = FakeQuantizeLearnablePerchannelAffine.apply( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + X = _fake_quantize_learnable_per_channel_affine_training( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, + self.quant_min, self.quant_max, grad_factor) + return X + + +def _fake_quantize_learnable_per_channel_affine_training(x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): + zero_point = (zero_point.round() - zero_point).detach() + zero_point + new_shape = [1] * len(x.shape) + new_shape[ch_axis] = x.shape[ch_axis] + scale = grad_scale(scale, grad_factor).reshape(new_shape) + zero_point = grad_scale(zero_point, grad_factor).reshape(new_shape) + x = x / scale + zero_point + x = (x.round() - x).detach() + x + x = torch.clamp(x, quant_min, quant_max) + return (x - zero_point) * scale + + +def grad_scale(t, scale): + return (t - (t * scale)).detach() + (t * scale) + + +class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): + return _fake_quantize_learnable_per_channel_affine_training(x, scale, zero_point, ch_axis, + quant_min, quant_max, grad_factor) + + @staticmethod + def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): + return g.op("::FakeQuantizeLearnablePerchannelAffine", x, scale, zero_point, quant_min_i=quant_min, quant_max_i=quant_max) \ No newline at end of file diff --git a/mqbench/fake_quantize/nnie.py b/mqbench/fake_quantize/nnie.py new file mode 100644 index 0000000..ebab029 --- /dev/null +++ b/mqbench/fake_quantize/nnie.py @@ -0,0 +1,43 @@ +import torch + +from mqbench.fake_quantize.quantize_base import QuantizeBase +from mqbench.utils import no_jit_trace + + +class NNIEFakeQuantize(QuantizeBase): + def __init__(self, observer, **observer_kwargs): + super(NNIEFakeQuantize, self).__init__(observer, **observer_kwargs) + self.register_buffer('data_max', torch.tensor(float('-inf'))) + + def forward(self, X): + with no_jit_trace(): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + data_max = torch.max(-self.activation_post_process.min_val, self.activation_post_process.max_val) + self.data_max = torch.max(data_max, self.data_max) + X = NNIEQuantizeFunc.apply(X, self.data_max) + return X + + +class NNIEQuantizeFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, data_max): + z = (16 * torch.log2(data_max.double())).round() - 127 + x = x.double() + pos_idx = x > 2 ** ((z - 16) / 16) + neg_idx = x < - 2 ** ((z + 1 - 16) / 16) + zero_idx = (x >= - 2 ** ((z + 1 - 16) / 16)) & (x < 2 ** ((z - 16) / 16)) + x[zero_idx] = 0 + x[pos_idx] = 2 ** ((torch.clamp(torch.round(16 * torch.log2(x[pos_idx]) - z), 0, 127) + z) / 16) + x[neg_idx] = - 2 ** ((torch.clamp(torch.round(16 * torch.log2(-x[neg_idx]) - z), 1, 127) + z) / 16) + x = x.float() + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output + return grad_input, None + + @staticmethod + def symbolic(g, x, data_max): + return g.op("::NNIEQuantize", x, data_max) \ No newline at end of file diff --git a/mqbench/fake_quantize/pact.py b/mqbench/fake_quantize/pact.py new file mode 100644 index 0000000..598c232 --- /dev/null +++ b/mqbench/fake_quantize/pact.py @@ -0,0 +1,52 @@ +import torch +from torch.nn.parameter import Parameter + +from mqbench.fake_quantize.quantize_base import QuantizeBase + + +class PACTFakeQuantize(QuantizeBase): + def __init__(self, observer, alpha=6.0, **observer_kwargs): + super(PACTFakeQuantize, self).__init__(observer, **observer_kwargs) + self.alpha = Parameter(torch.tensor([alpha])) + if not self.is_symmetric_quant: + self.n_alpha = Parameter(torch.tensor([-alpha])) + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int)) + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ + 'alpha={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, self.alpha) + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + X = torch.where(X > self.alpha, self.alpha, X) + self.activation_post_process.max_val.data.fill_(self.alpha.data[0]) + if X.min() < 0: + if self.is_symmetric_quant: + X = torch.where(X < -self.alpha, -self.alpha, X) + self.activation_post_process.min_val.data.fill_(-self.alpha[0].data) + else: + X = torch.where(X < self.n_alpha, self.n_alpha, X) + self.activation_post_process.min_val.data.fill_(self.n_alpha[0].data) + else: + self.activation_post_process.min_val.data.fill_(0.) + + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale.item(), self.zero_point.item(), self.quant_min, self.quant_max) + + return X \ No newline at end of file diff --git a/mqbench/fake_quantize/quantize_base.py b/mqbench/fake_quantize/quantize_base.py new file mode 100644 index 0000000..0338178 --- /dev/null +++ b/mqbench/fake_quantize/quantize_base.py @@ -0,0 +1,48 @@ +import torch +from torch.quantization import FakeQuantizeBase +from torch.quantization.observer import MovingAverageMinMaxObserver +from torch.quantization.fake_quantize import _is_per_channel, _is_per_tensor + +from mqbench.utils import is_symmetric_quant + + +class QuantizeBase(FakeQuantizeBase): + r""" This is an extension of the FakeQuantize module in fake_quantize.py, which + supports more generalized lower-bit quantization and support learning of the scale + and zero point parameters through backpropagation. For literature references, + please see the class _LearnableFakeQuantizePerTensorOp. + In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + """ + def __init__(self, observer=MovingAverageMinMaxObserver, **observer_kwargs): + super().__init__() + self.activation_post_process = observer(**observer_kwargs) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + assert self.quant_min <= self.quant_max, \ + 'quant_min must be less than or equal to quant_max' + self.pot_scale = self.activation_post_process.pot_scale + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + assert _is_per_channel(self.qscheme) or \ + _is_per_tensor(self.qscheme), \ + 'Only per channel and per tensor quantization are supported in fake quantize' + \ + ' got qscheme: ' + str(self.qscheme) + self.is_per_channel = _is_per_channel(self.qscheme) + bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.is_symmetric_quant = is_symmetric_quant(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): + return self.activation_post_process.calculate_qparams() + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, '.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis) \ No newline at end of file diff --git a/mqbench/fuser_method_mappings.py b/mqbench/fuser_method_mappings.py new file mode 100644 index 0000000..4ff7211 --- /dev/null +++ b/mqbench/fuser_method_mappings.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import mqbench.nn as qnn +from torch.quantization.fx.fusion_patterns import ConvBNReLUFusion + + +def fuse_linear_bn(linear, bn): + r"""Given the linear and bn modules, fuses them and returns the fused module + + Args: + conv: Module instance of type Linear + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Linear(10, 20) + >>> b1 = nn.BatchNorm1d(20) + >>> m2 = fuse_linear_bn(m1, b1) + """ + assert(linear.training == bn.training),\ + "Linear and BN both must be in the same mode (train or eval)." + + if linear.training: + assert bn.affine, 'Only support fusing BatchNorm1d with affine set to True' + assert bn.track_running_stats, 'Only support fusing BatchNorm1d with tracking_running_stats set to True' + return qnn.intrinsic.LinearBn1d(linear, bn) + else: + return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + + +fuse_custom_config_dict = { + "additional_fuser_method_mapping": { + (torch.nn.Linear, torch.nn.BatchNorm1d): fuse_linear_bn + }, + "additional_fusion_pattern": { + (torch.nn.BatchNorm1d, torch.nn.Linear): ConvBNReLUFusion + } +} +# Sinse additional_fuser_method_mapping will not be set because fuser.py:54 +# do not pass this dict. +from torch.quantization.fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD +DEFAULT_OP_LIST_TO_FUSER_METHOD.update({(torch.nn.Linear, torch.nn.BatchNorm1d): fuse_linear_bn}) \ No newline at end of file diff --git a/mqbench/fusion_method.py b/mqbench/fusion_method.py new file mode 100644 index 0000000..8c12869 --- /dev/null +++ b/mqbench/fusion_method.py @@ -0,0 +1,88 @@ +import torch +import torch.nn.intrinsic.qat as nniqat +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval +from torch.quantization.fx.utils import _parent_name + +import mqbench.nn.intrinsic as qnni +import mqbench.nn.intrinsic.qat as qnniqat +from mqbench.utils.registry import register_convert_function + + +@register_convert_function(qnni.LinearBn1d) +def convert_qnni_linearbn(model, fused_node): + modules = dict(model.named_modules()) + fused_module = modules[fused_node.target] + fused_linear = fuse_linear_bn(fused_module[0], fused_module[1]) + linear_parent_name, linear_name = _parent_name(fused_node.target) + setattr(modules[linear_parent_name], linear_name, fused_linear) + + +@register_convert_function(qnniqat.LinearBn1d) +def convert_qnniqat_linearbn(model, fused_node): + modules = dict(model.named_modules()) + fused_module = modules[fused_node.target] + # Create a Linear from FusedModule. + linear = torch.nn.Linear(fused_module.in_features, fused_module.out_features, fused_module.bias is not None) + linear.weight = fused_module.weight + if fused_module.bias is not None: + linear.bias = fused_module.bias + # Merge Linear + BN + fused_linear = fuse_linear_bn_eval(linear.eval(), fused_module.bn) + # We need nn.qat.linear here to export weight quantize node. + linear.qconfig = fused_module.qconfig + linear = torch.nn.qat.Linear.from_float(linear) + # Attach weight fake quantize params. + linear.weight_fake_quant = fused_module.weight_fake_quant + linear_parent_name, linear_name = _parent_name(fused_node.target) + setattr(modules[linear_parent_name], linear_name, fused_linear) + + +@register_convert_function(nniqat.ConvBn2d) +def convert_nniqat_convbn(model, fused_node): + modules = dict(model.named_modules()) + fused_module = modules[fused_node.target] + # Create a Conv2d from FusedModule. + conv = torch.nn.Conv2d(fused_module.in_channels, fused_module.out_channels, fused_module.kernel_size, + fused_module.stride, fused_module.padding, fused_module.dilation, + fused_module.groups, fused_module.bias is not None, fused_module.padding_mode) + conv.weight = fused_module.weight + if fused_module.bias is not None: + conv.bias = fused_module.bias + fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn) + # We need nn.qat.conv here to export weight quantize node. + fused_conv.qconfig = fused_module.qconfig + fused_conv = torch.nn.qat.Conv2d.from_float(fused_conv) + # Attach weight fake quantize params. + fused_conv.weight_fake_quant = fused_module.weight_fake_quant + conv_parent_name, conv_name = _parent_name(fused_node.target) + setattr(modules[conv_parent_name], conv_name, fused_conv) + + +@register_convert_function(nniqat.ConvBnReLU2d) +def convert_nniqat_convbnrelu(model, fused_node): + convert_nniqat_convbn(model, fused_node) + modules = dict(model.named_modules()) + fused_module = modules[fused_node.target] + # We need to Insert Relu after Merged conv. + conv_parent_name, conv_name = _parent_name(fused_node.target) + relu_name = 'relu' + # Maybe has another name, but we cannot know for now. + if not hasattr(modules[conv_parent_name], relu_name): + setattr(modules[conv_parent_name], relu_name, + torch.nn.ReLU(inplace=True).train(fused_module.training)) + # Update modules. + modules = dict(model.named_modules()) + graph = model.graph + nodes = list(model.graph.nodes) + with graph.inserting_after(fused_node): + relu_node_name = relu_name if conv_parent_name == "" else "{}.{}".format(conv_parent_name, relu_name) + assert relu_node_name in modules and isinstance(modules[relu_node_name], torch.nn.ReLU) + inserted_node = graph.create_node("call_module", relu_node_name, (fused_node,), {}) + for _node in nodes: + for i, _arg in enumerate(_node.args): + if _arg == fused_node: + _tmp = list(_node.args) + _tmp[i] = inserted_node + _node.args = tuple(_tmp) + model.recompile() + model.graph.lint() \ No newline at end of file diff --git a/mqbench/nn/__init__.py b/mqbench/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mqbench/nn/intrinsic/__init__.py b/mqbench/nn/intrinsic/__init__.py new file mode 100644 index 0000000..9a8067b --- /dev/null +++ b/mqbench/nn/intrinsic/__init__.py @@ -0,0 +1 @@ +from .modules import * \ No newline at end of file diff --git a/mqbench/nn/intrinsic/modules/__init__.py b/mqbench/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000..f1f5406 --- /dev/null +++ b/mqbench/nn/intrinsic/modules/__init__.py @@ -0,0 +1 @@ +from .fused import LinearBn1d \ No newline at end of file diff --git a/mqbench/nn/intrinsic/modules/fused.py b/mqbench/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000..5d8e49c --- /dev/null +++ b/mqbench/nn/intrinsic/modules/fused.py @@ -0,0 +1,13 @@ + +from torch.nn.intrinsic import _FusedModule +from torch.nn import Linear, BatchNorm1d + + +class LinearBn1d(_FusedModule): + r"""This is a sequential container which calls the Linear and Batch Norm 1d modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, linear, bn): + assert type(linear) == Linear and type(bn) == BatchNorm1d, \ + 'Incorrect types for input modules{}{}'.format( + type(linear), type(bn)) + super().__init__(linear, bn) \ No newline at end of file diff --git a/mqbench/nn/intrinsic/qat/__init__.py b/mqbench/nn/intrinsic/qat/__init__.py new file mode 100644 index 0000000..9a8067b --- /dev/null +++ b/mqbench/nn/intrinsic/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * \ No newline at end of file diff --git a/mqbench/nn/intrinsic/qat/modules/__init__.py b/mqbench/nn/intrinsic/qat/modules/__init__.py new file mode 100644 index 0000000..d33f1ba --- /dev/null +++ b/mqbench/nn/intrinsic/qat/modules/__init__.py @@ -0,0 +1 @@ +from .linear_fused import LinearBn1d \ No newline at end of file diff --git a/mqbench/nn/intrinsic/qat/modules/linear_fused.py b/mqbench/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000..8c22eb1 --- /dev/null +++ b/mqbench/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,177 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn import Linear +from torch.nn.intrinsic import _FusedModule +from torch.nn.parameter import Parameter + +from mqbench.nn.intrinsic import LinearBn1d + + +class LinearBn1d(Linear, _FusedModule): + _FLOAT_MODULE = LinearBn1d + + def __init__(self, + # ConvNd args + in_features, out_features, bias, + # BatchNormNd args + # num_features: out_channels + eps=1e-05, momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None): + Linear.__init__(self, in_features, out_features, False) + assert qconfig, 'qconfig must be provided for QAT module' + self.qconfig = qconfig + self.freeze_bn = freeze_bn if self.training else True + self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) + self.weight_fake_quant = self.qconfig.weight() + if bias: + self.bias = Parameter(torch.empty(out_features)) + else: + self.register_parameter('bias', None) + self.reset_bn_parameters() + + # this needs to be called after reset_bn_parameters, + # as they modify the same state + if self.training: + if freeze_bn: + self.freeze_bn_stats() + else: + self.update_bn_stats() + else: + self.freeze_bn_stats() + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def reset_bn_parameters(self): + self.bn.reset_running_stats() + init.uniform_(self.bn.weight) + init.zeros_(self.bn.bias) + # note: below is actully for Linear, not BN + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def reset_parameters(self): + super(LinearBn1d, self).reset_parameters() + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def _forward(self, input): + assert self.bn.running_var is not None + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + # input.shape = (batch_size, in_features, *) + # scale_factor.shape = (out_feature, ) + # self.weight.shape = (out_feature, in_feature, *) + # self.bias.shape = (out_feature, *) + # output.shape = (batch_size, out_feature, *) + if self.bn.affine: + scale_factor = self.bn.weight / running_std + else: + scale_factor = 1. / running_std + weight_shape = [1] * len(self.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(input.shape) + bias_shape[1] = -1 + scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape)) + # using zero bias here since the bias for original Linear + # will be added later + # Linear layer takes permuted input since the format is (batch_size, *, in_features) + linear_out = F.linear(input, scaled_weight) + linear_orig = linear_out / scale_factor.reshape(bias_shape) + if self.bias is not None: + linear_orig = linear_orig + self.bias.reshape(bias_shape) + linear_out = self.bn(linear_orig) + return linear_out + + def extra_repr(self): + return super(LinearBn1d, self).extra_repr() + + def forward(self, input): + return self._forward(input) + + def train(self, mode=True): + """ + Batchnorm's training behavior is using the self.training flag. Prevent + changing it if BN is frozen. This makes sure that calling `model.train()` + on a model with a frozen BN will behave properly. + """ + self.training = mode + if not self.freeze_bn: + for module in self.children(): + module.train(mode) + return self + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + if version is None or version == 1: + # BN related parameters and buffers were moved into the BN module for v2 + v2_to_v1_names = { + 'bn.weight': 'gamma', + 'bn.bias': 'beta', + 'bn.running_mean': 'running_mean', + 'bn.running_var': 'running_var', + 'bn.num_batches_tracked': 'num_batches_tracked', + } + for v2_name, v1_name in v2_to_v1_names.items(): + if prefix + v1_name in state_dict: + state_dict[prefix + v2_name] = state_dict[prefix + v1_name] + state_dict.pop(prefix + v1_name) + elif prefix + v2_name in state_dict: + # there was a brief period where forward compatibility + # for this module was broken (between + # https://github.com/pytorch/pytorch/pull/38478 + # and https://github.com/pytorch/pytorch/pull/38820) + # and modules emitted the v2 state_dict format while + # specifying that version == 1. This patches the forward + # compatibility issue by allowing the v2 style entries to + # be used. + pass + elif strict: + missing_keys.append(prefix + v2_name) + + super(LinearBn1d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + @classmethod + def from_float(cls, mod): + r"""Create a qat module from a float module or qparams_dict + + Args: `mod` a float module, either produced by torch.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + assert mod.qconfig, 'Input float module must have a valid qconfig' + qconfig = mod.qconfig + linear, bn = mod[0], mod[1] + qat_linearbn = cls(linear.in_features, linear.out_features, False, + bn.eps, bn.momentum, + False, + qconfig) + qat_linearbn.weight = linear.weight + qat_linearbn.bias = linear.bias + qat_linearbn.bn.weight = bn.weight + qat_linearbn.bn.bias = bn.bias + qat_linearbn.bn.running_mean = bn.running_mean + qat_linearbn.bn.running_var = bn.running_var + # mypy error: Cannot determine type of 'num_batches_tracked' + qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[has-type] + return qat_linearbn diff --git a/mqbench/observer.py b/mqbench/observer.py new file mode 100644 index 0000000..aa859b5 --- /dev/null +++ b/mqbench/observer.py @@ -0,0 +1,285 @@ +import math +from typing import Tuple + +import torch +from torch.quantization.observer import _ObserverBase + +from mqbench.utils import sync_tensor, pot_quantization, is_symmetric_quant + + +class ObserverBase(_ObserverBase): + ''' + Support per-tensor / per-channel. + dtype: quant min/max can be infered using dtype, we actually do not need this. + qscheme: quantization scheme + reduce_range: special for fbgemm to avoid overflow + quant_min: fix point value min + quant_max: fix point value max + ch_axis: per-channel axis or per-tensor(-1) + above is similiar to torch observer. + pot_scale: indecate wheather scale is power of two. + ''' + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, + reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False): + super(ObserverBase, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max) + self.ch_axis = ch_axis + self.pot_scale = pot_scale + self.register_buffer("min_val", torch.tensor(float("inf"))) + self.register_buffer("max_val", torch.tensor(float("-inf"))) + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters.""" + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + if self.pot_scale: + scale = pot_quantization(scale) + sync_tensor(scale) + sync_tensor(zero_point) + return scale, zero_point + + @torch.jit.export + def _calculate_qmin_qmax(self) -> Tuple[int, int]: + r"""Calculates actual qmin and qmax based on the quantization range, + observer datatype and if range is reduced. + """ + if self.has_customized_qrange: + # This initialization here is to be resolve TorchScript compilation issues and allow + # using of refinement to decouple initial_qmin and initial_qmax from quantization range. + # The actual values of initial_qmin and initial_qmax will be reset below. + initial_quant_min, initial_quant_max = 0, 255 + # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the + # attribute from Optional valid integers for use, based on TorchScript's requirements. + custom_quant_min, custom_quant_max = self.quant_min, self.quant_max + if custom_quant_min is not None and custom_quant_max is not None: + initial_quant_min, initial_quant_max = ( + custom_quant_min, + custom_quant_max, + ) + + qrange_len = initial_quant_max - initial_quant_min + 1 + if is_symmetric_quant(self.qscheme): + quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1 + else: + quant_min, quant_max = 0, qrange_len - 1 + if self.reduce_range: + quant_min, quant_max = quant_min // 2, quant_max // 2 + else: + # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. + if self.dtype == torch.qint8: + if self.reduce_range: + quant_min, quant_max = -64, 63 + else: + quant_min, quant_max = -128, 127 + elif self.dtype == torch.quint8: + if self.reduce_range: + quant_min, quant_max = 0, 127 + else: + quant_min, quant_max = 0, 255 + else: + quant_min, quant_max = 0, 15 + return quant_min, quant_max + + @torch.jit.export + def extra_repr(self): + return "min_val={}, max_val={} ch_axis={} pot={}".format(self.min_val, self.max_val, self.ch_axis, self.pot_scale) + + +class MinMaxObserver(ObserverBase): + ''' + Calculate minmax of whole calibration dataset. + ''' + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, + reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False): + super(MinMaxObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, ch_axis, pot_scale) + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + + self.min_val = min_val_cur + self.max_val = max_val_cur + + return x + + +class EMAMinMaxObserver(ObserverBase): + """Moving average min/max among batches. + """ + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, + quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, ema_ratio=0.9): + super(EMAMinMaxObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, pot_scale) + self.ema_ratio = ema_ratio + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + + if self.max_val.numel() <= 1 and self.max_val.isinf(): + self.min_val = min_val_cur + self.max_val = max_val_cur + else: + self.min_val = self.min_val * self.ema_ratio + min_val_cur * (1.0 - self.ema_ratio) + self.max_val = self.max_val * self.ema_ratio + max_val_cur * (1.0 - self.ema_ratio) + return x + + +class ClipStdObserver(ObserverBase): + """Clip std. + """ + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, + quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, std_scale=2.6): + super(ClipStdObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, ch_axis, pot_scale) + self.std_scale = std_scale + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + mean = x.mean() + std = x.std() + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + mean = y.mean(1) + std = y.std(1) + + # using statistics to clip min and max + min_val = torch.minimum(mean - self.std_scale * std, min_val_cur) + max_val = torch.maximum(mean + self.std_scale * std, max_val_cur) + + self.min_val = min_val + self.max_val = max_val + + return x + + +class LSQObserver(ObserverBase): + ''' + LSQ observer. + ''' + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, + quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False): + super(LSQObserver, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max, ch_axis, pot_scale) + self.tensor_norm = None + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + self.tensor_norm = x.abs().mean() + self.min_val, self.max_val = torch._aminmax(x) + else: + # compute channel-wise mean + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + self.tensor_norm = y.abs().mean(1) + self.min_val, self.max_val = torch._aminmax(y) + + return x + + def calculate_qparams(self): + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + if self.pot_scale: + scale = pot_quantization(scale) + zero_point = torch.zeros_like(self.tensor_norm) + if not is_symmetric_quant(self.qscheme): + if self.min_val >= 0.: + zero_point = self.quant_min - torch.round(self.min_val / scale) + sync_tensor(scale) + sync_tensor(zero_point) + return scale, zero_point + + +class LSQPlusObserver(ObserverBase): + ''' + LSQ+ observer. + ''' + + def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, reduce_range=False, + quant_min=-128, quant_max=128, ch_axis=-1, pot_scale=False): + + super(LSQPlusObserver, self).__init__(dtype, qscheme, reduce_range, + quant_min, quant_max, ch_axis, pot_scale) + self.mean = None + self.std = None + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + self.mean = x.mean() + self.std = x.std() + self.min_val, self.max_val = torch._aminmax(x) + else: + # compute channel-wise mean + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + self.mean = y.mean(1) + self.std = y.std(1) + self.min_val, self.max_val = torch._aminmax(y) + + return x + + def calculate_qparams(self): + scale = torch.maximum((self.mean - 3 * self.std).abs(), + (self.mean + 3 * self.std).abs()) / (self.quant_max - self.quant_min + 1) + if self.pot_scale: + scale = pot_quantization(scale) + zero_point = torch.zeros_like(self.mean) + if not is_symmetric_quant(self.qscheme): + if self.min_val >= 0.: + zero_point = self.quant_min - torch.round(self.min_val / scale) + sync_tensor(scale) + sync_tensor(zero_point) + return scale, zero_point \ No newline at end of file diff --git a/mqbench/prepare_by_platform.py b/mqbench/prepare_by_platform.py new file mode 100644 index 0000000..8756afe --- /dev/null +++ b/mqbench/prepare_by_platform.py @@ -0,0 +1,250 @@ +from enum import Enum +from typing import Any, Dict + +import torch +from torch.fx.symbolic_trace import symbolic_trace +from torch.quantization.quantize_fx import _swap_ff_with_fxff, _fuse_fx +from torch.quantization import QConfig + +from mqbench.fake_quantize import ( + LearnableFakeQuantize, + NNIEFakeQuantize, + FixedFakeQuantize, + DoReFaFakeQuantize, + DSQFakeQuantize, + PACTFakeQuantize +) +from mqbench.observer import ( + ClipStdObserver, + LSQObserver, + MinMaxObserver, + EMAMinMaxObserver +) +from mqbench.fuser_method_mappings import fuse_custom_config_dict +from mqbench.utils.logger import logger +from mqbench.utils.registry import DEFAULT_MODEL_QUANTIZER + + +class BackendType(Enum): + Academic = 'Academic' + Tensorrt = 'Tensorrt' + SNPE = 'SNPE' + PPLW8A16 = 'PPLW8A16' + NNIE = 'NNIE' + + +class QuantizeScheme(object): + """Describe quantization scheme. + """ + def __init__(self, symmetry=True, per_channel=False, pot_scale=False, bit=8): + self.symmetry = symmetry + self.per_channel = per_channel + self.pot_scale = pot_scale + self.bit = bit + if self.per_channel: + self.torch_qscheme = torch.per_channel_symmetric if self.symmetry else torch.per_channel_affine + else: + self.torch_qscheme = torch.per_tensor_symmetric if self.symmetry else torch.per_tensor_affine + + def to_observer_params(self): + return { + 'quant_min': -2 ** (self.bit - 1) if self.symmetry else 0, + 'quant_max': 2 ** (self.bit - 1) - 1 if self.symmetry else 2 ** self.bit - 1, + 'dtype': torch.qint8 if self.symmetry else torch.quint8, + 'pot_scale': self.pot_scale, + 'qscheme': self.torch_qscheme, + 'reduce_range': False, + 'ch_axis': 0 if self.per_channel else -1 + } + + def __str__(self): + return "Symmetric: {} / Bitwidth: {} / Per channel: {} / Pot scale: {}".format(self.symmetry, + self.bit, + self.per_channel, + self.pot_scale) + + +ParamsTable = { + BackendType.Academic: dict(qtype='affine'), # noqa: E241 + BackendType.NNIE: dict(qtype='nnie', # noqa: E241 + # NNIE actually do not need w/a qscheme. We add for initialize observer only. + w_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8), + a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8), + default_weight_quantize=NNIEFakeQuantize, + default_act_quantize=NNIEFakeQuantize, + default_weight_observer=MinMaxObserver, + default_act_observer=EMAMinMaxObserver), + BackendType.Tensorrt: dict(qtype='affine', # noqa: E241 + w_qscheme=QuantizeScheme(symmetry=True, per_channel=True, pot_scale=False, bit=8), + a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8), + default_weight_quantize=LearnableFakeQuantize, + default_act_quantize=LearnableFakeQuantize, + default_weight_observer=LSQObserver, + default_act_observer=LSQObserver), + BackendType.SNPE: dict(qtype='affine', # noqa: E241 + w_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=8), + a_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=8), + default_weight_quantize=LearnableFakeQuantize, + default_act_quantize=LearnableFakeQuantize, + default_weight_observer=LSQObserver, + default_act_observer=LSQObserver), + BackendType.PPLW8A16: dict(qtype='affine', # noqa: E241 + w_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=8), + a_qscheme=QuantizeScheme(symmetry=False, per_channel=False, pot_scale=False, bit=16), + default_weight_quantize=LearnableFakeQuantize, + default_act_quantize=LearnableFakeQuantize, + default_weight_observer=LSQObserver, + default_act_observer=LSQObserver) +} + +ObserverDict = { + 'MinMaxObserver': MinMaxObserver, # noqa: E241 + 'EMAMinMaxObserver': EMAMinMaxObserver, # More general choice. # noqa: E241 + 'QuantileObserver': None, # TODO quantile. # noqa: E241 + 'ClipStdObserver': ClipStdObserver, # Usually used for DSQ. # noqa: E241 + 'LSQObserver': LSQObserver # Usually used for LSQ. # noqa: E241 +} + +FakeQuantizeDict = { + 'FixedFakeQuantize': FixedFakeQuantize, # Unlearnable scale/zeropoint # noqa: E241 + 'LearnableFakeQuantize': LearnableFakeQuantize, # Learnable scale/zeropoint # noqa: E241 + 'NNIEFakeQuantize': NNIEFakeQuantize, # Quantize function for NNIE # noqa: E241 + 'DoReFaFakeQuantize': DoReFaFakeQuantize, # Dorefa # noqa: E241 + 'DSQFakeQuantize': DSQFakeQuantize, # DSQ # noqa: E241 + 'PACTFakeQuantize': PACTFakeQuantize # PACT # noqa: E241 +} + +def get_qconfig_by_platform(deploy_backend: BackendType, extra_qparams): + """ + + Args: + deploy_backend (BackendType): + extra_qparams (dict): + + >>> extra params format: { + 'w_observer': str, weight observer name, + 'a_observer': str, activation observer name, + 'w_fakequantize': str, weight fake quantize function name, + 'w_fakeq_params": dict, params for weight quantize function, + 'a_fakequantize': str, activation fake quantize function name, + 'a_fakeq_params': dict, params for act quantize function, + if deploy_backend == BackendType.Academic keys below will be used: + 'w_qscheme': { + 'bit': bitwidth, + 'symmetry': weather quantize scheme is symmetric, + 'per_channel': weather quantize scheme is perchannel, + 'pot_scale': weather scale is power of two. + } + 'a_qscheme': { + same with w_qscheme. + } + } + """ + w_observer = extra_qparams.get('w_observer', None) + if w_observer: + w_observer = ObserverDict[w_observer] + a_observer = extra_qparams.get('a_observer', None) + if a_observer: + a_observer = ObserverDict[a_observer] + w_fakequantize = extra_qparams.get('w_fakequantize', None) + if w_fakequantize: + w_fakequantize = FakeQuantizeDict[w_fakequantize] + a_fakequantize = extra_qparams.get('a_fakequantize', None) + if w_fakequantize: + a_fakequantize = FakeQuantizeDict[a_fakequantize] + backend_params = ParamsTable[deploy_backend] + + # NNIE backend must use NNIEFakeQuantize but leave observer adjustable. + if backend_params['qtype'] == 'nnie': + if not w_observer: + w_observer = backend_params['default_weight_observer'] + if not a_observer: + a_observer = backend_params['default_act_observer'] + w_qscheme = backend_params['w_qscheme'] + a_qscheme = backend_params['a_qscheme'] + w_config = backend_params['default_weight_quantize'].with_args(observer=w_observer, + **w_qscheme.to_observer_params()) + a_config = backend_params['default_act_quantize'].with_args(observer=a_observer, + **a_qscheme.to_observer_params()) + return QConfig(activation=a_config, weight=w_config) + + # Academic setting should specific quant scheme in config. + if deploy_backend == BackendType.Academic: + w_qscheme = QuantizeScheme(**extra_qparams['w_qscheme']) + a_qscheme = QuantizeScheme(**extra_qparams['a_qscheme']) + else: + w_qscheme = backend_params['w_qscheme'] + a_qscheme = backend_params['a_qscheme'] + + # Get weight / act fake quantize function and params. + if not w_fakequantize: + w_fakequantize = backend_params['default_weight_quantize'] + w_fakeq_params = extra_qparams.get('w_fakeq_params', {}) + if not a_fakequantize: + a_fakequantize = backend_params['default_act_quantize'] + a_fakeq_params = extra_qparams.get('a_fakeq_params', {}) + # Observer dot not need extra params for now. + if not w_observer: + w_observer = MinMaxObserver + if not a_observer: + a_observer = EMAMinMaxObserver + + # Create qconfig. + w_qconfig = w_fakequantize.with_args(observer=w_observer, **w_fakeq_params, **w_qscheme.to_observer_params()) + a_qconfig = a_fakequantize.with_args(observer=a_observer, **a_fakeq_params, **a_qscheme.to_observer_params()) + logger.info('Weight Qconfig:\n FakeQuantize: {} Params: {}\n' + ' Oberver: {} Params: {}'.format(w_fakequantize.__name__, w_fakeq_params, + w_observer.__name__, str(w_qscheme))) + logger.info('Activation Qconfig:\n FakeQuantize: {} Params: {}\n' + ' Oberver: {} Params: {}'.format(a_fakequantize.__name__, a_fakeq_params, + a_observer.__name__, str(a_qscheme))) + return QConfig(activation=a_qconfig, weight=w_qconfig) + + +def prepare_qat_fx_by_platform( + model: torch.nn.Module, + deploy_backend: BackendType, + prepare_custom_config_dict: Dict[str, Any] = {}): + assert model.training, 'prepare_qat_fx_custom only works for models in ' + \ + 'train mode' + + logger.info("Quantize model using {} scheme.".format(deploy_backend)) + + # Get Qconfig + extra_qconfig_dict = prepare_custom_config_dict.get('extra_qconfig_dict', {}) + qconfig = get_qconfig_by_platform(deploy_backend, extra_qconfig_dict) + # Preserve attr. + preserve_attr_dict = dict() + if 'preserve_attr' in prepare_custom_config_dict: + for submodule_name in prepare_custom_config_dict['preserve_attr']: + cur_module = model + if submodule_name != "": + cur_module = getattr(model, submodule_name) + preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name] + preserve_attr_dict[submodule_name] = {} + for attr in preserve_attr_list: + preserve_attr_dict[submodule_name][attr] = getattr(cur_module, attr) + # Symbolic trace + concrete_args = prepare_custom_config_dict.get('concrete_args', None) + graph_module = symbolic_trace(model, concrete_args=concrete_args) + # Model fusion. + extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {}) + _swap_ff_with_fxff(graph_module) + extra_fuse_dict.update(fuse_custom_config_dict) + graph_module = _fuse_fx(graph_module, extra_fuse_dict) + # Prepare + import mqbench.custom_quantizer # noqa: F401 + extra_quantizer_dict = prepare_custom_config_dict.get('extra_quantizer_dict', {}) + quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict) + prepared = quantizer.prepare(graph_module, qconfig) + # Restore attr. + if 'preserve_attr' in prepare_custom_config_dict: + for submodule_name in prepare_custom_config_dict['preserve_attr']: + cur_module = prepared + if submodule_name != "": + cur_module = getattr(prepared, submodule_name) + preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name] + for attr in preserve_attr_list: + logger.info("Preserve attr: {}.{}".format(submodule_name, attr)) + setattr(cur_module, attr, preserve_attr_dict[submodule_name][attr]) + return prepared \ No newline at end of file diff --git a/mqbench/utils/__init__.py b/mqbench/utils/__init__.py new file mode 100644 index 0000000..90f60fd --- /dev/null +++ b/mqbench/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/mqbench/utils/logger.py b/mqbench/utils/logger.py new file mode 100644 index 0000000..c3ee136 --- /dev/null +++ b/mqbench/utils/logger.py @@ -0,0 +1,24 @@ +import logging +import sys + + +QBENCH_LOGGER_NAME = "QBENCH" +logger = logging.getLogger(QBENCH_LOGGER_NAME) +logger.propagate = False +stdout_handler = logging.StreamHandler(sys.stdout) +fmt = logging.Formatter("[%(name)s] %(levelname)s: %(message)s") +stdout_handler.setFormatter(fmt) +stdout_handler.setLevel(logging.DEBUG) +logger.addHandler(stdout_handler) +logger.setLevel(logging.INFO) +logger.parent = None + + +def set_log_level(level): + logger.setLevel(level) + for handler in logger.handlers: + handler.setLevel(level) + + +def disable_logging(): + logger.handlers = [] \ No newline at end of file diff --git a/mqbench/utils/registry.py b/mqbench/utils/registry.py new file mode 100644 index 0000000..e6fa2ed --- /dev/null +++ b/mqbench/utils/registry.py @@ -0,0 +1,33 @@ +from collections import OrderedDict + + +DEFAULT_MODEL_QUANTIZER = OrderedDict() + + +def register_model_quantizer(backend_type): + def insert(quantizer_cls): + DEFAULT_MODEL_QUANTIZER[backend_type] = quantizer_cls + return quantizer_cls + return insert + +BACKEND_DEPLOY_FUNCTION = OrderedDict() + + +def register_deploy_function(backend_type): + def insert(func): + if backend_type in BACKEND_DEPLOY_FUNCTION: + BACKEND_DEPLOY_FUNCTION[backend_type].append(func) + else: + BACKEND_DEPLOY_FUNCTION[backend_type] = [func] + return func + return insert + + +FUSED_MODULE_CONVERT_FUNCTION = OrderedDict() + + +def register_convert_function(module_type): + def insert(func): + FUSED_MODULE_CONVERT_FUNCTION[module_type] = func + return func + return insert \ No newline at end of file diff --git a/mqbench/utils/state.py b/mqbench/utils/state.py new file mode 100644 index 0000000..332703d --- /dev/null +++ b/mqbench/utils/state.py @@ -0,0 +1,30 @@ +import torch + +from mqbench.utils.logger import logger + + +def enable_calibration(model): + logger.info('Enable observer and Disable quantize.') + for name, submodule in model.named_modules(): + if isinstance(submodule, torch.quantization.FakeQuantizeBase): + logger.debug('Enable observer and Disable quant: {}'.format(name)) + submodule.enable_observer() + submodule.disable_fake_quant() + + +def enable_quantization(model): + logger.info('Disable observer and Enable quantize.') + for name, submodule in model.named_modules(): + if isinstance(submodule, torch.quantization.FakeQuantizeBase): + logger.debug('Disable observer and Enable quant: {}'.format(name)) + submodule.disable_observer() + submodule.enable_fake_quant() + + +def disable_all(model): + logger.info('Disable observer and Disable quantize.') + for name, submodule in model.named_modules(): + if isinstance(submodule, torch.quantization.FakeQuantizeBase): + logger.debug('Disable observer and Disable quantize: {}'.format(name)) + submodule.disable_observer() + submodule.disable_fake_quant() \ No newline at end of file diff --git a/mqbench/utils/utils.py b/mqbench/utils/utils.py new file mode 100644 index 0000000..b128afd --- /dev/null +++ b/mqbench/utils/utils.py @@ -0,0 +1,49 @@ +import torch + +USE_LINK = False +USE_DDP = False + +try: + import spring.linklink as link + USE_LINK = True +except ModuleNotFoundError: + import torch.distributed as dist + if torch.distributed.is_initialized(): + USE_DDP = True + + +def sync_tensor(tensor): + global USE_LINK + global USE_DDP + if USE_LINK: + if tensor.is_cuda is True: + world_size = link.get_world_size() + link.allreduce(tensor / world_size) + elif USE_DDP: + world_size = dist.get_world_size() + dist.allreduce(tensor / world_size) + + +def pot_quantization(tensor: torch.Tensor): + log2t = torch.log2(tensor) + log2t = (torch.round(log2t) - log2t).detach() + log2t + return 2 ** log2t + + +def is_symmetric_quant(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +class no_jit_trace: + def __enter__(self): + # pylint: disable=protected-access + self.state = torch._C._get_tracing_state() + torch._C._set_tracing_state(None) + + def __exit__(self, *args): + torch._C._set_tracing_state(self.state) + self.state = None + + +def is_tracing_state(): + return torch._C._get_tracing_state() \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..ef5e32a --- /dev/null +++ b/setup.py @@ -0,0 +1,26 @@ +import setuptools +from mqbench import __version__ + + +def read_requirements(): + reqs = [] + with open('requirements.txt', 'r') as fin: + for line in fin.readlines(): + reqs.append(line.strip()) + return reqs + + +setuptools.setup( + name="MQBench", + version=__version__, + author="The Great Cold", + author_email="", + description=("Quantization aware training."), + python_requires='>=3.6', + packages=setuptools.find_packages(), + classifiers=( + 'Development Status :: 3 - Alpha', + "Programming Language :: Python :: 3", + "Operating System :: POSIX :: Linux"), + install_requires=read_requirements() +) diff --git a/test/backend/__init__.py b/test/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/backend/test_backend.py b/test/backend/test_backend.py new file mode 100644 index 0000000..ae9816a --- /dev/null +++ b/test/backend/test_backend.py @@ -0,0 +1,92 @@ +import torch +import unittest + +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.convert_deploy import convert_deploy +from mqbench.utils.state import enable_calibration, enable_quantization + + +class TestQuantizeBackend(unittest.TestCase): + def test_quantize_acedemic(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'LearnableFakeQuantize', + 'w_qscheme': { + 'bit': 4, + 'symmetry': True, + 'per_channel': False, + 'pot_scale': False + }, + 'a_qscheme': { + 'bit': 4, + 'symmetry': True, + 'per_channel': False, + 'pot_scale': False + } + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Academic, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Academic, {'x': [1, 3, 224, 224]}, model_name='resnet18_acedemic_4bit.onnx') + + def test_quantize_tensorrt(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_trt.onnx') + + def test_quantize_nnie(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.NNIE) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.NNIE, {'x': [1, 3, 224, 224]}, model_name='resnet18_nnie.onnx') + + def test_quantize_snpe(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.SNPE) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.SNPE, {'x': [1, 3, 224, 224]}, model_name='resnet18_snpe.onnx') + + def test_quantize_pplw8a16(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.PPLW8A16) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.PPLW8A16, {'x': [1, 3, 224, 224]}, model_name='resnet18_pplw8a16.onnx') \ No newline at end of file diff --git a/test/fake_quant/__init__.py b/test/fake_quant/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/fake_quant/test_fake_quant.py b/test/fake_quant/test_fake_quant.py new file mode 100644 index 0000000..1af1471 --- /dev/null +++ b/test/fake_quant/test_fake_quant.py @@ -0,0 +1,128 @@ +import torch +import unittest + +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.convert_deploy import convert_deploy +from mqbench.utils.state import enable_calibration, enable_quantization + + +class TestFakeQuantize(unittest.TestCase): + def test_fixed_fake_quantize(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'FixedFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_fixed.onnx') + + def test_learnable_fake_quantize(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'LearnableFakeQuantize', + 'a_fakequantize': 'LearnableFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_lsq.onnx') + + def test_nnie_fake_quantize(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'NNIEFakeQuantize', + 'a_fakequantize': 'NNIEFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.NNIE, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.NNIE, {'x': [1, 3, 224, 224]}, model_name='resnet18_nnie.onnx') + + def test_dorefa_fake_quantize(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'DoReFaFakeQuantize', + 'a_fakequantize': 'LearnableFakeQuantize' + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_dorefa_trt.onnx') + + def test_pact_fake_quantize(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'PACTFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_pact_trt.onnx') + + def test_dsq_fake_quantize(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'ClipStdObserver', + 'a_observer': 'ClipStdObserver', + 'w_fakequantize': 'DSQFakeQuantize', + 'a_fakequantize': 'DSQFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_dsq_trt.onnx') \ No newline at end of file diff --git a/test/load_ckpt/__init__.py b/test/load_ckpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/load_ckpt/test_load_ckpt.py b/test/load_ckpt/test_load_ckpt.py new file mode 100644 index 0000000..bf1e515 --- /dev/null +++ b/test/load_ckpt/test_load_ckpt.py @@ -0,0 +1,38 @@ +import torch +import unittest + +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.convert_deploy import convert_deploy +from mqbench.utils.state import enable_calibration, enable_quantization + + +class TestLoadCheckPoint(unittest.TestCase): + def test_case_1(self): + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'LearnableFakeQuantize', + 'a_fakequantize': 'LearnableFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + # First model + model_1 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + model_1 = prepare_qat_fx_by_platform(model_1, BackendType.Tensorrt, prepare_custom_config_dict) + model_1.train() + enable_calibration(model_1) + model_1(dummy_input) + enable_quantization(model_1) + model_1.eval() + prev_output = model_1(dummy_input) + torch.save(model_1.state_dict(), 'saved_model.ckpt') + # Second model + model_2 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + model_2 = prepare_qat_fx_by_platform(model_2, BackendType.Tensorrt, prepare_custom_config_dict) + state_dict = torch.load('saved_model.ckpt') + model_2.load_state_dict(state_dict) + enable_quantization(model_2) + model_2.eval() + new_output = model_2(dummy_input) + # Test + self.assertTrue((new_output - prev_output).abs().sum() < 1e-9) \ No newline at end of file diff --git a/test/model/__init__.py b/test/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/model/test_model.py b/test/model/test_model.py new file mode 100644 index 0000000..8b2eb86 --- /dev/null +++ b/test/model/test_model.py @@ -0,0 +1,36 @@ +import torch +import unittest + +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.convert_deploy import convert_deploy +from mqbench.utils.state import enable_calibration, enable_quantization +from mqbench.utils.logger import logger + + +class TestQuantizeBackend(unittest.TestCase): + def test_model_ppl(self): + exclude_list = ['googlenet', 'deeplabv3_mobilenet_v3_large', 'inception_v3', 'lraspp_mobilenet_v3_large', + 'mobilenet_v3_large', 'mobilenet_v3_small'] + entrypoints = torch.hub.list('pytorch/vision', force_reload=False) + for entrypoint in entrypoints: + if entrypoint in exclude_list: + continue + logger.info(f'testing {entrypoint}') + if 'deeplab' in entrypoint or 'fcn' in entrypoint: + model_to_quantize = torch.hub.load('pytorch/vision', entrypoint, pretrained=False, pretrained_backbone=False) + else: + model_to_quantize = torch.hub.load('pytorch/vision', entrypoint, pretrained=False) + dummy_input = torch.randn(8, 3, 224, 224, device='cpu') + model_to_quantize.train() + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.PPLW8A16) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + output = model_prepared(dummy_input) + try: + loss = output.sum() + except AttributeError: + loss = output['out'].sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.PPLW8A16, {'x': [1, 3, 224, 224]}, model_name='{}.onnx'.format(entrypoint)) diff --git a/test/observer/__init__.py b/test/observer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/observer/test_observer.py b/test/observer/test_observer.py new file mode 100644 index 0000000..83e6f95 --- /dev/null +++ b/test/observer/test_observer.py @@ -0,0 +1,90 @@ +import torch +import unittest + +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.convert_deploy import convert_deploy +from mqbench.utils.state import enable_calibration, enable_quantization + + +class TestObserver(unittest.TestCase): + def test_ema_observer(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'FixedFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_ema.onnx') + + def test_minmax_observer(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'MinMaxObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'FixedFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_minmax.onnx') + + def test_lsq_observer(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'LSQObserver', + 'a_observer': 'LSQObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'FixedFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_lsq.onnx') + + def test_clip_std_observer(self): + model_to_quantize = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False) + dummy_input = torch.randn(2, 3, 224, 224, device='cpu') + model_to_quantize.train() + extra_qconfig_dict = { + 'w_observer': 'ClipStdObserver', + 'a_observer': 'ClipStdObserver', + 'w_fakequantize': 'FixedFakeQuantize', + 'a_fakequantize': 'FixedFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + model_prepared = prepare_qat_fx_by_platform(model_to_quantize, BackendType.Tensorrt, prepare_custom_config_dict) + enable_calibration(model_prepared) + model_prepared(dummy_input) + enable_quantization(model_prepared) + loss = model_prepared(dummy_input).sum() + loss.backward() + model_prepared.eval() + convert_deploy(model_prepared, BackendType.Tensorrt, {'x': [1, 3, 224, 224]}, model_name='resnet18_clip_std.onnx') + + \ No newline at end of file diff --git a/test/test.sh b/test/test.sh new file mode 100644 index 0000000..4e65740 --- /dev/null +++ b/test/test.sh @@ -0,0 +1 @@ +python -m unittest discover . diff --git a/test/test_merge/__init__.py b/test/test_merge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_merge/test_merge.py b/test/test_merge/test_merge.py new file mode 100644 index 0000000..63e1dbd --- /dev/null +++ b/test/test_merge/test_merge.py @@ -0,0 +1,38 @@ +import torch +import unittest + +from mqbench.prepare_by_platform import prepare_qat_fx_by_platform, BackendType +from mqbench.convert_deploy import convert_merge_bn +from mqbench.utils.state import enable_calibration, enable_quantization, disable_all + + +class TestMergeBN(unittest.TestCase): + + def test_case_1(self): + + def cos(a, b): + return (a * b).sum() / torch.sqrt((a ** 2).sum()) / torch.sqrt((b ** 2).sum()) + + dummy_input = torch.randn(1, 3, 224, 224, device='cpu') + extra_qconfig_dict = { + 'w_observer': 'MinMaxObserver', + 'a_observer': 'EMAMinMaxObserver', + 'w_fakequantize': 'LearnableFakeQuantize', + 'a_fakequantize': 'LearnableFakeQuantize', + } + prepare_custom_config_dict = {'extra_qconfig_dict': extra_qconfig_dict} + # First model + model_1 = torch.hub.load('pytorch/vision', 'mobilenet_v2', pretrained=False) + model_1 = prepare_qat_fx_by_platform(model_1, BackendType.Tensorrt, prepare_custom_config_dict) + model_1.train() + enable_calibration(model_1) + model_1(dummy_input) + enable_quantization(model_1) + model_1.eval() + prev_output = model_1(dummy_input) + # Convert model + convert_merge_bn(model_1) + new_output = model_1(dummy_input) + # Test + # merge bn import about 1e-8 mean error. + self.assertTrue(cos(new_output, prev_output) >= 0.999) \ No newline at end of file