diff --git a/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt b/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt index 4c5b16cdc..9ccc219d6 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt +++ b/fairseq2n/python/src/fairseq2n/bindings/CMakeLists.txt @@ -26,6 +26,7 @@ target_sources(py_bindings data/text/init.cc data/text/sentencepiece.cc data/text/text_reader.cc + data/video.cc type_casters/data.cc type_casters/map_fn.cc type_casters/string.cc diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/init.cc b/fairseq2n/python/src/fairseq2n/bindings/data/init.cc index f0408b262..807a1faee 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/data/init.cc +++ b/fairseq2n/python/src/fairseq2n/bindings/data/init.cc @@ -40,6 +40,8 @@ def_data(py::module_ &base) def_audio(m); + def_video(m); + def_image(m); def_data_pipeline(m); diff --git a/fairseq2n/python/src/fairseq2n/bindings/data/video.cc b/fairseq2n/python/src/fairseq2n/bindings/data/video.cc new file mode 100644 index 000000000..cc8a3f4f3 --- /dev/null +++ b/fairseq2n/python/src/fairseq2n/bindings/data/video.cc @@ -0,0 +1,61 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/bindings/module.h" + +#include +#include + +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace fairseq2n { + +void +def_video(py::module_ &data_module) +{ + py::module_ m = data_module.def_submodule("video"); + + // VideoDecoder + py::class_>(m, "VideoDecoder") + .def( + py::init([]( + std::optional maybe_dtype, + std::optional maybe_device, + bool pin_memory, + bool get_pts_only, + bool get_frames_only, + int width = 0, + int height = 0) + { + auto opts = video_decoder_options() + .maybe_dtype(maybe_dtype) + .maybe_device(maybe_device) + .pin_memory(pin_memory) + .get_pts_only(get_pts_only) + .get_frames_only(get_frames_only) + .width(width) + .height(height); + + return std::make_shared(opts); + }), + py::arg("dtype") = std::nullopt, + py::arg("device") = std::nullopt, + py::arg("pin_memory") = false, + py::arg("get_pts_only") = false, + py::arg("get_frames_only") = false, + py::arg("width") = 0, + py::arg("height") = 0) + .def("__call__", &video_decoder::operator(), py::call_guard{}); + + map_functors().register_(); +} +} // namespace fairseq2n diff --git a/fairseq2n/python/src/fairseq2n/bindings/module.h b/fairseq2n/python/src/fairseq2n/bindings/module.h index 3cf686773..36f00e8be 100644 --- a/fairseq2n/python/src/fairseq2n/bindings/module.h +++ b/fairseq2n/python/src/fairseq2n/bindings/module.h @@ -49,4 +49,7 @@ def_text_converters(pybind11::module_ &text_module); void def_text_reader(pybind11::module_ &text_module); +void +def_video(pybind11::module_ &data_module); + } // namespace fairseq2n diff --git a/fairseq2n/src/fairseq2n/CMakeLists.txt b/fairseq2n/src/fairseq2n/CMakeLists.txt index 34d06c2bc..cb3b6ab4e 100644 --- a/fairseq2n/src/fairseq2n/CMakeLists.txt +++ b/fairseq2n/src/fairseq2n/CMakeLists.txt @@ -68,6 +68,10 @@ target_sources(fairseq2n data/text/sentencepiece/sp_encoder.cc data/text/sentencepiece/sp_model.cc data/text/sentencepiece/sp_processor.cc + data/video/video_decoder.cc + data/video/detail/ffmpeg.cc + data/video/detail/stream.cc + data/video/detail/transform.cc ) if(FAIRSEQ2N_SUPPORT_IMAGE) @@ -115,12 +119,33 @@ target_include_directories(fairseq2n ${system} $ ) +find_path(AVCODEC_INCLUDE_DIR libavcodec/avcodec.h) +find_library(AVCODEC_LIBRARY avcodec) +find_path(AVFORMAT_INCLUDE_DIR libavformat/avformat.h) +find_library(AVFORMAT_LIBRARY avformat) +find_path(AVUTIL_INCLUDE_DIR libavutil/avutil.h) +find_library(AVUTIL_LIBRARY avutil) +find_path(SWSCALE_INCLUDE_DIR libswscale/swscale.h) +find_library(SWSCALE_LIBRARY swscale) + +target_include_directories(fairseq2n + PRIVATE + ${AVCODEC_INCLUDE_DIR} + ${AVFORMAT_INCLUDE_DIR} + ${AVUTIL_INCLUDE_DIR} + ${SWSCALE_INCLUDE_DIR} +) + find_package(PNG REQUIRED) find_package(JPEG REQUIRED) target_link_libraries(fairseq2n PRIVATE ${CMAKE_DL_LIBS} + ${AVCODEC_LIBRARY} + ${AVFORMAT_LIBRARY} + ${AVUTIL_LIBRARY} + ${SWSCALE_LIBRARY} PRIVATE fmt::fmt Iconv::Iconv diff --git a/fairseq2n/src/fairseq2n/data/video/detail/ffmpeg.cc b/fairseq2n/src/fairseq2n/data/video/detail/ffmpeg.cc new file mode 100644 index 000000000..dd10ce70f --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/ffmpeg.cc @@ -0,0 +1,211 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/video/detail/ffmpeg.h" + +#include +#include +#include + +#include +#include + +#include "fairseq2n/exception.h" +#include "fairseq2n/float.h" +#include "fairseq2n/fmt.h" +#include "fairseq2n/memory.h" +#include "fairseq2n/data/detail/tensor_helpers.h" +#include "fairseq2n/detail/exception.h" + +using namespace std; + +namespace fairseq2n::detail { + +ffmpeg_decoder::ffmpeg_decoder(video_decoder_options opts) + : opts_{opts} +{} + +data_dict +ffmpeg_decoder::open_container(const memory_block &block) +{ + // Opens the media container and iterates over the streams. + + auto data_ptr = reinterpret_cast(block.data()); + size_t data_size = block.size(); + fairseq2n::detail::buffer_data bd = {data_ptr, data_size, this}; + int ret = 0; + + fmt_ctx_ = avformat_alloc_context(); + if (fmt_ctx_ == nullptr) { + throw_("Failed to allocate AVFormatContext."); + } + // Allocate buffer for input/output operations via AVIOContext + avio_ctx_buffer_ = static_cast(av_malloc(data_size + AV_INPUT_BUFFER_PADDING_SIZE)); + if (avio_ctx_buffer_ == nullptr) { + throw_("Failed to allocate AVIOContext buffer."); + } + // Create an AVIOContext for using custom IO + avio_ctx_ = avio_alloc_context( + avio_ctx_buffer_, + data_size, + 0, // Write flag + &bd, // Pointer to user data + &read_callback, // Read function + nullptr, // Write function, not used + nullptr // Seek function, not used + ); + if (avio_ctx_ == nullptr) { + throw_("Failed to allocate AVIOContext."); + } + if (!read_callback_error_message().empty()) { + throw_("Size is too large to fit in an int"); + } + + fmt_ctx_->pb = avio_ctx_; + fmt_ctx_->flags |= AVFMT_FLAG_CUSTOM_IO; + fmt_ctx_->flags |= AVFMT_FLAG_NONBLOCK; + + // Determine the input format + fmt_ctx_->iformat = nullptr; + if (data_size <= std::numeric_limits::max()) { + AVProbeData probe_data = {nullptr, avio_ctx_buffer_, static_cast(data_size), nullptr}; + fmt_ctx_->iformat = av_probe_input_format(&probe_data, 1); + } + + // Open media and read the header + ret = avformat_open_input(&fmt_ctx_, nullptr, fmt_ctx_->iformat, nullptr); + if (ret < 0) { + throw_with_nested("Failed to open input."); + } + + // Read data from the media + ret = avformat_find_stream_info(fmt_ctx_, nullptr); + if (ret < 0) { + throw_("Failed to find stream information."); + } + + // Iterate over all streams + flat_hash_map all_streams; + for (int i = 0; i < static_cast(fmt_ctx_->nb_streams); i++) { + all_streams[std::to_string(i)] = open_stream(i); + } + + return all_streams; +} + +data_dict +ffmpeg_decoder::open_stream(int stream_index) +{ + // Opens a stream and decodes the video frames. Skips all streams that are not video for now. + + av_stream_ = std::make_unique(stream_index, *fmt_ctx_); + int processed_frames = 0; + if (av_stream_->type_ == AVMEDIA_TYPE_VIDEO) { + av_stream_->alloc_resources(); + + // Fill codec context with codec parameters + int ret = avcodec_parameters_to_context(av_stream_->codec_ctx_, av_stream_->codec_params_); + if (ret < 0) { + throw_("Failed to copy decoder parameters to input decoder context for stream {}\n", + stream_index); + } + + // Open the codec + ret = avcodec_open2(av_stream_->codec_ctx_, av_stream_->codec_, nullptr); + if (ret < 0) { + throw_("Failed to open decoder for stream {}\n", stream_index); + } + + // Create tensor storage for the stream + av_stream_->init_tensor_storage(opts_); + // Iterate over all frames in the stream and decode them + while (av_read_frame(fmt_ctx_, av_stream_->pkt_) >= 0) { + if (av_stream_->pkt_->stream_index == stream_index) { + // Send raw data packet (compressed frame) to the decoder through the codec context + ret = avcodec_send_packet(av_stream_->codec_ctx_, av_stream_->pkt_); + if (ret < 0) { + throw_("Error sending packet to decoder for stream {}\n", + stream_index); + } + // Receive raw data frame (uncompressed frame) from the decoder through the codec context + while (ret >= 0) { + ret = avcodec_receive_frame(av_stream_->codec_ctx_, av_stream_->frame_); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + break; + // EAGAIN is not an error, it means we need more input + // AVERROR_EOF means decoding finished + } else if (ret < 0) { + throw_("Error receiving frame from decoder for stream {}\n", + stream_index); + } + // Tranform frame to RGB to guarantee 3 color channels + sws_ = std::make_unique(av_stream_->frame_->width, av_stream_->frame_->height, + static_cast(av_stream_->frame_->format), opts_); + sws_->transform_to_rgb(*av_stream_->sw_frame_, *av_stream_->frame_, stream_index, opts_); + // Store PTS in microseconds + if (!opts_.get_frames_only()) { + av_stream_->tensor_storage_.frame_pts[processed_frames] = av_stream_->frame_->pts * av_stream_->metadata_.time_base * 1000000; + } + // Store raw frame data for one frame + if (!opts_.get_pts_only()) { + at::Tensor one_frame = av_stream_->tensor_storage_.all_video_frames[processed_frames]; + writable_memory_span frame_bits = get_raw_mutable_storage(one_frame); + auto frame_data = reinterpret_cast(frame_bits.data()); + // Get total size of the frame in bytes + int frame_size = av_image_get_buffer_size(AV_PIX_FMT_RGB24, av_stream_->sw_frame_->width, + av_stream_->sw_frame_->height, 1); + // Copy the entire frame at once + memcpy(frame_data, av_stream_->sw_frame_->data[0], frame_size); + } + processed_frames++; + + av_frame_unref(av_stream_->frame_); // Unref old data so the frame can be reused + av_frame_unref(av_stream_->sw_frame_); + } + } + av_packet_unref(av_stream_->pkt_); + } + av_stream_->init_data_storage(opts_); + + return av_stream_->stream_data_; + } else { + // Skip streams if not video for now + return data_dict{}; + } +} + +int +ffmpeg_decoder::read_callback(void *opaque, uint8_t *buf, int buf_size) +{ + // C style function used by ffmpeg to read from memory buffer + // Read up to buf_size bytes from the resource accessed by the AVIOContext object + auto *bd = static_cast(opaque); + size_t temp_size = std::min(static_cast(buf_size), bd->size); + if (temp_size > std::numeric_limits::max()) { + bd->decoder->error_message_ = "Size is too large to fit in an int"; + return AVERROR(EINVAL); + } + buf_size = static_cast(temp_size); + if (buf_size <= 0) + return AVERROR_EOF; + memcpy(buf, bd->ptr, static_cast(buf_size)); + bd->ptr += buf_size; + bd->size -= static_cast(buf_size); + return buf_size; +} + +ffmpeg_decoder::~ffmpeg_decoder() +{ + if (avio_ctx_ != nullptr) { + av_freep(&avio_ctx_->buffer); + av_freep(&avio_ctx_); + } + if (fmt_ctx_ != nullptr) { + avformat_free_context(fmt_ctx_); + } +} + +} // namespace fairseq2n diff --git a/fairseq2n/src/fairseq2n/data/video/detail/ffmpeg.h b/fairseq2n/src/fairseq2n/data/video/detail/ffmpeg.h new file mode 100644 index 000000000..f9b20c4e5 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/ffmpeg.h @@ -0,0 +1,61 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "fairseq2n/data/video/detail/utils.h" +#include "fairseq2n/data/video/detail/stream.h" +#include "fairseq2n/data/video/detail/transform.h" + +#include "fairseq2n/api.h" +#include "fairseq2n/data/data.h" + +#include +#include + +extern "C" { + #include + #include + #include +} + +namespace fairseq2n::detail { + +class FAIRSEQ2_API ffmpeg_decoder { +public: + explicit + ffmpeg_decoder(video_decoder_options opts = {}); + + ~ffmpeg_decoder(); + + data_dict + open_container(const memory_block &block); + + ffmpeg_decoder(const ffmpeg_decoder&) = delete; + + ffmpeg_decoder& operator=(const ffmpeg_decoder&) = delete; + +private: + data_dict + open_stream(int stream_index); + + static int + read_callback(void *opaque, uint8_t *buf, int buf_size); + + const std::string& read_callback_error_message() const { return error_message_; } + +private: + video_decoder_options opts_; + AVFormatContext* fmt_ctx_ = nullptr; + AVIOContext* avio_ctx_ = nullptr; + uint8_t* avio_ctx_buffer_ = nullptr; + std::unique_ptr av_stream_; + std::unique_ptr sws_; + std::string error_message_; +}; + +} // namespace fairseq2n diff --git a/fairseq2n/src/fairseq2n/data/video/detail/stream.cc b/fairseq2n/src/fairseq2n/data/video/detail/stream.cc new file mode 100644 index 000000000..2ea74e7a0 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/stream.cc @@ -0,0 +1,120 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/video/detail/stream.h" +#include "fairseq2n/data/video/detail/utils.h" + +#include "fairseq2n/exception.h" +#include "fairseq2n/detail/exception.h" + +namespace fairseq2n::detail { + +stream::stream(int stream_index, const AVFormatContext& fmt_ctx) { + // Initialize the AVStream, AVCodecParameters, AVCodec, and get metadata + stream_index_ = stream_index; + av_stream_ = fmt_ctx.streams[stream_index]; + metadata_.numerator = av_stream_->time_base.num; + metadata_.denominator = av_stream_->time_base.den; + metadata_.duration_microseconds = fmt_ctx.duration; + metadata_.fps = av_q2d(av_stream_->avg_frame_rate); + metadata_.num_frames = av_stream_->nb_frames; + metadata_.time_base = av_q2d(av_stream_->time_base); + codec_params_ = av_stream_->codecpar; + metadata_.height = codec_params_->height; + metadata_.width = codec_params_->width; + type_ = codec_params_->codec_type; + codec_ = avcodec_find_decoder(codec_params_->codec_id); + if (codec_ == nullptr) { + throw_("Failed to find decoder for stream {}\n", + stream_index_); + } +} + +void +stream::alloc_resources() { + // Allocate memory to hold the context for decoding process + codec_ctx_ = avcodec_alloc_context3(codec_); + if (codec_ctx_ == nullptr) { + throw_("Failed to allocate the decoder context for stream {}\n", + stream_index_); + } + // Allocate memory to hold the packet + pkt_ = av_packet_alloc(); + if (pkt_ == nullptr) { + throw_("Failed to allocate the packet for stream {}\n", + stream_index_); + } + // Allocate memory to hold the frames + frame_ = av_frame_alloc(); + if (frame_ == nullptr) { + throw_("Failed to allocate the frame for stream {}\n", + stream_index_); + } + sw_frame_ = av_frame_alloc(); + if (sw_frame_ == nullptr) { + throw_("Failed to allocate the software frame for stream {}\n", + stream_index_); + } +} + +void +stream::init_tensor_storage(video_decoder_options opts) { + // Initialize tensors for storing raw frames and metadata + + if (!opts.get_pts_only()) { + tensor_storage_.all_video_frames = at::empty({metadata_.num_frames, metadata_.height, metadata_.width, 3}, + at::dtype(opts.maybe_dtype().value_or(at::kByte)).device(at::kCPU).pinned_memory(opts.pin_memory())); + } + + if (!opts.get_frames_only()) { + tensor_storage_.frame_pts = at::empty({metadata_.num_frames}, + at::dtype(at::kLong).device(at::kCPU).pinned_memory(opts.pin_memory())); + } + + if (!opts.get_pts_only() && !opts.get_frames_only()) { + tensor_storage_.timebase = at::tensor({metadata_.numerator, metadata_.denominator}, + at::dtype(at::kInt).device(at::kCPU).pinned_memory(opts.pin_memory())); + tensor_storage_.fps = at::tensor({metadata_.fps}, + at::dtype(at::kFloat).device(at::kCPU).pinned_memory(opts.pin_memory())); + tensor_storage_.duration = at::tensor({metadata_.duration_microseconds}, + at::dtype(at::kLong).device(at::kCPU).pinned_memory(opts.pin_memory())); + } +} + +void +stream::init_data_storage(video_decoder_options opts) { + + if (!opts.get_pts_only()) { + stream_data_["all_video_frames"] = data(tensor_storage_.all_video_frames); + } + + if (!opts.get_frames_only()) { + stream_data_["frame_pts"] = data(tensor_storage_.frame_pts); + } + + if (!opts.get_pts_only() && !opts.get_frames_only()) { + stream_data_["timebase"] = data(tensor_storage_.timebase); + stream_data_["fps"] = data(tensor_storage_.fps); + stream_data_["duration"] = data(tensor_storage_.duration); + } +} + +stream::~stream() { + if (codec_ctx_ != nullptr) { + avcodec_free_context(&codec_ctx_); + } + if (frame_ != nullptr) { + av_frame_free(&frame_); + } + if (sw_frame_ != nullptr) { + av_frame_free(&sw_frame_); + } + if (pkt_ != nullptr) { + av_packet_free(&pkt_); + } +} + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/video/detail/stream.h b/fairseq2n/src/fairseq2n/data/video/detail/stream.h new file mode 100644 index 000000000..90074eb1e --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/stream.h @@ -0,0 +1,62 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "fairseq2n/api.h" +#include "fairseq2n/data/video/detail/utils.h" + +#include +#include + +extern "C" { + #include + #include +} +#include +#include "fairseq2n/memory.h" +#include "fairseq2n/data/detail/tensor_helpers.h" + +namespace fairseq2n::detail { + +class FAIRSEQ2_API stream { +friend class ffmpeg_decoder; +public: + stream(int stream_index, const AVFormatContext& fmt_ctx); + + void + alloc_resources(); + + ~stream(); + + void + init_tensor_storage(video_decoder_options opts); + + void + init_data_storage(video_decoder_options opts); + + AVCodecContext* get_codec_ctx() const; + + stream(const stream&) = delete; + + stream& operator=(const stream&) = delete; + +private: + AVCodecContext* codec_ctx_{nullptr}; + AVFrame* frame_{nullptr}; + AVFrame *sw_frame_{nullptr}; + AVStream *av_stream_{nullptr}; + AVPacket *pkt_{nullptr}; + media_metadata metadata_; + AVCodecParameters* codec_params_{nullptr}; + AVMediaType type_{AVMEDIA_TYPE_UNKNOWN}; + AVCodec* codec_; + int stream_index_; + tensor_storage tensor_storage_; + data_dict stream_data_; +}; + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/video/detail/transform.cc b/fairseq2n/src/fairseq2n/data/video/detail/transform.cc new file mode 100644 index 000000000..77b870908 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/transform.cc @@ -0,0 +1,55 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +//#include "fairseq2n/data/video/detail/avcodec_resources.h" +//#include "fairseq2n/data/video/detail/utils.h" + +#include "fairseq2n/data/video/detail/transform.h" + +#include "fairseq2n/exception.h" +#include "fairseq2n/detail/exception.h" + +namespace fairseq2n::detail { + +transform::transform(int width, int height, AVPixelFormat fmt, video_decoder_options opts) +{ + int dstWidth = (opts.width() != 0) ? opts.width() : width; + int dstHeight = (opts.height() != 0) ? opts.height() : height; + sws_ctx_ = sws_getContext(width, height, fmt, dstWidth, dstHeight, AV_PIX_FMT_RGB24, + SWS_BILINEAR, nullptr, nullptr, nullptr); + if (sws_ctx_ == nullptr) { + throw_("Failed to create the conversion context\n"); + } +} + +void transform::transform_to_rgb(AVFrame& sw_frame, const AVFrame &frame, int stream_index, +video_decoder_options opts) +{ + // AV_PIX_FMT_RGB24 guarantees 3 color channels + av_frame_unref(&sw_frame); // Safety check + sw_frame.format = AV_PIX_FMT_RGB24; + sw_frame.width = (opts.width() != 0) ? opts.width() : frame.width; + sw_frame.height = (opts.height() != 0) ? opts.height() : frame.height; + int ret = av_frame_get_buffer(&sw_frame, 0); + if (ret < 0) { + throw_("Failed to allocate buffer for the RGB frame for stream {}\n", + stream_index); + } + ret = sws_scale(sws_ctx_, frame.data, frame.linesize, 0, frame.height, + sw_frame.data, sw_frame.linesize); + if (ret < 0) { + throw_("Failed to convert the frame to RGB for stream {}\n", + stream_index); + } +} + +transform::~transform() +{ + if (sws_ctx_ != nullptr) + sws_freeContext(sws_ctx_); +} + +} // namespace fairseq2n::detail \ No newline at end of file diff --git a/fairseq2n/src/fairseq2n/data/video/detail/transform.h b/fairseq2n/src/fairseq2n/data/video/detail/transform.h new file mode 100644 index 000000000..344e7fdf9 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/transform.h @@ -0,0 +1,41 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#include "fairseq2n/api.h" +#include "fairseq2n/data/video/detail/stream.h" + +extern "C" { + #include +} +#include + +namespace fairseq2n::detail { + +class FAIRSEQ2_API transform { +friend class ffmpeg_decoder; + +public: + transform(int width, int height, AVPixelFormat fmt, video_decoder_options opts); + + ~transform(); + + void + transform_to_rgb(AVFrame& sw_frame, const AVFrame &frame, int stream_index, + video_decoder_options opts); + + transform(const transform&) = delete; + + transform& operator=(const transform&) = delete; + +private: + SwsContext *sws_ctx_{nullptr}; +}; + +} // namespace fairseq2n::detail \ No newline at end of file diff --git a/fairseq2n/src/fairseq2n/data/video/detail/utils.h b/fairseq2n/src/fairseq2n/data/video/detail/utils.h new file mode 100644 index 000000000..c0fc42f62 --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/detail/utils.h @@ -0,0 +1,177 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include "fairseq2n/detail/exception.h" +#include "fairseq2n/data/data.h" + +namespace fairseq2n::detail { + +class ffmpeg_decoder; + +struct buffer_data { + const uint8_t *ptr; // Pointer to the start of the memory_block buffer + size_t size; + ffmpeg_decoder *decoder; // Used to access decoder in static functions +}; + +struct media_metadata { + int64_t num_frames{0}; // Number of frames in the stream + int numerator{0}; // Time base numerator + int denominator{0}; // Time base denominator + int64_t duration_microseconds{0}; // Duration of the stream + int height{0}; // Height of a frame in pixels + int width{0}; // Width of a frame in pixels + double time_base{0}; // Time base of the stream + double fps{0}; // Frames per second for video streams +}; + +struct tensor_storage { + at::Tensor all_video_frames; + at::Tensor frame_pts; + at::Tensor timebase; + at::Tensor fps; + at::Tensor duration; +}; + +class video_decoder_options { +public: + video_decoder_options + maybe_dtype(std::optional value) noexcept + { + auto tmp = *this; + + tmp.maybe_dtype_ = value; + + return tmp; + } + + std::optional + maybe_dtype() const noexcept + { + return maybe_dtype_; + } + + video_decoder_options + maybe_device(std::optional value) noexcept + { + auto tmp = *this; + + tmp.maybe_device_ = value; + + return tmp; + } + + std::optional + maybe_device() const noexcept + { + return maybe_device_; + } + + video_decoder_options + get_pts_only(bool value) + { + auto tmp = *this; + + if (value && tmp.get_frames_only_) { + throw_("get_pts_only and get_frames_only cannot both be true"); + } + + tmp.get_pts_only_ = value; + + return tmp; + } + + bool + get_pts_only() const noexcept + { + return get_pts_only_; + } + + video_decoder_options + get_frames_only(bool value) + { + auto tmp = *this; + + if (value && tmp.get_pts_only_) { + throw_("get_pts_only and get_frames_only cannot both be true"); + } + + tmp.get_frames_only_ = value; + + return tmp; + } + + bool + get_frames_only() const noexcept + { + return get_frames_only_; + } + + video_decoder_options + pin_memory(bool value) noexcept + { + auto tmp = *this; + + tmp.pin_memory_ = value; + + return tmp; + } + + bool + pin_memory() const noexcept + { + return pin_memory_; + } + + video_decoder_options + width(int value) noexcept + { + auto tmp = *this; + + tmp.width_ = value; + + return tmp; + } + + int + width() const noexcept + { + return width_; + } + + video_decoder_options + height(int value) noexcept + { + auto tmp = *this; + + tmp.height_ = value; + + return tmp; + } + + int + height() const noexcept + { + return height_; + } + +private: + std::optional maybe_dtype_{}; + std::optional maybe_device_{}; + bool pin_memory_ = false; + bool get_pts_only_ = false; + bool get_frames_only_ = false; + int width_ = 0; + int height_ = 0; +}; + +} // namespace fairseq2n::detail diff --git a/fairseq2n/src/fairseq2n/data/video/video_decoder.cc b/fairseq2n/src/fairseq2n/data/video/video_decoder.cc new file mode 100644 index 000000000..2349ff21b --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/video_decoder.cc @@ -0,0 +1,63 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "fairseq2n/data/video/video_decoder.h" + +#include +#include +#include + +#include +#include + +#include "fairseq2n/exception.h" +#include "fairseq2n/float.h" +#include "fairseq2n/fmt.h" +#include "fairseq2n/memory.h" +#include "fairseq2n/data/detail/tensor_helpers.h" +#include "fairseq2n/detail/exception.h" + +using namespace std; +using namespace fairseq2n::detail; + +namespace fairseq2n { + +video_decoder::video_decoder(video_decoder_options opts, bool pin_memory) + : opts_{opts} +{ + /* + dtype is used to determine the type of the output tensor for raw frame data only, + which is usually stored as unsigned 8-bit or 10-bit integers corresponding to Kbyte + and kShort in Pytorch. + */ + at::ScalarType dtype = opts_.maybe_dtype().value_or(at::kByte); + if (dtype != at::kByte && dtype != at::kShort) + throw_( + "`video_decoder` supports only `torch.int16` and `torch.uint8` data types."); + opts_.pin_memory(pin_memory); +} + +data +video_decoder::operator()(data &&d) const +{ + if (!d.is_memory_block()) + throw_(fmt::format( + "The input data must be of type `memory_block`, but is of type `{}` instead.", d.type())); + + const memory_block &block = d.as_memory_block(); + if (block.empty()) + throw_("The input memory block has zero length and cannot be decoded."); + + ffmpeg_decoder decoder(opts_); + + data_dict decoded_video = decoder.open_container(block); + + data_dict output; + output.emplace("video", std::move(decoded_video)); + return output; +} + +} // namespace fairseq2n diff --git a/fairseq2n/src/fairseq2n/data/video/video_decoder.h b/fairseq2n/src/fairseq2n/data/video/video_decoder.h new file mode 100644 index 000000000..e99531a6d --- /dev/null +++ b/fairseq2n/src/fairseq2n/data/video/video_decoder.h @@ -0,0 +1,35 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "fairseq2n/data/video/detail/ffmpeg.h" + +#include "fairseq2n/api.h" +#include "fairseq2n/data/data.h" + +#include +#include + +using namespace fairseq2n::detail; + +namespace fairseq2n { + +class FAIRSEQ2_API video_decoder { +public: + explicit + video_decoder(video_decoder_options opts = {}, bool pin_memory = false); + + data + operator()(data &&d) const; + +private: + video_decoder_options opts_; + ffmpeg_decoder decoder_; +}; + +} // namespace fairseq2n diff --git a/src/fairseq2/data/video.py b/src/fairseq2/data/video.py new file mode 100644 index 000000000..cb7d1bbf4 --- /dev/null +++ b/src/fairseq2/data/video.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import TYPE_CHECKING, Optional, TypedDict + +from torch import Tensor + +from fairseq2 import _DOC_MODE +from fairseq2.memory import MemoryBlock +from fairseq2.typing import DataType, Device + +if TYPE_CHECKING or _DOC_MODE: + + class VideoDecoder: + def __init__( + self, + dtype: Optional[DataType] = None, + device: Optional[Device] = None, + pin_memory: bool = False, + get_pts_only: Optional[bool] = False, + get_frames_only: Optional[bool] = False, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + ... + + def __call__(self, memory_block: MemoryBlock) -> "VideoDecoderOutput": + ... + +else: + from fairseq2n.bindings.data.video import VideoDecoder as VideoDecoder + + def _set_module_name() -> None: + for t in [VideoDecoder]: + t.__module__ = __name__ + + _set_module_name() + + +class VideoDecoderOutput(TypedDict): + video: Tensor diff --git a/tests/unit/data/video/__init__.py b/tests/unit/data/video/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/video/test.mp4 b/tests/unit/data/video/test.mp4 new file mode 100644 index 000000000..76e8176d4 Binary files /dev/null and b/tests/unit/data/video/test.mp4 differ diff --git a/tests/unit/data/video/test2.MP4 b/tests/unit/data/video/test2.MP4 new file mode 100644 index 000000000..9c78cf812 Binary files /dev/null and b/tests/unit/data/video/test2.MP4 differ diff --git a/tests/unit/data/video/test_video_decoder.py b/tests/unit/data/video/test_video_decoder.py new file mode 100644 index 000000000..e46f056a5 --- /dev/null +++ b/tests/unit/data/video/test_video_decoder.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from typing import Any, Final + +import pytest +import torch + +from fairseq2.data.video import VideoDecoder +from fairseq2.memory import MemoryBlock +from fairseq2.typing import DataType +from tests.common import assert_close, device + +TEST_MP4_PATH: Final = Path(__file__).parent.joinpath("test.mp4") +TEST_UNSUPPORTED_CODEC_PATH: Final = Path(__file__).parent.joinpath("test2.MP4") + + +class TestVideoDecoder: + @pytest.mark.parametrize("dtype", [torch.float16, torch.int64]) + def test_init_raises_error_when_data_type_is_not_supported( + self, dtype: DataType + ) -> None: + with pytest.raises( + ValueError, + match=r"^`video_decoder` supports only `torch.int16` and `torch.uint8` data types\.$", + ): + VideoDecoder(dtype=dtype) + + def test_call_works(self) -> None: + decoder = VideoDecoder(device=device) + + with TEST_MP4_PATH.open("rb") as fb: + block = MemoryBlock(fb.read()) + + output = decoder(block) + + video = output["video"]["0"]["all_video_frames"] + + assert video.shape == torch.Size([24, 2160, 3840, 3]) + + assert video.dtype == torch.uint8 + + assert video.device == device + + assert_close(video[0][0][0][0], torch.tensor(80, dtype=torch.uint8)) + + assert_close(video.sum(), torch.tensor(3816444602, device=device)) + + assert_close(output["video"]["0"]["frame_pts"].sum(), torch.tensor(11511489)) + + assert_close(output["video"]["0"]["timebase"], torch.tensor([1, 24000], dtype=torch.int32)) + + assert_close(output["video"]["0"]["fps"][0], torch.tensor(23.9760)) + + assert_close(output["video"]["0"]["duration"][0], torch.tensor(1024000)) + + def test_call_raises_error_when_codec_is_not_supported(self) -> None: + decoder = VideoDecoder() + + with TEST_UNSUPPORTED_CODEC_PATH.open("rb") as fb: + block = MemoryBlock(fb.read()) + + with pytest.raises( + RuntimeError, + match=r"^Failed to find decoder for stream 2", + ): + decoder(block) + + @pytest.mark.parametrize( + "value,type_name", [(None, "pyobj"), (123, "int"), ("s", "string")] + ) + def test_call_raises_error_when_input_is_not_memory_block( + self, value: Any, type_name: str + ) -> None: + decoder = VideoDecoder() + + with pytest.raises( + ValueError, + match=rf"^The input data must be of type `memory_block`, but is of type `{type_name}` instead\.$", + ): + decoder(value) + + def test_call_raises_error_when_input_is_empty(self) -> None: + decoder = VideoDecoder() + + empty_block = MemoryBlock() + + with pytest.raises( + ValueError, + match=r"^The input memory block has zero length and cannot be decoded\.$", + ): + decoder(empty_block) + + def test_call_raises_error_when_input_is_invalid(self) -> None: + decoder = VideoDecoder() + + block = MemoryBlock(b"foo") + + with pytest.raises( + ValueError, + match=r"^Failed to open input\.$", + ): + decoder(block)