-
Notifications
You must be signed in to change notification settings - Fork 65
/
test_video.py
142 lines (118 loc) · 5.04 KB
/
test_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import concurrent.futures
import json
import multiprocessing
import time
import torch
import numpy as np
from src.models.video_model import DMC
from src.utils.common import create_folder, dump_json
from src.utils.test_helper import parse_args, init_func, worker
from tqdm import tqdm
def main():
begin_time = time.time()
torch.backends.cudnn.enabled = True
args = parse_args()
if args.cuda_idx is not None:
cuda_device = ','.join([str(s) for s in args.cuda_idx])
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
worker_num = args.worker
assert worker_num >= 1
with open(args.test_config) as f:
config = json.load(f)
gpu_num = 0
if args.cuda:
gpu_num = torch.cuda.device_count()
multiprocessing.set_start_method("spawn")
threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num,
initializer=init_func,
initargs=(args, gpu_num))
objs = []
count_frames = 0
count_sequences = 0
rate_num = args.rate_num
q_indexes_i = []
if args.q_indexes_i is not None:
assert len(args.q_indexes_i) == rate_num
q_indexes_i = args.q_indexes_i
else:
assert 2 <= rate_num <= DMC.get_qp_num()
for i in np.linspace(0, DMC.get_qp_num() - 1, num=rate_num):
q_indexes_i.append(int(i+0.5))
if not args.force_intra:
if args.q_indexes_p is not None:
assert len(args.q_indexes_p) == rate_num
q_indexes_p = args.q_indexes_p
else:
q_indexes_p = q_indexes_i
print(f"testing {rate_num} rates, using q_indexes: ", end='')
for q in q_indexes_i:
print(f"{q}, ", end='')
print()
root_path = args.force_root_path if args.force_root_path is not None else config['root_path']
config = config['test_classes']
for ds_name in config:
if config[ds_name]['test'] == 0:
continue
for seq in config[ds_name]['sequences']:
count_sequences += 1
for rate_idx in range(rate_num):
cur_args = {}
cur_args['rate_idx'] = rate_idx
cur_args['float16'] = args.float16
cur_args['q_index_i'] = q_indexes_i[rate_idx]
if not args.force_intra:
cur_args['q_index_p'] = q_indexes_p[rate_idx]
cur_args['force_intra'] = args.force_intra
cur_args['reset_interval'] = args.reset_interval
cur_args['seq'] = seq
cur_args['src_type'] = config[ds_name]['src_type']
cur_args['src_height'] = config[ds_name]['sequences'][seq]['height']
cur_args['src_width'] = config[ds_name]['sequences'][seq]['width']
cur_args['intra_period'] = config[ds_name]['sequences'][seq]['intra_period']
if args.force_intra:
cur_args['intra_period'] = 1
if args.force_intra_period > 0:
cur_args['intra_period'] = args.force_intra_period
cur_args['frame_num'] = config[ds_name]['sequences'][seq]['frames']
if args.force_frame_num > 0:
cur_args['frame_num'] = args.force_frame_num
cur_args['rate_gop_size'] = args.rate_gop_size
cur_args['calc_ssim'] = args.calc_ssim
cur_args['dataset_path'] = os.path.join(root_path, config[ds_name]['base_path'])
cur_args['write_stream'] = args.write_stream
cur_args['stream_path'] = args.stream_path
cur_args['save_decoded_frame'] = args.save_decoded_frame
cur_args['ds_name'] = ds_name
cur_args['verbose'] = args.verbose
cur_args['verbose_json'] = args.verbose_json
count_frames += cur_args['frame_num']
obj = threadpool_executor.submit(worker, cur_args)
objs.append(obj)
results = []
for obj in tqdm(objs):
result = obj.result()
results.append(result)
log_result = {}
for ds_name in config:
if config[ds_name]['test'] == 0:
continue
log_result[ds_name] = {}
for seq in config[ds_name]['sequences']:
log_result[ds_name][seq] = {}
for res in results:
log_result[res['ds_name']][res['seq']][f"{res['rate_idx']:03d}"] = res
out_json_dir = os.path.dirname(args.output_path)
if len(out_json_dir) > 0:
create_folder(out_json_dir, True)
with open(args.output_path, 'w') as fp:
dump_json(log_result, fp, float_digits=6, indent=2)
total_minutes = (time.time() - begin_time) / 60
print('Test finished')
print(f'Tested {count_frames} frames from {count_sequences} sequences')
print(f'Total elapsed time: {total_minutes:.1f} min')
if __name__ == "__main__":
main()