Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]apply new backend ir on some ir utils #70952

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gather_srcs(
ir_base.cc
ir_visitor.cc
ir_printer.cc
ir_collector.cc
ir_mutator.cc
function_definition.cc
buffer.cc
Expand Down
218 changes: 218 additions & 0 deletions paddle/cinn/ir/ir_collector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/ir_collector.h"
#include "paddle/cinn/ir/expr_visitors.h"

namespace cinn {
namespace ir {
using stmt::StmtRef;
using ExprSet = std::vector<Expr>;
using StmtSet = std::vector<StmtRef>;

template <typename SourceT, typename TargetT, typename MiddleT>
IRCollector<SourceT, TargetT> operator*(IRCollector<SourceT, MiddleT> x,
IRCollector<MiddleT, TargetT> y) {
const auto& new_f = [&](const SourceT& source) -> std::vector<TargetT> {
const auto& x_res_set = x.f_(source);
std::vector<TargetT> res;
for (const auto& x_res : x_res_set) {
const auto& y_res = y.f_(x_res);
res.insert(res.end(), y_res.begin(), y_res.end());
}
return res;
};
return IRCollector<SourceT, TargetT>(std::function(new_f),
x.name + "*" + y.name);
}

Stmt2ExprCollector Store2Value = Stmt2ExprCollector(
[](const StmtRef& s) -> ExprSet {
if (s.isa<stmt::Store>()) return {s.as<stmt::Store>()->value()};
return {};
},
"Store2Value");

Stmt2ExprCollector Schedule2IterValues = Stmt2ExprCollector(
[](const StmtRef& s) -> ExprSet {
if (s.isa<stmt::Schedule>()) return s.as<stmt::Schedule>()->iter_values();
return {};
},
"Schedule2IterValues");

Stmt2StmtCollector ScheduleNotRoot = FilterMaker<StmtRef>(
[](const StmtRef& s) -> bool {
return (s.isa<stmt::Schedule>() &&
s.as<stmt::Schedule>()->name().find("root") == std::string::npos);
},
"ScheduleNotRoot");

Stmt2StmtCollector ScheduleIsRoot = FilterMaker<StmtRef>(
[](const StmtRef& s) -> bool {
return (s.isa<stmt::Schedule>() &&
s.as<stmt::Schedule>()->name().find("root") != std::string::npos);
},
"ScheduleIsRoot");

Stmt2StmtCollector ScheduleIsNotInit = FilterMaker<StmtRef>(
[](const StmtRef& s) -> bool {
return (s.isa<stmt::Schedule>() &&
s.as<stmt::Schedule>()->name().find("__reduce_init") ==
std::string::npos);
},
"ScheduleIsNotInit");

Stmt2StmtCollector ScheduleIsInit = FilterMaker<StmtRef>(
[](const StmtRef& s) -> bool {
return (s.isa<stmt::Schedule>() &&
s.as<stmt::Schedule>()->name().find("__reduce_init") !=
std::string::npos);
},
"ScheduleIsInit");

Stmt2StmtCollector IsFor = FilterMaker<StmtRef>(
[](const StmtRef& s) -> bool { return s.isa<stmt::For>(); }, "IsFor");

Stmt2StmtCollector ChildSchedules =
NestedCollectorMaker<StmtRef, StmtRef>(
[](const StmtRef& s) -> bool { return s.isa<stmt::Schedule>(); },
"ChildSchedules") *
ScheduleNotRoot;

Stmt2StmtCollector IsForWithIterVar(const ir::Var& var) {
return FilterMaker<StmtRef>(
[&](const StmtRef& s) -> bool {
return s.isa<stmt::For>() && s.as<stmt::For>()->loop_var() == var;
},
"IsForWithIterVar");
}

Stmt2ExprCollector For2Min = Stmt2ExprCollector(
[](const StmtRef& s) -> ExprSet { return {s.as<stmt::For>()->min()}; },
"For2Min");

Stmt2ExprCollector For2Max = Stmt2ExprCollector(
[](const StmtRef& s) -> ExprSet { return {s.as<stmt::For>()->extent()}; },
"For2Max");

Stmt2StmtCollector ChildStores = NestedCollectorMaker<StmtRef, StmtRef>(
[](const StmtRef& s) -> bool { return s.isa<stmt::Schedule>(); },
"ChildStores");

Stmt2ExprCollector ChildExprsWithoutNested = Stmt2ExprCollector(
[](const StmtRef& s) -> ExprSet {
std::vector<Expr> res;
switch (s->stmt_type()) {
#define __(stmt__) \
case StmtNodeTy::stmt__: \
VisitExpr(s.as<stmt::stmt__>(), \
[&](const Expr& e) { res.emplace_back(e); }); \
break;
NODETY_FORALL_STMT(__)

default:
PADDLE_THROW(::common::errors::InvalidArgument(
"Deadcode, not supported StmtNodeTy"));
#undef __
}
return res;
},
"ChildExprsWithoutNested");

/*
* Collect all exprs in a stmt, including nested blocks.
*/
Stmt2ExprCollector ChildExprs = Stmt2ExprCollector(
[](const StmtRef& s) -> ExprSet {
std::vector<Expr> rs;
Visit(
s,
[&](const stmt::StmtRef& x) {
const auto& rs_without_nested = ChildExprsWithoutNested(x);
rs.insert(
rs.end(), rs_without_nested.begin(), rs_without_nested.end());
},
[&](const stmt::StmtRef& x) {});
return rs;
},
"ChildExprs");

Stmt2ExprCollector ChildTensorLoads = NestedCollectorMaker<StmtRef, Expr>(
[](const ir::Expr& e) {
return e.As<ir::Load>() && e.As<ir::Load>()->is_addr_tensor();
},
"ChildLoads");

Stmt2StmtCollector ChildTensorStores = NestedCollectorMaker<StmtRef, StmtRef>(
[](const StmtRef& s) {
return s.isa<stmt::Store>() && s.as<stmt::Store>()->is_addr_tensor();
},
"ChildTensorStores");

Expr2ExprCollector FilterLoadByTensor(const Tensor& tensor) {
return FilterMaker<Expr>(
[tensor = tensor](const Expr& e) -> bool {
return e.As<ir::Load>() &&
e.As<ir::Load>()->tensor.as_tensor_ref()->name == tensor->name;
},
"FilterLoadByTensor(" + tensor->name + ")");
}

Stmt2StmtCollector ChildFors = NestedCollectorMaker<StmtRef, StmtRef>(
[](const StmtRef& s) { return s.isa<stmt::For>(); }, "ChildFors");

Stmt2StmtCollector ChildIfThenElses = NestedCollectorMaker<StmtRef, StmtRef>(
[](const StmtRef& s) { return s.isa<stmt::IfThenElse>(); },
"ChildIfThenElses");

// TODO(Hongqing-work): update father collector after supporting StmtRef to
// record father.
Stmt2StmtCollector FindFather(const StmtRef& root) {
const auto& f = [root](const StmtRef& child) -> StmtSet {
std::vector<StmtRef> result;
std::vector<StmtRef> cur_fathers;
std::function<stmt::VisitResult(const StmtRef&)> pre_callback =
[&](const StmtRef& current) -> stmt::VisitResult {
if (current == child) {
result = cur_fathers;
return stmt::VisitResult::interrupt();
}
cur_fathers.push_back(current);
return stmt::VisitResult::advance();
};
std::function<stmt::VisitResult(const StmtRef&)> post_callback =
[&](const StmtRef& current) -> stmt::VisitResult {
if (current == child) {
result = cur_fathers;
return stmt::VisitResult::interrupt();
}
cur_fathers.push_back(current);
return stmt::VisitResult::advance();
};
const auto& visit_res = stmt::Visit(root, pre_callback, post_callback);
return result;
};
return Stmt2StmtCollector(f, "FindFather");
}

Stmt2StmtCollector DirectlyFather(const StmtRef& root) {
const auto& f = [root](const StmtRef& child) -> StmtSet {
StmtSet result = FindFather(root)(child);
return {result[result.size() - 1]};
};
return Stmt2StmtCollector(f, "DirectlyFather");
}

} // namespace ir
} // namespace cinn
160 changes: 160 additions & 0 deletions paddle/cinn/ir/ir_collector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/ir/stmt.h"
#include "paddle/cinn/ir/stmt_visitors.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"

namespace cinn {
namespace ir {

template <typename SourceT, typename TargetT>
struct IRCollector {
using CollectorFuncT = std::function<std::vector<TargetT>(const SourceT& x)>;
CollectorFuncT f_;
std::string name;
explicit IRCollector(CollectorFuncT f, std::string s = "") : f_(f), name(s) {}

std::vector<TargetT> operator()(const SourceT& x) const { return f_(x); }
TargetT GetSingle(const SourceT& x) const {
const auto& o = f_(x);
PADDLE_ENFORCE_EQ(
o.size(),
1,
::common::errors::InvalidArgument(
"Try to get single result, but we get %d.", o.size()));
return *o.begin();
}
};

using Expr2ExprCollector = IRCollector<Expr, Expr>;
using Stmt2StmtCollector = IRCollector<stmt::StmtRef, stmt::StmtRef>;
using Stmt2ExprCollector = IRCollector<stmt::StmtRef, Expr>;

template <typename SourceT, typename TargetT, typename MiddleT>
IRCollector<SourceT, TargetT> operator*(IRCollector<SourceT, MiddleT> x,
IRCollector<MiddleT, TargetT> y);

template Expr2ExprCollector operator*
<Expr, Expr, Expr>(Expr2ExprCollector, Expr2ExprCollector);
template Stmt2StmtCollector operator*
<stmt::StmtRef, stmt::StmtRef, stmt::StmtRef>(Stmt2StmtCollector,
Stmt2StmtCollector);
template Stmt2ExprCollector operator*
<stmt::StmtRef, Expr, Expr>(Stmt2ExprCollector, Expr2ExprCollector);
template Stmt2ExprCollector operator*
<stmt::StmtRef, Expr, stmt::StmtRef>(Stmt2StmtCollector,
Stmt2ExprCollector);

template <typename SourceT, typename TargetT>
IRCollector<SourceT, TargetT> NestedCollectorMaker(
std::function<bool(const TargetT&)> teller, std::string name = "");

// TODO(Hongqing-work): move methods of ir_nodes_collector.h to this file.
template <>
Expr2ExprCollector NestedCollectorMaker(std::function<bool(const Expr&)> teller,
std::string name) {
return Expr2ExprCollector(
[=](const Expr& x) -> std::vector<Expr> {
const auto new_func = [=](const ir::Expr* x) -> bool {
return teller(*x);
};
return cinn::ir::ir_utils::CollectIRNodesInOrder(x, new_func);
},
name);
}

template <>
Stmt2StmtCollector NestedCollectorMaker(
std::function<bool(const stmt::StmtRef&)> teller, std::string name) {
return Stmt2StmtCollector(
[=](const stmt::StmtRef& x) -> std::vector<stmt::StmtRef> {
std::vector<stmt::StmtRef> rs;
stmt::Visit(
x,
[&](const stmt::StmtRef& x) {
if (teller(x)) rs.push_back(x);
},
[&](const stmt::StmtRef& x) {});
return rs;
},
name);
}

extern Stmt2ExprCollector ChildExprsWithoutNested;
extern Stmt2ExprCollector ChildExprs;

template <>
Stmt2ExprCollector NestedCollectorMaker(std::function<bool(const Expr&)> teller,
std::string name) {
Stmt2ExprCollector res =
ChildExprs * NestedCollectorMaker<Expr, Expr>(teller);
res.name = name;
return res;
}

template <typename TargetT>
IRCollector<TargetT, TargetT> FilterMaker(
std::function<bool(const TargetT&)> filter, std::string name) {
return IRCollector<TargetT, TargetT>(
[=](const TargetT& x) -> std::vector<TargetT> {
if (filter(x)) return {x};
return {};
},
name);
}

extern Stmt2ExprCollector Store2Value;

extern Stmt2ExprCollector Schedule2IterValues;

extern Stmt2StmtCollector ScheduleNotRoot;

extern Stmt2StmtCollector ScheduleIsRoot;

extern Stmt2StmtCollector ScheduleIsNotInit;

extern Stmt2StmtCollector ScheduleIsInit;

extern Stmt2StmtCollector IsFor;

extern Stmt2StmtCollector ChildSchedules;

extern Stmt2ExprCollector For2Min;

extern Stmt2ExprCollector For2Max;

extern Stmt2StmtCollector ChildStores;

extern Stmt2ExprCollector ChildTensorLoads;

extern Stmt2StmtCollector ChildTensorStores;

extern Stmt2StmtCollector ChildFors;

extern Stmt2StmtCollector ChildIfThenElses;

Stmt2StmtCollector IsForWithIterVar(const Var& var);

Expr2ExprCollector FilterLoadByTensor(const Tensor& tensor);

Stmt2StmtCollector FindFather(const stmt::StmtRef& root);

Stmt2StmtCollector DirectlyFather(const stmt::StmtRef& root);

} // namespace ir
} // namespace cinn
Loading
Loading