Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do some more simplifications specific to extents #3891

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Feb 14, 2025

While working on #3848, I noticed test_unpadded_catop_issue2275_repro2 took an extremely long time (> 30 min). That seems largely due to index hoisting and expression simplification. It took just a few seconds when they were disabled. That's likely due to a lot of min and max due to slicing of symbolic extents, as shown below:

T10_l_float[iS152{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}, iS153{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 32) )) ), 32) )) )}, iS154{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 4096) )) ), 4096) )) )}, iS155{( ( ( fmax(64, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 128) )) ) - 64 ) + ( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 64) )) ) )}]
   = __bfloat2float(T9_l___bfloat[iS148{( fmax(0, ( fmin(( fmax(0, ( fmin(i0, 2) )) ), 2) )) )}, iS149{( fmax(0, ( fmin(( fmax(0, ( fmin(i1, 32) )) ), 32) )) )}, iS150{( fmax(0, ( fmin(( fmax(0, ( fmin(i2, 4096) )) ), 4096) )) )}, iS151{( ( ( fmax(64, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 128) )) ) - 64 ) + ( fmax(0, ( fmin(( fmax(0, ( fmin(i3, 128) )) ), 64) )) ) )}]);

This PR tries to simplifies these extents a little further, which results in:

T10_l_float[?S54{( fmin(i0, 2) )}, ?S55{( fmin(i1, 32) )}, ?S56{( fmin(i2, 4096) )}, ?S57{( ( ( fmax(64, ( fmin(i3, 128) )) ) - 64 ) + ( fmin(i3, 64) ) )}]
      = __bfloat2float(T9_l___bfloat[?S50{( fmin(i0, 2) )}, ?S51{( fmin(i1, 32) )}, ?S52{( fmin(i2, 4096) )}, ?S53{( ( ( fmax(64, ( fmin(i3, 128) )) ) - 64 ) + ( fmin(i3, 64) ) )}]);

The test time is reduced to several seconds by these simplifications.

Confirmed no failure with manual_ci.sh on an H100 machine.

Copy link

github-actions bot commented Feb 14, 2025

Review updated until commit d72668e

Description

  • Added simplification for expressions like x + (-x) to 0.

  • Enhanced slice function with specialized simplifications for extents.

  • Updated test cases to reflect new simplifications.


Changes walkthrough 📝

Relevant files
Enhancement
builder.cpp
Simplify x + (-x) to 0                                                                     

csrc/ir/builder.cpp

  • Added logic to simplify expressions of the form x + (-x) to 0.
+15/-0   
alias.cpp
Enhance extent simplification in slice                                     

