-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DID loop split on reshaped IDs #3875
base: main
Are you sure you want to change the base?
Conversation
!test |
Review updated until commit 73d237a Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
d3c602d
to
52a7a0c
Compare
858f9fc
to
90ab5ee
Compare
!test |
|
||
// Reorder the reshape dimensions to the front of the domain |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam I'm quite confused by this pre-existing logic before I can understand this PR. Why is it necessary to move reshape dimensions to the front of the loop domain? It can cause conflict with the pre-existing assumption that DIDs have to be the front as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is related to the propagation done at line 2279. IIRC, it propagates the outermost N dimensions, where N is old2new.size()
in this case. Since here we just want to propagate the transformations related to the rfactor, this is how we limit the propagation.
We can probably just reorder tv
back after line 2279.
Co-authored-by: Jingyue Wu <[email protected]>
auto find_it = std::find( | ||
tv->getLoopDomain().begin(), tv->getLoopDomain().end(), logical_id); | ||
|
||
// If not found directly and there is a sharded loop ID, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I see what the below part is trying to do and why, which seems to make sense, but can you expand the comment and elaborate a little more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bears several assumptions that will break in a foreseeable future.
With context parallelism, the sequence dimension s
will be split into [tp, iDIDy{cp}, s/tp/cp]
, so the code below won't be able to find tp
and s/tp/cp
. Similarly, with overlapping, the sequence dimension s
will be split into [sp, iDID{tp}, s/sp/tp]
where the sp
is the stream parallelization factor. See this test for the idea.
I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I still haven't given up on improving ID model.)
If we have to do graph traversal like this, we may want to do it in a place where the logic can be generalized and reused (and therefore ID model). At this moment, there are two use cases:
- splitting reshape: [h]=>[d, h/d] and [h]=>[d,a/d,h/a]
- merging reshape: [a,h/a]=>[d,a/d,h/a] and [a,h/a]=>[h]=>[d,h/d]
We want ID model to map thed
s in both cases so these reshapes won't be considered resharding.
How much harder is it to make ID model support these cases than working around using reshape transformation? I suspect the latter has a bigger blast radius because the former is local to ID model and the latter changes TensorViews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized the same limitation for DID loop split on slice:
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
const int d = communicator_->size();
const int64_t b = 2, s = 2, h = 4;
TensorView* in = makeContigConcreteTensor(
{b, s, 3 * d * h});
TensorView* out = slice(
in,
{0, 0, 0},
{b, s, d * h});
fusion->addInput(in);
fusion->addOutput(out);
auto mesh = DeviceMesh::createForNumDevices(d);
for (auto tv: {in, out}) {
tv->setDeviceMesh(mesh);
tv->split(-1, d, /*inner_split=*/false);
tv->axis(-2)->parallelize(ParallelType::DIDx);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}
I was trying to manually handle the case of SliceOp
in hasDifferentShardings
but it would make certain assumptions about the parallelization patterns and can easily break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this change does fix some narrow cases that we care about at this very moment, but I'll have to think more about how to fix the broader issue...
Yes, I agree. I wanted to add an example to demonstrate how reshapes can be loop split but it certainly does not cover all the cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still haven't given up on improving ID model.
(Sorry -- I wish I knew more about IdModel to be more constructive.)
Another use case to consider is manual sharding -- the user wants to manually shard a subset of TVs to improve perf when our sharding propagation is suboptimal.
They may well annotate [b,s,h]=>(reshape)=>[b,s,a,h/a]
as follows
in: [b,s,h] => [b,s,d,h/d]
out: [b,s,a,h/a] => [b,s,d,a/d,h/a]
and expect nvFuser to recognize this reshape is local. In this case, it's hard to replay the reshape on the input because h
there is already split by d
.
auto split = dynamic_cast<Split*>( | ||
tv->getLoopDomain().at(sharded_axis)->definition()); | ||
if (split != nullptr && split->in() == logical_id) { | ||
find_it = std::find( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I understand your intention, I don't know the implications of doing this for TransformPropagator. E.g.
root=[b, s, h], logical=[b, s, a, h/a], loop=[b, s, d, a/d, h/a]
This will move a/d
and h/a
to the front so the new loop domain becomes [a/d, h/a, b, s, d]
and later ask TransformPropagator
to replay at replayed_pos_
2. What is TransformPropagator
supposed to do with that? The first two loop IDs (a/d
and h/a
) don't even form a split in this TV.
cc @naoyam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. In the case of reshape with DID loop split, we have already propagated the reshape upwards, so the TransformPropagator only reorders the axis when called later. In the absence of the earlier reshape propagation before the loop split, the behavior could be erroneous since they don't form a split.
Although, since the reshape has already been propagated, and, as @naoyam mentioned above, the tv
is reordered back, maybe this propagation can be skipped altogether.
Let me think about it more and see what the schedulers expect from this propagateReshapeTransform
.
However, this may not work for the manual sharding case you mentioned in the above comment.
This PR updates
propagateReshapeTransform
to support DID loop split.When the loop split is on the iterdomains being reshaped, the logical reshaped iterdomain is no longer present in the loop domain since it is split. In this case, we check if there is a sharded loop ID and compare the logical reshape iterdomain to the producer of this DID split.