-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CPU usage is up, untested if performance has actually increased
- Loading branch information
1 parent
a90d4ab
commit 9461417
Showing
2 changed files
with
152 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from collections import defaultdict | ||
from multiprocessing import Process, Queue | ||
from queue import Empty | ||
import signal | ||
import platform | ||
import numpy | ||
import sys | ||
import os | ||
|
||
|
||
class Worker(object): | ||
|
||
def __init__(self, target, *args): | ||
self.queue_out = Queue() | ||
self.queue_in = Queue() | ||
self.target = target | ||
self.args = args | ||
self.process = Process( | ||
target=self.target, | ||
args=(self.queue_out, self.queue_in,) + self.args | ||
) | ||
self.process.daemon = True | ||
|
||
def start(self): | ||
try: | ||
return self.process.start() | ||
except BrokenPipeError as e: | ||
print("=" * 30) | ||
print("Ran into a broken pipe error") | ||
print("=" * 30) | ||
print("This can occur if you are calling functions directly from a module outside of any class/function") | ||
print("Make sure you have your script entry point inside a function, for example:") | ||
print("\n".join([ | ||
"", | ||
"def main():", | ||
" # code here", | ||
"", | ||
"if __name__ == '__main__':", | ||
" main()" | ||
])) | ||
print("=" * 30) | ||
print("Original exception:") | ||
print("=" * 30) | ||
raise | ||
|
||
def get(self): | ||
return self.queue_in.get() | ||
|
||
def put(self, val): | ||
return self.queue_out.put(val) | ||
|
||
def join(self): | ||
return self.process.join() | ||
|
||
|
||
def iterframes_job(recv_queue, send_queue, times, clip_generator_cls, clip_generator_attrs, dtype): | ||
generator = clip_generator_cls(**clip_generator_attrs) | ||
clip = generator.get_clip() | ||
|
||
for current in iterate_frames_at_times(clip, times, dtype): | ||
send_queue.put(current, timeout=10) | ||
|
||
# Avoiding running ahead of the main thread and filling up memory | ||
# Timeout in 10 seconds in case the main thread has been killed | ||
try: | ||
recv_queue.get(timeout=10) | ||
except Empty: | ||
# For some reason sys.exit(), os._exit(), or raising an exception doesn't work | ||
sig = signal.SIGTERM | ||
if platform.system() == "Windows": | ||
sig = signal.CTRL_C_EVENT | ||
os.kill(os.getpid(), sig) | ||
|
||
|
||
def iterate_frames_at_times(clip, times, dtype): | ||
for time in times: | ||
frame = clip.get_frame(time) | ||
if (dtype is not None) and (frame.dtype != dtype): | ||
frame = frame.astype(dtype) | ||
yield time, frame | ||
|
||
|
||
def get_clip_times(clip, fps): | ||
return numpy.arange(0, clip.duration, 1.0 / fps) | ||
|
||
|
||
def iterframes(threads, clip, fps, dtype, with_times): | ||
attrs = { | ||
"clip": clip, | ||
"fps": fps, | ||
"dtype": dtype, | ||
} | ||
if threads < 1: | ||
generator = singlethread_iterframes | ||
else: | ||
generator = multithread_iterframes | ||
attrs["threads"] = threads | ||
|
||
for current in generator(**attrs): | ||
if with_times: | ||
yield current | ||
else: | ||
yield current[1] | ||
|
||
|
||
def singlethread_iterframes(clip, fps, dtype): | ||
for current in iterate_frames_at_times(clip, get_clip_times(clip, fps), dtype): | ||
yield current | ||
|
||
|
||
def multithread_iterframes(threads, clip, fps, dtype): | ||
times = get_clip_times(clip, fps) | ||
jobsets = defaultdict(list) | ||
for index, time in enumerate(times): | ||
jobsets[index % threads].append(time) | ||
|
||
workers = [Worker(iterframes_job, jobsets[i], clip.generator_cls, clip.generator_attrs, dtype) for i in range(threads)] | ||
for worker in workers: | ||
worker.start() | ||
|
||
for index, time in enumerate(times): | ||
current = workers[index % threads].get() | ||
workers[index % threads].put(True) | ||
yield current | ||
sys.stdout.flush() | ||
|
||
for worker in workers: | ||
worker.join() |