Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should not remove predicate for 1d TMA #3899

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 14, 2025

What?

  • Remove predicate only for nd TMA (CpAsyncBulkTensorTile), not 1d TMACpAsyncBulk.
  • Added a test to demo predicate is still required for 1d TMA.

Why?
CpAsyncBulkTensorTile handles out of bound accesses automatically in hardware, no need to predicate it.
However, CpAsyncBulk still needs predicate.
ref the memory range [srcMem, srcMem + size - 1] must not overflow the source memory space. Otherwise, the behavior is undefined

Difficulties to handle circualr buffer
We need to predicate both CpAsyncBulk and waitParity to avoid out of bound access and deadlock. For pre-log (prefetch 2), we can change to the following, where b26 is a newly added predicate to avoid out of bound access.

    for(nvfuser_index_t i25 = 0; i25 < 2; ++i25) {
      bool b26;
      b26 = i19 < (-i25);
      if (((Hopper::electSync(4294967295U) && b9) && b26)) {
        mbarrier::arriveExpectTX(toSmem((&T4[i25])), i2);
        Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (ptr15 + (T0.logical_size[1LL] * i25)), i2, toSmem((&T4[i25])) }), (i6 + (i1 * i25)));
      }
    }

For main-loop, the change is similar except for the extra predicate of wait parity. The predicate for wait parity is assuming the split by stages is divisible, otherwise, it is not correct. This is the first challenge of adding predicate.

    for(nvfuser_index_t i27 = 0; i27 < 3; ++i27) {
      nvfuser_index_t i28;
      i28 = T0.logical_size[1LL] * i27;
      nvfuser_index_t i29;
      i29 = ((nvfuser_index_t)threadIdx.x) + i28;
      nvfuser_index_t i30;
      i30 = -i27;
      bool b31;
      b31 = i20 < i30;
      if (((Hopper::electSync(4294967295U) && b9) && b31)) {
        mbarrier::waitParity(toSmem((&T4[(((2LL + i27) % 5) + 5LL)])), (uint32_t)((((2LL + i27) / 5) % 2)));
        mbarrier::arriveExpectTX(toSmem((&T4[((2LL + i27) % 5)])), i2);
        Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (ptr16 + i28), i2, toSmem((&T4[((2LL + i27) % 5)])) }), (i6 + (i1 * ((2 + i27) % 5))));
      }
      if(b31){
        mbarrier::waitParity(toSmem((&T4[i27])), 0U);
      }

For epilog, there is no cpAsyncBulkG2S, we need to create the predicate using other reference tensors, e.g. output tensor. This is another challenge of adding predicate.

Other approaches:
The WAR in #3917 is much simplier. It uses a mod op to aviod out of bound access, the redundant tma load only happens for the tail SMs in the last iteration.

Copy link

github-actions bot commented Feb 14, 2025

Review updated until commit 5b2a40d

Description

  • Updated predicate elimination to differentiate between 1D and nD TMA operations.

  • Added new utility functions to identify 1D and nD TMA operations.

  • Enhanced predicate handling for CpAsyncUblk operations.

  • Added tests to verify predicate behavior for 1D and nD TMA operations.


Changes walkthrough 📝

Relevant files
Enhancement
predicate_elimination.cpp
Differentiate 1D and nD TMA in predicate elimination         

