-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do some more simplifications specific to extents #3891
base: main
Are you sure you want to change the base?
Conversation
Review updated until commit d72668e Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
!test |
!test |
Maybe it's time for me to revive #511. With runtime info we should be able to fully simplify these expressions at concretization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Comments are minor.
// Simplify x + (-x) to 0 | ||
Val* x = nullptr; | ||
auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); | ||
if (uop != nullptr) { | ||
// lhs may be (-x). Pick rhs as x | ||
x = rhs; | ||
} else { | ||
uop = dynamic_cast<UnaryOp*>(rhs->definition()); | ||
// rhs may be (-x). Pick lhs as x | ||
x = lhs; | ||
} | ||
if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && | ||
uop->in()->sameAs(x)) { | ||
return lhs->fusion()->zeroVal(lhs->dtype()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you had abs(y) + (-abs(y))
this might not catch it because lhs=abs(y)
is not a Neg. Instead, what about something like this?
// Simplify x + (-x) to 0 | |
Val* x = nullptr; | |
auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); | |
if (uop != nullptr) { | |
// lhs may be (-x). Pick rhs as x | |
x = rhs; | |
} else { | |
uop = dynamic_cast<UnaryOp*>(rhs->definition()); | |
// rhs may be (-x). Pick lhs as x | |
x = lhs; | |
} | |
if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && | |
uop->in()->sameAs(x)) { | |
return lhs->fusion()->zeroVal(lhs->dtype()); | |
} | |
// Simplify x + (-x) to 0 | |
Val* x = nullptr; | |
Val* neg_x_in = nullptr; | |
if (auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) { | |
// lhs is -x. Pick rhs as x | |
neg_x_in = uop->in(); | |
x = rhs; | |
} else if (auto uop = dynamic_cast<UnaryOp*>(rhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) { | |
// rhs is -x. Pick lhs as x | |
neg_x_in = uop->in(); | |
x = lhs; | |
} | |
if (x != nullptr && neg_x_in->sameAs(x)) { | |
return lhs->fusion()->zeroVal(lhs->dtype()); | |
} |
const auto normalize_slice_range = [&manual_normalization]( | ||
Slice range, Val* extent) -> Slice { | ||
const auto get_int = [](Val* x) -> std::optional<int64_t> { | ||
if (x != nullptr && x->isConstInt()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit: if you initialize an ExpressionEvaluator
in this function I think you could use expr_eval.evaluate(x)
to return a PolymorphicValue
and check pv.hasValue()
instead of using optional<int64_t>
. Then you could avoid the isConstInt()
condition here which will redundantly evaluate the val and replace it with isIntegralScalar()
.
if (range.start->isConstInt()) { | ||
start_int = range.start->evaluate().as<int64_t>(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above if we have an expr_eval. Not important but will avoid re-computing these vals.
While working on #3848, I noticed
test_unpadded_catop_issue2275_repro2
took an extremely long time (> 30 min). That seems largely due to index hoisting and expression simplification. It took just a few seconds when they were disabled. That's likely due to a lot ofmin
andmax
due to slicing of symbolic extents, as shown below:This PR tries to simplifies these extents a little further, which results in:
The test time is reduced to several seconds by these simplifications.
Confirmed no failure with manual_ci.sh on an H100 machine.