Skip to content

Commit

Permalink
[CINN] Update ir on gpu codegen (#70829)
Browse files Browse the repository at this point in the history
* [CINN] Update ir on codegen

* update collect ir nodes

* refine code

* refine code

* update temp space
  • Loading branch information
Dmovic authored Jan 22, 2025
1 parent feb941a commit 5c4f665
Show file tree
Hide file tree
Showing 13 changed files with 488 additions and 111 deletions.
185 changes: 185 additions & 0 deletions paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,77 @@ void CodeGenC::Visit(const ir::For *op) {
str_ += "}";
}
}

void CodeGenC::VisitStmt(const ir::stmt::For &stmt) {
Expr extent = stmt->extent();
Expr min = stmt->min();
int num_task = 1;
if (stmt->is_parallel()) {
str_ += "int num_task = max_concurrency();\n";
DoIndent();
str_ += "omp_set_num_threads(num_task);\n";
DoIndent();
str_ += "auto flambda = [=](int task_id, int num_task) -> int {\n";
IncIndent();
DoIndent();
str_ += "int n_per_task = ";
Expr num_task_var = Var("num_task");
IrPrinter::Visit((stmt->extent() + num_task_var - 1) / num_task_var);
str_ += ";\n";
PADDLE_ENFORCE_EQ(min.as_int32(),
0,
::common::errors::InvalidArgument(
"The min of the for loop should be 0"));
auto task_id = Var("task_id");
auto n_per_task = Var("n_per_task");
min = task_id * n_per_task;
extent = (task_id + 1) * n_per_task;
DoIndent();
}
str_ += "for (";
str_ += GetTypeRepr(Int(32));
str_ += " ";
str_ += stmt->loop_var()->name;
str_ += " = ";
IrPrinter::Visit(min);
str_ += "; ";
str_ += stmt->loop_var()->name;
str_ += " < ";
IrPrinter::Visit(stmt->extent());
if (stmt->is_parallel()) {
str_ += " && ";
str_ += stmt->loop_var()->name;
str_ += " < ";
IrPrinter::Visit(extent);
}
str_ += "; ";

str_ += stmt->loop_var()->name;
str_ += " += 1";
str_ += ") ";

VisitBlock(stmt->body());
if (stmt->is_parallel()) {
str_ += "\n";
DoIndent();
str_ += "return 0;\n";
DecIndent();
DoIndent();
str_ += "};\n";
str_ += "#pragma omp parallel num_threads(num_task)\n";
DoIndent();
str_ += "{\n";
IncIndent();
DoIndent();
str_ += "int task_id = omp_get_thread_num();\n";
DoIndent();
str_ += "flambda(task_id, num_task);\n";
DecIndent();
DoIndent();
str_ += "}";
}
}

void CodeGenC::Visit(const ir::PolyFor *op) {
str_ += "for (";
str_ += GetTypeRepr(Int(32));
Expand Down Expand Up @@ -310,6 +381,20 @@ void CodeGenC::Visit(const ir::IfThenElse *op) {
IrPrinter::Visit(op->false_case);
}
}

void CodeGenC::VisitStmt(const ir::stmt::IfThenElse &stmt) {
str_ += "if (";
IrPrinter::Visit(stmt->condition());
str_ += ") ";

VisitBlock(stmt->true_case());

if (!stmt->false_case()->stmts().empty()) {
str_ += " else ";
VisitBlock(stmt->false_case());
}
}

