-
Notifications
You must be signed in to change notification settings - Fork 13
/
batch_preprocess.py
68 lines (52 loc) · 2.35 KB
/
batch_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from utils import get_boundary_label, get_distance_label, binarize_matrix, split_pair_names, normalize_rgb
import tensorflow as tf
import cv2
class DataParser():
def __init__(self, img_path, label_path, label_dict, validation_split, batch_size=8, image_size=256, num_classes = 2):
self.img_path = img_path
self.label_path = label_path
self.samples = split_pair_names(self.img_path, self.label_path)
self.n_samples = len(self.samples)
self.all_ids = list(range(self.n_samples))
np.random.shuffle(self.all_ids)
train_split = 1 - validation_split
self.training_ids = self.all_ids[:int(train_split * self.n_samples)]
self.validation_ids = self.all_ids[int(train_split * self.n_samples):]
self.batch_size = batch_size
self.steps_per_epoch = len(self.training_ids)/batch_size
self.validation_steps = len(self.validation_ids)/(batch_size*2)
self.image_size = image_size
self.label_dict = label_dict
self.num_classes = num_classes
def get_batch(self, batch):
images = []
seg = []
bound = []
dist = []
for b in batch:
im = cv2.imread(self.samples[b][0])
im = cv2.resize(im,(self.image_size,self.image_size))
em = cv2.imread(self.samples[b][1], 0)
em = cv2.resize(em,(self.image_size,self.image_size))
em[em > 0] = 255
em = np.stack([em, em, em], axis=-1)
em = binarize_matrix(em, self.label_dict)
em = tf.keras.utils.to_categorical(em, self.num_classes)
im = im.astype(np.float32)
im = normalize_rgb(im)
images.append(im)
# All multitasking labels are saved in one-hot
# Segmentation
seg.append(em.astype(np.float32))
# Boundary
bound_label_h = get_boundary_label(em).astype(np.float32)
bound.append(bound_label_h)
# Distance
dist_label_h = get_distance_label(em).astype(np.float32)
dist.append(dist_label_h)
images = np.asarray(images)
labels = {'segmentation': np.asarray(seg)}
labels['boundary'] = np.asarray(bound)
labels['distance'] = np.asarray(dist)
return images, labels