Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DID loop split on reshaped IDs #3875

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

DID loop split on reshaped IDs #3875

wants to merge 6 commits into from

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Feb 11, 2025

This PR updates propagateReshapeTransform to support DID loop split.

When the loop split is on the iterdomains being reshaped, the logical reshaped iterdomain is no longer present in the loop domain since it is split. In this case, we check if there is a sharded loop ID and compare the logical reshape iterdomain to the producer of this DID split.

@Priya2698
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 11, 2025

Review updated until commit 73d237a

Description

  • Added support for DID loop split in propagateReshapeTransform

  • Enhanced propagateReshapeTransform to handle reshaped iterdomains

  • Added a new test case for transform propagation with reshape


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Handle DID loop split in propagateReshapeTransform             

csrc/scheduler/utils.cpp

  • Added logic to check if logical ID is directly in the loop domain
  • Added handling for sharded loop ID and DID split producer comparison
  • Reordered reshape dimensions to the front of the domain
  • +21/-1   
    Tests
    test_multidevice_sharding.cpp
    Add test for transform propagation with reshape                   

    tests/cpp/test_multidevice_sharding.cpp

  • Added a new test case TransformPropagatorWithReshape
  • Demonstrates transform propagation with reshape and DID loop split
  • +60/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Error Handling

    The error message in NVF_ERROR could be more descriptive, especially regarding the context of the DID loop split and reshaped IDs.

    "Require ",
    logical_id,
    " is in the active domain of ",
    tv->toString(),
    " for view propagation.");
    Code Clarity

    The logic for handling the DID loop split and reshaped IDs could be refactored for better readability and maintainability.

    // Check if logical ID is directly in the loop domain
    auto find_it = std::find(
        tv->getLoopDomain().begin(), tv->getLoopDomain().end(), logical_id);
    
    // If not found directly and there is a sharded loop ID,
    // check if the logical ID is the same as the producer of the DID split.
    if (find_it == tv->getLoopDomain().end()) {
      int64_t sharded_axis = getShardedLoopAxis(tv, ParallelType::DIDx);
      if (sharded_axis != -1) {
        // Get the split operation that created the DIDx dimension
        auto split = dynamic_cast<Split*>(
            tv->getLoopDomain().at(sharded_axis)->definition());
        if (split != nullptr && split->in() == logical_id) {
          find_it = std::find(
              tv->getLoopDomain().begin(),
              tv->getLoopDomain().end(),
              split->inner());
        }
      }
    }
    
    NVF_ERROR(
    Test Coverage

    Ensure that the new test case covers all edge cases and potential failure scenarios related to DID loop split and reshaped IDs.

    TEST_F(MultiDeviceTest, TransformPropagatorWithReshape) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int d = communicator_->size();
      const int64_t b = 2, s = 2, h = 4, e = 3;
    
      TensorView* in = makeContigConcreteTensor(
          {b, s, d * h * e}); // in: loop domain: {b, s, d*h*e}
      TensorView* out = reshape(
          in,
          {b, s, d * h * e},
          {b, s, d * h, e}); // out: loop domain: {b, s, d*h, e}
    
      fusion->addInput(in);
      fusion->addOutput(out);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
    
      // Propagate transform from reshaped output to input.
      // Without this propagation, the two DID axes on `in` and `out` will not be
      // mapped in together in ID model. This causes scheduling to fail due to
      // resharding.
      TransformPropagator propagator_c2p(out);
      MaxLogicalDomainInfoSpanningTree(out).traverse(&propagator_c2p);
      // in: loop domain: {b, s, d*h, e} after transform propagation
    
      // Loop split and parallelize input
      in->setDeviceMesh(mesh);
      in->split(-2, d, /*inner_split=*/false);
      in->axis(-3)->parallelize(ParallelType::DIDx);
      // in: loop domain: {b, s, DIDx{d}, h, e}
    
      // Propagate DID loop split to output
      TransformPropagator propagator_p2c(in);
      MaxLogicalDomainInfoSpanningTree(in).traverse(&propagator_p2c);
      // out: loop domain: {b, s, d, h, e} after transform propagation
    
      // Parallelize output
      scheduler_utils::parallelizeAllLike(
          in,
          /*pos=*/-1,
          /*selected_tv=*/{out});
      // out: loop domain: {b, s, DIDx{d}, h, e} after parallelization
    
      in->setAllocationDomain(in->getLoopDomain(), true);
      out->setAllocationDomain(out->getLoopDomain(), true);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor in_tensor = at::randn({b, s, h * e}, tensor_options);
      at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0];
      testValidate(
          executor_cache.fusion(),
          {out_tensor},
          {in_tensor},
          {in_tensor.view({b, s, h, e})},
          __LINE__,
          __FILE__);
    }

    @Priya2698 Priya2698 marked this pull request as ready for review February 12, 2025 22:34
    @Priya2698 Priya2698 force-pushed the pm/reshape_propagate branch from d3c602d to 52a7a0c Compare February 12, 2025 22:57
    @Priya2698 Priya2698 requested a review from wujingyue February 12, 2025 22:58
    @Priya2698 Priya2698 force-pushed the pm/reshape_propagate branch from 858f9fc to 90ab5ee Compare February 13, 2025 00:57
    @wujingyue wujingyue requested a review from naoyam February 13, 2025 04:01
    @Priya2698
    Copy link
    Collaborator Author

    !test

    csrc/scheduler/utils.cpp Outdated Show resolved Hide resolved

    // Reorder the reshape dimensions to the front of the domain
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @naoyam I'm quite confused by this pre-existing logic before I can understand this PR. Why is it necessary to move reshape dimensions to the front of the loop domain? It can cause conflict with the pre-existing assumption that DIDs have to be the front as well.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think this is related to the propagation done at line 2279. IIRC, it propagates the outermost N dimensions, where N is old2new.size() in this case. Since here we just want to propagate the transformations related to the rfactor, this is how we limit the propagation.

    We can probably just reorder tv back after line 2279.

    auto find_it = std::find(
    tv->getLoopDomain().begin(), tv->getLoopDomain().end(), logical_id);

    // If not found directly and there is a sharded loop ID,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I think I see what the below part is trying to do and why, which seems to make sense, but can you expand the comment and elaborate a little more?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This bears several assumptions that will break in a foreseeable future.

    With context parallelism, the sequence dimension s will be split into [tp, iDIDy{cp}, s/tp/cp], so the code below won't be able to find tp and s/tp/cp. Similarly, with overlapping, the sequence dimension s will be split into [sp, iDID{tp}, s/sp/tp] where the sp is the stream parallelization factor. See this test for the idea.

    I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    (I still haven't given up on improving ID model.)

    If we have to do graph traversal like this, we may want to do it in a place where the logic can be generalized and reused (and therefore ID model). At this moment, there are two use cases:

    1. splitting reshape: [h]=>[d, h/d] and [h]=>[d,a/d,h/a]
    2. merging reshape: [a,h/a]=>[d,a/d,h/a] and [a,h/a]=>[h]=>[d,h/d]
      We want ID model to map the ds in both cases so these reshapes won't be considered resharding.

    How much harder is it to make ID model support these cases than working around using reshape transformation? I suspect the latter has a bigger blast radius because the former is local to ID model and the latter changes TensorViews.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I realized the same limitation for DID loop split on slice:

      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      const int d = communicator_->size();
      const int64_t b = 2, s = 2, h = 4;
    
      TensorView* in = makeContigConcreteTensor(
          {b, s, 3 * d * h});
      TensorView* out = slice(
          in,
          {0, 0, 0},
          {b, s, d * h});
    
      fusion->addInput(in);
      fusion->addOutput(out);
    
      auto mesh = DeviceMesh::createForNumDevices(d);
      for (auto tv: {in, out}) {
        tv->setDeviceMesh(mesh);
        tv->split(-1, d, /*inner_split=*/false);
        tv->axis(-2)->parallelize(ParallelType::DIDx);
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    

    I was trying to manually handle the case of SliceOp in hasDifferentShardings but it would make certain assumptions about the parallelization patterns and can easily break.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...

    Yes, I agree. I wanted to add an example to demonstrate how reshapes can be loop split but it certainly does not cover all the cases.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I still haven't given up on improving ID model.

    (Sorry -- I wish I knew more about IdModel to be more constructive.)

    Another use case to consider is manual sharding -- the user wants to manually shard a subset of TVs to improve perf when our sharding propagation is suboptimal.

    They may well annotate [b,s,h]=>(reshape)=>[b,s,a,h/a] as follows

    in: [b,s,h] => [b,s,d,h/d]
    out: [b,s,a,h/a] => [b,s,d,a/d,h/a]
    

    and expect nvFuser to recognize this reshape is local. In this case, it's hard to replay the reshape on the input because h there is already split by d.

    auto split = dynamic_cast<Split*>(
    tv->getLoopDomain().at(sharded_axis)->definition());
    if (split != nullptr && split->in() == logical_id) {
    find_it = std::find(
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    While I understand your intention, I don't know the implications of doing this for TransformPropagator. E.g.

    root=[b, s, h], logical=[b, s, a, h/a], loop=[b, s, d, a/d, h/a]
    

    This will move a/d and h/a to the front so the new loop domain becomes [a/d, h/a, b, s, d] and later ask TransformPropagator to replay at replayed_pos_ 2. What is TransformPropagator supposed to do with that? The first two loop IDs (a/d and h/a) don't even form a split in this TV.

    cc @naoyam

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I see your point. In the case of reshape with DID loop split, we have already propagated the reshape upwards, so the TransformPropagator only reorders the axis when called later. In the absence of the earlier reshape propagation before the loop split, the behavior could be erroneous since they don't form a split.

    Although, since the reshape has already been propagated, and, as @naoyam mentioned above, the tv is reordered back, maybe this propagation can be skipped altogether.
    Let me think about it more and see what the schedulers expect from this propagateReshapeTransform.

    However, this may not work for the manual sharding case you mentioned in the above comment.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants