Skip to content

Commit

Permalink
comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 committed Feb 12, 2025
1 parent 52a7a0c commit 858f9fc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
1 change: 0 additions & 1 deletion csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2247,7 +2247,6 @@ void propagateReshapeTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {
auto split = dynamic_cast<Split*>(
tv->getLoopDomain().at(sharded_axis)->definition());
if (split && split->in() == logical_id) {
// The DIDx axis is not reordered, since
find_it = std::find(
tv->getLoopDomain().begin(),
tv->getLoopDomain().end(),
Expand Down
8 changes: 6 additions & 2 deletions tests/cpp/test_multidevice_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,9 @@ TEST_F(MultiDeviceTest, TransformPropagatorWithReshape) {
auto mesh = DeviceMesh::createForNumDevices(d);

// Propagate transform from reshaped output to input.
// Without this propagation, the two DID axes on `in` and `out` will not be
// mapped in together in ID model. This causes scheduling to fail due to
// resharding.
TransformPropagator propagator_c2p(out);
MaxLogicalDomainInfoSpanningTree(out).traverse(&propagator_c2p);
// in: loop domain: {b, s, d*h, e} after transform propagation
Expand All @@ -735,16 +738,17 @@ TEST_F(MultiDeviceTest, TransformPropagatorWithReshape) {
in->axis(-3)->parallelize(ParallelType::DIDx);
// in: loop domain: {b, s, DIDx{d}, h, e}

// Propagate DID loop split to output
TransformPropagator propagator_p2c(in);
MaxLogicalDomainInfoSpanningTree(in).traverse(&propagator_p2c);
// out: loop domain: {b, s, d, h, e} after transform propagation

// Parallelize out
// Parallelize output
scheduler_utils::parallelizeAllLike(
in,
/*pos=*/-1,
/*selected_tv=*/{out});
// out: loop domain: {b, s, DIDx{d}, h, e} after transform propagation
// out: loop domain: {b, s, DIDx{d}, h, e} after parallelization

in->setAllocationDomain(in->getLoopDomain(), true);
out->setAllocationDomain(out->getLoopDomain(), true);
Expand Down

0 comments on commit 858f9fc

Please sign in to comment.