From 230c6ce115fbf9d20c71789cc6817f77b26dfba1 Mon Sep 17 00:00:00 2001 From: yangguohao <70266361+yangguohao@users.noreply.github.com> Date: Thu, 3 Aug 2023 19:44:09 +0800 Subject: [PATCH] FLUID: move limit_by_capacity to PHI (#55948) --- .../fluid/operators/limit_by_capacity_op.cc | 15 ++-- .../fluid/operators/limit_by_capacity_op.cu | 85 ------------------- paddle/fluid/operators/limit_by_capacity_op.h | 37 -------- .../kernels/cpu/limit_by_capacity_kernel.cc | 41 +++++++++ .../kernels/gpu/limit_by_capacity_kernel.cu | 67 +++++++++++++++ paddle/phi/kernels/limit_by_capacity_kernel.h | 28 ++++++ .../phi/ops/compat/limit_by_capacity_sig.cc | 28 ++++++ 7 files changed, 171 insertions(+), 130 deletions(-) delete mode 100644 paddle/fluid/operators/limit_by_capacity_op.cu delete mode 100644 paddle/fluid/operators/limit_by_capacity_op.h create mode 100644 paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc create mode 100644 paddle/phi/kernels/gpu/limit_by_capacity_kernel.cu create mode 100644 paddle/phi/kernels/limit_by_capacity_kernel.h create mode 100644 paddle/phi/ops/compat/limit_by_capacity_sig.cc diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc index e4ce30d41ae63..26f88c305d7b4 100644 --- a/paddle/fluid/operators/limit_by_capacity_op.cc +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -12,7 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/limit_by_capacity_op.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif namespace paddle { namespace operators { @@ -80,10 +86,3 @@ namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity, ops::LimitByCapacityOp, ops::LimitByCapacityOpMaker); - -PD_REGISTER_STRUCT_KERNEL(limit_by_capacity, - CPU, - ALL_LAYOUT, - ops::LimitByCapacityOpCPUKernel, - int, - int64_t) {} diff --git a/paddle/fluid/operators/limit_by_capacity_op.cu b/paddle/fluid/operators/limit_by_capacity_op.cu deleted file mode 100644 index 4ddc921144843..0000000000000 --- a/paddle/fluid/operators/limit_by_capacity_op.cu +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// The file has been adapted from the two files: -// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cu -// https://github.com/laekov/fastmoe/blob/master/cuda/balancing.cuh -// Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4 -// We retain the following license from the original files: -// Copyright 2021, Jiaao He. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"). - -#include "paddle/fluid/operators/limit_by_capacity_op.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -template -__global__ void limit_by_capacity_impl( - const T* expc, T* cap, T* out, const int n_expert, const int n_worker) { - int eid, wid; - CUDA_KERNEL_LOOP(i, (n_expert * n_worker)) { - wid = i / n_expert; - eid = i % n_expert; - auto proposal = expc[wid * n_expert + eid]; - auto cap_left = phi::CudaAtomicAdd(cap + eid, proposal * (-1)); - if (cap_left >= proposal) { - out[wid * n_expert + eid] = proposal; - } else if (cap_left >= 0) { - out[wid * n_expert + eid] = cap_left; - } else { - out[wid * n_expert + eid] = 0; - } - } -} - -template -class LimitByCapacityOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto expert_count = context.Input("expert_count"); - auto capacity = context.Input("capacity"); - auto n_worker = context.Attr("n_worker"); - auto out = context.Output("Out"); - - auto n_expert = expert_count->numel() / n_worker; - const auto place = context.GetPlace(); - const auto& dev_ctx = context.template device_context(); - - dim3 grid_dim(256); - dim3 block_dim(1024); - auto out_data = out->mutable_data(place); - const T* ec_data = expert_count->data(); - - phi::DenseTensor capacity_copy; - framework::TensorCopy(*capacity, place, dev_ctx, &capacity_copy); - T* cap_data = capacity_copy.mutable_data(place); - - limit_by_capacity_impl<<>>( - ec_data, cap_data, out_data, n_expert, n_worker); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -PD_REGISTER_STRUCT_KERNEL(limit_by_capacity, - GPU, - ALL_LAYOUT, - ops::LimitByCapacityOpCUDAKernel, - int64_t) {} diff --git a/paddle/fluid/operators/limit_by_capacity_op.h b/paddle/fluid/operators/limit_by_capacity_op.h deleted file mode 100644 index c08183b5f1a67..0000000000000 --- a/paddle/fluid/operators/limit_by_capacity_op.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" - -#if defined(PADDLE_WITH_GLOO) -#include "paddle/fluid/framework/fleet/gloo_wrapper.h" -#endif - -namespace paddle { -namespace operators { - -template -class LimitByCapacityOpCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW(platform::errors::Unavailable( - "Do not support limit by capacity op for cpu kernel now.")); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc b/paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc new file mode 100644 index 0000000000000..ea2f6cbc6ee82 --- /dev/null +++ b/paddle/phi/kernels/cpu/limit_by_capacity_kernel.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/limit_by_capacity_kernel.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#endif +namespace phi { + +template +void LimitByCapacityKernel(const Context& dev_ctx, + const DenseTensor& expert_count, + const DenseTensor& capacity, + int n_worker, + DenseTensor* Out) { + PADDLE_THROW( + phi::errors::Unimplemented("limit_by_capacity is not supported on CPU.")); +} + +} // namespace phi + +PD_REGISTER_KERNEL(limit_by_capacity, + CPU, + ALL_LAYOUT, + phi::LimitByCapacityKernel, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/limit_by_capacity_kernel.cu b/paddle/phi/kernels/gpu/limit_by_capacity_kernel.cu new file mode 100644 index 0000000000000..82d2cc47b5fad --- /dev/null +++ b/paddle/phi/kernels/gpu/limit_by_capacity_kernel.cu @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/limit_by_capacity_kernel.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { + +template +__global__ void limit_by_capacity_impl( + const T* expc, T* cap, T* out, const int n_expert, const int n_worker) { + int eid, wid; + CUDA_KERNEL_LOOP(i, (n_expert * n_worker)) { + wid = i / n_expert; + eid = i % n_expert; + auto proposal = expc[wid * n_expert + eid]; + auto cap_left = phi::CudaAtomicAdd(cap + eid, proposal * (-1)); + if (cap_left >= proposal) { + out[wid * n_expert + eid] = proposal; + } else if (cap_left >= 0) { + out[wid * n_expert + eid] = cap_left; + } else { + out[wid * n_expert + eid] = 0; + } + } +} + +template +void LimitByCapacityKernel(const Context& dev_ctx, + const DenseTensor& expert_count, + const DenseTensor& capacity, + int n_worker, + DenseTensor* Out) { + auto expert_count_ptr = &expert_count; + auto n_expert = expert_count_ptr->numel() / n_worker; + + dim3 grid_dim(256); + dim3 block_dim(1024); + auto out_data = dev_ctx.template Alloc(Out); + const T* ec_data = expert_count_ptr->data(); + + phi::DenseTensor capacity_copy; + phi::Copy(dev_ctx, capacity, dev_ctx.GetPlace(), false, &capacity_copy); + T* cap_data = dev_ctx.template Alloc(&capacity_copy); + + limit_by_capacity_impl<<>>( + ec_data, cap_data, out_data, n_expert, n_worker); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + limit_by_capacity, GPU, ALL_LAYOUT, phi::LimitByCapacityKernel, int64_t) {} diff --git a/paddle/phi/kernels/limit_by_capacity_kernel.h b/paddle/phi/kernels/limit_by_capacity_kernel.h new file mode 100644 index 0000000000000..b01e61832ebfe --- /dev/null +++ b/paddle/phi/kernels/limit_by_capacity_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LimitByCapacityKernel(const Context& dev_ctx, + const DenseTensor& expert_count, + const DenseTensor& capacity, + int n_worker, + DenseTensor* Out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/limit_by_capacity_sig.cc b/paddle/phi/ops/compat/limit_by_capacity_sig.cc new file mode 100644 index 0000000000000..939d9aad4dee7 --- /dev/null +++ b/paddle/phi/ops/compat/limit_by_capacity_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LimitByCapacityOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature( + "limit_by_capacity", {"expert_count", "capacity"}, {"n_worker"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(limit_by_capacity, + phi::LimitByCapacityOpArgumentMapping);