Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid ublk tma out bound access #3917

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Feb 18, 2025

Background
CpAsyncBulkTensorTile handles out of bound accesses automatically in hardware, no need to predicate it.
However, CpAsyncBulk leads to illegal memory access if out of bound access happens.
ref the memory range [srcMem, srcMem + size - 1] must not overflow the source memory space. Otherwise, the behavior is undefined
It happens in persistent kernel with non-divisible circular buffer stages and/or non-divisible CTA count. For example, tensor [M, N] is split as: [sm_count, M/stages/sm_count, stages, N] and parallelized as: [BIDx, Serial, Serial, Bulk]. The TMA load is nested within two for-loops, one for [I0/stages/sm_count] and the other for [stages], since predicate is not generated for TMA load, out of bound access happens if any of the split is not disvisible.

Candidate Solutions
(1) Add a predicate for CpAsyncBulk:
I explored this approach in PR #3899. However, implementing it becomes challenging in cases involving a circular buffer, as we need to predicate both CpAsyncBulk and waitParity to avoid potential deadlocks.

(2) Modulo the source address calculation for CpAsyncBulk (this PR)
When calculating the source address, we can apply a modulo operation using the tensor size. For a tensor of shape [M, N], each CpAsyncBulk loads one row. If the computed source address corresponds to (M + x) * N, it causes an overflow. After mod(M×N), it changes to load from x * N. While this load is redundant, it occurs only in the last iteration of the loop on SMs that otherwise would stays idle, thus has minimal impact on other valid loads and computations across other SMs.

Copy link

github-actions bot commented Feb 18, 2025

Review updated until commit 48bedc1

Description

  • Added modulo operation to avoid out-of-bound access in CpAsyncBulk.

  • Extended circular buffer tests for 1D TMA loads.

  • Introduced isCpAsyncUblk function to identify UBLK operations.

  • Added new test cases for UBLK predicate scenarios.


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Added UBLK detection function                                                       

csrc/device_lower/utils.cpp

  • Added isCpAsyncUblk function to identify UBLK operations.
+18/-0   
indexing.cpp
Applied modulo to UBLK loads                                                         

