-
Notifications
You must be signed in to change notification settings - Fork 329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization for Roberta unstick->reshape->transpose->reshape->stick #3056
base: main
Are you sure you want to change the base?
Optimization for Roberta unstick->reshape->transpose->reshape->stick #3056
Conversation
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
@@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern< | |||
// | |||
//===----------------------------------------------------------------------===// | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
migrated the code elsewhere so that it can be reused, as it was needed to support the reshape op.
@@ -402,8 +473,8 @@ AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) { | |||
return AffineMapAttr::get(map); | |||
} | |||
|
|||
AffineMapAttr getTiling3DTo4DMap(OpBuilder &b, Value val) { | |||
assert(isTiling3DTo4D(val) && | |||
AffineMapAttr getLeftmostTiling3DTo4DMap(OpBuilder &b, Value val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of the prior operations where not specific if they applied to the right most or leftmost position, as only one was needed. As I added more, I made all names more explicit,
IndexExprScope currScope(&rewriter, loc); | ||
// Here, cannot use the shape found in the reshape op, as it is the original | ||
// shape before memref normalization. | ||
Value input = reshapeOp.getX(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if X is normalized or not before processing further, by using something like this:
// Input must have no affine layout. In other words, it has been normalized.
if (hasNonIdentityLayout(input.getType()))
return failure();
Without this check, I see that, in your lit test zlow-rewrite.mlir
, zlow.reshape
with affine_maps is still lowered to memref.reinterpret_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I did not realize that zlow-rewrite
ran twice. Its now fixed.
// CHECK-LABEL: func.func @handle_zlow_reshape | ||
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<8x384x768xf16, #map>, [[PARAM_1_:%.+]]: memref<96x64x384xf16, #map>) -> memref<96x384x384xf16, #map> { | ||
|
||
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [96, 384, 64], strides: [24576, 64, 1] : memref<8x384x768xf16, #map> to memref<96x384x64xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not look like what we are expecting since the input memref is not normalized.
To check this case, you can
- add a new check for this case where we call
--normalize-memrefs
, by adding this line to top of this file:
// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --normalize-memrefs --zlow-rewrite --canonicalize %s -split-input-file | FileCheck %s --check-prefix=RESHAPE
- then, replace
CHECK
byRESHAPE
inCHECK-DAG, CHECK-LABEL, ...
since we use the prefixRESHAPE
for this check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply wrote 2 versions of that test, one without and one with memref normalized, and checked that the pattern only applies with memref normalized.
Signed-off-by: Alexandre Eichenberger <[email protected]>
Thanks @tungld for the feedback, implemented both suggestions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Glad to see the performance improvement!
In some situations, a sequence of transformations are "no-ops" under a given ztensor representation.
The pattern that is exploited here is the 3DS for
(A, B, C*D) <-> (A*C, B, D)
which are equivalent whenB%32=0
andD%64=0
,The pattern detected and transformed to a
high/zlow reshape
are the followingand
A high level proof is here below. For a detail proof, one has to follow every steps of the transformations above and show equality of memory accesses, namely that when accessing
(e3, e2, e1)
, we get the same memory location in the original 3DS tensor as well as the final 3DS tensor in the above examples.In practice, this PR adds 2 rules to catch the above 2 patterns, replace them with a
zhigh.Reshape
which is similar to thememref.reshape
in that it performs no "data layout transformation", just provide mapping between two equivalent shapes. The ZHigh version performs such equivalency for ZTensor formats such as 3D.THe ZHigh reshape operation is lowered to ZLow equivalent reshape operation, which is then transformed to a
memref.reinterpret_cast
operations after all members are normalized.PR adds littlest to catch the patterns listed above, and one for ZHigh to ZLow conversion, and one for ZLow to memref.
I checked that the values generated by Roberta with/without this PR were the same. Performance measurements show that in Roberta, the number of transpose were reduced from 48 to 12 (with a reductions in stick/unstick also by 36 operations). Speedup for the time spent in the transpose/stick/unstick were reduced by 9%, 33%, and 37%. Overall (with one NNPA and once CPU), the time was reduced by 4%.
At this time, this PR is restricted to static shapes.