Support tensors with only column-wise data #1505
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The quantized tensor infrastructure in TE 2.0 assumes that tensors have row-wise data available, both in the core C++ library and in the PyTorch extensions. This PR relaxes that assumption to support tensors wth only column-wise data. This allows us to avoid caching unnecessary data after the linear forward pass (we only need column-wise input for wgrad GEMM) and to reduce communication volume in MXFP8 tensor-parallel all-gathers. It is not quite perfectly optimized (tensor-parallel linear module still caches BF16 input tensor instead of column-wise MXFP8 tensor), but it is a reasonable step.
Type of change
Changes
Linear
module,LayerNormLinear
module,BasicLinear
opChecklist: