-
Notifications
You must be signed in to change notification settings - Fork 20
make all 3 gemms in Float8Linear support configurability, not user facing #315
Conversation
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 302e7c443595b5bb65acb27f3fcca82ccf664a6d Pull Request resolved: #315
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1ef7943cbe21517d40975c69d0be4a719c7bf20d Pull Request resolved: #315
inpt_tensor: torch.Tensor, | ||
linear_mm_config: LinearMMConfig, | ||
reduce_amax: bool = False, | ||
gemm_input_role: GemmInputRole = GemmInputRole.X, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe no default value for gemm role?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can clean up in a separate PR, there is extra complexity because we'd need to change the argument order
self.backward_config = ScaledMMConfig( | ||
emulate, False, False, config.pad_inner_dim | ||
# TODO(future): user level configuration of gemms | ||
self.linear_mm_config = LinearMMConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[I think this might be another stylistic thing so no need to change]:
I think I would actually make this a func and then super document it. Its not very clear reading this what everything does so I would clearly explain in that func the exact recipe that we choose by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't user facing, so we can clean up at any time
@@ -76,6 +84,7 @@ def float8_cat(aten_op, args, kwargs=None): | |||
scale = chunked_tensors[0]._scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future PR:
We should code share between unflatten/flatten and here to just splat out the extra metadata that lives on the tensor
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 74ca77e254806922b7dd4d078cd1a8b9f9c74f6d Pull Request resolved: #315
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ebffcec426b76fba656a881d9a43397499179773 Pull Request resolved: #315
…not user facing" Summary: This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable: 1. add `LinearMMConfig` to `Float8Tensor` to tie together the three `ScaledMMConfig` objects, one per gemm 2. add `GemmInputRole` to `Float8Tensor` to specify how to pick the right config 3. plumb all of these throughout the codebase Note that none of this is user facing, and there is no logic change. Planned follow-ups: * a future PR will make the per-gemm behavior configurable in a user facing way, which will hook up to the objects introduced in this PR * a future PR will update the naming from x/w/dL_dY to input/weight/grad_output throughout the codebase Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 39cb928167b6e6be7e754569363f634f28f5472a Pull Request resolved: #315
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This pull request has been merged in c58fb5d. |
Stack from ghstack (oldest at bottom):
Summary:
This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable:
LinearMMConfig
toFloat8Tensor
to tie together the threeScaledMMConfig
objects, one per gemmGemmInputRole
toFloat8Tensor
to specify how to pick the right configNote that none of this is user facing, and there is no logic change. Planned follow-ups:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59973551