csrc/ops/alias.cpp

  • Introduced helper functions get_int and min_extents for extent
    simplifications.
  • Enhanced normalize_slice_range to use min_extents for better
    simplification.
  • +61/-12 
    Tests
    test_resize.cpp
    Update test for extent simplification                                       

    tests/cpp/test_resize.cpp

  • Updated test case SliceExtentSimplification to reflect new
    simplification logic.
  • +2/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Simplification Logic

    The logic for simplifying x + (-x) to 0 should be validated to ensure it doesn't inadvertently simplify other expressions incorrectly.

    // Simplify x + (-x) to 0
    Val* x = nullptr;
    auto uop = dynamic_cast<UnaryOp*>(lhs->definition());
    if (uop != nullptr) {
      // lhs may be (-x). Pick rhs as x
      x = rhs;
    } else {
      uop = dynamic_cast<UnaryOp*>(rhs->definition());
      // rhs may be (-x). Pick lhs as x
      x = lhs;
    }
    if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg &&
        uop->in()->sameAs(x)) {
      return lhs->fusion()->zeroVal(lhs->dtype());
    }
    return IrBuilder::addExpr(lhs, rhs);
    Extent Simplification

    The extent simplification logic should be thoroughly tested to ensure it handles all edge cases correctly, especially with symbolic extents.

    const auto get_int = [](Val* x) -> std::optional<int64_t> {
      if (x != nullptr && x->isConstInt()) {
        return x->evaluate().as<int64_t>();
      } else {
        return std::nullopt;
      }
    };
    
    // Specialized min for extents. Do some more simplification beyond
    // SimplifyingIrBuilder that are only valid for extents.
    const auto min_extents = [&](Val* x, Val* y) -> Val* {
      auto x_int = get_int(x);
      auto y_int = get_int(y);
      // Since extents are never negative, if one is 0, that must be the mininum.
      if (x_int == 0) {
        return x;
      } else if (y_int == 0) {
        return y;
      }
      // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it
      // isn't uncommon.
      auto bop = dynamic_cast<BinaryOp*>(x->definition());
      if (y_int != std::nullopt && bop != nullptr &&
          bop->getBinaryOpType() == BinaryOpType::Min) {
        if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) {
          return SimplifyingIrBuilder::minExpr(
              bop->rhs(), IrBuilder::create<Val>(std::min(*lhs_int, *y_int)));
        } else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) {
          return SimplifyingIrBuilder::minExpr(
              bop->lhs(), IrBuilder::create<Val>(std::min(*rhs_int, *y_int)));
        }
      }
    
      return SimplifyingIrBuilder::minExpr(x, y);
    };
    
    const auto normalize_slice_range =
        [&manual_normalization, &min_extents, &get_int](
            Slice range, Val* extent) -> Slice {
      auto cast_extent =
          SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent);
    
      auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index);
    
      auto start_int = get_int(range.start);
      auto stop_int = get_int(range.stop);
    
      // norm_start = max(0, start < 0 ? start + extent : start)
      if (range.start == nullptr) {
        range.start = zero;
        start_int = 0;
      } else if (start_int != 0) {
        range.start =
            SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start);
        if (!manual_normalization) {
          range.start = SimplifyingIrBuilder::maxExpr(
              zero,
              SimplifyingIrBuilder::whereExpr(
                  SimplifyingIrBuilder::ltExpr(range.start, zero),
                  SimplifyingIrBuilder::addExpr(range.start, cast_extent),
                  range.start));
        }
        if (range.start->isConstInt()) {
          start_int = range.start->evaluate().as<int64_t>();
        }
      }
    
      // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop)
      if (range.stop == nullptr) {
        range.stop = cast_extent;
      } else if (!range.stop->sameAs(extent)) {
        range.stop =
            SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop);
        // Commonly, range.start is zero and stop is non negative
        if (start_int == 0 && stop_int >= 0) {
          range.stop = min_extents(cast_extent, range.stop);
        } else {
          if (!manual_normalization) {
            range.stop = SimplifyingIrBuilder::maxExpr(
                range.start,
                min_extents(
                    cast_extent,
                    SimplifyingIrBuilder::whereExpr(
                        SimplifyingIrBuilder::ltExpr(range.stop, zero),
                        SimplifyingIrBuilder::addExpr(range.stop, cast_extent),
                        range.stop)));
          }
        }
    Test Expectation

    The test expectation for SliceExtentSimplification should be verified to ensure it accurately reflects the expected simplification.

      auto tv1 =
          slice(tv0, {{IrBuilder::create<Val>(0L), IrBuilder::create<Val>(1L)}});
      // By default, the extent of the tv1 domain is:
      //   i0 + ( ( fmax(0, ( fmin(i0, 1) )) ) + ( -i0 ) )
      // This should be simplified to just:
      //   fmin(i0, 1)
    
      fusion.addOutput(tv1);
    
      auto resize_extent = tv1->axis(0)->extent();
      auto bop = dynamic_cast<BinaryOp*>(resize_extent->definition());
      ASSERT_TRUE(bop != nullptr)
          << "Unexpected resize output extent: " << resize_extent->toInlineString();
      EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min)
          << "Unexpected resize output extent: " << resize_extent->toInlineString();
    }

    @naoyam naoyam marked this pull request as ready for review February 14, 2025 06:42
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 14, 2025

    !test

    @naoyam naoyam requested a review from jacobhinkle February 14, 2025 06:43
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Feb 16, 2025

    !test

    @jacobhinkle
    Copy link
    Collaborator

    Maybe it's time for me to revive #511. With runtime info we should be able to fully simplify these expressions at concretization.

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM. Comments are minor.

    Comment on lines +383 to +397
    // Simplify x + (-x) to 0
    Val* x = nullptr;
    auto uop = dynamic_cast<UnaryOp*>(lhs->definition());
    if (uop != nullptr) {
    // lhs may be (-x). Pick rhs as x
    x = rhs;
    } else {
    uop = dynamic_cast<UnaryOp*>(rhs->definition());
    // rhs may be (-x). Pick lhs as x
    x = lhs;
    }
    if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg &&
    uop->in()->sameAs(x)) {
    return lhs->fusion()->zeroVal(lhs->dtype());
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think if you had abs(y) + (-abs(y)) this might not catch it because lhs=abs(y) is not a Neg. Instead, what about something like this?

    Suggested change
    // Simplify x + (-x) to 0
    Val* x = nullptr;
    auto uop = dynamic_cast<UnaryOp*>(lhs->definition());
    if (uop != nullptr) {
    // lhs may be (-x). Pick rhs as x
    x = rhs;
    } else {
    uop = dynamic_cast<UnaryOp*>(rhs->definition());
    // rhs may be (-x). Pick lhs as x
    x = lhs;
    }
    if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg &&
    uop->in()->sameAs(x)) {
    return lhs->fusion()->zeroVal(lhs->dtype());
    }
    // Simplify x + (-x) to 0
    Val* x = nullptr;
    Val* neg_x_in = nullptr;
    if (auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) {
    // lhs is -x. Pick rhs as x
    neg_x_in = uop->in();
    x = rhs;
    } else if (auto uop = dynamic_cast<UnaryOp*>(rhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) {
    // rhs is -x. Pick lhs as x
    neg_x_in = uop->in();
    x = lhs;
    }
    if (x != nullptr && neg_x_in->sameAs(x)) {
    return lhs->fusion()->zeroVal(lhs->dtype());
    }

    const auto normalize_slice_range = [&manual_normalization](
    Slice range, Val* extent) -> Slice {
    const auto get_int = [](Val* x) -> std::optional<int64_t> {
    if (x != nullptr && x->isConstInt()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Minor nit: if you initialize an ExpressionEvaluator in this function I think you could use expr_eval.evaluate(x) to return a PolymorphicValue and check pv.hasValue() instead of using optional<int64_t>. Then you could avoid the isConstInt() condition here which will redundantly evaluate the val and replace it with isIntegralScalar().

    Comment on lines +812 to +814
    if (range.start->isConstInt()) {
    start_int = range.start->evaluate().as<int64_t>();
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    See comment above if we have an expr_eval. Not important but will avoid re-computing these vals.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants