Skip to content

Commit

Permalink
Implement different padding strategies
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Feb 10, 2025
1 parent 6a11c6c commit c1137a9
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 227 deletions.
325 changes: 205 additions & 120 deletions dali/operators/decoder/video/video_decoder_base.h

Large diffs are not rendered by default.

61 changes: 49 additions & 12 deletions dali/operators/decoder/video/video_decoder_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,63 @@ Each output sample is a sequence of frames with shape (F, H, W, C) where:
- F is the number of frames in the sequence (can vary between samples)
- H is the frame height in pixels
- W is the frame width in pixels
- C is the number of color channels)code")
- C is the number of color channels
.. code-block:: python
video_decoder = dali.experimental.decoders.Video(
frames=[0, 10, 20, 30, 40, 50, 40, 30, 20, 10, 0]
)
Example 2: Extract a sequence of frames:
.. code-block:: python
video_decoder = dali.experimental.decoders.Video(
start_frame=0, sequence_length=10, stride=2
)
Example 3: Pad the sequence by repeating the last frame:
.. code-block:: python
video_decoder = dali.experimental.decoders.Video(
start_frame=0, sequence_length=100, stride=2, pad_mode="edge"
)
)code")
.NumInput(1)
.NumOutput(1)
.InputDox(0, "buffer", "TensorList", "Memory buffer containing the encoded video file data")
.InputDox(0, "encoded", "TensorList", "Encoded video stream")
.AddOptionalArg("affine",
R"code(Whether to pin threads to CPU cores (mixed backend only).
If True, each thread in the internal thread pool will be pinned to a specific CPU core.
If False, threads can migrate between cores based on OS scheduling.)code", true)
.AddOptionalArg<int64_t>("start_frame",
R"code(Index of the first frame to extract from each video)code", nullptr, true)
.AddOptionalArg<int64_t>("stride",
R"code(Number of frames to skip between each extracted frame)code", nullptr, true)
.AddOptionalArg<int64_t>("sequence_length",
R"code(Number of frames to extract from each video. If not provided, the whole video is decoded.)code", nullptr, true)
.AddOptionalArg<std::vector<int>>("frames",
R"code(Specifies which frames to extract from each video by their indices.
The indices can be provided in any order and can include duplicates. For example, [0,10,5,10] would extract:
- Frame 0 (first frame)
- Frame 10
- Frame 5
- Frame 10 (again)
This argument cannot be used together with ``start_frame``, ``sequence_length``, ``stride`` or ``pad_mode``.)code", nullptr, true)
.AddOptionalArg<int>("start_frame",
R"code(Index of the first frame to extract from each video. Cannot be used together with frames argument.)code", nullptr, true)
.AddOptionalArg<int>("stride",
R"code(Number of frames to skip between each extracted frame. Cannot be used together with ``frames`` argument.)code", nullptr, true)
.AddOptionalArg<int>("sequence_length",
R"code(Number of frames to extract from each video. If not provided, the whole video is decoded. Cannot be used together with ``frames`` argument.)code", nullptr, true)
.AddOptionalArg<std::string>("pad_mode",
R"code(How to handle videos with insufficient frames:
- none: Return shorter sequences if not enough frames
- constant: Pad with a fixed value (specified by ``pad_value``)
- edge: Repeat the last valid frame)code", "constant", true)
R"code(How to handle videos with insufficient frames when using start_frame/sequence_length/stride:
- 'none': Return shorter sequences if not enough frames: ABC -> ABC
- 'constant': Pad with a fixed value (specified by ``pad_value``): ABC -> ABCPPP
- 'edge': Repeat the last valid frame: ABC -> ABCCCC
- 'reflect_1001' or 'symmetric': Reflect padding, including the last element: ABC -> ABCCBA
- 'reflect_101' or 'reflect': Reflect padding, not including the last element: ABC -> ABCBA
Not relevant when using frames argument.)code", "constant", true)
.AddOptionalArg<int>("pad_value",
R"code(Value used to pad missing frames when pad_mode='constant'. Must be in range [0, 255].)code", 0);

