From 78a7cfe6701d6e9c88e9d4a0d1f7990af84b2146 Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Thu, 3 Oct 2024 12:19:44 -0700 Subject: [PATCH] move to v0.18.0 of mlx (#137) * move to v0.18.0 of mlx - https://github.com/ml-explore/mlx-c v0.0.10 - https://github.com/ml-explore/mlx/compare/v0.16.0... v0.18.0 * turn on additional swift 6 concurrency checks and fix issues * adopt new mlx_optional_* Co-authored-by: Awni Hannun --- CMakeLists.txt | 2 +- Package.swift | 33 +- Plugins/PrepareMetalShaders/main.swift | 1 - Source/Cmlx/mlx | 2 +- Source/Cmlx/mlx-c | 2 +- Source/Cmlx/mlx-generated/binary.cpp | 66 +- Source/Cmlx/mlx-generated/binary_two.cpp | 82 +- .../Cmlx/mlx-generated/compiled_preamble.cpp | 28 +- Source/Cmlx/mlx-generated/conv.cpp | 4 +- Source/Cmlx/mlx-generated/copy.cpp | 86 +- Source/Cmlx/mlx-generated/gather.cpp | 29 +- Source/Cmlx/mlx-generated/gemv_masked.cpp | 644 +++++++++++++ Source/Cmlx/mlx-generated/hadamard.cpp | 8 +- Source/Cmlx/mlx-generated/quantized.cpp | 154 ++- Source/Cmlx/mlx-generated/reduce.cpp | 877 +++++++++++------- Source/Cmlx/mlx-generated/reduce_utils.cpp | 108 ++- Source/Cmlx/mlx-generated/scatter.cpp | 2 +- Source/Cmlx/mlx-generated/softmax.cpp | 8 +- Source/Cmlx/mlx-generated/sort.cpp | 44 +- .../Cmlx/mlx-generated/steel_conv_general.cpp | 2 +- Source/Cmlx/mlx-generated/ternary.cpp | 58 +- Source/Cmlx/mlx-generated/unary.cpp | 28 +- Source/Cmlx/mlx-generated/unary_ops.cpp | 9 + Source/Cmlx/mlx-generated/utils.cpp | 201 ++-- Source/MLX/Cmlx+Util.swift | 23 +- .../Organization/arithmetic.md | 2 + .../Organization/convolution.md | 3 + .../MLX/Documentation.docc/free-functions.md | 3 + Source/MLX/Factory.swift | 13 +- Source/MLX/GPU.swift | 13 +- Source/MLX/MLXArray+Ops.swift | 3 +- Source/MLX/Ops+Array.swift | 3 +- Source/MLX/Ops.swift | 304 +++++- Source/MLX/Transforms+Compile.swift | 17 +- Source/MLX/Transforms+Internal.swift | 9 +- Source/MLX/Transforms.swift | 16 +- Source/MLXFast/Cmlx+Util.swift | 74 ++ Source/MLXFast/MLXFast.swift | 42 +- Source/MLXFast/MLXFastKernel.swift | 207 +++++ Source/MLXLinalg/Cmlx+Util.swift | 23 +- Source/MLXLinalg/Linalg.swift | 79 +- Source/MLXNN/Activations.swift | 36 +- Source/MLXNN/Cache.swift | 2 +- Source/MLXNN/ConvolutionTransposed.swift | 173 ++++ Source/MLXNN/Documentation.docc/layers.md | 4 + Source/MLXNN/Dropout.swift | 6 +- Source/MLXNN/Losses.swift | 24 +- Source/MLXNN/Module.swift | 26 +- Source/MLXNN/Normalization.swift | 14 +- Source/MLXNN/Recurrent.swift | 28 +- Source/MLXNN/Transformer.swift | 4 +- Source/MLXNN/Upsample.swift | 4 +- Source/MLXRandom/Cmlx+Util.swift | 47 + Source/MLXRandom/Random.swift | 22 +- Source/MLXRandom/State.swift | 25 +- Tests/MLXTests/MLXFastKernelTests.swift | 75 ++ tools/update-mlx.sh | 1 + 57 files changed, 2962 insertions(+), 841 deletions(-) create mode 100644 Source/Cmlx/mlx-generated/gemv_masked.cpp create mode 100644 Source/MLXFast/Cmlx+Util.swift create mode 100644 Source/MLXFast/MLXFastKernel.swift create mode 100644 Source/MLXNN/ConvolutionTransposed.swift create mode 100644 Source/MLXRandom/Cmlx+Util.swift create mode 100644 Tests/MLXTests/MLXFastKernelTests.swift diff --git a/CMakeLists.txt b/CMakeLists.txt index 7516d65d..4398e563 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ endif() FetchContent_Declare( mlx-c GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG "v0.0.9") + GIT_TAG "v0.0.10") FetchContent_MakeAvailable(mlx-c) # swift-numerics diff --git a/Package.swift b/Package.swift index bdb1a740..688c0877 100644 --- a/Package.swift +++ b/Package.swift @@ -148,31 +148,52 @@ let package = Package( dependencies: [ "Cmlx", .product(name: "Numerics", package: "swift-numerics"), + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") ] ), .target( name: "MLXRandom", - dependencies: ["MLX"] + dependencies: ["MLX"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] ), .target( name: "MLXFast", - dependencies: ["MLX", "Cmlx"] + dependencies: ["MLX", "Cmlx"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] ), .target( name: "MLXNN", - dependencies: ["MLX", "MLXRandom", "MLXFast"] + dependencies: ["MLX", "MLXRandom", "MLXFast"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] ), .target( name: "MLXOptimizers", - dependencies: ["MLX", "MLXNN"] + dependencies: ["MLX", "MLXNN"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] ), .target( name: "MLXFFT", - dependencies: ["MLX"] + dependencies: ["MLX"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] ), .target( name: "MLXLinalg", - dependencies: ["MLX"] + dependencies: ["MLX"], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] ), .testTarget( diff --git a/Plugins/PrepareMetalShaders/main.swift b/Plugins/PrepareMetalShaders/main.swift index 454dee0f..107cbaf3 100644 --- a/Plugins/PrepareMetalShaders/main.swift +++ b/Plugins/PrepareMetalShaders/main.swift @@ -35,7 +35,6 @@ struct PrepareMetalShaders: BuildToolPlugin { "arg_reduce.metal", "conv.metal", "gemv.metal", - "gemv_masked.metal", "random.metal", "rms_norm.metal", "layer_norm.metal", diff --git a/Source/Cmlx/mlx b/Source/Cmlx/mlx index d0da7420..b1e2b53c 160000 --- a/Source/Cmlx/mlx +++ b/Source/Cmlx/mlx @@ -1 +1 @@ -Subproject commit d0da74209bbb5a1ca31e1f99d8f2d57750b918cb +Subproject commit b1e2b53c2d5892728d7e0d55ebfb2a2ec37ef910 diff --git a/Source/Cmlx/mlx-c b/Source/Cmlx/mlx-c index 82f176ac..a1b66041 160000 --- a/Source/Cmlx/mlx-c +++ b/Source/Cmlx/mlx-c @@ -1 +1 @@ -Subproject commit 82f176ac84ea3217e6b5fd11b4104f0b0e5a8166 +Subproject commit a1b66041f4ffdf2bcb8000f6f2919e7be19e1523 diff --git a/Source/Cmlx/mlx-generated/binary.cpp b/Source/Cmlx/mlx-generated/binary.cpp index 9925f9a5..d336683e 100644 --- a/Source/Cmlx/mlx-generated/binary.cpp +++ b/Source/Cmlx/mlx-generated/binary.cpp @@ -35,6 +35,36 @@ template c[index] = Op()(a[index], b[index]); } template +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[0], b[offset]); +} +template +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[offset], b[0]); +} +template +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + c[offset] = Op()(a[offset], b[offset]); +} +template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, @@ -57,7 +87,7 @@ template uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } template @@ -72,25 +102,10 @@ template auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template -[[kernel]] void binary_g_nd( - device const T* a, - device const T* b, - device U* c, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - c[out_idx] = Op()(a[idx.x], b[idx.y]); -} -template +template [[kernel]] void binary_g( device const T* a, device const T* b, @@ -101,9 +116,18 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); - c[out_idx] = Op()(a[idx.x], b[idx.y]); + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + c[out_idx++] = Op()(a[idx.x], b[idx.y]); + idx.x += a_xstride; + idx.y += b_xstride; + } } )preamble"; } diff --git a/Source/Cmlx/mlx-generated/binary_two.cpp b/Source/Cmlx/mlx-generated/binary_two.cpp index a81b5c48..3b551c3c 100644 --- a/Source/Cmlx/mlx-generated/binary_two.cpp +++ b/Source/Cmlx/mlx-generated/binary_two.cpp @@ -47,6 +47,45 @@ template d[index] = out[1]; } template +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[0], b[offset]); + c[offset] = out[0]; + d[offset] = out[1]; +} +template +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[offset], b[0]); + c[offset] = out[0]; + d[offset] = out[1]; +} +template +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + auto out = Op()(a[offset], b[offset]); + c[offset] = out[0]; + d[offset] = out[1]; +} +template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, @@ -73,7 +112,7 @@ template uint2 grid_dim [[threads_per_grid]]) { auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; @@ -91,30 +130,12 @@ template auto a_idx = elem_to_loc_3(index, a_strides); auto b_idx = elem_to_loc_3(index, b_strides); size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template -[[kernel]] void binary_g_nd( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - auto out = Op()(a[idx.x], b[idx.y]); - c[out_idx] = out[0]; - d[out_idx] = out[1]; -} -template +template [[kernel]] void binary_g( device const T* a, device const T* b, @@ -126,11 +147,20 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); - auto out = Op()(a[idx.x], b[idx.y]); - c[out_idx] = out[0]; - d[out_idx] = out[1]; + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx++] = out[1]; + idx.x += a_xstride; + idx.y += b_xstride; + } } )preamble"; } diff --git a/Source/Cmlx/mlx-generated/compiled_preamble.cpp b/Source/Cmlx/mlx-generated/compiled_preamble.cpp index 9c29dcee..ba69c97a 100644 --- a/Source/Cmlx/mlx-generated/compiled_preamble.cpp +++ b/Source/Cmlx/mlx-generated/compiled_preamble.cpp @@ -7,7 +7,7 @@ return R"preamble( # 1 "Source/Cmlx/mlx/mlx/backend/common/compiled_preamble.h" # 1 "" 1 # 1 "" 3 -# 418 "" 3 +# 424 "" 3 # 1 "" 1 # 1 "" 2 # 1 "Source/Cmlx/mlx/mlx/backend/common/compiled_preamble.h" 2 @@ -22,35 +22,35 @@ return R"preamble( -# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/arm_fp16.h" 1 3 -# 27 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/arm_fp16.h" 3 -# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 1 3 -# 96 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/arm_fp16.h" 1 3 +# 27 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/arm_fp16.h" 3 +# 1 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 1 3 +# 96 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef long long int int64_t; typedef long long unsigned int uint64_t; -# 118 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 118 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef int64_t int_least64_t; typedef uint64_t uint_least64_t; typedef int64_t int_fast64_t; typedef uint64_t uint_fast64_t; -# 193 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 193 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef int int32_t; typedef unsigned int uint32_t; -# 216 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 216 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef int32_t int_least32_t; typedef uint32_t uint_least32_t; typedef int32_t int_fast32_t; typedef uint32_t uint_fast32_t; -# 241 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 241 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef short int16_t; typedef unsigned short uint16_t; -# 255 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 255 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef int16_t int_least16_t; typedef uint16_t uint_least16_t; typedef int16_t int_fast16_t; @@ -74,7 +74,7 @@ typedef int8_t int_least8_t; typedef uint8_t uint_least8_t; typedef int8_t int_fast8_t; typedef uint8_t uint_fast8_t; -# 291 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/stdint.h" 3 +# 291 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/stdint.h" 3 typedef long int intptr_t; @@ -90,7 +90,7 @@ typedef long unsigned int uintptr_t; typedef long int intmax_t; typedef long unsigned int uintmax_t; -# 28 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/15.0.0/include/arm_fp16.h" 2 3 +# 28 "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/16/include/arm_fp16.h" 2 3 typedef __fp16 float16_t; # 7 "Source/Cmlx/mlx/mlx/types/half_types.h" 2 @@ -671,6 +671,10 @@ struct Sign { uint64_t operator()(uint64_t x) { return x != 0; } + + complex64_t operator()(complex64_t x) { + return x == complex64_t(0) ? x : x / std::abs(x); + } }; struct Sin { diff --git a/Source/Cmlx/mlx-generated/conv.cpp b/Source/Cmlx/mlx-generated/conv.cpp index d1b2c566..13593557 100644 --- a/Source/Cmlx/mlx-generated/conv.cpp +++ b/Source/Cmlx/mlx-generated/conv.cpp @@ -363,7 +363,7 @@ struct Conv2DWeightBlockLoader { const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_ -> wt_strides[0]), + : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), @@ -581,7 +581,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { const constant ImplicitGemmConv2DParams* gemm_params_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_ -> wt_strides[0]), + : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), diff --git a/Source/Cmlx/mlx-generated/copy.cpp b/Source/Cmlx/mlx-generated/copy.cpp index 8a9d8d84..69ab0247 100644 --- a/Source/Cmlx/mlx-generated/copy.cpp +++ b/Source/Cmlx/mlx-generated/copy.cpp @@ -17,6 +17,24 @@ template dst[index] = static_cast(src[index]); } template +[[kernel]] void copy_s2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + dst[offset] = static_cast(src[0]); +} +template +[[kernel]] void copy_v2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + dst[offset] = static_cast(src[offset]); +} +template [[kernel]] void copy_g_nd1( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -48,20 +66,7 @@ template index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } -template -[[kernel]] void copy_g_nd( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); -} -template +template [[kernel]] void copy_g( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -70,10 +75,22 @@ template constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); + auto src_idx = elem_to_loc( + {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); + if (N == 1) { + int64_t dst_idx = + index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); + return; + } + auto xshape = src_shape[ndim - 1]; int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); - dst[dst_idx] = static_cast(src[src_idx]); + N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z); + auto src_xstride = src_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[dst_idx + i] = static_cast(src[src_idx]); + src_idx += src_xstride; + } } template [[kernel]] void copy_gg_nd1( @@ -108,19 +125,7 @@ template auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template -[[kernel]] void copy_gg_nd( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - auto dst_idx = elem_to_loc_nd(index, src_shape, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} -template +template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -129,9 +134,24 @@ template constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); - auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); - dst[dst_idx] = static_cast(src[src_idx]); + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = static_cast(src[idx.x]); + return; + } + auto src_xstride = src_strides[ndim - 1]; + auto dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = static_cast(src[idx.x]); + idx.x += src_xstride; + idx.y += dst_xstride; + } } )preamble"; } diff --git a/Source/Cmlx/mlx-generated/gather.cpp b/Source/Cmlx/mlx-generated/gather.cpp index 42d1c0dd..954baf7c 100644 --- a/Source/Cmlx/mlx-generated/gather.cpp +++ b/Source/Cmlx/mlx-generated/gather.cpp @@ -28,30 +28,35 @@ METAL_FUNC void gather_impl( const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], const thread Indices& indices, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto ind_idx = index.x; - auto ind_offset = index.y; + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { size_t src_idx = 0; for (int i = 0; i < NIDX; ++i) { size_t idx_loc; if (IDX_NDIM == 0) { idx_loc = 0; } else if (IDX_NDIM == 1) { - idx_loc = ind_idx * indices.strides[indices.ndim * i]; + idx_loc = index.x * indices.strides[indices.ndim * i]; } else { - idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); + idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc += elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); src_idx += idx_val * src_strides[ax]; } - auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim); - size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + size_t out_idx = index.z; + if (IDX_NDIM == 1) { + out_idx += static_cast(grid_dim.z) * index.x; + } else if (IDX_NDIM >= 2) { + out_idx += + grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + } out[out_idx] = src[src_offset + src_idx]; } )preamble"; diff --git a/Source/Cmlx/mlx-generated/gemv_masked.cpp b/Source/Cmlx/mlx-generated/gemv_masked.cpp new file mode 100644 index 00000000..c97f9ea7 --- /dev/null +++ b/Source/Cmlx/mlx-generated/gemv_masked.cpp @@ -0,0 +1,644 @@ +namespace mlx::core::metal { + +const char* gemv_masked() { + return R"preamble( +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} + +using namespace metal; +struct _NoMask { + char x; + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; +typedef struct _NoMask nomask_t; +template +struct ScaleOp { + OutT scale; + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, + const int BN, + const int SM, + const int SN, + const int TM, + const int TN> +struct GEMVKernel { + static constant constexpr const int threadsM = BM * SM; + static constant constexpr const int threadsN = BN * SN; + static constant constexpr const int blockM = threadsM * TM; + static constant constexpr const int blockN = threadsN * TN; + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); + static constant constexpr const bool has_operand_mask = !metal::is_same_v; + static constant constexpr const bool has_output_mask = !metal::is_same_v; + static constant constexpr const bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + static constant constexpr const bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + static constant constexpr const short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + static constant constexpr const bool needs_tgp_reduction = BN > 1; + static METAL_FUNC void + load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } + static METAL_FUNC void load_safe( + const device T* src, + thread T dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src[src_offset + tn]; + } + } else { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0; + } + } + } + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + int out_row = tid.x * blockM + bm; + if (out_row >= out_vec_size) + return; + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); + const int out_mask_offset = + !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; + int mat_mask_offset = + !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; + T out_scale{1}; + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + if (!mask_out) { + if (simdN == 0 && thrN == 0) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = T(0.); + } + } + return; + } + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + mat += out_row * matrix_ld; + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + for (int i = 0; i < n_iter; ++i) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + load_unsafe(in_vec, v_coeff, bn); + if (has_mul_operand_mask) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + int mat_offset = 0; +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + load_unsafe(mat, inter, mat_offset + bn); +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + mat_offset += matrix_ld; + } + } + bn += blockN; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + load_safe(in_vec, v_coeff, bn, in_size); + if (has_mul_operand_mask) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + if (has_mul_output_mask) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + result[tm] *= out_scale; + } + } +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { +#pragma clang loop unroll(full) + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + threadgroup_barrier(mem_flags::mem_none); + if (sgN == 0) { +#pragma clang loop unroll(full) + for (int sgn = 1; sgn < BN; sgn++) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + if (simdN == 0 && thrN == 0) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = result[tm]; + } + } + } +}; +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, + const int BN, + const int SM, + const int SN, + const int TM, + const int TN> +struct GEMVTKernel { + static constant constexpr const int threadsM = BM * SM; + static constant constexpr const int threadsN = BN * SN; + static constant constexpr const int blockM = threadsM * TM; + static constant constexpr const int blockN = threadsN * TN; + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + static constant constexpr const bool has_operand_mask = !metal::is_same_v; + static constant constexpr const bool has_output_mask = !metal::is_same_v; + static constant constexpr const bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + static constant constexpr const bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + static constant constexpr const short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + static constant constexpr const bool needs_tgp_reduction = BM > 1; + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + const int simdM = SM * sgM; + const int simdN = SN * sgN; + int cm = (simdM + thrM); + int cn = (simdN + thrN); + int bm = cm * TM; + int bn = cn * TN; + int out_col = tid.x * blockN + bn; + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + out_mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); + const int out_mask_offset = + !has_output_mask ? 0 : n_block_idx; + int mat_mask_offset = + !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; + T out_scale{1}; + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + if (!mask_out) { + if (cm == 0 && out_col < out_vec_size) { + if (out_col + TN <= out_vec_size) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + out_vec[out_col + tn] = T(0.); + } + } else { + for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { + out_vec[out_col + tn] = T(0.); + } + } + } + return; + } + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + if (out_col < out_vec_size) { + out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; + for (int i = 0; i < n_iter; ++i) { + threadgroup_barrier(mem_flags::mem_none); + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + if (has_mul_operand_mask) { +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] *= block_scale; + } + } +#pragma clang loop unroll(full) + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + bm += blockM; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + if (leftover > 0 && + (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset])))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + if (has_mul_output_mask) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + result[tn] *= out_scale; + } + } +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { +#pragma clang loop unroll(full) + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + if (needs_tgp_reduction) { + threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + threadgroup_barrier(mem_flags::mem_none); + if (sgM == 0) { +#pragma clang loop unroll(full) + for (int sgm = 1; sgm < BM; sgm++) { +#pragma clang loop unroll(full) + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + if (cm == 0 && out_col < out_vec_size) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + out_vec[out_col + j] = result[j]; + } + } + } +}; +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, + const int BN, + const int SM, + const int SN, + const int TM, + const int TN, + const bool kDoNCBatch> +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + out_vec += tid.z * out_vec_size; + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, + const int BN, + const int SM, + const int SN, + const int TM, + const int TN, + const bool kDoNCBatch> +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* vector_batch_stride [[buffer(11)]], + const constant size_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant size_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVTKernel; + threadgroup T tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + if (has_operand_mask) { + const constant size_t* mask_strides_mat = mask_batch_strides; + const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + out_vec += tid.z * out_vec_size; + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} +)preamble"; +} + +} // namespace mlx::core::metal diff --git a/Source/Cmlx/mlx-generated/hadamard.cpp b/Source/Cmlx/mlx-generated/hadamard.cpp index 2bcf5275..b0839df9 100644 --- a/Source/Cmlx/mlx-generated/hadamard.cpp +++ b/Source/Cmlx/mlx-generated/hadamard.cpp @@ -60,7 +60,7 @@ template radix_func(x); #pragma clang loop unroll(full) for (short r = 0; r < max_radix; r++) { - buf[j + h * r] = x[r]; + buf[j + h * r] = T(x[r]); } h <<= logR; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -78,7 +78,7 @@ template radix_func(x); #pragma clang loop unroll(full) for (short r = 0; r < final_radix; r++) { - buf[j + h * r] = x[r]; + buf[j + h * r] = T(x[r]); } } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -88,7 +88,7 @@ template short index = j * read_width * num_threads + i * read_width; #pragma clang loop unroll(full) for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = buf[index + r] * scale; + out[batch_idx + index + r] = T(buf[index + r] * scale); } } } @@ -118,7 +118,7 @@ template for (short c = 0; c < M; c++) { #pragma clang loop unroll(full) for (short r = 0; r < read_width; r++) { - out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale; + out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); } } } diff --git a/Source/Cmlx/mlx-generated/quantized.cpp b/Source/Cmlx/mlx-generated/quantized.cpp index a5f8561c..b4de253d 100644 --- a/Source/Cmlx/mlx-generated/quantized.cpp +++ b/Source/Cmlx/mlx-generated/quantized.cpp @@ -589,12 +589,12 @@ METAL_FUNC void qvm_impl( } template < typename T, - const int BM, - const int BK, - const int BN, const int group_size, const int bits, - const bool aligned_N> + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> METAL_FUNC void qmm_t_impl( const device T* x, const device uint32_t* w, @@ -698,11 +698,11 @@ METAL_FUNC void qmm_t_impl( } template < typename T, - const int BM, - const int BK, - const int BN, const int group_size, - const int bits> + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> METAL_FUNC void qmm_n_impl( const device T* x, const device uint32_t* w, @@ -965,7 +965,7 @@ template < constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; - qmm_t_impl( + qmm_t_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } template < @@ -993,7 +993,7 @@ template < constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; - qmm_n_impl( + qmm_n_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } template @@ -1237,7 +1237,7 @@ template < s_strides, b_strides, tid); - qmm_t_impl( + qmm_t_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } template < @@ -1301,9 +1301,139 @@ template < s_strides, b_strides, tid); - qmm_n_impl( + qmm_n_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } +template +[[kernel]] void affine_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device T* scales [[buffer(2)]], + device T* biases [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr T eps = T(1e-7); + constexpr int simd_size = 32; + constexpr int uint8_bits = 8; + constexpr T n_bins = (1 << bits) - 1; + constexpr int packs_per_int = uint8_bits / bits; + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + static_assert( + group_size % simd_size == 0, + "Group size must be divisible by simd size."); + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * values_per_reduce; + size_t out_index = offset * writes_per_pack; + T w_thread[values_per_reduce]; + T w_min = Limits::max; + T w_max = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + T val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + w_min = simd_min(w_min); + w_max = simd_max(w_max); + T scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + T edge = side ? w_min : w_max; + T q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + T bias = at_zero ? T(0) : edge; + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = scale; + biases[gindex] = bias; + } + uint8_t output = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output += val << (bits * (i % packs_per_int)); + } + if (packs_per_int < values_per_reduce && + i % packs_per_int == packs_per_int - 1) { + out[out_index + i / packs_per_int] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 0; j < writes_per_reduce - 1; j++) { + uint8_t sval = simd_shuffle_down(val, j + 1); + output += sval << (bits * (values_per_reduce + j + i)); + } + } + } + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } +} +template +[[kernel]] void affine_quantize_scales_biases( + const device T* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device uint8_t* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr int uint8_bits = 8; + constexpr int packs_per_int = uint8_bits / bits; + constexpr T n_bins = (1 << bits) - 1; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * packs_per_int; + size_t gindex = in_index / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + uint8_t output = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < packs_per_int; i++) { + uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output += val << (bits * i); + } + } + out[offset] = output; +} +template +[[kernel]] void affine_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr int uint8_bits = 8; + constexpr int packs_per_int = uint8_bits / bits; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * packs_per_int; + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < packs_per_int; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[oindex + i] = scale * d + bias; + } +} )preamble"; } diff --git a/Source/Cmlx/mlx-generated/reduce.cpp b/Source/Cmlx/mlx-generated/reduce.cpp index 3254d12b..33eaa583 100644 --- a/Source/Cmlx/mlx-generated/reduce.cpp +++ b/Source/Cmlx/mlx-generated/reduce.cpp @@ -3,448 +3,607 @@ namespace mlx::core::metal { const char* reduce() { return R"preamble( -template -METAL_FUNC U per_thread_all_reduce( - const device T* in, - const device size_t& in_size, - uint gid, - uint grid_size) { - Op op; - U total_val = Op::init; - if (gid * N_READS < in_size) { - in += gid * N_READS; - int r = 0; - for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) { - U vals[N_READS] = {op.init}; - for (int i = 0; i < N_READS; i++) { - vals[i] = static_cast(in[i]); - } - for (int i = 0; i < N_READS; i++) { - total_val = op(vals[i], total_val); - } - in += grid_size * N_READS; - } - size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS; - if (curr_idx < in_size) { - int max_reads = in_size - curr_idx; - T vals[N_READS]; - for (int i = 0, idx = 0; i < N_READS; i++, idx++) { - idx = idx < max_reads ? idx : max_reads - 1; - vals[i] = in[idx]; - } - for (int i = 0; i < N_READS; i++) { - U val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); - } - } - } - return total_val; -} template [[kernel]] void all_reduce( const device T* in [[buffer(0)]], - device mlx_atomic* out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], + device U* out [[buffer(1)]], + const constant size_t& in_size [[buffer(2)]], + const constant size_t& row_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - threadgroup U local_vals[simd_size]; - U total_val = - per_thread_all_reduce(in, in_size, gid, grid_size); - total_val = op.simd_reduce(total_val); - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; + threadgroup U shared_vals[simd_size]; + U total = Op::init; + int64_t start_idx = gid.y * row_size; + int64_t actual_row = + (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; + int64_t blocks = actual_row / (lsize.x * N_READS); + int extra = actual_row - blocks * (lsize.x * N_READS); + extra -= lid.x * N_READS; + start_idx += lid.x * N_READS; + in += start_idx; + if (extra >= N_READS) { + blocks++; + extra = 0; } - threadgroup_barrier(mem_flags::mem_threadgroup); - total_val = lid < simd_per_group ? local_vals[lid] : op.init; - total_val = op.simd_reduce(total_val); - if (lid == 0) { - op.atomic_update(out, total_val); - } -} -template -[[kernel]] void all_reduce_no_atomics( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const device size_t& in_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint grid_size [[threads_per_grid]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint thread_group_id [[threadgroup_position_in_grid]]) { - Op op; - threadgroup U local_vals[simd_size]; - U total_val = - per_thread_all_reduce(in, in_size, gid, grid_size); - for (uint16_t lane_offset = simd_size / 2; lane_offset > 0; - lane_offset /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); + for (int64_t b = 0; b < blocks; b++) { + for (int i = 0; i < N_READS; i++) { + total = op(static_cast(in[i]), total); + } + in += lsize.x * N_READS; } - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; + if (extra > 0) { + for (int i = 0; i < extra; i++) { + total = op(static_cast(in[i]), total); + } } - threadgroup_barrier(mem_flags::mem_threadgroup); - total_val = lid < simd_per_group ? local_vals[lid] : op.init; - for (uint16_t lane_offset = simd_size / 2; lane_offset > 0; - lane_offset /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, lane_offset)); + total = op.simd_reduce(total); + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + shared_vals[simd_group_id] = total; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; + total = op.simd_reduce(total); } - if (lid == 0) { - out[thread_group_id] = total_val; + if (lid.x == 0) { + out[gid.y] = total; } } -template +template < + typename T, + typename U, + typename Op, + int NDIMS, + int N_READS = REDUCE_N_READS> [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - const constant size_t& non_col_reductions [[buffer(8)]], - const constant int* non_col_shapes [[buffer(9)]], - const constant size_t* non_col_strides [[buffer(10)]], - const constant int& non_col_ndim [[buffer(11)]], - uint tid [[thread_position_in_grid]]) { - (void)out_size; + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { Op op; - U total_val = Op::init; - auto out_idx = tid; - in += elem_to_loc( - out_idx, - shape + non_col_ndim, - strides + non_col_ndim, - ndim - non_col_ndim); - for (uint i = 0; i < non_col_reductions; i++) { - size_t in_idx = - elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim); - for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) { - U val = static_cast(in[in_idx]); - total_val = op(total_val, val); + looped_elem_to_loc loop; + const device T* row; + if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { + U totals[31]; + for (int i = 0; i < 31; i++) { + totals[i] = Op::init; } - } - out[out_idx] = total_val; -} -template -METAL_FUNC U _contiguous_strided_reduce( - const device T* in, - threadgroup U* local_data, - uint in_idx, - uint reduction_size, - uint reduction_stride, - uint2 tid, - uint2 lid, - uint2 lsize) { - Op op; - U total_val = Op::init; - uint base_offset = (tid.y * lsize.y + lid.y) * N_READS; - for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) { - uint offset = base_offset + r; - total_val = - op(static_cast(total_val), in[in_idx + offset * reduction_stride]); - } - local_data[lsize.y * lid.x + lid.y] = total_val; - threadgroup_barrier(mem_flags::mem_threadgroup); - U val = Op::init; - if (lid.y == 0) { - for (uint i = 0; i < lsize.y; i++) { - val = op(val, local_data[lsize.y * lid.x + i]); + short stride = reduction_stride; + short size = reduction_size; + short blocks = stride / N_READS; + short extra = stride - blocks * N_READS; + size_t out_idx = tid.x + tsize.y * size_t(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + for (uint r = 0; r < non_col_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + for (short i = 0; i < size; i++) { + for (short j = 0; j < blocks; j++) { + for (short k = 0; k < N_READS; k++) { + totals[j * N_READS + k] = + op(totals[j * N_READS + k], + static_cast(row[i * stride + j * N_READS + k])); + } + } + for (short k = 0; k < extra; k++) { + totals[blocks * N_READS + k] = + op(totals[blocks * N_READS + k], + static_cast(row[i * stride + blocks * N_READS + k])); + } + } + loop.next(reduce_shape, reduce_strides); } - } - return val; -} -template -[[kernel]] void col_reduce_general( - const device T* in [[buffer(0)]], - device mlx_atomic* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup U* local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]]) { - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim); - Op op; - if (out_idx < out_size) { - U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); - if (lid.y == 0) { - op.atomic_update(out, val, out_idx); + out += out_idx * reduction_stride; + for (short j = 0; j < stride; j++) { + out[j] = totals[j]; } } -} -template -[[kernel]] void col_reduce_general_no_atomics( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - threadgroup U* local_data [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 gid [[thread_position_in_grid]], - uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]]) { - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim); - if (out_idx < out_size) { - U val = _contiguous_strided_reduce( - in, - local_data, - in_idx, - reduction_size, - reduction_stride, - tid.xy, - lid.xy, - lsize.xy); - if (lid.y == 0) { - uint tgsize_y = ceildiv(gsize.y, lsize.y); - uint tgsize_z = ceildiv(gsize.z, lsize.z); - out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val; + else if (reduction_size * non_col_reductions < 32) { + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = Op::init; + } + short size = reduction_size; + size_t offset = size_t(tid.x) * N_READS; + bool safe = offset + N_READS <= reduction_stride; + short extra = reduction_stride - offset; + size_t out_idx = tid.y + tsize.z * size_t(tid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + offset; + for (uint r = 0; r < non_col_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + if (safe) { + for (short i = 0; i < size; i++) { + for (short j = 0; j < N_READS; j++) { + totals[j] = + op(static_cast(row[i * reduction_stride + j]), totals[j]); + } + } + } else { + for (short i = 0; i < size; i++) { + for (short j = 0; j < extra; j++) { + totals[j] = + op(static_cast(row[i * reduction_stride + j]), totals[j]); + } + } + } + loop.next(reduce_shape, reduce_strides); + } + out += out_idx * reduction_stride + offset; + if (safe) { + for (short i = 0; i < N_READS; i++) { + out[i] = totals[i]; + } + } else { + for (short i = 0; i < extra; i++) { + out[i] = totals[i]; + } } } -} -template -[[kernel]] void row_reduce_general_small( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint lid [[thread_position_in_grid]]) { - Op op; - uint out_idx = lid; - if (out_idx >= out_size) { - return; - } - U total_val = Op::init; - for (short r = 0; r < short(non_row_reductions); r++) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T* in_row = in + in_idx; - for (short i = 0; i < short(reduction_size); i++) { - total_val = op(static_cast(in_row[i]), total_val); + else { + threadgroup U shared_vals[1024]; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = Op::init; + } + short stride = reduction_stride; + short lid = simd_group_id * simd_size + simd_lane_id; + short2 tile((stride + N_READS - 1) / N_READS, 32); + short2 offset((lid % tile.x) * N_READS, lid / tile.x); + short sm_stride = tile.x * N_READS; + bool safe = offset.x + N_READS <= stride; + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y, reduce_shape, reduce_strides); + for (size_t r = offset.y; r < total; r += simd_size) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + if (safe) { + for (int i = 0; i < N_READS; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[N_READS]; + for (int i = 0; i < N_READS; i++) { + vals[i] = (offset.x + i < stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < N_READS; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + loop.next(simd_size, reduce_shape, reduce_strides); + } + for (int i = 0; i < N_READS; i++) { + shared_vals[offset.y * sm_stride + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_READS; i++) { + totals[i] = op.simd_reduce( + shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]); + } + if (simd_lane_id == 0) { + short column = simd_group_id * N_READS; + out += out_idx * reduction_stride + column; + if (column + N_READS <= stride) { + for (int i = 0; i < N_READS; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < stride; i++) { + out[i] = totals[i]; + } + } } } - out[out_idx] = total_val; } -template -[[kernel]] void row_reduce_general_med( +template +[[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], - uint tid [[threadgroup_position_in_grid]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - uint out_idx = simd_per_group * tid + simd_group_id; - if (out_idx >= out_size) { - return; + constexpr int n_simdgroups = 4; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + looped_elem_to_loc loop; + const device T* row; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; } - U total_val = Op::init; - if (short(non_row_reductions) == 1) { - uint in_idx = elem_to_loc(out_idx, shape, strides, ndim); - const device T* in_row = in + in_idx; - for (short i = simd_lane_id; i < short(reduction_size); i += 32) { - total_val = op(static_cast(in_row[i]), total_val); + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + size_t column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + size_t total = non_col_reductions * reduction_size; + loop.next(offset.y, reduce_shape, reduce_strides); + for (size_t r = offset.y; r < total; r += BM) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } } + loop.next(BM, reduce_shape, reduce_strides); } - else if (short(non_row_reductions) >= 32) { - for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T* in_row = in + in_idx; - for (short i = 0; i < short(reduction_size); i++) { - total_val = op(static_cast(in_row[i]), total_val); + if (BM == 32) { + constexpr int n_outputs = BN / n_simdgroups; + static_assert( + BM != 32 || n_outputs == n_reads, + "The tile should be selected such that n_outputs == n_reads"); + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + if (simd_lane_id == 0) { + size_t out_column = BN * gid.x + out_offset.x; + out += out_idx * reduction_stride + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } } } } else { - const short n_reductions = - short(reduction_size) * short(non_row_reductions); - const short reductions_per_thread = - (n_reductions + simd_size - 1) / simd_size; - const short r_st = simd_lane_id / reductions_per_thread; - const short r_ed = short(non_row_reductions); - const short r_jump = simd_size / reductions_per_thread; - const short i_st = simd_lane_id % reductions_per_thread; - const short i_ed = short(reduction_size); - const short i_jump = reductions_per_thread; - if (r_st < r_jump) { - for (short r = r_st; r < r_ed; r += r_jump) { - uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim); - const device T* in_row = in + in_idx; - for (short i = i_st; i < i_ed; i += i_jump) { - total_val = op(static_cast(in_row[i]), total_val); + short x_block = offset.x / n_reads; + for (int i = 0; i < n_reads; i++) { + shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (offset.y == 0) { + for (int i = 0; i < n_reads; i++) { + for (int j = 1; j < BM; j++) { + totals[i] = + op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); + } + } + } + if (offset.y == 0) { + out += out_idx * reduction_stride + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; } } } } - total_val = op.simd_reduce(total_val); - if (simd_lane_id == 0) { - out[out_idx] = total_val; +} +template +[[kernel]] void init_reduce( + device T* out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* inputs[N_WRITES], + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + Op op; + for (int i = 0; i < N_WRITES; i++) { + totals[i] = Op::init; + } + for (int i = 0; i < blocks; i++) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + inputs[j] += lsize_x * N_READS; + } + } + int index = lid_x * N_READS; + if (index + N_READS <= extra) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } else { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; index + i < extra; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } } } -template -METAL_FUNC U per_thread_row_reduce( +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], const device T* in, const constant size_t& reduction_size, - const constant size_t& out_size, + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + const device T* inputs[N_WRITES]; + inputs[0] = in + lid_x * N_READS; + for (int i = 1; i < N_READS; i++) { + inputs[i] = inputs[i - 1] + reduction_size; + } + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const size_t row_idx, + int blocks, + int extra, const constant int* shape, const constant size_t* strides, const constant int& ndim, uint lsize_x, - uint lid_x, - uint2 tid) { + uint lid_x) { + const device T* inputs[N_WRITES]; + in += lid_x * N_READS; + for (int i = 0; i < N_READS; i++) { + inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); + } + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void threadgroup_reduce( + thread U totals[N_WRITES], + threadgroup U* shared_vals, + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - int idx = tid.y * out_size + tid.x; - int extra_offset = elem_to_loc(idx, shape, strides, ndim); - in += extra_offset + lid_x * N_READS; - U total_val = Op::init; - int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) { - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = in[i]; + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(totals[i]); + } + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + for (int i = 0; i < N_WRITES; i++) { + shared_vals[simd_group_id * N_WRITES + i] = totals[i]; + } } - for (int i = 0; i < N_READS; i++) { - total_val = op(static_cast(vals[i]), total_val); + threadgroup_barrier(mem_flags::mem_threadgroup); + U values[N_WRITES]; + for (int i = 0; i < N_WRITES; i++) { + values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] + : op.init; + } + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(values[i]); } - in += lsize_x * N_READS; } - size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS; - if (reduction_index < reduction_size) { - int max_reads = reduction_size - reduction_index; - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - int idx = min(i, max_reads - 1); - vals[i] = static_cast(in[idx]); +} +template +METAL_FUNC void +thread_reduce(thread U& total, const device T* row, int blocks, int extra) { + Op op; + for (int i = 0; i < blocks; i++) { + U vals[N_READS]; + for (int j = 0; j < N_READS; j++) { + vals[j] = row[j]; } - for (int i = 0; i < N_READS; i++) { - T val = i < max_reads ? vals[i] : Op::init; - total_val = op(static_cast(val), total_val); + for (int j = 0; j < N_READS; j++) { + total = op(vals[j], total); } + row += N_READS; + } + for (int i = 0; i < extra; i++) { + total = op(*row++, total); } - return total_val; } -template -[[kernel]] void row_reduce_general( +template < + typename T, + typename U, + typename Op, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& row_size [[buffer(2)]], + const constant size_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { + Op op; + U total_val = Op::init; + looped_elem_to_loc loop; + const device T* row; + int blocks = row_size / N_READS; + int extra = row_size % N_READS; + if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { + size_t out_idx = tid.x + tsize.y * size_t(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + for (uint r = 0; r < non_row_reductions; r++) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + thread_reduce(total_val, row, blocks, extra); + loop.next(reduce_shape, reduce_strides); + } + out[out_idx] = total_val; + } else { + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); + loop.next(simd_lane_id, reduce_shape, reduce_strides); + for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { + row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + thread_reduce(total_val, row, blocks, extra); + loop.next(simd_size, reduce_shape, reduce_strides); + } + total_val = op.simd_reduce(total_val); + if (simd_lane_id == 0) { + out[out_idx] = total_val; + } + } +} +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +[[kernel]] void row_reduce_simple( const device T* in [[buffer(0)]], - device mlx_atomic* out [[buffer(1)]], + device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - (void)non_row_reductions; - Op op; - threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce( - in, - reduction_size, - out_size, - shape, - strides, - ndim, - lsize.x, - lid.x, - tid.xy); - total_val = op.simd_reduce(total_val); - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (reduction_size > simd_size) { - total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - total_val = op.simd_reduce(total_val); + threadgroup U shared_vals[simd_size * N_WRITES]; + U totals[N_WRITES]; + size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); + if (out_idx + N_WRITES > out_size) { + out_idx = out_size - N_WRITES; } + in += out_idx * reduction_size; + out += out_idx; + int blocks = reduction_size / (lsize.x * N_READS); + int extra = reduction_size - blocks * (lsize.x * N_READS); + per_thread_row_reduce( + totals, in, reduction_size, blocks, extra, lsize.x, lid.x); + threadgroup_reduce( + totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); if (lid.x == 0) { - op.atomic_update(out, total_val, tid.x); + for (int i = 0; i < N_WRITES; i++) { + out[i] = totals[i]; + } } } -template -[[kernel]] void row_reduce_general_no_atomics( +template < + typename T, + typename U, + typename Op, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], - const constant size_t& non_row_reductions [[buffer(4)]], - const constant int* shape [[buffer(5)]], - const constant size_t* strides [[buffer(6)]], - const constant int& ndim [[buffer(7)]], + const constant size_t& row_size [[buffer(2)]], + const constant size_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 lsize [[threads_per_threadgroup]], - uint3 gsize [[threads_per_grid]], - uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - (void)non_row_reductions; Op op; - threadgroup U local_vals[simd_size]; - U total_val = per_thread_row_reduce( - in, - reduction_size, - out_size, - shape, - strides, - ndim, - lsize.x, - lid.x, - tid.xy); - for (uint16_t i = simd_size / 2; i > 0; i /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, i)); - } - if (simd_lane_id == 0) { - local_vals[simd_group_id] = total_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ceildiv(reduction_size, N_READS) > simd_size) { - total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; - for (uint16_t i = simd_size / 2; i > 0; i /= 2) { - total_val = op(total_val, simd_shuffle_down(total_val, i)); - } + threadgroup U shared_vals[simd_size]; + U total = Op::init; + size_t out_idx = gid.y + gsize.y * size_t(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + looped_elem_to_loc loop; + const device T* row; + int blocks = row_size / (lsize.x * N_READS); + int extra = row_size - blocks * (lsize.x * N_READS); + for (size_t i = 0; i < non_row_reductions; i++) { + row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); + U row_total; + per_thread_row_reduce( + &row_total, &row, blocks, extra, lsize.x, lid.x); + total = op(total, row_total); + loop.next(reduce_shape, reduce_strides); } + threadgroup_reduce( + &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); if (lid.x == 0) { - out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val; + out[out_idx] = total; } } )preamble"; diff --git a/Source/Cmlx/mlx-generated/reduce_utils.cpp b/Source/Cmlx/mlx-generated/reduce_utils.cpp index 970de8dc..de28f912 100644 --- a/Source/Cmlx/mlx-generated/reduce_utils.cpp +++ b/Source/Cmlx/mlx-generated/reduce_utils.cpp @@ -21,52 +21,54 @@ struct mlx_atomic>> { }; template , bool> = true> METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); } template , bool> = true> METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, uint offset) { +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, uint offset) { +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { T expected = mlx_atomic_load_explicit(object, offset); while (!mlx_atomic_compare_exchange_weak_explicit( object, &expected, val * expected, offset)) { @@ -77,7 +79,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread T* expected, T val, - uint offset) { + size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, @@ -89,7 +91,7 @@ template <> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, float val, - uint offset) { + size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val < expected) { if (mlx_atomic_compare_exchange_weak_explicit( @@ -102,7 +104,7 @@ template <> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, float val, - uint offset) { + size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val > expected) { if (mlx_atomic_compare_exchange_weak_explicit( @@ -121,7 +123,7 @@ union uint_or_packed { }; template struct mlx_atomic_update_helper { - uint operator()(uint_or_packed init, T update, uint elem_offset) { + uint operator()(uint_or_packed init, T update, size_t elem_offset) { Op op; init.val[elem_offset] = op(update, init.val[elem_offset]); return init.bits; @@ -131,9 +133,9 @@ template METAL_FUNC void mlx_atomic_update_and_store( device mlx_atomic* object, T update, - uint offset) { - uint pack_offset = offset / packing_size; - uint elem_offset = offset % packing_size; + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; mlx_atomic_update_helper helper; uint_or_packed expected; expected.bits = @@ -200,9 +202,9 @@ struct __Min { } template , bool> = true> METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { - uint pack_offset = offset / sizeof(T); - uint elem_offset = offset % sizeof(T); +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + size_t pack_offset = offset / sizeof(T); + size_t elem_offset = offset % sizeof(T); uint_or_packed packed_val; packed_val.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); @@ -210,16 +212,16 @@ mlx_atomic_load_explicit(device mlx_atomic* object, uint offset) { } template , bool> = true> METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, uint offset) { +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, - uint offset) { - uint pack_offset = offset / packing_size; - uint elem_offset = offset % packing_size; + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = 4294967295U; identity.val[elem_offset] = val; @@ -227,10 +229,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit( &(object[pack_offset].val), identity.bits, memory_order_relaxed); } template , bool> = true> -METAL_FUNC void -mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, uint offset) { - uint pack_offset = offset / packing_size; - uint elem_offset = offset % packing_size; +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = 0; identity.val[elem_offset] = val; @@ -241,28 +245,28 @@ template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, - uint offset) { + size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> @@ -270,7 +274,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread uint* expected, uint val, - uint offset) { + size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, @@ -285,13 +289,14 @@ union bool4_or_uint { }; struct None { template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_store_explicit(out, val, offset); } }; template struct And { - bool simd_reduce(bool val) { + template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } + bool simd_reduce_impl(bool val) { return simd_all(val); } static constexpr constant bool init = true; @@ -299,7 +304,7 @@ struct And { device mlx_atomic* out, bool val, int elem_idx, - int offset = 0) { + size_t offset = 0) { if (!val) { bool4_or_uint update; update.b = {true, true, true, true}; @@ -307,7 +312,8 @@ struct And { mlx_atomic_fetch_and_explicit(out, update.i, offset); } } - void atomic_update(device mlx_atomic* out, bool val, uint offset = 0) { + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { if (!val) { mlx_atomic_store_explicit(out, val, offset); } @@ -321,15 +327,16 @@ struct And { }; template struct Or { - bool simd_reduce(bool val) { + template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } + bool simd_reduce_impl(bool val) { return simd_any(val); } static constexpr constant bool init = false; void atomic_update( device mlx_atomic* out, bool val, - uint elem_idx, - uint offset = 0) { + int elem_idx, + size_t offset = 0) { if (val) { bool4_or_uint update; update.b = {false, false, false, false}; @@ -337,7 +344,8 @@ struct Or { mlx_atomic_fetch_or_explicit(out, update.i, offset); } } - void atomic_update(device mlx_atomic* out, bool val, uint offset = 0) { + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { if (val) { mlx_atomic_store_explicit(out, val, offset); } @@ -351,13 +359,14 @@ struct Or { }; template struct Sum { + template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_sum(val); } static constexpr constant U init = U(0); template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_add_explicit(out, val, offset); } U operator()(U a, U b) { @@ -366,13 +375,14 @@ struct Sum { }; template struct Prod { + template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_product(val); } static constexpr constant U init = U(1); template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_mul_explicit(out, val, offset); } U operator()(U a, U b) { @@ -381,13 +391,14 @@ struct Prod { }; template struct Min { + template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_min(val); } static constexpr constant U init = Limits::max; template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_min_explicit(out, val, offset); } U operator()(U a, U b) { @@ -396,13 +407,14 @@ struct Min { }; template struct Max { + template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } template - T simd_reduce(T val) { + T simd_reduce_impl(T val) { return simd_max(val); } static constexpr constant U init = Limits::min; template - void atomic_update(device mlx_atomic* out, T val, uint offset = 0) { + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_max_explicit(out, val, offset); } U operator()(U a, U b) { diff --git a/Source/Cmlx/mlx-generated/scatter.cpp b/Source/Cmlx/mlx-generated/scatter.cpp index 38b0c7a6..4bbad074 100644 --- a/Source/Cmlx/mlx-generated/scatter.cpp +++ b/Source/Cmlx/mlx-generated/scatter.cpp @@ -31,7 +31,7 @@ METAL_FUNC void scatter_1d_index_impl( const thread array& idx_buffers, uint2 gid [[thread_position_in_grid]]) { Op op; - uint out_idx = 0; + size_t out_idx = 0; for (int i = 0; i < NIDX; i++) { auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); out_idx += idx_val * out_strides[i]; diff --git a/Source/Cmlx/mlx-generated/softmax.cpp b/Source/Cmlx/mlx-generated/softmax.cpp index 08736bfc..bd08c48e 100644 --- a/Source/Cmlx/mlx-generated/softmax.cpp +++ b/Source/Cmlx/mlx-generated/softmax.cpp @@ -20,7 +20,7 @@ template threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; AccT ld[N_READS]; - in += gid * axis_size + lid * N_READS; + in += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { ld[i] = AccT(in[i]); @@ -72,7 +72,7 @@ template } threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = 1 / local_normalizer[0]; - out += gid * axis_size + lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { out[i] = T(ld[i] * normalizer); @@ -95,7 +95,7 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - in += gid * axis_size; + in += gid * size_t(axis_size); constexpr int SIMD_SIZE = 32; threadgroup AccT local_max[SIMD_SIZE]; threadgroup AccT local_normalizer[SIMD_SIZE]; @@ -142,7 +142,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); normalizer = 1 / normalizer; - out += gid * axis_size; + out += gid * size_t(axis_size); for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); r++) { int offset = r * lsize * N_READS + lid * N_READS; diff --git a/Source/Cmlx/mlx-generated/sort.cpp b/Source/Cmlx/mlx-generated/sort.cpp index 650f9bcc..fce60a27 100644 --- a/Source/Cmlx/mlx-generated/sort.cpp +++ b/Source/Cmlx/mlx-generated/sort.cpp @@ -268,9 +268,9 @@ template < const constant int& in_stride_sorted_axis [[buffer(3)]], const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* in_nc_strides [[buffer(7)]], - const device size_t* out_nc_strides [[buffer(8)]], + const constant int* nc_shape [[buffer(6)]], + const constant size_t* in_nc_strides [[buffer(7)]], + const constant size_t* out_nc_strides [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = @@ -391,8 +391,8 @@ template < const constant int& size_sorted_axis [[buffer(3)]], const constant int& stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], - const device int* nc_shape [[buffer(6)]], - const device size_t* nc_strides [[buffer(7)]], + const constant int* nc_shape [[buffer(6)]], + const constant size_t* nc_strides [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< @@ -424,13 +424,13 @@ template < bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void -mb_block_partition( +[[kernel]] void mb_block_partition( device idx_t* block_partitions [[buffer(0)]], const device val_t* dev_vals [[buffer(1)]], const device idx_t* dev_idxs [[buffer(2)]], const constant int& size_sorted_axis [[buffer(3)]], const constant int& merge_tiles [[buffer(4)]], + const constant int& n_blocks [[buffer(5)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint3 tgp_dims [[threads_per_threadgroup]]) { @@ -443,18 +443,24 @@ mb_block_partition( block_partitions += tid.y * tgp_dims.x; dev_vals += tid.y * size_sorted_axis; dev_idxs += tid.y * size_sorted_axis; - int merge_group = lid.x / merge_tiles; - int merge_lane = lid.x % merge_tiles; - int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; - int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; - int A_st = min(size_sorted_axis, sort_st); - int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); - int B_st = A_ed; - int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); - int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); - int partition = sort_kernel::merge_partition( - dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); - block_partitions[lid.x] = A_st + partition; + for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { + int merge_group = i / merge_tiles; + int merge_lane = i % merge_tiles; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); + block_partitions[i] = A_st + partition; + } } template < typename val_t, diff --git a/Source/Cmlx/mlx-generated/steel_conv_general.cpp b/Source/Cmlx/mlx-generated/steel_conv_general.cpp index 7d99404e..286b22e1 100644 --- a/Source/Cmlx/mlx-generated/steel_conv_general.cpp +++ b/Source/Cmlx/mlx-generated/steel_conv_general.cpp @@ -152,7 +152,7 @@ struct Conv2DWeightBlockLoaderGeneral { const short base_ww_, uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_ -> wt_strides[0]), + : src_ld(params_->wt_strides[0]), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), diff --git a/Source/Cmlx/mlx-generated/ternary.cpp b/Source/Cmlx/mlx-generated/ternary.cpp index 8ff14fa5..ce4a2c8a 100644 --- a/Source/Cmlx/mlx-generated/ternary.cpp +++ b/Source/Cmlx/mlx-generated/ternary.cpp @@ -12,6 +12,17 @@ template d[index] = Op()(a[index], b[index], c[index]); } template +[[kernel]] void ternary_v2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + d[offset] = Op()(a[offset], b[offset], c[offset]); +} +template [[kernel]] void ternary_g_nd1( device const bool* a, device const T* b, @@ -40,7 +51,7 @@ template auto a_idx = elem_to_loc_2(index, a_strides); auto b_idx = elem_to_loc_2(index, b_strides); auto c_idx = elem_to_loc_2(index, c_strides); - size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + size_t out_idx = index.x + size_t(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } template @@ -58,28 +69,10 @@ template auto b_idx = elem_to_loc_3(index, b_strides); auto c_idx = elem_to_loc_3(index, c_strides); size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template -[[kernel]] void ternary_g_nd( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int shape[DIM], - constant const size_t a_strides[DIM], - constant const size_t b_strides[DIM], - constant const size_t c_strides[DIM], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides); - size_t out_idx = - index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); - d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); -} -template +template [[kernel]] void ternary_g( device const bool* a, device const T* b, @@ -92,10 +85,25 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim); - size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); - d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]); + auto idx = elem_to_loc_3_nd( + {N * index.x, index.y, index.z}, + shape, + a_strides, + b_strides, + c_strides, + ndim); + auto xshape = shape[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + auto a_xstride = a_strides[ndim - 1]; + auto b_xstride = b_strides[ndim - 1]; + auto c_xstride = c_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); + idx.x += a_xstride; + idx.y += b_xstride; + idx.z += c_xstride; + } } )preamble"; } diff --git a/Source/Cmlx/mlx-generated/unary.cpp b/Source/Cmlx/mlx-generated/unary.cpp index df31ec75..00d67fb7 100644 --- a/Source/Cmlx/mlx-generated/unary.cpp +++ b/Source/Cmlx/mlx-generated/unary.cpp @@ -10,15 +10,33 @@ template out[index] = Op()(in[index]); } template +[[kernel]] void unary_v2( + device const T* in, + device T* out, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + size_t offset = index.x + grid_dim.x * size_t(index.y); + out[offset] = Op()(in[offset]); +} +template [[kernel]] void unary_g( device const T* in, device T* out, - device const int* in_shape, - device const size_t* in_strides, + constant const int* in_shape, + constant const size_t* in_strides, device const int& ndim, - uint index [[thread_position_in_grid]]) { - auto idx = elem_to_loc(index, in_shape, in_strides, ndim); - out[index] = Op()(in[idx]); + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = + elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto xshape = in_shape[ndim - 1]; + auto xstride = in_strides[ndim - 1]; + size_t out_idx = + N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + out[out_idx++] = Op()(in[idx]); + idx += xstride; + } } )preamble"; } diff --git a/Source/Cmlx/mlx-generated/unary_ops.cpp b/Source/Cmlx/mlx-generated/unary_ops.cpp index 28536626..6b1be9ea 100644 --- a/Source/Cmlx/mlx-generated/unary_ops.cpp +++ b/Source/Cmlx/mlx-generated/unary_ops.cpp @@ -91,6 +91,7 @@ float expm1f(float a) { float r; r = expm1f_scaled_unchecked(a, 1.0f); if (abs(a - 1.0f) > 88.0f) { + r = pow(2, a); r = fma(r, r, -1.0f); } return r; @@ -369,6 +370,14 @@ struct Sign { uint32_t operator()(uint32_t x) { return x != 0; }; + template <> + complex64_t operator()(complex64_t x) { + if (x == complex64_t(0)) { + return x; + } + return x / + (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); + }; }; struct Sin { template diff --git a/Source/Cmlx/mlx-generated/utils.cpp b/Source/Cmlx/mlx-generated/utils.cpp index bd3e1a04..dccadb30 100644 --- a/Source/Cmlx/mlx-generated/utils.cpp +++ b/Source/Cmlx/mlx-generated/utils.cpp @@ -176,6 +176,8 @@ struct complex64_t { float real; float imag; constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; + constexpr complex64_t() : real(0), imag(0) {}; + constexpr complex64_t() threadgroup : real(0), imag(0) {}; template < typename T, typename = typename enable_if>::type> @@ -262,7 +264,8 @@ constexpr complex64_t operator%(complex64_t a, complex64_t b) { return {real, imag}; } static constant constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; -static constant constexpr int REDUCE_N_READS = 16; +static constant constexpr int REDUCE_N_READS = 4; +static constant constexpr int REDUCE_N_WRITES = 4; static constant constexpr int SOFTMAX_N_READS = 4; static constant constexpr int RMS_N_READS = 4; static constant constexpr int RMS_LOOPED_LIMIT = 4096; @@ -291,11 +294,20 @@ struct Limits { static constexpr constant bool max = true; static constexpr constant bool min = false; }; +template <> +struct Limits { + static constexpr constant complex64_t max = complex64_t( + metal::numeric_limits::infinity(), + metal::numeric_limits::infinity()); + static constexpr constant complex64_t min = complex64_t( + -metal::numeric_limits::infinity(), + -metal::numeric_limits::infinity()); +}; template METAL_FUNC stride_t elem_to_loc( uint elem, - device const int* shape, - device const stride_t* strides, + constant const int* shape, + constant const stride_t* strides, int ndim) { stride_t loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { @@ -306,7 +318,7 @@ METAL_FUNC stride_t elem_to_loc( } template METAL_FUNC stride_t elem_to_loc( - uint elem, + stride_t elem, constant const int* shape, constant const stride_t* strides, int ndim) { @@ -344,67 +356,16 @@ METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; } -template -METAL_FUNC size_t elem_to_loc_nd( - uint elem, - device const int* shape, - device const size_t* strides) { - size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1]; -#pragma clang loop unroll(full) - for (int d = NDIM - 2; d >= 0; --d) { - elem /= shape[d + 1]; - loc += (elem % shape[d]) * strides[d]; - } - return loc; -} -template -METAL_FUNC size_t elem_to_loc_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t strides[NDIM]) { - size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; - for (int d = NDIM - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; - elem.z /= shape[d]; - } - return loc; -} -template -METAL_FUNC int64_t elem_to_loc_nd( - uint elem, - constant const int shape[NDIM], - constant const int64_t strides[NDIM]) { - int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1]; -#pragma clang loop unroll(full) - for (int d = NDIM - 2; d >= 0; --d) { - elem /= shape[d + 1]; - loc += (elem % shape[d]) * strides[d]; - } - return loc; -} -template -METAL_FUNC int64_t elem_to_loc_nd( - uint3 elem, - constant const int shape[NDIM], - constant const int64_t strides[NDIM]) { - int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; - for (int d = NDIM - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; - elem.z /= shape[d]; - } - return loc; -} -METAL_FUNC uint2 elem_to_loc_2_nd( +template +METAL_FUNC ulong2 elem_to_loc_2_nd( uint3 elem, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, + constant const stride_t* a_strides, + constant const stride_t* b_strides, int ndim) { - uint2 loc = { - static_cast( - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - static_cast( - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + ulong2 loc = { + ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * a_strides[d]; @@ -413,20 +374,17 @@ METAL_FUNC uint2 elem_to_loc_2_nd( } return loc; } -METAL_FUNC uint3 elem_to_loc_3_nd( +METAL_FUNC ulong3 elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, constant const size_t* b_strides, constant const size_t* c_strides, int ndim) { - uint3 loc = { - static_cast( - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - static_cast( - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]), - static_cast( - elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])}; + ulong3 loc = { + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2], + elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * a_strides[d]; @@ -436,49 +394,66 @@ METAL_FUNC uint3 elem_to_loc_3_nd( } return loc; } -template -METAL_FUNC uint2 elem_to_loc_2_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t a_strides[NDIM], - constant const size_t b_strides[NDIM]) { - uint2 loc = { - static_cast( - elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), - static_cast( - elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; - for (int d = NDIM - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - elem.z /= shape[d]; +template +struct looped_elem_to_loc { + looped_elem_to_loc inner_looper; + offset_t offset{0}; + int index{0}; + void next(const constant int* shape, const constant size_t* strides) { + index++; + offset += strides[dim - 1]; + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + void next(int n, const constant int* shape, const constant size_t* strides) { + index += n; + offset += n * strides[dim - 1]; + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + offset_t + location(offset_t, const constant int*, const constant size_t*, int) { + return offset; } - return loc; -} -template -METAL_FUNC uint3 elem_to_loc_3_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t a_strides[NDIM], - constant const size_t b_strides[NDIM], - constant const size_t c_strides[NDIM]) { - uint3 loc = { - static_cast( - elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), - static_cast( - elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]), - static_cast( - elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])}; - for (int d = NDIM - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - loc.z += l * c_strides[d]; - elem.z /= shape[d]; +}; +template +struct looped_elem_to_loc<1, offset_t> { + offset_t offset{0}; + void next(const constant int*, const constant size_t* strides) { + offset += strides[0]; } - return loc; -} -inline size_t ceildiv(size_t N, size_t M) { + void next(int n, const constant int*, const constant size_t* strides) { + offset += n * strides[0]; + } + offset_t + location(offset_t, const constant int*, const constant size_t*, int) { + return offset; + } +}; +template +struct looped_elem_to_loc<0, offset_t> { + void next(const constant int*, const constant size_t*) {} + void next(int, const constant int*, const constant size_t*) {} + offset_t location( + offset_t idx, + const constant int* shape, + const constant size_t* strides, + int ndim) { + return elem_to_loc(idx, shape, strides, ndim); + } +}; +template +inline T ceildiv(T N, U M) { return (N + M - 1) / M; } inline float log1p(float x) { @@ -512,6 +487,10 @@ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); } +inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); +} )preamble"; } diff --git a/Source/MLX/Cmlx+Util.swift b/Source/MLX/Cmlx+Util.swift index 12ef59aa..f91ec2ee 100644 --- a/Source/MLX/Cmlx+Util.swift +++ b/Source/MLX/Cmlx+Util.swift @@ -22,7 +22,7 @@ func mlx_describe(_ ptr: OpaquePointer) -> String? { // return a +1 mlx_vector_array containing the given arrays func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { let result = mlx_vector_array_new()! - mlx_vector_array_add_arrays(result, arrays.map { $0.ctx }, arrays.count) + mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) return result } @@ -142,3 +142,24 @@ func new_mlx_closure(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> mlx_closure { return mlx_closure_new_with_payload(trampoline, payload, free)! } + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { + let a = mlx_tuple_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_get_1(tuple)! + return (MLXArray(a), MLXArray(b)) +} + +func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { + let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! + defer { mlx_free(a) } + let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! + defer { mlx_free(b) } + return (mlx_vector_array_values(a), mlx_vector_array_values(b)) +} + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { + let a = mlx_tuple_array_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_array_get_1(tuple)! + let c = mlx_tuple_array_array_array_get_2(tuple)! + return (MLXArray(a), MLXArray(b), MLXArray(c)) +} diff --git a/Source/MLX/Documentation.docc/Organization/arithmetic.md b/Source/MLX/Documentation.docc/Organization/arithmetic.md index 5955900a..0067bc5b 100644 --- a/Source/MLX/Documentation.docc/Organization/arithmetic.md +++ b/Source/MLX/Documentation.docc/Organization/arithmetic.md @@ -164,6 +164,7 @@ Note: the `-` and `/` operators are not able to be linked here. - ``floorDivide(_:_:stream:)`` - ``isNaN(_:stream:)`` - ``isInf(_:stream:)`` +- ``isFinite(_:stream:)`` - ``isPosInf(_:stream:)`` - ``isNegInf(_:stream:)`` - ``leftShift(_:_:stream:)`` @@ -178,6 +179,7 @@ Note: the `-` and `/` operators are not able to be linked here. - ``maximum(_:_:stream:)`` - ``minimum(_:_:stream:)`` - ``multiply(_:_:stream:)`` +- ``nanToNum(_:nan:posInf:negInf:stream:)`` - ``negative(_:stream:)`` - ``notEqual(_:_:stream:)`` - ``pow(_:_:stream:)-7pe7j`` diff --git a/Source/MLX/Documentation.docc/Organization/convolution.md b/Source/MLX/Documentation.docc/Organization/convolution.md index 44483c08..664a9e61 100644 --- a/Source/MLX/Documentation.docc/Organization/convolution.md +++ b/Source/MLX/Documentation.docc/Organization/convolution.md @@ -13,4 +13,7 @@ MLX has several functions to support convolutions: - ``conv3d(_:_:stride:padding:dilation:groups:stream:)`` - ``convGeneral(_:_:strides:padding:kernelDilation:inputDilation:groups:flip:stream:)-9t1sj`` - ``convGeneral(_:_:strides:padding:kernelDilation:inputDilation:groups:flip:stream:)-6j1nr`` +- ``convTransposed1d(_:_:stride:padding:dilation:groups:stream:)`` +- ``convTransposed2d(_:_:stride:padding:dilation:groups:stream:)`` +- ``convTransposed3d(_:_:stride:padding:dilation:groups:stream:)`` - ``convolve(_:_:mode:stream:)`` diff --git a/Source/MLX/Documentation.docc/free-functions.md b/Source/MLX/Documentation.docc/free-functions.md index 998ac058..714b9830 100644 --- a/Source/MLX/Documentation.docc/free-functions.md +++ b/Source/MLX/Documentation.docc/free-functions.md @@ -226,4 +226,7 @@ operations as methods for convenience. - ``diag(_:k:stream:)`` - ``diagonal(_:offset:axis1:axis2:stream:)`` +- ``einsum(_:operands:stream:)`` +- ``einsum(_:_:stream:)`` +- ``hadamardTransform(_:scale:stream:)`` - ``view(_:dtype:stream:)`` diff --git a/Source/MLX/Factory.swift b/Source/MLX/Factory.swift index cd8c79e4..94b141d0 100644 --- a/Source/MLX/Factory.swift +++ b/Source/MLX/Factory.swift @@ -537,9 +537,10 @@ public func eye( /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` public func full( - _ shape: [Int], values: MLXArray, type: T.Type, stream: StreamOrDevice = .default + _ shape: [Int], values: ScalarOrArray, type: T.Type, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_full(shape.asInt32, shape.count, values.ctx, T.dtype.cmlxDtype, stream.ctx)) + let values = values.asMLXArray(dtype: nil) + return MLXArray(mlx_full(shape.asInt32, shape.count, values.ctx, T.dtype.cmlxDtype, stream.ctx)) } /// Construct an array with the given value. @@ -563,8 +564,12 @@ public func full( /// - /// - ``full(_:values:type:stream:)`` /// - ``repeated(_:count:axis:stream:)`` -public func full(_ shape: [Int], values: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_full(shape.asInt32, shape.count, values.ctx, values.dtype.cmlxDtype, stream.ctx)) +public func full(_ shape: [Int], values: ScalarOrArray, stream: StreamOrDevice = .default) + -> MLXArray +{ + let values = values.asMLXArray(dtype: nil) + return MLXArray( + mlx_full(shape.asInt32, shape.count, values.ctx, values.dtype.cmlxDtype, stream.ctx)) } /// Create a square identity matrix. diff --git a/Source/MLX/GPU.swift b/Source/MLX/GPU.swift index da2be343..50d6147d 100644 --- a/Source/MLX/GPU.swift +++ b/Source/MLX/GPU.swift @@ -24,9 +24,16 @@ public enum GPU { static let queue = DispatchQueue(label: "GPUEnum") - static var _relaxedMemoryLimit = true - static var _cacheLimit: Int? - static var _memoryLimit: Int? + // note: these are guarded by the queue above + #if swift(>=5.10) + nonisolated(unsafe) static var _relaxedMemoryLimit = true + nonisolated(unsafe) static var _cacheLimit: Int? + nonisolated(unsafe) static var _memoryLimit: Int? + #else + static var _relaxedMemoryLimit = true + static var _cacheLimit: Int? + static var _memoryLimit: Int? + #endif /// Snapshot of memory stats. /// diff --git a/Source/MLX/MLXArray+Ops.swift b/Source/MLX/MLXArray+Ops.swift index 86454416..ac1f9fca 100644 --- a/Source/MLX/MLXArray+Ops.swift +++ b/Source/MLX/MLXArray+Ops.swift @@ -2550,7 +2550,8 @@ extension MLXArray { /// - ``take(_:axis:stream:)`` /// - ``take(_:_:axis:stream:)`` public func take(_ indices: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_take_all(ctx, indices.ctx, stream.ctx)) + let input = self.reshaped([-1], stream: stream) + return MLXArray(mlx_take(input.ctx, indices.ctx, 0, stream.ctx)) } /// Transpose the dimensions of the array. diff --git a/Source/MLX/Ops+Array.swift b/Source/MLX/Ops+Array.swift index d285b3a1..17dab891 100644 --- a/Source/MLX/Ops+Array.swift +++ b/Source/MLX/Ops+Array.swift @@ -1598,7 +1598,8 @@ public func take( public func take(_ array: MLXArray, _ indices: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_take_all(array.ctx, indices.ctx, stream.ctx)) + let input = array.reshaped([-1], stream: stream) + return MLXArray(mlx_take(input.ctx, indices.ctx, 0, stream.ctx)) } /// Transpose the dimensions of the array. diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index ff297761..83920d2f 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -411,6 +411,24 @@ public func ceil(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr MLXArray(mlx_ceil(array.ctx, stream.ctx)) } +/// Clip the values of the array between the given minimum and maximum. +/// +/// - Parameters: +/// - array: input array +/// - min: minimum value (must broadcast to `array`) +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``clip(_:max:stream:)`` +/// - ``clip(_:min:max:stream:)`` +public func clip( + _ array: MLXArray, min: A, stream: StreamOrDevice = .default +) -> MLXArray { + let (array, min) = toArrays(array, min) + return MLXArray(mlx_clip(array.ctx, min.ctx, nil, stream.ctx)) +} + /// Clip the values of the array between the given minimum and maximum. /// /// - Parameters: @@ -423,11 +441,11 @@ public func ceil(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``clip(_:max:stream:)`` public func clip( - _ array: MLXArray, min: A, max: B? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, min: A, max: B, stream: StreamOrDevice = .default ) -> MLXArray { let (array, min) = toArrays(array, min) - let (_, max) = max == nil ? (array, nil) : toArrays(array, max!) - return MLXArray(mlx_clip(array.ctx, min.ctx, max?.ctx, stream.ctx)) + let (_, max) = toArrays(array, max) + return MLXArray(mlx_clip(array.ctx, min.ctx, max.ctx, stream.ctx)) } /// Clip the values of the array up to the given maximum. @@ -439,6 +457,7 @@ public func clip( /// /// ### See Also /// - +/// - ``clip(_:min:stream:)`` /// - ``clip(_:min:max:stream:)`` public func clip(_ array: MLXArray, max: A, stream: StreamOrDevice = .default) -> MLXArray @@ -674,6 +693,129 @@ public func convGeneral( groups.int32, flip, stream.ctx)) } +/// 1D transposed convolution over an input with several channels. +/// +/// > Only the default `groups=1` is currently supported. +/// +/// - Parameters: +/// - array: input array of shape `[N, H, C_in]` +/// - weight: weight array of shape `[C_out, H, C_in]` +/// - stride: kernel stride +/// - padding: input padding +/// - dilation: kernel dilation +/// - groups: input feature groups +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``conv1d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convTransposed2d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convTransposed3d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convolve(_:_:mode:stream:)`` +public func convTransposed1d( + _ array: MLXArray, _ weight: MLXArray, stride: Int = 1, padding: Int = 0, dilation: Int = 1, + groups: Int = 1, stream: StreamOrDevice = .default +) -> MLXArray { + MLXArray( + mlx_conv_transpose1d( + array.ctx, weight.ctx, stride.int32, padding.int32, dilation.int32, groups.int32, + stream.ctx)) +} + +/// 2D transposed convolution over an input with several channels. +/// +/// > Only the default `groups=1` is currently supported. +/// +/// The numeric parameters may be given as single values: +/// +/// ```swift +/// padding: 1 +/// ``` +/// +/// This will produce a padding of `(1, 1)`. You can also give an array: +/// +/// ```swift +/// padding: [2, 3] +/// ``` +/// +/// See ``IntOrPair`` for more information. +/// +/// - Parameters: +/// - array: input array of shape `[N, H, W, C_in]` +/// - weight: weight array of shape `[C_out, H, W, C_in]` +/// - stride: kernel stride +/// - padding: input padding +/// - dilation: kernel dilation +/// - groups: input feature groups +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``IntOrPair`` +/// - ``conv1d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convTransposed1d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convTransposed3d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convolve(_:_:mode:stream:)`` +/// - ``convGeneral(_:_:strides:padding:kernelDilation:inputDilation:groups:flip:stream:)-9t1sj`` +public func convTransposed2d( + _ array: MLXArray, _ weight: MLXArray, stride: IntOrPair = 1, padding: IntOrPair = 0, + dilation: IntOrPair = 1, groups: Int = 1, stream: StreamOrDevice = .default +) -> MLXArray { + MLXArray( + mlx_conv_transpose2d( + array.ctx, weight.ctx, stride.first.int32, stride.second.int32, padding.first.int32, + padding.second.int32, dilation.first.int32, dilation.second.int32, groups.int32, + stream.ctx)) +} + +/// 3D transposed convolution over an input with several channels. +/// +/// > Only the default `groups=1` is currently supported. +/// +/// The numeric parameters may be given as single values: +/// +/// ```swift +/// padding: 1 +/// ``` +/// +/// This will produce a padding of `(1, 1, 1)`. You can also give an array: +/// +/// ```swift +/// padding: [2, 3, 3] +/// ``` +/// +/// See ``IntOrTriple`` for more information. +/// +/// - Parameters: +/// - array: input array of shape `[N, D, H, W, C_in]` +/// - weight: weight array of shape `[C_out, D, H, W, C_in]` +/// - stride: kernel stride +/// - padding: input padding +/// - dilation: kernel dilation +/// - groups: input feature groups +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``IntOrTriple`` +/// - ``conv1d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convTransposed1d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convTransposed3d(_:_:stride:padding:dilation:groups:stream:)`` +/// - ``convolve(_:_:mode:stream:)`` +/// - ``convGeneral(_:_:strides:padding:kernelDilation:inputDilation:groups:flip:stream:)-9t1sj`` +public func convTransposed3d( + _ array: MLXArray, _ weight: MLXArray, stride: IntOrTriple = 1, padding: IntOrTriple = 0, + dilation: IntOrTriple = 1, groups: Int = 1, stream: StreamOrDevice = .default +) -> MLXArray { + MLXArray( + mlx_conv_transpose3d( + array.ctx, weight.ctx, + stride.first.int32, stride.second.int32, stride.third.int32, + padding.first.int32, padding.second.int32, padding.third.int32, + dilation.first.int32, dilation.second.int32, dilation.third.int32, + groups.int32, stream.ctx)) +} + /// Mode for ``convolve(_:_:mode:stream:)`` public enum ConvolveMode: Sendable { case full @@ -830,6 +972,36 @@ public func divmod( return (result[0], result[1]) } +/// Perform the Einstein summation convention on the operands. +/// +/// - Parameters: +/// - subscripts: Einstein summation convention equation +/// - operands: input arrays +/// - stream: stream or device to evaluate on +public func einsum(_ subscripts: String, _ operands: MLXArray..., stream: StreamOrDevice = .default) + -> MLXArray +{ + einsum(subscripts, operands: operands, stream: stream) +} + +/// Perform the Einstein summation convention on the operands. +/// +/// - Parameters: +/// - subscripts: Einstein summation convention equation +/// - operands: input arrays +/// - stream: stream or device to evaluate on +public func einsum(_ subscripts: String, operands: [MLXArray], stream: StreamOrDevice = .default) + -> MLXArray +{ + let subscripts = mlx_string_new(subscripts.cString(using: .utf8))! + defer { mlx_free(subscripts) } + + let operands = new_mlx_vector_array(operands) + defer { mlx_free(operands) } + + return MLXArray(mlx_einsum(subscripts, operands, stream.ctx)) +} + /// Element-wise equality. /// /// Equality comparison on two arrays with . @@ -1045,6 +1217,22 @@ public func greaterEqual( return MLXArray(mlx_greater_equal(a.ctx, b.ctx, stream.ctx)) } +/// Perform the Walsh-Hadamard transform along the final axis. +/// +/// Supports sizes `n = m*2^k` for `m` in `(1, 12, 20, 28)` and `2^k <= 8192` +/// for ``DType/float32`` and `2^k <= 16384` for ``DType/float16`` and ``DType/bfloat16``. +/// +/// - Parameters: +/// - array: input array +/// - scale: scale the output by this factor -- default is `1.0/sqrt(array.dim(-1))` +/// - stream: stream to evaluate on +public func hadamardTransform( + _ array: MLXArray, scale: Float? = nil, stream: StreamOrDevice = .default +) -> MLXArray { + let scale = mlx_optional_float(value: scale ?? 0, has_value: scale != nil) + return MLXArray(mlx_hadamard_transform(array.ctx, scale, stream.ctx)) +} + /// Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. /// /// - Parameters: @@ -1118,6 +1306,19 @@ public func isInf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr MLXArray(mlx_isinf(array.ctx, stream.ctx)) } +/// Return a boolean array indicating which elements are finite. +/// +/// - Parameters: +/// - array: input array +/// - stream: stream or device to evaluate on +/// - Returns: The boolean array indicating which elements are infinity. +/// +/// ### See Also +/// - +public func isFinite(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + MLXArray(mlx_isfinite(array.ctx, stream.ctx)) +} + /// Return a boolean array indicating which elements are negative infinity. /// /// - Parameters: @@ -1512,6 +1713,29 @@ public func multiply( return MLXArray(mlx_multiply(a.ctx, b.ctx, stream.ctx)) } +/// Replace NaN and Inf values with finite numbers. +/// +/// - Parameters: +/// - array: input array +/// - nan: value to replace NaN with +/// - posInf: value to replace positive inifinites with. If not specified will use +/// the largest finite value for the given dtype. +/// - negInf: value to replace negative inifinites with. If not specified will use +/// the negative of the largest finite value for the given dtype. +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func nanToNum( + _ array: MLXArray, + nan: Float = 0, posInf: Float? = 0, negInf: Float? = 0, + stream: StreamOrDevice = .default +) -> MLXArray { + let posInf = mlx_optional_float(value: posInf ?? 0, has_value: posInf != nil) + let negInf = mlx_optional_float(value: negInf ?? 0, has_value: negInf != nil) + return MLXArray(mlx_nan_to_num(array.ctx, nan, posInf, negInf, stream.ctx)) +} + /// Element-wise negation. /// /// Negate the values in the array. @@ -1579,12 +1803,21 @@ public func outer( MLXArray(mlx_outer(a.ctx, b.ctx, stream.ctx)) } +/// Mode for ``padded(_:width:value:stream:)`` +public enum PadMode: String { + /// pads with constant value + case constant + /// pads with the edge values of the array + case edge +} + /// Pad an array with a constant value. /// /// - Parameters: /// - array: the array to pad /// - width: either an `Int` number of values to pad before AND after each axis or an array of 2 giving the /// before and after counts +/// - mode: padding mode, see ``PadMode`` /// - value: constant value to pad the edges with /// - stream: stream or device to evaluate on /// @@ -1592,7 +1825,8 @@ public func outer( /// - /// - ``padded(_:widths:value:stream:)`` public func padded( - _ array: MLXArray, width: IntOrPair, value: MLXArray? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, width: IntOrPair, mode: PadMode = .constant, value: MLXArray? = nil, + stream: StreamOrDevice = .default ) -> MLXArray { let ndim = array.ndim let axes = Array(Int32(0) ..< Int32(ndim)) @@ -1600,8 +1834,12 @@ public func padded( let highPads = (0 ..< ndim).map { _ in width.second.int32 } let value = value ?? MLXArray(0, dtype: array.dtype) + let mlx_mode = mlx_string_new(mode.rawValue.cString(using: .utf8))! + defer { mlx_free(mlx_mode) } + return MLXArray( - mlx_pad(array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, stream.ctx)) + mlx_pad( + array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, mlx_mode, stream.ctx)) } /// Pad an array with a constant value. @@ -1609,6 +1847,7 @@ public func padded( /// - Parameters: /// - array: the array to pad /// - widths: array of int or pairs giving the before/after amounts for each axis +/// - mode: padding mode, see ``PadMode`` /// - value: constant value to pad the edges with /// - stream: stream or device to evaluate on /// @@ -1616,7 +1855,7 @@ public func padded( /// - /// - ``padded(_:width:value:stream:)`` public func padded( - _ array: MLXArray, widths: [IntOrPair], value: MLXArray? = nil, + _ array: MLXArray, widths: [IntOrPair], mode: PadMode = .constant, value: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let ndim = array.ndim @@ -1625,8 +1864,12 @@ public func padded( let highPads = widths.map { $0.second.int32 } let value = value ?? MLXArray(0, dtype: array.dtype) + let mlx_mode = mlx_string_new(mode.rawValue.cString(using: .utf8))! + defer { mlx_free(mlx_mode) } + return MLXArray( - mlx_pad(array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, stream.ctx)) + mlx_pad( + array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, mlx_mode, stream.ctx)) } /// Returns a partitioned copy of the array such that the smaller `kth` @@ -1676,6 +1919,46 @@ public func partitioned(_ array: MLXArray, kth: Int, stream: StreamOrDevice = .d MLXArray(mlx_partition_all(array.ctx, kth.int32, stream.ctx)) } +/// Put values along an axis at the specified indices. +/// +/// - Parameters: +/// - array: destination array +/// - indices: Indices array. These should be broadcastable with the input array excluding the `axis` dimension. +/// - values: Values array. These should be broadcastable with the indices. +/// - axis: Axis in the destination to put the values to +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``takeAlong(_:_:stream:)`` +public func putAlong( + _ array: MLXArray, _ indices: MLXArray, values: MLXArray, axis: Int, + stream: StreamOrDevice = .default +) -> MLXArray { + MLXArray(mlx_put_along_axis(array.ctx, indices.ctx, values.ctx, axis.int32, stream.ctx)) +} + +/// Put values along an axis at the specified indices in a flattened array. +/// +/// - Parameters: +/// - array: destination array +/// - indices: Indices array. These should be broadcastable with the flattened input array +/// - values: Values array. These should be broadcastable with the flattened input array +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +/// - ``takeAlong(_:_:axis:stream:) +public func putAlong( + _ array: MLXArray, _ indices: MLXArray, values: MLXArray, stream: StreamOrDevice = .default +) + -> MLXArray +{ + let input = array.reshaped([-1], stream: stream) + let result = MLXArray(mlx_put_along_axis(input.ctx, indices.ctx, values.ctx, 0, stream.ctx)) + return result.reshaped(array.shape, stream: stream) +} + /// Quantize the matrix `w` using `bits` bits per element. /// /// Note, every `group_size` elements in a row of `w` are quantized @@ -1694,11 +1977,10 @@ public func partitioned(_ array: MLXArray, kth: Int, stream: StreamOrDevice = .d public func quantized( _ w: MLXArray, groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default ) -> (wq: MLXArray, scales: MLXArray, biases: MLXArray) { - let result = mlx_quantize(w.ctx, groupSize.int32, bits.int32, stream.ctx)! - defer { mlx_free(result) } + let result_tuple = mlx_quantize(w.ctx, groupSize.int32, bits.int32, stream.ctx)! + defer { mlx_free(result_tuple) } - let arrays = mlx_vector_array_values(result) - return (arrays[0], arrays[1], arrays[2]) + return mlx_tuple_values(result_tuple) } /// Perform the matrix multiplication with the quantized matrix `w`. The diff --git a/Source/MLX/Transforms+Compile.swift b/Source/MLX/Transforms+Compile.swift index 6fd7ab03..99395971 100644 --- a/Source/MLX/Transforms+Compile.swift +++ b/Source/MLX/Transforms+Compile.swift @@ -3,10 +3,11 @@ import Cmlx import Foundation -class CompiledFunction { +// Note: this is all immutable state -- the `id` property is only set at init time +final class CompiledFunction: @unchecked (Sendable) { /// unique (for the lifetime of the object) identifier for the compiled function - var id: UInt! + private var id: UInt! /// the function to compile let f: ([MLXArray]) -> [MLXArray] @@ -128,9 +129,7 @@ class CompiledFunction { public func compile( inputs: [any Updatable] = [], outputs: [any Updatable] = [], shapeless: Bool = false, _ f: @escaping ([MLXArray]) -> [MLXArray] -) -> ( - [MLXArray] -) -> [MLXArray] { +) -> @Sendable ([MLXArray]) -> [MLXArray] { let compileState = CompiledFunction(inputs: inputs, outputs: outputs, shapeless: shapeless, f) return { arrays in @@ -147,9 +146,7 @@ public func compile( public func compile( inputs: [any Updatable] = [], outputs: [any Updatable] = [], shapeless: Bool = false, _ f: @escaping (MLXArray) -> MLXArray -) -> ( - MLXArray -) -> MLXArray { +) -> @Sendable (MLXArray) -> MLXArray { let compileState = CompiledFunction(inputs: inputs, outputs: outputs, shapeless: shapeless) { [f($0[0])] } @@ -169,7 +166,7 @@ public func compile( inputs: [any Updatable] = [], outputs: [any Updatable] = [], shapeless: Bool = false, _ f: @escaping (MLXArray, MLXArray) -> MLXArray ) - -> (MLXArray, MLXArray) -> MLXArray + -> @Sendable (MLXArray, MLXArray) -> MLXArray { let compileState = CompiledFunction(inputs: inputs, outputs: outputs, shapeless: shapeless) { [f($0[0], $0[1])] @@ -188,7 +185,7 @@ public func compile( /// - ``compile(inputs:outputs:shapeless:_:)-7korq`` public func compile( inputs: [any Updatable] = [], outputs: [any Updatable] = [], shapeless: Bool = false, - _ f: @escaping (MLXArray, MLXArray, MLXArray) -> MLXArray + _ f: @Sendable @escaping (MLXArray, MLXArray, MLXArray) -> MLXArray ) -> (MLXArray, MLXArray, MLXArray) -> MLXArray { diff --git a/Source/MLX/Transforms+Internal.swift b/Source/MLX/Transforms+Internal.swift index 1589314a..57ca2e6f 100644 --- a/Source/MLX/Transforms+Internal.swift +++ b/Source/MLX/Transforms+Internal.swift @@ -14,13 +14,8 @@ private func valueAndGradient(apply valueAndGrad: mlx_closure_value_and_grad, ar let vector_pair = mlx_closure_value_and_grad_apply(valueAndGrad, input_vector)! defer { mlx_free(vector_pair) } - let values = mlx_vector_vector_array_get(vector_pair, 0)! - defer { mlx_free((values)) } - - let gradient = mlx_vector_vector_array_get(vector_pair, 1)! - defer { mlx_free((gradient)) } - - return (mlx_vector_array_values(values), mlx_vector_array_values(gradient)) + let (values, gradient) = mlx_tuple_vectors(vector_pair) + return (values, gradient) } func buildGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [Int]) -> ( diff --git a/Source/MLX/Transforms.swift b/Source/MLX/Transforms.swift index d765f899..5218d877 100644 --- a/Source/MLX/Transforms.swift +++ b/Source/MLX/Transforms.swift @@ -29,13 +29,7 @@ public func jvp( mlx_free(closure) - let v1 = mlx_vector_vector_array_get(vector_pair, 0)! - defer { mlx_free((v1)) } - - let v2 = mlx_vector_vector_array_get(vector_pair, 1)! - defer { mlx_free((v2)) } - - return (mlx_vector_array_values(v1), mlx_vector_array_values(v2)) + return mlx_tuple_vectors(vector_pair) } /// Compute the vector-Jacobian product. @@ -64,13 +58,7 @@ public func vjp( mlx_free(closure) - let v1 = mlx_vector_vector_array_get(vector_pair, 0)! - defer { mlx_free((v1)) } - - let v2 = mlx_vector_vector_array_get(vector_pair, 1)! - defer { mlx_free((v2)) } - - return (mlx_vector_array_values(v1), mlx_vector_array_values(v2)) + return mlx_tuple_vectors(vector_pair) } /// Returns a function that computes the gradient and result of `f`, computing the gradient with respect to the ``NestedDictionary``. diff --git a/Source/MLXFast/Cmlx+Util.swift b/Source/MLXFast/Cmlx+Util.swift new file mode 100644 index 00000000..a36e123b --- /dev/null +++ b/Source/MLXFast/Cmlx+Util.swift @@ -0,0 +1,74 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx +import Foundation +import MLX + +@inline(__always) +func mlx_free(_ ptr: OpaquePointer) { + mlx_free(UnsafeMutableRawPointer(ptr)) +} + +// return a +1 mlx_vector_array containing the given arrays +func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { + let result = mlx_vector_array_new()! + mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) + return result +} + +func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { + (0 ..< mlx_vector_array_size(vector_array)) + .map { index in + // ctx is a +1 object, the array takes ownership + let ctx = mlx_vector_array_get(vector_array, index)! + return MLXArray(ctx) + } +} + +func new_mlx_vector_string(_ values: [String]) -> mlx_vector_string { + let result = mlx_vector_string_new()! + for value in values { + let mlxString = mlx_string_new(value.cString(using: .utf8))! + mlx_vector_string_add_value(result, mlxString) + mlx_free(mlxString) + } + return result +} + +func new_mlx_vector_vector_int(_ values: [[Int]]) -> mlx_vector_vector_int { + let result = mlx_vector_vector_int_new()! + for value in values { + let vector = mlx_vector_int_from_data(value.map { Int32($0) }, value.count)! + mlx_vector_vector_int_add_value(result, vector) + mlx_free(vector) + } + return result +} + +func new_mlx_vector_array_dtype(_ values: [DType]) -> mlx_vector_array_dtype { + mlx_vector_array_dtype_from_data( + values.map { $0.cmlxDtype }, + values.count + ) +} + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { + let a = mlx_tuple_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_get_1(tuple)! + return (MLXArray(a), MLXArray(b)) +} + +func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { + let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! + defer { mlx_free(a) } + let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! + defer { mlx_free(b) } + return (mlx_vector_array_values(a), mlx_vector_array_values(b)) +} + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { + let a = mlx_tuple_array_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_array_get_1(tuple)! + let c = mlx_tuple_array_array_array_get_2(tuple)! + return (MLXArray(a), MLXArray(b), MLXArray(c)) +} diff --git a/Source/MLXFast/MLXFast.swift b/Source/MLXFast/MLXFast.swift index 3c63f98c..79b2cf96 100644 --- a/Source/MLXFast/MLXFast.swift +++ b/Source/MLXFast/MLXFast.swift @@ -23,12 +23,14 @@ import MLX /// /// > Note: `MLXNN.RoPE` uses this implementation internally. public func RoPE( - _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float, scale: Float, offset: Int, - stream: StreamOrDevice = .default + _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, offset: Int, + freqs: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( + let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) + return MLXArray( mlx_fast_rope( - array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), stream.ctx)) + array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), + freqs?.ctx, stream.ctx)) } /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` @@ -55,11 +57,14 @@ public func RoPE( /// ``` public func scaledDotProductAttention( queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, - stream: StreamOrDevice = .default + memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( + let memoryEfficientThreshold = mlx_optional_int( + value: Int32(memoryEfficientThreshold ?? 0), has_value: memoryEfficientThreshold != nil) + return MLXArray( mlx_fast_scaled_dot_product_attention( - queries.ctx, keys.ctx, values.ctx, scale, mask?.ctx, stream.ctx)) + queries.ctx, keys.ctx, values.ctx, scale, mask?.ctx, + memoryEfficientThreshold, stream.ctx)) } /// Root Mean Square normalization (RMS norm). @@ -96,3 +101,26 @@ public func layerNorm( ) -> MLXArray { MLXArray(mlx_fast_layer_norm(x.ctx, weight?.ctx, bias?.ctx, eps, stream.ctx)) } + +/// Quantize the matrix `w` using the provided `scales` and +/// `biases` and the `groupSize` and `bits` configuration. + +/// For details, please see +/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.affine_quantize.html) +/// +/// - Parameters: +/// - w: Matrix to be quantized +/// - scales: The scales to use per `groupSize` elements of `w` +/// - biases: The biases to use per `groupSize` elements of `w` +/// - groupSize: The size of the group in `w` that shares a scale and bias. +/// - bits: The number of bits occupied by each element in `w`. +/// - stream: stream or device to evaluate on +/// - Returns: quantized version of `w` +public func affineQuantized( + _ w: MLXArray, scales: MLXArray, biases: MLXArray, groupSize: Int = 64, bits: Int = 4, + stream: StreamOrDevice = .default +) -> MLXArray { + MLXArray( + mlx_fast_affine_quantize( + w.ctx, scales.ctx, biases.ctx, Int32(groupSize), Int32(bits), stream.ctx)) +} diff --git a/Source/MLXFast/MLXFastKernel.swift b/Source/MLXFast/MLXFastKernel.swift new file mode 100644 index 00000000..8425a40d --- /dev/null +++ b/Source/MLXFast/MLXFastKernel.swift @@ -0,0 +1,207 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx +import MLX + +/// Marker protocol for types that can be used in the `template` of a kernel call. +/// +/// Currently: +/// - `Int` +/// - `Bool` +/// - `DType` +/// +/// See also: ``MLXFastKernel`` +public protocol KernelTemplateArg {} + +extension Bool: KernelTemplateArg {} +extension Int: KernelTemplateArg {} +extension DType: KernelTemplateArg {} + +/// Add a ``KernelTemplateArg`` to the tuple of vectors +private func add( + name: String, + value: any KernelTemplateArg, + to vector: mlx_vector_tuple_string_variant_int_bool_array_dtype +) { + let name = mlx_string_new(name.cString(using: .utf8))! + defer { mlx_free(name) } + + let value = + switch value { + case let value as Bool: + mlx_variant_int_bool_array_dtype_new_with_bool(value)! + + case let value as Int: + mlx_variant_int_bool_array_dtype_new_with_int(Int32(value))! + + case let value as DType: + mlx_variant_int_bool_array_dtype_new_with_array_dtype(value.cmlxDtype)! + + default: + fatalError("Unable to handle KernelTemplateArg with type: \(type(of: value)).") + } + + defer { mlx_free(value) } + + let tuple = mlx_tuple_string_variant_int_bool_array_dtype_new(name, value)! + defer { mlx_free(tuple) } + + mlx_vector_tuple_string_variant_int_bool_array_dtype_add_value(vector, tuple) +} + +/// Container for a kernel created by +/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:)``. +/// +/// The ``callAsFunction(inputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:stream:)`` +/// can be used to evaluate the kernel with inputs: +/// +/// ```swift +/// let a = normal([2, 2]) +/// let kernel = MLXFast.metalKernel( +/// name: "basic", +/// inputNames: ["a"], +/// outputNames: ["out1"], +/// source: """ +/// uint elem = thread_position_in_grid.x; +/// out1[elem] = a[elem]; +/// """) +/// +/// let out = kernel( +/// inputs: [a], +/// grid: (4, 1, 1), +/// threadGroup: (2, 1, 1), +/// outputShapes: [[2, 2]], +/// outputDTypes: [.float32]) +/// ``` +open class MLXFastKernel { + let kernel: mlx_closure_metal_kernel_function + public let outputNames: [String] + + init( + name: String, inputNames: [String], outputNames: [String], + source: String, header: String = "", + ensureRowContiguous: Bool = true, + atomicOutputs: Bool = false + ) { + self.outputNames = outputNames + + let mlxName = mlx_string_new(name.cString(using: .utf8))! + defer { mlx_free(mlxName) } + + let mlxInputNames = new_mlx_vector_string(inputNames) + defer { mlx_free(mlxInputNames) } + let mlxOutputNames = new_mlx_vector_string(outputNames) + defer { mlx_free(mlxOutputNames) } + + let mlxSource = mlx_string_new(source.cString(using: .utf8))! + defer { mlx_free(mlxSource) } + let mlxHeader = mlx_string_new(header.cString(using: .utf8))! + defer { mlx_free(mlxHeader) } + + self.kernel = mlx_fast_metal_kernel( + mlxName, mlxInputNames, mlxOutputNames, mlxSource, mlxHeader, ensureRowContiguous, + atomicOutputs) + } + + deinit { + mlx_free(kernel) + } + + /// Call the prepared metal kernel. + /// + /// See ``MLXFastKernel`` for example. Use + /// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:)`` + /// to create an instance. + /// + /// - Parameters: + /// - inputs: inputs passed to the metal kernel + /// - template: template arguments + /// - grid: 3-tuple specifying the grid to launch the kernel with + /// - threadGroup: 3-tuple specifying the threadgroup size to use + /// - outputShapes: list of shapes for each output in ``outputNames`` + /// - outputDTypes: list of data types for each output in ``outputNames`` + /// - initValue: optional value to use to initialize all of the output arrays + /// - verbose: if true will print the full generated source code of the kernel when run + /// - stream: stream to run on + /// - Returns: array of `MLXArray` + public func callAsFunction( + inputs: [ScalarOrArray], + template: [(String, KernelTemplateArg)]? = nil, + grid: (Int, Int, Int), + threadGroup: (Int, Int, Int), + outputShapes: [[Int]], + outputDTypes: [DType], + initValue: Float? = nil, + verbose: Bool = false, + stream: StreamOrDevice = .default + ) -> [MLXArray] { + // convert all the inputs into the mlx-c types + let inputs = new_mlx_vector_array(inputs.map { $0.asMLXArray(dtype: nil) }) + defer { mlx_free(inputs) } + + let outputShapes = new_mlx_vector_vector_int(outputShapes) + defer { mlx_free(outputShapes) } + + let outputDTypes = new_mlx_vector_array_dtype(outputDTypes) + defer { mlx_free(outputDTypes) } + + let grid = mlx_tuple_int_int_int_new(Int32(grid.0), Int32(grid.1), Int32(grid.2))! + defer { mlx_free(grid) } + + let threadGroup = mlx_tuple_int_int_int_new( + Int32(threadGroup.0), Int32(threadGroup.1), Int32(threadGroup.2))! + defer { mlx_free(threadGroup) } + + let templateVector = mlx_vector_tuple_string_variant_int_bool_array_dtype_new()! + defer { mlx_free(templateVector) } + if let template { + for (name, value) in template { + add(name: name, value: value, to: templateVector) + } + } + + let initValue = mlx_optional_float(value: initValue ?? 0, has_value: initValue != nil) + + let result = mlx_closure_metal_kernel_function_apply( + kernel, + inputs, + outputShapes, + outputDTypes, + grid, + threadGroup, + templateVector, + initValue, + verbose, + stream.ctx)! + defer { mlx_free(result) } + + return mlx_vector_array_values(result) + } +} + +/// A jit-compiled custom Metal kernel defined from a source string. +/// +/// - Parameters: +/// - name: name for the kernel +/// - inputNames: parameter names of the inputs in the function signature +/// - outputNames: parameter names of the outputs in the function signature +/// - source: source code -- this is the body of a function in Metal, +/// the function signature will be automatically generated. +/// - header: header source code to include before the main function. Useful +/// for helper functions or includes that should live outside of the main function body. +/// - ensureRowContiguous: whether to ensure the inputs are row contiguous +/// before the kernel runs (at a performance cost) +/// - atomicOutputs: whether to use atomic outputs in the function signature, +/// e.g. `device atomic` +/// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it +public func metalKernel( + name: String, inputNames: [String], outputNames: [String], + source: String, header: String = "", + ensureRowContiguous: Bool = true, + atomicOutputs: Bool = false +) -> MLXFastKernel { + MLXFastKernel( + name: name, inputNames: inputNames, outputNames: outputNames, + source: source, header: header, + ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs) +} diff --git a/Source/MLXLinalg/Cmlx+Util.swift b/Source/MLXLinalg/Cmlx+Util.swift index 938b8319..06d5024d 100644 --- a/Source/MLXLinalg/Cmlx+Util.swift +++ b/Source/MLXLinalg/Cmlx+Util.swift @@ -12,7 +12,7 @@ func mlx_free(_ ptr: OpaquePointer) { // return a +1 mlx_vector_array containing the given arrays func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { let result = mlx_vector_array_new()! - mlx_vector_array_add_arrays(result, arrays.map { $0.ctx }, arrays.count) + mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) return result } @@ -24,3 +24,24 @@ func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { return MLXArray(ctx) } } + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { + let a = mlx_tuple_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_get_1(tuple)! + return (MLXArray(a), MLXArray(b)) +} + +func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { + let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! + defer { mlx_free(a) } + let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! + defer { mlx_free(b) } + return (mlx_vector_array_values(a), mlx_vector_array_values(b)) +} + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { + let a = mlx_tuple_array_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_array_get_1(tuple)! + let c = mlx_tuple_array_array_array_get_2(tuple)! + return (MLXArray(a), MLXArray(b), MLXArray(c)) +} diff --git a/Source/MLXLinalg/Linalg.swift b/Source/MLXLinalg/Linalg.swift index 54b54ec2..0f9dd58f 100644 --- a/Source/MLXLinalg/Linalg.swift +++ b/Source/MLXLinalg/Linalg.swift @@ -149,15 +149,18 @@ public func norm( /// /// See ``norm(_:ord:axes:keepDims:stream:)-4dwwp`` public func norm( - _ array: MLXArray, ord: NormKind? = nil, keepDims: Bool = false, - stream: StreamOrDevice = .default + _ array: MLXArray, ord: NormKind? = nil, axis: IntOrArray? = nil, + keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { if let ord { let ord_str = mlx_string_new(ord.rawValue.cString(using: .utf8))! defer { mlx_free(ord_str) } - return MLXArray(mlx_linalg_norm_ord(array.ctx, ord_str, nil, 0, keepDims, stream.ctx)) + return MLXArray( + mlx_linalg_norm_ord( + array.ctx, ord_str, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx)) } else { - return MLXArray(mlx_linalg_norm(array.ctx, nil, 0, keepDims, stream.ctx)) + return MLXArray( + mlx_linalg_norm(array.ctx, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx)) } } @@ -165,9 +168,12 @@ public func norm( /// /// See ``norm(_:ord:axes:keepDims:stream:)-3t3ay`` public func norm( - _ array: MLXArray, ord: Double, keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, ord: Double, axis: IntOrArray? = nil, + keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_linalg_norm_p(array.ctx, ord, nil, 0, keepDims, stream.ctx)) + MLXArray( + mlx_linalg_norm_p( + array.ctx, ord, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx)) } /// The QR factorization of the input matrix. @@ -178,11 +184,10 @@ public func norm( /// /// - Returns: the `Q` and `R` matrices public func qr(_ array: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { - let result_vector = mlx_linalg_qr(array.ctx, stream.ctx)! - defer { mlx_free(result_vector) } + let result_tuple = mlx_linalg_qr(array.ctx, stream.ctx)! + defer { mlx_free(result_tuple) } - let arrays = mlx_vector_array_values(result_vector) - return (arrays[0], arrays[1]) + return mlx_tuple_values(result_tuple) } /// The Singular Value Decomposition (SVD) of the input matrix. @@ -219,6 +224,24 @@ public func inv(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra MLXArray(mlx_linalg_inv(array.ctx, stream.ctx)) } +/// Compute the inverse of a triangular square matrix. +/// +/// This function supports arrays with at least 2 dimensions. When the input +/// has more than two dimensions, the inverse is computed for each matrix +/// in the last two dimensions of `array`. +/// +/// - Parameters: +/// - array: input array +/// - upper: true if the array is an upper triangular matrix +/// - stream: stream or device to evaluate on +/// - Returns: `ainv` such that `dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])` +public func triInv( + _ array: MLXArray, upper: Bool = false, + stream: StreamOrDevice = .default +) -> MLXArray { + MLXArray(mlx_linalg_tri_inv(array.ctx, upper, stream.ctx)) +} + /// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. /// /// This function supports arrays with at least 2 dimensions. When the input @@ -237,3 +260,39 @@ public func cholesky(_ array: MLXArray, upper: Bool = false, stream: StreamOrDev { MLXArray(mlx_linalg_cholesky(array.ctx, upper, stream.ctx)) } + +/// Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition. +/// +/// This function supports arrays with at least 2 dimensions. When the input +/// has more than two dimensions, the Cholesky decomposition is computed for each matrix +/// in the last two dimensions of `a`. +/// +/// If the input matrix is not a triangular matrix behavior is undefined. +/// +/// - Parameters: +/// - array: input array +/// - upper: if true return the upper triangular Cholesky factor, otherwise the lower triangular +/// Cholesky factor. +/// - stream: stream or device to evaluate on +public func choleskyInv(_ array: MLXArray, upper: Bool = false, stream: StreamOrDevice = .default) + -> MLXArray +{ + MLXArray(mlx_linalg_cholesky_inv(array.ctx, upper, stream.ctx)) +} + +/// Compute the cross product of two arrays along a specified axis. +/// +/// The cross product is defined for arrays with size 2 or 3 in the +/// specified axis. If the size is 2 then the third value is assumed +/// to be zero. +/// +/// - Parameters: +/// - a: input array +/// - b: input array +/// - axis: axis along which to compute the cross product +/// - stream: stream or device to evaluate on +public func cross(_ a: MLXArray, _ b: MLXArray, axis: Int = -1, stream: StreamOrDevice = .default) + -> MLXArray +{ + MLXArray(mlx_linalg_cross(a.ctx, b.ctx, axis.int32, stream.ctx)) +} diff --git a/Source/MLXNN/Activations.swift b/Source/MLXNN/Activations.swift index 52d18c7a..3dbf2b22 100644 --- a/Source/MLXNN/Activations.swift +++ b/Source/MLXNN/Activations.swift @@ -230,7 +230,7 @@ public func geluApproximate(_ x: MLXArray) -> MLXArray { /// This is: /// /// ```swift -/// x * sigmoid(1.773 * x) +/// x * sigmoid(1.702 * x) /// ``` /// /// ### See Also @@ -672,6 +672,8 @@ open class GELU: Module, UnaryLayer { case none /// See ``geluApproximate(_:)`` case precise + /// Alias for ``precise`` -- see ``geluApproximate(_:)`` + case tanh /// See ``geluFastApproximate(_:)`` case fast } @@ -687,7 +689,7 @@ open class GELU: Module, UnaryLayer { switch approximation { case .none: gelu(x) - case .precise: + case .precise, .tanh: geluApproximate(x) case .fast: geluFastApproximate(x) @@ -766,85 +768,85 @@ open class SELU: Module, UnaryLayer { // MARK: - Compiled Activation Functions -private let compiledLeakyRelu: (MLXArray, MLXArray) -> MLXArray = { +private let compiledLeakyRelu: @Sendable (MLXArray, MLXArray) -> MLXArray = { compile(shapeless: true) { x, negativeSlope in maximum(negativeSlope * x, x) } }() -private let compiledElu: (MLXArray, MLXArray) -> MLXArray = { +private let compiledElu: @Sendable (MLXArray, MLXArray) -> MLXArray = { compile(shapeless: true) { x, alpha in which(x .> 0, x, alpha * (MLX.exp(x) - 1)) } }() -private let compiledRelu6: (MLXArray) -> MLXArray = { +private let compiledRelu6: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in minimum(maximum(x, 0), 6) } }() -private let compiledSoftsign: (MLXArray) -> MLXArray = { +private let compiledSoftsign: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in x / (1 + abs(x)) } }() -private let compiledCelu: (MLXArray, MLXArray) -> MLXArray = { +private let compiledCelu: @Sendable (MLXArray, MLXArray) -> MLXArray = { compile(shapeless: true) { x, alpha in maximum(x, 0.0) + alpha * (exp(minimum(x, 0.0) / alpha) - 1) } }() -private let compiledSilu: (MLXArray) -> MLXArray = { +private let compiledSilu: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in x * sigmoid(x) } }() -private let compiledLogSigmoid: (MLXArray) -> MLXArray = { +private let compiledLogSigmoid: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in -softplus(-x) } }() -private let compiledGelu: (MLXArray) -> MLXArray = { +private let compiledGelu: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in x * (1 + erf(x / sqrt(2))) / 2 } }() -private let compiledGeluApproximate: (MLXArray) -> MLXArray = { +private let compiledGeluApproximate: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in 0.5 * x * (1 + tanh(sqrt(2 / Float.pi) * (x + 0.044715 * x ** 3))) } }() -private let compiledGeluFastApproximate: (MLXArray) -> MLXArray = { +private let compiledGeluFastApproximate: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in - x * sigmoid(1.773 * x) + x * sigmoid(1.702 * x) } }() -private let compiledSelu: (MLXArray) -> MLXArray = { +private let compiledSelu: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in elu(x, alpha: 1.67326) * 1.0507 } }() -private let compiledPrelu: (MLXArray, MLXArray) -> MLXArray = { +private let compiledPrelu: @Sendable (MLXArray, MLXArray) -> MLXArray = { compile(shapeless: true) { x, alpha in maximum(0, x) + alpha * minimum(0, x) } }() -private let compiledMish: (MLXArray) -> MLXArray = { +private let compiledMish: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in x * tanh(softplus(x)) } }() -private let compiledHardSwish: (MLXArray) -> MLXArray = { +private let compiledHardSwish: @Sendable (MLXArray) -> MLXArray = { compile(shapeless: true) { x in let maxXPlus3 = maximum(x + 3, 0) return x * minimum(maxXPlus3, 6) / 6 diff --git a/Source/MLXNN/Cache.swift b/Source/MLXNN/Cache.swift index ed658d16..f9c46f51 100644 --- a/Source/MLXNN/Cache.swift +++ b/Source/MLXNN/Cache.swift @@ -5,7 +5,7 @@ import Foundation /// Simple cache for holding prepared MLXArrays, etc. /// /// See ``RoPE`` -class Cache { +class Cache: @unchecked (Sendable) { let queue = DispatchQueue(label: "Cache") diff --git a/Source/MLXNN/ConvolutionTransposed.swift b/Source/MLXNN/ConvolutionTransposed.swift new file mode 100644 index 00000000..11ae260d --- /dev/null +++ b/Source/MLXNN/ConvolutionTransposed.swift @@ -0,0 +1,173 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXRandom + +/// Applies a 1-dimensional transposed convolution over the multi-channel input sequence. +/// +/// ### See Also +/// - ``ConvTransposed2d`` +/// - ``ConvTransposed3d`` +/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:bias:)`` +open class ConvTransposed1d: Module, UnaryLayer { + + public let weight: MLXArray + public let bias: MLXArray? + public let padding: Int + public let stride: Int + + /// Applies a 1-dimensional transposed convolution over the multi-channel input sequence. + /// + /// The channels are expected to be last i.e. the input shape should be `NLC` where: + /// + /// - `N` is the batch dimension + /// - `L` is the sequence length + /// - `C` is the number of input channels + /// + /// - Parameters: + /// - inputChannels: number of input channels (`C` from the discussion) + /// - outputChannels: number of output channels + /// - kernelSize: size of the convolution filters + /// - stride: stride when applying the filter + /// - padding: how many positions to 0-pad the input with + /// - bias: if `true` add a learnable bias to the output + public init( + inputChannels: Int, + outputChannels: Int, + kernelSize: Int, + stride: Int = 1, + padding: Int = 0, + bias: Bool = true + ) { + let scale = sqrt(1 / Float(inputChannels * kernelSize)) + + self.weight = uniform( + low: -scale, high: scale, [outputChannels, kernelSize, inputChannels]) + self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.padding = padding + self.stride = stride + } + + open func callAsFunction(_ x: MLXArray) -> MLXArray { + var y = convTransposed1d(x, weight, stride: stride, padding: padding) + if let bias { + y = y + bias + } + return y + } +} + +/// Applies a 2-dimensional transposed convolution over the multi-channel input image. +/// +/// ### See Also +/// - ``ConvTransposed1d`` +/// - ``ConvTransposed3d`` +/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:bias:)`` +open class ConvTransposed2d: Module, UnaryLayer { + + public let weight: MLXArray + public let bias: MLXArray? + public let padding: (Int, Int) + public let stride: (Int, Int) + + /// Applies a 2-dimensional transposed convolution over the multi-channel input image. + /// + /// The channels are expected to be last i.e. the input shape should be `NHWC` where: + /// + /// - `N` is the batch dimension + /// - `H` is the input image height + /// - `W` is the input image width + /// - `C` is the number of input channels + /// + /// - Parameters: + /// - inputChannels: number of input channels (`C` from the discussion) + /// - outputChannels: number of output channels + /// - kernelSize: size of the convolution filters + /// - stride: stride when applying the filter + /// - padding: how many positions to 0-pad the input with + /// - bias: if `true` add a learnable bias to the output + public init( + inputChannels: Int, + outputChannels: Int, + kernelSize: IntOrPair, + stride: IntOrPair = 1, + padding: IntOrPair = 0, + bias: Bool = true + ) { + let scale = sqrt(1 / Float(inputChannels * kernelSize.first * kernelSize.second)) + + self.weight = uniform( + low: -scale, high: scale, + [outputChannels, kernelSize.first, kernelSize.second, inputChannels]) + self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.padding = padding.values + self.stride = stride.values + } + + open func callAsFunction(_ x: MLXArray) -> MLXArray { + var y = convTransposed2d(x, weight, stride: .init(stride), padding: .init(padding)) + if let bias { + y = y + bias + } + return y + } +} + +/// Applies a 3-dimensional transposed convolution over the multi-channel input image. +/// +/// ### See Also +/// - ``ConvTransposed1d`` +/// - ``ConvTransposed2d`` +/// - ``init(inputChannels:outputChannels:kernelSize:stride:padding:bias:)`` +open class ConvTransposed3d: Module, UnaryLayer { + + public let weight: MLXArray + public let bias: MLXArray? + public let padding: (Int, Int, Int) + public let stride: (Int, Int, Int) + + /// Applies a 3-dimensional transposed convolution over the multi-channel input image. + /// + /// The channels are expected to be last i.e. the input shape should be `NDHWC` where: + /// + /// - `N` is the batch dimension + /// - `D` is the input image depth + /// - `H` is the input image height + /// - `W` is the input image width + /// - `C` is the number of input channels + /// + /// - Parameters: + /// - inputChannels: number of input channels (`C` from the discussion) + /// - outputChannels: number of output channels + /// - kernelSize: size of the convolution filters + /// - stride: stride when applying the filter + /// - padding: how many positions to 0-pad the input with + /// - bias: if `true` add a learnable bias to the output + public init( + inputChannels: Int, + outputChannels: Int, + kernelSize: IntOrTriple, + stride: IntOrTriple = 1, + padding: IntOrTriple = 0, + bias: Bool = true + ) { + let scale = sqrt( + 1 / Float(inputChannels * kernelSize.first * kernelSize.second * kernelSize.third)) + + self.weight = uniform( + low: -scale, high: scale, + [outputChannels, kernelSize.first, kernelSize.second, kernelSize.third, inputChannels]) + self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.padding = padding.values + self.stride = stride.values + } + + open func callAsFunction(_ x: MLXArray) -> MLXArray { + var y = convTransposed3d(x, weight, stride: .init(stride), padding: .init(padding)) + if let bias { + y = y + bias + } + return y + } +} diff --git a/Source/MLXNN/Documentation.docc/layers.md b/Source/MLXNN/Documentation.docc/layers.md index 8a94fa54..972ee383 100644 --- a/Source/MLXNN/Documentation.docc/layers.md +++ b/Source/MLXNN/Documentation.docc/layers.md @@ -16,6 +16,10 @@ These can be used with ``Sequential``. - ``AvgPool2d`` - ``Conv1d`` - ``Conv2d`` +- ``Conv3d`` +- ``ConvTransposed1d`` +- ``ConvTransposed2d`` +- ``ConvTransposed3d`` - ``Dropout`` - ``Dropout2d`` - ``Dropout3d`` diff --git a/Source/MLXNN/Dropout.swift b/Source/MLXNN/Dropout.swift index 1ff07d72..0da70310 100644 --- a/Source/MLXNN/Dropout.swift +++ b/Source/MLXNN/Dropout.swift @@ -28,7 +28,7 @@ open class Dropout: Module, UnaryLayer { } let mask = bernoulli(p1, x.shape) - return (1 / p1) * mask * x + return (mask * x) * (1 / p1) } } @@ -79,7 +79,7 @@ open class Dropout2d: Module, UnaryLayer { maskShape[maskShape.endIndex - 3] = 1 let mask = bernoulli(p1, maskShape) - return (1 / p1) * mask * x + return (mask * x) * (1 / p1) } } @@ -127,6 +127,6 @@ open class Dropout3d: Module, UnaryLayer { maskShape[maskShape.endIndex - 4] = 1 let mask = bernoulli(p1, maskShape) - return (1 / p1) * mask * x + return (mask * x) * (1 / p1) } } diff --git a/Source/MLXNN/Losses.swift b/Source/MLXNN/Losses.swift index 68f5f9fb..563020b6 100644 --- a/Source/MLXNN/Losses.swift +++ b/Source/MLXNN/Losses.swift @@ -74,18 +74,38 @@ public func crossEntropy( /// Computes the binary cross entropy loss. /// +/// By default, this function takes the pre-sigmoid logits, which results in a faster +/// and more precise loss. For improved numerical stability when `withLogits` is true, +/// the loss calculation clips the input probabilities (in log-space) to a minimum value +/// of `-100`. +/// /// - Parameters: /// - logits: unnormalized predicted logits /// - targets: binary target values in {0, 1} +/// - weights: optional weights for each target +/// - withLogits: whether the `logits` parameter is logits or probabilities /// - reduction: reduction type /// - Returns: computed binary cross entropy loss /// /// ### See Also /// - public func binaryCrossEntropy( - logits: MLXArray, targets: MLXArray, reduction: LossReduction = .none + logits: MLXArray, targets: MLXArray, + weights: MLXArray? = nil, withLogits: Bool = true, + reduction: LossReduction = .none ) -> MLXArray { - let loss = logAddExp(0, logits) - targets * logits + var loss: MLXArray + if withLogits { + loss = logAddExp(0, logits) - targets * logits + } else { + let logInputsClip = clip(log(logits), min: -100) + let logInputsInverseClip = clip(log(1 - logits), min: -100) + loss = -(targets * logInputsClip + (1 - targets) * logInputsInverseClip) + } + if let weights { + precondition(weights.shape == loss.shape) + loss *= weights + } return reduction.reduce(loss: loss) } diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index 8f58de68..f40b6267 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -887,7 +887,8 @@ extension Module { /// /// ### See Also /// - - static public let filterAll = { (module: Module, key: String, item: ModuleItem) in + static public let filterAll: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in true } @@ -895,7 +896,8 @@ extension Module { /// /// ### See Also /// - - static public let filterValidChild = { (module: Module, key: String, item: ModuleItem) in + static public let filterValidChild: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary: true case .value(.module): true @@ -910,7 +912,8 @@ extension Module { /// - /// - ``filterLocalParameters`` /// - ``filterTrainableParameters`` - static public let filterValidParameters = { (module: Module, key: String, item: ModuleItem) in + static public let filterValidParameters: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary: !key.hasPrefix("_") case .value(.parameters), .value(.module): !key.hasPrefix("_") @@ -925,7 +928,8 @@ extension Module { /// - /// - ``filterValidParameters`` /// - ``filterTrainableParameters`` - static public let filterLocalParameters = { (module: Module, key: String, item: ModuleItem) in + static public let filterLocalParameters: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary: !key.hasPrefix("_") case .value(.parameters): !key.hasPrefix("_") @@ -941,7 +945,7 @@ extension Module { /// - ``freeze(recursive:keys:strict:)`` /// - ``filterValidParameters`` /// - ``filterLocalParameters`` - static public let filterTrainableParameters = { + static public let filterTrainableParameters: @Sendable (Module, String, ModuleItem) -> Bool = { (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary, .value(.parameters), .value(.module): @@ -954,7 +958,8 @@ extension Module { /// /// ### See Also /// - - static public let filterOther = { (module: Module, key: String, item: ModuleItem) in + static public let filterOther: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .value(.other): true default: false @@ -1062,7 +1067,8 @@ extension Module { /// ### See Also /// - /// - ``filterMap(filter:map:isLeaf:)`` - static public let isLeafDefault = { (module: Module, key: String, item: ModuleItem) in + static public let isLeafDefault: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary, .none, .value(.module): false case .value(.parameters), .value(.other), .value(.none): true @@ -1074,7 +1080,8 @@ extension Module { /// ### See Also /// - /// - ``filterMap(filter:map:isLeaf:)`` - static public let isLeafModule = { (module: Module, key: String, item: ModuleItem) in + static public let isLeafModule: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary, .none: false case .value(.module): true @@ -1087,7 +1094,8 @@ extension Module { /// ### See Also /// - /// - ``filterMap(filter:map:isLeaf:)`` - static public let isLeafModuleNoChildren = { (module: Module, key: String, item: ModuleItem) in + static public let isLeafModuleNoChildren: @Sendable (Module, String, ModuleItem) -> Bool = { + (module: Module, key: String, item: ModuleItem) in switch item { case .array, .dictionary, .none: false case .value(.module(let m)): m.children().isEmpty diff --git a/Source/MLXNN/Normalization.swift b/Source/MLXNN/Normalization.swift index 1f7a2716..ce839d45 100644 --- a/Source/MLXNN/Normalization.swift +++ b/Source/MLXNN/Normalization.swift @@ -199,17 +199,17 @@ open class GroupNorm: Module, UnaryLayer { let batch = x.dim(0) let dims = x.dim(-1) let rest = x.shape.dropFirst().dropLast() + let groupSize = dims / groupCount // split into groups - var x = x.reshaped(batch, -1, groupCount, dims / groupCount) - x = x.transposed(0, 1, 3, 2).reshaped(batch, -1, groupCount) + var x = x.reshaped(batch, -1, groupCount, groupSize) + x = x.transposed(0, 2, 1, 3).reshaped(batch, groupCount, -1) // normalize - let means = mean(x, axis: 1, keepDims: true) - let variance = variance(x, axis: 1, keepDims: true) - x = (x - means) * rsqrt(variance + eps) - x = x.reshaped(batch, -1, dims / groupCount, groupCount) - x = x.transposed(0, 1, 3, 2).reshaped([batch] + rest + [dims]) + x = MLXFast.layerNorm(x, weight: nil, bias: nil, eps: eps) + + x = x.reshaped(batch, groupCount, -1, groupSize) + x = x.transposed(0, 2, 1, 3).reshaped([batch] + rest + [dims]) return x } diff --git a/Source/MLXNN/Recurrent.swift b/Source/MLXNN/Recurrent.swift index 8c4d35a5..2e0fc2aa 100644 --- a/Source/MLXNN/Recurrent.swift +++ b/Source/MLXNN/Recurrent.swift @@ -39,7 +39,7 @@ open class RNN: Module { let scale = 1 / sqrt(Float(hiddenSize)) self._wxh.wrappedValue = MLXRandom.uniform( - low: -scale, high: scale, [inputSize, hiddenSize]) + low: -scale, high: scale, [hiddenSize, inputSize]) self._whh.wrappedValue = MLXRandom.uniform( low: -scale, high: scale, [hiddenSize, hiddenSize]) if bias { @@ -53,16 +53,16 @@ open class RNN: Module { var x = x if let bias { - x = addMM(bias, x, wxh) + x = addMM(bias, x, wxh.T) } else { - x = matmul(x, wxh) + x = matmul(x, wxh.T) } var hidden: MLXArray! = hidden var allHidden = [MLXArray]() for index in 0 ..< x.dim(-2) { if hidden != nil { - hidden = x[.ellipsis, index, 0...] + matmul(hidden, whh) + hidden = addMM(x[.ellipsis, index, 0...], hidden, whh.T) } else { hidden = x[.ellipsis, index, 0...] } @@ -107,9 +107,9 @@ open class GRU: Module { self.hiddenSize = hiddenSize let scale = 1 / sqrt(Float(hiddenSize)) self._wx.wrappedValue = MLXRandom.uniform( - low: -scale, high: scale, [inputSize, 3 * hiddenSize]) + low: -scale, high: scale, [3 * hiddenSize, inputSize]) self._wh.wrappedValue = MLXRandom.uniform( - low: -scale, high: scale, [hiddenSize, 3 * hiddenSize]) + low: -scale, high: scale, [3 * hiddenSize, inputSize]) if bias { self.b = MLXRandom.uniform(low: -scale, high: scale, [3 * hiddenSize]) self.bhn = MLXRandom.uniform(low: -scale, high: scale, [hiddenSize]) @@ -123,9 +123,9 @@ open class GRU: Module { var x = x if let b { - x = addMM(b, x, wx) + x = addMM(b, x, wx.T) } else { - x = matmul(x, wx) + x = matmul(x, wx.T) } let x_rz = x[.ellipsis, .stride(to: -hiddenSize)] @@ -138,7 +138,7 @@ open class GRU: Module { var rz = x_rz[.ellipsis, index, 0...] var hProj_n: MLXArray! if hidden != nil { - let hProj = matmul(hidden, wh) + let hProj = matmul(hidden, wh.T) let hProj_rz = hProj[.ellipsis, .stride(to: -hiddenSize)] hProj_n = hProj[.ellipsis, .stride(from: -hiddenSize)] @@ -204,9 +204,9 @@ open class LSTM: Module { public init(inputSize: Int, hiddenSize: Int, bias: Bool = true) { let scale = 1 / sqrt(Float(hiddenSize)) self._wx.wrappedValue = MLXRandom.uniform( - low: -scale, high: scale, [inputSize, 4 * hiddenSize]) + low: -scale, high: scale, [4 * hiddenSize, inputSize]) self._wh.wrappedValue = MLXRandom.uniform( - low: -scale, high: scale, [hiddenSize, 4 * hiddenSize]) + low: -scale, high: scale, [4 * hiddenSize, inputSize]) if bias { self.bias = MLXRandom.uniform(low: -scale, high: scale, [4 * hiddenSize]) } else { @@ -220,9 +220,9 @@ open class LSTM: Module { var x = x if let bias { - x = addMM(bias, x, wx) + x = addMM(bias, x, wx.T) } else { - x = matmul(x, wx) + x = matmul(x, wx.T) } var hidden: MLXArray! = hidden @@ -233,7 +233,7 @@ open class LSTM: Module { for index in 0 ..< x.dim(-2) { var ifgo = x[.ellipsis, index, 0...] if hidden != nil { - ifgo = ifgo + matmul(hidden, wh) + ifgo = addMM(ifgo, hidden, wh.T) } let pieces = split(ifgo, parts: 4, axis: -1) diff --git a/Source/MLXNN/Transformer.swift b/Source/MLXNN/Transformer.swift index 8e2feea7..213b0ee4 100644 --- a/Source/MLXNN/Transformer.swift +++ b/Source/MLXNN/Transformer.swift @@ -158,9 +158,9 @@ class TransformerEncoderLayer: Module { } else { y = attention(x, keys: x, values: x, mask: mask) y = dropout1(y) - y = ln1(x + y) + x = ln1(x + y) - y = linear1(y) + y = linear1(x) y = activation(y) y = dropout2(y) y = linear2(y) diff --git a/Source/MLXNN/Upsample.swift b/Source/MLXNN/Upsample.swift index 7244de0a..f6948ac6 100644 --- a/Source/MLXNN/Upsample.swift +++ b/Source/MLXNN/Upsample.swift @@ -230,7 +230,7 @@ private func linearIndices(dimension: Int, scale: Float, alignCorners: Bool, dim ] } -private let compiledGetWeight1: (MLXArray, MLXArray) -> MLXArray = { +private let compiledGetWeight1: @Sendable (MLXArray, MLXArray) -> MLXArray = { // PyTorch uses -0.5 for antialiasing=true (compatibility with PIL) // and uses -0.75 for antialiasing=false (compatibility with OpenCV) @@ -241,7 +241,7 @@ private let compiledGetWeight1: (MLXArray, MLXArray) -> MLXArray = { } }() -private let compiledGetWeight2: (MLXArray, MLXArray) -> MLXArray = { +private let compiledGetWeight2: @Sendable (MLXArray, MLXArray) -> MLXArray = { // PyTorch uses -0.5 for antialiasing=true (compatibility with PIL) // and uses -0.75 for antialiasing=false (compatibility with OpenCV) diff --git a/Source/MLXRandom/Cmlx+Util.swift b/Source/MLXRandom/Cmlx+Util.swift new file mode 100644 index 00000000..06d5024d --- /dev/null +++ b/Source/MLXRandom/Cmlx+Util.swift @@ -0,0 +1,47 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx +import Foundation +import MLX + +@inline(__always) +func mlx_free(_ ptr: OpaquePointer) { + mlx_free(UnsafeMutableRawPointer(ptr)) +} + +// return a +1 mlx_vector_array containing the given arrays +func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { + let result = mlx_vector_array_new()! + mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) + return result +} + +func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { + (0 ..< mlx_vector_array_size(vector_array)) + .map { index in + // ctx is a +1 object, the array takes ownership + let ctx = mlx_vector_array_get(vector_array, index)! + return MLXArray(ctx) + } +} + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { + let a = mlx_tuple_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_get_1(tuple)! + return (MLXArray(a), MLXArray(b)) +} + +func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { + let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! + defer { mlx_free(a) } + let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! + defer { mlx_free(b) } + return (mlx_vector_array_values(a), mlx_vector_array_values(b)) +} + +func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { + let a = mlx_tuple_array_array_array_get_0(tuple)! + let b = mlx_tuple_array_array_array_get_1(tuple)! + let c = mlx_tuple_array_array_array_get_2(tuple)! + return (MLXArray(a), MLXArray(b), MLXArray(c)) +} diff --git a/Source/MLXRandom/Random.swift b/Source/MLXRandom/Random.swift index 5ed48ac4..a49c782b 100644 --- a/Source/MLXRandom/Random.swift +++ b/Source/MLXRandom/Random.swift @@ -37,8 +37,9 @@ public func split(key: MLXArray, into num: Int, stream: StreamOrDevice = .defaul /// ### See Also /// - ``split(key:into:stream:)`` public func split(key: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { - let keys = MLXArray(mlx_random_split_equal_parts(key.ctx, 2, stream.ctx)) - return (keys[0], keys[1]) + let keys = mlx_random_split(key.ctx, stream.ctx)! + defer { mlx_free(keys) } + return mlx_tuple_values(keys) } /// Generate uniformly distributed random numbers with a `RangeExpression`. @@ -561,3 +562,20 @@ public func categorical( mlx_random_categorical_num_samples( logits.ctx, axis.int32, count.int32, key.ctx, stream.ctx)) } + +/// Sample numbers from a Laplace distribution. +/// +/// - Parameters: +/// - shape: shape of the output +/// - dtype: type of the output +/// - loc: mean of the distribution +/// - scale: scale "b" of the distribution +public func laplace( + _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, + key: MLXArray? = nil, stream: StreamOrDevice = .default +) -> MLXArray { + let key = key ?? globalState.next() + return MLXArray( + mlx_random_laplace( + shape.asInt32, shape.count, dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx)) +} diff --git a/Source/MLXRandom/State.swift b/Source/MLXRandom/State.swift index be4f6a42..70c1c105 100644 --- a/Source/MLXRandom/State.swift +++ b/Source/MLXRandom/State.swift @@ -3,9 +3,14 @@ import Foundation import MLX -/// Global random state -public class RandomState: Updatable, Evaluatable { +/// Global random state. +/// +/// Note: although this type is thread-safe, the MLXArrays that it returns are not -- do not +/// evaluate these values or expressions that depend on them across multiple threads +/// simultaneously. +public class RandomState: Updatable, Evaluatable, @unchecked (Sendable) { private var state: MLXArray + private let lock = NSLock() init() { let now = mach_approximate_time() @@ -13,17 +18,23 @@ public class RandomState: Updatable, Evaluatable { } public func innerState() -> [MLXArray] { - [state] + lock.withLock { + [state] + } } public func next() -> MLXArray { - let (a, b) = split(key: state) - self.state = a - return b + lock.withLock { + let (a, b) = split(key: state) + self.state = a + return b + } } public func seed(_ seed: UInt64) { - state = key(seed) + lock.withLock { + state = key(seed) + } } } diff --git a/Tests/MLXTests/MLXFastKernelTests.swift b/Tests/MLXTests/MLXFastKernelTests.swift new file mode 100644 index 00000000..0e0d22bc --- /dev/null +++ b/Tests/MLXTests/MLXFastKernelTests.swift @@ -0,0 +1,75 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXFast +import MLXRandom +import XCTest + +class MLXFastKernelTests: XCTestCase { + + func testCustomKernelBasic() { + // based on def test_custom_kernel_basic + MLXRandom.seed(7) + let a = normal([2, 2]) + let kernel = MLXFast.metalKernel( + name: "basic", + inputNames: ["a"], + outputNames: ["out1"], + source: """ + uint elem = thread_position_in_grid.x; + out1[elem] = a[elem]; + """) + + let out = kernel( + inputs: [a], + grid: (4, 1, 1), + threadGroup: (2, 1, 1), + outputShapes: [[2, 2]], + outputDTypes: [.float32]) + + XCTAssertTrue(allClose(out[0], a).all().item()) + } + + func testCustomKernelArgs() { + // based on def test_custom_kernel_args + MLXRandom.seed(7) + let a = normal([3, 6]) + let c = normal([2, 2]).asType(.bfloat16) + + let kernel = MLXFast.metalKernel( + name: "arg_test", + inputNames: ["a", "b", "c", "d"], + outputNames: ["out1", "out2"], + source: """ + uint elem = thread_position_in_grid.x; + T tmp = a[0]; + if (e) { + out1[elem] = a[1] + b[2] + c[3] + d + f; + } else { + out1[elem] = 1; + } + out2[elem] = a[1] + b[2] + c[1] - d; + """) + + let out = kernel( + inputs: [ + a, + MLXArray([3, 4, 5]), + c, + 7.3, + ], + template: [ + ("e", true), + ("f", 3), + ("T", DType.float16), + ], + grid: (6, 1, 1), + threadGroup: (2, 1, 1), + outputShapes: [[2, 2], [3, 2]], + outputDTypes: [.float32, .int32]) + + XCTAssertTrue(allClose(out[0], full([2, 2], values: 14.0484)).all().item()) + XCTAssertTrue(allClose(out[1], full([3, 2], values: -2)).all().item()) + } +} diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index 6d5a0f66..bdf7941d 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -35,6 +35,7 @@ make \ fft \ gather \ gemm \ + gemv_masked \ hadamard \ quantized \ reduce \