diff --git a/CMakeLists.txt b/CMakeLists.txt index 132a494f..7516d65d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,17 +11,9 @@ endif() FetchContent_Declare( mlx-c GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG "v0.0.8") + GIT_TAG "v0.0.9") FetchContent_MakeAvailable(mlx-c) -# TEMPORARY OVERRIDE -- 0.0.8 depends on v0.14.0 but we need v0.15.2 for iOS / -# float16 issues -FetchContent_Declare( - mlx - GIT_REPOSITORY "https://github.com/ml-explore/mlx.git" - GIT_TAG v0.15.2) -FetchContent_MakeAvailable(mlx) - # swift-numerics set(swift_numerics_patch git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/swift-numerics.patch) diff --git a/Package.swift b/Package.swift index 09c15177..bdb1a740 100644 --- a/Package.swift +++ b/Package.swift @@ -105,6 +105,10 @@ let package = Package( "mlx/mlx/distributed/mpi", "mlx/mlx/distributed/ops.cpp", "mlx/mlx/distributed/primitives.cpp", + + // the mlx-c side of distributed + "include/mlx/c/distributed.cpp", + "include/mlx/c/distributed_group.cpp", ], cSettings: [ diff --git a/Source/Cmlx/mlx b/Source/Cmlx/mlx index d6383a1c..d0da7420 160000 --- a/Source/Cmlx/mlx +++ b/Source/Cmlx/mlx @@ -1 +1 @@ -Subproject commit d6383a1c6a2d663d6912bf20f3f57c727cd606d5 +Subproject commit d0da74209bbb5a1ca31e1f99d8f2d57750b918cb diff --git a/Source/Cmlx/mlx-c b/Source/Cmlx/mlx-c index 44910a61..82f176ac 160000 --- a/Source/Cmlx/mlx-c +++ b/Source/Cmlx/mlx-c @@ -1 +1 @@ -Subproject commit 44910a6117209ba3bb4db4398f6a9c0f15106d8d +Subproject commit 82f176ac84ea3217e6b5fd11b4104f0b0e5a8166 diff --git a/Source/Cmlx/mlx-generated/hadamard.cpp b/Source/Cmlx/mlx-generated/hadamard.cpp new file mode 100644 index 00000000..2bcf5275 --- /dev/null +++ b/Source/Cmlx/mlx-generated/hadamard.cpp @@ -0,0 +1,128 @@ +namespace mlx::core::metal { + +const char* hadamard() { + return R"preamble( + +using namespace metal; +template +METAL_FUNC void radix_func(thread float* x) { + constexpr short logR = __builtin_ctz(R); + short h = 1; +#pragma clang loop unroll(full) + for (short s = 0; s < logR; s++) { +#pragma clang loop unroll(full) + for (short i = 0; i < R / 2; i++) { + short k = i & (h - 1); + short j = ((i - k) << 1) + k; + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + h <<= 1; + } +} +template +[[kernel]] void hadamard_n( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + constexpr short num_threads = N / max_radix; + constexpr short logN = __builtin_ctz(N); + constexpr short logR = __builtin_ctz(max_radix); + constexpr short num_steps = logN / logR; + constexpr short logFinal = logN % logR; + constexpr short final_radix = 1 << (logFinal); + int batch_idx = elem.x * N; + short i = elem.y; + threadgroup T buf[N]; +#pragma clang loop unroll(full) + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; +#pragma clang loop unroll(full) + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float x[max_radix]; + short h = 1; +#pragma clang loop unroll(full) + for (short s = 0; s < num_steps; s++) { + short k = i & (h - 1); + short j = ((i - k) << logR) + k; +#pragma clang loop unroll(full) + for (short r = 0; r < max_radix; r++) { + x[r] = buf[j + h * r]; + } + radix_func(x); +#pragma clang loop unroll(full) + for (short r = 0; r < max_radix; r++) { + buf[j + h * r] = x[r]; + } + h <<= logR; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + if (final_radix > 1) { +#pragma clang loop unroll(full) + for (int t = 0; t < max_radix / final_radix; t++) { + short index = i + t * num_threads; + short k = index & (h - 1); + short j = ((index - k) << logFinal) + k; +#pragma clang loop unroll(full) + for (short r = 0; r < final_radix; r++) { + x[r] = buf[j + h * r]; + } + radix_func(x); +#pragma clang loop unroll(full) + for (short r = 0; r < final_radix; r++) { + buf[j + h * r] = x[r]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +#pragma clang loop unroll(full) + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; +#pragma clang loop unroll(full) + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = buf[index + r] * scale; + } + } +} +template +[[kernel]] void hadamard_m( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + int index = elem.x * grid.y + elem.y; + short i = index % (N / read_width); + int batch_idx = index / (N / read_width) * M * N; + float x[read_width][M]; +#pragma clang loop unroll(full) + for (short c = 0; c < M; c++) { +#pragma clang loop unroll(full) + for (short r = 0; r < read_width; r++) { + x[r][c] = in[batch_idx + c * N + i * read_width + r]; + } + } +#pragma clang loop unroll(full) + for (short r = 0; r < read_width; r++) { + hadamard_radix_m(x[r]); + } +#pragma clang loop unroll(full) + for (short c = 0; c < M; c++) { +#pragma clang loop unroll(full) + for (short r = 0; r < read_width; r++) { + out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale; + } + } +} +)preamble"; +} + +} // namespace mlx::core::metal diff --git a/Source/MLX/Documentation.docc/free-functions.md b/Source/MLX/Documentation.docc/free-functions.md index 46b20f43..56127860 100644 --- a/Source/MLX/Documentation.docc/free-functions.md +++ b/Source/MLX/Documentation.docc/free-functions.md @@ -222,3 +222,4 @@ operations as methods for convenience. - ``diag(_:k:stream:)`` - ``diagonal(_:offset:axis1:axis2:stream:)`` +- ``view(_:dtype:stream:)`` diff --git a/Source/MLX/MLXArray+Ops.swift b/Source/MLX/MLXArray+Ops.swift index 04898615..86454416 100644 --- a/Source/MLX/MLXArray+Ops.swift +++ b/Source/MLX/MLXArray+Ops.swift @@ -2658,4 +2658,20 @@ extension MLXArray { MLXArray(mlx_var_all(ctx, keepDims, ddof.int32, stream.ctx)) } + /// View the array as a different type. + /// + /// The output array will change along the last axis if the input array's + /// type and the output array's type do not have the same size. + /// + /// Note: the view op does not imply that the input and output arrays share + /// their underlying data. The view only gaurantees that the binary + /// representation of each element (or group of elements) is the same. + /// + /// - Parameters: + /// - dtype: type to change to + /// - stream: stream or device to evaluate on + /// - Returns: array with the new type + public func view(dtype: DType, stream: StreamOrDevice = .default) -> MLXArray { + MLXArray(mlx_view(ctx, dtype.cmlxDtype, stream.ctx)) + } } diff --git a/Source/MLX/Ops+Array.swift b/Source/MLX/Ops+Array.swift index ce115343..d285b3a1 100644 --- a/Source/MLX/Ops+Array.swift +++ b/Source/MLX/Ops+Array.swift @@ -1715,3 +1715,22 @@ public func variance( ) -> MLXArray { MLXArray(mlx_var_all(array.ctx, keepDims, ddof.int32, stream.ctx)) } + +/// View the array as a different type. +/// +/// The output array will change along the last axis if the input array's +/// type and the output array's type do not have the same size. +/// +/// Note: the view op does not imply that the input and output arrays share +/// their underlying data. The view only gaurantees that the binary +/// representation of each element (or group of elements) is the same. +/// +/// - Parameters: +/// - dtype: type to change to +/// - stream: stream or device to evaluate on +/// +/// ### See Also +///- ``MLXArray/view(dtype:stream:)`` +public func view(_ array: MLXArray, dtype: DType, stream: StreamOrDevice = .default) -> MLXArray { + MLXArray(mlx_view(array.ctx, dtype.cmlxDtype, stream.ctx)) +} diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index 585e13ba..6d5a0f66 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -18,7 +18,8 @@ cmake ../Source/Cmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0 # NOTE: # until mlx supports overriding the METAL_VERSION you will need to edit -# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION. +# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION +# to "3.0" # # Also Plugins/PrepareMetalShaders/main.swift kernels needs to be in sync. @@ -34,6 +35,7 @@ make \ fft \ gather \ gemm \ + hadamard \ quantized \ reduce \ reduce_utils \