Skip to content

Commit

Permalink
Code review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <[email protected]>
  • Loading branch information
jantonguirao committed Feb 11, 2025
1 parent 1755c75 commit 85caa8c
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions dali/operators/image/remap/warp_param_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#define DALI_OPERATORS_IMAGE_REMAP_WARP_PARAM_PROVIDER_H_

#include <cassert>
#include <vector>
#include <string>
#include <vector>
#include "dali/core/dev_buffer.h"
#include "dali/core/mm/memory.h"
#include "dali/core/static_switch.h"
Expand All @@ -42,7 +42,7 @@ class InterpTypeProvider {
auto &tensor_list = ws.ArgumentInput("interp_type");
int n = tensor_list.shape().num_samples();
DALI_ENFORCE(n == 1 || n == num_samples,
"interp_type must be a single value or contain one value per sample");
"interp_type must be a single value or contain one value per sample");
interp_types_.resize(n);
for (int i = 0; i < n; i++)
interp_types_[i] = tensor_list.tensor<DALIInterpType>(i)[0];
Expand All @@ -52,7 +52,7 @@ class InterpTypeProvider {

for (size_t i = 0; i < interp_types_.size(); i++) {
DALI_ENFORCE(interp_types_[i] == DALI_INTERP_NN || interp_types_[i] == DALI_INTERP_LINEAR,
"Only nearest and linear interpolation is supported");
"Only nearest and linear interpolation is supported");
}
}

Expand All @@ -66,6 +66,7 @@ class BorderTypeProvider {
BorderType Border() const {
return border_;
}

protected:
void SetBorder(const OpSpec &spec) {
float fborder;
Expand All @@ -79,8 +80,7 @@ class BorderTypeProvider {
};

template <>
inline void BorderTypeProvider<kernels::BorderClamp>::SetBorder(const OpSpec &spec) {
}
inline void BorderTypeProvider<kernels::BorderClamp>::SetBorder(const OpSpec &spec) {}

/** @brief Provides warp parameters
*
Expand Down Expand Up @@ -216,22 +216,24 @@ class WarpParamProvider : public InterpTypeProvider, public BorderTypeProvider<B

virtual void ResetParams() {
params_gpu_ = {};
params_gpu_sz_ = 0;
params_cpu_ = {};
params_cpu_sz_ = 0;
params_count_ = 0;
}

virtual void SetParams() {
}
virtual void SetParams() {}

virtual void AdjustParams() {
}
virtual void AdjustParams() {}

virtual void ValidateOutputSizes() {}

virtual void GetUniformOutputSize(SpatialShape &out_size) const {
assert(HasExplicitSize() && !HasExplicitPerSampleSize());
std::vector<float> out_size_f = spec_->template GetArgument<std::vector<float>>("size");
DALI_ENFORCE(static_cast<int>(out_size_f.size()) == spatial_ndim,
"output_size must specify same number of dimensions as the input (excluding channels)");
DALI_ENFORCE(
static_cast<int>(out_size_f.size()) == spatial_ndim,
"output_size must specify same number of dimensions as the input (excluding channels)");
for (int d = 0; d < spatial_ndim; d++) {
float s = out_size_f[d];
DALI_ENFORCE(s > 0, "Output size must be positive");
Expand All @@ -246,12 +248,11 @@ class WarpParamProvider : public InterpTypeProvider, public BorderTypeProvider<B
const int N = num_samples_;

DALI_ENFORCE(is_uniform(shape), "Output sizes must be passed as uniform Tensor List.");
DALI_ENFORCE(
(shape.num_samples() == N && shape[0] == TensorShape<>(spatial_ndim)) ||
(shape.num_samples() == 1 && (shape[0] == TensorShape<>(N, spatial_ndim) ||
shape[0] == TensorShape<>(N * spatial_ndim))),
"Output sizes must either be a batch of `dim`-sized tensors, flat array of size "
"num_samples*dim or one 2D tensor of shape {num_samples, dim}.");
DALI_ENFORCE((shape.num_samples() == N && shape[0] == TensorShape<>(spatial_ndim)) ||
(shape.num_samples() == 1 && (shape[0] == TensorShape<>(N, spatial_ndim) ||
shape[0] == TensorShape<>(N * spatial_ndim))),
"Output sizes must either be a batch of `dim`-sized tensors, flat array of size "
"num_samples*dim or one 2D tensor of shape {num_samples, dim}.");

out_sizes.resize(N);
if (shape.num_samples() == N) {
Expand All @@ -261,7 +262,7 @@ class WarpParamProvider : public InterpTypeProvider, public BorderTypeProvider<B
} else {
for (int i = 0; i < N; i++)
for (int d = 0; d < spatial_ndim; d++)
out_sizes[i][d] = shape_list.data[0][i*N + d];
out_sizes[i][d] = shape_list.data[0][i * N + d];
}
}

Expand Down Expand Up @@ -316,17 +317,23 @@ class WarpParamProvider : public InterpTypeProvider, public BorderTypeProvider<B
/** @brief Allocates num_samples_ MappingParams objects in memory specified by alloc */
template <typename MemoryKind>
MappingParams *AllocParams(int num_samples) {
return AllocParams(static_cast<MemoryKind*>(nullptr), num_samples);
return AllocParams(static_cast<MemoryKind *>(nullptr), num_samples);
}

inline MappingParams *AllocParams(mm::memory_kind::device *, int count) {
params_gpu_ = mm::alloc_raw_unique<MappingParams, mm::memory_kind::device>(count);
if (count > params_gpu_sz_) {
params_gpu_ = mm::alloc_raw_unique<MappingParams, mm::memory_kind::device>(count);
params_gpu_sz_ = count;
}
params_count_ = count;
return params_gpu_.get();
}

inline MappingParams *AllocParams(mm::memory_kind::host *, int count) {
params_cpu_ = mm::alloc_raw_unique<MappingParams, mm::memory_kind::host>(count);
if (count > params_cpu_sz_) {
params_cpu_ = mm::alloc_raw_unique<MappingParams, mm::memory_kind::host>(count);
params_cpu_sz_ = count;
}
params_count_ = count;
return params_cpu_.get();
}
Expand All @@ -341,7 +348,9 @@ class WarpParamProvider : public InterpTypeProvider, public BorderTypeProvider<B
std::vector<SpatialShape> out_sizes_;
mm::uptr<MappingParams> params_gpu_;
mm::uptr<MappingParams> params_cpu_;
int params_count_ = -1;
int params_cpu_sz_ = 0;
int params_gpu_sz_ = 0;
int params_count_ = 0;
};

} // namespace dali
Expand Down

0 comments on commit 85caa8c

Please sign in to comment.