Skip to content

Commit

Permalink
[Operators] Support bfloat16 data type in matmul operator (#511)
Browse files Browse the repository at this point in the history
Closes #500
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 20, 2024
1 parent 8fc6de3 commit a467c76
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 38 deletions.
1 change: 1 addition & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def visit_DataType(self, t: DataType):
'uint8x4': 'uint4',
'int4bx8': 'uint32_t',
'uint4bx8': 'uint32_t',
'bfloat16x2': '__nv_bfloat162',
}

self.require_complex = self.require_complex or t.name in ['complex64', 'complex128']
Expand Down
1 change: 1 addition & 0 deletions python/hidet/cuda/cublas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def as_pointer(obj) -> int:
dtypes.float16: cudaDataType.CUDA_R_16F,
dtypes.float32: cudaDataType.CUDA_R_32F,
dtypes.float64: cudaDataType.CUDA_R_64F,
dtypes.bfloat16: cudaDataType.CUDA_R_16BF,
}


Expand Down
102 changes: 72 additions & 30 deletions python/hidet/graph/ops/matmul/matmul_f16_cute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from hidet.ir import dtypes
from hidet.ir.type import DataType, data_type
from hidet.ir.dtypes import float16, float32
from hidet.ir.dtypes import float16, float32, bfloat16
from hidet.ir.expr import if_then_else, Int, Expr, cast
from hidet.ir.func import Function
from hidet.ir.module import IRModule
Expand All @@ -32,6 +32,10 @@ def cast_fp16(x: Expr):
return float16(x)


def cast_bf16(x: Expr):
return bfloat16(x)


