Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

为您提供一个tpu-mlir/python/sample/detect_yolov8.py的执行脚本 #177

Open
wlc952 opened this issue Jul 22, 2024 · 0 comments
Open

Comments

@wlc952
Copy link

wlc952 commented Jul 22, 2024

#!/usr/bin/env python3
# Copyright (C) 2022 Sophgo Technologies Inc.  All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================
try:
    from tpu_mlir.python import *
except ImportError:
    pass

import numpy as np
import os
import sys
import argparse
import cv2
from tools.model_runner import mlir_inference, model_inference, onnx_inference, torch_inference
from utils.preprocess import supported_customization_format

classes = {
    0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat',
    9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat',
    16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe',
    24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis',
    31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard',
    37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife',
    44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot',
    52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed',
    60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard',
    67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book',
    74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}


class YOLOv8:
    def __init__(self, model_path, net_input_shape, input_image, confidence_thres, iou_thres):
        self.model_path = model_path
        self.input_image = input_image
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres
        self.input_size = tuple(map(int, args.net_input_dims.split(',')))
        self.classes = classes
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
        self.pixel_format = 'rgb'
        self.channel_format = 'nchw'
        self.img = cv2.imread(self.input_image)

    def draw_detections(self, img, box, score, class_id):
        x1, y1, w, h = box
        color = self.color_palette[class_id]
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
        label = f"{self.classes[class_id]}: {score:.2f}"
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
        return img

    def preproc(self):
        img = self.img
        if len(img.shape) == 3:
            padded_img = np.ones((self.input_size[0], self.input_size[1], 3), dtype=np.uint8) * 114  # 114
        else:
            padded_img = np.ones(self.input_size, dtype=np.uint8) * 114  # 114

        r = min(self.input_size[0] / img.shape[0], self.input_size[1] / img.shape[1])

        resized_img = cv2.resize(
            img,
            (int(img.shape[1] * r), int(img.shape[0] * r)),
            interpolation=cv2.INTER_LINEAR,
        ).astype(np.uint8)
        top = int((self.input_size[0] - int(img.shape[0] * r)) / 2)
        left = int((self.input_size[1] - int(img.shape[1] * r)) / 2)
        padded_img[top:int(img.shape[0] * r) + top, left:int(img.shape[1] * r) + left] = resized_img

        if self.channel_format == 'nchw':
            padded_img = padded_img.transpose((2, 0, 1))  # HWC to CHW
        if self.pixel_format == 'rgb':
            padded_img = padded_img[::-1]  # BGR to RGB

        padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)

        return padded_img, r, top, left

    def postproc(self, output, r, top, left):
        def getInter(box1, box2):
            box1_x1, box1_y1, box1_x2, box1_y2 = box1[0] - box1[2] / 2, box1[1] - box1[3] / 2, \
                                                box1[0] + box1[2] / 2, box1[1] + box1[3] / 2
            box2_x1, box2_y1, box2_x2, box2_y2 = box2[0] - box2[2] / 2, box2[1] - box1[3] / 2, \
                                                box2[0] + box2[2] / 2, box2[1] + box2[3] / 2
            if box1_x1 > box2_x2 or box1_x2 < box2_x1:
                return 0
            if box1_y1 > box2_y2 or box1_y2 < box2_y1:
                return 0
            x_list = [box1_x1, box1_x2, box2_x1, box2_x2]
            x_list = np.sort(x_list)
            x_inter = x_list[2] - x_list[1]
            y_list = [box1_y1, box1_y2, box2_y1, box2_y2]
            y_list = np.sort(y_list)
            y_inter = y_list[2] - y_list[1]
            inter = x_inter * y_inter
            return inter
        def getIou(box1, box2):
            inter_area = getInter(box1, box2)
            box1_area = box1[2] * box1[3]
            box2_area = box2[2] * box2[3]
            union = box1_area + box2_area - inter_area
            iou = inter_area / union
            return iou

        img = self.img
        pred = np.transpose(np.squeeze(output[0]))
        pred_class = pred[..., 4:]
        pred_conf = np.max(pred_class, axis=-1)
        pred = np.insert(pred, 4, pred_conf, axis=-1)

        conf = pred[..., 4] > self.confidence_thres
        true_pred = pred[conf]
        true_cls_score = true_pred[..., 5:]

        all_cls = [int(np.argmax(cls_scores)) for cls_scores in true_cls_score]
        classes = list(set(all_cls))

        output_boxes = []
        for cls in classes:
            clss_mask = np.array(all_cls) == cls
            clss_boxes = true_pred[clss_mask][:, :6]
            clss_boxes[:, 5] = cls
            clss_boxes = clss_boxes[np.argsort(clss_boxes[:, 4])[::-1]]

            while len(clss_boxes) > 0:
                current_box = clss_boxes[0]
                output_boxes.append(current_box)
                if len(clss_boxes) == 1:
                    break
                ious = np.array([getIou(current_box, box) for box in clss_boxes[1:]])
                clss_boxes = clss_boxes[1:][ious < self.iou_thres]

        for bb in output_boxes: 
            x, y, w, h  = bb[:4]
            box_x = int((x - w / 2 - left) / r)
            box_y = int((y - h / 2- top) / r)
            box_width = int(w / r)
            box_height = int(h / r )        
            box = [box_x, box_y, box_width, box_height]
            score = bb[4]
            class_id = int(bb[5])
            img = self.draw_detections(img, box, score, class_id)

        return img

    def main(self):
        img, ratio, top, left = self.preproc()
        img = np.expand_dims(img, axis=0)
        img /= 255. 

        data = {"data": img}  # input name from model
        output = dict()
        if args.model.endswith('.onnx'):
            output = onnx_inference(data, args.model, False)
        elif args.model.endswith('.pt') or args.model.endswith('.pth'):
            output = torch_inference(data, args.model, False)
        elif args.model.endswith('.mlir'):
            output = mlir_inference(data, args.model, False)
        elif args.model.endswith(".bmodel"):
            output = model_inference(data, args.model)
        elif args.model.endswith(".cvimodel"):
            output = model_inference(data, args.model, False)
        else:
            raise RuntimeError("not support modle file:{}".format(args.model))
        outputs = next(iter(output.values()))
        return self.postproc(outputs, ratio, top, left) 


def parse_args():
    parser = argparse.ArgumentParser(description='Inference Yolo v5 network.')
    parser.add_argument("--model", type=str, required=True, help="Model definition file")
    parser.add_argument("--net_input_dims", type=str, default="640,640", help="(h,w) of net input")
    parser.add_argument("--input", type=str, required=True, help="Input image for testing")
    parser.add_argument("--output", type=str, required=True, help="Output image after detection")
    parser.add_argument("--conf_thres", type=float, default=0.5, help="Confidence threshold")
    parser.add_argument("--iou_thres", type=float, default=0.6, help="NMS IOU threshold")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    input_shape = tuple(map(int, args.net_input_dims.split(',')))
    detection = YOLOv8(args.model, args.net_input_dims, args.input, args.conf_thres, args.iou_thres)
    output_image = detection.main()
    cv2.imwrite(args.output, output_image)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant