-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
64 lines (59 loc) · 2.21 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# import the necessary packages
from keras.models import load_model
from skimage import transform
from skimage import exposure
from skimage import io
from imutils import paths
import numpy as np
import argparse
import imutils
import random
import cv2
import os
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
help="path to pre-trained traffic sign recognizer")
ap.add_argument("-i", "--images", required=True,
help="path to testing directory containing images")
ap.add_argument("-e", "--examples", required=True,
help="path to output examples directory")
args = vars(ap.parse_args())
# load the traffic sign recognizer model
print("[INFO] loading model...")
model = load_model(args["model"])
# load the label names
labelNames = open("signnames.csv").read().strip().split("\n")[1:]
labelNames = [l.split(",")[1] for l in labelNames]
# grab the paths to the input images, shuffle them, and grab a sample
print("[INFO] predicting...")
imagePaths = list(paths.list_images(args["images"]))
random.shuffle(imagePaths)
imagePaths = imagePaths[:25]
# loop over the image paths
for (i, imagePath) in enumerate(imagePaths):
# load the image, resize it to 32x32 pixels, and then apply
# Contrast Limited Adaptive Histogram Equalization (CLAHE),
# just like we did during training
image = io.imread(imagePath)
image = transform.resize(image, (32, 32))
image = exposure.equalize_adapthist(image, clip_limit=0.1)
# preprocess the image by scaling it to the range [0, 1]
image = image.astype("float32") / 255.0
image = np.expand_dims(image, axis=0)
# make predictions using the traffic sign recognizer CNN
preds = model.predict(image)
j = preds.argmax(axis=1)[0]
label = labelNames[j]
# load the image using OpenCV, resize it, and draw the label
# on it
image = cv2.imread(imagePath)
image = imutils.resize(image, width=128)
cv2.putText(image, label, (5, 15), cv2.FONT_HERSHEY_SIMPLEX,
0.45, (0, 0, 255), 2)
# save the image to disk
p = os.path.sep.join([args["examples"], "{}.png".format(i)])
cv2.imwrite(p, image)
print("[INFO] {}".format(label))
cv2.imshow("Image", image)
cv2.waitKey(0)