From dd1739cfb3f544456de26629af64f06505c11fad Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 30 Jul 2022 15:02:10 +0000 Subject: [PATCH 1/2] PyTorch DataLoader wrapper for FrameReader --- video2numpy/pytorch_wrapper.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 video2numpy/pytorch_wrapper.py diff --git a/video2numpy/pytorch_wrapper.py b/video2numpy/pytorch_wrapper.py new file mode 100644 index 0000000..1aeb31e --- /dev/null +++ b/video2numpy/pytorch_wrapper.py @@ -0,0 +1,21 @@ +"""Wrapper around FrameReader so it appears to be a PyTorch DataLoader.""" +import torch +from .frame_reader import FrameReader + + +class Dataset(torch.utils.data.IterableDataset): + def __init__(self, frame_reader: FrameReader): + self.frame_reader = frame_reader + + def __len__(self): + return 10 ** 9 + + def __iter__(self): + for item in self.frame_reader: + yield item + + +def fr2dl(frame_reader): + dataset = Dataset(frame_reader) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x[0]) + return dataloader From b91fe6f7da87165481a8e7b3d5c2239f35f29956 Mon Sep 17 00:00:00 2001 From: iejMac Date: Sat, 30 Jul 2022 15:05:08 +0000 Subject: [PATCH 2/2] lint fix --- benchmark/reader_benchmark.py | 2 +- tests/test_modules.py | 2 +- tests/test_read.py | 4 +++- video2numpy/pytorch_wrapper.py | 10 ++++++---- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/benchmark/reader_benchmark.py b/benchmark/reader_benchmark.py index 3139d99..9cd98bf 100644 --- a/benchmark/reader_benchmark.py +++ b/benchmark/reader_benchmark.py @@ -72,7 +72,7 @@ def benchmark_reading(vids, take_en, resize_size, workers): samp_per_s, _, _ = benchmark_reading(vids, ten, resize_size, workers) print(f"samples/s @ {fps} FPS = {samp_per_s}") results.append(samp_per_s) - time.sleep(5) # allow time for reset + time.sleep(5) # allow time for reset plt.plot(video_fps, results) plt.title(f"{args.name}: resize size - {resize_size} | workers - {workers}") diff --git a/tests/test_modules.py b/tests/test_modules.py index 75079ee..1870bfc 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -23,7 +23,7 @@ def test_reader(): reader.start_reading() for vid_frames, info in reader: - vid_frames[0,0,0,0,0] # assert still allocated + vid_frames[0, 0, 0, 0, 0] # assert still allocated mp4_name = info["dst_name"][:-4] + ".mp4" frame_count = vid_frames.shape[0] * vid_frames.shape[1] - info["pad_by"] assert frame_count == FRAME_COUNTS[mp4_name] diff --git a/tests/test_read.py b/tests/test_read.py index 86a6473..bb79975 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -18,7 +18,9 @@ def test_read(): take_en = 2 rs = 100 with tempfile.TemporaryDirectory() as tmpdir: - video2numpy(os.path.join(test_path, "test_list.txt"), tmpdir, take_every_nth=take_en, resize_size=rs, memory_size=1) + video2numpy( + os.path.join(test_path, "test_list.txt"), tmpdir, take_every_nth=take_en, resize_size=rs, memory_size=1 + ) for vid in FRAME_COUNTS.keys(): if vid.endswith(".mp4"): ld = vid[:-4] + ".npy" diff --git a/video2numpy/pytorch_wrapper.py b/video2numpy/pytorch_wrapper.py index 1aeb31e..ad02da7 100644 --- a/video2numpy/pytorch_wrapper.py +++ b/video2numpy/pytorch_wrapper.py @@ -8,7 +8,7 @@ def __init__(self, frame_reader: FrameReader): self.frame_reader = frame_reader def __len__(self): - return 10 ** 9 + return 10**9 def __iter__(self): for item in self.frame_reader: @@ -16,6 +16,8 @@ def __iter__(self): def fr2dl(frame_reader): - dataset = Dataset(frame_reader) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x[0]) - return dataloader + dataset = Dataset(frame_reader) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x[0] + ) + return dataloader