diff --git a/DCVC-FM/README.md b/DCVC-FM/README.md
new file mode 100644
index 0000000..5ac751e
--- /dev/null
+++ b/DCVC-FM/README.md
@@ -0,0 +1,131 @@
+# Introduction
+
+Official Pytorch implementation for DCVC-FM: [Neural Video Compression with **F**eature **M**odulation](https://arxiv.org/abs/2402.17414), in CVPR 2024.
+
+# Prerequisites
+* Python 3.10 and conda, get [Conda](https://www.anaconda.com/)
+* CUDA if want to use GPU
+* Environment
+ ```
+ conda create -n $YOUR_PY_ENV_NAME python=3.10
+ conda activate $YOUR_PY_ENV_NAME
+
+ conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
+ pip install -r requirements.txt
+ ```
+
+# Test dataset
+
+We support arbitrary original resolution. The input video resolution will be padded automatically. The reconstructed video will be cropped back to the original size. The distortion (PSNR) is calculated at original resolution.
+
+## YUV 420 content
+
+Put the*.yuv in the folder structure similar to the following structure.
+
+ /media/data/HEVC_B/
+ - BQTerrace_1920x1080_60.yuv
+ - BasketballDrive_1920x1080_50.yuv
+ - ...
+ /media/data/HEVC_D/
+ /media/data/HEVC_C/
+ ...
+
+The dataset structure can be seen in dataset_config_example_yuv420.json.
+
+## RGB content
+
+Please convert YUV 420 data to RGB data using BT.709 conversion matrix.
+
+For example, one video of HEVC Class B can be prepared as:
+* Make the video path:
+ ```
+ mkdir BasketballDrive_1920x1080_50
+ ```
+* Convert YUV to PNG:
+
+We use BT.709 conversion matrix to generate png data to test RGB sequences. Please refer to ./test_data_to_png.py for more details.
+
+At last, the folder structure of dataset is like:
+
+ /media/data/HEVC_B/
+ * BQTerrace_1920x1080_60/
+ - im00001.png
+ - im00002.png
+ - im00003.png
+ - ...
+ * BasketballDrive_1920x1080_50/
+ - im00001.png
+ - im00002.png
+ - im00003.png
+ - ...
+ * ...
+ /media/data/HEVC_D/
+ /media/data/HEVC_C/
+ ...
+
+The dataset structure can be seen in dataset_config_example_rgb.json.
+
+# Build the project
+Please build the C++ code if want to test with actual bitstream writing. There is minor difference about the bits for calculating the bits using entropy (the method used in the paper to report numbers) and actual bitstreaming writing. There is overhead when writing the bitstream into the file and the difference percentage depends on the bitstream size.
+
+## Build the entropy encoding/decoding module
+```bash
+sudo apt-get install cmake g++
+cd src
+mkdir build
+cd build
+conda activate $YOUR_PY_ENV_NAME
+cmake ../cpp -DCMAKE_BUILD_TYPE=Release
+make -j
+```
+
+## Build customized flow warp implementation (especially you want to test fp16 inference)
+```
+sudo apt install ninja-build
+cd ./src/models/extensions/
+python setup.py build_ext --inplace
+```
+
+# Pretrained models
+
+* Download [our pretrained models](https://1drv.ms/f/s!AozfVVwtWWYoi1QkAhlIE-7aAaKV?e=OoemTr) and put them into ./checkpoints folder.
+* Or run the script in ./checkpoints directly to download the model.
+* There are 2 models, one for image coding and the other for video coding.
+
+# Test the models
+
+Example to test pretrained model with four rate points:
+```bash
+python test_video.py --model_path_i ./checkpoints/cvpr2024_image.pth.tar --model_path_p ./checkpoints/cvpr2024_video.pth.tar --rate_num 4 --test_config ./dataset_config_example_yuv420.json --cuda 1 --worker 1 --write_stream 0 --output_path output.json --force_intra_period 9999 --force_frame_num 96
+```
+
+It is recommended that the ```--worker``` number is equal to your GPU number.
+
+You can also specify different ```--rate_num``` values (2~64) to test finer bitrate adjustment.
+
+# Comparing with other method
+Bit saving over VTM-17.0 (HEVC E (600 frames) with single intra-frame setting (i.e. intra-period = –1) and YUV420 colorspace.)
+
+
+
+RD curve of YUV420 PSNR
+
+
+
+# Acknowledgement
+The implementation is based on [CompressAI](https://github.com/InterDigitalInc/CompressAI) and [PyTorchVideoCompression](https://github.com/ZhihaoHu/PyTorchVideoCompression).
+# Citation
+If you find this work useful for your research, please cite:
+
+```
+@inproceedings{li2024neural,
+ title={Neural Video Compression with Feature Modulation},
+ author={Li, Jiahao and Li, Bin and Lu, Yan},
+ booktitle={{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
+ {CVPR} 2024, Seattle, WA, USA, June 17-21, 2024},
+ year={2024}
+}
+```
+
+# Trademarks
+This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
diff --git a/DCVC-FM/assets/bitsaving.png b/DCVC-FM/assets/bitsaving.png
new file mode 100644
index 0000000..9ed65fa
Binary files /dev/null and b/DCVC-FM/assets/bitsaving.png differ
diff --git a/DCVC-FM/assets/rd_yuv420_psnr.png b/DCVC-FM/assets/rd_yuv420_psnr.png
new file mode 100644
index 0000000..a4d6cd3
Binary files /dev/null and b/DCVC-FM/assets/rd_yuv420_psnr.png differ
diff --git a/DCVC-FM/checkpoints/download.py b/DCVC-FM/checkpoints/download.py
new file mode 100644
index 0000000..36bbc2b
--- /dev/null
+++ b/DCVC-FM/checkpoints/download.py
@@ -0,0 +1,21 @@
+import urllib.request
+
+
+def download_one(url, target):
+ urllib.request.urlretrieve(url, target)
+
+
+def main():
+ urls = {
+ 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211494&authkey=!AOxzcrEFT_h-iCk': 'cvpr2024_image.pth.tar',
+ 'https://onedrive.live.com/download?cid=2866592D5C55DF8C&resid=2866592D5C55DF8C%211493&authkey=!AFxYv6oK1o6GrZc': 'cvpr2024_video.pth.tar',
+ }
+ for url in urls:
+ target = urls[url]
+ print("downloading", target)
+ download_one(url, target)
+ print("downloaded", target)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DCVC-FM/dataset_config_example_rgb.json b/DCVC-FM/dataset_config_example_rgb.json
new file mode 100644
index 0000000..9b8dbd4
--- /dev/null
+++ b/DCVC-FM/dataset_config_example_rgb.json
@@ -0,0 +1,100 @@
+{
+ "root_path": "/media/data/",
+ "test_classes": {
+ "UVG": {
+ "test": 1,
+ "base_path": "UVG",
+ "src_type": "png",
+ "sequences": {
+ "Beauty_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "Bosphorus_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "HoneyBee_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "Jockey_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "ShakeNDry_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "YachtRide_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96}
+ }
+ },
+ "MCL-JCV": {
+ "test": 1,
+ "base_path": "MCL-JCV",
+ "src_type": "png",
+ "sequences": {
+ "videoSRC01_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC02_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC03_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC04_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC05_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC06_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC07_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC08_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC09_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC10_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC11_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC12_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC13_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC14_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC15_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC16_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC17_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC18_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC19_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC20_1920x1080_25": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC21_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC22_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC23_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC24_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC25_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC26_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC27_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC28_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC29_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "videoSRC30_1920x1080_30": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96}
+ }
+ },
+ "HEVC_B": {
+ "test": 1,
+ "base_path": "HEVC_B",
+ "src_type": "png",
+ "sequences": {
+ "BQTerrace_1920x1080_60": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "BasketballDrive_1920x1080_50": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "Cactus_1920x1080_50": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "Kimono1_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96},
+ "ParkScene_1920x1080_24": {"width": 1920, "height": 1080, "frames": 96, "intra_period": 96}
+ }
+ },
+ "HEVC_E": {
+ "test": 1,
+ "base_path": "HEVC_E",
+ "src_type": "png",
+ "sequences": {
+ "FourPeople_1280x720_60": {"width": 1280, "height": 720, "frames": 96, "intra_period": 96},
+ "Johnny_1280x720_60": {"width": 1280, "height": 720, "frames": 96, "intra_period": 96},
+ "KristenAndSara_1280x720_60": {"width": 1280, "height": 720, "frames": 96, "intra_period": 96}
+ }
+ },
+ "HEVC_C": {
+ "test": 1,
+ "base_path": "HEVC_C",
+ "src_type": "png",
+ "sequences": {
+ "BQMall_832x480_60": {"width": 832, "height": 480, "frames": 96, "intra_period": 96},
+ "BasketballDrill_832x480_50": {"width": 832, "height": 480, "frames": 96, "intra_period": 96},
+ "PartyScene_832x480_50": {"width": 832, "height": 480, "frames": 96, "intra_period": 96},
+ "RaceHorses_832x480_30": {"width": 832, "height": 480, "frames": 96, "intra_period": 96}
+ }
+ },
+ "HEVC_D": {
+ "test": 1,
+ "base_path": "HEVC_D",
+ "src_type": "png",
+ "sequences": {
+ "BasketballPass_416x240_50": {"width": 416, "height": 240, "frames": 96, "intra_period": 96},
+ "BlowingBubbles_416x240_50": {"width": 416, "height": 240, "frames": 96, "intra_period": 96},
+ "BQSquare_416x240_60": {"width": 416, "height": 240, "frames": 96, "intra_period": 96},
+ "RaceHorses_416x240_30": {"width": 416, "height": 240, "frames": 96, "intra_period": 96}
+ }
+ }
+ }
+}
diff --git a/DCVC-FM/dataset_config_example_yuv420.json b/DCVC-FM/dataset_config_example_yuv420.json
new file mode 100644
index 0000000..d548014
--- /dev/null
+++ b/DCVC-FM/dataset_config_example_yuv420.json
@@ -0,0 +1,100 @@
+{
+ "root_path": "/media/data/",
+ "test_classes": {
+ "UVG": {
+ "test": 1,
+ "base_path": "UVG",
+ "src_type": "yuv420",
+ "sequences": {
+ "Beauty_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "Bosphorus_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "HoneyBee_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "Jockey_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "ShakeNDry_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 300, "intra_period": -1},
+ "YachtRide_1920x1080_120fps_420_8bit_YUV.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1}
+ }
+ },
+ "MCL-JCV": {
+ "test": 1,
+ "base_path": "MCL-JCV",
+ "src_type": "yuv420",
+ "sequences": {
+ "videoSRC01_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC02_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC03_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC04_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC05_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC06_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC07_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC08_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC09_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC10_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC11_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC12_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC13_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC14_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC15_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC16_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC17_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC18_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC19_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC20_1920x1080_25.yuv": {"width": 1920, "height": 1080, "frames": 125, "intra_period": -1},
+ "videoSRC21_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC22_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC23_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC24_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC25_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC26_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC27_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC28_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1},
+ "videoSRC29_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 120, "intra_period": -1},
+ "videoSRC30_1920x1080_30.yuv": {"width": 1920, "height": 1080, "frames": 150, "intra_period": -1}
+ }
+ },
+ "HEVC_B": {
+ "test": 1,
+ "base_path": "HEVC_B",
+ "src_type": "yuv420",
+ "sequences": {
+ "BQTerrace_1920x1080_60.yuv": {"width": 1920, "height": 1080, "frames": 600, "intra_period": -1},
+ "BasketballDrive_1920x1080_50.yuv": {"width": 1920, "height": 1080, "frames": 500, "intra_period": -1},
+ "Cactus_1920x1080_50.yuv": {"width": 1920, "height": 1080, "frames": 500, "intra_period": -1},
+ "Kimono1_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 240, "intra_period": -1},
+ "ParkScene_1920x1080_24.yuv": {"width": 1920, "height": 1080, "frames": 240, "intra_period": -1}
+ }
+ },
+ "HEVC_E": {
+ "test": 1,
+ "base_path": "HEVC_E",
+ "src_type": "yuv420",
+ "sequences": {
+ "FourPeople_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1},
+ "Johnny_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1},
+ "KristenAndSara_1280x720_60.yuv": {"width": 1280, "height": 720, "frames": 600, "intra_period": -1}
+ }
+ },
+ "HEVC_C": {
+ "test": 1,
+ "base_path": "HEVC_C",
+ "src_type": "yuv420",
+ "sequences": {
+ "BQMall_832x480_60.yuv": {"width": 832, "height": 480, "frames": 600, "intra_period": -1},
+ "BasketballDrill_832x480_50.yuv": {"width": 832, "height": 480, "frames": 500, "intra_period": -1},
+ "PartyScene_832x480_50.yuv": {"width": 832, "height": 480, "frames": 500, "intra_period": -1},
+ "RaceHorses_832x480_30.yuv": {"width": 832, "height": 480, "frames": 300, "intra_period": -1}
+ }
+ },
+ "HEVC_D": {
+ "test": 1,
+ "base_path": "HEVC_D",
+ "src_type": "yuv420",
+ "sequences": {
+ "BasketballPass_416x240_50.yuv": {"width": 416, "height": 240, "frames": 500, "intra_period": -1},
+ "BlowingBubbles_416x240_50.yuv": {"width": 416, "height": 240, "frames": 500, "intra_period": -1},
+ "BQSquare_416x240_60.yuv": {"width": 416, "height": 240, "frames": 600, "intra_period": -1},
+ "RaceHorses_416x240_30.yuv": {"width": 416, "height": 240, "frames": 300, "intra_period": -1}
+ }
+ }
+ }
+}
diff --git a/DCVC-FM/requirements.txt b/DCVC-FM/requirements.txt
new file mode 100644
index 0000000..6c5a43b
--- /dev/null
+++ b/DCVC-FM/requirements.txt
@@ -0,0 +1,8 @@
+numpy>=1.20.0
+scipy
+matplotlib
+torch>=2.0.0
+tensorboard
+tqdm
+bd-metric
+ptflops
diff --git a/DCVC-FM/src/cpp/3rdparty/CMakeLists.txt b/DCVC-FM/src/cpp/3rdparty/CMakeLists.txt
new file mode 100644
index 0000000..8d94573
--- /dev/null
+++ b/DCVC-FM/src/cpp/3rdparty/CMakeLists.txt
@@ -0,0 +1,4 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+add_subdirectory(pybind11)
diff --git a/DCVC-FM/src/cpp/3rdparty/pybind11/CMakeLists.txt b/DCVC-FM/src/cpp/3rdparty/pybind11/CMakeLists.txt
new file mode 100644
index 0000000..3c88809
--- /dev/null
+++ b/DCVC-FM/src/cpp/3rdparty/pybind11/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt)
+execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
+ RESULT_VARIABLE result
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download )
+if(result)
+ message(FATAL_ERROR "CMake step for pybind11 failed: ${result}")
+endif()
+execute_process(COMMAND ${CMAKE_COMMAND} --build .
+ RESULT_VARIABLE result
+ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download )
+if(result)
+ message(FATAL_ERROR "Build step for pybind11 failed: ${result}")
+endif()
+
+add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/
+ ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/
+ EXCLUDE_FROM_ALL)
+
+set(PYBIND11_INCLUDE
+ ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/
+ CACHE INTERNAL "")
diff --git a/DCVC-FM/src/cpp/3rdparty/pybind11/CMakeLists.txt.in b/DCVC-FM/src/cpp/3rdparty/pybind11/CMakeLists.txt.in
new file mode 100644
index 0000000..936202e
--- /dev/null
+++ b/DCVC-FM/src/cpp/3rdparty/pybind11/CMakeLists.txt.in
@@ -0,0 +1,33 @@
+cmake_minimum_required(VERSION 3.6.3)
+
+project(pybind11-download NONE)
+
+include(ExternalProject)
+if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include")
+ ExternalProject_Add(pybind11
+ GIT_REPOSITORY https://github.com/pybind/pybind11.git
+ GIT_TAG v2.10.4
+ GIT_SHALLOW 1
+ SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src"
+ BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build"
+ DOWNLOAD_COMMAND ""
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+ )
+else()
+ ExternalProject_Add(pybind11
+ GIT_REPOSITORY https://github.com/pybind/pybind11.git
+ GIT_TAG v2.10.4
+ GIT_SHALLOW 1
+ SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src"
+ BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build"
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND ""
+ TEST_COMMAND ""
+ )
+endif()
diff --git a/DCVC-FM/src/cpp/CMakeLists.txt b/DCVC-FM/src/cpp/CMakeLists.txt
new file mode 100644
index 0000000..06001c6
--- /dev/null
+++ b/DCVC-FM/src/cpp/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+cmake_minimum_required (VERSION 3.6.3)
+project (MLCodec)
+
+set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CXX_EXTENSIONS OFF)
+
+# treat warning as error
+if (MSVC)
+ add_compile_options(/W4 /WX)
+else()
+ add_compile_options(-Wall -Wextra -pedantic -Werror)
+endif()
+
+# The sequence is tricky, put 3rd party first
+add_subdirectory(3rdparty)
+add_subdirectory (ops)
+add_subdirectory (rans)
+add_subdirectory (py_rans)
diff --git a/DCVC-FM/src/cpp/ops/CMakeLists.txt b/DCVC-FM/src/cpp/ops/CMakeLists.txt
new file mode 100644
index 0000000..ed31abb
--- /dev/null
+++ b/DCVC-FM/src/cpp/ops/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+cmake_minimum_required(VERSION 3.7)
+set(PROJECT_NAME MLCodec_CXX)
+project(${PROJECT_NAME})
+
+set(cxx_source
+ ops.cpp
+ )
+
+set(include_dirs
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${PYBIND11_INCLUDE}
+ )
+
+pybind11_add_module(${PROJECT_NAME} ${cxx_source})
+
+target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
+
+# The post build argument is executed after make!
+add_custom_command(
+ TARGET ${PROJECT_NAME} POST_BUILD
+ COMMAND
+ "${CMAKE_COMMAND}" -E copy
+ "$"
+ "${CMAKE_CURRENT_SOURCE_DIR}/../../models/"
+)
diff --git a/DCVC-FM/src/cpp/ops/ops.cpp b/DCVC-FM/src/cpp/ops/ops.cpp
new file mode 100644
index 0000000..9463ab7
--- /dev/null
+++ b/DCVC-FM/src/cpp/ops/ops.cpp
@@ -0,0 +1,91 @@
+/* Copyright 2020 InterDigital Communications, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+std::vector pmf_to_quantized_cdf(const std::vector &pmf,
+ int precision) {
+ /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal
+ * although it's only run once per model after training. See TF/compression
+ * implementation for an optimized version. */
+
+ std::vector cdf(pmf.size() + 1);
+ cdf[0] = 0; /* freq 0 */
+
+ std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) {
+ return static_cast(std::round(p * (1 << precision)) + 0.5);
+ });
+
+ const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0);
+
+ std::transform(
+ cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) {
+ return static_cast((((1ull << precision) * p) / total));
+ });
+
+ std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
+ cdf.back() = 1 << precision;
+
+ for (int i = 0; i < static_cast(cdf.size() - 1); ++i) {
+ if (cdf[i] == cdf[i + 1]) {
+ /* Try to steal frequency from low-frequency symbols */
+ uint32_t best_freq = ~0u;
+ int best_steal = -1;
+ for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) {
+ uint32_t freq = cdf[j + 1] - cdf[j];
+ if (freq > 1 && freq < best_freq) {
+ best_freq = freq;
+ best_steal = j;
+ }
+ }
+
+ assert(best_steal != -1);
+
+ if (best_steal < i) {
+ for (int j = best_steal + 1; j <= i; ++j) {
+ cdf[j]--;
+ }
+ } else {
+ assert(best_steal > i);
+ for (int j = i + 1; j <= best_steal; ++j) {
+ cdf[j]++;
+ }
+ }
+ }
+ }
+
+ assert(cdf[0] == 0);
+ assert(cdf.back() == (1u << precision));
+ for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) {
+ assert(cdf[i + 1] > cdf[i]);
+ }
+
+ return cdf;
+}
+
+PYBIND11_MODULE(MLCodec_CXX, m) {
+ m.attr("__name__") = "MLCodec_CXX";
+
+ m.doc() = "C++ utils";
+
+ m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf,
+ "Return quantized CDF for a given PMF");
+}
diff --git a/DCVC-FM/src/cpp/py_rans/CMakeLists.txt b/DCVC-FM/src/cpp/py_rans/CMakeLists.txt
new file mode 100644
index 0000000..b99e3c6
--- /dev/null
+++ b/DCVC-FM/src/cpp/py_rans/CMakeLists.txt
@@ -0,0 +1,30 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+cmake_minimum_required(VERSION 3.7)
+set(PROJECT_NAME MLCodec_rans)
+project(${PROJECT_NAME})
+
+set(py_rans_source
+ py_rans.h
+ py_rans.cpp
+ )
+
+set(include_dirs
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${PYBIND11_INCLUDE}
+ )
+
+pybind11_add_module(${PROJECT_NAME} ${py_rans_source})
+
+target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
+target_link_libraries (${PROJECT_NAME} LINK_PUBLIC Rans)
+
+# The post build argument is executed after make!
+add_custom_command(
+ TARGET ${PROJECT_NAME} POST_BUILD
+ COMMAND
+ "${CMAKE_COMMAND}" -E copy
+ "$"
+ "${CMAKE_CURRENT_SOURCE_DIR}/../../models/"
+)
diff --git a/DCVC-FM/src/cpp/py_rans/py_rans.cpp b/DCVC-FM/src/cpp/py_rans/py_rans.cpp
new file mode 100644
index 0000000..0311b03
--- /dev/null
+++ b/DCVC-FM/src/cpp/py_rans/py_rans.cpp
@@ -0,0 +1,281 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include "py_rans.h"
+
+#include
+#include
+
+namespace py = pybind11;
+
+RansEncoder::RansEncoder(bool multiThread, int streamPart) {
+ bool useMultiThread = multiThread || streamPart > 1;
+ for (int i = 0; i < streamPart; i++) {
+ if (useMultiThread) {
+ m_encoders.push_back(std::make_shared());
+ } else {
+ m_encoders.push_back(std::make_shared());
+ }
+ }
+}
+
+void RansEncoder::encode_with_indexes(const py::array_t &symbols,
+ const py::array_t &indexes,
+ const int cdf_group_index) {
+ py::buffer_info symbols_buf = symbols.request();
+ py::buffer_info indexes_buf = indexes.request();
+ int16_t *symbols_ptr = static_cast(symbols_buf.ptr);
+ int16_t *indexes_ptr = static_cast(indexes_buf.ptr);
+
+ int encoderNum = static_cast(m_encoders.size());
+ int symbolSize = static_cast(symbols.size());
+ int eachSymbolSize = symbolSize / encoderNum;
+ int lastSymbolSize = symbolSize - eachSymbolSize * (encoderNum - 1);
+ for (int i = 0; i < encoderNum; i++) {
+ int currSymbolSize = i < encoderNum - 1 ? eachSymbolSize : lastSymbolSize;
+ int currOffset = i * eachSymbolSize;
+ auto copySize = sizeof(int16_t) * currSymbolSize;
+ auto vec_symbols = std::make_shared>(currSymbolSize);
+ memcpy(vec_symbols->data(), symbols_ptr + currOffset, copySize);
+ auto vec_indexes = std::make_shared>(eachSymbolSize);
+ memcpy(vec_indexes->data(), indexes_ptr + currOffset, copySize);
+ m_encoders[i]->encode_with_indexes(vec_symbols, vec_indexes,
+ cdf_group_index);
+ }
+}
+
+int RansEncoder::add_cdf(const py::array_t &cdfs,
+ const py::array_t &cdfs_sizes,
+ const py::array_t &offsets) {
+ py::buffer_info cdfs_sizes_buf = cdfs_sizes.request();
+ py::buffer_info offsets_buf = offsets.request();
+ int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr);
+ int32_t *offsets_ptr = static_cast(offsets_buf.ptr);
+
+ int cdf_num = static_cast(cdfs_sizes.size());
+ auto vec_cdfs_sizes = std::make_shared>(cdf_num);
+ memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num);
+ auto vec_offsets = std::make_shared>(offsets.size());
+ memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size());
+
+ int per_vector_size = static_cast(cdfs.size() / cdf_num);
+ auto vec_cdfs = std::make_shared>>(cdf_num);
+ auto cdfs_raw = cdfs.unchecked<2>();
+ for (int i = 0; i < cdf_num; i++) {
+ std::vector t(per_vector_size);
+ memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size);
+ vec_cdfs->at(i) = t;
+ }
+
+ int encoderNum = static_cast(m_encoders.size());
+ int cdfIdx = 0;
+ for (int i = 0; i < encoderNum; i++) {
+ cdfIdx = m_encoders[i]->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets);
+ }
+ return cdfIdx;
+}
+
+void RansEncoder::empty_cdf_buffer() {
+ int encoderNum = static_cast(m_encoders.size());
+ for (int i = 0; i < encoderNum; i++) {
+ m_encoders[i]->empty_cdf_buffer();
+ }
+}
+
+void RansEncoder::flush() {
+ for (auto encoder : m_encoders) {
+ encoder->flush();
+ }
+}
+
+py::array_t RansEncoder::get_encoded_stream() {
+ std::vector> results;
+ int maximumSize = 0;
+ int totalSize = 0;
+ int encoderNumber = static_cast(m_encoders.size());
+ for (int i = 0; i < encoderNumber; i++) {
+ std::vector result = m_encoders[i]->get_encoded_stream();
+ results.push_back(result);
+ int nbytes = static_cast(result.size());
+ if (i < encoderNumber - 1 && nbytes > maximumSize) {
+ maximumSize = nbytes;
+ }
+ totalSize += nbytes;
+ }
+
+ int overhead = 1;
+ int perStreamHeader = maximumSize > 65535 ? 4 : 2;
+ if (encoderNumber > 1) {
+ overhead += ((encoderNumber - 1) * perStreamHeader);
+ }
+
+ py::array_t stream(totalSize + overhead);
+ py::buffer_info stream_buf = stream.request();
+ uint8_t *stream_ptr = static_cast(stream_buf.ptr);
+
+ uint8_t flag = static_cast(((encoderNumber - 1) << 4) +
+ (perStreamHeader == 2 ? 1 : 0));
+ memcpy(stream_ptr, &flag, 1);
+ for (int i = 0; i < encoderNumber - 1; i++) {
+ if (perStreamHeader == 2) {
+ uint16_t streamSizes = static_cast(results[i].size());
+ memcpy(stream_ptr + 1 + 2 * i, &streamSizes, 2);
+ } else {
+ uint32_t streamSizes = static_cast(results[i].size());
+ memcpy(stream_ptr + 1 + 4 * i, &streamSizes, 4);
+ }
+ }
+
+ int offset = overhead;
+ for (int i = 0; i < encoderNumber; i++) {
+ int nbytes = static_cast(results[i].size());
+ memcpy(stream_ptr + offset, results[i].data(), nbytes);
+ offset += nbytes;
+ }
+ return stream;
+}
+
+void RansEncoder::reset() {
+ for (auto encoder : m_encoders) {
+ encoder->reset();
+ }
+}
+
+RansDecoder::RansDecoder(int streamPart) {
+ for (int i = 0; i < streamPart; i++) {
+ m_decoders.push_back(std::make_shared());
+ }
+}
+
+void RansDecoder::set_stream(const py::array_t &encoded) {
+ py::buffer_info encoded_buf = encoded.request();
+ uint8_t flag = *(static_cast(encoded_buf.ptr));
+ int numberOfStreams = (flag >> 4) + 1;
+
+ uint8_t perStreamSizeLength = (flag & 0x0f) == 1 ? 2 : 4;
+ std::vector streamSizes;
+ int offset = 1;
+ int totalSize = 0;
+ for (int i = 0; i < numberOfStreams - 1; i++) {
+ uint8_t *currPos = static_cast(encoded_buf.ptr) + offset;
+ if (perStreamSizeLength == 2) {
+ uint16_t streamSize = *(reinterpret_cast(currPos));
+ offset += 2;
+ streamSizes.push_back(streamSize);
+ totalSize += streamSize;
+ } else {
+ uint32_t streamSize = *(reinterpret_cast(currPos));
+ offset += 4;
+ streamSizes.push_back(streamSize);
+ totalSize += streamSize;
+ }
+ }
+ streamSizes.push_back(static_cast(encoded.size()) - offset - totalSize);
+ for (int i = 0; i < numberOfStreams; i++) {
+ auto stream = std::make_shared>(streamSizes[i]);
+ memcpy(stream->data(), static_cast(encoded_buf.ptr) + offset,
+ streamSizes[i]);
+ m_decoders[i]->set_stream(stream);
+ offset += streamSizes[i];
+ }
+}
+
+py::array_t
+RansDecoder::decode_stream(const py::array_t &indexes,
+ const int cdf_group_index) {
+ py::buffer_info indexes_buf = indexes.request();
+ int16_t *indexes_ptr = static_cast(indexes_buf.ptr);
+
+ int decoderNum = static_cast(m_decoders.size());
+ int indexSize = static_cast(indexes.size());
+ int eachSymbolSize = indexSize / decoderNum;
+ int lastSymbolSize = indexSize - eachSymbolSize * (decoderNum - 1);
+
+ std::vector>> results;
+
+ for (int i = 0; i < decoderNum; i++) {
+ int currSymbolSize = i < decoderNum - 1 ? eachSymbolSize : lastSymbolSize;
+ int copySize = sizeof(int16_t) * currSymbolSize;
+ auto vec_indexes = std::make_shared>(currSymbolSize);
+ memcpy(vec_indexes->data(), indexes_ptr + i * eachSymbolSize, copySize);
+
+ std::shared_future> result =
+ std::async(std::launch::async, [=]() {
+ return m_decoders[i]->decode_stream(vec_indexes, cdf_group_index);
+ });
+ results.push_back(result);
+ }
+
+ py::array_t output(indexes.size());
+ py::buffer_info buf = output.request();
+ int offset = 0;
+ for (int i = 0; i < decoderNum; i++) {
+ std::vector result = results[i].get();
+ int resultSize = static_cast(result.size());
+ int copySize = sizeof(int16_t) * resultSize;
+ memcpy(static_cast(buf.ptr) + offset, result.data(), copySize);
+ offset += resultSize;
+ }
+
+ return output;
+}
+
+int RansDecoder::add_cdf(const py::array_t &cdfs,
+ const py::array_t &cdfs_sizes,
+ const py::array_t &offsets) {
+ py::buffer_info cdfs_sizes_buf = cdfs_sizes.request();
+ py::buffer_info offsets_buf = offsets.request();
+ int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr);
+ int32_t *offsets_ptr = static_cast(offsets_buf.ptr);
+
+ int cdf_num = static_cast(cdfs_sizes.size());
+ auto vec_cdfs_sizes = std::make_shared>(cdf_num);
+ memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num);
+ auto vec_offsets = std::make_shared>(offsets.size());
+ memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size());
+
+ int per_vector_size = static_cast(cdfs.size() / cdf_num);
+ auto vec_cdfs = std::make_shared>>(cdf_num);
+ auto cdfs_raw = cdfs.unchecked<2>();
+ for (int i = 0; i < cdf_num; i++) {
+ std::vector t(per_vector_size);
+ memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size);
+ vec_cdfs->at(i) = t;
+ }
+ int decoderNum = static_cast(m_decoders.size());
+ int cdfIdx = 0;
+ for (int i = 0; i < decoderNum; i++) {
+ cdfIdx = m_decoders[i]->add_cdf(vec_cdfs, vec_cdfs_sizes, vec_offsets);
+ }
+
+ return cdfIdx;
+}
+
+void RansDecoder::empty_cdf_buffer() {
+ int decoderNum = static_cast(m_decoders.size());
+ for (int i = 0; i < decoderNum; i++) {
+ m_decoders[i]->empty_cdf_buffer();
+ }
+}
+
+PYBIND11_MODULE(MLCodec_rans, m) {
+ m.attr("__name__") = "MLCodec_rans";
+
+ m.doc() = "range Asymmetric Numeral System python bindings";
+
+ py::class_(m, "RansEncoder")
+ .def(py::init())
+ .def("encode_with_indexes", &RansEncoder::encode_with_indexes)
+ .def("flush", &RansEncoder::flush)
+ .def("get_encoded_stream", &RansEncoder::get_encoded_stream)
+ .def("reset", &RansEncoder::reset)
+ .def("add_cdf", &RansEncoder::add_cdf)
+ .def("empty_cdf_buffer", &RansEncoder::empty_cdf_buffer);
+
+ py::class_(m, "RansDecoder")
+ .def(py::init())
+ .def("set_stream", &RansDecoder::set_stream)
+ .def("decode_stream", &RansDecoder::decode_stream)
+ .def("add_cdf", &RansDecoder::add_cdf)
+ .def("empty_cdf_buffer", &RansDecoder::empty_cdf_buffer);
+}
diff --git a/DCVC-FM/src/cpp/py_rans/py_rans.h b/DCVC-FM/src/cpp/py_rans/py_rans.h
new file mode 100644
index 0000000..39045b4
--- /dev/null
+++ b/DCVC-FM/src/cpp/py_rans/py_rans.h
@@ -0,0 +1,57 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#pragma once
+#include "rans.h"
+#include
+#include
+
+namespace py = pybind11;
+
+// the classes in this file only perform the type conversion
+// from python type (numpy) to C++ type (vector)
+class RansEncoder {
+public:
+ RansEncoder(bool multiThread, int streamPart);
+
+ RansEncoder(const RansEncoder &) = delete;
+ RansEncoder(RansEncoder &&) = delete;
+ RansEncoder &operator=(const RansEncoder &) = delete;
+ RansEncoder &operator=(RansEncoder &&) = delete;
+
+ void encode_with_indexes(const py::array_t &symbols,
+ const py::array_t &indexes,
+ const int cdf_group_index);
+ void flush();
+ py::array_t get_encoded_stream();
+ void reset();
+ int add_cdf(const py::array_t &cdfs,
+ const py::array_t &cdfs_sizes,
+ const py::array_t &offsets);
+ void empty_cdf_buffer();
+
+private:
+ std::vector> m_encoders;
+};
+
+class RansDecoder {
+public:
+ RansDecoder(int streamPart);
+
+ RansDecoder(const RansDecoder &) = delete;
+ RansDecoder(RansDecoder &&) = delete;
+ RansDecoder &operator=(const RansDecoder &) = delete;
+ RansDecoder &operator=(RansDecoder &&) = delete;
+
+ void set_stream(const py::array_t &);
+
+ py::array_t decode_stream(const py::array_t &indexes,
+ const int cdf_group_index);
+ int add_cdf(const py::array_t &cdfs,
+ const py::array_t &cdfs_sizes,
+ const py::array_t &offsets);
+ void empty_cdf_buffer();
+
+private:
+ std::vector> m_decoders;
+};
diff --git a/DCVC-FM/src/cpp/rans/CMakeLists.txt b/DCVC-FM/src/cpp/rans/CMakeLists.txt
new file mode 100644
index 0000000..ea23ba6
--- /dev/null
+++ b/DCVC-FM/src/cpp/rans/CMakeLists.txt
@@ -0,0 +1,23 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+cmake_minimum_required(VERSION 3.7)
+set(PROJECT_NAME Rans)
+project(${PROJECT_NAME})
+
+set(rans_source
+ rans_byte.h
+ rans.h
+ rans.cpp
+ )
+
+set(include_dirs
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${RYG_RANS_INCLUDE}
+ )
+
+if (NOT MSVC)
+ add_compile_options(-fPIC)
+endif()
+add_library (${PROJECT_NAME} ${rans_source})
+target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
diff --git a/DCVC-FM/src/cpp/rans/rans.cpp b/DCVC-FM/src/cpp/rans/rans.cpp
new file mode 100644
index 0000000..14dc46c
--- /dev/null
+++ b/DCVC-FM/src/cpp/rans/rans.cpp
@@ -0,0 +1,362 @@
+/* Copyright 2020 InterDigital Communications, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Rans64 extensions from:
+ * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/
+ * Unbounded range coding from:
+ * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc
+ **/
+
+#include "rans.h"
+
+#include
+#include
+#include
+
+/* probability range, this could be a parameter... */
+constexpr int precision = 16;
+
+constexpr uint16_t bypass_precision = 2; /* number of bits in bypass mode */
+constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1;
+namespace {
+
+inline void RansEncPutBits(RansState *r, uint8_t **pptr, uint32_t val,
+ uint32_t nbits) {
+ assert(nbits <= 8);
+ assert(val < (1u << nbits));
+
+ uint32_t x = *r;
+ uint32_t freq = 1 << (precision - nbits);
+ // uint32_t x_max =
+ //((RANS_BYTE_L >> precision) << 8) * freq; // this turns into a shift.
+ uint32_t x_max = freq << 15;
+ while (x >= x_max) {
+ *(--(*pptr)) = static_cast(x & 0xff);
+ x >>= 8;
+ }
+
+ *r = (x << nbits) | val;
+}
+
+inline uint32_t RansDecGetBits(RansState *r, uint8_t **pptr, uint32_t n_bits) {
+ uint32_t x = *r;
+ uint32_t val = x & ((1u << n_bits) - 1);
+
+ /* Re-normalize */
+ x = x >> n_bits;
+ if (x < RANS_BYTE_L) {
+ x = (x << 8) | **pptr;
+ *pptr += 1;
+ RansAssert(x >= RANS_BYTE_L);
+ }
+
+ *r = x;
+
+ return val;
+}
+} // namespace
+
+int RansEncoderLib::add_cdf(
+ const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets) {
+
+ auto ransSymbols =
+ std::make_shared>>(cdfs->size());
+ for (int i = 0; i < static_cast(cdfs->size()); i++) {
+ const int32_t *cdf = cdfs->at(i).data();
+ std::vector ransSym(cdfs->at(i).size());
+ const int ransSize = static_cast(ransSym.size() - 1);
+ for (int j = 0; j < ransSize; j++) {
+ ransSym[j] = RansSymbol({static_cast(cdf[j]),
+ static_cast(cdf[j + 1] - cdf[j])});
+ }
+ ransSymbols->at(i) = ransSym;
+ }
+
+ _ransSymbols.push_back(ransSymbols);
+ _cdfs_sizes.push_back(cdfs_sizes);
+ _offsets.push_back(offsets);
+ return static_cast(_ransSymbols.size()) - 1;
+}
+
+void RansEncoderLib::empty_cdf_buffer() {
+ _ransSymbols.clear();
+ _cdfs_sizes.clear();
+ _offsets.clear();
+}
+
+void RansEncoderLib::encode_with_indexes(
+ const std::shared_ptr> symbols,
+ const std::shared_ptr> indexes,
+ const int cdf_group_index) {
+
+ // backward loop on symbols from the end;
+ const int16_t *symbols_ptr = symbols->data();
+ const int16_t *indexes_ptr = indexes->data();
+ const int32_t *cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data();
+ const int32_t *offsets_ptr = _offsets[cdf_group_index]->data();
+ const int symbol_size = static_cast(symbols->size());
+ _syms.reserve(symbol_size * 3 / 2);
+ auto ransSymbols = _ransSymbols[cdf_group_index];
+
+ for (int i = 0; i < symbol_size; ++i) {
+ const int32_t cdf_idx = indexes_ptr[i];
+ if (cdf_idx < 0) {
+ continue;
+ }
+ const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2;
+ int32_t value = symbols_ptr[i] - offsets_ptr[cdf_idx];
+
+ uint32_t raw_val = 0;
+ if (value < 0) {
+ raw_val = -2 * value - 1;
+ value = max_value;
+ } else if (value >= max_value) {
+ raw_val = 2 * (value - max_value);
+ value = max_value;
+ }
+
+ _syms.push_back(ransSymbols->at(cdf_idx)[value]);
+
+ /* Bypass coding mode (value == max_value -> sentinel flag) */
+ if (value == max_value) {
+ /* Determine the number of bypasses (in bypass_precision size) needed to
+ * encode the raw value. */
+ int32_t n_bypass = 0;
+ while ((raw_val >> (n_bypass * bypass_precision)) != 0) {
+ ++n_bypass;
+ }
+
+ /* Encode number of bypasses */
+ int32_t val = n_bypass;
+ while (val >= max_bypass_val) {
+ _syms.push_back({max_bypass_val, 0});
+ val -= max_bypass_val;
+ }
+ _syms.push_back({static_cast(val), 0});
+
+ /* Encode raw value */
+ for (int32_t j = 0; j < n_bypass; ++j) {
+ const int32_t val1 =
+ (raw_val >> (j * bypass_precision)) & max_bypass_val;
+ _syms.push_back({static_cast(val1), 0});
+ }
+ }
+ }
+}
+
+void RansEncoderLib::flush() {
+ RansState rans;
+ RansEncInit(&rans);
+
+ std::vector output(_syms.size()); // too much space ?
+ uint8_t *ptrEnd = output.data() + output.size();
+ uint8_t *ptr = ptrEnd;
+ assert(ptr != nullptr);
+
+ for (auto it = _syms.rbegin(); it < _syms.rend(); it++) {
+ const RansSymbol sym = *it;
+
+ if (sym.range != 0) {
+ RansEncPut(&rans, &ptr, sym.start, sym.range, precision);
+ } else {
+ // unlikely...
+ RansEncPutBits(&rans, &ptr, sym.start, bypass_precision);
+ }
+ }
+
+ RansEncFlush(&rans, &ptr);
+
+ const int nbytes = static_cast(std::distance(ptr, ptrEnd));
+
+ _stream.resize(nbytes);
+ memcpy(_stream.data(), ptr, nbytes);
+}
+
+std::vector RansEncoderLib::get_encoded_stream() { return _stream; }
+
+void RansEncoderLib::reset() { _syms.clear(); }
+
+RansEncoderLibMultiThread::RansEncoderLibMultiThread()
+ : RansEncoderLib(), m_finish(false), m_result_ready(false),
+ m_thread(std::thread(&RansEncoderLibMultiThread::worker, this)) {}
+
+RansEncoderLibMultiThread::~RansEncoderLibMultiThread() {
+ {
+ std::lock_guard lk(m_mutex_pending);
+ std::lock_guard lk1(m_mutex_result);
+ m_finish = true;
+ }
+ m_cv_pending.notify_one();
+ m_cv_result.notify_one();
+ m_thread.join();
+}
+
+void RansEncoderLibMultiThread::encode_with_indexes(
+ const std::shared_ptr> symbols,
+ const std::shared_ptr> indexes,
+ const int cdf_group_index) {
+ PendingTask p;
+ p.workType = WorkType::Encode;
+ p.symbols = symbols;
+ p.indexes = indexes;
+ p.cdf_group_index = cdf_group_index;
+ {
+ std::unique_lock lk(m_mutex_pending);
+ m_pending.push_back(p);
+ }
+ m_cv_pending.notify_one();
+}
+
+void RansEncoderLibMultiThread::flush() {
+ PendingTask p;
+ p.workType = WorkType::Flush;
+ {
+ std::unique_lock lk(m_mutex_pending);
+ m_pending.push_back(p);
+ }
+ m_cv_pending.notify_one();
+}
+
+std::vector RansEncoderLibMultiThread::get_encoded_stream() {
+ std::unique_lock lk(m_mutex_result);
+ m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; });
+ return RansEncoderLib::get_encoded_stream();
+}
+
+void RansEncoderLibMultiThread::reset() {
+ RansEncoderLib::reset();
+ std::lock_guard lk(m_mutex_result);
+ m_result_ready = false;
+}
+
+void RansEncoderLibMultiThread::worker() {
+ while (!m_finish) {
+ std::unique_lock lk(m_mutex_pending);
+ m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; });
+ if (m_finish) {
+ lk.unlock();
+ break;
+ }
+ if (m_pending.size() == 0) {
+ lk.unlock();
+ // std::cout << "contine in worker" << std::endl;
+ continue;
+ }
+ while (m_pending.size() > 0) {
+ auto p = m_pending.front();
+ m_pending.pop_front();
+ lk.unlock();
+ if (p.workType == WorkType::Encode) {
+ RansEncoderLib::encode_with_indexes(p.symbols, p.indexes,
+ p.cdf_group_index);
+ } else if (p.workType == WorkType::Flush) {
+ RansEncoderLib::flush();
+ {
+ std::lock_guard lk_result(m_mutex_result);
+ m_result_ready = true;
+ }
+ m_cv_result.notify_one();
+ }
+ lk.lock();
+ }
+ lk.unlock();
+ }
+}
+
+void RansDecoderLib::set_stream(
+ const std::shared_ptr> encoded) {
+ _stream = encoded;
+ _ptr8 = (uint8_t *)(_stream->data());
+ RansDecInit(&_rans, &_ptr8);
+}
+
+int RansDecoderLib::add_cdf(
+ const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets) {
+ _cdfs.push_back(cdfs);
+ _cdfs_sizes.push_back(cdfs_sizes);
+ _offsets.push_back(offsets);
+ return static_cast(_cdfs.size()) - 1;
+}
+
+void RansDecoderLib::empty_cdf_buffer() {
+ _cdfs.clear();
+ _cdfs_sizes.clear();
+ _offsets.clear();
+}
+
+std::vector RansDecoderLib::decode_stream(
+ const std::shared_ptr> indexes,
+ const int cdf_group_index) {
+
+ int index_size = static_cast(indexes->size());
+ std::vector output(index_size);
+
+ int16_t *outout_ptr = output.data();
+ const int16_t *indexes_ptr = indexes->data();
+ const int32_t *cdfs_sizes_ptr = _cdfs_sizes[cdf_group_index]->data();
+ const int32_t *offsets_ptr = _offsets[cdf_group_index]->data();
+ const auto &cdfs = _cdfs[cdf_group_index];
+ for (int i = 0; i < index_size; ++i) {
+ const int32_t cdf_idx = indexes_ptr[i];
+ if (cdf_idx < 0) {
+ outout_ptr[i] = 0;
+ continue;
+ }
+ const int32_t *cdf = cdfs->at(cdf_idx).data();
+ const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2;
+ const uint32_t cum_freq = RansDecGet(&_rans, precision);
+
+ const auto cdf_end = cdf + cdfs_sizes_ptr[cdf_idx];
+ const auto it = std::find_if(cdf, cdf_end, [cum_freq](int v) {
+ return static_cast(v) > cum_freq;
+ });
+ const uint32_t s = static_cast(std::distance(cdf, it) - 1);
+
+ RansDecAdvance(&_rans, &_ptr8, cdf[s], cdf[s + 1] - cdf[s], precision);
+
+ int32_t value = static_cast(s);
+
+ if (value == max_value) {
+ /* Bypass decoding mode */
+ int32_t val = RansDecGetBits(&_rans, &_ptr8, bypass_precision);
+ int32_t n_bypass = val;
+
+ while (val == max_bypass_val) {
+ val = RansDecGetBits(&_rans, &_ptr8, bypass_precision);
+ n_bypass += val;
+ }
+
+ int32_t raw_val = 0;
+ for (int j = 0; j < n_bypass; ++j) {
+ val = RansDecGetBits(&_rans, &_ptr8, bypass_precision);
+ raw_val |= val << (j * bypass_precision);
+ }
+ value = raw_val >> 1;
+ if (raw_val & 1) {
+ value = -value - 1;
+ } else {
+ value += max_value;
+ }
+ }
+
+ const int32_t offset = offsets_ptr[cdf_idx];
+ outout_ptr[i] = static_cast(value + offset);
+ }
+ return output;
+}
diff --git a/DCVC-FM/src/cpp/rans/rans.h b/DCVC-FM/src/cpp/rans/rans.h
new file mode 100644
index 0000000..2837c17
--- /dev/null
+++ b/DCVC-FM/src/cpp/rans/rans.h
@@ -0,0 +1,152 @@
+/* Copyright 2020 InterDigital Communications, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wsign-compare"
+#endif
+
+#ifdef _MSC_VER
+#pragma warning(disable : 4244)
+#endif
+
+#include "rans_byte.h"
+
+#ifdef _MSC_VER
+#pragma warning(default : 4244)
+#endif
+
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
+
+struct RansSymbol {
+ uint16_t start;
+ uint16_t range; // range for normal coding and 0 for bypass coding
+};
+
+enum class WorkType {
+ Encode,
+ Flush,
+};
+
+struct PendingTask {
+ WorkType workType;
+ std::shared_ptr> symbols;
+ std::shared_ptr> indexes;
+ int cdf_group_index{0};
+};
+
+/* NOTE: Warning, we buffer everything for now... In case of large files we
+ * should split the bitstream into chunks... Or for a memory-bounded encoder
+ **/
+class RansEncoderLib {
+public:
+ RansEncoderLib() {}
+ virtual ~RansEncoderLib() = default;
+
+ RansEncoderLib(const RansEncoderLib &) = delete;
+ RansEncoderLib(RansEncoderLib &&) = delete;
+ RansEncoderLib &operator=(const RansEncoderLib &) = delete;
+ RansEncoderLib &operator=(RansEncoderLib &&) = delete;
+
+ virtual void
+ encode_with_indexes(const std::shared_ptr> symbols,
+ const std::shared_ptr> indexes,
+ const int cdf_group_index);
+ virtual void flush();
+ virtual std::vector get_encoded_stream();
+ virtual void reset();
+ virtual int
+ add_cdf(const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets);
+ virtual void empty_cdf_buffer();
+
+private:
+ std::vector _syms;
+ std::vector _stream;
+
+ std::vector>>>
+ _ransSymbols;
+ std::vector>> _cdfs_sizes;
+ std::vector>> _offsets;
+};
+
+class RansEncoderLibMultiThread : public RansEncoderLib {
+public:
+ RansEncoderLibMultiThread();
+ virtual ~RansEncoderLibMultiThread();
+
+ virtual void
+ encode_with_indexes(const std::shared_ptr> symbols,
+ const std::shared_ptr> indexes,
+ const int cdf_group_index) override;
+ virtual void flush() override;
+ virtual std::vector get_encoded_stream() override;
+ virtual void reset() override;
+
+ void worker();
+
+private:
+ bool m_finish;
+ bool m_result_ready;
+ std::thread m_thread;
+ std::mutex m_mutex_result;
+ std::mutex m_mutex_pending;
+ std::condition_variable m_cv_pending;
+ std::condition_variable m_cv_result;
+ std::list m_pending;
+};
+
+class RansDecoderLib {
+public:
+ RansDecoderLib() {}
+ virtual ~RansDecoderLib() = default;
+
+ RansDecoderLib(const RansDecoderLib &) = delete;
+ RansDecoderLib(RansDecoderLib &&) = delete;
+ RansDecoderLib &operator=(const RansDecoderLib &) = delete;
+ RansDecoderLib &operator=(RansDecoderLib &&) = delete;
+
+ void set_stream(const std::shared_ptr> encoded);
+
+ std::vector
+ decode_stream(const std::shared_ptr> indexes,
+ const int cdf_group_index);
+
+ virtual int
+ add_cdf(const std::shared_ptr>> cdfs,
+ const std::shared_ptr> cdfs_sizes,
+ const std::shared_ptr> offsets);
+ virtual void empty_cdf_buffer();
+
+private:
+ RansState _rans;
+ uint8_t *_ptr8;
+ std::shared_ptr> _stream;
+
+ std::vector>>> _cdfs;
+ std::vector>> _cdfs_sizes;
+ std::vector>> _offsets;
+};
diff --git a/DCVC-FM/src/cpp/rans/rans_byte.h b/DCVC-FM/src/cpp/rans/rans_byte.h
new file mode 100644
index 0000000..4d42339
--- /dev/null
+++ b/DCVC-FM/src/cpp/rans/rans_byte.h
@@ -0,0 +1,155 @@
+// The code is from https://github.com/rygorous/ryg_rans
+// The original lisence is below.
+
+// To the extent possible under law, Fabian Giesen has waived all
+// copyright and related or neighboring rights to ryg_rans, as
+// per the terms of the CC0 license:
+
+// https://creativecommons.org/publicdomain/zero/1.0
+
+// This work is published from the United States.
+
+// Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg'
+// Giesen 2014
+//
+// Not intended to be "industrial strength"; just meant to illustrate the
+// general idea.
+
+#pragma once
+
+#include
+
+#ifdef assert
+#define RansAssert assert
+#else
+#define RansAssert(x)
+#endif
+
+// READ ME FIRST:
+//
+// This is designed like a typical arithmetic coder API, but there's three
+// twists you absolutely should be aware of before you start hacking:
+//
+// 1. You need to encode data in *reverse* - last symbol first. rANS works
+// like a stack: last in, first out.
+// 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give
+// it a pointer to the *end* of your buffer (exclusive), and it will
+// slowly move towards the beginning as more bytes are emitted.
+// 3. Unlike basically any other entropy coder implementation you might
+// have used, you can interleave data from multiple independent rANS
+// encoders into the same bytestream without any extra signaling;
+// you can also just write some bytes by yourself in the middle if
+// you want to. This is in addition to the usual arithmetic encoder
+// property of being able to switch models on the fly. Writing raw
+// bytes can be useful when you have some data that you know is
+// incompressible, and is cheaper than going through the rANS encode
+// function. Using multiple rANS coders on the same byte stream wastes
+// a few bytes compared to using just one, but execution of two
+// independent encoders can happen in parallel on superscalar and
+// Out-of-Order CPUs, so this can be *much* faster in tight decoding
+// loops.
+//
+// This is why all the rANS functions take the write pointer as an
+// argument instead of just storing it in some context struct.
+
+// --------------------------------------------------------------------------
+
+// L ('l' in the paper) is the lower bound of our normalization interval.
+// Between this and our byte-aligned emission, we use 31 (not 32!) bits.
+// This is done intentionally because exact reciprocals for 31-bit uints
+// fit in 32-bit uints: this permits some optimizations during encoding.
+#define RANS_BYTE_L (1u << 23) // lower bound of our normalization interval
+
+// State for a rANS encoder. Yep, that's all there is to it.
+typedef uint32_t RansState;
+
+// Initialize a rANS encoder.
+static inline void RansEncInit(RansState *r) { *r = RANS_BYTE_L; }
+
+// Renormalize the encoder. Internal function.
+static inline RansState RansEncRenorm(RansState x, uint8_t **pptr,
+ uint32_t freq, uint32_t scale_bits) {
+ (void)scale_bits;
+ // const uint32_t x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq; // this
+ // turns into a shift.
+ const uint32_t x_max = freq << 15;
+ while (x >= x_max) {
+ *(--(*pptr)) = static_cast(x & 0xff);
+ x >>= 8;
+ }
+ return x;
+}
+
+// Encodes a single symbol with range start "start" and frequency "freq".
+// All frequencies are assumed to sum to "1 << scale_bits", and the
+// resulting bytes get written to ptr (which is updated).
+//
+// NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from
+// beginning to end! Likewise, the output bytestream is written *backwards*:
+// ptr starts pointing at the end of the output buffer and keeps decrementing.
+static inline void RansEncPut(RansState *r, uint8_t **pptr, uint32_t start,
+ uint32_t freq, uint32_t scale_bits) {
+ // renormalize
+ RansState x = RansEncRenorm(*r, pptr, freq, scale_bits);
+
+ // x = C(s,x)
+ *r = ((x / freq) << scale_bits) + (x % freq) + start;
+}
+
+// Flushes the rANS encoder.
+static inline void RansEncFlush(RansState *r, uint8_t **pptr) {
+ uint32_t x = *r;
+ uint8_t *ptr = *pptr;
+
+ ptr -= 4;
+ ptr[0] = (uint8_t)(x >> 0);
+ ptr[1] = (uint8_t)(x >> 8);
+ ptr[2] = (uint8_t)(x >> 16);
+ ptr[3] = (uint8_t)(x >> 24);
+
+ *pptr = ptr;
+}
+
+// Initializes a rANS decoder.
+// Unlike the encoder, the decoder works forwards as you'd expect.
+static inline void RansDecInit(RansState *r, uint8_t **pptr) {
+ uint32_t x;
+ uint8_t *ptr = *pptr;
+
+ x = ptr[0] << 0;
+ x |= ptr[1] << 8;
+ x |= ptr[2] << 16;
+ x |= ptr[3] << 24;
+ ptr += 4;
+
+ *pptr = ptr;
+ *r = x;
+}
+
+// Returns the current cumulative frequency (map it to a symbol yourself!)
+static inline uint32_t RansDecGet(RansState *r, uint32_t scale_bits) {
+ return *r & ((1u << scale_bits) - 1);
+}
+
+// Advances in the bit stream by "popping" a single symbol with range start
+// "start" and frequency "freq". All frequencies are assumed to sum to "1 <<
+// scale_bits", and the resulting bytes get written to ptr (which is updated).
+static inline void RansDecAdvance(RansState *r, uint8_t **pptr, uint32_t start,
+ uint32_t freq, uint32_t scale_bits) {
+ uint32_t mask = (1u << scale_bits) - 1;
+
+ // s, x = D(x)
+ uint32_t x = *r;
+ x = freq * (x >> scale_bits) + (x & mask) - start;
+
+ // renormalize
+ if (x < RANS_BYTE_L) {
+ uint8_t *ptr = *pptr;
+ do
+ x = (x << 8) | *ptr++;
+ while (x < RANS_BYTE_L);
+ *pptr = ptr;
+ }
+
+ *r = x;
+}
diff --git a/DCVC-FM/src/models/block_mc.py b/DCVC-FM/src/models/block_mc.py
new file mode 100644
index 0000000..c16e8b9
--- /dev/null
+++ b/DCVC-FM/src/models/block_mc.py
@@ -0,0 +1,80 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+
+import torch
+
+
+CUSTOMIZED_CUDA = False
+try:
+ from .extensions.block_mc_cpp_cuda import block_mc_forward # pylint: disable=E0401, E0611
+ CUSTOMIZED_CUDA = True
+except Exception: # pylint: disable=W0718
+ pass
+
+if not CUSTOMIZED_CUDA:
+ try:
+ from block_mc_cpp_cuda import block_mc_forward # pylint: disable=E0401 # noqa: F811
+ CUSTOMIZED_CUDA = True
+ except Exception: # pylint: disable=W0718
+ pass
+
+if not CUSTOMIZED_CUDA and 'SUPPRESS_CUSTOM_KERNEL_WARNING' not in os.environ:
+ print("cannot import motion compensation in cuda, fallback to pytorch grid_sample.")
+
+
+backward_grid = [{} for _ in range(9)] # 0~7 for GPU, -1 for CPU
+FORCE_RECALCULATE_GRID = False
+
+
+def set_force_recalculate_grid(force):
+ global FORCE_RECALCULATE_GRID
+ FORCE_RECALCULATE_GRID = force
+
+
+def add_grid_cache(flow):
+ device_id = -1 if flow.device == torch.device('cpu') else flow.device.index
+ if str(flow.size()) not in backward_grid[device_id] or FORCE_RECALCULATE_GRID:
+ B, _, H, W = flow.size()
+ tensor_hor = torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32).view(
+ 1, 1, 1, W).expand(B, -1, H, -1)
+ tensor_ver = torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32).view(
+ 1, 1, H, 1).expand(B, -1, -1, W)
+ backward_grid[device_id][str(flow.size())] = torch.cat([tensor_hor, tensor_ver], 1)
+
+
+def torch_warp(feature, flow):
+ device_id = -1 if feature.device == torch.device('cpu') else feature.device.index
+ add_grid_cache(flow)
+ flow = torch.cat([flow[:, 0:1, :, :] / ((feature.size(3) - 1.0) / 2.0),
+ flow[:, 1:2, :, :] / ((feature.size(2) - 1.0) / 2.0)], 1)
+
+ grid = backward_grid[device_id][str(flow.size())] + flow
+ return torch.nn.functional.grid_sample(input=feature,
+ grid=grid.permute(0, 2, 3, 1),
+ mode='bilinear',
+ padding_mode='border',
+ align_corners=True)
+
+
+def flow_warp(im, flow):
+ is_float16 = False
+ if im.dtype == torch.float16:
+ is_float16 = True
+ im = im.to(torch.float32)
+ flow = flow.to(torch.float32)
+ warp = torch_warp(im, flow)
+ if is_float16:
+ warp = warp.to(torch.float16)
+ return warp
+
+
+def block_mc_func(im, flow):
+ if not CUSTOMIZED_CUDA:
+ return flow_warp(im, flow)
+ with torch.no_grad():
+ B, C, H, W = im.size()
+ out = torch.empty_like(im)
+ block_mc_forward(out, im, flow, B, C, H, W)
+ return out
diff --git a/DCVC-FM/src/models/common_model.py b/DCVC-FM/src/models/common_model.py
new file mode 100644
index 0000000..924f538
--- /dev/null
+++ b/DCVC-FM/src/models/common_model.py
@@ -0,0 +1,324 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import math
+
+import torch
+from torch import nn
+
+from .entropy_models import BitEstimator, GaussianEncoder, EntropyCoder
+from ..utils.stream_helper import get_padding_size
+
+
+class CompressionModel(nn.Module):
+ def __init__(self, y_distribution, z_channel, mv_z_channel=None,
+ ec_thread=False, stream_part=1):
+ super().__init__()
+
+ self.y_distribution = y_distribution
+ self.z_channel = z_channel
+ self.mv_z_channel = mv_z_channel
+ self.entropy_coder = None
+ if mv_z_channel is None:
+ self.bit_estimator_z = BitEstimator(64, z_channel)
+ self.bit_estimator_z_mv = None
+ else:
+ self.bit_estimator_z = BitEstimator(1, z_channel)
+ self.bit_estimator_z_mv = BitEstimator(1, mv_z_channel)
+ self.gaussian_encoder = GaussianEncoder(distribution=y_distribution)
+ self.ec_thread = ec_thread
+ self.stream_part = stream_part
+
+ self.masks = {}
+
+ def quant(self, x):
+ return torch.round(x)
+
+ def get_one_q_scale(self, q_scale, q_index):
+ min_q = q_scale[0:1, :, :, :]
+ max_q = q_scale[1:2, :, :, :]
+ step = (torch.log(max_q) - torch.log(min_q)) / (self.get_qp_num() - 1)
+ q = torch.exp(torch.log(min_q) + step * q_index)
+ return q
+
+ def get_curr_q(self, q_scale, q_index):
+ if isinstance(q_index, list):
+ q_step = [self.get_one_q_scale(q_scale, i) for i in q_index]
+ q_step = torch.cat(q_step, dim=0)
+ else:
+ q_step = self.get_one_q_scale(q_scale, q_index)
+
+ return q_step
+
+ @staticmethod
+ def get_index_tensor(q_index, device):
+ if not isinstance(q_index, list):
+ q_index = [q_index]
+ return torch.tensor(q_index, dtype=torch.int32, device=device)
+
+ @staticmethod
+ def get_qp_num():
+ return 64
+
+
+ @staticmethod
+ def probs_to_bits(probs):
+ factor = -1.0 / math.log(2.0)
+ bits = torch.log(probs + 1e-5) * factor
+ bits = torch.clamp(bits, 0, None)
+ return bits
+
+ def get_y_gaussian_bits(self, y, sigma):
+ mu = torch.zeros_like(sigma)
+ sigma = sigma.clamp(1e-5, 1e10)
+ gaussian = torch.distributions.normal.Normal(mu, sigma)
+ probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5)
+ probs = probs.to(torch.float32)
+ return CompressionModel.probs_to_bits(probs)
+
+ def get_y_laplace_bits(self, y, sigma):
+ mu = torch.zeros_like(sigma)
+ sigma = sigma.clamp(1e-5, 1e10)
+ gaussian = torch.distributions.laplace.Laplace(mu, sigma)
+ probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5)
+ probs = probs.to(torch.float32)
+ return CompressionModel.probs_to_bits(probs)
+
+ def get_z_bits(self, z, bit_estimator, index):
+ probs = bit_estimator.get_cdf(z + 0.5, index) - bit_estimator.get_cdf(z - 0.5, index)
+ probs = probs.to(torch.float32)
+ return CompressionModel.probs_to_bits(probs)
+
+ def update(self, force=False):
+ self.entropy_coder = EntropyCoder(self.ec_thread, self.stream_part)
+ self.gaussian_encoder.update(force=force, entropy_coder=self.entropy_coder)
+ self.bit_estimator_z.update(force=force, entropy_coder=self.entropy_coder)
+ if self.bit_estimator_z_mv is not None:
+ self.bit_estimator_z_mv.update(force=force, entropy_coder=self.entropy_coder)
+
+ def pad_for_y(self, y):
+ _, _, H, W = y.size()
+ padding_l, padding_r, padding_t, padding_b = get_padding_size(H, W, 4)
+ y_pad = torch.nn.functional.pad(
+ y,
+ (padding_l, padding_r, padding_t, padding_b),
+ mode="replicate",
+ )
+ return y_pad, (-padding_l, -padding_r, -padding_t, -padding_b)
+
+ @staticmethod
+ def get_to_y_slice_shape(height, width):
+ padding_l, padding_r, padding_t, padding_b = get_padding_size(height, width, 4)
+ return (-padding_l, -padding_r, -padding_t, -padding_b)
+
+ def slice_to_y(self, param, slice_shape):
+ return torch.nn.functional.pad(param, slice_shape)
+
+ @staticmethod
+ def separate_prior(params, is_video=False):
+ if is_video:
+ quant_step, scales, means = params.chunk(3, 1)
+ quant_step = torch.clamp(quant_step, 0.5, None)
+ q_enc = 1. / quant_step
+ q_dec = quant_step
+ else:
+ q = params[:, :2, :, :]
+ q_enc, q_dec = (torch.sigmoid(q) * 1.5 + 0.5).chunk(2, 1)
+ scales, means = params[:, 2:, :, :].chunk(2, 1)
+ return q_enc, q_dec, scales, means
+
+ def get_mask(self, height, width, dtype, device):
+ curr_mask_str = f"{width}x{height}"
+ if curr_mask_str not in self.masks:
+ micro_mask = torch.tensor(((1, 0), (0, 1)), dtype=dtype, device=device)
+ mask_0 = micro_mask.repeat((height + 1) // 2, (width + 1) // 2)
+ mask_0 = mask_0[:height, :width]
+ mask_0 = torch.unsqueeze(mask_0, 0)
+ mask_0 = torch.unsqueeze(mask_0, 0)
+ mask_1 = torch.ones_like(mask_0) - mask_0
+ self.masks[curr_mask_str] = [mask_0, mask_1]
+ return self.masks[curr_mask_str]
+
+ def process_with_mask(self, y, scales, means, mask):
+ scales_hat = scales * mask
+ means_hat = means * mask
+
+ y_res = (y - means_hat) * mask
+ y_q = self.quant(y_res)
+ y_hat = y_q + means_hat
+
+ return y_res, y_q, y_hat, scales_hat
+
+ @staticmethod
+ def get_one_channel_four_parts_mask(height, width, dtype, device):
+ micro_mask_0 = torch.tensor(((1, 0), (0, 0)), dtype=dtype, device=device)
+ mask_0 = micro_mask_0.repeat((height + 1) // 2, (width + 1) // 2)
+ mask_0 = mask_0[:height, :width]
+ mask_0 = torch.unsqueeze(mask_0, 0)
+ mask_0 = torch.unsqueeze(mask_0, 0)
+
+ micro_mask_1 = torch.tensor(((0, 1), (0, 0)), dtype=dtype, device=device)
+ mask_1 = micro_mask_1.repeat((height + 1) // 2, (width + 1) // 2)
+ mask_1 = mask_1[:height, :width]
+ mask_1 = torch.unsqueeze(mask_1, 0)
+ mask_1 = torch.unsqueeze(mask_1, 0)
+
+ micro_mask_2 = torch.tensor(((0, 0), (1, 0)), dtype=dtype, device=device)
+ mask_2 = micro_mask_2.repeat((height + 1) // 2, (width + 1) // 2)
+ mask_2 = mask_2[:height, :width]
+ mask_2 = torch.unsqueeze(mask_2, 0)
+ mask_2 = torch.unsqueeze(mask_2, 0)
+
+ micro_mask_3 = torch.tensor(((0, 0), (0, 1)), dtype=dtype, device=device)
+ mask_3 = micro_mask_3.repeat((height + 1) // 2, (width + 1) // 2)
+ mask_3 = mask_3[:height, :width]
+ mask_3 = torch.unsqueeze(mask_3, 0)
+ mask_3 = torch.unsqueeze(mask_3, 0)
+
+ return mask_0, mask_1, mask_2, mask_3
+
+ def get_mask_four_parts(self, batch, channel, height, width, dtype, device):
+ curr_mask_str = f"{batch}_{channel}x{width}x{height}"
+ with torch.no_grad():
+ if curr_mask_str not in self.masks:
+ assert channel % 4 == 0
+ m = torch.ones((batch, channel // 4, height, width), dtype=dtype, device=device)
+ m0, m1, m2, m3 = self.get_one_channel_four_parts_mask(height, width, dtype, device)
+
+ mask_0 = torch.cat((m * m0, m * m1, m * m2, m * m3), dim=1)
+ mask_1 = torch.cat((m * m3, m * m2, m * m1, m * m0), dim=1)
+ mask_2 = torch.cat((m * m2, m * m3, m * m0, m * m1), dim=1)
+ mask_3 = torch.cat((m * m1, m * m0, m * m3, m * m2), dim=1)
+
+ self.masks[curr_mask_str] = [mask_0, mask_1, mask_2, mask_3]
+ return self.masks[curr_mask_str]
+
+ @staticmethod
+ def combine_four_parts(x_0_0, x_0_1, x_0_2, x_0_3,
+ x_1_0, x_1_1, x_1_2, x_1_3,
+ x_2_0, x_2_1, x_2_2, x_2_3,
+ x_3_0, x_3_1, x_3_2, x_3_3):
+ x_0 = x_0_0 + x_0_1 + x_0_2 + x_0_3
+ x_1 = x_1_0 + x_1_1 + x_1_2 + x_1_3
+ x_2 = x_2_0 + x_2_1 + x_2_2 + x_2_3
+ x_3 = x_3_0 + x_3_1 + x_3_2 + x_3_3
+ return torch.cat((x_0, x_1, x_2, x_3), dim=1)
+
+ @staticmethod
+ def combine_for_writing(x):
+ x0, x1, x2, x3 = x.chunk(4, 1)
+ return (x0 + x1) + (x2 + x3)
+
+ def forward_four_part_prior(self, y, common_params,
+ y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
+ y_spatial_prior_adaptor_3, y_spatial_prior,
+ y_spatial_prior_reduction=None, write=False):
+ '''
+ y_0 means split in channel, the 0/4 quater
+ y_1 means split in channel, the 1/4 quater
+ y_2 means split in channel, the 2/4 quater
+ y_3 means split in channel, the 3/4 quater
+ y_?_0, means multiply with mask_0
+ y_?_1, means multiply with mask_1
+ y_?_2, means multiply with mask_2
+ y_?_3, means multiply with mask_3
+ '''
+ q_enc, q_dec, scales, means = self.separate_prior(common_params,
+ y_spatial_prior_reduction is None)
+ if y_spatial_prior_reduction is not None:
+ common_params = y_spatial_prior_reduction(common_params)
+ dtype = y.dtype
+ device = y.device
+ B, C, H, W = y.size()
+ mask_0, mask_1, mask_2, mask_3 = self.get_mask_four_parts(B, C, H, W, dtype, device)
+
+ y = y * q_enc
+
+ y_res_0, y_q_0, y_hat_0, s_hat_0 = self.process_with_mask(y, scales, means, mask_0)
+
+ y_hat_so_far = y_hat_0
+ params = torch.cat((y_hat_so_far, common_params), dim=1)
+ scales, means = y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(2, 1)
+ y_res_1, y_q_1, y_hat_1, s_hat_1 = self.process_with_mask(y, scales, means, mask_1)
+
+ y_hat_so_far = y_hat_so_far + y_hat_1
+ params = torch.cat((y_hat_so_far, common_params), dim=1)
+ scales, means = y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(2, 1)
+ y_res_2, y_q_2, y_hat_2, s_hat_2 = self.process_with_mask(y, scales, means, mask_2)
+
+ y_hat_so_far = y_hat_so_far + y_hat_2
+ params = torch.cat((y_hat_so_far, common_params), dim=1)
+ scales, means = y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(2, 1)
+ y_res_3, y_q_3, y_hat_3, s_hat_3 = self.process_with_mask(y, scales, means, mask_3)
+
+ y_res = (y_res_0 + y_res_1) + (y_res_2 + y_res_3)
+ y_q = (y_q_0 + y_q_1) + (y_q_2 + y_q_3)
+ y_hat = y_hat_so_far + y_hat_3
+ scales_hat = (s_hat_0 + s_hat_1) + (s_hat_2 + s_hat_3)
+
+ y_hat = y_hat * q_dec
+
+ if write:
+ y_q_w_0 = self.combine_for_writing(y_q_0)
+ y_q_w_1 = self.combine_for_writing(y_q_1)
+ y_q_w_2 = self.combine_for_writing(y_q_2)
+ y_q_w_3 = self.combine_for_writing(y_q_3)
+ scales_w_0 = self.combine_for_writing(s_hat_0)
+ scales_w_1 = self.combine_for_writing(s_hat_1)
+ scales_w_2 = self.combine_for_writing(s_hat_2)
+ scales_w_3 = self.combine_for_writing(s_hat_3)
+ return y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \
+ scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat
+ return y_res, y_q, y_hat, scales_hat
+
+ def compress_four_part_prior(self, y, common_params,
+ y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
+ y_spatial_prior_adaptor_3, y_spatial_prior,
+ y_spatial_prior_reduction=None):
+ return self.forward_four_part_prior(y, common_params,
+ y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
+ y_spatial_prior_adaptor_3, y_spatial_prior,
+ y_spatial_prior_reduction, write=True)
+
+ def decompress_four_part_prior(self, common_params,
+ y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
+ y_spatial_prior_adaptor_3, y_spatial_prior,
+ y_spatial_prior_reduction=None):
+ _, quant_step, scales, means = self.separate_prior(common_params,
+ y_spatial_prior_reduction is None)
+ if y_spatial_prior_reduction is not None:
+ common_params = y_spatial_prior_reduction(common_params)
+ dtype = means.dtype
+ device = means.device
+ B, C, H, W = means.size()
+ mask_0, mask_1, mask_2, mask_3 = self.get_mask_four_parts(B, C, H, W, dtype, device)
+
+ scales_r = self.combine_for_writing(scales * mask_0)
+ y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
+ y_hat_curr_step = (torch.cat((y_q_r, y_q_r, y_q_r, y_q_r), dim=1) + means) * mask_0
+ y_hat_so_far = y_hat_curr_step
+
+ params = torch.cat((y_hat_so_far, common_params), dim=1)
+ scales, means = y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(2, 1)
+ scales_r = self.combine_for_writing(scales * mask_1)
+ y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
+ y_hat_curr_step = (torch.cat((y_q_r, y_q_r, y_q_r, y_q_r), dim=1) + means) * mask_1
+ y_hat_so_far = y_hat_so_far + y_hat_curr_step
+
+ params = torch.cat((y_hat_so_far, common_params), dim=1)
+ scales, means = y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(2, 1)
+ scales_r = self.combine_for_writing(scales * mask_2)
+ y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
+ y_hat_curr_step = (torch.cat((y_q_r, y_q_r, y_q_r, y_q_r), dim=1) + means) * mask_2
+ y_hat_so_far = y_hat_so_far + y_hat_curr_step
+
+ params = torch.cat((y_hat_so_far, common_params), dim=1)
+ scales, means = y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(2, 1)
+ scales_r = self.combine_for_writing(scales * mask_3)
+ y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
+ y_hat_curr_step = (torch.cat((y_q_r, y_q_r, y_q_r, y_q_r), dim=1) + means) * mask_3
+ y_hat_so_far = y_hat_so_far + y_hat_curr_step
+
+ y_hat = y_hat_so_far * quant_step
+
+ return y_hat
diff --git a/DCVC-FM/src/models/entropy_models.py b/DCVC-FM/src/models/entropy_models.py
new file mode 100644
index 0000000..a941fa4
--- /dev/null
+++ b/DCVC-FM/src/models/entropy_models.py
@@ -0,0 +1,304 @@
+import math
+
+import torch
+import numpy as np
+from torch import nn
+import torch.nn.functional as F
+
+
+class EntropyCoder():
+ def __init__(self, ec_thread=False, stream_part=1):
+ super().__init__()
+
+ from .MLCodec_rans import RansEncoder, RansDecoder # pylint: disable=E0401
+ self.encoder = RansEncoder(ec_thread, stream_part)
+ self.decoder = RansDecoder(stream_part)
+
+ @staticmethod
+ def pmf_to_quantized_cdf(pmf, precision=16):
+ from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_cdf # pylint: disable=E0401
+ cdf = _pmf_to_cdf(pmf.tolist(), precision)
+ cdf = torch.IntTensor(cdf)
+ return cdf
+
+ @staticmethod
+ def pmf_to_cdf(pmf, tail_mass, pmf_length, max_length):
+ entropy_coder_precision = 16
+ cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32)
+ for i, p in enumerate(pmf):
+ prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
+ _cdf = EntropyCoder.pmf_to_quantized_cdf(prob, entropy_coder_precision)
+ cdf[i, : _cdf.size(0)] = _cdf
+ return cdf
+
+ def reset(self):
+ self.encoder.reset()
+
+ def add_cdf(self, cdf, cdf_length, offset):
+ enc_cdf_idx = self.encoder.add_cdf(cdf, cdf_length, offset)
+ dec_cdf_idx = self.decoder.add_cdf(cdf, cdf_length, offset)
+ assert enc_cdf_idx == dec_cdf_idx
+ return enc_cdf_idx
+
+ def encode_with_indexes(self, symbols, indexes, cdf_group_index):
+ self.encoder.encode_with_indexes(symbols.clamp(-30000, 30000).to(torch.int16).cpu().numpy(),
+ indexes.to(torch.int16).cpu().numpy(),
+ cdf_group_index)
+
+ def encode_with_indexes_np(self, symbols, indexes, cdf_group_index):
+ self.encoder.encode_with_indexes(symbols.clip(-30000, 30000).astype(np.int16).reshape(-1),
+ indexes.astype(np.int16).reshape(-1),
+ cdf_group_index)
+
+ def flush(self):
+ self.encoder.flush()
+
+ def get_encoded_stream(self):
+ return self.encoder.get_encoded_stream().tobytes()
+
+ def set_stream(self, stream):
+ self.decoder.set_stream((np.frombuffer(stream, dtype=np.uint8)))
+
+ def decode_stream(self, indexes, cdf_group_index):
+ rv = self.decoder.decode_stream(indexes.to(torch.int16).cpu().numpy(),
+ cdf_group_index)
+ rv = torch.Tensor(rv)
+ return rv
+
+ def decode_stream_np(self, indexes, cdf_group_index):
+ rv = self.decoder.decode_stream(indexes.astype(np.int16).reshape(-1),
+ cdf_group_index)
+ return rv
+
+
+class Bitparm(nn.Module):
+ def __init__(self, qp_num, channel, final=False):
+ super().__init__()
+ self.final = final
+ self.h = nn.Parameter(torch.nn.init.normal_(
+ torch.empty([qp_num, channel, 1, 1]), 0, 0.01))
+ self.b = nn.Parameter(torch.nn.init.normal_(
+ torch.empty([qp_num, channel, 1, 1]), 0, 0.01))
+ if not final:
+ self.a = nn.Parameter(torch.nn.init.normal_(
+ torch.empty([qp_num, channel, 1, 1]), 0, 0.01))
+ else:
+ self.a = None
+
+ def forward(self, x, index):
+ h = torch.index_select(self.h, 0, index)
+ b = torch.index_select(self.b, 0, index)
+ x = x * F.softplus(h) + b
+ if self.final:
+ return x
+
+ a = torch.index_select(self.a, 0, index)
+ return x + torch.tanh(x) * torch.tanh(a)
+
+
+class AEHelper():
+ def __init__(self):
+ super().__init__()
+ self.entropy_coder = None
+ self.cdf_group_index = None
+ self._offset = None
+ self._quantized_cdf = None
+ self._cdf_length = None
+
+ def set_cdf_info(self, quantized_cdf, cdf_length, offset):
+ self._quantized_cdf = quantized_cdf.cpu().numpy()
+ self._cdf_length = cdf_length.reshape(-1).int().cpu().numpy()
+ self._offset = offset.reshape(-1).int().cpu().numpy()
+
+ def get_cdf_info(self):
+ return self._quantized_cdf, \
+ self._cdf_length, \
+ self._offset
+
+
+class BitEstimator(AEHelper, nn.Module):
+ def __init__(self, qp_num, channel):
+ super().__init__()
+ self.f1 = Bitparm(qp_num, channel)
+ self.f2 = Bitparm(qp_num, channel)
+ self.f3 = Bitparm(qp_num, channel)
+ self.f4 = Bitparm(qp_num, channel, True)
+ self.qp_num = qp_num
+ self.channel = channel
+
+ def forward(self, x, index):
+ return self.get_cdf(x, index)
+
+ def get_logits_cdf(self, x, index):
+ x = self.f1(x, index)
+ x = self.f2(x, index)
+ x = self.f3(x, index)
+ x = self.f4(x, index)
+ return x
+
+ def get_cdf(self, x, index):
+ return torch.sigmoid(self.get_logits_cdf(x, index))
+
+ def update(self, force=False, entropy_coder=None):
+ assert entropy_coder is not None
+ self.entropy_coder = entropy_coder
+
+ if not force and self._offset is not None:
+ return
+
+ with torch.no_grad():
+ device = next(self.parameters()).device
+ medians = torch.zeros((self.qp_num, self.channel, 1, 1), device=device)
+ index = torch.arange(self.qp_num, device=device, dtype=torch.int32)
+
+ minima = medians + 50
+ for i in range(50, 1, -1):
+ samples = torch.zeros_like(medians) - i
+ probs = self.forward(samples, index)
+ minima = torch.where(probs < torch.zeros_like(medians) + 0.0001,
+ torch.zeros_like(medians) + i, minima)
+
+ maxima = medians + 50
+ for i in range(50, 1, -1):
+ samples = torch.zeros_like(medians) + i
+ probs = self.forward(samples, index)
+ maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999,
+ torch.zeros_like(medians) + i, maxima)
+
+ minima = minima.int()
+ maxima = maxima.int()
+
+ offset = -minima
+
+ pmf_start = medians - minima
+ pmf_length = maxima + minima + 1
+
+ max_length = pmf_length.max()
+ device = pmf_start.device
+ samples = torch.arange(max_length, device=device)
+
+ samples = samples[None, None, None, :] + pmf_start
+
+ half = float(0.5)
+
+ lower = self.forward(samples - half, index)
+ upper = self.forward(samples + half, index)
+ pmf = upper - lower
+
+ pmf = pmf[:, :, 0, :]
+ tail_mass = lower[:, :, 0, :1] + (1.0 - upper[:, :, 0, -1:])
+
+ pmf = pmf.reshape([-1, max_length])
+ tail_mass = tail_mass.reshape([-1, 1])
+ pmf_length = pmf_length.reshape([-1])
+ offset = offset.reshape([-1])
+ quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
+ cdf_length = pmf_length + 2
+ self.set_cdf_info(quantized_cdf, cdf_length, offset)
+ self.cdf_group_index = self.entropy_coder.add_cdf(*self.get_cdf_info())
+
+ def build_indexes(self, size, qp):
+ B, C, H, W = size
+ indexes = torch.arange(C, dtype=torch.int).view(1, -1, 1, 1) + qp * self.channel
+ return indexes.repeat(B, 1, H, W)
+
+ def build_indexes_np(self, size, qp):
+ return self.build_indexes(size, qp).cpu().numpy()
+
+ def encode(self, x, qp):
+ indexes = self.build_indexes(x.size(), qp)
+ return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1),
+ self.cdf_group_index)
+
+ def decode_stream(self, size, dtype, device, qp):
+ output_size = (1, self.channel, size[0], size[1])
+ indexes = self.build_indexes(output_size, qp)
+ val = self.entropy_coder.decode_stream(indexes.reshape(-1), self.cdf_group_index)
+ val = val.reshape(indexes.shape)
+ return val.to(dtype).to(device)
+
+
+class GaussianEncoder(AEHelper):
+ def __init__(self, distribution='laplace'):
+ super().__init__()
+ assert distribution in ['laplace', 'gaussian']
+ self.distribution = distribution
+ if distribution == 'laplace':
+ self.cdf_distribution = torch.distributions.laplace.Laplace
+ self.scale_min = 0.01
+ self.scale_max = 64.0
+ self.scale_level = 256
+ elif distribution == 'gaussian':
+ self.cdf_distribution = torch.distributions.normal.Normal
+ self.scale_min = 0.11
+ self.scale_max = 64.0
+ self.scale_level = 256
+ self.scale_table = self.get_scale_table(self.scale_min, self.scale_max, self.scale_level)
+
+ self.log_scale_min = math.log(self.scale_min)
+ self.log_scale_max = math.log(self.scale_max)
+ self.log_scale_step = (self.log_scale_max - self.log_scale_min) / (self.scale_level - 1)
+
+ @staticmethod
+ def get_scale_table(min_val, max_val, levels):
+ return torch.exp(torch.linspace(math.log(min_val), math.log(max_val), levels))
+
+ def update(self, force=False, entropy_coder=None):
+ assert entropy_coder is not None
+ self.entropy_coder = entropy_coder
+
+ if not force and self._offset is not None:
+ return
+
+ pmf_center = torch.zeros_like(self.scale_table) + 50
+ scales = torch.zeros_like(pmf_center) + self.scale_table
+ mu = torch.zeros_like(scales)
+ cdf_distribution = self.cdf_distribution(mu, scales)
+ for i in range(50, 1, -1):
+ samples = torch.zeros_like(pmf_center) + i
+ probs = cdf_distribution.cdf(samples)
+ probs = torch.squeeze(probs)
+ pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999,
+ torch.zeros_like(pmf_center) + i, pmf_center)
+
+ pmf_center = pmf_center.int()
+ pmf_length = 2 * pmf_center + 1
+ max_length = torch.max(pmf_length).item()
+
+ device = pmf_center.device
+ samples = torch.arange(max_length, device=device) - pmf_center[:, None]
+ samples = samples.float()
+
+ scales = torch.zeros_like(samples) + self.scale_table[:, None]
+ mu = torch.zeros_like(scales)
+ cdf_distribution = self.cdf_distribution(mu, scales)
+
+ upper = cdf_distribution.cdf(samples + 0.5)
+ lower = cdf_distribution.cdf(samples - 0.5)
+ pmf = upper - lower
+
+ tail_mass = 2 * lower[:, :1]
+
+ quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
+ quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
+
+ self.set_cdf_info(quantized_cdf, pmf_length+2, -pmf_center)
+ self.cdf_group_index = self.entropy_coder.add_cdf(*self.get_cdf_info())
+
+ def build_indexes(self, scales):
+ scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5)
+ indexes = (torch.log(scales) - self.log_scale_min) / self.log_scale_step
+ indexes = indexes.clamp_(0, self.scale_level - 1)
+ return indexes.int()
+
+ def encode(self, x, scales):
+ indexes = self.build_indexes(scales)
+ return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1),
+ self.cdf_group_index)
+
+ def decode_stream(self, scales, dtype, device):
+ indexes = self.build_indexes(scales)
+ val = self.entropy_coder.decode_stream(indexes.reshape(-1),
+ self.cdf_group_index)
+ val = val.reshape(scales.shape)
+ return val.to(device).to(dtype)
diff --git a/DCVC-FM/src/models/extensions/block_mc.cpp b/DCVC-FM/src/models/extensions/block_mc.cpp
new file mode 100644
index 0000000..c9cd9e9
--- /dev/null
+++ b/DCVC-FM/src/models/extensions/block_mc.cpp
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include "block_mc.h"
+#include
+
+void block_mc_forward(torch::Tensor &out, const torch::Tensor &im,
+ const torch::Tensor &flow, const int B, const int C,
+ const int H, const int W) {
+ block_mc_forward_cuda(out, im, flow, B, C, H, W);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("block_mc_forward", &block_mc_forward, "Motion Compensation forward");
+}
diff --git a/DCVC-FM/src/models/extensions/block_mc.h b/DCVC-FM/src/models/extensions/block_mc.h
new file mode 100644
index 0000000..5012a45
--- /dev/null
+++ b/DCVC-FM/src/models/extensions/block_mc.h
@@ -0,0 +1,8 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include
+
+void block_mc_forward_cuda(torch::Tensor &out, const torch::Tensor &im,
+ const torch::Tensor &flow, const int B, const int C,
+ const int H, const int W);
diff --git a/DCVC-FM/src/models/extensions/block_mc_kernel.cu b/DCVC-FM/src/models/extensions/block_mc_kernel.cu
new file mode 100644
index 0000000..cbc2a4a
--- /dev/null
+++ b/DCVC-FM/src/models/extensions/block_mc_kernel.cu
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+#include
+
+#include "block_mc.h"
+#include "common.h"
+#include
+
+inline __device__ float __tofloat(float a) { return a; }
+
+inline __device__ float __tofloat(__half a) { return __half2float(a); }
+
+inline __device__ float __multiply_add(float a, float b, float c) {
+ return __fmaf_rn(a, b, c);
+}
+
+inline __device__ __half __multiply_add(__half a, __half b, __half c) {
+ return __hfma(a, b, c);
+}
+
+template
+__global__ void block_mc_forward_kernel(GPUTensor out, const GPUTensor im,
+ const GPUTensor flow, const int B,
+ const int C, const int H, const int W) {
+ const int b = blockIdx.z;
+ const int h = blockIdx.y * blockDim.y + threadIdx.y;
+ const int w = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (h < H && w < W) {
+ const T x_off = flow.ptr[b * flow.stride[0] + 0 * flow.stride[1] +
+ h * flow.stride[2] + w * flow.stride[3]];
+ const T y_off = flow.ptr[b * flow.stride[0] + 1 * flow.stride[1] +
+ h * flow.stride[2] + w * flow.stride[3]];
+ float x_pos = __tofloat(x_off) + static_cast(w);
+ float y_pos = __tofloat(y_off) + static_cast(h);
+ x_pos = min(max(x_pos, 0.f), static_cast(W - 1));
+ y_pos = min(max(y_pos, 0.f), static_cast(H - 1));
+ int x0 = __float2int_rd(x_pos);
+ int x1 = min(x0 + 1, W - 1);
+ int y0 = __float2int_rd(y_pos);
+ int y1 = min(y0 + 1, H - 1);
+
+ float w_r = x_pos - static_cast(x0);
+ float w_l = 1.f - w_r;
+ float w_b = y_pos - static_cast(y0);
+ float w_t = 1.f - w_b;
+
+ const T wa = __totype(w_l * w_t);
+ const T wb = __totype(w_l * w_b);
+ const T wc = __totype(w_r * w_t);
+ const T wd = __totype(w_r * w_b);
+
+ for (int c = 0; c < C; c++) {
+ const int baseOffset = b * im.stride[0] + c * im.stride[1];
+
+ T r = __totype(0.f);
+ const T ima = im.ptr[baseOffset + y0 * im.stride[2] + x0 * im.stride[3]];
+ r = __multiply_add(ima, wa, r);
+ const T imb = im.ptr[baseOffset + y1 * im.stride[2] + x0 * im.stride[3]];
+ r = __multiply_add(imb, wb, r);
+ const T imc = im.ptr[baseOffset + y0 * im.stride[2] + x1 * im.stride[3]];
+ r = __multiply_add(imc, wc, r);
+ const T imd = im.ptr[baseOffset + y1 * im.stride[2] + x1 * im.stride[3]];
+ r = __multiply_add(imd, wd, r);
+ out.ptr[b * out.stride[0] + c * out.stride[1] + h * out.stride[2] +
+ w * out.stride[3]] = r;
+ }
+ }
+}
+
+void block_mc_forward_cuda(torch::Tensor &out, const torch::Tensor &im,
+ const torch::Tensor &flow, const int B, const int C,
+ const int H, const int W) {
+ const int BLOCK_SIZE = 32;
+ const dim3 gridDim((W + BLOCK_SIZE - 1) / BLOCK_SIZE,
+ (H + BLOCK_SIZE - 1) / BLOCK_SIZE, B);
+ const dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE);
+ if (im.element_size() == 4) {
+ block_mc_forward_kernel
+ <<>>(out, im, flow, B, C, H, W);
+ } else if (im.element_size() == 2) {
+ block_mc_forward_kernel<__half>
+ <<>>(out, im, flow, B, C, H, W);
+ }
+}
diff --git a/DCVC-FM/src/models/extensions/common.h b/DCVC-FM/src/models/extensions/common.h
new file mode 100644
index 0000000..ab367d2
--- /dev/null
+++ b/DCVC-FM/src/models/extensions/common.h
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+#pragma once
+
+template struct GPUTensor {
+ GPUTensor(torch::Tensor &t) : ptr(static_cast(t.data_ptr())) {
+ assert(sizeof(T) == t.element_size());
+ assert(t.dim() <= 8);
+ for (int i = 0; i < t.dim(); i++) {
+ stride[i] = static_cast(t.stride(i));
+ }
+ }
+ GPUTensor(const torch::Tensor &t) : ptr(static_cast(t.data_ptr())) {
+ assert(sizeof(T) == t.element_size());
+ assert(t.dim() <= 8);
+ for (int i = 0; i < t.dim(); i++) {
+ stride[i] = static_cast(t.stride(i));
+ }
+ }
+
+ T *__restrict__ const ptr;
+ int stride[8] = {0};
+};
+
+template inline __device__ T __totype(float a) {
+ return static_cast(a);
+}
+
+template <> inline __device__ __half __totype(float a) {
+ return __float2half(a);
+}
diff --git a/DCVC-FM/src/models/extensions/setup.py b/DCVC-FM/src/models/extensions/setup.py
new file mode 100644
index 0000000..6a86b25
--- /dev/null
+++ b/DCVC-FM/src/models/extensions/setup.py
@@ -0,0 +1,26 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+
+cxx_flags = ["-O3"]
+nvcc_flags = ["-O3", "--use_fast_math", "--extra-device-vectorization", "-arch=native"]
+
+
+setup(
+ name='block_mc_cpp',
+ ext_modules=[
+ CUDAExtension(
+ name='block_mc_cpp_cuda',
+ sources=[
+ 'block_mc.cpp',
+ 'block_mc_kernel.cu',
+ ],
+ extra_compile_args={
+ "cxx": cxx_flags,
+ "nvcc": nvcc_flags,
+ },)
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ }
+ )
diff --git a/DCVC-FM/src/models/image_model.py b/DCVC-FM/src/models/image_model.py
new file mode 100644
index 0000000..38ec5ba
--- /dev/null
+++ b/DCVC-FM/src/models/image_model.py
@@ -0,0 +1,225 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import torch
+from torch import nn
+
+
+from .common_model import CompressionModel
+from .layers import conv3x3, DepthConvBlock2, DepthConvBlock3, DepthConvBlock4, \
+ ResidualBlockUpsample, ResidualBlockWithStride2
+from .video_net import UNet
+from ..utils.stream_helper import write_ip, get_downsampled_shape
+
+
+class IntraEncoder(nn.Module):
+ def __init__(self, N, inplace=False):
+ super().__init__()
+
+ self.enc_1 = nn.Sequential(
+ ResidualBlockWithStride2(3, 128, inplace=inplace),
+ DepthConvBlock3(128, 128, inplace=inplace),
+ )
+ self.enc_2 = nn.Sequential(
+ ResidualBlockWithStride2(128, 192, inplace=inplace),
+ DepthConvBlock3(192, 192, inplace=inplace),
+ ResidualBlockWithStride2(192, N, inplace=inplace),
+ DepthConvBlock3(N, N, inplace=inplace),
+ nn.Conv2d(N, N, 3, stride=2, padding=1),
+ )
+
+ def forward(self, x, quant_step):
+ out = self.enc_1(x)
+ out = out * quant_step
+ return self.enc_2(out)
+
+
+class IntraDecoder(nn.Module):
+ def __init__(self, N, inplace=False):
+ super().__init__()
+
+ self.dec_1 = nn.Sequential(
+ DepthConvBlock3(N, N, inplace=inplace),
+ ResidualBlockUpsample(N, N, 2, inplace=inplace),
+ DepthConvBlock3(N, N, inplace=inplace),
+ ResidualBlockUpsample(N, 192, 2, inplace=inplace),
+ DepthConvBlock3(192, 192, inplace=inplace),
+ ResidualBlockUpsample(192, 128, 2, inplace=inplace),
+ )
+ self.dec_2 = nn.Sequential(
+ DepthConvBlock3(128, 128, inplace=inplace),
+ ResidualBlockUpsample(128, 16, 2, inplace=inplace),
+ )
+
+ def forward(self, x, quant_step):
+ out = self.dec_1(x)
+ out = out * quant_step
+ return self.dec_2(out)
+
+
+class DMCI(CompressionModel):
+ def __init__(self, N=256, z_channel=128, ec_thread=False, stream_part=1, inplace=False):
+ super().__init__(y_distribution='gaussian', z_channel=z_channel,
+ ec_thread=ec_thread, stream_part=stream_part)
+
+ self.enc = IntraEncoder(N, inplace)
+
+ self.hyper_enc = nn.Sequential(
+ DepthConvBlock4(N, z_channel, inplace=inplace),
+ nn.Conv2d(z_channel, z_channel, 3, stride=2, padding=1),
+ nn.LeakyReLU(inplace=inplace),
+ nn.Conv2d(z_channel, z_channel, 3, stride=2, padding=1),
+ )
+ self.hyper_dec = nn.Sequential(
+ ResidualBlockUpsample(z_channel, z_channel, 2, inplace=inplace),
+ ResidualBlockUpsample(z_channel, z_channel, 2, inplace=inplace),
+ DepthConvBlock4(z_channel, N),
+ )
+
+ self.y_prior_fusion = nn.Sequential(
+ DepthConvBlock4(N, N * 2, inplace=inplace),
+ DepthConvBlock4(N * 2, N * 2 + 2, inplace=inplace),
+ )
+
+ self.y_spatial_prior_reduction = nn.Conv2d(N * 2 + 2, N * 1, 1)
+ self.y_spatial_prior_adaptor_1 = DepthConvBlock2(N * 2, N * 2, inplace=inplace)
+ self.y_spatial_prior_adaptor_2 = DepthConvBlock2(N * 2, N * 2, inplace=inplace)
+ self.y_spatial_prior_adaptor_3 = DepthConvBlock2(N * 2, N * 2, inplace=inplace)
+ self.y_spatial_prior = nn.Sequential(
+ DepthConvBlock2(N * 2, N * 2, inplace=inplace),
+ DepthConvBlock2(N * 2, N * 2, inplace=inplace),
+ DepthConvBlock2(N * 2, N * 2, inplace=inplace),
+ )
+
+ self.dec = IntraDecoder(N, inplace)
+ self.refine = nn.Sequential(
+ UNet(16, 16, inplace=inplace),
+ conv3x3(16, 3),
+ )
+
+ self.q_scale_enc = nn.Parameter(torch.ones((self.get_qp_num(), 128, 1, 1)))
+ self.q_scale_dec = nn.Parameter(torch.ones((self.get_qp_num(), 128, 1, 1)))
+
+ def forward_one_frame(self, x, q_index=None):
+ _, _, H, W = x.size()
+ device = x.device
+ index = self.get_index_tensor(q_index, device)
+ curr_q_enc = torch.index_select(self.q_scale_enc, 0, index)
+ curr_q_dec = torch.index_select(self.q_scale_dec, 0, index)
+
+ y = self.enc(x, curr_q_enc)
+ y_pad, slice_shape = self.pad_for_y(y)
+ z = self.hyper_enc(y_pad)
+ z_q = self.quant(z)
+ z_hat = z_q
+
+ params = self.hyper_dec(z_hat)
+ params = self.y_prior_fusion(params)
+ params = self.slice_to_y(params, slice_shape)
+ y_res, y_q, y_hat, scales_hat = self.forward_four_part_prior(
+ y, params,
+ self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
+ self.y_spatial_prior_adaptor_3, self.y_spatial_prior,
+ y_spatial_prior_reduction=self.y_spatial_prior_reduction)
+
+ x_hat = self.dec(y_hat, curr_q_dec)
+ x_hat = self.refine(x_hat)
+
+ y_for_bit = y_q
+ z_for_bit = z_q
+ bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat)
+ bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z, index)
+ pixel_num = H * W
+ bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num
+ bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num
+
+ bits = torch.sum(bpp_y + bpp_z) * pixel_num
+
+ return {
+ "x_hat": x_hat,
+ "bit": bits,
+ }
+
+ def encode(self, x, q_index, sps_id=0, output_file=None):
+ # pic_width and pic_height may be different from x's size. X here is after padding
+ # x_hat has the same size with x
+ if output_file is None:
+ encoded = self.forward_one_frame(x, q_index)
+ result = {
+ 'bit': encoded['bit'].item(),
+ 'x_hat': encoded['x_hat'],
+ }
+ return result
+
+ compressed = self.compress(x, q_index)
+ bit_stream = compressed['bit_stream']
+ written = write_ip(output_file, True, sps_id, bit_stream)
+ result = {
+ 'bit': written * 8,
+ 'x_hat': compressed['x_hat'],
+ }
+ return result
+
+ def compress(self, x, q_index):
+ device = x.device
+ index = self.get_index_tensor(q_index, device)
+ curr_q_enc = torch.index_select(self.q_scale_enc, 0, index)
+ curr_q_dec = torch.index_select(self.q_scale_dec, 0, index)
+
+ y = self.enc(x, curr_q_enc)
+ y_pad, slice_shape = self.pad_for_y(y)
+ z = self.hyper_enc(y_pad)
+ z_q = torch.round(z)
+ z_hat = z_q
+
+ params = self.hyper_dec(z_hat)
+ params = self.y_prior_fusion(params)
+ params = self.slice_to_y(params, slice_shape)
+ y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \
+ scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat = self.compress_four_part_prior(
+ y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
+ self.y_spatial_prior_adaptor_3, self.y_spatial_prior,
+ y_spatial_prior_reduction=self.y_spatial_prior_reduction)
+
+ self.entropy_coder.reset()
+ self.bit_estimator_z.encode(z_q, q_index)
+ self.gaussian_encoder.encode(y_q_w_0, scales_w_0)
+ self.gaussian_encoder.encode(y_q_w_1, scales_w_1)
+ self.gaussian_encoder.encode(y_q_w_2, scales_w_2)
+ self.gaussian_encoder.encode(y_q_w_3, scales_w_3)
+ self.entropy_coder.flush()
+
+ x_hat = self.refine(self.dec(y_hat, curr_q_dec)).clamp_(0, 1)
+ bit_stream = self.entropy_coder.get_encoded_stream()
+
+ result = {
+ "bit_stream": bit_stream,
+ "x_hat": x_hat,
+ }
+ return result
+
+ def decompress(self, bit_stream, sps):
+ dtype = next(self.parameters()).dtype
+ device = next(self.parameters()).device
+ index = self.get_index_tensor(sps['qp'], device)
+ curr_q_dec = torch.index_select(self.q_scale_dec, 0, index)
+
+ self.entropy_coder.set_stream(bit_stream)
+ z_size = get_downsampled_shape(sps['height'], sps['width'], 64)
+ y_height, y_width = get_downsampled_shape(sps['height'], sps['width'], 16)
+ slice_shape = self.get_to_y_slice_shape(y_height, y_width)
+ z_q = self.bit_estimator_z.decode_stream(z_size, dtype, device, sps['qp'])
+ z_hat = z_q
+
+ params = self.hyper_dec(z_hat)
+ params = self.y_prior_fusion(params)
+ params = self.slice_to_y(params, slice_shape)
+ y_hat = self.decompress_four_part_prior(params,
+ self.y_spatial_prior_adaptor_1,
+ self.y_spatial_prior_adaptor_2,
+ self.y_spatial_prior_adaptor_3,
+ self.y_spatial_prior,
+ self.y_spatial_prior_reduction)
+
+ x_hat = self.refine(self.dec(y_hat, curr_q_dec)).clamp_(0, 1)
+ return {"x_hat": x_hat}
diff --git a/DCVC-FM/src/models/layers.py b/DCVC-FM/src/models/layers.py
new file mode 100644
index 0000000..5f43237
--- /dev/null
+++ b/DCVC-FM/src/models/layers.py
@@ -0,0 +1,299 @@
+# Copyright 2020 InterDigital Communications, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from torch import nn
+
+
+def conv3x3(in_ch, out_ch, stride=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
+
+
+def subpel_conv3x3(in_ch, out_ch, r=1):
+ """3x3 sub-pixel convolution for up-sampling."""
+ return nn.Sequential(
+ nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
+ )
+
+
+def subpel_conv1x1(in_ch, out_ch, r=1):
+ """1x1 sub-pixel convolution for up-sampling."""
+ return nn.Sequential(
+ nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r)
+ )
+
+
+def conv1x1(in_ch, out_ch, stride=1):
+ """1x1 convolution."""
+ return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
+
+
+class ResidualBlockWithStride2(nn.Module):
+ def __init__(self, in_ch, out_ch, inplace=False):
+ super().__init__()
+ self.down = nn.Conv2d(in_ch, out_ch, 2, stride=2)
+ self.conv = nn.Sequential(
+ nn.Conv2d(out_ch, out_ch, 3, padding=1),
+ nn.LeakyReLU(inplace=inplace),
+ nn.Conv2d(out_ch, out_ch, 1),
+ nn.LeakyReLU(inplace=inplace),
+ )
+
+ def forward(self, x):
+ x = self.down(x)
+ identity = x
+ out = self.conv(x)
+ out = out + identity
+ return out
+
+
+class ResidualBlockWithStride(nn.Module):
+ """Residual block with a stride on the first convolution.
+
+ Args:
+ in_ch (int): number of input channels
+ out_ch (int): number of output channels
+ stride (int): stride value (default: 2)
+ """
+
+ def __init__(self, in_ch, out_ch, stride=2, inplace=False):
+ super().__init__()
+ self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
+ self.leaky_relu = nn.LeakyReLU(inplace=inplace)
+ self.conv2 = conv3x3(out_ch, out_ch)
+ self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
+ if stride != 1:
+ self.downsample = conv1x1(in_ch, out_ch, stride=stride)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.leaky_relu(out)
+ out = self.conv2(out)
+ out = self.leaky_relu2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = out + identity
+ return out
+
+
+class ResidualBlockUpsample(nn.Module):
+ """Residual block with sub-pixel upsampling on the last convolution.
+
+ Args:
+ in_ch (int): number of input channels
+ out_ch (int): number of output channels
+ upsample (int): upsampling factor (default: 2)
+ """
+
+ def __init__(self, in_ch, out_ch, upsample=2, inplace=False):
+ super().__init__()
+ self.subpel_conv = subpel_conv1x1(in_ch, out_ch, upsample)
+ self.leaky_relu = nn.LeakyReLU(inplace=inplace)
+ self.conv = conv3x3(out_ch, out_ch)
+ self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
+ self.upsample = subpel_conv1x1(in_ch, out_ch, upsample)
+
+ def forward(self, x):
+ identity = x
+ out = self.subpel_conv(x)
+ out = self.leaky_relu(out)
+ out = self.conv(out)
+ out = self.leaky_relu2(out)
+ identity = self.upsample(x)
+ out = out + identity
+ return out
+
+
+class ResidualBlock(nn.Module):
+ """Simple residual block with two 3x3 convolutions.
+
+ Args:
+ in_ch (int): number of input channels
+ out_ch (int): number of output channels
+ """
+
+ def __init__(self, in_ch, out_ch, leaky_relu_slope=0.01, inplace=False):
+ super().__init__()
+ self.conv1 = conv3x3(in_ch, out_ch)
+ self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope, inplace=inplace)
+ self.conv2 = conv3x3(out_ch, out_ch)
+ self.adaptor = None
+ if in_ch != out_ch:
+ self.adaptor = conv1x1(in_ch, out_ch)
+
+ def forward(self, x):
+ identity = x
+ if self.adaptor is not None:
+ identity = self.adaptor(identity)
+
+ out = self.conv1(x)
+ out = self.leaky_relu(out)
+ out = self.conv2(out)
+ out = self.leaky_relu(out)
+
+ out = out + identity
+ return out
+
+
+class DepthConv(nn.Module):
+ def __init__(self, in_ch, out_ch, slope=0.01, inplace=False):
+ super().__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_ch, in_ch, 1),
+ nn.LeakyReLU(negative_slope=slope, inplace=inplace),
+ )
+ self.depth_conv = nn.Conv2d(in_ch, in_ch, 3, padding=1, groups=in_ch)
+ self.conv2 = nn.Conv2d(in_ch, out_ch, 1)
+
+ self.adaptor = None
+ if in_ch != out_ch:
+ self.adaptor = nn.Conv2d(in_ch, out_ch, 1)
+
+ def forward(self, x):
+ identity = x
+ if self.adaptor is not None:
+ identity = self.adaptor(identity)
+
+ out = self.conv1(x)
+ out = self.depth_conv(out)
+ out = self.conv2(out)
+
+ return out + identity
+
+
+class DepthConv2(nn.Module):
+ def __init__(self, in_ch, out_ch, slope=0.01, inplace=False):
+ super().__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, 1),
+ nn.LeakyReLU(negative_slope=slope, inplace=inplace),
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, groups=out_ch)
+ )
+ self.conv2 = nn.Conv2d(in_ch, out_ch, 1)
+ self.out_conv = nn.Conv2d(out_ch, out_ch, 1)
+ self.adaptor = None
+ if in_ch != out_ch:
+ self.adaptor = nn.Conv2d(in_ch, out_ch, 1)
+
+ def forward(self, x):
+ identity = x
+ if self.adaptor is not None:
+ identity = self.adaptor(x)
+ x1 = self.conv1(x)
+ x2 = self.conv2(x)
+ x = self.out_conv(x1 * x2)
+ return identity + x
+
+
+class ConvFFN(nn.Module):
+ def __init__(self, in_ch, slope=0.1, inplace=False):
+ super().__init__()
+ internal_ch = max(min(in_ch * 4, 1024), in_ch * 2)
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch, internal_ch, 1),
+ nn.LeakyReLU(negative_slope=slope, inplace=inplace),
+ nn.Conv2d(internal_ch, in_ch, 1),
+ nn.LeakyReLU(negative_slope=slope, inplace=inplace),
+ )
+
+ def forward(self, x):
+ identity = x
+ return identity + self.conv(x)
+
+
+class ConvFFN2(nn.Module):
+ def __init__(self, in_ch, slope=0.1, inplace=False):
+ super().__init__()
+ expansion_factor = 2
+ slope = 0.1
+ internal_ch = in_ch * expansion_factor
+ self.conv = nn.Conv2d(in_ch, internal_ch * 2, 1)
+ self.conv_out = nn.Conv2d(internal_ch, in_ch, 1)
+ self.relu = nn.LeakyReLU(negative_slope=slope, inplace=inplace)
+
+ def forward(self, x):
+ identity = x
+ x1, x2 = self.conv(x).chunk(2, 1)
+ out = x1 * self.relu(x2)
+ return identity + self.conv_out(out)
+
+
+class ConvFFN3(nn.Module):
+ def __init__(self, in_ch, inplace=False):
+ super().__init__()
+ expansion_factor = 2
+ internal_ch = in_ch * expansion_factor
+ self.conv = nn.Conv2d(in_ch, internal_ch * 2, 1)
+ self.conv_out = nn.Conv2d(internal_ch, in_ch, 1)
+ self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
+ self.relu2 = nn.LeakyReLU(negative_slope=0.01, inplace=inplace)
+
+ def forward(self, x):
+ identity = x
+ x1, x2 = self.conv(x).chunk(2, 1)
+ out = self.relu1(x1) + self.relu2(x2)
+ return identity + self.conv_out(out)
+
+
+class DepthConvBlock(nn.Module):
+ def __init__(self, in_ch, out_ch, slope_depth_conv=0.01, slope_ffn=0.1, inplace=False):
+ super().__init__()
+ self.block = nn.Sequential(
+ DepthConv(in_ch, out_ch, slope=slope_depth_conv, inplace=inplace),
+ ConvFFN(out_ch, slope=slope_ffn, inplace=inplace),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DepthConvBlock2(nn.Module):
+ def __init__(self, in_ch, out_ch, slope_depth_conv=0.01, slope_ffn=0.1, inplace=False):
+ super().__init__()
+ self.block = nn.Sequential(
+ DepthConv(in_ch, out_ch, slope=slope_depth_conv, inplace=inplace),
+ ConvFFN2(out_ch, slope=slope_ffn, inplace=inplace),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DepthConvBlock3(nn.Module):
+ def __init__(self, in_ch, out_ch, slope_depth_conv=0.01, slope_ffn=0.1, inplace=False):
+ super().__init__()
+ self.block = nn.Sequential(
+ DepthConv2(in_ch, out_ch, slope=slope_depth_conv, inplace=inplace),
+ ConvFFN2(out_ch, slope=slope_ffn, inplace=inplace),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DepthConvBlock4(nn.Module):
+ def __init__(self, in_ch, out_ch, slope_depth_conv=0.01, inplace=False):
+ super().__init__()
+ self.block = nn.Sequential(
+ DepthConv(in_ch, out_ch, slope=slope_depth_conv, inplace=inplace),
+ ConvFFN3(out_ch, inplace=inplace),
+ )
+
+ def forward(self, x):
+ return self.block(x)
diff --git a/DCVC-FM/src/models/video_model.py b/DCVC-FM/src/models/video_model.py
new file mode 100644
index 0000000..5c57698
--- /dev/null
+++ b/DCVC-FM/src/models/video_model.py
@@ -0,0 +1,581 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import time
+
+import torch
+from torch import nn
+import torch.utils.checkpoint
+
+from .common_model import CompressionModel
+from .video_net import ME_Spynet, ResBlock, UNet2, bilinearupsacling, bilineardownsacling
+from .layers import subpel_conv3x3, subpel_conv1x1, DepthConvBlock, DepthConvBlock4, \
+ ResidualBlockWithStride, ResidualBlockUpsample
+from .block_mc import block_mc_func
+from ..utils.stream_helper import get_downsampled_shape, write_ip, write_p_frames
+
+
+g_ch_1x = 48
+g_ch_2x = 64
+g_ch_4x = 96
+g_ch_8x = 96
+g_ch_16x = 128
+g_ch_z = 64
+
+
+class OffsetDiversity(nn.Module):
+ def __init__(self, in_channel=g_ch_1x, aux_feature_num=g_ch_1x+3+2,
+ offset_num=2, group_num=16, max_residue_magnitude=40, inplace=False):
+ super().__init__()
+ self.in_channel = in_channel
+ self.offset_num = offset_num
+ self.group_num = group_num
+ self.max_residue_magnitude = max_residue_magnitude
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(aux_feature_num, g_ch_2x, 3, 2, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=inplace),
+ nn.Conv2d(g_ch_2x, g_ch_2x, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=inplace),
+ nn.Conv2d(g_ch_2x, 3 * group_num * offset_num, 3, 1, 1),
+ )
+ self.fusion = nn.Conv2d(in_channel * offset_num, in_channel, 1, 1, groups=group_num)
+
+ def forward(self, x, aux_feature, flow):
+ B, C, H, W = x.shape
+ out = self.conv_offset(aux_feature)
+ out = bilinearupsacling(out)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ mask = torch.sigmoid(mask)
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
+ offset = offset + flow.repeat(1, self.group_num * self.offset_num, 1, 1)
+
+ # warp
+ offset = offset.view(B * self.group_num * self.offset_num, 2, H, W)
+ mask = mask.view(B * self.group_num * self.offset_num, 1, H, W)
+ x = x.repeat(1, self.offset_num, 1, 1)
+ x = x.view(B * self.group_num * self.offset_num, C // self.group_num, H, W)
+ x = block_mc_func(x, offset)
+ x = x * mask
+ x = x.view(B, C * self.offset_num, H, W)
+ x = self.fusion(x)
+
+ return x
+
+
+class FeatureExtractor(nn.Module):
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.conv1 = nn.Conv2d(g_ch_1x, g_ch_1x, 3, stride=1, padding=1)
+ self.res_block1 = ResBlock(g_ch_1x, inplace=inplace)
+ self.conv2 = nn.Conv2d(g_ch_1x, g_ch_2x, 3, stride=2, padding=1)
+ self.res_block2 = ResBlock(g_ch_2x, inplace=inplace)
+ self.conv3 = nn.Conv2d(g_ch_2x, g_ch_4x, 3, stride=2, padding=1)
+ self.res_block3 = ResBlock(g_ch_4x, inplace=inplace)
+
+ def forward(self, feature):
+ layer1 = self.conv1(feature)
+ layer1 = self.res_block1(layer1)
+
+ layer2 = self.conv2(layer1)
+ layer2 = self.res_block2(layer2)
+
+ layer3 = self.conv3(layer2)
+ layer3 = self.res_block3(layer3)
+
+ return layer1, layer2, layer3
+
+
+class MultiScaleContextFusion(nn.Module):
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.conv3_up = subpel_conv3x3(g_ch_4x, g_ch_2x, 2)
+ self.res_block3_up = ResBlock(g_ch_2x, inplace=inplace)
+ self.conv3_out = nn.Conv2d(g_ch_4x, g_ch_4x, 3, padding=1)
+ self.res_block3_out = ResBlock(g_ch_4x, inplace=inplace)
+ self.conv2_up = subpel_conv3x3(g_ch_2x * 2, g_ch_1x, 2)
+ self.res_block2_up = ResBlock(g_ch_1x, inplace=inplace)
+ self.conv2_out = nn.Conv2d(g_ch_2x * 2, g_ch_2x, 3, padding=1)
+ self.res_block2_out = ResBlock(g_ch_2x, inplace=inplace)
+ self.conv1_out = nn.Conv2d(g_ch_1x * 2, g_ch_1x, 3, padding=1)
+ self.res_block1_out = ResBlock(g_ch_1x, inplace=inplace)
+
+ def forward(self, context1, context2, context3):
+ context3_up = self.conv3_up(context3)
+ context3_up = self.res_block3_up(context3_up)
+ context3_out = self.conv3_out(context3)
+ context3_out = self.res_block3_out(context3_out)
+ context2_up = self.conv2_up(torch.cat((context3_up, context2), dim=1))
+ context2_up = self.res_block2_up(context2_up)
+ context2_out = self.conv2_out(torch.cat((context3_up, context2), dim=1))
+ context2_out = self.res_block2_out(context2_out)
+ context1_out = self.conv1_out(torch.cat((context2_up, context1), dim=1))
+ context1_out = self.res_block1_out(context1_out)
+ context1 = context1 + context1_out
+ context2 = context2 + context2_out
+ context3 = context3 + context3_out
+
+ return context1, context2, context3
+
+
+class MvEnc(nn.Module):
+ def __init__(self, input_channel, channel, inplace=False):
+ super().__init__()
+ self.enc_1 = nn.Sequential(
+ ResidualBlockWithStride(input_channel, channel, stride=2, inplace=inplace),
+ DepthConvBlock4(channel, channel, inplace=inplace),
+ )
+ self.enc_2 = ResidualBlockWithStride(channel, channel, stride=2, inplace=inplace)
+
+ self.adaptor_0 = DepthConvBlock4(channel, channel, inplace=inplace)
+ self.adaptor_1 = DepthConvBlock4(channel * 2, channel, inplace=inplace)
+ self.enc_3 = nn.Sequential(
+ ResidualBlockWithStride(channel, channel, stride=2, inplace=inplace),
+ DepthConvBlock4(channel, channel, inplace=inplace),
+ nn.Conv2d(channel, channel, 3, stride=2, padding=1),
+ )
+
+ def forward(self, x, context, quant_step):
+ out = self.enc_1(x)
+ out = out * quant_step
+ out = self.enc_2(out)
+ if context is None:
+ out = self.adaptor_0(out)
+ else:
+ out = self.adaptor_1(torch.cat((out, context), dim=1))
+ return self.enc_3(out)
+
+
+class MvDec(nn.Module):
+ def __init__(self, output_channel, channel, inplace=False):
+ super().__init__()
+ self.dec_1 = nn.Sequential(
+ DepthConvBlock4(channel, channel, inplace=inplace),
+ ResidualBlockUpsample(channel, channel, 2, inplace=inplace),
+ DepthConvBlock4(channel, channel, inplace=inplace),
+ ResidualBlockUpsample(channel, channel, 2, inplace=inplace),
+ DepthConvBlock4(channel, channel, inplace=inplace)
+ )
+ self.dec_2 = ResidualBlockUpsample(channel, channel, 2, inplace=inplace)
+ self.dec_3 = nn.Sequential(
+ DepthConvBlock4(channel, channel, inplace=inplace),
+ subpel_conv1x1(channel, output_channel, 2),
+ )
+
+ def forward(self, x, quant_step):
+ feature = self.dec_1(x)
+ out = self.dec_2(feature)
+ out = out * quant_step
+ mv = self.dec_3(out)
+ return mv, feature
+
+
+class ContextualEncoder(nn.Module):
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.conv1 = nn.Conv2d(g_ch_1x + 3, g_ch_2x, 3, stride=2, padding=1)
+ self.res1 = DepthConvBlock4(g_ch_2x * 2, g_ch_2x * 2, inplace=inplace)
+ self.conv2 = nn.Conv2d(g_ch_2x * 2, g_ch_4x, 3, stride=2, padding=1)
+ self.res2 = DepthConvBlock4(g_ch_4x * 2, g_ch_4x * 2, inplace=inplace)
+ self.conv3 = nn.Conv2d(g_ch_4x * 2, g_ch_8x, 3, stride=2, padding=1)
+ self.conv4 = nn.Conv2d(g_ch_8x, g_ch_16x, 3, stride=2, padding=1)
+
+ def forward(self, x, context1, context2, context3, quant_step):
+ feature = self.conv1(torch.cat([x, context1], dim=1))
+ feature = self.res1(torch.cat([feature, context2], dim=1))
+ feature = feature * quant_step
+ feature = self.conv2(feature)
+ feature = self.res2(torch.cat([feature, context3], dim=1))
+ feature = self.conv3(feature)
+ feature = self.conv4(feature)
+ return feature
+
+
+class ContextualDecoder(nn.Module):
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.up1 = subpel_conv3x3(g_ch_16x, g_ch_8x, 2)
+ self.up2 = subpel_conv3x3(g_ch_8x, g_ch_4x, 2)
+ self.res1 = DepthConvBlock4(g_ch_4x * 2, g_ch_4x * 2, inplace=inplace)
+ self.up3 = subpel_conv3x3(g_ch_4x * 2, g_ch_2x, 2)
+ self.res2 = DepthConvBlock4(g_ch_2x * 2, g_ch_2x * 2, inplace=inplace)
+ self.up4 = subpel_conv3x3(g_ch_2x * 2, 32, 2)
+
+ def forward(self, x, context2, context3, quant_step):
+ feature = self.up1(x)
+ feature = self.up2(feature)
+ feature = self.res1(torch.cat([feature, context3], dim=1))
+ feature = self.up3(feature)
+ feature = feature * quant_step
+ feature = self.res2(torch.cat([feature, context2], dim=1))
+ feature = self.up4(feature)
+ return feature
+
+
+class ReconGeneration(nn.Module):
+ def __init__(self, ctx_channel=g_ch_1x, res_channel=32, inplace=False):
+ super().__init__()
+ self.first_conv = nn.Conv2d(ctx_channel + res_channel, g_ch_1x, 3, stride=1, padding=1)
+ self.unet_1 = UNet2(g_ch_1x, g_ch_1x, inplace=inplace)
+ self.unet_2 = UNet2(g_ch_1x, g_ch_1x, inplace=inplace)
+ self.recon_conv = nn.Conv2d(g_ch_1x, 3, 3, stride=1, padding=1)
+
+ def forward(self, ctx, res):
+ feature = self.first_conv(torch.cat((ctx, res), dim=1))
+ feature = self.unet_1(feature)
+ feature = self.unet_2(feature)
+ recon = self.recon_conv(feature)
+ return feature, recon
+
+
+class DMC(CompressionModel):
+ def __init__(self, ec_thread=False, stream_part=1, inplace=False):
+ super().__init__(y_distribution='laplace', z_channel=g_ch_z, mv_z_channel=64,
+ ec_thread=ec_thread, stream_part=stream_part)
+
+ channel_mv = 64
+ channel_N = 64
+
+ self.optic_flow = ME_Spynet()
+ self.align = OffsetDiversity(inplace=inplace)
+
+ self.mv_encoder = MvEnc(2, channel_mv)
+ self.mv_hyper_prior_encoder = nn.Sequential(
+ DepthConvBlock4(channel_mv, channel_N, inplace=inplace),
+ nn.Conv2d(channel_N, channel_N, 3, stride=2, padding=1),
+ nn.LeakyReLU(inplace=inplace),
+ nn.Conv2d(channel_N, channel_N, 3, stride=2, padding=1),
+ )
+ self.mv_hyper_prior_decoder = nn.Sequential(
+ ResidualBlockUpsample(channel_N, channel_N, 2, inplace=inplace),
+ ResidualBlockUpsample(channel_N, channel_N, 2, inplace=inplace),
+ DepthConvBlock4(channel_N, channel_mv),
+ )
+
+ self.mv_y_prior_fusion_adaptor_0 = DepthConvBlock(channel_mv * 1, channel_mv * 2,
+ inplace=inplace)
+ self.mv_y_prior_fusion_adaptor_1 = DepthConvBlock(channel_mv * 2, channel_mv * 2,
+ inplace=inplace)
+
+ self.mv_y_prior_fusion = nn.Sequential(
+ DepthConvBlock(channel_mv * 2, channel_mv * 3, inplace=inplace),
+ DepthConvBlock(channel_mv * 3, channel_mv * 3, inplace=inplace),
+ )
+
+ self.mv_y_spatial_prior_adaptor_1 = nn.Conv2d(channel_mv * 4, channel_mv * 3, 1)
+ self.mv_y_spatial_prior_adaptor_2 = nn.Conv2d(channel_mv * 4, channel_mv * 3, 1)
+ self.mv_y_spatial_prior_adaptor_3 = nn.Conv2d(channel_mv * 4, channel_mv * 3, 1)
+
+ self.mv_y_spatial_prior = nn.Sequential(
+ DepthConvBlock(channel_mv * 3, channel_mv * 3, inplace=inplace),
+ DepthConvBlock(channel_mv * 3, channel_mv * 3, inplace=inplace),
+ DepthConvBlock(channel_mv * 3, channel_mv * 2, inplace=inplace),
+ )
+
+ self.mv_decoder = MvDec(2, channel_mv, inplace=inplace)
+
+ self.feature_adaptor_I = nn.Conv2d(3, g_ch_1x, 3, stride=1, padding=1)
+ self.feature_adaptor = nn.ModuleList([nn.Conv2d(g_ch_1x, g_ch_1x, 1) for _ in range(3)])
+ self.feature_extractor = FeatureExtractor(inplace=inplace)
+ self.context_fusion_net = MultiScaleContextFusion(inplace=inplace)
+
+ self.contextual_encoder = ContextualEncoder(inplace=inplace)
+
+ self.contextual_hyper_prior_encoder = nn.Sequential(
+ DepthConvBlock4(g_ch_16x, g_ch_z, inplace=inplace),
+ nn.Conv2d(g_ch_z, g_ch_z, 3, stride=2, padding=1),
+ nn.LeakyReLU(inplace=inplace),
+ nn.Conv2d(g_ch_z, g_ch_z, 3, stride=2, padding=1),
+ )
+ self.contextual_hyper_prior_decoder = nn.Sequential(
+ ResidualBlockUpsample(g_ch_z, g_ch_z, 2, inplace=inplace),
+ ResidualBlockUpsample(g_ch_z, g_ch_z, 2, inplace=inplace),
+ DepthConvBlock4(g_ch_z, g_ch_16x),
+ )
+
+ self.temporal_prior_encoder = nn.Sequential(
+ nn.Conv2d(g_ch_4x, g_ch_8x, 3, stride=2, padding=1),
+ nn.LeakyReLU(0.1, inplace=inplace),
+ nn.Conv2d(g_ch_8x, g_ch_16x, 3, stride=2, padding=1),
+ )
+
+ self.y_prior_fusion_adaptor_0 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 3,
+ inplace=inplace)
+ self.y_prior_fusion_adaptor_1 = DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3,
+ inplace=inplace)
+
+ self.y_prior_fusion = nn.Sequential(
+ DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
+ DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
+ )
+
+ self.y_spatial_prior_adaptor_1 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1)
+ self.y_spatial_prior_adaptor_2 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1)
+ self.y_spatial_prior_adaptor_3 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1)
+
+ self.y_spatial_prior = nn.Sequential(
+ DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
+ DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
+ DepthConvBlock(g_ch_16x * 3, g_ch_16x * 2, inplace=inplace),
+ )
+
+ self.contextual_decoder = ContextualDecoder(inplace=inplace)
+ self.recon_generation_net = ReconGeneration(inplace=inplace)
+
+ self.mv_y_q_enc = nn.Parameter(torch.ones((2, 1, 1, 1)))
+ self.mv_y_q_dec = nn.Parameter(torch.ones((2, 1, 1, 1)))
+
+ self.y_q_enc = nn.Parameter(torch.ones((2, 1, 1, 1)))
+ self.y_q_dec = nn.Parameter(torch.ones((2, 1, 1, 1)))
+
+ def multi_scale_feature_extractor(self, dpb, fa_idx):
+ if dpb["ref_feature"] is None:
+ feature = self.feature_adaptor_I(dpb["ref_frame"])
+ else:
+ feature = self.feature_adaptor[fa_idx](dpb["ref_feature"])
+ return self.feature_extractor(feature)
+
+ def motion_compensation(self, dpb, mv, fa_idx):
+ warpframe = block_mc_func(dpb["ref_frame"], mv)
+ mv2 = bilineardownsacling(mv) / 2
+ mv3 = bilineardownsacling(mv2) / 2
+ ref_feature1, ref_feature2, ref_feature3 = self.multi_scale_feature_extractor(dpb, fa_idx)
+ context1_init = block_mc_func(ref_feature1, mv)
+ context1 = self.align(ref_feature1, torch.cat(
+ (context1_init, warpframe, mv), dim=1), mv)
+ context2 = block_mc_func(ref_feature2, mv2)
+ context3 = block_mc_func(ref_feature3, mv3)
+ context1, context2, context3 = self.context_fusion_net(context1, context2, context3)
+ return context1, context2, context3, warpframe
+
+ def mv_prior_param_decoder(self, mv_z_hat, dpb, slice_shape=None):
+ mv_params = self.mv_hyper_prior_decoder(mv_z_hat)
+ mv_params = self.slice_to_y(mv_params, slice_shape)
+ ref_mv_y = dpb["ref_mv_y"]
+ if ref_mv_y is None:
+ mv_params = self.mv_y_prior_fusion_adaptor_0(mv_params)
+ else:
+ mv_params = torch.cat((mv_params, ref_mv_y), dim=1)
+ mv_params = self.mv_y_prior_fusion_adaptor_1(mv_params)
+ mv_params = self.mv_y_prior_fusion(mv_params)
+ return mv_params
+
+ def contextual_prior_param_decoder(self, z_hat, dpb, context3, slice_shape=None):
+ hierarchical_params = self.contextual_hyper_prior_decoder(z_hat)
+ hierarchical_params = self.slice_to_y(hierarchical_params, slice_shape)
+ temporal_params = self.temporal_prior_encoder(context3)
+ ref_y = dpb["ref_y"]
+ if ref_y is None:
+ params = torch.cat((temporal_params, hierarchical_params), dim=1)
+ params = self.y_prior_fusion_adaptor_0(params)
+ else:
+ params = torch.cat((temporal_params, hierarchical_params, ref_y), dim=1)
+ params = self.y_prior_fusion_adaptor_1(params)
+ params = self.y_prior_fusion(params)
+ return params
+
+ def get_recon_and_feature(self, y_hat, context1, context2, context3, y_q_dec):
+ recon_image_feature = self.contextual_decoder(y_hat, context2, context3, y_q_dec)
+ feature, x_hat = self.recon_generation_net(recon_image_feature, context1)
+ x_hat = x_hat.clamp_(0, 1)
+ return x_hat, feature
+
+ def motion_estimation_and_mv_encoding(self, x, dpb, mv_y_q_enc):
+ est_mv = self.optic_flow(x, dpb["ref_frame"])
+ ref_mv_feature = dpb["ref_mv_feature"]
+ mv_y = self.mv_encoder(est_mv, ref_mv_feature, mv_y_q_enc)
+ return mv_y
+
+ def get_all_q(self, q_index):
+ mv_y_q_enc = self.get_curr_q(self.mv_y_q_enc, q_index)
+ mv_y_q_dec = self.get_curr_q(self.mv_y_q_dec, q_index)
+ y_q_enc = self.get_curr_q(self.y_q_enc, q_index)
+ y_q_dec = self.get_curr_q(self.y_q_dec, q_index)
+ return mv_y_q_enc, mv_y_q_dec, y_q_enc, y_q_dec
+
+ def compress(self, x, dpb, q_index, fa_idx):
+ # pic_width and pic_height may be different from x's size. x here is after padding
+ # x_hat has the same size with x
+ mv_y_q_enc, mv_y_q_dec, y_q_enc, y_q_dec = self.get_all_q(q_index)
+ mv_y = self.motion_estimation_and_mv_encoding(x, dpb, mv_y_q_enc)
+ mv_y_pad, slice_shape = self.pad_for_y(mv_y)
+ mv_z = self.mv_hyper_prior_encoder(mv_y_pad)
+ mv_z_hat = torch.round(mv_z)
+ mv_params = self.mv_prior_param_decoder(mv_z_hat, dpb, slice_shape)
+ mv_y_q_w_0, mv_y_q_w_1, mv_y_q_w_2, mv_y_q_w_3, \
+ mv_scales_w_0, mv_scales_w_1, mv_scales_w_2, mv_scales_w_3, mv_y_hat = \
+ self.compress_four_part_prior(
+ mv_y, mv_params,
+ self.mv_y_spatial_prior_adaptor_1, self.mv_y_spatial_prior_adaptor_2,
+ self.mv_y_spatial_prior_adaptor_3, self.mv_y_spatial_prior)
+
+ mv_hat, mv_feature = self.mv_decoder(mv_y_hat, mv_y_q_dec)
+ context1, context2, context3, _ = self.motion_compensation(dpb, mv_hat, fa_idx)
+
+ y = self.contextual_encoder(x, context1, context2, context3, y_q_enc)
+ y_pad, slice_shape = self.pad_for_y(y)
+ z = self.contextual_hyper_prior_encoder(y_pad)
+ z_hat = torch.round(z)
+ params = self.contextual_prior_param_decoder(z_hat, dpb, context3, slice_shape)
+ y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \
+ scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat = \
+ self.compress_four_part_prior(
+ y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
+ self.y_spatial_prior_adaptor_3, self.y_spatial_prior)
+
+ self.entropy_coder.reset()
+ self.bit_estimator_z_mv.encode(mv_z_hat, 0)
+ self.bit_estimator_z.encode(z_hat, 0)
+ self.gaussian_encoder.encode(mv_y_q_w_0, mv_scales_w_0)
+ self.gaussian_encoder.encode(mv_y_q_w_1, mv_scales_w_1)
+ self.gaussian_encoder.encode(mv_y_q_w_2, mv_scales_w_2)
+ self.gaussian_encoder.encode(mv_y_q_w_3, mv_scales_w_3)
+ self.gaussian_encoder.encode(y_q_w_0, scales_w_0)
+ self.gaussian_encoder.encode(y_q_w_1, scales_w_1)
+ self.gaussian_encoder.encode(y_q_w_2, scales_w_2)
+ self.gaussian_encoder.encode(y_q_w_3, scales_w_3)
+ self.entropy_coder.flush()
+
+ x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec)
+ bit_stream = self.entropy_coder.get_encoded_stream()
+
+ result = {
+ "dpb": {
+ "ref_frame": x_hat,
+ "ref_feature": feature,
+ "ref_mv_feature": mv_feature,
+ "ref_y": y_hat,
+ "ref_mv_y": mv_y_hat,
+ },
+ "bit_stream": bit_stream,
+ }
+ return result
+
+ def decompress(self, bit_stream, dpb, sps):
+ dtype = next(self.parameters()).dtype
+ device = next(self.parameters()).device
+ torch.cuda.synchronize(device=device)
+ t0 = time.time()
+ _, mv_y_q_dec, _, y_q_dec = self.get_all_q(sps['qp'])
+
+ if bit_stream is not None:
+ self.entropy_coder.set_stream(bit_stream)
+ z_size = get_downsampled_shape(sps['height'], sps['width'], 64)
+ y_height, y_width = get_downsampled_shape(sps['height'], sps['width'], 16)
+ slice_shape = self.get_to_y_slice_shape(y_height, y_width)
+ mv_z_hat = self.bit_estimator_z_mv.decode_stream(z_size, dtype, device, 0)
+ z_hat = self.bit_estimator_z.decode_stream(z_size, dtype, device, 0)
+ mv_params = self.mv_prior_param_decoder(mv_z_hat, dpb, slice_shape)
+ mv_y_hat = self.decompress_four_part_prior(mv_params,
+ self.mv_y_spatial_prior_adaptor_1,
+ self.mv_y_spatial_prior_adaptor_2,
+ self.mv_y_spatial_prior_adaptor_3,
+ self.mv_y_spatial_prior)
+
+ mv_hat, mv_feature = self.mv_decoder(mv_y_hat, mv_y_q_dec)
+ context1, context2, context3, _ = self.motion_compensation(dpb, mv_hat, sps['fa_idx'])
+
+ params = self.contextual_prior_param_decoder(z_hat, dpb, context3, slice_shape)
+ y_hat = self.decompress_four_part_prior(params,
+ self.y_spatial_prior_adaptor_1,
+ self.y_spatial_prior_adaptor_2,
+ self.y_spatial_prior_adaptor_3,
+ self.y_spatial_prior)
+ x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec)
+
+ torch.cuda.synchronize(device=device)
+ t1 = time.time()
+ return {
+ "dpb": {
+ "ref_frame": x_hat,
+ "ref_feature": feature,
+ "ref_mv_feature": mv_feature,
+ "ref_y": y_hat,
+ "ref_mv_y": mv_y_hat,
+ },
+ "decoding_time": t1 - t0,
+ }
+
+ def encode(self, x, dpb, q_index, fa_idx, sps_id=0, output_file=None):
+ # pic_width and pic_height may be different from x's size. x here is after padding
+ # x_hat has the same size with x
+ if output_file is None:
+ encoded = self.forward_one_frame(x, dpb, q_index=q_index, fa_idx=fa_idx)
+ result = {
+ "dpb": encoded['dpb'],
+ "bit": encoded['bit'].item(),
+ }
+ return result
+
+ device = x.device
+ torch.cuda.synchronize(device=device)
+ t0 = time.time()
+ encoded = self.compress(x, dpb, q_index, fa_idx)
+ written = write_ip(output_file, False, sps_id, encoded['bit_stream'])
+ torch.cuda.synchronize(device=device)
+ t1 = time.time()
+ result = {
+ "dpb": encoded["dpb"],
+ "bit": written * 8,
+ "encoding_time": t1 - t0,
+ }
+ return result
+
+ def forward_one_frame(self, x, dpb, q_index=None, fa_idx=0):
+ mv_y_q_enc, mv_y_q_dec, y_q_enc, y_q_dec = self.get_all_q(q_index)
+ index = self.get_index_tensor(0, x.device)
+
+ est_mv = self.optic_flow(x, dpb["ref_frame"])
+ mv_y = self.mv_encoder(est_mv, dpb["ref_mv_feature"], mv_y_q_enc)
+
+ mv_y_pad, slice_shape = self.pad_for_y(mv_y)
+ mv_z = self.mv_hyper_prior_encoder(mv_y_pad)
+ mv_z_hat = self.quant(mv_z)
+ mv_params = self.mv_prior_param_decoder(mv_z_hat, dpb, slice_shape)
+ mv_y_res, mv_y_q, mv_y_hat, mv_scales_hat = self.forward_four_part_prior(
+ mv_y, mv_params, self.mv_y_spatial_prior_adaptor_1, self.mv_y_spatial_prior_adaptor_2,
+ self.mv_y_spatial_prior_adaptor_3, self.mv_y_spatial_prior)
+
+ mv_hat, mv_feature = self.mv_decoder(mv_y_hat, mv_y_q_dec)
+
+ context1, context2, context3, _ = self.motion_compensation(dpb, mv_hat, fa_idx)
+
+ y = self.contextual_encoder(x, context1, context2, context3, y_q_enc)
+ y_pad, slice_shape = self.pad_for_y(y)
+ z = self.contextual_hyper_prior_encoder(y_pad)
+ z_hat = self.quant(z)
+ params = self.contextual_prior_param_decoder(z_hat, dpb, context3, slice_shape)
+ y_res, y_q, y_hat, scales_hat = self.forward_four_part_prior(
+ y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
+ self.y_spatial_prior_adaptor_3, self.y_spatial_prior)
+ x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec)
+
+ _, _, H, W = x.size()
+ pixel_num = H * W
+
+ y_for_bit = y_q
+ mv_y_for_bit = mv_y_q
+ z_for_bit = z_hat
+ mv_z_for_bit = mv_z_hat
+ bits_y = self.get_y_laplace_bits(y_for_bit, scales_hat)
+ bits_mv_y = self.get_y_laplace_bits(mv_y_for_bit, mv_scales_hat)
+ bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z, index)
+ bits_mv_z = self.get_z_bits(mv_z_for_bit, self.bit_estimator_z_mv, index)
+
+ bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num
+ bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num
+ bpp_mv_y = torch.sum(bits_mv_y, dim=(1, 2, 3)) / pixel_num
+ bpp_mv_z = torch.sum(bits_mv_z, dim=(1, 2, 3)) / pixel_num
+
+ bpp = bpp_y + bpp_z + bpp_mv_y + bpp_mv_z
+ bit = torch.sum(bpp) * pixel_num
+
+ return {"dpb": {
+ "ref_frame": x_hat,
+ "ref_feature": feature,
+ "ref_mv_feature": mv_feature,
+ "ref_y": y_hat,
+ "ref_mv_y": mv_y_hat,
+ },
+ "bit": bit,
+ }
diff --git a/DCVC-FM/src/models/video_net.py b/DCVC-FM/src/models/video_net.py
new file mode 100644
index 0000000..68e2e3b
--- /dev/null
+++ b/DCVC-FM/src/models/video_net.py
@@ -0,0 +1,209 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .layers import subpel_conv1x1, DepthConvBlock2, DepthConvBlock4
+from .block_mc import block_mc_func
+
+
+def bilinearupsacling(inputfeature):
+ inputheight = inputfeature.size(2)
+ inputwidth = inputfeature.size(3)
+ outfeature = F.interpolate(
+ inputfeature, (inputheight * 2, inputwidth * 2), mode='bilinear', align_corners=False)
+
+ return outfeature
+
+
+def bilineardownsacling(inputfeature):
+ inputheight = inputfeature.size(2)
+ inputwidth = inputfeature.size(3)
+ outfeature = F.interpolate(
+ inputfeature, (inputheight // 2, inputwidth // 2), mode='bilinear', align_corners=False)
+ return outfeature
+
+
+class ResBlock(nn.Module):
+ def __init__(self, channel, slope=0.01, end_with_relu=False,
+ bottleneck=False, inplace=False):
+ super().__init__()
+ in_channel = channel // 2 if bottleneck else channel
+ self.first_layer = nn.LeakyReLU(negative_slope=slope, inplace=False)
+ self.conv1 = nn.Conv2d(channel, in_channel, 3, padding=1)
+ self.relu = nn.LeakyReLU(negative_slope=slope, inplace=inplace)
+ self.conv2 = nn.Conv2d(in_channel, channel, 3, padding=1)
+ self.last_layer = self.relu if end_with_relu else nn.Identity()
+
+ def forward(self, x):
+ identity = x
+ out = self.first_layer(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.last_layer(out)
+ return identity + out
+
+
+class MEBasic(nn.Module):
+ def __init__(self, complexity_level=0):
+ super().__init__()
+ self.relu = nn.ReLU()
+ self.by_pass = False
+ if complexity_level < 0:
+ self.by_pass = True
+ elif complexity_level == 0:
+ self.conv1 = nn.Conv2d(8, 32, 7, 1, padding=3)
+ self.conv2 = nn.Conv2d(32, 64, 7, 1, padding=3)
+ self.conv3 = nn.Conv2d(64, 32, 7, 1, padding=3)
+ self.conv4 = nn.Conv2d(32, 16, 7, 1, padding=3)
+ self.conv5 = nn.Conv2d(16, 2, 7, 1, padding=3)
+ elif complexity_level == 3:
+ self.conv1 = nn.Conv2d(8, 32, 5, 1, padding=2)
+ self.conv2 = nn.Conv2d(32, 64, 5, 1, padding=2)
+ self.conv3 = nn.Conv2d(64, 32, 5, 1, padding=2)
+ self.conv4 = nn.Conv2d(32, 16, 5, 1, padding=2)
+ self.conv5 = nn.Conv2d(16, 2, 5, 1, padding=2)
+
+ def forward(self, x):
+ if self.by_pass:
+ return x[:, -2:, :, :]
+
+ x = self.relu(self.conv1(x))
+ x = self.relu(self.conv2(x))
+ x = self.relu(self.conv3(x))
+ x = self.relu(self.conv4(x))
+ x = self.conv5(x)
+ return x
+
+
+class ME_Spynet(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.me_8x = MEBasic(0)
+ self.me_4x = MEBasic(0)
+ self.me_2x = MEBasic(3)
+ self.me_1x = MEBasic(3)
+
+ def forward(self, im1, im2):
+ batchsize = im1.size()[0]
+
+ im1_1x = im1
+ im1_2x = F.avg_pool2d(im1_1x, kernel_size=2, stride=2)
+ im1_4x = F.avg_pool2d(im1_2x, kernel_size=2, stride=2)
+ im1_8x = F.avg_pool2d(im1_4x, kernel_size=2, stride=2)
+ im2_1x = im2
+ im2_2x = F.avg_pool2d(im2_1x, kernel_size=2, stride=2)
+ im2_4x = F.avg_pool2d(im2_2x, kernel_size=2, stride=2)
+ im2_8x = F.avg_pool2d(im2_4x, kernel_size=2, stride=2)
+
+ shape_fine = im1_8x.size()
+ zero_shape = [batchsize, 2, shape_fine[2], shape_fine[3]]
+ flow_8x = torch.zeros(zero_shape, dtype=im1.dtype, device=im1.device)
+ flow_8x = self.me_8x(torch.cat((im1_8x, im2_8x, flow_8x), dim=1))
+
+ flow_4x = bilinearupsacling(flow_8x) * 2.0
+ flow_4x = flow_4x + self.me_4x(torch.cat((im1_4x,
+ block_mc_func(im2_4x, flow_4x),
+ flow_4x),
+ dim=1))
+
+ flow_2x = bilinearupsacling(flow_4x) * 2.0
+ flow_2x = flow_2x + self.me_2x(torch.cat((im1_2x,
+ block_mc_func(im2_2x, flow_2x),
+ flow_2x),
+ dim=1))
+
+ flow_1x = bilinearupsacling(flow_2x) * 2.0
+ flow_1x = flow_1x + self.me_1x(torch.cat((im1_1x,
+ block_mc_func(im2_1x, flow_1x),
+ flow_1x),
+ dim=1))
+ return flow_1x
+
+
+class UNet(nn.Module):
+ def __init__(self, in_ch=64, out_ch=64, inplace=False):
+ super().__init__()
+ self.conv1 = DepthConvBlock2(in_ch, 32, inplace=inplace)
+ self.down1 = nn.Conv2d(32, 32, 2, stride=2)
+ self.conv2 = DepthConvBlock2(32, 64, inplace=inplace)
+ self.down2 = nn.Conv2d(64, 64, 2, stride=2)
+ self.conv3 = DepthConvBlock2(64, 128, inplace=inplace)
+
+ self.context_refine = nn.Sequential(
+ DepthConvBlock2(128, 128, inplace=inplace),
+ DepthConvBlock2(128, 128, inplace=inplace),
+ DepthConvBlock2(128, 128, inplace=inplace),
+ DepthConvBlock2(128, 128, inplace=inplace),
+ )
+
+ self.up3 = subpel_conv1x1(128, 64, 2)
+ self.up_conv3 = DepthConvBlock2(128, 64, inplace=inplace)
+
+ self.up2 = subpel_conv1x1(64, 32, 2)
+ self.up_conv2 = DepthConvBlock2(64, out_ch, inplace=inplace)
+
+ def forward(self, x):
+ # encoding path
+ x1 = self.conv1(x)
+ x2 = self.down1(x1)
+
+ x2 = self.conv2(x2)
+ x3 = self.down2(x2)
+
+ x3 = self.conv3(x3)
+ x3 = self.context_refine(x3)
+
+ # decoding + concat path
+ d3 = self.up3(x3)
+ d3 = torch.cat((x2, d3), dim=1)
+ d3 = self.up_conv3(d3)
+
+ d2 = self.up2(d3)
+ d2 = torch.cat((x1, d2), dim=1)
+ d2 = self.up_conv2(d2)
+ return d2
+
+
+class UNet2(nn.Module):
+ def __init__(self, in_ch=64, out_ch=64, inplace=False):
+ super().__init__()
+ self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ self.conv1 = DepthConvBlock4(in_ch, 32, inplace=inplace)
+ self.conv2 = DepthConvBlock4(32, 64, inplace=inplace)
+ self.conv3 = DepthConvBlock4(64, 128, inplace=inplace)
+
+ self.context_refine = nn.Sequential(
+ DepthConvBlock4(128, 128, inplace=inplace),
+ DepthConvBlock4(128, 128, inplace=inplace),
+ DepthConvBlock4(128, 128, inplace=inplace),
+ DepthConvBlock4(128, 128, inplace=inplace),
+ )
+
+ self.up3 = subpel_conv1x1(128, 64, 2)
+ self.up_conv3 = DepthConvBlock4(128, 64, inplace=inplace)
+
+ self.up2 = subpel_conv1x1(64, 32, 2)
+ self.up_conv2 = DepthConvBlock4(64, out_ch, inplace=inplace)
+
+ def forward(self, x):
+ # encoding path
+ x1 = self.conv1(x)
+ x2 = self.max_pool(x1)
+
+ x2 = self.conv2(x2)
+ x3 = self.max_pool(x2)
+
+ x3 = self.conv3(x3)
+ x3 = self.context_refine(x3)
+
+ # decoding + concat path
+ d3 = self.up3(x3)
+ d3 = torch.cat((x2, d3), dim=1)
+ d3 = self.up_conv3(d3)
+
+ d2 = self.up2(d3)
+ d2 = torch.cat((x1, d2), dim=1)
+ d2 = self.up_conv2(d2)
+ return d2
diff --git a/DCVC-FM/src/transforms/functional.py b/DCVC-FM/src/transforms/functional.py
new file mode 100644
index 0000000..104554c
--- /dev/null
+++ b/DCVC-FM/src/transforms/functional.py
@@ -0,0 +1,300 @@
+from typing import Tuple, Union
+
+import numpy as np
+import scipy.ndimage
+import torch
+import torch.nn.functional as F
+
+from torch import Tensor
+
+YCBCR_WEIGHTS = {
+ # Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b
+ "ITU-R_BT.709": (0.2126, 0.7152, 0.0722)
+}
+
+
+def rgb_to_ycbcr420(rgb):
+ '''
+ input is 3xhxw RGB float numpy array, in the range of [0, 1]
+ output is y: 1xhxw, uv: 2x(h/2)x(w/2), in the range of [0, 1]
+ '''
+ c, h, w = rgb.shape
+ assert c == 3
+ assert h % 2 == 0
+ assert w % 2 == 0
+ r, g, b = np.split(rgb, 3, axis=0)
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ y = Kr * r + Kg * g + Kb * b
+ cb = 0.5 * (b - y) / (1 - Kb) + 0.5
+ cr = 0.5 * (r - y) / (1 - Kr) + 0.5
+
+ # to 420
+ cb = np.mean(np.reshape(cb, (1, h//2, 2, w//2, 2)), axis=(-1, -3))
+ cr = np.mean(np.reshape(cr, (1, h//2, 2, w//2, 2)), axis=(-1, -3))
+ uv = np.concatenate((cb, cr), axis=0)
+
+ y = np.clip(y, 0., 1.)
+ uv = np.clip(uv, 0., 1.)
+
+ return y, uv
+
+
+def rgb_to_ycbcr444(rgb):
+ '''
+ input is 3xhxw RGB float numpy array, in the range of [0, 1]
+ output is y: 1xhxw, uv: 2xhxw, in the range of [0, 1]
+ '''
+ c, _, _ = rgb.shape
+ assert c == 3
+ r, g, b = np.split(rgb, 3, axis=0)
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ y = Kr * r + Kg * g + Kb * b
+ cb = 0.5 * (b - y) / (1 - Kb) + 0.5
+ cr = 0.5 * (r - y) / (1 - Kr) + 0.5
+ uv = np.concatenate((cb, cr), axis=0)
+
+ y = np.clip(y, 0., 1.)
+ uv = np.clip(uv, 0., 1.)
+
+ return y, uv
+
+
+def ycbcr420_to_rgb(y, uv, order=1):
+ '''
+ y is 1xhxw Y float numpy array, in the range of [0, 1]
+ uv is 2x(h/2)x(w/2) UV float numpy array, in the range of [0, 1]
+ order: 0 nearest neighbor, 1: binear (default)
+ return value is 3xhxw RGB float numpy array, in the range of [0, 1]
+ '''
+ uv = scipy.ndimage.zoom(uv, (1, 2, 2), order=order)
+ cb = uv[0:1, :, :]
+ cr = uv[1:2, :, :]
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ r = y + (2 - 2 * Kr) * (cr - 0.5)
+ b = y + (2 - 2 * Kb) * (cb - 0.5)
+ g = (y - Kr * r - Kb * b) / Kg
+ rgb = np.concatenate((r, g, b), axis=0)
+ rgb = np.clip(rgb, 0., 1.)
+ return rgb
+
+
+def ycbcr444_to_rgb(y, uv):
+ '''
+ y is 1xhxw Y float numpy array, in the range of [0, 1]
+ uv is 2xhxw UV float numpy array, in the range of [0, 1]
+ return value is 3xhxw RGB float numpy array, in the range of [0, 1]
+ '''
+ cb = uv[0:1, :, :]
+ cr = uv[1:2, :, :]
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ r = y + (2 - 2 * Kr) * (cr - 0.5)
+ b = y + (2 - 2 * Kb) * (cb - 0.5)
+ g = (y - Kr * r - Kb * b) / Kg
+ rgb = np.concatenate((r, g, b), axis=0)
+ rgb = np.clip(rgb, 0., 1.)
+ return rgb
+
+
+def ycbcr420_to_444(y, uv, order=0, separate=False):
+ '''
+ y is 1xhxw Y float numpy array, in the range of [0, 1]
+ uv is 2x(h/2)x(w/2) UV float numpy array, in the range of [0, 1]
+ order: 0 nearest neighbor (default), 1: binear
+ return value is 3xhxw YCbCr float numpy array, in the range of [0, 1]
+ '''
+ uv = scipy.ndimage.zoom(uv, (1, 2, 2), order=order)
+ if separate:
+ return y, uv
+ yuv = np.concatenate((y, uv), axis=0)
+ return yuv
+
+
+def ycbcr444_to_420(yuv):
+ '''
+ input is 3xhxw YUV float numpy array, in the range of [0, 1]
+ output is y: 1xhxw, uv: 2x(h/2)x(w/x), in the range of [0, 1]
+ '''
+ c, h, w = yuv.shape
+ assert c == 3
+ assert h % 2 == 0
+ assert w % 2 == 0
+ y, u, v = np.split(yuv, 3, axis=0)
+
+ # to 420
+ u = np.mean(np.reshape(u, (1, h//2, 2, w//2, 2)), axis=(-1, -3))
+ v = np.mean(np.reshape(v, (1, h//2, 2, w//2, 2)), axis=(-1, -3))
+ uv = np.concatenate((u, v), axis=0)
+
+ y = np.clip(y, 0., 1.)
+ uv = np.clip(uv, 0., 1.)
+
+ return y, uv
+
+
+def rgb_to_ycbcr(rgb):
+ '''
+ input is 3xhxw RGB float numpy array, in the range of [0, 1]
+ output is yuv: 3xhxw, in the range of [0, 1]
+ '''
+ c, h, w = rgb.shape
+ assert c == 3
+ r, g, b = np.split(rgb, 3, axis=0)
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ y = Kr * r + Kg * g + Kb * b
+ cb = 0.5 * (b - y) / (1 - Kb) + 0.5
+ cr = 0.5 * (r - y) / (1 - Kr) + 0.5
+
+ yuv = np.concatenate((y, cb, cr), axis=0)
+ yuv = np.clip(yuv, 0., 1.)
+
+ return yuv
+
+
+def ycbcr_to_rgb(yuv):
+ '''
+ yuv is 3xhxw YCbCr float numpy array, in the range of [0, 1]
+ return value is 3xhxw RGB float numpy array, in the range of [0, 1]
+ '''
+ y, cb, cr = np.split(yuv, 3, axis=0)
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ r = y + (2 - 2 * Kr) * (cr - 0.5)
+ b = y + (2 - 2 * Kb) * (cb - 0.5)
+ g = (y - Kr * r - Kb * b) / Kg
+ rgb = np.concatenate((r, g, b), axis=0)
+ rgb = np.clip(rgb, 0., 1.)
+ return rgb
+
+
+def _check_input_tensor(tensor: Tensor) -> None:
+ if (
+ not isinstance(tensor, Tensor)
+ or not tensor.is_floating_point()
+ or not len(tensor.size()) in (3, 4)
+ or not tensor.size(-3) == 3
+ ):
+ raise ValueError(
+ "Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input"
+ )
+
+
+def rgb2ycbcr(rgb: Tensor) -> Tensor:
+ """RGB to YCbCr conversion for torch Tensor.
+ Using ITU-R BT.709 coefficients.
+
+ Args:
+ rgb (torch.Tensor): 3D or 4D floating point RGB tensor
+
+ Returns:
+ ycbcr (torch.Tensor): converted tensor
+ """
+ _check_input_tensor(rgb)
+
+ r, g, b = rgb.chunk(3, -3)
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ y = Kr * r + Kg * g + Kb * b
+ cb = 0.5 * (b - y) / (1 - Kb) + 0.5
+ cr = 0.5 * (r - y) / (1 - Kr) + 0.5
+ ycbcr = torch.cat((y, cb, cr), dim=-3)
+ ycbcr = torch.clamp(ycbcr, 0., 1.)
+ return ycbcr
+
+
+def down_and_upsample(yuv: Tensor) -> Tensor:
+ y, u, v = yuv.chunk(3, 1)
+ u = F.avg_pool2d(u, kernel_size=2, stride=2)
+ u = F.interpolate(u, scale_factor=2, mode='nearest')
+ v = F.avg_pool2d(v, kernel_size=2, stride=2)
+ v = F.interpolate(v, scale_factor=2, mode='nearest')
+ return torch.cat((y, u, v), dim=1)
+
+
+def ycbcr2rgb(ycbcr: Tensor) -> Tensor:
+ """YCbCr to RGB conversion for torch Tensor.
+ Using ITU-R BT.709 coefficients.
+
+ Args:
+ ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor
+
+ Returns:
+ rgb (torch.Tensor): converted tensor
+ """
+ _check_input_tensor(ycbcr)
+
+ y, cb, cr = ycbcr.chunk(3, -3)
+ Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"]
+ r = y + (2 - 2 * Kr) * (cr - 0.5)
+ b = y + (2 - 2 * Kb) * (cb - 0.5)
+ g = (y - Kr * r - Kb * b) / Kg
+ rgb = torch.cat((r, g, b), dim=-3)
+ rgb = torch.clamp(rgb, 0., 1.)
+ return rgb
+
+
+def yuv_444_to_420(
+ yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
+ mode: str = "avg_pool",
+) -> Tuple[Tensor, Tensor, Tensor]:
+ """Convert a 444 tensor to a 420 representation.
+
+ Args:
+ yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444
+ input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple
+ of 3 (Nx1xHxW) tensors.
+ mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
+ ``'avg_pool'``
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
+ """
+ if mode not in ("avg_pool",):
+ raise ValueError(f'Invalid downsampling mode "{mode}".')
+
+ if mode == "avg_pool":
+
+ def _downsample(tensor):
+ return F.avg_pool2d(tensor, kernel_size=2, stride=2)
+
+ if isinstance(yuv, torch.Tensor):
+ y, u, v = yuv.chunk(3, 1)
+ else:
+ y, u, v = yuv
+
+ return (y, _downsample(u), _downsample(v))
+
+
+def yuv_420_to_444(
+ yuv: Tuple[Tensor, Tensor, Tensor],
+ mode: str = "bilinear",
+ return_tuple: bool = False,
+) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]:
+ """Convert a 420 input to a 444 representation.
+
+ Args:
+ yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
+ (Nx1xHxW) format
+ mode (str): algorithm used for upsampling: ``'bilinear'`` |
+ ``'nearest'`` Default ``'bilinear'``
+ return_tuple (bool): return input as tuple of tensors instead of a
+ concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
+ tensor (default: False)
+
+ Returns:
+ (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
+ 444
+ """
+ if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv):
+ raise ValueError("Expected a tuple of 3 torch tensors")
+
+ if mode not in ("bilinear", "nearest"):
+ raise ValueError(f'Invalid upsampling mode "{mode}".')
+
+ if mode in ("bilinear", "nearest"):
+
+ def _upsample(tensor):
+ return F.interpolate(tensor, scale_factor=2, mode=mode, align_corners=False)
+
+ y, u, v = yuv
+ u, v = _upsample(u), _upsample(v)
+ if return_tuple:
+ return y, u, v
+ return torch.cat((y, u, v), dim=1)
diff --git a/DCVC-FM/src/transforms/transforms.py b/DCVC-FM/src/transforms/transforms.py
new file mode 100644
index 0000000..a9f7e59
--- /dev/null
+++ b/DCVC-FM/src/transforms/transforms.py
@@ -0,0 +1,118 @@
+from . import functional as F_transforms
+
+__all__ = [
+ "RGB2YCbCr",
+ "YCbCr2RGB",
+ "YUV444To420",
+ "YUV420To444",
+]
+
+
+class RGB2YCbCr:
+ """Convert a RGB tensor to YCbCr.
+ The tensor is expected to be in the [0, 1] floating point range, with a
+ shape of (3xHxW) or (Nx3xHxW).
+ """
+
+ def __call__(self, rgb):
+ """
+ Args:
+ rgb (torch.Tensor): 3D or 4D floating point RGB tensor
+
+ Returns:
+ ycbcr(torch.Tensor): converted tensor
+ """
+ return F_transforms.rgb2ycbcr(rgb)
+
+ def ___repr__(self):
+ return f"{self.__class__.__name__}()"
+
+
+class YCbCr2RGB:
+ """Convert a YCbCr tensor to RGB.
+ The tensor is expected to be in the [0, 1] floating point range, with a
+ shape of (3xHxW) or (Nx3xHxW).
+ """
+
+ def __call__(self, ycbcr):
+ """
+ Args:
+ ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor
+
+ Returns:
+ rgb(torch.Tensor): converted tensor
+ """
+ return F_transforms.ycbcr2rgb(ycbcr)
+
+ def ___repr__(self):
+ return f"{self.__class__.__name__}()"
+
+
+class YUV444To420:
+ """Convert a YUV 444 tensor to a 420 representation.
+
+ Args:
+ mode (str): algorithm used for downsampling: ``'avg_pool'``. Default
+ ``'avg_pool'``
+
+ Example:
+ >>> x = torch.rand(1, 3, 32, 32)
+ >>> y, u, v = YUV444To420()(x)
+ >>> y.size() # 1, 1, 32, 32
+ >>> u.size() # 1, 1, 16, 16
+ """
+
+ def __init__(self, mode: str = "avg_pool"):
+ self.mode = str(mode)
+
+ def __call__(self, yuv):
+ """
+ Args:
+ yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)):
+ 444 input to be downsampled. Takes either a (Nx3xHxW) tensor or
+ a tuple of 3 (Nx1xHxW) tensors.
+
+ Returns:
+ (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420
+ """
+ return F_transforms.yuv_444_to_420(yuv, mode=self.mode)
+
+ def ___repr__(self):
+ return f"{self.__class__.__name__}()"
+
+
+class YUV420To444:
+ """Convert a YUV 420 input to a 444 representation.
+
+ Args:
+ mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``.
+ Default ``'bilinear'``
+ return_tuple (bool): return input as tuple of tensors instead of a
+ concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW)
+ tensor (default: False)
+
+ Example:
+ >>> y = torch.rand(1, 1, 32, 32)
+ >>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16)
+ >>> x = YUV420To444()((y, u, v))
+ >>> x.size() # 1, 3, 32, 32
+ """
+
+ def __init__(self, mode: str = "bilinear", return_tuple: bool = False):
+ self.mode = str(mode)
+ self.return_tuple = bool(return_tuple)
+
+ def __call__(self, yuv):
+ """
+ Args:
+ yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in
+ (Nx1xHxW) format
+
+ Returns:
+ (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted
+ 444
+ """
+ return F_transforms.yuv_420_to_444(yuv, return_tuple=self.return_tuple)
+
+ def ___repr__(self):
+ return f"{self.__class__.__name__}(return_tuple={self.return_tuple})"
diff --git a/DCVC-FM/src/utils/common.py b/DCVC-FM/src/utils/common.py
new file mode 100644
index 0000000..658cdb5
--- /dev/null
+++ b/DCVC-FM/src/utils/common.py
@@ -0,0 +1,148 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import json
+import os
+from unittest.mock import patch
+
+import numpy as np
+
+
+
+def str2bool(v):
+ return str(v).lower() in ("yes", "y", "true", "t", "1")
+
+
+def create_folder(path, print_if_create=False):
+ if not os.path.exists(path):
+ os.makedirs(path)
+ if print_if_create:
+ print(f"created folder: {path}")
+
+
+@patch('json.encoder.c_make_encoder', None)
+def dump_json(obj, fid, float_digits=-1, **kwargs):
+ of = json.encoder._make_iterencode # pylint: disable=W0212
+
+ def inner(*args, **kwargs):
+ args = list(args)
+ # fifth argument is float formater which we will replace
+ args[4] = lambda o: format(o, '.%df' % float_digits)
+ return of(*args, **kwargs)
+
+ with patch('json.encoder._make_iterencode', wraps=inner):
+ json.dump(obj, fid, **kwargs)
+
+
+def generate_log_json(frame_num, frame_pixel_num, test_time, frame_types, bits, psnrs, ssims,
+ verbose=False):
+ include_yuv = len(psnrs[0]) > 1
+ assert not include_yuv or (len(psnrs[0]) == 4 and len(ssims[0]) == 4)
+ i_bits = 0
+ i_psnr = 0
+ i_psnr_y = 0
+ i_psnr_u = 0
+ i_psnr_v = 0
+ i_ssim = 0
+ i_ssim_y = 0
+ i_ssim_u = 0
+ i_ssim_v = 0
+ p_bits = 0
+ p_psnr = 0
+ p_psnr_y = 0
+ p_psnr_u = 0
+ p_psnr_v = 0
+ p_ssim = 0
+ p_ssim_y = 0
+ p_ssim_u = 0
+ p_ssim_v = 0
+ i_num = 0
+ p_num = 0
+ for idx in range(frame_num):
+ if frame_types[idx] == 0:
+ i_bits += bits[idx]
+ i_psnr += psnrs[idx][0]
+ i_ssim += ssims[idx][0]
+ i_num += 1
+ if include_yuv:
+ i_psnr_y += psnrs[idx][1]
+ i_psnr_u += psnrs[idx][2]
+ i_psnr_v += psnrs[idx][3]
+ i_ssim_y += ssims[idx][1]
+ i_ssim_u += ssims[idx][2]
+ i_ssim_v += ssims[idx][3]
+ else:
+ p_bits += bits[idx]
+ p_psnr += psnrs[idx][0]
+ p_ssim += ssims[idx][0]
+ p_num += 1
+ if include_yuv:
+ p_psnr_y += psnrs[idx][1]
+ p_psnr_u += psnrs[idx][2]
+ p_psnr_v += psnrs[idx][3]
+ p_ssim_y += ssims[idx][1]
+ p_ssim_u += ssims[idx][2]
+ p_ssim_v += ssims[idx][3]
+
+ log_result = {}
+ log_result['frame_pixel_num'] = frame_pixel_num
+ log_result['i_frame_num'] = i_num
+ log_result['p_frame_num'] = p_num
+ log_result['ave_i_frame_bpp'] = i_bits / i_num / frame_pixel_num
+ log_result['ave_i_frame_psnr'] = i_psnr / i_num
+ log_result['ave_i_frame_msssim'] = i_ssim / i_num
+ if include_yuv:
+ log_result['ave_i_frame_psnr_y'] = i_psnr_y / i_num
+ log_result['ave_i_frame_psnr_u'] = i_psnr_u / i_num
+ log_result['ave_i_frame_psnr_v'] = i_psnr_v / i_num
+ log_result['ave_i_frame_msssim_y'] = i_ssim_y / i_num
+ log_result['ave_i_frame_msssim_u'] = i_ssim_u / i_num
+ log_result['ave_i_frame_msssim_v'] = i_ssim_v / i_num
+ if verbose:
+ log_result['frame_bpp'] = list(np.array(bits) / frame_pixel_num)
+ log_result['frame_psnr'] = [v[0] for v in psnrs]
+ log_result['frame_msssim'] = [v[0] for v in ssims]
+ log_result['frame_type'] = frame_types
+ if include_yuv:
+ log_result['frame_psnr_y'] = [v[1] for v in psnrs]
+ log_result['frame_psnr_u'] = [v[2] for v in psnrs]
+ log_result['frame_psnr_v'] = [v[3] for v in psnrs]
+ log_result['frame_msssim_y'] = [v[1] for v in ssims]
+ log_result['frame_msssim_u'] = [v[2] for v in ssims]
+ log_result['frame_msssim_v'] = [v[3] for v in ssims]
+ log_result['test_time'] = test_time
+ if p_num > 0:
+ total_p_pixel_num = p_num * frame_pixel_num
+ log_result['ave_p_frame_bpp'] = p_bits / total_p_pixel_num
+ log_result['ave_p_frame_psnr'] = p_psnr / p_num
+ log_result['ave_p_frame_msssim'] = p_ssim / p_num
+ if include_yuv:
+ log_result['ave_p_frame_psnr_y'] = p_psnr_y / p_num
+ log_result['ave_p_frame_psnr_u'] = p_psnr_u / p_num
+ log_result['ave_p_frame_psnr_v'] = p_psnr_v / p_num
+ log_result['ave_p_frame_msssim_y'] = p_ssim_y / p_num
+ log_result['ave_p_frame_msssim_u'] = p_ssim_u / p_num
+ log_result['ave_p_frame_msssim_v'] = p_ssim_v / p_num
+ else:
+ log_result['ave_p_frame_bpp'] = 0
+ log_result['ave_p_frame_psnr'] = 0
+ log_result['ave_p_frame_msssim'] = 0
+ if include_yuv:
+ log_result['ave_p_frame_psnr_y'] = 0
+ log_result['ave_p_frame_psnr_u'] = 0
+ log_result['ave_p_frame_psnr_v'] = 0
+ log_result['ave_p_frame_msssim_y'] = 0
+ log_result['ave_p_frame_msssim_u'] = 0
+ log_result['ave_p_frame_msssim_v'] = 0
+ log_result['ave_all_frame_bpp'] = (i_bits + p_bits) / (frame_num * frame_pixel_num)
+ log_result['ave_all_frame_psnr'] = (i_psnr + p_psnr) / frame_num
+ log_result['ave_all_frame_msssim'] = (i_ssim + p_ssim) / frame_num
+ if include_yuv:
+ log_result['ave_all_frame_psnr_y'] = (i_psnr_y + p_psnr_y) / frame_num
+ log_result['ave_all_frame_psnr_u'] = (i_psnr_u + p_psnr_u) / frame_num
+ log_result['ave_all_frame_psnr_v'] = (i_psnr_v + p_psnr_v) / frame_num
+ log_result['ave_all_frame_msssim_y'] = (i_ssim_y + p_ssim_y) / frame_num
+ log_result['ave_all_frame_msssim_u'] = (i_ssim_u + p_ssim_u) / frame_num
+ log_result['ave_all_frame_msssim_v'] = (i_ssim_v + p_ssim_v) / frame_num
+
+ return log_result
diff --git a/DCVC-FM/src/utils/metrics.py b/DCVC-FM/src/utils/metrics.py
new file mode 100644
index 0000000..828d230
--- /dev/null
+++ b/DCVC-FM/src/utils/metrics.py
@@ -0,0 +1,94 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import numpy as np
+from scipy import signal
+from scipy import ndimage
+
+
+def fspecial_gauss(size, sigma):
+ x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
+ g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2)))
+ return g / g.sum()
+
+
+def calc_ssim(img1, img2, data_range=255):
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ size = 11
+ sigma = 1.5
+ window = fspecial_gauss(size, sigma)
+ K1 = 0.01
+ K2 = 0.03
+ C1 = (K1 * data_range)**2
+ C2 = (K2 * data_range)**2
+ mu1 = signal.fftconvolve(window, img1, mode='valid')
+ mu2 = signal.fftconvolve(window, img2, mode='valid')
+ mu1_sq = mu1 * mu1
+ mu2_sq = mu2 * mu2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = signal.fftconvolve(window, img1 * img1, mode='valid') - mu1_sq
+ sigma2_sq = signal.fftconvolve(window, img2 * img2, mode='valid') - mu2_sq
+ sigma12 = signal.fftconvolve(window, img1 * img2, mode='valid') - mu1_mu2
+
+ return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2)),
+ (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2))
+
+
+def calc_msssim(img1, img2, data_range=255):
+ '''
+ img1 and img2 are 2D arrays
+ '''
+ level = 5
+ weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
+ height, width = img1.shape
+ if height < 176 or width < 176:
+ # according to HM implementation
+ level = 4
+ weight = np.array([0.0517, 0.3295, 0.3462, 0.2726])
+ if height < 88 or width < 88:
+ assert False
+ downsample_filter = np.ones((2, 2)) / 4.0
+ im1 = img1.astype(np.float64)
+ im2 = img2.astype(np.float64)
+ mssim = np.array([])
+ mcs = np.array([])
+ for _ in range(level):
+ ssim_map, cs_map = calc_ssim(im1, im2, data_range=data_range)
+ mssim = np.append(mssim, ssim_map.mean())
+ mcs = np.append(mcs, cs_map.mean())
+ filtered_im1 = ndimage.filters.convolve(im1, downsample_filter,
+ mode='reflect')
+ filtered_im2 = ndimage.filters.convolve(im2, downsample_filter,
+ mode='reflect')
+ im1 = filtered_im1[::2, ::2]
+ im2 = filtered_im2[::2, ::2]
+ return (np.prod(mcs[0:level - 1]**weight[0:level - 1]) *
+ (mssim[level - 1]**weight[level - 1]))
+
+
+def calc_msssim_rgb(img1, img2, data_range=255):
+ '''
+ img1 and img2 are arrays with 3xHxW
+ '''
+ msssim = 0
+ for i in range(3):
+ msssim += calc_msssim(img1[i, :, :], img2[i, :, :], data_range)
+ return msssim / 3
+
+
+def calc_psnr(img1, img2, data_range=255):
+ '''
+ img1 and img2 are arrays with same shape
+ '''
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean(np.square(img1 - img2))
+ if np.isnan(mse) or np.isinf(mse):
+ return -999.9
+ if mse > 1e-10:
+ psnr = 10 * np.log10(data_range * data_range / mse)
+ else:
+ psnr = 999.9
+ return psnr
diff --git a/DCVC-FM/src/utils/stream_helper.py b/DCVC-FM/src/utils/stream_helper.py
new file mode 100644
index 0000000..ff2b815
--- /dev/null
+++ b/DCVC-FM/src/utils/stream_helper.py
@@ -0,0 +1,249 @@
+# Copyright 2020 InterDigital Communications, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import struct
+from pathlib import Path
+
+import torch
+from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
+
+
+def get_padding_size(height, width, p=64):
+ new_h = (height + p - 1) // p * p
+ new_w = (width + p - 1) // p * p
+ # padding_left = (new_w - width) // 2
+ padding_left = 0
+ padding_right = new_w - width - padding_left
+ # padding_top = (new_h - height) // 2
+ padding_top = 0
+ padding_bottom = new_h - height - padding_top
+ return padding_left, padding_right, padding_top, padding_bottom
+
+
+def get_downsampled_shape(height, width, p):
+ new_h = (height + p - 1) // p * p
+ new_w = (width + p - 1) // p * p
+ return int(new_h / p + 0.5), int(new_w / p + 0.5)
+
+
+def get_state_dict(ckpt_path):
+ ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
+ if "state_dict" in ckpt:
+ ckpt = ckpt['state_dict']
+ if "net" in ckpt:
+ ckpt = ckpt["net"]
+ consume_prefix_in_state_dict_if_present(ckpt, prefix="module.")
+ return ckpt
+
+
+def filesize(filepath: str) -> int:
+ if not Path(filepath).is_file():
+ raise ValueError(f'Invalid file "{filepath}".')
+ return Path(filepath).stat().st_size
+
+
+def write_uints(fd, values, fmt=">{:d}I"):
+ fd.write(struct.pack(fmt.format(len(values)), *values))
+ return len(values) * 4
+
+
+def write_uchars(fd, values, fmt=">{:d}B"):
+ fd.write(struct.pack(fmt.format(len(values)), *values))
+ return len(values)
+
+
+def read_uints(fd, n, fmt=">{:d}I"):
+ sz = struct.calcsize("I")
+ return struct.unpack(fmt.format(n), fd.read(n * sz))
+
+
+def read_uchars(fd, n, fmt=">{:d}B"):
+ sz = struct.calcsize("B")
+ return struct.unpack(fmt.format(n), fd.read(n * sz))
+
+
+def write_bytes(fd, values, fmt=">{:d}s"):
+ if len(values) == 0:
+ return 0
+ fd.write(struct.pack(fmt.format(len(values)), values))
+ return len(values)
+
+
+def read_bytes(fd, n, fmt=">{:d}s"):
+ sz = struct.calcsize("s")
+ return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
+
+
+def write_ushorts(fd, values, fmt=">{:d}H"):
+ fd.write(struct.pack(fmt.format(len(values)), *values))
+ return len(values) * 2
+
+
+def read_ushorts(fd, n, fmt=">{:d}H"):
+ sz = struct.calcsize("H")
+ return struct.unpack(fmt.format(n), fd.read(n * sz))
+
+
+def write_uint_adaptive(f, a):
+ if a <= 32767:
+ a0 = a & 0xff
+ a1 = a >> 8
+ write_uchars(f, (a1, a0))
+ return 2
+
+ assert a < (1 << 30)
+ a0 = a & 0xff
+ a1 = (a >> 8) & 0xff
+ a2 = (a >> 16) & 0xff
+ a3 = (a >> 24) & 0xff
+ a3 = a3 | (1 << 7)
+ write_uchars(f, (a3, a2, a1, a0))
+ return 4
+
+
+def read_uint_adaptive(f):
+ a3 = read_uchars(f, 1)[0]
+ a2 = read_uchars(f, 1)[0]
+
+ if (a3 >> 7) == 0:
+ return (a3 << 8) + a2
+ a3 = a3 & 0x7f
+ a1 = read_uchars(f, 1)[0]
+ a0 = read_uchars(f, 1)[0]
+ return (a3 << 24) + (a2 << 16) + (a1 << 8) + a0
+
+
+class NalType(enum.IntEnum):
+ NAL_SPS = 0
+ NAL_I = 1
+ NAL_P = 2
+ NAL_Ps = 3
+
+
+class SPSHelper():
+ def __init__(self):
+ super().__init__()
+ self.spss = []
+
+ def get_sps_id(self, target_sps):
+ min_id = -1
+ for sps in self.spss:
+ if sps['height'] == target_sps['height'] and sps['width'] == target_sps['width'] and \
+ sps['qp'] == target_sps['qp'] and sps['fa_idx'] == target_sps['fa_idx']:
+ return sps['sps_id'], False
+ if sps['sps_id'] > min_id:
+ min_id = sps['sps_id']
+ assert min_id < 15
+ sps = target_sps.copy()
+ sps['sps_id'] = min_id + 1
+ self.spss.append(sps)
+ return sps['sps_id'], True
+
+ def add_sps_by_id(self, sps):
+ for i in range(len(self.spss)):
+ if self.spss[i]['sps_id'] == sps['sps_id']:
+ self.spss[i] = sps.copy()
+ return
+ self.spss.append(sps.copy())
+
+ def get_sps_by_id(self, sps_id):
+ for sps in self.spss:
+ if sps['sps_id'] == sps_id:
+ return sps
+ return None
+
+
+def write_sps(f, sps):
+ # nal_type(4), sps_id(4)
+ # height (variable)
+ # width (vairable)
+ # qp(6), fa_idx(2)
+ assert sps['sps_id'] < 16
+ assert sps['qp'] < 64
+ assert sps['fa_idx'] < 4
+ written = 0
+ flag = int((NalType.NAL_SPS << 4) + sps['sps_id'])
+ written += write_uchars(f, (flag,))
+ written += write_uint_adaptive(f, sps['height'])
+ written += write_uint_adaptive(f, sps['width'])
+ flag = (sps['qp'] << 2) + sps['fa_idx']
+ written += write_uchars(f, (flag,))
+ return written
+
+
+def read_header(f):
+ header = {}
+ flag = read_uchars(f, 1)[0]
+ nal_type = flag >> 4
+ header['nal_type'] = NalType(nal_type)
+ if nal_type < 3:
+ header['sps_id'] = flag & 0x0f
+ return header
+
+ frame_num_minus1 = flag & 0x0f
+ frame_num = frame_num_minus1 + 1
+ header['frame_num'] = frame_num
+ sps_ids = []
+ for _ in range(0, frame_num, 2):
+ flag = read_uchars(f, 1)[0]
+ sps_ids.append(flag >> 4)
+ sps_ids.append(flag & 0x0f)
+ sps_ids = sps_ids[:frame_num]
+ header['sps_ids'] = sps_ids
+ return header
+
+
+def read_sps_remaining(f, sps_id):
+ sps = {}
+ sps['sps_id'] = sps_id
+ sps['height'] = read_uint_adaptive(f)
+ sps['width'] = read_uint_adaptive(f)
+ flag = read_uchars(f, 1)[0]
+ sps['qp'] = flag >> 2
+ sps['fa_idx'] = flag & 0x03
+ return sps
+
+
+def write_ip(f, is_i_frame, sps_id, bit_stream):
+ written = 0
+ flag = (int(NalType.NAL_I if is_i_frame else NalType.NAL_P) << 4) + sps_id
+ written += write_uchars(f, (flag,))
+ # we write all the streams in the same file, thus, we need to write the per-frame length
+ # if packed independently, we do not need to write it
+ written += write_uint_adaptive(f, len(bit_stream))
+ written += write_bytes(f, bit_stream)
+ return written
+
+
+def read_ip_remaining(f):
+ stream_length = read_uint_adaptive(f)
+ bit_stream = read_bytes(f, stream_length)
+ return bit_stream
+
+
+def write_p_frames(f, sps_ids, bit_stream):
+ frame_num_minus1 = len(sps_ids) - 1
+ assert frame_num_minus1 < 16
+ written = 0
+ flag = (int(NalType.NAL_Ps) << 4) + frame_num_minus1
+ written += write_uchars(f, (flag,))
+ if len(sps_ids) % 2 == 1:
+ sps_ids.append(0)
+ for i in range(0, len(sps_ids), 2):
+ flag = (sps_ids[i] << 4) + sps_ids[i+1]
+ written += write_uchars(f, (flag,))
+ written += write_uint_adaptive(f, len(bit_stream))
+ written += write_bytes(f, bit_stream)
+ return written
diff --git a/DCVC-FM/src/utils/test_helper.py b/DCVC-FM/src/utils/test_helper.py
new file mode 100644
index 0000000..f11a40b
--- /dev/null
+++ b/DCVC-FM/src/utils/test_helper.py
@@ -0,0 +1,486 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import argparse
+import json
+import multiprocessing
+import os
+import time
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+from src.models.video_model import DMC
+from src.models.image_model import DMCI
+
+from src.utils.common import str2bool, create_folder, generate_log_json
+from src.utils.stream_helper import get_padding_size, get_state_dict, SPSHelper, NalType, \
+ write_sps, read_header, read_sps_remaining, read_ip_remaining
+from src.utils.video_reader import PNGReader, YUVReader
+from src.utils.video_writer import PNGWriter, YUVWriter
+from src.utils.metrics import calc_psnr, calc_msssim, calc_msssim_rgb
+from src.transforms.functional import ycbcr444_to_420, ycbcr420_to_444, \
+ rgb_to_ycbcr444, ycbcr444_to_rgb
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Example testing script")
+
+ parser.add_argument("--ec_thread", type=str2bool, default=False)
+ parser.add_argument("--stream_part_i", type=int, default=1)
+ parser.add_argument("--stream_part_p", type=int, default=1)
+ parser.add_argument('--model_path_i', type=str)
+ parser.add_argument('--model_path_p', type=str)
+ parser.add_argument('--rate_num', type=int, default=4)
+ parser.add_argument('--q_indexes_i', type=int, nargs="+")
+ parser.add_argument('--q_indexes_p', type=int, nargs="+")
+ parser.add_argument("--force_intra", type=str2bool, default=False)
+ parser.add_argument("--force_frame_num", type=int, default=-1)
+ parser.add_argument("--force_intra_period", type=int, default=-1)
+ parser.add_argument("--rate_gop_size", type=int, default=8, choices=[4, 8])
+ parser.add_argument('--reset_interval', type=int, default=32, required=False)
+ parser.add_argument('--test_config', type=str, required=True)
+ parser.add_argument('--force_root_path', type=str, default=None, required=False)
+ parser.add_argument("--worker", "-w", type=int, default=1, help="worker number")
+ parser.add_argument('--float16', type=str2bool, default=False)
+ parser.add_argument("--cuda", type=str2bool, default=False)
+ parser.add_argument('--cuda_idx', type=int, nargs="+", help='GPU indexes to use')
+ parser.add_argument('--calc_ssim', type=str2bool, default=False, required=False)
+ parser.add_argument('--write_stream', type=str2bool, default=False)
+ parser.add_argument('--stream_path', type=str, default="out_bin")
+ parser.add_argument('--save_decoded_frame', type=str2bool, default=False)
+ parser.add_argument('--output_path', type=str, required=True)
+ parser.add_argument('--verbose_json', type=str2bool, default=False)
+ parser.add_argument('--verbose', type=int, default=0)
+
+ args = parser.parse_args()
+ return args
+
+
+def np_image_to_tensor(img):
+ image = torch.from_numpy(img).type(torch.FloatTensor)
+ image = image.unsqueeze(0)
+ return image
+
+
+def get_src_reader(args):
+ if args['src_type'] == 'png':
+ src_reader = PNGReader(args['src_path'], args['src_width'], args['src_height'])
+ elif args['src_type'] == 'yuv420':
+ src_reader = YUVReader(args['src_path'], args['src_width'], args['src_height'])
+ return src_reader
+
+
+def get_src_frame(args, src_reader, device):
+ if args['src_type'] == 'yuv420':
+ y, uv = src_reader.read_one_frame(dst_format="420")
+ yuv = ycbcr420_to_444(y, uv)
+ x = np_image_to_tensor(yuv)
+ y = y[0, :, :]
+ u = uv[0, :, :]
+ v = uv[1, :, :]
+ rgb = None
+ else:
+ assert args['src_type'] == 'png'
+ rgb = src_reader.read_one_frame(dst_format="rgb")
+ y, uv = rgb_to_ycbcr444(rgb)
+ u, v = None, None
+ yuv = np.concatenate((y, uv), axis=0)
+ x = np_image_to_tensor(yuv)
+
+ if args['float16']:
+ x = x.to(torch.float16)
+ x = x.to(device)
+ return x, y, u, v, rgb
+
+
+def get_distortion(args, x_hat, y, u, v, rgb):
+ if args['src_type'] == 'yuv420':
+ yuv_rec = x_hat.squeeze(0).cpu().numpy()
+ y_rec, uv_rec = ycbcr444_to_420(yuv_rec)
+ y_rec = y_rec[0, :, :]
+ u_rec = uv_rec[0, :, :]
+ v_rec = uv_rec[1, :, :]
+ psnr_y = calc_psnr(y, y_rec, data_range=1)
+ psnr_u = calc_psnr(u, u_rec, data_range=1)
+ psnr_v = calc_psnr(v, v_rec, data_range=1)
+ psnr = (6 * psnr_y + psnr_u + psnr_v) / 8
+ if args['calc_ssim']:
+ ssim_y = calc_msssim(y, y_rec, data_range=1)
+ ssim_u = calc_msssim(u, u_rec, data_range=1)
+ ssim_v = calc_msssim(v, v_rec, data_range=1)
+ else:
+ ssim_y, ssim_u, ssim_v = 0., 0., 0.
+ ssim = (6 * ssim_y + ssim_u + ssim_v) / 8
+
+ curr_psnr = [psnr, psnr_y, psnr_u, psnr_v]
+ curr_ssim = [ssim, ssim_y, ssim_u, ssim_v]
+ else:
+ assert args['src_type'] == 'png'
+ yuv_rec = x_hat.squeeze(0).cpu().numpy()
+ rgb_rec = ycbcr444_to_rgb(yuv_rec[:1, :, :], yuv_rec[1:, :, :])
+ psnr = calc_psnr(rgb, rgb_rec, data_range=1)
+ if args['calc_ssim']:
+ msssim = calc_msssim_rgb(rgb, rgb_rec, data_range=1)
+ else:
+ msssim = 0.
+ curr_psnr = [psnr]
+ curr_ssim = [msssim]
+ return curr_psnr, curr_ssim
+
+
+def run_one_point_fast(p_frame_net, i_frame_net, args):
+ frame_num = args['frame_num']
+ rate_gop_size = args['rate_gop_size']
+ verbose = args['verbose']
+ reset_interval = args['reset_interval']
+ verbose_json = args['verbose_json']
+ device = next(i_frame_net.parameters()).device
+
+ frame_types = []
+ psnrs = []
+ msssims = []
+ bits = []
+ index_map = [0, 1, 0, 2, 0, 2, 0, 2]
+
+ start_time = time.time()
+ src_reader = get_src_reader(args)
+ pic_height = args['src_height']
+ pic_width = args['src_width']
+ padding_l, padding_r, padding_t, padding_b = get_padding_size(pic_height, pic_width, 16)
+
+ with torch.no_grad():
+ for frame_idx in range(frame_num):
+ frame_start_time = time.time()
+ x, y, u, v, rgb = get_src_frame(args, src_reader, device)
+
+ # pad if necessary
+ x_padded = F.pad(x, (padding_l, padding_r, padding_t, padding_b), mode="replicate")
+
+ if frame_idx % args['intra_period'] == 0:
+ result = i_frame_net.encode(x_padded, args['q_index_i'])
+ dpb = {
+ "ref_frame": result["x_hat"],
+ "ref_feature": None,
+ "ref_mv_feature": None,
+ "ref_y": None,
+ "ref_mv_y": None,
+ }
+ recon_frame = result["x_hat"]
+ frame_types.append(0)
+ bits.append(result["bit"])
+ else:
+ if reset_interval > 0 and frame_idx % reset_interval == 1:
+ dpb["ref_feature"] = None
+ fa_idx = index_map[frame_idx % rate_gop_size]
+ result = p_frame_net.encode(x_padded, dpb, args['q_index_p'], fa_idx)
+
+ dpb = result["dpb"]
+ recon_frame = dpb["ref_frame"]
+ frame_types.append(1)
+ bits.append(result['bit'])
+
+ recon_frame = recon_frame.clamp_(0, 1)
+ x_hat = F.pad(recon_frame, (-padding_l, -padding_r, -padding_t, -padding_b))
+ frame_end_time = time.time()
+ curr_psnr, curr_ssim = get_distortion(args, x_hat, y, u, v, rgb)
+ psnrs.append(curr_psnr)
+ msssims.append(curr_ssim)
+
+ if verbose >= 2:
+ print(f"frame {frame_idx}, {frame_end_time - frame_start_time:.3f} seconds, "
+ f"bits: {bits[-1]:.3f}, PSNR: {psnrs[-1][0]:.4f}, "
+ f"MS-SSIM: {msssims[-1][0]:.4f} ")
+
+ src_reader.close()
+ test_time = time.time() - start_time
+
+ log_result = generate_log_json(frame_num, pic_height * pic_width, test_time,
+ frame_types, bits, psnrs, msssims, verbose=verbose_json)
+ return log_result
+
+
+def run_one_point_with_stream(p_frame_net, i_frame_net, args):
+ frame_num = args['frame_num']
+ rate_gop_size = args['rate_gop_size']
+ save_decoded_frame = args['save_decoded_frame']
+ verbose = args['verbose']
+ reset_interval = args['reset_interval']
+ verbose_json = args['verbose_json']
+ device = next(i_frame_net.parameters()).device
+
+ src_reader = get_src_reader(args)
+ pic_height = args['src_height']
+ pic_width = args['src_width']
+ padding_l, padding_r, padding_t, padding_b = get_padding_size(pic_height, pic_width, 16)
+
+ frame_types = []
+ psnrs = []
+ msssims = []
+ bits = []
+
+ start_time = time.time()
+ p_frame_number = 0
+ overall_p_encoding_time = 0
+ overall_p_decoding_time = 0
+ index_map = [0, 1, 0, 2, 0, 2, 0, 2]
+
+ bitstream_path = Path(args['curr_bin_path'])
+ output_file = bitstream_path.open("wb")
+ sps_helper = SPSHelper()
+ outstanding_sps_bytes = 0
+ sps_buffer = []
+
+ with torch.no_grad():
+ for frame_idx in range(frame_num):
+ frame_start_time = time.time()
+ x, y, u, v, rgb = get_src_frame(args, src_reader, device)
+
+ # pad if necessary
+ x_padded = F.pad(x, (padding_l, padding_r, padding_t, padding_b), mode="replicate")
+
+ if frame_idx % args['intra_period'] == 0:
+ sps = {
+ 'sps_id': -1,
+ 'height': pic_height,
+ 'width': pic_width,
+ 'qp': args['q_index_i'],
+ 'fa_idx': 0,
+ }
+ sps_id, sps_new = sps_helper.get_sps_id(sps)
+ sps['sps_id'] = sps_id
+ if sps_new:
+ outstanding_sps_bytes += write_sps(output_file, sps)
+ if verbose >= 2:
+ print("new sps", sps)
+ result = i_frame_net.encode(x_padded, args['q_index_i'], sps_id, output_file)
+ dpb = {
+ "ref_frame": result["x_hat"],
+ "ref_feature": None,
+ "ref_mv_feature": None,
+ "ref_y": None,
+ "ref_mv_y": None,
+ }
+ recon_frame = result["x_hat"]
+ frame_types.append(0)
+ bits.append(result["bit"] + outstanding_sps_bytes * 8)
+ outstanding_sps_bytes = 0
+ else:
+ fa_idx = index_map[frame_idx % rate_gop_size]
+ if reset_interval > 0 and frame_idx % reset_interval == 1:
+ dpb["ref_feature"] = None
+ fa_idx = 3
+
+ sps = {
+ 'sps_id': -1,
+ 'height': pic_height,
+ 'width': pic_width,
+ 'qp': args['q_index_p'],
+ 'fa_idx': fa_idx,
+ }
+ sps_id, sps_new = sps_helper.get_sps_id(sps)
+ sps['sps_id'] = sps_id
+ if sps_new:
+ outstanding_sps_bytes += write_sps(output_file, sps)
+ if verbose >= 2:
+ print("new sps", sps)
+ result = p_frame_net.encode(x_padded, dpb, args['q_index_p'], fa_idx, sps_id,
+ output_file)
+
+ dpb = result["dpb"]
+ recon_frame = dpb["ref_frame"]
+ frame_types.append(1)
+ bits.append(result['bit'] + outstanding_sps_bytes * 8)
+ outstanding_sps_bytes = 0
+ p_frame_number += 1
+ overall_p_encoding_time += result['encoding_time']
+
+ recon_frame = recon_frame.clamp_(0, 1)
+ x_hat = F.pad(recon_frame, (-padding_l, -padding_r, -padding_t, -padding_b))
+ frame_end_time = time.time()
+ curr_psnr, curr_ssim = get_distortion(args, x_hat, y, u, v, rgb)
+ psnrs.append(curr_psnr)
+ msssims.append(curr_ssim)
+
+ if verbose >= 2:
+ print(f"frame {frame_idx} encoded, {frame_end_time - frame_start_time:.3f} s, "
+ f"bits: {bits[-1]}, PSNR: {psnrs[-1][0]:.4f}, "
+ f"MS-SSIM: {msssims[-1][0]:.4f} ")
+
+ src_reader.close()
+ output_file.close()
+ sps_helper = SPSHelper()
+ input_file = bitstream_path.open("rb")
+ decoded_frame_number = 0
+ src_reader = get_src_reader(args)
+
+ if save_decoded_frame:
+ if args['src_type'] == 'png':
+ recon_writer = PNGWriter(args['bin_folder'], args['src_width'], args['src_height'])
+ elif args['src_type'] == 'yuv420':
+ recon_writer = YUVWriter(args['curr_rec_path'], args['src_width'], args['src_height'])
+ pending_frame_spss = []
+ with torch.no_grad():
+ while decoded_frame_number < frame_num:
+ new_stream = False
+ if len(pending_frame_spss) == 0:
+ header = read_header(input_file)
+ if header['nal_type'] == NalType.NAL_SPS:
+ sps = read_sps_remaining(input_file, header['sps_id'])
+ sps_helper.add_sps_by_id(sps)
+ if verbose >= 2:
+ print("new sps", sps)
+ continue
+ if header['nal_type'] == NalType.NAL_Ps:
+ pending_frame_spss = header['sps_ids'][1:]
+ sps_id = header['sps_ids'][0]
+ else:
+ sps_id = header['sps_id']
+ new_stream = True
+ else:
+ sps_id = pending_frame_spss[0]
+ pending_frame_spss.pop(0)
+ sps = sps_helper.get_sps_by_id(sps_id)
+ if new_stream:
+ bit_stream = read_ip_remaining(input_file)
+ else:
+ bit_stream = None
+ frame_start_time = time.time()
+ x, y, u, v, rgb = get_src_frame(args, src_reader, device)
+ if header['nal_type'] == NalType.NAL_I:
+ decoded = i_frame_net.decompress(bit_stream, sps)
+ dpb = {
+ "ref_frame": decoded["x_hat"],
+ "ref_feature": None,
+ "ref_mv_feature": None,
+ "ref_y": None,
+ "ref_mv_y": None,
+ }
+ recon_frame = decoded["x_hat"]
+ elif header['nal_type'] == NalType.NAL_P or header['nal_type'] == NalType.NAL_Ps:
+ if sps['fa_idx'] == 3:
+ dpb["ref_feature"] = None
+ decoded = p_frame_net.decompress(bit_stream, dpb, sps)
+ dpb = decoded["dpb"]
+ recon_frame = dpb["ref_frame"]
+ overall_p_decoding_time += decoded['decoding_time']
+
+ recon_frame = recon_frame.clamp_(0, 1)
+ x_hat = F.pad(recon_frame, (-padding_l, -padding_r, -padding_t, -padding_b))
+ frame_end_time = time.time()
+ curr_psnr, curr_ssim = get_distortion(args, x_hat, y, u, v, rgb)
+ assert psnrs[decoded_frame_number][0] == curr_psnr[0]
+
+ if verbose >= 2:
+ stream_length = 0 if bit_stream is None else len(bit_stream) * 8
+ print(f"frame {decoded_frame_number} decoded, "
+ f"{frame_end_time - frame_start_time:.3f} s, "
+ f"bits: {stream_length}, PSNR: {curr_psnr[0]:.4f} ")
+
+ if save_decoded_frame:
+ yuv_rec = x_hat.squeeze(0).cpu().numpy()
+ if args['src_type'] == 'yuv420':
+ y_rec, uv_rec = ycbcr444_to_420(yuv_rec)
+ recon_writer.write_one_frame(y=y_rec, uv=uv_rec, src_format='420')
+ else:
+ assert args['src_type'] == 'png'
+ rgb_rec = ycbcr444_to_rgb(yuv_rec[:1, :, :], yuv_rec[1:, :, :])
+ recon_writer.write_one_frame(rgb=rgb_rec, src_format='rgb')
+ decoded_frame_number += 1
+ input_file.close()
+ src_reader.close()
+
+ if save_decoded_frame:
+ recon_writer.close()
+
+ test_time = time.time() - start_time
+ if verbose >= 1 and p_frame_number > 0:
+ print(f"encoding/decoding {p_frame_number} P frames, "
+ f"average encoding time {overall_p_encoding_time/p_frame_number * 1000:.0f} ms, "
+ f"average decoding time {overall_p_decoding_time/p_frame_number * 1000:.0f} ms.")
+
+ log_result = generate_log_json(frame_num, pic_height * pic_width, test_time,
+ frame_types, bits, psnrs, msssims, verbose=verbose_json)
+ with open(args['curr_json_path'], 'w') as fp:
+ json.dump(log_result, fp, indent=2)
+ return log_result
+
+
+i_frame_net = None # the model is initialized after each process is spawn, thus OK for multiprocess
+p_frame_net = None
+
+
+def worker(args):
+ global i_frame_net
+ global p_frame_net
+
+ sub_dir_name = args['seq']
+ bin_folder = os.path.join(args['stream_path'], args['ds_name'])
+ if args['write_stream']:
+ create_folder(bin_folder, True)
+
+ args['src_path'] = os.path.join(args['dataset_path'], sub_dir_name)
+ args['bin_folder'] = bin_folder
+ args['curr_bin_path'] = os.path.join(bin_folder,
+ f"{args['seq']}_q{args['q_index_i']}.bin")
+ args['curr_rec_path'] = args['curr_bin_path'].replace('.bin', '.yuv')
+ args['curr_json_path'] = args['curr_bin_path'].replace('.bin', '.json')
+
+ if args['write_stream']:
+ result = run_one_point_with_stream(p_frame_net, i_frame_net, args)
+ else:
+ result = run_one_point_fast(p_frame_net, i_frame_net, args)
+
+ result['ds_name'] = args['ds_name']
+ result['seq'] = args['seq']
+ result['rate_idx'] = args['rate_idx']
+
+ return result
+
+
+def init_func(args, gpu_num):
+ torch.backends.cudnn.benchmark = False
+ torch.use_deterministic_algorithms(True)
+ torch.manual_seed(0)
+ torch.set_num_threads(1)
+ np.random.seed(seed=0)
+
+ process_name = multiprocessing.current_process().name
+ process_idx = int(process_name[process_name.rfind('-') + 1:])
+ gpu_id = -1
+ if gpu_num > 0:
+ gpu_id = process_idx % gpu_num
+ if gpu_id >= 0:
+ if args.cuda_idx is not None:
+ gpu_id = args.cuda_idx[gpu_id]
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+ device = "cuda:0"
+ else:
+ device = "cpu"
+
+ global i_frame_net
+ i_state_dict = get_state_dict(args.model_path_i)
+ i_frame_net = DMCI(ec_thread=args.ec_thread, stream_part=args.stream_part_i, inplace=True)
+ i_frame_net.load_state_dict(i_state_dict)
+ i_frame_net = i_frame_net.to(device)
+ i_frame_net.eval()
+
+ global p_frame_net
+ if not args.force_intra:
+ p_state_dict = get_state_dict(args.model_path_p)
+ p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p, inplace=True)
+ p_frame_net.load_state_dict(p_state_dict)
+ p_frame_net = p_frame_net.to(device)
+ p_frame_net.eval()
+
+ if args.write_stream:
+ if p_frame_net is not None:
+ p_frame_net.update(force=True)
+ i_frame_net.update(force=True)
+
+ if args.float16:
+ if p_frame_net is not None:
+ p_frame_net.half()
+ i_frame_net.half()
diff --git a/DCVC-FM/src/utils/video_reader.py b/DCVC-FM/src/utils/video_reader.py
new file mode 100644
index 0000000..c25c9d6
--- /dev/null
+++ b/DCVC-FM/src/utils/video_reader.py
@@ -0,0 +1,184 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+
+import numpy as np
+from PIL import Image
+from ..transforms.functional import rgb_to_ycbcr420, ycbcr420_to_rgb, ycbcr444_to_rgb, \
+ ycbcr444_to_420, rgb_to_ycbcr444
+
+
+class VideoReader():
+ def __init__(self, src_path, width, height):
+ self.src_path = src_path
+ self.width = width
+ self.height = height
+ self.eof = False
+
+ def read_one_frame(self, dst_format='rgb'):
+ '''
+ y is 1xhxw Y float numpy array, in the range of [0, 1]
+ uv is 2x(h/2)x(w/2) UV float numpy array, in the range of [0, 1]
+ rgb is 3xhxw float numpy array, in the range of [0, 1]
+ '''
+ raise NotImplementedError
+
+ @staticmethod
+ def _none_exist_frame(dst_format):
+ if dst_format == "420":
+ return None, None
+ assert dst_format == "rgb"
+ return None
+
+ @staticmethod
+ def _get_dst_format(rgb=None, y=None, uv=None, src_format='rgb', dst_format='rgb'):
+ if dst_format == 'rgb':
+ if src_format == '420':
+ rgb = ycbcr420_to_rgb(y, uv, order=1)
+ elif src_format == '444':
+ rgb = ycbcr444_to_rgb(y, uv)
+ return rgb
+ elif dst_format == '420':
+ if src_format == 'rgb':
+ y, uv = rgb_to_ycbcr420(rgb)
+ elif src_format == '444':
+ y, uv = ycbcr444_to_420(np.concatenate((y, uv), axis=0))
+ return y, uv
+ elif dst_format == '444':
+ if src_format == 'rgb':
+ y, uv = rgb_to_ycbcr444(rgb)
+ elif src_format == '420':
+ y, uv = ycbcr444_to_420(y, uv)
+ return y, uv
+ assert False
+
+
+class PNGReader(VideoReader):
+ def __init__(self, src_path, width, height, start_num=1):
+ super().__init__(src_path, width, height)
+
+ pngs = os.listdir(self.src_path)
+ if 'im1.png' in pngs:
+ self.padding = 1
+ elif 'im00001.png' in pngs:
+ self.padding = 5
+ else:
+ raise ValueError('unknown image naming convention; please specify')
+ self.current_frame_index = start_num
+
+ def read_one_frame(self, dst_format="rgb"):
+ if self.eof:
+ return self._none_exist_frame(dst_format)
+
+ png_path = os.path.join(self.src_path,
+ f"im{str(self.current_frame_index).zfill(self.padding)}.png"
+ )
+ if not os.path.exists(png_path):
+ self.eof = True
+ return self._none_exist_frame(dst_format)
+
+ rgb = Image.open(png_path).convert('RGB')
+ rgb = np.asarray(rgb).astype('float32').transpose(2, 0, 1)
+ rgb = rgb / 255.
+ _, height, width = rgb.shape
+ assert height == self.height
+ assert width == self.width
+
+ self.current_frame_index += 1
+ return self._get_dst_format(rgb=rgb, src_format='rgb', dst_format=dst_format)
+
+ def close(self):
+ self.current_frame_index = 1
+
+
+class RGBReader(VideoReader):
+ def __init__(self, src_path, width, height, src_format='rgb', bit_depth=8):
+ super().__init__(src_path, width, height)
+
+ self.src_format = src_format
+ self.bit_depth = bit_depth
+ self.rgb_size = width * height * 3
+ self.dtype = np.uint8
+ self.max_val = 255
+ if bit_depth > 8 and bit_depth <= 16:
+ self.rgb_size = self.rgb_size * 2
+ self.dtype = np.uint16
+ self.max_val = (1 << bit_depth) - 1
+ else:
+ assert bit_depth == 8
+ # pylint: disable=R1732
+ self.file = open(src_path, "rb")
+ # pylint: enable=R1732
+
+ def read_one_frame(self, dst_format="420"):
+ if self.eof:
+ return self._none_exist_frame(dst_format)
+ rgb = self.file.read(self.rgb_size)
+ if not rgb:
+ self.eof = True
+ return self._none_exist_frame(dst_format)
+ rgb = np.frombuffer(rgb, dtype=self.dtype).copy().reshape(3, self.height, self.width)
+ rgb = rgb.astype(np.float32) / self.max_val
+
+ return self._get_dst_format(rgb=rgb, src_format='rgb', dst_format=dst_format)
+
+ def close(self):
+ self.file.close()
+
+
+class YUVReader(VideoReader):
+ def __init__(self, src_path, width, height, src_format='420', bit_depth=8, skip_frame=0):
+ super().__init__(src_path, width, height)
+ if not src_path.endswith('.yuv'):
+ src_path = src_path + '.yuv'
+ self.src_path = src_path
+
+ self.src_format = src_format
+ self.y_size = width * height
+ self.uv_size = self.y_size * 2
+ self.uv_width = width
+ self.uv_height = height
+ self.src_format = '444'
+ if src_format == '420':
+ self.src_format = '420'
+ self.uv_size = width * height // 2
+ self.uv_width = width // 2
+ self.uv_height = height // 2
+ self.dtype = np.uint8
+ self.max_val = 255
+ if bit_depth > 8 and bit_depth <= 16:
+ self.y_size = self.y_size * 2
+ self.uv_size = self.uv_size * 2
+ self.dtype = np.uint16
+ self.max_val = (1 << bit_depth) - 1
+ else:
+ assert bit_depth == 8
+ # pylint: disable=R1732
+ self.file = open(src_path, "rb")
+ # pylint: enable=R1732
+ skipped_frame = 0
+ while not self.eof and skipped_frame < skip_frame:
+ y = self.file.read(self.y_size)
+ uv = self.file.read(self.uv_size)
+ if not y or not uv:
+ self.eof = True
+ skipped_frame += 1
+
+ def read_one_frame(self, dst_format="420"):
+ if self.eof:
+ return self._none_exist_frame(dst_format)
+ y = self.file.read(self.y_size)
+ uv = self.file.read(self.uv_size)
+ if not y or not uv:
+ self.eof = True
+ return self._none_exist_frame(dst_format)
+ y = np.frombuffer(y, dtype=self.dtype).copy().reshape(1, self.height, self.width)
+ uv = np.frombuffer(uv, dtype=self.dtype).copy().reshape(2, self.uv_height, self.uv_width)
+ y = y.astype(np.float32) / self.max_val
+ uv = uv.astype(np.float32) / self.max_val
+
+ return self._get_dst_format(y=y, uv=uv, src_format=self.src_format, dst_format=dst_format)
+
+ def close(self):
+ self.file.close()
diff --git a/DCVC-FM/src/utils/video_writer.py b/DCVC-FM/src/utils/video_writer.py
new file mode 100644
index 0000000..b1063aa
--- /dev/null
+++ b/DCVC-FM/src/utils/video_writer.py
@@ -0,0 +1,131 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+
+import numpy as np
+from PIL import Image
+from ..transforms.functional import ycbcr420_to_rgb, rgb_to_ycbcr420, ycbcr444_to_rgb, \
+ rgb_to_ycbcr444, ycbcr420_to_444, ycbcr444_to_420
+
+
+class VideoWriter():
+ def __init__(self, dst_path, width, height):
+ self.dst_path = dst_path
+ self.width = width
+ self.height = height
+
+ def write_one_frame(self, rgb=None, y=None, uv=None, src_format="rgb"):
+ '''
+ y is 1xhxw Y float numpy array, in the range of [0, 1]
+ uv is 2x(h/2)x(w/2) UV float numpy array, in the range of [0, 1]
+ rgb is 3xhxw float numpy array, in the range of [0, 1]
+ '''
+ raise NotImplementedError
+
+
+class PNGWriter(VideoWriter):
+ def __init__(self, dst_path, width, height):
+ super().__init__(dst_path, width, height)
+ self.padding = 5
+ self.current_frame_index = 1
+ os.makedirs(dst_path, exist_ok=True)
+
+ def write_one_frame(self, rgb=None, y=None, uv=None, src_format="rgb"):
+ if src_format == "420":
+ rgb = ycbcr420_to_rgb(y, uv, order=1)
+ elif src_format == "444":
+ rgb = ycbcr444_to_rgb(y, uv)
+ rgb = rgb.transpose(1, 2, 0)
+
+ png_path = os.path.join(self.dst_path,
+ f"im{str(self.current_frame_index).zfill(self.padding)}.png"
+ )
+ img = np.clip(np.rint(rgb * 255), 0, 255).astype(np.uint8)
+ Image.fromarray(img).save(png_path)
+
+ self.current_frame_index += 1
+
+ def close(self):
+ self.current_frame_index = 1
+
+
+class RGBWriter(VideoWriter):
+ def __init__(self, dst_path, width, height, dst_format='rgb', bit_depth=8):
+ super().__init__(dst_path, width, height)
+
+ self.dst_format = dst_format
+ self.bit_depth = bit_depth
+ self.rgb_size = width * height * 3
+ self.dtype = np.uint8
+ self.max_val = 255
+ if bit_depth > 8 and bit_depth <= 16:
+ self.rgb_size = self.rgb_size * 2
+ self.dtype = np.uint16
+ self.max_val = (1 << bit_depth) - 1
+ else:
+ assert bit_depth == 8
+ # pylint: disable=R1732
+ self.file = open(dst_path, "wb")
+ # pylint: enable=R1732
+
+ def write_one_frame(self, rgb=None, y=None, uv=None, src_format="rgb"):
+ if src_format == '420':
+ rgb = ycbcr420_to_rgb(y, uv, order=1)
+ elif src_format == '444':
+ rgb = ycbcr444_to_rgb(y, uv)
+ rgb = np.clip(np.rint(rgb * self.max_val), 0, self.max_val).astype(self.dtype)
+
+ self.file.write(rgb.tobytes())
+
+ def close(self):
+ self.file.close()
+
+
+class YUVWriter(VideoWriter):
+ def __init__(self, dst_path, width, height, dst_format='420', bit_depth=8):
+ super().__init__(dst_path, width, height)
+ if not dst_path.endswith('.yuv'):
+ dst_path = dst_path + '/out.yuv'
+ self.dst_path = dst_path
+
+ self.dst_format = dst_format
+ self.y_size = width * height
+ self.uv_size = width * height
+ if dst_format == '420':
+ self.uv_size = width * height // 2
+ self.bit_depth = bit_depth
+ self.dtype = np.uint8
+ self.max_val = 255
+ if bit_depth > 8 and bit_depth <= 16:
+ self.y_size = self.y_size * 2
+ self.uv_size = self.uv_size * 2
+ self.dtype = np.uint16
+ self.max_val = (1 << bit_depth) - 1
+ else:
+ assert bit_depth == 8
+ self.eof = False
+ # pylint: disable=R1732
+ self.file = open(dst_path, "wb")
+ # pylint: enable=R1732
+
+ def write_one_frame(self, rgb=None, y=None, uv=None, src_format="420"):
+ if src_format == 'rgb':
+ if self.dst_format == '420':
+ y, uv = rgb_to_ycbcr420(rgb)
+ elif self.dst_format == '444':
+ y, uv = rgb_to_ycbcr444(rgb)
+ else:
+ assert False
+ elif src_format == '420' and self.dst_format == '444':
+ y, uv = ycbcr420_to_444(y, uv, separate=True)
+ elif src_format == '444' and self.dst_format == '420':
+ y, uv = ycbcr444_to_420(y, uv)
+ y = np.clip(np.rint(y * self.max_val), 0, self.max_val).astype(self.dtype)
+ uv = np.clip(np.rint(uv * self.max_val), 0, self.max_val).astype(self.dtype)
+
+ self.file.write(y.tobytes())
+ self.file.write(uv.tobytes())
+
+ def close(self):
+ self.file.close()
diff --git a/DCVC-FM/test_data_to_png.py b/DCVC-FM/test_data_to_png.py
new file mode 100644
index 0000000..6c4f456
--- /dev/null
+++ b/DCVC-FM/test_data_to_png.py
@@ -0,0 +1,29 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from src.utils.video_reader import YUVReader
+from src.utils.video_writer import PNGWriter
+
+
+def convert_one_seq_to_png(src_path, width, height, dst_path):
+ src_reader = YUVReader(src_path, width, height, src_format='420')
+ png_writer = PNGWriter(dst_path, width, height)
+ rgb = src_reader.read_one_frame(dst_format='rgb')
+ processed_frame = 0
+ while not src_reader.eof:
+ png_writer.write_one_frame(rgb=rgb, src_format='rgb')
+ processed_frame += 1
+ rgb = src_reader.read_one_frame(dst_format='rgb')
+ print(src_path, processed_frame)
+
+
+def main():
+ src_path = "source_yuv_path"
+ width = 1920
+ height = 1080
+ dst_path = "destination_png_path"
+ convert_one_seq_to_png(src_path, width, height, dst_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DCVC-FM/test_video.py b/DCVC-FM/test_video.py
new file mode 100644
index 0000000..ac750e3
--- /dev/null
+++ b/DCVC-FM/test_video.py
@@ -0,0 +1,142 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+import concurrent.futures
+import json
+import multiprocessing
+import time
+
+import torch
+import numpy as np
+from src.models.video_model import DMC
+from src.utils.common import create_folder, dump_json
+from src.utils.test_helper import parse_args, init_func, worker
+from tqdm import tqdm
+
+
+def main():
+ begin_time = time.time()
+
+ torch.backends.cudnn.enabled = True
+ args = parse_args()
+
+ if args.cuda_idx is not None:
+ cuda_device = ','.join([str(s) for s in args.cuda_idx])
+ os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
+
+ worker_num = args.worker
+ assert worker_num >= 1
+
+ with open(args.test_config) as f:
+ config = json.load(f)
+
+ gpu_num = 0
+ if args.cuda:
+ gpu_num = torch.cuda.device_count()
+
+ multiprocessing.set_start_method("spawn")
+ threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num,
+ initializer=init_func,
+ initargs=(args, gpu_num))
+ objs = []
+
+ count_frames = 0
+ count_sequences = 0
+
+ rate_num = args.rate_num
+ q_indexes_i = []
+ if args.q_indexes_i is not None:
+ assert len(args.q_indexes_i) == rate_num
+ q_indexes_i = args.q_indexes_i
+ else:
+ assert 2 <= rate_num <= DMC.get_qp_num()
+ for i in np.linspace(0, DMC.get_qp_num() - 1, num=rate_num):
+ q_indexes_i.append(int(i+0.5))
+
+ if not args.force_intra:
+ if args.q_indexes_p is not None:
+ assert len(args.q_indexes_p) == rate_num
+ q_indexes_p = args.q_indexes_p
+ else:
+ q_indexes_p = q_indexes_i
+
+ print(f"testing {rate_num} rates, using q_indexes: ", end='')
+ for q in q_indexes_i:
+ print(f"{q}, ", end='')
+ print()
+
+ root_path = args.force_root_path if args.force_root_path is not None else config['root_path']
+ config = config['test_classes']
+ for ds_name in config:
+ if config[ds_name]['test'] == 0:
+ continue
+ for seq in config[ds_name]['sequences']:
+ count_sequences += 1
+ for rate_idx in range(rate_num):
+ cur_args = {}
+ cur_args['rate_idx'] = rate_idx
+ cur_args['float16'] = args.float16
+ cur_args['q_index_i'] = q_indexes_i[rate_idx]
+ if not args.force_intra:
+ cur_args['q_index_p'] = q_indexes_p[rate_idx]
+ cur_args['force_intra'] = args.force_intra
+ cur_args['reset_interval'] = args.reset_interval
+ cur_args['seq'] = seq
+ cur_args['src_type'] = config[ds_name]['src_type']
+ cur_args['src_height'] = config[ds_name]['sequences'][seq]['height']
+ cur_args['src_width'] = config[ds_name]['sequences'][seq]['width']
+ cur_args['intra_period'] = config[ds_name]['sequences'][seq]['intra_period']
+ if args.force_intra:
+ cur_args['intra_period'] = 1
+ if args.force_intra_period > 0:
+ cur_args['intra_period'] = args.force_intra_period
+ cur_args['frame_num'] = config[ds_name]['sequences'][seq]['frames']
+ if args.force_frame_num > 0:
+ cur_args['frame_num'] = args.force_frame_num
+ cur_args['rate_gop_size'] = args.rate_gop_size
+ cur_args['calc_ssim'] = args.calc_ssim
+ cur_args['dataset_path'] = os.path.join(root_path, config[ds_name]['base_path'])
+ cur_args['write_stream'] = args.write_stream
+ cur_args['stream_path'] = args.stream_path
+ cur_args['save_decoded_frame'] = args.save_decoded_frame
+ cur_args['ds_name'] = ds_name
+ cur_args['verbose'] = args.verbose
+ cur_args['verbose_json'] = args.verbose_json
+
+ count_frames += cur_args['frame_num']
+
+ obj = threadpool_executor.submit(worker, cur_args)
+ objs.append(obj)
+
+ results = []
+ for obj in tqdm(objs):
+ result = obj.result()
+ results.append(result)
+
+ log_result = {}
+ for ds_name in config:
+ if config[ds_name]['test'] == 0:
+ continue
+ log_result[ds_name] = {}
+ for seq in config[ds_name]['sequences']:
+ log_result[ds_name][seq] = {}
+
+ for res in results:
+ log_result[res['ds_name']][res['seq']][f"{res['rate_idx']:03d}"] = res
+
+ out_json_dir = os.path.dirname(args.output_path)
+ if len(out_json_dir) > 0:
+ create_folder(out_json_dir, True)
+ with open(args.output_path, 'w') as fp:
+ dump_json(log_result, fp, float_digits=6, indent=2)
+
+ total_minutes = (time.time() - begin_time) / 60
+ print('Test finished')
+ print(f'Tested {count_frames} frames from {count_sequences} sequences')
+ print(f'Total elapsed time: {total_minutes:.1f} min')
+
+
+if __name__ == "__main__":
+ main()
diff --git a/README.md b/README.md
index 03a38c6..589d1f2 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,9 @@ Official Pytorch implementation for Neural Video and Image Compression including
* DCVC-DC: [Neural Video Compression with **D**iverse **C**ontexts](https://arxiv.org/abs/2302.14402), CVPR 2023, in [this folder](./DCVC-DC/).
- The first end-to-end neural video codec to exceed [ECM](https://jvet-experts.org/doc_end_user/documents/27_Teleconference/wg11/JVET-AA0006-v1.zip) using the highest compression ratio low delay configuration with a intra refresh period roughly to one second (32 frames), in terms of PSNR and MS-SSIM for RGB content.
- The first end-to-end neural video codec to exceed ECM using the highest compression ratio low delay configuration with a intra refresh period roughly to one second (32 frames), in terms of PSNR for YUV420 content.
+ * DCVC-FM: [Neural Video Compression with **F**eature **M**odulation](https://arxiv.org/abs/2402.17414), CVPR 2024, in [this folder](./DCVC-FM/).
+ - The first end-to-end neural video codec to exceed ECM using the highest compression ratio low delay configuration with only one intra frame, in terms of PSNR for both YUV420 content and RGB content in a single model.
+ - The first end-to-end neural video codec that support a large quality and bitrate range in a single model.
* Neural Image Codec
* [EVC: Towards Real-Time Neural Image Compression with Mask Decay](https://openreview.net/forum?id=XUxad2Gj40n), ICLR 2023, in [this folder](./EVC/).
@@ -62,6 +65,14 @@ If you find this work useful for your research, please cite:
year={2023}
}
+@inproceedings{li2024neural,
+ title={Neural Video Compression with Feature Modulation},
+ author={Li, Jiahao and Li, Bin and Lu, Yan},
+ booktitle={{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
+ {CVPR} 2024, Seattle, WA, USA, June 17-21, 2024},
+ year={2024}
+}
+
@inproceedings{wang2023EVC,
title={EVC: Towards Real-Time Neural Image Compression with Mask Decay},
author={Wang, Guo-Hua and Li, Jiahao and Li, Bin and Lu, Yan},