Expand Down
56 changes: 23 additions & 33 deletions dali/operators/reader/loader/video/frames_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ FramesDecoder::FramesDecoder(const std::string &filename)
}
InitAvState();
BuildIndex();
DetectVfr();
is_valid_ = true;
}

Expand Down Expand Up @@ -260,13 +259,9 @@ FramesDecoder::FramesDecoder(const char *memory_file, int memory_file_size, bool
ParseNumFrames();
}

if (!build_index) {
is_valid_ = true;
return;
if (build_index) {
BuildIndex();
}

BuildIndex();
DetectVfr();
is_valid_ = true;
}

Expand Down Expand Up @@ -368,14 +363,17 @@ void FramesDecoder::BuildIndex() {
index_ = vector<IndexEntry>();

int last_keyframe = -1;
while (av_read_frame(av_state_->ctx_, av_state_->packet_) >= 0) {
IndexEntry entry;
// We want to make sure that we call av_packet_unref in every iteration
auto packet = AVPacketScope(av_state_->packet_, av_packet_unref);

while (true) {
int ret = av_read_frame(av_state_->ctx_, av_state_->packet_);
auto packet = AVPacketScope(av_state_->packet_, av_packet_unref);
if (ret != 0) {
break; // End of file
}
if (packet->stream_index != av_state_->stream_id_) {
continue;
}
IndexEntry entry;
if (packet->flags & AV_PKT_FLAG_KEY) {
if (index_->size() == 0) {
entry.is_keyframe = true;
Expand Down Expand Up @@ -414,31 +412,27 @@ void FramesDecoder::BuildIndex() {
entry.last_keyframe_id = last_keyframe;
index_->push_back(entry);
}
av_packet_unref(av_state_->packet_);
index_->back().is_flush_frame = true;
// Make sure that the index is sorted by pts as some frames may not be in the presentation
// order in the container
std::sort(index_->begin(), index_->end(), [](const IndexEntry &a, const IndexEntry &b) {
return a.pts < b.pts;
});
Reset();
}

void FramesDecoder::DetectVfr() {
if (NumFrames() < 3) {
is_vfr_ = false;
return;
}

int pts_step = Index(1).pts - Index(0).pts;
for (int frame_id = 2; frame_id < NumFrames(); ++frame_id) {
if ((Index(frame_id).pts - Index(frame_id - 1).pts) != pts_step) {
is_vfr_ = true;
return;
is_vfr_ = false;
auto& index = *index_;
if (NumFrames() > 3) {
int pts_step = index[1].pts - index[1].pts;
for (int frame_id = 2; frame_id < NumFrames(); ++frame_id) {
if ((index[frame_id].pts - index[frame_id - 1].pts) != pts_step) {
is_vfr_ = true;
break;
}
}
}

is_vfr_ = false;

Reset();
}

void FramesDecoder::CopyToOutput(uint8_t *data) {
Expand Down Expand Up @@ -555,6 +549,9 @@ void FramesDecoder::SeekFrame(int frame_id) {
frame_id >= 0 && frame_id < NumFrames(),
make_string("Invalid seek frame id. frame_id = ", frame_id, ", num_frames = ", NumFrames()));

if (!HasIndex()) {
DALI_FAIL("Functionality is unavailible when index is not built.");
}
auto &frame_entry = Index(frame_id);
int keyframe_id = frame_entry.last_keyframe_id;
auto &keyframe_entry = Index(keyframe_id);
Expand Down Expand Up @@ -621,11 +618,4 @@ bool FramesDecoder::ReadNextFrame(uint8_t *data, bool copy_to_output) {
return ReadFlushFrame(data, copy_to_output);
}

const IndexEntry &FramesDecoder::Index(int frame_id) const {
if (!index_.has_value()) {
DALI_FAIL("Functionality is unavailible when index is not built.");
}

return (*index_)[frame_id];
}
} // namespace dali
16 changes: 8 additions & 8 deletions dali/operators/reader/loader/video/frames_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ class DLL_PUBLIC FramesDecoder {
int NextFrameIdx() { return next_frame_idx_; }

/**
* @brief Returns true if the index was built.
*
* @return Boolean indicating whether or not the index was created.
*/
* @brief Returns true if the index was built.
*
* @return Boolean indicating whether or not the index was created.
*/
bool HasIndex() const { return index_.has_value(); }

FramesDecoder(FramesDecoder&&) = default;
Expand All @@ -214,13 +214,15 @@ class DLL_PUBLIC FramesDecoder {
return is_valid_;
}

const IndexEntry& Index(int frame_id) const {
return (*index_)[frame_id];
}

protected:
std::unique_ptr<AvState> av_state_;

std::optional<std::vector<IndexEntry>> index_ = {};

const IndexEntry &Index(int frame_id) const;

int next_frame_idx_ = 0;

bool is_full_range_ = false;
Expand Down Expand Up @@ -268,8 +270,6 @@ class DLL_PUBLIC FramesDecoder {

bool CheckCodecSupport();

void DetectVfr();

void ParseNumFrames();

void CreateAvState(std::unique_ptr<AvState> &av_state, bool init_codecs);
Expand Down
16 changes: 11 additions & 5 deletions dali/operators/reader/loader/video/frames_decoder_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,12 @@ int FramesDecoderGpu::HandlePictureDisplay(CUVIDPARSERDISPINFO *picture_display_
piped_pts_.pop();

// current_pts is pts of frame that came from the decoder
// NextFramePts() is pts of the frame that we want to return
// Index(NextFrameIdx()).pts is pts of the frame that we want to return
// in this call to ReadNextFrame
// If they are the same, we just return this frame
// If not, we store it in the buffer for later

if (HasIndex() && current_pts == NextFramePts()) {
if (HasIndex() && current_pts == Index(NextFrameIdx()).pts) {
// Currently decoded frame is actually the one we wanted
frame_returned_ = true;

Expand Down Expand Up @@ -530,6 +530,9 @@ void FramesDecoderGpu::SeekFrame(int frame_id) {

bool FramesDecoderGpu::ReadNextFrameWithIndex(uint8_t *data, bool copy_to_output) {
// Check if requested frame was buffered earlier
if (!HasIndex()) {
DALI_FAIL("Functionality is unavailible when index is not built.");
}
for (auto &frame : frame_buffer_) {
if (frame.pts_ != -1 && frame.pts_ == Index(next_frame_idx_).pts) {
if (copy_to_output) {
Expand All @@ -548,9 +551,13 @@ bool FramesDecoderGpu::ReadNextFrameWithIndex(uint8_t *data, bool copy_to_output
current_copy_to_output_ = copy_to_output;
current_frame_output_ = data;

while (av_read_frame(av_state_->ctx_, av_state_->packet_) >= 0) {
// We want to make sure that we call av_packet_unref in every iteration
while (true) {
int ret = av_read_frame(av_state_->ctx_, av_state_->packet_);
auto packet = AVPacketScope(av_state_->packet_, av_packet_unref);
if (ret != 0) {
break; // No more frames in the file
}

if (!SendFrameToParser()) {
continue;
}
Expand All @@ -560,7 +567,6 @@ bool FramesDecoderGpu::ReadNextFrameWithIndex(uint8_t *data, bool copy_to_output
return true;
}
}
av_packet_unref(av_state_->packet_);

DALI_ENFORCE(piped_pts_.size() == 1);

Expand Down
2 changes: 0 additions & 2 deletions dali/operators/reader/loader/video/frames_decoder_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ class DLL_PUBLIC FramesDecoderGpu : public FramesDecoder {

void Reset() override;

int NextFramePts() { return Index(NextFrameIdx()).pts; }

int ProcessPictureDecode(CUVIDPICPARAMS *picture_params);

int HandlePictureDisplay(CUVIDPARSERDISPINFO *picture_display_info);
Expand Down
Loading

0 comments on commit c1137a9

Please sign in to comment.