class MatmulF16CuteTask(Task):
"""
A matmul template that enables f16 tensorcore with CuTe dialect.
Expand All @@ -48,8 +52,23 @@ def __init__(
transpose_b: bool = False,
):
self.transpose_b = transpose_b
if not a.type.dtype == float16 or not b.type.dtype == float16:
raise ValueError('Both inputs must be float16 tensors')

if not a.type.dtype == b.type.dtype:
raise ValueError(f'Both inputs must have the same dtype, but got {a.type.dtype} and {b.type.dtype}')

both_f16 = a.type.dtype == float16
both_bf16 = a.type.dtype == bfloat16

if both_f16:
target_float_type = float16
elif both_bf16:
target_float_type = bfloat16
else:
raise ValueError(
f'Both inputs must be float16 or bfloat tensors, but got {a.type.dtype} and {b.type.dtype}'
)

self.target_float_type = target_float_type

if len(a.shape) < 2 or len(b.shape) < 2:
raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a.shape, b.shape))
Expand Down Expand Up @@ -114,14 +133,14 @@ def inner_compute(k_part, indices, k):
)

def outer_compute(k_part, *indices):
return float16(
return target_float_type(
reduce(shape=[k_part_extent], fcompute=partial(inner_compute, k_part, indices), reduce_type='sum')
)

c = compute(name='c', shape=c_shape, fcompute=outer_compute)

super().__init__(
name=f'matmul_f16_pk_cute_transpose_b_{transpose_b}',
name=f'matmul_{target_float_type.short_name}_pk_cute_transpose_b_{transpose_b}',
inputs=[a, b],
outputs=[c],
attributes={'acc_dtype': acc_dtype, 'parallel_k_parts': parallel_k_parts, 'transpose_b': transpose_b},
Expand Down Expand Up @@ -191,6 +210,8 @@ def schedule(
from hidet.lang.constructs.declare import as_tensor_pointer

transpose_b = self.attrs['transpose_b']
target_float_type = self.target_float_type
cast_func = cast_fp16 if target_float_type == float16 else cast_bf16

# input shapes
node_a, node_b, node_c = self.inputs[0], self.inputs[1], self.outputs[0]
Expand All @@ -211,6 +232,7 @@ def schedule(
if transpose_b:
# TODO: Is there a way to support cuBLAS with B transposed?
tune.check(not use_cublas, 'Cublas does not support transpose_b')

if use_cublas:
from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule
from hidet.cuda.cublas import cublasComputeType
Expand All @@ -237,8 +259,17 @@ def schedule(
return get_cublas_matmul_schedule(a_shape, b_shape, c_shape, dtype, dtype, dtype, compute_type)

# schedule parameters

# For the bfloat16 case, there is no mma config with float16 accumulator, only float32
if target_float_type == bfloat16:
tune.check(acc_dtype == float32, 'bfloat16 only supports float32 accumulator')

mma_configs_f16 = {'m16n8k8': MmaConfig.m16n8k8_f16_f16(), 'm16n8k16': MmaConfig.m16n8k16_f16_f16()}
mma_configs_f32 = {'m16n8k8': MmaConfig.m16n8k8_f16_f32(), 'm16n8k16': MmaConfig.m16n8k16_f16_f32()}
mma_configs_f32 = (
{'m16n8k8': MmaConfig.m16n8k8_f16_f32(), 'm16n8k16': MmaConfig.m16n8k16_f16_f32()}
if target_float_type == float16
else {'m16n8k8': MmaConfig.m16n8k8_bf16_f32(), 'm16n8k16': MmaConfig.m16n8k16_bf16_f32()}
)
tune.check(mma in mma_configs_f16 or mma in mma_configs_f32)
mma_config = mma_configs_f16[mma] if acc_dtype == float16 else mma_configs_f32[mma]

Expand All @@ -260,17 +291,19 @@ def schedule(
tune.check(block_k % 8 == 0)
tune.check(is_power_of_two(block_k // 8))
smem_a_type = tensor_type(
'float16', shape=[block_m, block_k], layout=row_major(block_m, block_k // 8).swizzle(1) * row_major(1, 8)
self.target_float_type.name,
shape=[block_m, block_k],
layout=row_major(block_m, block_k // 8).swizzle(1) * row_major(1, 8),
)
if not transpose_b:
smem_b_type = tensor_type(
'float16',
target_float_type.name,
shape=[block_k, block_n],
layout=row_major(block_k // 8, block_n // 64) * row_major(8, 8).swizzle(1) * row_major(1, 8),
)
else:
smem_b_type = tensor_type(
'float16',
target_float_type.name,
shape=[block_n, block_k],
layout=row_major(block_n, block_k // 8).swizzle(1) * row_major(1, 8),
)
Expand All @@ -289,7 +322,7 @@ def schedule(
with hidet.script_module() as module:

@hidet.script
def load_regs_a(mi: int, k1: int, smem_a: smem_a_type, regs_a: float16[mma_config.a_elements]):
def load_regs_a(mi: int, k1: int, smem_a: smem_a_type, regs_a: target_float_type[mma_config.a_elements]):
warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32
for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id):
p, q = col_spatial(16, 2).map(lane_id)
Expand All @@ -303,7 +336,9 @@ def load_regs_a(mi: int, k1: int, smem_a: smem_a_type, regs_a: float16[mma_confi
)

@hidet.script
def load_regs_b_2x(mj: int, k1: int, smem_b: smem_b_type, regs_b: float16[2 * mma_config.b_elements]):
def load_regs_b_2x(
mj: int, k1: int, smem_b: smem_b_type, regs_b: target_float_type[2 * mma_config.b_elements]
):
"""
We merge two ldmatrix.x2 insts to a single ldmatrix.x4 so that we can improve the throughput.
"""
Expand Down Expand Up @@ -333,14 +368,14 @@ def load_regs_b_2x(mj: int, k1: int, smem_b: smem_b_type, regs_b: float16[2 * mm

@hidet.script
def warp_mma(
regs_a: float16[mma_config.a_elements],
regs_b: float16[mma_config.b_elements],
regs_a: target_float_type[mma_config.a_elements],
regs_b: target_float_type[mma_config.b_elements],
regs_c: acc_dtype[mma_config.c_elements],
):
mma_sync(mma_config, regs_a, regs_b, regs_c)

@hidet.script
def load_smem_a(k0: int, a: float16[a_head + [m_size, k_size]], smem_a: smem_a_type):
def load_smem_a(k0: int, a: target_float_type[a_head + [m_size, k_size]], smem_a: smem_a_type):
c_head_index = spatial(*c_head).map(blockIdx.z)
offset_m = blockIdx.x * block_m
offset_k = c_head_index[0] * k_part_extent + k0 * block_k
Expand Down Expand Up @@ -372,17 +407,17 @@ def load_smem_a(k0: int, a: float16[a_head + [m_size, k_size]], smem_a: smem_a_t
)

@hidet.script
def load_smem_b(k0: int, b_ptr: ~float16, smem_b: smem_b_type):
def load_smem_b(k0: int, b_ptr: ~target_float_type, smem_b: smem_b_type):
c_head_index = spatial(*c_head).map(blockIdx.z)
offset_n = blockIdx.y * block_n
offset_k = c_head_index[0] * k_part_extent + k0 * block_k
maximum_k = min(k_size, (c_head_index[0] + 1) * k_part_extent)

if not transpose_b:
b = as_tensor_pointer(b_ptr, 'float16', b_head + [k_size, n_size])
b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [k_size, n_size])
gmem_b = b[broadcast_indices(c_head_index[1:], b_head, c_head[1:])][offset_k:, offset_n:]
else:
b = as_tensor_pointer(b_ptr, 'float16', b_head + [n_size, k_size])
b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [n_size, k_size])
gmem_b = b[broadcast_indices(c_head_index[1:], b_head, c_head[1:])][offset_n:, offset_k:]

if not transpose_b:
Expand Down Expand Up @@ -502,13 +537,13 @@ def load_smem_b(k0: int, b_ptr: ~float16, smem_b: smem_b_type):
@hidet.script
def store_c_reg2gmem(
regs_c: acc_dtype[mma_count_m, mma_count_n, mma_config.c_elements],
c: float16[c_head + [m_size, n_size]],
c: target_float_type[c_head + [m_size, n_size]],
):
t_regs_c = tiled_tensor_view(regs_c, mma_layout, "register")
if fp16_acc:
cvt_t_regs_c = rearrange(t_regs_c, store_c_layout, "register")
else:
cvt_t_regs_c = rearrange(arithmetic(t_regs_c, op=cast_fp16), store_c_layout, "register")
cvt_t_regs_c = rearrange(arithmetic(t_regs_c, op=cast_func), store_c_layout, "register")
extents = [m_size - blockIdx.x * block_m, n_size - blockIdx.y * block_n]
offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n
c_head_index = spatial(*c_head).map(blockIdx.z)
Expand Down Expand Up @@ -536,7 +571,7 @@ def store_c_reg2gmem(
smem_layout = TensorLayout((block_m, block_n), (block_n, 1))

@hidet.script
def store_c_smem2gmem(smem_c: acc_dtype[block_m, block_n], c: float16[c_head + [m_size, n_size]]):
def store_c_smem2gmem(smem_c: acc_dtype[block_m, block_n], c: target_float_type[c_head + [m_size, n_size]]):
regs_c = register_tensor(acc_dtype, [store_c_layout.val_layout().size()])
t_regs_c = tiled_tensor_view(regs_c, store_c_layout, "register")
t_smem_c = tiled_tensor_view(smem_c, smem_layout, "shared")
Expand All @@ -556,7 +591,9 @@ def store_c_smem2gmem(smem_c: acc_dtype[block_m, block_n], c: float16[c_head + [

@hidet.script
def matmul_f16_kernel(
a: float16[a_head + [m_size, k_size]], b_ptr: ~float16, c: float16[c_head + [m_size, n_size]]
a: target_float_type[a_head + [m_size, k_size]],
b_ptr: ~target_float_type,
c: target_float_type[c_head + [m_size, n_size]],
):
# matrix multiplication, using mma instruction
attrs.cuda.grid_dim = grid_dim
Expand All @@ -565,28 +602,28 @@ def matmul_f16_kernel(
attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes

if not transpose_b:
b = as_tensor_pointer(b_ptr, 'float16', b_head + [k_size, n_size])
b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [k_size, n_size])
else:
b = as_tensor_pointer(b_ptr, 'float16', b_head + [n_size, k_size])
b = as_tensor_pointer(b_ptr, target_float_type.name, b_head + [n_size, k_size])

# smem_storage = dyn_smem_storage
smem_a = tensor_pointer(
'float16', shape=[2, block_m, block_k], layout=row_major(2) + smem_a_type.layout
target_float_type.name, shape=[2, block_m, block_k], layout=row_major(2) + smem_a_type.layout
)

if not transpose_b:
smem_b = tensor_pointer(
'float16', shape=[2, block_k, block_n], layout=row_major(2) + smem_b_type.layout
target_float_type.name, shape=[2, block_k, block_n], layout=row_major(2) + smem_b_type.layout
)
else:
smem_b = tensor_pointer(
'float16', shape=[2, block_n, block_k], layout=row_major(2) + smem_b_type.layout
target_float_type.name, shape=[2, block_n, block_k], layout=row_major(2) + smem_b_type.layout
)

smem_a = dynamic_shared_memory(byte_offset=0, dtype=float16)
smem_b = dynamic_shared_memory(byte_offset=2 * block_m * block_k * 2, dtype=float16)
regs_a = register_tensor(float16, [2, mma_count_m, mma_config.a_elements])
regs_b = register_tensor(float16, [2, mma_count_n, mma_config.b_elements])
regs_a = register_tensor(target_float_type, [2, mma_count_m, mma_config.a_elements])
regs_b = register_tensor(target_float_type, [2, mma_count_n, mma_config.b_elements])
regs_c = register_tensor(acc_dtype, [mma_count_m, mma_count_n, mma_config.c_elements])

for i, j, p in grid(mma_count_m, mma_count_n, mma_config.c_elements):
Expand Down Expand Up @@ -670,7 +707,12 @@ def matmul_f16_cute(
a.shape[-1] % 2 != 0 or b.shape[-1] % 2 != 0
):
raise ValueError('Expect the last dimension of the input tensors to be a multiple of 2')
if a.dtype != dtypes.float16 or b.dtype != dtypes.float16:
raise ValueError('BatchMatmulF16Op only support float16, got {} and {}'.format(a.dtype, b.dtype))
if a.dtype != b.dtype:
raise ValueError('a and b must have the same dtype, got {} and {}'.format(a.dtype, b.dtype))

valid_dtypes = [dtypes.float16, dtypes.bfloat16]
if a.dtype not in valid_dtypes or b.dtype not in valid_dtypes:
raise ValueError('matmul_f16_cute only supports float16 or bfloat16, got {} and {}'.format(a.dtype, b.dtype))

acc_dtype = data_type(acc_dtype)
return MatmulF16CuteOp(a, b, acc_dtype, parallel_k_parts, transpose_b).outputs[0]
10 changes: 6 additions & 4 deletions python/hidet/graph/ops/matmul/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,19 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:

transpose_b = op.attrs['transpose_b']

valid_dtypes = [dtypes.float16, dtypes.bfloat16]

if not transpose_b and not (
a.dtype == dtypes.float16
and b.dtype == dtypes.float16
a.dtype in valid_dtypes
and b.dtype in valid_dtypes
and is_constant(a.shape[-1], b.shape[-1])
and (a.shape[-1] % 2 == b.shape[-1] % 2 == 0)
):
return None

elif transpose_b and not (
a.dtype == dtypes.float16
and b.dtype == dtypes.float16
a.dtype in valid_dtypes
and b.dtype in valid_dtypes
and is_constant(a.shape[-1], b.shape[-2])
and (a.shape[-1] % 2 == b.shape[-2] % 2 == 0)
):
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/ir/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .floats import float16, float32, float64, bfloat16, tfloat32
from .floats import f16, f32, f64, bf16, tf32
from .boolean import boolean
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize, bfloat16x2
from .vector import f16x2, f32x4, f32x8, i4x8, u4x8
from .complex import complex64, complex128
from .integer import IntegerType
Expand Down Expand Up @@ -56,6 +56,7 @@
'uint1b': uint1b,
'int4bx8': int4bx8,
'uint4bx8': uint4bx8,
'bfloat16x2': bfloat16x2,
}

sname2dtype = {
Expand Down
5 changes: 4 additions & 1 deletion python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from typing import Any, Sequence
from hidet.ir.type import DataType
from .floats import float32, float16
from .floats import float32, float16, bfloat16
from .integer import int8, uint8
from .integer_subbyte import int4b, uint4b
from .boolean import boolean
Expand Down Expand Up @@ -109,6 +109,8 @@ def max_value(self):
uint4bx8 = VectorType(uint4b, 8)
u4x8 = uint4bx8

bfloat16x2 = VectorType(bfloat16, 2)


def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
table = {
Expand All @@ -118,6 +120,7 @@ def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
(int8, 4): int8x4,
(uint8, 4): uint8x4,
(boolean, 4): int8x4,
(bfloat16, 2): bfloat16x2,
}
if (base_dtype, num_lanes) in table:
return table[(base_dtype, num_lanes)]
Expand Down
1 change: 1 addition & 0 deletions python/hidet/ir/primitives/cuda/smem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def register_functions():
'int32',
'float16',
'float32',
'bfloat16',
'bool',
'int4b',
'uint4b',
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ torch>=2.3.0
torchvision
datasets
diffusers
ruacy check for huggingface LLMs in Regression (#368))
transformers
sentencepiece
sacremoses
Expand Down
Loading

0 comments on commit a467c76

Please sign in to comment.