From a467c76ff1b05772b64638a1a57187513234dee1 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Mon, 21 Oct 2024 21:47:12 -0400 Subject: [PATCH] [Operators] Support `bfloat16` data type in `matmul` operator (#511) Closes #500 --- python/hidet/backend/codegen.py | 1 + python/hidet/cuda/cublas/utils.py | 1 + .../hidet/graph/ops/matmul/matmul_f16_cute.py | 102 ++++++++++++------ python/hidet/graph/ops/matmul/resolve.py | 10 +- python/hidet/ir/dtypes/__init__.py | 3 +- python/hidet/ir/dtypes/vector.py | 5 +- python/hidet/ir/primitives/cuda/smem.py | 1 + requirements-dev.txt | 1 - tests/operators/test_matmul.py | 27 ++++- 9 files changed, 113 insertions(+), 38 deletions(-) diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index 0e3846abf..516491794 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -636,6 +636,7 @@ def visit_DataType(self, t: DataType): 'uint8x4': 'uint4', 'int4bx8': 'uint32_t', 'uint4bx8': 'uint32_t', + 'bfloat16x2': '__nv_bfloat162', } self.require_complex = self.require_complex or t.name in ['complex64', 'complex128'] diff --git a/python/hidet/cuda/cublas/utils.py b/python/hidet/cuda/cublas/utils.py index 584bae2a7..28bdb7544 100644 --- a/python/hidet/cuda/cublas/utils.py +++ b/python/hidet/cuda/cublas/utils.py @@ -30,6 +30,7 @@ def as_pointer(obj) -> int: dtypes.float16: cudaDataType.CUDA_R_16F, dtypes.float32: cudaDataType.CUDA_R_32F, dtypes.float64: cudaDataType.CUDA_R_64F, + dtypes.bfloat16: cudaDataType.CUDA_R_16BF, } diff --git a/python/hidet/graph/ops/matmul/matmul_f16_cute.py b/python/hidet/graph/ops/matmul/matmul_f16_cute.py index 9b622a42a..93e6f222a 100644 --- a/python/hidet/graph/ops/matmul/matmul_f16_cute.py +++ b/python/hidet/graph/ops/matmul/matmul_f16_cute.py @@ -14,7 +14,7 @@ from hidet.ir import dtypes from hidet.ir.type import DataType, data_type -from hidet.ir.dtypes import float16, float32 +from hidet.ir.dtypes import float16, float32, bfloat16 from hidet.ir.expr import if_then_else, Int, Expr, cast from hidet.ir.func import Function from hidet.ir.module import IRModule @@ -32,6 +32,10 @@ def cast_fp16(x: Expr): return float16(x) +def cast_bf16(x: Expr): + return bfloat16(x) + + class MatmulF16CuteTask(Task): """ A matmul template that enables f16 tensorcore with CuTe dialect. @@ -48,8 +52,23 @@ def __init__( transpose_b: bool = False, ): self.transpose_b = transpose_b - if not a.type.dtype == float16 or not b.type.dtype == float16: - raise ValueError('Both inputs must be float16 tensors') + + if not a.type.dtype == b.type.dtype: + raise ValueError(f'Both inputs must have the same dtype, but got {a.type.dtype} and {b.type.dtype}') + + both_f16 = a.type.dtype == float16 + both_bf16 = a.type.dtype == bfloat16 + + if both_f16: + target_float_type = float16 + elif both_bf16: + target_float_type = bfloat16 + else: + raise ValueError( + f'Both inputs must be float16 or bfloat tensors, but got {a.type.dtype} and {b.type.dtype}' + ) + + self.target_float_type = target_float_type if len(a.shape) < 2 or len(b.shape) < 2: raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a.shape, b.shape)) @@ -114,14 +133,14 @@ def inner_compute(k_part, indices, k): ) def outer_compute(k_part, *indices): - return float16( + return target_float_type( reduce(shape=[k_part_extent], fcompute=partial(inner_compute, k_part, indices), reduce_type='sum') ) c = compute(name='c', shape=c_shape, fcompute=outer_compute) super().__init__( - name=f'matmul_f16_pk_cute_transpose_b_{transpose_b}', + name=f'matmul_{target_float_type.short_name}_pk_cute_transpose_b_{transpose_b}', inputs=[a, b], outputs=[c], attributes={'acc_dtype': acc_dtype, 'parallel_k_parts': parallel_k_parts, 'transpose_b': transpose_b}, @@ -191,6 +210,8 @@ def schedule( from hidet.lang.constructs.declare import as_tensor_pointer transpose_b = self.attrs['transpose_b'] + target_float_type = self.target_float_type + cast_func = cast_fp16 if target_float_type == float16 else cast_bf16 # input shapes node_a, node_b, node_c = self.inputs[0], self.inputs[1], self.outputs[0] @@ -211,6 +232,7 @@ def schedule( if transpose_b: # TODO: Is there a way to support cuBLAS with B transposed? tune.check(not use_cublas, 'Cublas does not support transpose_b') + if use_cublas: from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule from hidet.cuda.cublas import cublasComputeType @@ -237,8 +259,17 @@ def schedule( return get_cublas_matmul_schedule(a_shape, b_shape, c_shape, dtype, dtype, dtype, compute_type) # schedule parameters + + # For the bfloat16 case, there is no mma config with float16 accumulator, only float32 + if target_float_type == bfloat16: + tune.check(acc_dtype == float32, 'bfloat16 only supports float32 accumulator') + mma_configs_f16 = {'m16n8k8': MmaConfig.m16n8k8_f16_f16(), 'm16n8k16': MmaConfig.m16n8k16_f16_f16()} - mma_configs_f32 = {'m16n8k8': MmaConfig.m16n8k8_f16_f32(), 'm16n8k16': MmaConfig.m16n8k16_f16_f32()} + mma_configs_f32 = ( + {'m16n8k8': MmaConfig.m16n8k8_f16_f32(), 'm16n8k16': MmaConfig.m16n8k16_f16_f32()} + if target_float_type == float16 + else {'m16n8k8': MmaConfig.m16n8k8_bf16_f32(), 'm16n8k16': MmaConfig.m16n8k16_bf16_f32()} + ) tune.check(mma in mma_configs_f16 or mma in mma_configs_f32) mma_config = mma_configs_f16[mma] if acc_dtype == float16 else mma_configs_f32[mma] @@ -260,17 +291,19 @@ def schedule( tune.check(block_k % 8 == 0) tune.check(is_power_of_two(block_k // 8)) smem_a_type = tensor_type( - 'float16', shape=[block_m, block_k], layout=row_major(block_m, block_k // 8).swizzle(1) * row_major(1, 8) + self.target_float_type.name, + shape=[block_m, block_k], + layout=row_major(block_m, block_k // 8).swizzle(1) * row_major(1, 8), ) if not transpose_b: smem_b_type = tensor_type( - 'float16', + target_float_type.name, shape=[block_k, block_n], layout=row_major(block_k // 8, block_n // 64) * row_major(8, 8).swizzle(1) * row_major(1, 8), ) else: smem_b_type = tensor_type( - 'float16', + target_float_type.name, shape=[block_n, block_k], layout=row_major(block_n, block_k // 8).swizzle(1) * row_major(1, 8), ) @@ -289,7 +322,7 @@ def schedule( with hidet.script_module() as module: @hidet.script - def load_regs_a(mi: int, k1: int, smem_a: smem_a_type, regs_a: float16[mma_config.a_elements]): + def load_regs_a(mi: int, k1: int, smem_a: smem_a_type, regs_a: target_float_type[mma_config.a_elements]): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id): p, q = col_spatial(16, 2).map(lane_id) @@ -303,7 +336,9 @@ def load_regs_a(mi: int, k1: int, smem_a: smem_a_type, regs_a: float16[mma_confi ) @hidet.script - def load_regs_b_2x(mj: int, k1: int, smem_b: smem_b_type, regs_b: float16[2 * mma_config.b_elements]): + def load_regs_b_2x( + mj: int, k1: int, smem_b: smem_b_type, regs_b: target_float_type[2 * mma_config.b_elements] + ): """ We merge two ldmatrix.x2 insts to a single ldmatrix.x4 so that we can improve the throughput. """ @@ -333,14 +368,14 @@ def load_regs_b_2x(mj: int, k1: int, smem_b: smem_b_type, regs_b: float16[2 * mm @hidet.script def warp_mma( - regs_a: float16[mma_config.a_elements], - regs_b: float16[mma_config.b_elements], + regs_a: target_float_type[mma_config.a_elements], + regs_b: target_float_type[mma_config.b_elements], regs_c: acc_dtype[mma_config.c_elements], ): mma_sync(mma_config, regs_a, regs_b, regs_c) @hidet.script - def load_smem_a(k0: int, a: float16[a_head + [m_size, k_size]], smem_a: smem_a_type): + def load_smem_a(k0: int, a: target_float_type[a_head + [m_size, k_size]], smem_a: smem_a_type): c_head_index = spatial(*c_head).map(blockIdx.z) offset_m = blockIdx.x * block_m offset_k = c_head_index[0] * k_part_extent + k0 * block_k @@ -372,17 +407,17 @@ def load_smem_a(k0: int, a: float16[a_head + [m_size, k_size]], smem_a: smem_a_t ) @hidet.script - def load_smem_b(k0: int, b_ptr: ~float16, smem_b: smem_b_type): + def load_smem_b(k0: int, b_ptr: ~target_float_type, smem_b: smem_b_type): c_head_index = spatial(*c_head).map(blockIdx.z) offset_n = blockIdx.y * block_n offset_k = c_head_index[0] * k_part_extent + k0 * block_k maximum_k = min(k_size, (c_head_index[0] + 1) * k_part_extent) if not transpose_b: - b = as_tensor_pointer(b_ptr, 'float16', b_head + [k_size, n_size]) + b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [k_size, n_size]) gmem_b = b[broadcast_indices(c_head_index[1:], b_head, c_head[1:])][offset_k:, offset_n:] else: - b = as_tensor_pointer(b_ptr, 'float16', b_head + [n_size, k_size]) + b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [n_size, k_size]) gmem_b = b[broadcast_indices(c_head_index[1:], b_head, c_head[1:])][offset_n:, offset_k:] if not transpose_b: @@ -502,13 +537,13 @@ def load_smem_b(k0: int, b_ptr: ~float16, smem_b: smem_b_type): @hidet.script def store_c_reg2gmem( regs_c: acc_dtype[mma_count_m, mma_count_n, mma_config.c_elements], - c: float16[c_head + [m_size, n_size]], + c: target_float_type[c_head + [m_size, n_size]], ): t_regs_c = tiled_tensor_view(regs_c, mma_layout, "register") if fp16_acc: cvt_t_regs_c = rearrange(t_regs_c, store_c_layout, "register") else: - cvt_t_regs_c = rearrange(arithmetic(t_regs_c, op=cast_fp16), store_c_layout, "register") + cvt_t_regs_c = rearrange(arithmetic(t_regs_c, op=cast_func), store_c_layout, "register") extents = [m_size - blockIdx.x * block_m, n_size - blockIdx.y * block_n] offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n c_head_index = spatial(*c_head).map(blockIdx.z) @@ -536,7 +571,7 @@ def store_c_reg2gmem( smem_layout = TensorLayout((block_m, block_n), (block_n, 1)) @hidet.script - def store_c_smem2gmem(smem_c: acc_dtype[block_m, block_n], c: float16[c_head + [m_size, n_size]]): + def store_c_smem2gmem(smem_c: acc_dtype[block_m, block_n], c: target_float_type[c_head + [m_size, n_size]]): regs_c = register_tensor(acc_dtype, [store_c_layout.val_layout().size()]) t_regs_c = tiled_tensor_view(regs_c, store_c_layout, "register") t_smem_c = tiled_tensor_view(smem_c, smem_layout, "shared") @@ -556,7 +591,9 @@ def store_c_smem2gmem(smem_c: acc_dtype[block_m, block_n], c: float16[c_head + [ @hidet.script def matmul_f16_kernel( - a: float16[a_head + [m_size, k_size]], b_ptr: ~float16, c: float16[c_head + [m_size, n_size]] + a: target_float_type[a_head + [m_size, k_size]], + b_ptr: ~target_float_type, + c: target_float_type[c_head + [m_size, n_size]], ): # matrix multiplication, using mma instruction attrs.cuda.grid_dim = grid_dim @@ -565,28 +602,28 @@ def matmul_f16_kernel( attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes if not transpose_b: - b = as_tensor_pointer(b_ptr, 'float16', b_head + [k_size, n_size]) + b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [k_size, n_size]) else: - b = as_tensor_pointer(b_ptr, 'float16', b_head + [n_size, k_size]) + b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [n_size, k_size]) # smem_storage = dyn_smem_storage smem_a = tensor_pointer( - 'float16', shape=[2, block_m, block_k], layout=row_major(2) + smem_a_type.layout + target_float_type.name, shape=[2, block_m, block_k], layout=row_major(2) + smem_a_type.layout ) if not transpose_b: smem_b = tensor_pointer( - 'float16', shape=[2, block_k, block_n], layout=row_major(2) + smem_b_type.layout + target_float_type.name, shape=[2, block_k, block_n], layout=row_major(2) + smem_b_type.layout ) else: smem_b = tensor_pointer( - 'float16', shape=[2, block_n, block_k], layout=row_major(2) + smem_b_type.layout + target_float_type.name, shape=[2, block_n, block_k], layout=row_major(2) + smem_b_type.layout ) smem_a = dynamic_shared_memory(byte_offset=0, dtype=float16) smem_b = dynamic_shared_memory(byte_offset=2 * block_m * block_k * 2, dtype=float16) - regs_a = register_tensor(float16, [2, mma_count_m, mma_config.a_elements]) - regs_b = register_tensor(float16, [2, mma_count_n, mma_config.b_elements]) + regs_a = register_tensor(target_float_type, [2, mma_count_m, mma_config.a_elements]) + regs_b = register_tensor(target_float_type, [2, mma_count_n, mma_config.b_elements]) regs_c = register_tensor(acc_dtype, [mma_count_m, mma_count_n, mma_config.c_elements]) for i, j, p in grid(mma_count_m, mma_count_n, mma_config.c_elements): @@ -670,7 +707,12 @@ def matmul_f16_cute( a.shape[-1] % 2 != 0 or b.shape[-1] % 2 != 0 ): raise ValueError('Expect the last dimension of the input tensors to be a multiple of 2') - if a.dtype != dtypes.float16 or b.dtype != dtypes.float16: - raise ValueError('BatchMatmulF16Op only support float16, got {} and {}'.format(a.dtype, b.dtype)) + if a.dtype != b.dtype: + raise ValueError('a and b must have the same dtype, got {} and {}'.format(a.dtype, b.dtype)) + + valid_dtypes = [dtypes.float16, dtypes.bfloat16] + if a.dtype not in valid_dtypes or b.dtype not in valid_dtypes: + raise ValueError('matmul_f16_cute only supports float16 or bfloat16, got {} and {}'.format(a.dtype, b.dtype)) + acc_dtype = data_type(acc_dtype) return MatmulF16CuteOp(a, b, acc_dtype, parallel_k_parts, transpose_b).outputs[0] diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index adb172422..2587b760b 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -195,17 +195,19 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: transpose_b = op.attrs['transpose_b'] + valid_dtypes = [dtypes.float16, dtypes.bfloat16] + if not transpose_b and not ( - a.dtype == dtypes.float16 - and b.dtype == dtypes.float16 + a.dtype in valid_dtypes + and b.dtype in valid_dtypes and is_constant(a.shape[-1], b.shape[-1]) and (a.shape[-1] % 2 == b.shape[-1] % 2 == 0) ): return None elif transpose_b and not ( - a.dtype == dtypes.float16 - and b.dtype == dtypes.float16 + a.dtype in valid_dtypes + and b.dtype in valid_dtypes and is_constant(a.shape[-1], b.shape[-2]) and (a.shape[-1] % 2 == b.shape[-2] % 2 == 0) ): diff --git a/python/hidet/ir/dtypes/__init__.py b/python/hidet/ir/dtypes/__init__.py index f1764da69..b1c0a07ef 100644 --- a/python/hidet/ir/dtypes/__init__.py +++ b/python/hidet/ir/dtypes/__init__.py @@ -17,7 +17,7 @@ from .floats import float16, float32, float64, bfloat16, tfloat32 from .floats import f16, f32, f64, bf16, tf32 from .boolean import boolean -from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize +from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize, bfloat16x2 from .vector import f16x2, f32x4, f32x8, i4x8, u4x8 from .complex import complex64, complex128 from .integer import IntegerType @@ -56,6 +56,7 @@ 'uint1b': uint1b, 'int4bx8': int4bx8, 'uint4bx8': uint4bx8, + 'bfloat16x2': bfloat16x2, } sname2dtype = { diff --git a/python/hidet/ir/dtypes/vector.py b/python/hidet/ir/dtypes/vector.py index 496e07fb9..f334face3 100644 --- a/python/hidet/ir/dtypes/vector.py +++ b/python/hidet/ir/dtypes/vector.py @@ -11,7 +11,7 @@ # limitations under the License. from typing import Any, Sequence from hidet.ir.type import DataType -from .floats import float32, float16 +from .floats import float32, float16, bfloat16 from .integer import int8, uint8 from .integer_subbyte import int4b, uint4b from .boolean import boolean @@ -109,6 +109,8 @@ def max_value(self): uint4bx8 = VectorType(uint4b, 8) u4x8 = uint4bx8 +bfloat16x2 = VectorType(bfloat16, 2) + def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType: table = { @@ -118,6 +120,7 @@ def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType: (int8, 4): int8x4, (uint8, 4): uint8x4, (boolean, 4): int8x4, + (bfloat16, 2): bfloat16x2, } if (base_dtype, num_lanes) in table: return table[(base_dtype, num_lanes)] diff --git a/python/hidet/ir/primitives/cuda/smem.py b/python/hidet/ir/primitives/cuda/smem.py index c05abf580..f50b2b8d9 100644 --- a/python/hidet/ir/primitives/cuda/smem.py +++ b/python/hidet/ir/primitives/cuda/smem.py @@ -31,6 +31,7 @@ def register_functions(): 'int32', 'float16', 'float32', + 'bfloat16', 'bool', 'int4b', 'uint4b', diff --git a/requirements-dev.txt b/requirements-dev.txt index 84c739f2a..6f1199d63 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -16,7 +16,6 @@ torch>=2.3.0 torchvision datasets diffusers -ruacy check for huggingface LLMs in Regression (#368)) transformers sentencepiece sacremoses diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index d38e364a7..d5375a6ef 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -198,7 +198,32 @@ def test_matmul_nt(a_shape, b_shape): graph = hidet.graph.trace_from(cc, inputs=[ahi_symbol, bhi_symbol]) graph_opt = hidet.graph.optimize(graph) c_hi = graph_opt(ahi, bhi) - np.testing.assert_allclose(c_hi.cpu().numpy(), c_correct.cpu().numpy(), atol=1e-2, rtol=1e-2) + np.testing.assert_allclose(c_hi.cpu().numpy(), c_correct.cpu().numpy(), atol=1e-1, rtol=1e-1) + + +@pytest.mark.parametrize( + "a_shape, b_shape", + [ + [[1, 128, 128], [128, 128]], + [[1, 128, 128 + 4], [128 + 4, 128]], + [[1, 128, 128 + 2], [128 + 2, 128]], + [[1, 128, 128 + 2], [128 + 2, 128 - 2]], + [[1, 128, 128], [128, 128 - 4]], + ], +) +def test_matmul_bf16(a_shape, b_shape): + a = torch.randn(*a_shape, dtype=torch.bfloat16, device='cuda') + b = torch.randn(*b_shape, dtype=torch.bfloat16, device='cuda') + c_correct = torch.matmul(a, b).to(dtype=torch.float32) + ahi = hidet.from_torch(a) + bhi = hidet.from_torch(b) + ahi_symbol = hidet.symbol_like(ahi) + bhi_symbol = hidet.symbol_like(bhi) + cc = ops.matmul(ahi_symbol, bhi_symbol) + graph = hidet.graph.trace_from(cc, inputs=[ahi_symbol, bhi_symbol]) + graph_opt = hidet.graph.optimize(graph) + c_hi = graph_opt(ahi, bhi).to(dtype='float32') + np.testing.assert_allclose(c_hi.cpu().numpy(), c_correct.cpu().numpy(), atol=1e-1, rtol=1e-1) if __name__ == '__main__':