csrc/device_lower/analysis/predicate_elimination.cpp

  • Modified needsPredicate to differentiate between 1D and nD TMA
    operations.
  • Updated comments to clarify the distinction between 1D and nD TMA
    operations.
  • +9/-9     
    predicate.cpp
    Enhance predicate handling for Ublk TMA loads                       

    csrc/device_lower/pass/predicate.cpp

  • Added new functions getUblkTmaLoad and getUblkTmaLoadFromIte to
    identify Ublk TMA loads.
  • Enhanced ConditionalFromPredicateModifier to handle Ublk TMA loads
    with inline predicates.
  • Added predicateCpAsyncUblk and predicateUblkWaitParity to combine
    predicates for Ublk TMA loads.
  • +217/-7 
    unroll.cpp
    Differentiate 1D and nD TMA in unroll pass                             

    csrc/device_lower/pass/unroll.cpp

  • Updated dispatch to differentiate between 1D and nD TMA operations.
  • Updated comments to clarify the distinction between 1D and nD TMA
    operations.
  • +5/-3     
    utils.cpp
    Add utility functions for TMA operations                                 

    csrc/device_lower/utils.cpp

  • Added isCpAsyncBulkTensorTile and isCpAsyncUblk to identify 1D and nD
    TMA operations.
  • +39/-0   
    predicate_compute.cpp
    Differentiate 1D and nD TMA in inline predicate computation

    csrc/predicate_compute.cpp

  • Updated getInlinePredicate to differentiate between 1D and nD TMA
    operations.
  • Updated comments to clarify the distinction between 1D and nD TMA
    operations.
  • +3/-3     
    utils.h
    Add declarations for TMA utility functions                             

    csrc/device_lower/utils.h

  • Added declarations for isCpAsyncBulkTensorTile and isCpAsyncUblk.
  • +2/-0     
    Tests
    test_memory.cpp
    Add tests for TMA predicate handling                                         

    tests/cpp/test_memory.cpp

  • Added new tests CpAsyncBulkPredicate,
    CpAsyncBulkPredicateCircularBuffer, and
    CpAsyncBulkPredicateCircularBuffer2 to verify predicate behavior for
    1D and nD TMA operations.
  • +146/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Debug Statements

    The code contains multiple std::cout statements which should be removed or replaced with proper logging mechanisms before merging.

      std::cout << "\n======================= dispatch:\n" << expr->toString() << std::endl;
      if(auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
        if (Expr* ublk_tma_load = getUblkTmaLoad(ite)) {
          // auto output = ir_utils::getTvOutput(ublk_tma_load);
          auto ldst = dynamic_cast<LoadStoreOp*>(ublk_tma_load);
          ublk_load_to_arrive_expect_ite.insert({ldst, ite});
          std::cout << "Add ublk tma load:\n" << ublk_tma_load->as<LoadStoreOp>()->toString()
                    << std::endl;
        }
      }
    
      if (expr != nullptr && expr->predicate() != nullptr) {
        // Replace expr predicate with bool conditional
        auto conditional = generateConditional(expr->predicate());
    
        if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
          if (expr->isA<kir::IfThenElse>()) {
            // TODO: This logic doesn't seem to fit well here, for unswitch the
            // logic is in the unroll loop to set the thread predicate to the
            // expr. I didn't have a quick way to do that so placing this here for
            // now.
            auto ite = expr->as<kir::IfThenElse>();
    
            NVF_ERROR(
                ite->thenBody().size() == 1,
                "Expecting predicated body to only have one vectorized expression.");
            auto vec_expr = ite->thenBody()[0];
            NVF_ERROR(
                vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>() ||
                    vec_expr->isA<TernaryOp>(),
                "Vectorize predicate exprs only supported on set operations.");
            NVF_ERROR(
                ir_utils::isTvOp(vec_expr),
                "Vectorize predicate exprs only supported on tensor view operations.");
            if (!vec_expr->inputs()[0]->isConstScalar()) {
              conditional = SimplifyingIrBuilder::logicalAndExpr(
                  conditional,
                  GpuLower::current()->threadPredMap().getPredicate(
                      ir_utils::getTvOutput(vec_expr)));
            }
          } else {
            NVF_ERROR(lower_utils::supportInlinePredicate(expr));
            auto thread_pred = GpuLower::current()->threadPredMap().getPredicate(
                ir_utils::getTvOutput(expr));
            NVF_ERROR(thread_pred->isConst() && thread_pred->value());
            conditional = SimplifyingIrBuilder::logicalAndExpr(
                conditional,
                GpuLower::current()->threadPredMap().getPredicate(
                    ir_utils::getTvOutput(expr)));
          }
        }
    
        if (ir_utils::isCpAsyncUblk(expr->predicate()->expr())) {
          predicateCpAsyncUblk(expr, conditional);
        } else {
          NVF_ERROR(conditional != nullptr);
          conditional = GpuLower::current()->commonScalarMap().hoistScalar(
              conditional, for_loops_);
          expr->predicate()->setValue(conditional);
          NVF_ERROR(expr->predicate()->value() != nullptr);
          setWritePredicate(expr);
        }
      }
    
      // may add extra predicate for wait parity to avoid deadlock
      if (expr->isA<kir::MBarrierWaitParity>()) {
        predicateUblkWaitParity(expr);
      }
    
      kir::ExprMutator::dispatch(expr);
    }
    
    void setWritePredicate(Expr* expr) {
      if (expr->writePredicate() != nullptr) {
        auto write_cond = generateConditional(expr->writePredicate());
        if (write_cond) {
          write_cond = GpuLower::current()->commonScalarMap().hoistScalar(
              write_cond, for_loops_);
          expr->writePredicate()->setValue(write_cond);
        } else {
          // If generateConditional returns null, it means no specific
          // predicate needs to be used.
          registerReplace(expr, expr->withWritePredicate(nullptr));
        }
      }
    }
    
    // This function combines the original elect sync predicate with the inline
    // predicate to avoid out-of-bound access for the ublk tma load.
    void predicateCpAsyncUblk(Expr* ite_tma_expr, Val* elect_sync_pred) {
      auto ublk_tma_expr = getUblkTmaLoadFromIte(ite_tma_expr);
      auto ldst = dynamic_cast<LoadStoreOp*>(const_cast<Expr*>(ublk_tma_expr));
      if(ublk_load_to_arrive_expect_ite.find(ldst) == ublk_load_to_arrive_expect_ite.end()) {
        std::cout << "Cannot find ublk tma load for: " << ublk_tma_expr->toString() << std::endl;
        for(auto[tv, ite] : ublk_load_to_arrive_expect_ite) {
          std::cout << "\nublk tma load: " << tv->toString() << std::endl;
          std::cout << "ublk tma itte: " << ite->toString() << std::endl;
        }
        return;
      }
      std::cout << "Find ublk tma load for: " << ublk_tma_expr->toString() << std::endl;
      kir::IfThenElse* ite = ublk_load_to_arrive_expect_ite.at(ldst);
      // inline predicate to void out-of-bound access
      auto inline_pred_val = PredicateCompute::getInlinePredicate(
          ublk_tma_expr,
          for_loops_,
          rotated_loop_,
          ite_tma_expr->predicate()->thread_pred(),
          ite_tma_expr->predicate()->predicate_type());
      std::cout << "inline pred:" << inline_pred_val->toString() << std::endl;
      inline_pred_val = GpuLower::current()->commonScalarMap().hoistScalar(
          inline_pred_val, for_loops_);
      // combine inline predicate with the original elect sync predicate
      auto combined_pred_val =
          SimplifyingIrBuilder::logicalAndExpr(elect_sync_pred, inline_pred_val);
      combined_pred_val = GpuLower::current()->commonScalarMap().hoistScalar(
          combined_pred_val, for_loops_);
      // map the mbarrier used in this tma load to the extra inline predicate,
      // this is then used to predicate the mbarrier wait parity.
      kir::TensorIndex* mbarrier =
          GpuLower::current()->tmaCircularBufferInfo().getTensorIndex(
              ublk_tma_expr->as<LoadStoreOp>());
      if(mbarrier){
        std::cout << "insert mbarrier:" << mbarrier->toString() << std::endl;
        mbarrier_inline_predicate_.insert({mbarrier, inline_pred_val});
      }else{
        std::cout << "Cannot find mbarrier for: " << ublk_tma_expr->toString() << std::endl;
      }
      // Since tma load expr is nested in the ite, we only need to predicate the
      // ite with the combined predicate and remove the tma load predicate by set
      // it to true.
      ite->predicate()->setValue(combined_pred_val);
      ite_tma_expr->predicate()->setValue(
          IrBuilder::create<Val>(true, DataType::Bool));
    
      // remove this tma load from map
      ublk_load_to_arrive_expect_ite.erase(ldst);
      std::cout << "erase:\n" << ublk_tma_expr->toString()
                << std::endl;
    }
    
    // This function addes the inline predicate to the mbarrier wait parity to
    // avoid deadlock since its corresponding ublk tma load may be predicated with
    // an inline predicate to avoid out-of-bound access.
    void predicateUblkWaitParity(Expr* expr) {
      // find the tensor index used in the mbarrier
      auto wait_parity = dynamic_cast<kir::MBarrierWaitParity*>(expr);
      auto mbarrier = wait_parity->mbarrier();
      kir::TensorIndex* tensor_index = nullptr;
      auto current_def = mbarrier->definition();
      while (current_def && current_def->isA<UnaryOp>()) {
        std::cout << "current def:\n" << current_def->toString() << std::endl;
        auto input = current_def->as<UnaryOp>()->in();
        if (input->isA<kir::TensorIndex>()) {
          tensor_index = input->as<kir::TensorIndex>();
          break;
        }
        current_def = input->definition();
      }
      NVF_CHECK(
          tensor_index,
          "Cannot find tensor index for mbarrier: ",
          mbarrier->toInlineString());
    
      // predicate this wait parity with the inline predicate used to predicate
      // the corresponding ublk tma load.
      if (mbarrier_inline_predicate_.find(tensor_index) !=
          mbarrier_inline_predicate_.end()) {
        auto pred_val = mbarrier_inline_predicate_.at(tensor_index);
        kir::Predicate* pred = IrBuilder::create<kir::Predicate>(pred_val);
        kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
        kir::ExprMutator::registerReplace(expr, inline_ite);
        inline_ite->thenBody().push_back(expr);
      }
    }
    Error Handling

    The function predicateCpAsyncUblk contains a check that returns early if a ublk_tma_load is not found. This should be handled more gracefully, possibly by throwing an exception or logging an error.

    if(ublk_load_to_arrive_expect_ite.find(ldst) == ublk_load_to_arrive_expect_ite.end()) {
      std::cout << "Cannot find ublk tma load for: " << ublk_tma_expr->toString() << std::endl;
      for(auto[tv, ite] : ublk_load_to_arrive_expect_ite) {
        std::cout << "\nublk tma load: " << tv->toString() << std::endl;
        std::cout << "ublk tma itte: " << ite->toString() << std::endl;
      }
      return;
    }
    Code Duplication

    The functions getUblkTmaLoad and getUblkTmaLoadFromIte have similar logic. Consider refactoring to avoid duplication.

    Expr* getUblkTmaLoad(Expr* ite_expr) {
      if(auto ite = dynamic_cast<kir::IfThenElse*>(ite_expr)){
        const auto& flattened_exprs = ir_utils::flattenScopedExprs(ite->thenBody().exprs());
        bool found_arrive_expect_ = false;
        for(auto expr : flattened_exprs) {
          if (expr->isA<kir::MBarrierArriveExpectTx>()) {
            found_arrive_expect_ = true;
          }
          if (found_arrive_expect_ && ir_utils::isCpAsyncUblk(expr)) {
            return expr;
          }
        }
      }
      return nullptr;
    }
    
    Expr* getUblkTmaLoadFromIte(Expr* ite_expr) {
      if(auto ite = dynamic_cast<kir::IfThenElse*>(ite_expr)){
        const auto& flattened_exprs = ir_utils::flattenScopedExprs(ite->thenBody().exprs());
        for(auto expr : flattened_exprs) {
          if (ir_utils::isCpAsyncUblk(expr)) {
            return expr;
          }
        }
      }
      return nullptr;
    }

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl marked this pull request as draft February 15, 2025 02:25
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant