-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support vectorized local reduction for p2p-based ReduceScatter overlap #1452
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sangkug Lym <[email protected]>
febe2ec
to
4f697a9
Compare
for more information, see https://pre-commit.ci
#pragma unroll | ||
for (int input_id = 1; input_id < num_inputs; ++input_id) { | ||
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); | ||
#pragma unroll | ||
for (int i = 0; i < nvec; ++i) { | ||
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale); | ||
if (input_id == num_inputs - 1) { | ||
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Correctness bug when
num_inputs == 1
- Unrolling loop over
num_inputs
is not necessary since it's not known at compile-time
#pragma unroll | |
for (int input_id = 1; input_id < num_inputs; ++input_id) { | |
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); | |
#pragma unroll | |
for (int i = 0; i < nvec; ++i) { | |
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale); | |
if (input_id == num_inputs - 1) { | |
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]); | |
} | |
} | |
} | |
for (int input_id = 1; input_id < num_inputs; ++input_id) { | |
loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); | |
#pragma unroll | |
for (int i = 0; i < nvec; ++i) { | |
accum_buf[i] += static_cast<float>(loader.separate()[i]) * (*scale); | |
} | |
} | |
#pragma unroll | |
for (int i = 0; i < nvec; ++i) { | |
storer.separate()[i] = static_cast<half_dtype>(accum_buf[i]); | |
} |
Same issue in reduce_bf16_cuda
.
} | ||
|
||
template <typename fp8type> | ||
void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, | ||
int input_size, cudaStream_t stream) { | ||
const int nvec = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're using this as a template arg, better to make it explicit that it is known at compile-time:
const int nvec = 32; | |
constexpr int nvec = 32; |
Same issue in reduce_bf16
.
transformer_engine::VectorizedLoader<fp8type, nvec, true> loader(inputs_fp8, tot_input_size); | ||
transformer_engine::VectorizedStorer<half_dtype, nvec, true> storer(output_half, input_size); | ||
|
||
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we handle the case where the block size doesn't neatly divide the input size? Maybe we can fix with something like:
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; | |
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; | |
if (tid >= num_aligned_elements_per_input) { | |
return; | |
} |
Alternatively, we can change how we configure the CUDA blocks:
size_t num_threads = MAX_THREADS / 4;
assert(num_aligned_elements_per_input % num_threads == 0);
size_t num_blocks = num_aligned_elements_per_input / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
Same issue in reduce_bf16_cuda
.
Signed-off-by: Sangkug Lym <[email protected]>
fb891d8
to
b1ad009
Compare
@timmoon10 BTW, regarding the lint error, |
Description
Vectorized load/store for p2p-based ReduceScatter overlap.
Type of change
Changes
Checklist: