From 70dbb62128a5a1471a5ab80363430adb33470cab Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Thu, 5 Dec 2024 08:19:34 -0800 Subject: [PATCH] fix #172 (#173) - steel_attenion.metal (new) was missing from the build --- .../steel/attn/kernels/steel_attention.metal | 31 +++++++++++++++++++ Tests/MLXTests/MLXFastKernelTests.swift | 26 ++++++++++++++++ tools/fix-metal-includes.sh | 5 +-- 3 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal new file mode 100644 index 00000000..284beecf --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal @@ -0,0 +1,31 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#include "../../../utils.h" + +#include "../../../steel/attn/attn.h" +#include "../../../steel/attn/kernels/steel_attention.h" + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \ + template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \ + [[kernel]] void attention( \ + const device dtype* Q [[buffer(0)]], \ + const device dtype* K [[buffer(1)]], \ + const device dtype* V [[buffer(2)]], \ + device dtype* O [[buffer(3)]],\ + const constant AttnParams* params [[buffer(4)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_attn_shapes_helper(iname, itype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1) + +instantiate_attn_shapes_helper(float16, half); +instantiate_attn_shapes_helper(bfloat16, bfloat16_t); + +instantiate_attn_shapes_helper(float32, float); +// clang-format on diff --git a/Tests/MLXTests/MLXFastKernelTests.swift b/Tests/MLXTests/MLXFastKernelTests.swift index 635e3246..432aadcf 100644 --- a/Tests/MLXTests/MLXFastKernelTests.swift +++ b/Tests/MLXTests/MLXFastKernelTests.swift @@ -70,4 +70,30 @@ class MLXFastKernelTests: XCTestCase { XCTAssertTrue(allClose(out[0], full([2, 2], values: 14.0484)).all().item()) XCTAssertTrue(allClose(out[1], full([3, 2], values: -2)).all().item()) } + + func testFastSDPA() { + // https://github.com/ml-explore/mlx-swift/issues/172 + // this will just make sure the MLXFast.scaled_dot_product_attention is + // callable in the various cases, based on + // https://github.com/ml-explore/mlx/blob/main/python/tests/test_fast_sdpa.py#L65-L87 + + let Dk = 64 + let scale = 1.0 / sqrt(Float(Dk)) + let dTypes = [DType.float32, DType.float16] + for SEQUENCE_LENGTH in [63, 129, 400] { + for dtype in dTypes { + let B = 2 + let H = 24 + let q = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype) + let k = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype) + let v = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype) + + let result = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: nil, + memoryEfficientThreshold: 2) + + eval(result) + } + } + } } diff --git a/tools/fix-metal-includes.sh b/tools/fix-metal-includes.sh index b5cced9e..622d4311 100755 --- a/tools/fix-metal-includes.sh +++ b/tools/fix-metal-includes.sh @@ -23,11 +23,12 @@ KERNEL_LIST=" \ arg_reduce.metal \ conv.metal \ gemv.metal \ +layer_norm.metal \ random.metal \ rms_norm.metal \ -layer_norm.metal \ rope.metal \ -scaled_dot_product_attention.metal" +scaled_dot_product_attention.metal \ +steel/attn/kernels/steel_attention.metal" # We fixup all the header files AND the listed kernel files HEADERS=$(find "${KERNELS_DIR}" -name "*.h")