Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] Support print runtime error log for xdnn error #71431

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 8 additions & 48 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1512,19 +1512,9 @@ void HeterComm<KeyType, ValType, GradType, GPUAccessor>::pull_merge_sparse(
auto xpu_context = xpu_dev_ctx.x_context();

int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUSSESS(r, "constant");
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r2,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r2,
XPUAPIErrorMsg[r2]));
PADDLE_ENFORCE_XDNN_SUSSESS(r2, "constant");
#endif

auto accessor_wrapper_ptr =
Expand Down Expand Up @@ -1692,19 +1682,9 @@ void HeterComm<KeyType, ValType, GradType, GPUAccessor>::pull_normal_sparse(
auto xpu_context = xpu_dev_ctx.x_context();

int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUSSESS(r, "constant");
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r2,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r2,
XPUAPIErrorMsg[r2]));
PADDLE_ENFORCE_XDNN_SUSSESS(r2, "constant");
#endif

auto d_idx = MemoryAlloc(place, len * sizeof(int));
Expand Down Expand Up @@ -1895,19 +1875,9 @@ void HeterComm<KeyType, ValType, GradType, GPUAccessor>::push_normal_sparse(
auto xpu_context = xpu_dev_ctx.x_context();

int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUSSESS(r, "constant");
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r2,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r2,
XPUAPIErrorMsg[r2]));
PADDLE_ENFORCE_XDNN_SUSSESS(r2, "constant");
#endif