csrc/id_model/indexing.cpp

  • Applied modulo operation to linear index for UBLK loads to avoid
    out-of-bound access.
  • +22/-0   
    utils.h
    Declared UBLK detection function                                                 

    csrc/device_lower/utils.h

    • Declared isCpAsyncUblk function in header file.
    +1/-0     
    Tests
    test_circular_buffering.cpp
    Added UBLK predicate tests                                                             

    tests/cpp/test_circular_buffering.cpp

  • Added new test suite TmaCircularBufferingTestUblk for UBLK predicate
    scenarios.
  • Extended test parameters to include UBLK load types.
  • +95/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Error Handling

    The function isCpAsyncUblk throws an exception if the memory types for CpAsyncBulk are invalid. This could lead to a crash if the function is called with unexpected memory types. Consider handling this case more gracefully.

      NVF_THROW("Invalid memory types for CpAsyncBulk");
    }
    Logical Size Calculation

    The calculation of logical_size in getLinearIndex assumes that all dimensions have an extent. If any dimension has an extent of zero, this could lead to incorrect results or division by zero. Ensure that the extent of each dimension is checked before multiplication.

    auto logical_size = gmem_tv->fusion()->oneVal();
    const auto& logical_domain = gmem_tv->getLogicalDomain();
    for (const auto i : c10::irange(logical_domain.size())) {
      logical_size = SimplifyingIrBuilder::mulExpr(
          logical_size, logical_domain.at(i)->extent());
    }
    Test Coverage

    The new test TmaCircularBufferingTestUblk covers a specific scenario with a prime number for the outer dimension. Consider adding more test cases to cover a wider range of scenarios, including edge cases, to ensure the robustness of the solution.

    // Similar to TmaCircularBufferingTest, but only test 1D TMA (UBLK) with one
    // tensor size. Outer dim is a prime number to test predicate due to
    // non-divisble split.
    class TmaCircularBufferingTestUblk : public TmaCircularBufferingTest {};
    TEST_P(TmaCircularBufferingTestUblk, Predicate) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
    
      if (testEnablesRegisterSharing()) {
        GTEST_SKIP();
        return;
      }
      constexpr at::ScalarType dtype = at::ScalarType::Float;
      CompileParams index32bit{DataType::Int32, 255, false};
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
      auto tv0 = makeContigTensor(2, aten_to_data_type(dtype));
      fusion->addInput(tv0);
      auto tv1 = add(tv0, tv0);
      fusion->addOutput(tv1);
    
      auto tv0a = tv0->cacheAfter(tma_load_type);
      auto tv1c = tv1->cacheBefore();
      tv0a->setMemoryType(MemoryType::Shared);
    
      // tensor_outer_dim is a prime number, not divisible by number_of_stages or
      // number_of_cta when stages is 1. When stages > 1, increase number_of_cta to
      // make sure the 2nd split i also not divisible by number_of_cta.
      int64_t number_of_cta =
          at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
      if (number_of_stages > 1) {
        int64_t after_stages =
            (tensor_outer_dim + number_of_stages - 1) / number_of_stages;
        while (after_stages % number_of_cta == 0) {
          number_of_cta++;
        }
      }
      tv1->split(0, number_of_stages);
      tv1->split(0, number_of_cta, false);
      TransformPropagator propagator(tv1);
      MaxLogicalDomainInfoSpanningTree(tv1).traverse(&propagator);
    
      tv1->axis(0)->parallelize(ParallelType::BIDx);
      scheduler_utils::parallelizeAllLike(tv1);
    
      /// TIDx for computation, Bulk for load
      tv1->axis(-1)->parallelize(ParallelType::TIDx);
      tv1c->axis(-1)->parallelize(ParallelType::TIDx);
      tv0a->axis(-1)->parallelize(ParallelType::Bulk);
      inlineMost();
    
      if (number_of_stages > 1) {
        tv0a->circularBuffer(
            number_of_stages, prefetch_distance, circular_buffer_type);
      }
    
      auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0);
      at::Tensor at_tv0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options);
    
      KernelExecutor ke;
      ke.compile(fusion.get(), {at_tv0}, {}, index32bit);
      auto outputs = ke.run({at_tv0});
      auto at_output = at_tv0 + at_tv0;
      testValidate(
          fusion.get(), outputs, {at_tv0}, {at_output}, __LINE__, __FILE__);
    }
    auto tmaUblkPredicateParams() {
      // When using register sharing with warp-specialized circular buffering, the
      // circular buffer loop must be the outer-most for-loop
      const std::vector<CircularBufferType> all_types{
          Pipelined(false),
          Pipelined(true),
          WarpSpecialized(ParallelType::TIDx),
          WarpSpecialized(ParallelType::TIDy),
          WarpSpecialized(ParallelType::TIDz)};
      int64_t dim0 = 8191, dim1 = 256;
      const std::vector<LoadStoreOpType> tma_types{LoadStoreOpType::CpAsyncBulk};
      std::vector<TmaCircularBufferingParams> values;
      for (int64_t i : {2, 4}) {
        for (int64_t j : c10::irange(-i, i)) {
          for (auto circular_buffer_type : all_types) {
            for (auto tma_load_type : tma_types) {
              values.emplace_back(
                  i, j, dim0, dim1, circular_buffer_type, tma_load_type);
            }
          }
        }
      }
      return testing::ValuesIn(values);
    }
    INSTANTIATE_TEST_SUITE_P(
        UblkTma,
        TmaCircularBufferingTestUblk,
        tmaUblkPredicateParams(),
        tmaName);
    
    } // namespace nvfuser

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant