-
Notifications
You must be signed in to change notification settings - Fork 0
/
upscale.py
149 lines (127 loc) · 4.24 KB
/
upscale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
A wrapper module based on the default ESRGAN's `test.py`
designed to let the upscaling be called as high-level as possible.
"""
import os.path
import cv2
import numpy as np
import torch
import warnings
from esrgan import architecture as arch
try:
# support type hints in Python 3:
from typing import *
except ImportError:
pass
def upscale_images_generator(
conversions, # type: Union[str, Iterable[Union[Tuple[str, str], str]]]
model_path, # type: str
device, # type: torch.device
suffix='upRes'
):
"""
This function is basically a ripoff of the default `test.py`
with anything that might be configured turned into the arguments
(instead of special files placement and command-line arguments).
The function returns a generator object, which haven't performed any upscaling yet.
It's done this way to let you get the intermediate results as soon as each file
is done processing.
:param conversions:
An iterable defining the paths of converted files. It cane be one of:
* a single `string` - the path of the output file is generated by the `suffix`.
* an `iterable` of `strings` - the output paths detected the same way.
*
an `iterable` of 2-size tuples of strings.
First one is a source path, second one is the output.
:param model_path:
The neural-net model. Presumably, the full path of either
`RRDB_ESRGAN_x4.pth` or `RRDB_PSNR_x4.pth`
:param device: either `torch.device('cuda')` or torch.device('cpu')
:param suffix:
The string dash-attached to the source filename
to generate the name of output file. Used if no out file specified.
:return:
generator of 3-size tuples of strings:
* source file
* output file
* base name (with ext, but without path) of the source file
"""
if isinstance(conversions, str):
conversions = [conversions,]
if not suffix:
suffix = 'upRes'
model = arch.RRDB_Net(
3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu',
mode='CNA', res_scale=1, upsample_mode='upconv'
)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
def _upscale_single(
conversion # type: Union[Tuple[str, str], str]
):
# conversion = r'e:\1-Projects\0-Scripts\_neural_nets\esrgan\LR\baboon.png'
if isinstance(conversion, str):
src_path = conversion # type: str
res_path = ''
else:
src_path, res_path = conversion # type: str
base = os.path.basename(src_path)
if not res_path:
base_nm, ext = os.path.splitext(base)
res_path = src_path[:-len(base)] + base_nm + '-' + suffix + ext
# read image
img = cv2.imread(src_path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite(res_path, output)
return src_path, res_path, base
return (
_upscale_single(c) for c in conversions
)
def upscale_images(
conversions, # type: Union[str, Iterable[Union[Tuple[str, str], str]]]
model_path, # type: str
device, # type: torch.device
suffix='upRes'
):
return list(upscale_images_generator(conversions, model_path, device, suffix))
def upscale_from_cmd_args(
args, # type: List[str]
rel_model_path='/models/RRDB_ESRGAN_x4.pth',
suffix='upRes'
):
"""
A service function for batch-processing files passed as command-line arguments.
"""
# print(__file__)
script_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(script_path).replace('\\', '/')
model_path = parent_dir + rel_model_path
print('Performing upscale...\nModel path: ' + model_path)
def _perform(
device # type: torch.device
):
for i, (src_path, res_path, base) in enumerate(
upscale_images_generator(args, model_path, device, suffix),
start=1
):
if i == 1:
print('\n')
print('{0}: {1}'.format(i, base))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
_perform(torch.device('cuda'))
except (AssertionError, RuntimeError):
print("CUDA didn't work. Trying on CPU...")
_perform(torch.device('cpu'))
print('\nComplete.')
input()