Skip to content

Commit

Permalink
fix_xpu_error_for_dynamic_shape_send_recv
Browse files Browse the repository at this point in the history
  • Loading branch information
AndSonder committed Mar 5, 2025
1 parent 4cd5d5a commit 751b78b
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions paddle/phi/kernels/funcs/send_recv_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_array.h"

#include "glog/logging.h"

namespace phi {

#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
Expand All @@ -31,11 +33,20 @@ void send_shape_info(const Context& dev_ctx,
CommContext* comm_ctx,
int peer,
StreamType stream) {
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
PADDLE_ENFORCE_EQ((stream != nullptr && comm_ctx != nullptr),
true,
errors::InvalidArgument(
"NCCLComm and Stream should be provided if use NCCL "
"to send the shape info."));
#elif defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ((comm_ctx != nullptr),
true,
errors::InvalidArgument(
"Stream should be provided if use BKCL "
"to send the shape info."));
#endif
paddle::DataType shape_dtype = paddle::DataType::INT32;
auto dims = x.dims();
int shape_size = dims.size();
Expand Down Expand Up @@ -93,11 +104,20 @@ DDim recv_shape_info(const Context& dev_ctx,
CommContext* comm_ctx,
int peer) {
StreamType stream = dev_ctx.stream();
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
PADDLE_ENFORCE_EQ((stream != nullptr && comm_ctx != nullptr),
true,
errors::InvalidArgument(
"NCCLComm and Stream should be provided if use NCCL "
"to send the shape info."));
#elif defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ((comm_ctx != nullptr),
true,
errors::InvalidArgument(
"Stream should be provided if use BKCL "
"to send the shape info."));
#endif
paddle::DataType shape_dtype = paddle::DataType::INT32;

// phi::DenseTensor shape_size_tensortensor(shape_dtype);
Expand Down

0 comments on commit 751b78b

Please sign in to comment.