auto d_idx = MemoryAlloc(place, len * sizeof(int));
Expand Down Expand Up @@ -2070,19 +2040,9 @@ void HeterComm<KeyType, ValType, GradType, GPUAccessor>::push_sparse(
auto xpu_context = xpu_dev_ctx.x_context();

int r = xpu::constant<int>(xpu_context, d_left_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUSSESS(r, "constant");
int r2 = xpu::constant<int>(xpu_context, d_right_ptr, total_device, -1);
PADDLE_ENFORCE_EQ(
r2,
XPU_SUCCESS,
common::errors::External("XPU constant kernel return wrong value[%d %s]",
r2,
XPUAPIErrorMsg[r2]));
PADDLE_ENFORCE_XDNN_SUSSESS(r2, "constant");
#endif

auto d_idx = MemoryAlloc(place, len * sizeof(int));
Expand Down
67 changes: 44 additions & 23 deletions paddle/phi/backends/xpu/enforce_xpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/enforce.h"
#include "xre/cuda_runtime_api.h"

namespace phi {
namespace backends {
Expand Down Expand Up @@ -116,6 +117,7 @@ inline const char* bkclGetErrorString(BKCLResult_t stat) {
#endif

inline const char* xdnnGetErrorString(int stat) {
// Also reused by xfa and xpudnn apis.
switch (stat) {
case baidu::xpu::api::Error_t::SUCCESS:
return "XDNN_SUCCESS";
Expand All @@ -133,19 +135,31 @@ inline const char* xdnnGetErrorString(int stat) {
}

inline std::string build_xpu_error_msg(int stat) {
std::string msg("XPU Error <" + std::to_string(stat) + ">, ");
return msg + xpuGetErrorString(stat) + " ";
std::string error_msg = "XPU Error <" + std::to_string(stat) + ">, " +
xpuGetErrorString(stat) + " ";
return error_msg;
}

#ifdef PADDLE_WITH_XPU_BKCL
inline std::string build_xpu_error_msg(BKCLResult_t stat) {
std::string msg("BKCL Error, ");
return msg + bkclGetErrorString(stat) + " ";
std::string error_msg = "BKCL Error <" + std::to_string(stat) + ">, " +
bkclGetErrorString(stat) + " ";
return error_msg;
}
#endif

inline std::string build_xpu_xdnn_error_msg(int stat, std::string msg) {
return msg + " XDNN Error, " + xdnnGetErrorString(stat) + " ";
inline std::string build_xdnn_error_msg(int stat, std::string msg) {
std::string error_msg = msg + "XDNN Error <" + std::to_string(stat) + ">, " +
xdnnGetErrorString(stat) + " ";
return error_msg;
}

inline std::string build_runtime_error_msg() {
auto rt_error_code = cudaGetLastError();
std::string error_msg = "XPU Runtime Error <" +
std::to_string(rt_error_code) + ">, " +
std::string(cudaGetErrorString(rt_error_code)) + " ";
return error_msg;
}

namespace details {
Expand Down Expand Up @@ -183,25 +197,32 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS);
} \
} while (0)

#define PADDLE_ENFORCE_XDNN_SUCCESS(COND, MSG) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(__cond__ != baidu::xpu::api::Error_t::SUCCESS)) { \
auto __summary__ = common::errors::External( \
::phi::backends::xpu::build_xpu_xdnn_error_msg(__cond__, MSG)); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
#define PADDLE_ENFORCE_XDNN_SUCCESS(COND, MSG) \
do { \
auto __cond__ = (COND); \
if (UNLIKELY(__cond__ != baidu::xpu::api::Error_t::SUCCESS)) { \
if (__cond__ == baidu::xpu::api::Error_t::RUNTIME_ERROR) { \
auto __summary__ = common::errors::External( \
::phi::backends::xpu::build_xdnn_error_msg(__cond__, MSG) + "\n" + \
::phi::backends::xpu::build_runtime_error_msg()); \
__THROW_ERROR_INTERNAL__(__summary__); \
} else { \
auto __summary__ = common::errors::External( \
::phi::backends::xpu::build_xdnn_error_msg(__cond__, MSG)); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
} \
} while (0)

#define PADDLE_ENFORCE_XDNN_NOT_NULL(ptr) \
do { \
if (UNLIKELY(ptr == nullptr)) { \
auto __summary__ = common::errors::External( \
::phi::backends::xpu::build_xpu_xdnn_error_msg( \
baidu::xpu::api::Error_t::NO_ENOUGH_WORKSPACE, \
"XPU memory is not enough")); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
#define PADDLE_ENFORCE_XDNN_NOT_NULL(ptr) \
do { \
if (UNLIKELY(ptr == nullptr)) { \
auto __summary__ = \
common::errors::External(::phi::backends::xpu::build_xdnn_error_msg( \
baidu::xpu::api::Error_t::NO_ENOUGH_WORKSPACE, \
"XPU memory is not enough")); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
} while (0)
#define PADDLE_ENFORCE_XRE_SUCCESS(COND) \
do { \
Expand Down
44 changes: 38 additions & 6 deletions paddle/phi/backends/xpu/xpu_header.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ limitations under the License. */

namespace xpu = baidu::xpu::api;

static std::map<int, std::string> XPUAPIErrorMsg = {
{xpu::Error_t::SUCCESS, "xpu api success"},
{xpu::Error_t::INVALID_PARAM, "xpu api invalid param"},
{xpu::Error_t::RUNTIME_ERROR, "xpu api runtime error"},
{xpu::Error_t::NO_ENOUGH_WORKSPACE, "xpu api no enough workspace"}};

template <typename T>
class XPUTypeTrait {
public:
Expand Down Expand Up @@ -75,4 +69,42 @@ class XPUTypeToPhiType<bfloat16> {
using Type = phi::dtype::bfloat16;
};

// XPUCopyTypeTrait is the same as XPUTypeTrait except for double, int16_t, and
// uint8_t. Used for ops that simply copy data and do not need to calculate
template <typename T>
class XPUCopyTypeTrait {
public:
using Type = T;
};

template <>
class XPUCopyTypeTrait<phi::dtype::float16> {
public:
using Type = float16;
};

template <>
class XPUCopyTypeTrait<phi::dtype::bfloat16> {
public:
using Type = bfloat16;
};

template <>
class XPUCopyTypeTrait<double> {
public:
using Type = int64_t;
};

template <>
class XPUCopyTypeTrait<int16_t> {
public:
using Type = float16;
};

template <>
class XPUCopyTypeTrait<uint8_t> {
public:
using Type = int8_t;
};

#endif
10 changes: 5 additions & 5 deletions paddle/phi/kernels/xpu/addmm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ void AddmmKernel(const Context& dev_ctx,
broadcast_flag = true;
input_2d_ptr = RAII_GUARD.alloc_l3_or_gm<XPUType>(x_dims[0] * y_dims[1]);
PADDLE_ENFORCE_XDNN_NOT_NULL(input_2d_ptr);
int r = xpu::broadcast<XPUType>(dev_ctx.x_context(),
input_ptr,
input_2d_ptr,
common::vectorize<int64_t>(input_dims),
{x_dims[0], y_dims[1]});
r = xpu::broadcast<XPUType>(dev_ctx.x_context(),
input_ptr,
input_2d_ptr,
common::vectorize<int64_t>(input_dims),
{x_dims[0], y_dims[1]});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
}

Expand Down
27 changes: 4 additions & 23 deletions paddle/phi/kernels/xpu/affine_channel_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,40 +77,21 @@ void AffineChannelGradXPUKernel(const Context& dev_ctx,
int r = 0;
if (dscale_d && dbias_d) {
r = xpu::reduce_sum<T>(dev_ctx.x_context(), dy_d, dbias_d, x_shape, rdims);
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
common::errors::External(
"The reduce_sum XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
T* tmp = RAII_GUARD.alloc_l3_or_gm<T>(dy->numel());
PADDLE_ENFORCE_NOT_NULL(
tmp, common::errors::External("XPU has no enough memory"));

r = xpu::mul<T>(dev_ctx.x_context(), dy_d, x->data<T>(), tmp, dy->numel());
PADDLE_ENFORCE_EQ(
r,
xpu::Error_t::SUCCESS,
common::errors::External(
"The mul XPU OP return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
r = xpu::reduce_sum<T>(dev_ctx.x_context(), tmp, dscale_d, x_shape, rdims);
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
common::errors::External(
"The reduce_sum XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
if (dx_d) {
r = xpu::broadcast_mul(
dev_ctx.x_context(), dy_d, scale_d, dx_d, x_shape, b_shape);
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
common::errors::External(
"The broadcast_mul XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul");
}
}
} // namespace phi
Expand Down
14 changes: 2 additions & 12 deletions paddle/phi/kernels/xpu/affine_channel_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,10 @@ void AffineChannelXPUKernel(const Context& dev_ctx,
int r = 0;
r = xpu::broadcast_mul(
dev_ctx.x_context(), x_d, scale_d, y_d, x_shape, b_shape);
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
common::errors::External(
"The broadcast_mul XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul");
r = xpu::broadcast_add(
dev_ctx.x_context(), y_d, bias_d, y_d, x_shape, b_shape);
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
common::errors::External(
"The broadcast_add XPU OP return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}

} // namespace phi
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/xpu/amp_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void UpdateLossScalingKernel(const Context& dev_ctx,
DenseTensor* out_good_steps,
DenseTensor* out_bad_steps) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
using XPUTyp = typename XPUTypeTrait<T>::Type;
using XPUType = typename XPUTypeTrait<T>::Type;

PADDLE_ENFORCE_EQ(found_infinite.numel(),
1,
Expand All @@ -72,9 +72,9 @@ void UpdateLossScalingKernel(const Context& dev_ctx,
VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --";
int r = 0;
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUTyp*>(out_data),
reinterpret_cast<XPUType*>(out_data),
num,
XPUTyp(0.0));
XPUType(0.0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
}
Expand Down
Loading