void CodeGenC::Visit(const ir::Block *op) {
str_ += "{\n";

Expand All @@ -332,6 +417,29 @@ void CodeGenC::Visit(const ir::Block *op) {
DoIndent();
str_ += "}";
}
void CodeGenC::VisitBlock(const ir::stmt::BlockRef &stmt) {
str_ += "{\n";

IncIndent();

// Note: size_t (0 - 1) = 18446744073709551615
if (stmt->stmts().size() >= 1) {
for (int i = 0; i < stmt->stmts().size() - 1; i++) {
DoIndent();
IrPrinter::VisitStmt(stmt->stmts()[i]);
str_ += ";\n";
}
DoIndent();
IrPrinter::VisitStmt(stmt->stmts().back());
str_ += ";";
}

DecIndent();
str_ += "\n";
DoIndent();
str_ += "}";
}

void CodeGenC::Visit(const ir::Call *op) {
if (op->name == runtime::intrinsic::buffer_malloc) {
PrintCall_buffer_malloc(op);
Expand Down Expand Up @@ -527,6 +635,30 @@ void CodeGenC::Visit(const ir::Load *op) {
}
}

void CodeGenC::VisitStmt(const ir::stmt::Store &stmt) {
PADDLE_ENFORCE_EQ(
stmt->is_addr_tensor(),
true,
::common::errors::InvalidArgument(
"The operation type is invalid. It must be an address tensor."));
ir::Expr offset = [&] {
if (store_stmt_to_offset_.count(stmt) == 0) {
store_stmt_to_offset_[stmt] = stmt->index();
}
return store_stmt_to_offset_.at(stmt);
}();
auto *tensor = stmt->tensor().As<ir::_Tensor_>();
PADDLE_ENFORCE_NOT_NULL(tensor,
::common::errors::InvalidArgument(
"The tensor is null. It must not be null."));
str_ += tensor->name;
str_ += "[";
IrPrinter::Visit(offset);
str_ += "]";
str_ += " = ";
IrPrinter::Visit(stmt->value());
}

void CodeGenC::Visit(const ir::Store *op) {
PADDLE_ENFORCE_EQ(
op->is_addr_tensor(),
Expand Down Expand Up @@ -560,6 +692,16 @@ void CodeGenC::Visit(const ir::Alloc *op) {
str_ += ")";
}

void CodeGenC::VisitStmt(const ir::stmt::Alloc &stmt) {
str_ += runtime::intrinsic::buffer_malloc;
str_ += "(";
str_ += "(void*)(0), ";

auto *buffer = stmt->destination().As<ir::_Buffer_>();
str_ += buffer->name;
str_ += ")";
}

void CodeGenC::Visit(const ir::Free *op) {
str_ += runtime::intrinsic::buffer_free;
str_ += "(";
Expand All @@ -570,6 +712,16 @@ void CodeGenC::Visit(const ir::Free *op) {
str_ += ")";
}

void CodeGenC::VisitStmt(const ir::stmt::Free &stmt) {
str_ += runtime::intrinsic::buffer_free;
str_ += "(";
str_ += "(void*)(0), ";

auto *buffer = stmt->destination().As<ir::_Buffer_>();
str_ += buffer->name;
str_ += ")";
}

void CodeGenC::Visit(const ir::_Buffer_ *op) { str_ += op->name; }
void CodeGenC::Visit(const ir::_Tensor_ *op) { str_ += op->buffer->name; }
void CodeGenC::Visit(const ir::Let *op) {
Expand Down Expand Up @@ -602,6 +754,36 @@ void CodeGenC::Visit(const ir::Let *op) {
}
}

void CodeGenC::VisitStmt(const ir::stmt::Let &stmt) {
bool is_vec = false;
PADDLE_ENFORCE_EQ(stmt->type().valid(),
true,
::common::errors::InvalidArgument(
"The operation type is invalid. It must be valid."));
if (stmt->body().defined() && stmt->body().As<ir::Broadcast>()) {
// broadcast's type is hard to decide, so use c++11 auto instead.
str_ += "auto";
is_vec = true;
} else {
str_ += GetTypeRepr(stmt->type());
}

str_ += " ";
IrPrinter::Visit(stmt->symbol());

// native C array.
if (stmt->type().lanes() > 1 && !is_vec) {
str_ += "[";
str_ += std::to_string(stmt->type().lanes());
str_ += "]";
}

if (stmt->body().defined()) {
str_ += " = ";
IrPrinter::Visit(stmt->body());
}
}

void CodeGenC::Visit(const ir::Reduce *op) {
PADDLE_THROW(::common::errors::InvalidArgument(
"Reduce IR is just for internal representation, should not be "
Expand Down Expand Up @@ -828,6 +1010,9 @@ void CodeGenC::PrintStackVecType(Type type, int lanes) {
void CodeGenC::Visit(const ir::PrimitiveNode *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::_BufferRange_ *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::ScheduleBlock *op) { CINN_NOT_IMPLEMENTED }
void CodeGenC::VisitStmt(const ir::stmt::Schedule &stmt) {
CINN_NOT_IMPLEMENTED
}
void CodeGenC::Visit(const ir::ScheduleBlockRealize *op) {
CINN_NOT_IMPLEMENTED
}
Expand Down
17 changes: 17 additions & 0 deletions paddle/cinn/backends/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ class CodeGenC : public ir::IrPrinter {
void Visit(const ir::_Module_* op) override;
void Visit(const ir::_LoweredFunc_* op) override;

void VisitBlock(const ir::stmt::BlockRef& block) override;
void VisitStmt(const ir::stmt::Let& stmt) override;
void VisitStmt(const ir::stmt::Store& stmt) override;
void VisitStmt(const ir::stmt::Alloc& stmt) override;
void VisitStmt(const ir::stmt::Free& stmt) override;
void VisitStmt(const ir::stmt::IfThenElse& stmt) override;
void VisitStmt(const ir::stmt::For& stmt) override;
void VisitStmt(const ir::stmt::Schedule& stmt) override;

#define __DEFINE_VISIT(op__) \
void Visit(const ir::intrinsics::op__* op) override;
INTRINSIC_KIND_FOR_EACH(__DEFINE_VISIT)
Expand All @@ -116,11 +125,19 @@ class CodeGenC : public ir::IrPrinter {

friend class ExternFunctionEmitter;

struct StoreHash {
size_t operator()(const ir::stmt::Store& stmt) const {
return std::hash<const Object*>()(stmt.get());
}
};

protected:
Target target_;
std::stringstream ss_;
bool inline_builtin_codes_{true};
std::unordered_map<const ir::Store*, ir::Expr> store_to_offset_;
std::unordered_map<const ir::stmt::Store, ir::Expr, StoreHash>
store_stmt_to_offset_;
std::unordered_map<const ir::Load*, ir::Expr> load_to_offset_;
};

Expand Down
39 changes: 26 additions & 13 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
[&](common::HygonDCUArchSYCL) {
call_kernel = runtime::intrinsic::call_sycl_kernel;
});
// TODO(Dmovic): use new ir when backend update done.
ir::Expr call_extern_api =
ir::Call::Make(Void(),
call_kernel.value(),
Expand All @@ -264,7 +265,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
0);

// create memset calls for temp_spaces if needed
std::vector<ir::Expr> call_kernel_stmts;
std::vector<ir::stmt::StmtRef> call_kernel_stmts;
for (auto &temp_space : func_node->temp_spaces) {
if (temp_space.need_zero_init()) {
ir::Expr size = common::cast(temp_space.size(), common::UInt(64));
Expand All @@ -274,23 +275,26 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Expr call_memset = lang::CallExtern(
runtime::intrinsic::call_cuda_memset,
{call_get_arg, ir::Expr(1), ir::Expr(0), size, kernel_stream_});
call_kernel_stmts.push_back(call_memset);
call_kernel_stmts.push_back(ir::stmt::Evaluate(call_memset));
}
}
call_kernel_stmts.push_back(call_extern_api);
call_extern_api = ir::Block::Make(call_kernel_stmts);
call_kernel_stmts.push_back(ir::stmt::Evaluate(call_extern_api));
auto call_extern_api_block = ir::stmt::BlockRef(call_kernel_stmts);

if (buckets_.empty()) {
buckets_.emplace_back(ir::IfThenElse::Make(predicate, call_extern_api));
buckets_.emplace_back(
ir::stmt::IfThenElse(predicate, call_extern_api_block));
} else {
auto false_expr = buckets_.back();
buckets_.pop_back();
buckets_.emplace_back(
ir::IfThenElse::Make(predicate, call_extern_api, false_expr));
buckets_.emplace_back(ir::stmt::IfThenElse(
predicate,
call_extern_api_block,
ir::stmt::BlockRef(std::vector<ir::stmt::StmtRef>{false_expr})));
}

// create infer shape calls for temp_spaces
std::vector<ir::Expr> temp_space_infer_shape_stmts;
std::vector<ir::stmt::StmtRef> temp_space_infer_shape_stmts;
for (int i = 0; i < func_node->temp_spaces.size(); ++i) {
ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int64_t **>());
ir::Expr size =
Expand All @@ -301,12 +305,20 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
ir::Expr(0),
size,
tensor_shape_args});
temp_space_infer_shape_stmts.push_back(call_set_value);
temp_space_infer_shape_stmts.push_back(ir::stmt::Evaluate(call_set_value));
}
if (!temp_space_infer_shape_stmts.empty()) {
ir::Expr if_body = ir::Block::Make(temp_space_infer_shape_stmts);
temp_space_infer_shape_body_ =
ir::IfThenElse::Make(predicate, if_body, temp_space_infer_shape_body_);
ir::stmt::BlockRef if_body =
ir::stmt::BlockRef(temp_space_infer_shape_stmts);
if (temp_space_infer_shape_body_.defined()) {
temp_space_infer_shape_body_ = ir::stmt::IfThenElse(
predicate,
if_body,
ir::stmt::BlockRef(
std::vector<ir::stmt::StmtRef>{temp_space_infer_shape_body_}));
} else {
temp_space_infer_shape_body_ = ir::stmt::IfThenElse(predicate, if_body);
}
}
}

Expand All @@ -325,7 +337,8 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
0);
ir::Expr let_symbol = ir::Expr(args[i].var_arg());
let_symbol->set_type(type_of<int64_t>());
ir::Expr stmt = ir::Let::Make(let_symbol, call_get_value_in_kernel_args);
ir::stmt::StmtRef stmt =
ir::stmt::Let(let_symbol, call_get_value_in_kernel_args);
arg_defs_.push_back(stmt);
}
}
Expand Down
Loading

0 comments on commit 5c4f665

Please sign in to comment.