Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is it possible to quantify the AnimeGANv3 model from FP32 to INT8 #58

Closed
zczjx opened this issue Sep 3, 2024 · 5 comments
Closed

is it possible to quantify the AnimeGANv3 model from FP32 to INT8 #58

zczjx opened this issue Sep 3, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@zczjx
Copy link

zczjx commented Sep 3, 2024

is it possible to quantify the AnimeGANv3 model from FP32 to INT8?

@TachibanaYoshino
Copy link
Owner

Quantization is possible, but I haven't tried it yet.
For generative models, quantization needs to ensure the visual quality of its output. It may be more demanding than recognition models. Do you have any good suggestions?

@zczjx
Copy link
Author

zczjx commented Sep 3, 2024

I just want to know the feasibility, and try to do that

@TachibanaYoshino TachibanaYoshino added the enhancement New feature or request label Sep 4, 2024
@TachibanaYoshino
Copy link
Owner

import os,cv2
import numpy as np
from tqdm import tqdm
# tensorflow version:    2.7.0
import tensorflow.compat.v1 as tf


def load_test_data(image_path, input_shape):
    img = cv2.imread(image_path)
    img = cv2.resize(img, (input_shape[1], input_shape[0]), interpolation = cv2.INTER_LINEAR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)/127.5 - 1.0
    img = np.expand_dims(img, axis=0)
    return img

def pb_to_tflite_QUANTIZED_INT8(pb_file, input_shape, data_path):
    def representative_dataset_gen():
        files = [os.path.join(data_path, x) for x in os.listdir(data_path) ]
        for f in tqdm(files):
            img = load_test_data(f, input_shape)
            # Get sample input data as a numpy array in a method of your choosing.
            yield [img]

    converter = tf.lite.TFLiteConverter.from_frozen_graph(pb_file,
                                                                    input_arrays = ["AnimeGANv3_input"],  # The name of the model input node
                                                                    input_shapes = {'AnimeGANv3_input': [1, input_shape[0], input_shape[1], 3]},
                                                                    output_arrays = ["generator_1/main/out_layer"])  # The name of the model output node
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.allow_custom_ops = True
    # Ensure that if any ops can't be quantized, the converter throws an error
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.target_spec.supported_types = [tf.lite.constants.INT8]    # QUANTIZED_INT8
    converter.inference_input_type = tf.uint8
    converter.inference_output_type = tf.uint8
    converter.representative_dataset = representative_dataset_gen
    tflite_quant_model = converter.convert()
    open("QUANTIZED_int8.tflite", "wb").write(tflite_quant_model)

def pb_to_tflite_QUANTIZED_FLOAT16(pb_file, input_shape, data_path=None):
    converter = tf.lite.TFLiteConverter.from_frozen_graph(pb_file,
                                                                    input_arrays = ["AnimeGANv3_input"],  # The name of the model input node
                                                                    input_shapes = {'AnimeGANv3_input': [1, input_shape[0], input_shape[1], 3]},
                                                                    output_arrays = ["generator_1/main/out_layer"])  # The name of the model output node
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.allow_custom_ops = True
    converter.target_spec.supported_types = [tf.lite.constants.FLOAT16]   # QUANTIZED FLOAT16
    tflite_quant_model = converter.convert()
    open("QUANTIZED_fp16.tflite", "wb").write(tflite_quant_model)

if __name__ =="__main__":
    data_path = r"../data/imgs"  # At least 100 images
    pb_file = r'../AnimeGANv3_Hayao_36.pb'
    input_shape = [512, 512]
    pb_to_tflite_QUANTIZED_FLOAT16(pb_file, input_shape)
    pb_to_tflite_QUANTIZED_INT8(pb_file, input_shape, data_path)

As shown above, I built a script to convert tensorflow's pb model to a quantized tflite model. The quantization formats include INT8 and float16. You can deploy them on mobile devices that support tflite, such as Android phones.
Taking the style AnimeGANv3_Hayao_36.pb mentioned in the paper as an example, the quantized model file is as follows:
QUANTIZED_fp16.tflite
QUANTIZED_int8.tflite

The comparison results before and after quantification are as follows:
1

It can be seen that after quantization, the output visual effect of the AnimeGANv3 model still maintains a high quality, and the model file is also reduced a lot.

@TachibanaYoshino TachibanaYoshino pinned this issue Sep 4, 2024
@zczjx
Copy link
Author

zczjx commented Sep 5, 2024

awesome, thanks a lot
let me try to migrate it to rknn on rk3588 arm linux platform

@zczjx
Copy link
Author

zczjx commented Sep 11, 2024

with porting tflite lib to rk3588, at least it can work by cpu inference
rknn model transfer failed, need to do more adaptive work to transfer the tflite model to rknn formart model

@zczjx zczjx closed this as completed Sep 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants