Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding an option to skip zeroing output tensor for f8f8bf16_rowwise_grouped_dynamic #3685

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ __global__ void set_kernel_args_fixed_nk_kernel(
int M,
int N,
int K,
int group_count) {
int group_count,
bool zeroing_output_tensor) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each thread is responsible for setting up the arguments for one group.
if (thread_idx < group_count) {
Expand All @@ -227,6 +228,7 @@ __global__ void set_kernel_args_fixed_nk_kernel(
// Write kernel args to memory.
kernel_args[thread_idx] = kernel_group_args;
}
if (!zeroing_output_tensor) return;

// Figure out where in memory we are.
// Each thread sets one float 4 which corresponds to 8 bf16 values.
Expand All @@ -252,7 +254,8 @@ void set_dynamic_kernel_args(
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor output,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size(0);
Expand Down Expand Up @@ -292,7 +295,8 @@ void set_dynamic_kernel_args(
M,
N,
K,
group_count);
group_count,
zeroing_output_tensor);
}

template <typename OutputType>
Expand Down Expand Up @@ -433,7 +437,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true) {
// Check that input datatypes are valid.
// First confirm that there are the same number of groups in all inputs.
int group_count = XQ.size(0);
Expand Down Expand Up @@ -473,7 +478,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
{static_cast<long>(group_count * sizeof(KernelArguments))},
XQ.options().dtype(at::kByte));
set_dynamic_kernel_args(
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M, zeroing_output_tensor);

RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,14 +682,20 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true) {
at::Tensor Y;
int group_count = XQ.size(0);
int M = XQ.size(1);
int N = WQ.size(1);
int K = XQ.size(0);
int total_output_size = group_count * M * N;
Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16));
if (zeroing_output_tensor) {
Y = at::zeros(total_output_size, XQ.options().dtype(at::kBFloat16));
} else {
Y = at::empty(total_output_size, XQ.options().dtype(at::kBFloat16));
}

// Return continuous view of output.
at::Tensor output = dispatch_fp8_grouped_kernel<at::Tensor>(
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
Expand Down Expand Up @@ -724,7 +730,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ, // FP8
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M) {
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor zero_start_index_M);
at::Tensor zero_start_index_M,
bool zeroing_output_tensor = true);
at::Tensor f8f8bf16_blockwise(
at::Tensor XQ,
at::Tensor WQ,
Expand Down Expand Up @@ -221,7 +222,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"f8f8bf16_rowwise_grouped_stacked(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor(a!)? output=None) -> Tensor");
m.def(
"f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M) -> Tensor");
"f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M, bool zeroing_output_tensor=True) -> Tensor");
m.def(
"f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor");
m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor");
Expand Down
Loading