diff --git a/CMakeLists.txt b/CMakeLists.txt index 4398e563..096e6e0b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ endif() FetchContent_Declare( mlx-c GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG "v0.0.10") + GIT_TAG "v0.1.0") FetchContent_MakeAvailable(mlx-c) # swift-numerics diff --git a/Package.swift b/Package.swift index edfd3347..981079c9 100644 --- a/Package.swift +++ b/Package.swift @@ -204,15 +204,6 @@ let package = Package( sources: ["Tutorial.swift"] ), - // ------ - // Internal Tools - - .executableTarget( - name: "GenerateGrad", - path: "Source/Tools", - sources: ["GenerateGrad.swift"] - ), - ], cxxLanguageStandard: .gnucxx17 ) diff --git a/Source/Cmlx/mlx b/Source/Cmlx/mlx index c21331d4..bb303c45 160000 --- a/Source/Cmlx/mlx +++ b/Source/Cmlx/mlx @@ -1 +1 @@ -Subproject commit c21331d47f65477795a0dcfa7223f115d4200871 +Subproject commit bb303c45a55d7147bc261e9aa8be218d49500d09 diff --git a/Source/Cmlx/mlx-c b/Source/Cmlx/mlx-c index a1b66041..e8889ddf 160000 --- a/Source/Cmlx/mlx-c +++ b/Source/Cmlx/mlx-c @@ -1 +1 @@ -Subproject commit a1b66041f4ffdf2bcb8000f6f2919e7be19e1523 +Subproject commit e8889ddf56b3bd734eb79d5e5b13586a39c01403 diff --git a/Source/Cmlx/mlx-generated/binary.cpp b/Source/Cmlx/mlx-generated/binary.cpp index d336683e..627045c2 100644 --- a/Source/Cmlx/mlx-generated/binary.cpp +++ b/Source/Cmlx/mlx-generated/binary.cpp @@ -72,11 +72,11 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); c[index] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, @@ -85,12 +85,12 @@ template constant const size_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, @@ -99,13 +99,17 @@ template constant const size_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void binary_g( device const T* a, device const T* b, @@ -116,13 +120,12 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { c[out_idx++] = Op()(a[idx.x], b[idx.y]); idx.x += a_xstride; diff --git a/Source/Cmlx/mlx-generated/binary_two.cpp b/Source/Cmlx/mlx-generated/binary_two.cpp index 3b551c3c..9b70de4f 100644 --- a/Source/Cmlx/mlx-generated/binary_two.cpp +++ b/Source/Cmlx/mlx-generated/binary_two.cpp @@ -94,13 +94,13 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); auto out = Op()(a[a_idx], b[b_idx]); c[index] = out[0]; d[index] = out[1]; } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, @@ -110,14 +110,14 @@ template constant const size_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, @@ -127,15 +127,19 @@ template constant const size_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void binary_g( device const T* a, device const T* b, @@ -147,13 +151,12 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { auto out = Op()(a[idx.x], b[idx.y]); c[out_idx] = out[0]; diff --git a/Source/Cmlx/mlx-generated/compiled_preamble.cpp b/Source/Cmlx/mlx-generated/compiled_preamble.cpp index ba69c97a..0c1453de 100644 --- a/Source/Cmlx/mlx-generated/compiled_preamble.cpp +++ b/Source/Cmlx/mlx-generated/compiled_preamble.cpp @@ -593,6 +593,13 @@ struct Floor { } }; +struct Imag { + template + T operator()(T x) { + return std::imag(x); + } +}; + struct Log { template T operator()(T x) { @@ -635,6 +642,13 @@ struct Negative { } }; +struct Real { + template + T operator()(T x) { + return std::real(x); + } +}; + struct Round { template T operator()(T x) { diff --git a/Source/Cmlx/mlx-generated/conv.cpp b/Source/Cmlx/mlx-generated/conv.cpp index 13593557..c3a3e38e 100644 --- a/Source/Cmlx/mlx-generated/conv.cpp +++ b/Source/Cmlx/mlx-generated/conv.cpp @@ -693,10 +693,396 @@ struct BlockSwizzle { }; } } +#pragma METAL internals : enable +namespace metal { +template +struct is_empty : metal::bool_constant<__is_empty(T)> {}; +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; +template +struct is_static : metal::bool_constant>::value> {}; +template +struct pointer_element {}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +using pointer_element_t = typename pointer_element>::type; +} +#pragma METAL internals : disable + +#pragma METAL internals : enable +namespace mlx { +namespace steel { +template +struct integral_constant { + static constexpr constant T value = v; + using value_type = T; + using type = integral_constant; + METAL_FUNC constexpr operator value_type() const noexcept { + return value; + } +}; +template +using bool_constant = integral_constant; +using true_type = bool_constant; +using false_type = bool_constant; +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral> + : bool_constant::value> {}; +template +constexpr constant bool is_integral_v = is_integral::value; +template +using Int = integral_constant; +template METAL_FUNC constexpr auto operator+( integral_constant, integral_constant) { constexpr auto res = tv + uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator-( integral_constant, integral_constant) { constexpr auto res = tv - uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator*( integral_constant, integral_constant) { constexpr auto res = tv * uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator/( integral_constant, integral_constant) { constexpr auto res = tv / uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator==( integral_constant, integral_constant) { constexpr auto res = tv == uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator!=( integral_constant, integral_constant) { constexpr auto res = tv != uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator<( integral_constant, integral_constant) { constexpr auto res = tv < uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator>( integral_constant, integral_constant) { constexpr auto res = tv > uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator<=( integral_constant, integral_constant) { constexpr auto res = tv <= uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator>=( integral_constant, integral_constant) { constexpr auto res = tv >= uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator&&( integral_constant, integral_constant) { constexpr auto res = tv && uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator||( integral_constant, integral_constant) { constexpr auto res = tv || uv; return integral_constant{}; }; +template +METAL_FUNC constexpr T sum(T x) { + return x; +} +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} +} +} +#pragma METAL internals : disable using namespace metal; namespace mlx { namespace steel { +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; +template +struct BaseMMAFrag { + static constant constexpr const int kFragRows = 8; + static constant constexpr const int kFragCols = 8; + static constant constexpr const int kElemsPerFrag = (kFragRows * kFragCols) / 32; + static constant constexpr const int kElemRows = 1; + static constant constexpr const int kElemCols = 2; + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + mma(D_mat, A_mat, B_mat, C_mat); + D = reinterpret_cast(D_mat.thread_elements()); + } + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + static constant constexpr const int kFragRows = MMAFrag_t::kFragRows; + static constant constexpr const int kFragCols = MMAFrag_t::kFragCols; + static constant constexpr const int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + static constant constexpr const int kTileRows = kTileRows_; + static constant constexpr const int kTileCols = kTileCols_; + static constant constexpr const int kRows = kTileRows * kFragRows; + static constant constexpr const int kCols = kTileCols * kFragCols; + static constant constexpr const int kNumFrags = kTileRows * kTileCols; + static constant constexpr const int kElemsPerTile = kNumFrags * kElemsPerFrag; + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + frag_type val_frags[kNumFrags] = {frag_type(0)}; + METAL_FUNC MMATile() thread {} + METAL_FUNC constexpr void clear() { +#pragma clang loop unroll(full) + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; +#pragma clang loop unroll(full) + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + template + METAL_FUNC void load(const threadgroup U* src) { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + template + METAL_FUNC void store(threadgroup U* dst) const { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + template + METAL_FUNC void load(const device U* src, const int ld) { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + template + METAL_FUNC void store(device U* dst, const int ld) const { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { +#pragma clang loop unroll(full) + for (int i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { +#pragma clang loop unroll(full) + for (int i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { +#pragma clang loop unroll(full) + for (short m = 0; m < M; ++m) { +#pragma clang loop unroll(full) + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; +#pragma clang loop unroll(full) + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} template < typename T, typename U, @@ -712,125 +1098,78 @@ template < typename AccumType = float, typename Epilogue = TransformNone> struct BlockMMA { - static constant constexpr const short TM_stride = 8 * WM; - static constant constexpr const short TN_stride = 8 * WN; - static constant constexpr const short TM = BM / TM_stride; - static constant constexpr const short TN = BN / TN_stride; - static constant constexpr const short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - static constant constexpr const short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - static constant constexpr const short jump_a = {transpose_a ? lda_tgp : 1}; - static constant constexpr const short jump_b = {transpose_b ? ldb_tgp : 1}; - static constant constexpr const short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - static constant constexpr const short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - const short tm; - const short tn; + static constant constexpr const short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + static constant constexpr const short TM_stride = kFragSize * WM; + static constant constexpr const short TN_stride = kFragSize * WN; + static constant constexpr const short TM = BM / (kFragSize * WM); + static constant constexpr const short TN = BN / (kFragSize * WN); + static constant constexpr const short A_str_m = transpose_a ? 1 : lda_tgp; + static constant constexpr const short A_str_k = transpose_a ? lda_tgp : 1; + static constant constexpr const short B_str_k = transpose_b ? 1 : ldb_tgp; + static constant constexpr const short B_str_n = transpose_b ? ldb_tgp : 1; + static constant constexpr const short tile_stride_a = kFragSize * A_str_k; + static constant constexpr const short tile_stride_b = kFragSize * B_str_k; + MMATile Atile; + MMATile Btile; + MMATile Ctile; short sm; short sn; short As_offset; short Bs_offset; METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - short qid = simd_lane_id / 4; - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; + sm += tm; + sn += tn; } METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { As += As_offset; Bs += Bs_offset; #pragma clang loop unroll(full) - for (short kk = 0; kk < BK; kk += 8) { + for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } + Atile.template load(As); simdgroup_barrier(mem_flags::mem_none); -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } + Btile.template load(Bs); simdgroup_barrier(mem_flags::mem_none); -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } + tile_matmad(Ctile, Atile, Btile, Ctile); As += tile_stride_a; Bs += tile_stride_b; } } - METAL_FUNC void store_result(device U* D, const int ldd) const { - D += (sm + tm) * ldd + tn + sn; -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { + METAL_FUNC void store_result(device U* D, const int ldd) { #pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; - D[offset] = outs[0]; - D[offset + 1] = outs[1]; - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } + D += sm * ldd + sn; + Ctile.template store(D, ldd); } METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { - D += (sm + tm) * ldd + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { #pragma clang loop unroll(full) - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { -#pragma clang loop unroll(full) - for (int j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - if (j * TN_stride < dst_tile_dims.x) { - D[offset] = Epilogue::apply(accum[0]); - } - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset + 1] = Epilogue::apply(accum[1]); - } - } - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + Ctile.template store_safe(D, ldd, dst_tile_dims); } template METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { #pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - thread auto& accum = results[i * TN + j].thread_elements(); - accum[0] = epilogue_op.apply(accum[0]); - accum[1] = epilogue_op.apply(accum[1]); - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } template @@ -839,15 +1178,17 @@ struct BlockMMA { const int ldc, const int fdc, thread const BinaryEpilogue& epilogue_op) { - C += (sm + tm) * ldc + (tn + sn) * fdc; + C += (sm)*ldc + (sn)*fdc; #pragma clang loop unroll(full) for (short i = 0; i < TM; i++) { #pragma clang loop unroll(full) for (short j = 0; j < TN; j++) { - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - accum[0] = epilogue_op.apply(accum[0], C[offset_c]); - accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); +#pragma clang loop unroll(full) + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -858,25 +1199,28 @@ struct BlockMMA { const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { - C += (sm + tm) * ldc + (tn + sn) * fdc; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; #pragma clang loop unroll(full) for (short i = 0; i < TM; i++) { #pragma clang loop unroll(full) for (short j = 0; j < TN; j++) { - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - U c_elems[2] = {0}; - if ((j * TN_stride + 1) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; - c_elems[1] = C[offset_c + fdc]; - } else if ((j * TN_stride) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + U c_elems[kelems] = {0}; +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); } - accum[0] = epilogue_op.apply(accum[0], c_elems[0]); - accum[1] = epilogue_op.apply(accum[1], c_elems[1]); } } } @@ -887,20 +1231,20 @@ struct BlockMMA { const int ldc, const int fdc, thread const Epilogue& epilogue_op) const { - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; #pragma clang loop unroll(full) for (short i = 0; i < TM; i++) { #pragma clang loop unroll(full) for (short j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -912,24 +1256,26 @@ struct BlockMMA { const int fdc, short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; #pragma clang loop unroll(full) for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { #pragma clang loop unroll(full) for (int j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } diff --git a/Source/Cmlx/mlx-generated/copy.cpp b/Source/Cmlx/mlx-generated/copy.cpp index 69ab0247..117c5c4a 100644 --- a/Source/Cmlx/mlx-generated/copy.cpp +++ b/Source/Cmlx/mlx-generated/copy.cpp @@ -40,33 +40,33 @@ template device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); + auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; + auto src_idx = elem_to_loc_2(index, src_strides); + IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + auto src_idx = elem_to_loc_3(index, src_strides); + IdxT dst_idx = + index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -75,17 +75,16 @@ template constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc( + auto src_idx = elem_to_loc( {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); if (N == 1) { - int64_t dst_idx = - index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); + IdxT dst_idx = + index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); return; } auto xshape = src_shape[ndim - 1]; - int64_t dst_idx = - N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z); + IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); auto src_xstride = src_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[dst_idx + i] = static_cast(src[src_idx]); @@ -99,33 +98,33 @@ template constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -134,7 +133,7 @@ template constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, src_shape, src_strides, @@ -144,8 +143,8 @@ template dst[idx.y] = static_cast(src[idx.x]); return; } - auto src_xstride = src_strides[ndim - 1]; - auto dst_xstride = dst_strides[ndim - 1]; + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; auto xshape = src_shape[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[idx.y] = static_cast(src[idx.x]); diff --git a/Source/Cmlx/mlx-generated/gather.cpp b/Source/Cmlx/mlx-generated/gather.cpp index 954baf7c..57103f94 100644 --- a/Source/Cmlx/mlx-generated/gather.cpp +++ b/Source/Cmlx/mlx-generated/gather.cpp @@ -7,10 +7,11 @@ struct Indices { const array buffers; const constant int* shapes; const constant size_t* strides; + const constant bool* row_contiguous; const int ndim; }; template -METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { if (is_unsigned_v) { return idx; } else { @@ -18,7 +19,7 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { } } -template +template METAL_FUNC void gather_impl( const device T* src [[buffer(0)]], device T* out [[buffer(1)]], @@ -30,32 +31,34 @@ METAL_FUNC void gather_impl( const thread Indices& indices, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - size_t src_idx = 0; + LocT src_idx = 0; for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; + LocT idx_loc; if (IDX_NDIM == 0) { idx_loc = 0; } else if (IDX_NDIM == 1) { - idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); } else { - idx_loc = index.x * indices.strides[indices.ndim * i]; - idx_loc += elem_to_loc( - index.y, - &indices.shapes[indices.ndim * i + 1], - &indices.strides[indices.ndim * i + 1], - indices.ndim - 1); + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); + idx_loc += indices.row_contiguous[i] + ? index.y + : elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; + src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); } - auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); - size_t out_idx = index.z; + auto src_offset = + elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + LocT out_idx = index.z; if (IDX_NDIM == 1) { - out_idx += static_cast(grid_dim.z) * index.x; + out_idx += static_cast(grid_dim.z) * index.x; } else if (IDX_NDIM >= 2) { - out_idx += - grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); } out[out_idx] = src[src_offset + src_idx]; } diff --git a/Source/Cmlx/mlx-generated/gemm.cpp b/Source/Cmlx/mlx-generated/gemm.cpp index 1152a759..8c490e38 100644 --- a/Source/Cmlx/mlx-generated/gemm.cpp +++ b/Source/Cmlx/mlx-generated/gemm.cpp @@ -182,10 +182,396 @@ struct BlockSwizzle { }; } } +#pragma METAL internals : enable +namespace metal { +template +struct is_empty : metal::bool_constant<__is_empty(T)> {}; +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; +template +struct is_static : metal::bool_constant>::value> {}; +template +struct pointer_element {}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +using pointer_element_t = typename pointer_element>::type; +} +#pragma METAL internals : disable + +#pragma METAL internals : enable +namespace mlx { +namespace steel { +template +struct integral_constant { + static constexpr constant T value = v; + using value_type = T; + using type = integral_constant; + METAL_FUNC constexpr operator value_type() const noexcept { + return value; + } +}; +template +using bool_constant = integral_constant; +using true_type = bool_constant; +using false_type = bool_constant; +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral> + : bool_constant::value> {}; +template +constexpr constant bool is_integral_v = is_integral::value; +template +using Int = integral_constant; +template METAL_FUNC constexpr auto operator+( integral_constant, integral_constant) { constexpr auto res = tv + uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator-( integral_constant, integral_constant) { constexpr auto res = tv - uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator*( integral_constant, integral_constant) { constexpr auto res = tv * uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator/( integral_constant, integral_constant) { constexpr auto res = tv / uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator==( integral_constant, integral_constant) { constexpr auto res = tv == uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator!=( integral_constant, integral_constant) { constexpr auto res = tv != uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator<( integral_constant, integral_constant) { constexpr auto res = tv < uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator>( integral_constant, integral_constant) { constexpr auto res = tv > uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator<=( integral_constant, integral_constant) { constexpr auto res = tv <= uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator>=( integral_constant, integral_constant) { constexpr auto res = tv >= uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator&&( integral_constant, integral_constant) { constexpr auto res = tv && uv; return integral_constant{}; }; +template METAL_FUNC constexpr auto operator||( integral_constant, integral_constant) { constexpr auto res = tv || uv; return integral_constant{}; }; +template +METAL_FUNC constexpr T sum(T x) { + return x; +} +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} +} +} +#pragma METAL internals : disable using namespace metal; namespace mlx { namespace steel { +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; +template +struct BaseMMAFrag { + static constant constexpr const int kFragRows = 8; + static constant constexpr const int kFragCols = 8; + static constant constexpr const int kElemsPerFrag = (kFragRows * kFragCols) / 32; + static constant constexpr const int kElemRows = 1; + static constant constexpr const int kElemCols = 2; + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; +#pragma clang loop unroll(full) + for (short i = 0; i < kElemRows; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + mma(D_mat, A_mat, B_mat, C_mat); + D = reinterpret_cast(D_mat.thread_elements()); + } + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + static constant constexpr const int kFragRows = MMAFrag_t::kFragRows; + static constant constexpr const int kFragCols = MMAFrag_t::kFragCols; + static constant constexpr const int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + static constant constexpr const int kTileRows = kTileRows_; + static constant constexpr const int kTileCols = kTileCols_; + static constant constexpr const int kRows = kTileRows * kFragRows; + static constant constexpr const int kCols = kTileCols * kFragCols; + static constant constexpr const int kNumFrags = kTileRows * kTileCols; + static constant constexpr const int kElemsPerTile = kNumFrags * kElemsPerFrag; + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + frag_type val_frags[kNumFrags] = {frag_type(0)}; + METAL_FUNC MMATile() thread {} + METAL_FUNC constexpr void clear() { +#pragma clang loop unroll(full) + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; +#pragma clang loop unroll(full) + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + template + METAL_FUNC void load(const threadgroup U* src) { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + template + METAL_FUNC void store(threadgroup U* dst) const { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + template + METAL_FUNC void load(const device U* src, const int ld) { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + template + METAL_FUNC void store(device U* dst, const int ld) const { +#pragma clang loop unroll(full) + for (short i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { +#pragma clang loop unroll(full) + for (int i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { +#pragma clang loop unroll(full) + for (int i = 0; i < kTileRows; ++i) { +#pragma clang loop unroll(full) + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { +#pragma clang loop unroll(full) + for (short m = 0; m < M; ++m) { +#pragma clang loop unroll(full) + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; +#pragma clang loop unroll(full) + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} template < typename T, typename U, @@ -201,125 +587,78 @@ template < typename AccumType = float, typename Epilogue = TransformNone> struct BlockMMA { - static constant constexpr const short TM_stride = 8 * WM; - static constant constexpr const short TN_stride = 8 * WN; - static constant constexpr const short TM = BM / TM_stride; - static constant constexpr const short TN = BN / TN_stride; - static constant constexpr const short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - static constant constexpr const short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - static constant constexpr const short jump_a = {transpose_a ? lda_tgp : 1}; - static constant constexpr const short jump_b = {transpose_b ? ldb_tgp : 1}; - static constant constexpr const short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - static constant constexpr const short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - const short tm; - const short tn; + static constant constexpr const short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + static constant constexpr const short TM_stride = kFragSize * WM; + static constant constexpr const short TN_stride = kFragSize * WN; + static constant constexpr const short TM = BM / (kFragSize * WM); + static constant constexpr const short TN = BN / (kFragSize * WN); + static constant constexpr const short A_str_m = transpose_a ? 1 : lda_tgp; + static constant constexpr const short A_str_k = transpose_a ? lda_tgp : 1; + static constant constexpr const short B_str_k = transpose_b ? 1 : ldb_tgp; + static constant constexpr const short B_str_n = transpose_b ? ldb_tgp : 1; + static constant constexpr const short tile_stride_a = kFragSize * A_str_k; + static constant constexpr const short tile_stride_b = kFragSize * B_str_k; + MMATile Atile; + MMATile Btile; + MMATile Ctile; short sm; short sn; short As_offset; short Bs_offset; METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - short qid = simd_lane_id / 4; - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; + sm += tm; + sn += tn; } METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { As += As_offset; Bs += Bs_offset; #pragma clang loop unroll(full) - for (short kk = 0; kk < BK; kk += 8) { + for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } + Atile.template load(As); simdgroup_barrier(mem_flags::mem_none); -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } + Btile.template load(Bs); simdgroup_barrier(mem_flags::mem_none); -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } + tile_matmad(Ctile, Atile, Btile, Ctile); As += tile_stride_a; Bs += tile_stride_b; } } - METAL_FUNC void store_result(device U* D, const int ldd) const { - D += (sm + tm) * ldd + tn + sn; -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { + METAL_FUNC void store_result(device U* D, const int ldd) { #pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; - D[offset] = outs[0]; - D[offset + 1] = outs[1]; - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } + D += sm * ldd + sn; + Ctile.template store(D, ldd); } METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { - D += (sm + tm) * ldd + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { #pragma clang loop unroll(full) - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { -#pragma clang loop unroll(full) - for (int j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - if (j * TN_stride < dst_tile_dims.x) { - D[offset] = Epilogue::apply(accum[0]); - } - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset + 1] = Epilogue::apply(accum[1]); - } - } - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); } + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + Ctile.template store_safe(D, ldd, dst_tile_dims); } template METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { #pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - thread auto& accum = results[i * TN + j].thread_elements(); - accum[0] = epilogue_op.apply(accum[0]); - accum[1] = epilogue_op.apply(accum[1]); - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } template @@ -328,15 +667,17 @@ struct BlockMMA { const int ldc, const int fdc, thread const BinaryEpilogue& epilogue_op) { - C += (sm + tm) * ldc + (tn + sn) * fdc; + C += (sm)*ldc + (sn)*fdc; #pragma clang loop unroll(full) for (short i = 0; i < TM; i++) { #pragma clang loop unroll(full) for (short j = 0; j < TN; j++) { - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - accum[0] = epilogue_op.apply(accum[0], C[offset_c]); - accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); +#pragma clang loop unroll(full) + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -347,25 +688,28 @@ struct BlockMMA { const int fdc, short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { - C += (sm + tm) * ldc + (tn + sn) * fdc; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; #pragma clang loop unroll(full) for (short i = 0; i < TM; i++) { #pragma clang loop unroll(full) for (short j = 0; j < TN; j++) { - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - U c_elems[2] = {0}; - if ((j * TN_stride + 1) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; - c_elems[1] = C[offset_c + fdc]; - } else if ((j * TN_stride) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + U c_elems[kelems] = {0}; +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); } - accum[0] = epilogue_op.apply(accum[0], c_elems[0]); - accum[1] = epilogue_op.apply(accum[1], c_elems[1]); } } } @@ -376,20 +720,20 @@ struct BlockMMA { const int ldc, const int fdc, thread const Epilogue& epilogue_op) const { - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; #pragma clang loop unroll(full) for (short i = 0; i < TM; i++) { #pragma clang loop unroll(full) for (short j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -401,24 +745,26 @@ struct BlockMMA { const int fdc, short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; #pragma clang loop unroll(full) for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { #pragma clang loop unroll(full) for (int j = 0; j < TN; j++) { - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } diff --git a/Source/Cmlx/mlx-generated/metal/bf16.h b/Source/Cmlx/mlx-generated/metal/bf16.h index e8c22afb..f5d48670 100644 --- a/Source/Cmlx/mlx-generated/metal/bf16.h +++ b/Source/Cmlx/mlx-generated/metal/bf16.h @@ -6,12 +6,6 @@ using namespace metal; -#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) - -typedef bfloat bfloat16_t; - -#else - ///////////////////////////////////////////////////////////////////////////// // Helpers ///////////////////////////////////////////////////////////////////////////// @@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) { } // namespace metal #pragma METAL internals : disable +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return x.bits_; +} -#endif - -#include "bf16_math.h" +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()); +} diff --git a/Source/Cmlx/mlx-generated/metal/bf16_math.h b/Source/Cmlx/mlx-generated/metal/bf16_math.h index 79e1ef15..0643fb3e 100644 --- a/Source/Cmlx/mlx-generated/metal/bf16_math.h +++ b/Source/Cmlx/mlx-generated/metal/bf16_math.h @@ -2,8 +2,6 @@ #pragma once -#include "bf16.h" - /////////////////////////////////////////////////////////////////////////////// // Metal math for bfloat16 /////////////////////////////////////////////////////////////////////////////// @@ -369,18 +367,6 @@ instantiate_metal_math_funcs( return static_cast(__metal_simd_xor(static_cast(data))); \ } -#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310) - -#define bfloat16_to_uint16(x) as_type(x) -#define uint16_to_bfloat16(x) as_type(x) - -#else - -#define bfloat16_to_uint16(x) x.bits_ -#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()) - -#endif - namespace metal { instantiate_metal_simd_comm_funcs( diff --git a/Source/Cmlx/mlx-generated/metal/binary.h b/Source/Cmlx/mlx-generated/metal/binary.h index d64488e9..4b260bc3 100644 --- a/Source/Cmlx/mlx-generated/metal/binary.h +++ b/Source/Cmlx/mlx-generated/metal/binary.h @@ -77,12 +77,12 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); c[index] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, @@ -91,13 +91,13 @@ template constant const size_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, @@ -106,14 +106,18 @@ template constant const size_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void binary_g( device const T* a, device const T* b, @@ -124,13 +128,12 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { c[out_idx++] = Op()(a[idx.x], b[idx.y]); idx.x += a_xstride; diff --git a/Source/Cmlx/mlx-generated/metal/binary_two.h b/Source/Cmlx/mlx-generated/metal/binary_two.h index a4a3130b..6057dd41 100644 --- a/Source/Cmlx/mlx-generated/metal/binary_two.h +++ b/Source/Cmlx/mlx-generated/metal/binary_two.h @@ -99,14 +99,14 @@ template constant const size_t& a_stride, constant const size_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); auto out = Op()(a[a_idx], b[b_idx]); c[index] = out[0]; d[index] = out[1]; } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, @@ -116,15 +116,15 @@ template constant const size_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, @@ -134,16 +134,20 @@ template constant const size_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void binary_g( device const T* a, device const T* b, @@ -155,13 +159,12 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { auto out = Op()(a[idx.x], b[idx.y]); c[out_idx] = out[0]; diff --git a/Source/Cmlx/mlx-generated/metal/conv.metal b/Source/Cmlx/mlx-generated/metal/conv.metal index c03d09c3..12f79eea 100644 --- a/Source/Cmlx/mlx-generated/metal/conv.metal +++ b/Source/Cmlx/mlx-generated/metal/conv.metal @@ -4,8 +4,8 @@ #include #include -#include "bf16.h" #include "steel/conv/params.h" +#include "utils.h" #define MLX_MTL_CONST static constant constexpr const diff --git a/Source/Cmlx/mlx-generated/metal/copy.h b/Source/Cmlx/mlx-generated/metal/copy.h index 914aebfd..2113c825 100644 --- a/Source/Cmlx/mlx-generated/metal/copy.h +++ b/Source/Cmlx/mlx-generated/metal/copy.h @@ -42,36 +42,36 @@ template device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); + auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; + auto src_idx = elem_to_loc_2(index, src_strides); + IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - int64_t dst_idx = - index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); + auto src_idx = elem_to_loc_3(index, src_strides); + IdxT dst_idx = + index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_g( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -80,17 +80,16 @@ template constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc( + auto src_idx = elem_to_loc( {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); if (N == 1) { - int64_t dst_idx = - index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); + IdxT dst_idx = + index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); return; } auto xshape = src_shape[ndim - 1]; - int64_t dst_idx = - N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z); + IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); auto src_xstride = src_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[dst_idx + i] = static_cast(src[src_idx]); @@ -105,36 +104,36 @@ template constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg_nd2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg_nd3( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } -template +template [[kernel]] void copy_gg( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], @@ -143,7 +142,7 @@ template constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, src_shape, src_strides, @@ -153,8 +152,8 @@ template dst[idx.y] = static_cast(src[idx.x]); return; } - auto src_xstride = src_strides[ndim - 1]; - auto dst_xstride = dst_strides[ndim - 1]; + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; auto xshape = src_shape[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { dst[idx.y] = static_cast(src[idx.x]); diff --git a/Source/Cmlx/mlx-generated/metal/gather.h b/Source/Cmlx/mlx-generated/metal/gather.h index 8063c6f6..32caa869 100644 --- a/Source/Cmlx/mlx-generated/metal/gather.h +++ b/Source/Cmlx/mlx-generated/metal/gather.h @@ -4,7 +4,7 @@ #include "indexing.h" -template +template METAL_FUNC void gather_impl( const device T* src [[buffer(0)]], device T* out [[buffer(1)]], @@ -16,34 +16,36 @@ METAL_FUNC void gather_impl( const thread Indices& indices, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - size_t src_idx = 0; + LocT src_idx = 0; for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; + LocT idx_loc; if (IDX_NDIM == 0) { idx_loc = 0; } else if (IDX_NDIM == 1) { - idx_loc = index.x * indices.strides[indices.ndim * i]; + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); } else { - idx_loc = index.x * indices.strides[indices.ndim * i]; - idx_loc += elem_to_loc( - index.y, - &indices.shapes[indices.ndim * i + 1], - &indices.strides[indices.ndim * i + 1], - indices.ndim - 1); + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); + idx_loc += indices.row_contiguous[i] + ? index.y + : elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; + src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); } - auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + auto src_offset = + elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); - size_t out_idx = index.z; + LocT out_idx = index.z; if (IDX_NDIM == 1) { - out_idx += static_cast(grid_dim.z) * index.x; + out_idx += static_cast(grid_dim.z) * index.x; } else if (IDX_NDIM >= 2) { - out_idx += - grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); } out[out_idx] = src[src_offset + src_idx]; } diff --git a/Source/Cmlx/mlx-generated/metal/gemv.metal b/Source/Cmlx/mlx-generated/metal/gemv.metal index 00e0704e..5ff010e7 100644 --- a/Source/Cmlx/mlx-generated/metal/gemv.metal +++ b/Source/Cmlx/mlx-generated/metal/gemv.metal @@ -3,8 +3,6 @@ #include #include -#include "bf16.h" -#include "defines.h" #include "utils.h" #include "steel/utils.h" @@ -912,4 +910,4 @@ template < // clang-format off instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float16, half); -instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on \ No newline at end of file +instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/indexing.h b/Source/Cmlx/mlx-generated/metal/indexing.h index 9f76e477..05bef96b 100644 --- a/Source/Cmlx/mlx-generated/metal/indexing.h +++ b/Source/Cmlx/mlx-generated/metal/indexing.h @@ -9,11 +9,12 @@ struct Indices { const array buffers; const constant int* shapes; const constant size_t* strides; + const constant bool* row_contiguous; const int ndim; }; template -METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { if (is_unsigned_v) { return idx; } else { diff --git a/Source/Cmlx/mlx-generated/metal/jit/bf16.h b/Source/Cmlx/mlx-generated/metal/jit/bf16.h new file mode 100644 index 00000000..8675ed84 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/jit/bf16.h @@ -0,0 +1,16 @@ +// Copyright © 2024 Apple Inc. + +// clang-format off +#define jit_if #if +#define jit_else #else +#define jit_endif #endif + +jit_if (__METAL_VERSION__ >= 310) + +#include "../metal_3_1/bf16.h" + +jit_else + +#include "../metal_3_0/bf16.h" + +jit_endif // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/layer_norm.metal b/Source/Cmlx/mlx-generated/metal/layer_norm.metal index 79f04d7b..c3b06cab 100644 --- a/Source/Cmlx/mlx-generated/metal/layer_norm.metal +++ b/Source/Cmlx/mlx-generated/metal/layer_norm.metal @@ -3,8 +3,6 @@ #include #include -#include "bf16.h" -#include "defines.h" #include "utils.h" using namespace metal; diff --git a/Source/Cmlx/mlx-generated/metal/metal_3_0/bf16.h b/Source/Cmlx/mlx-generated/metal/metal_3_0/bf16.h new file mode 100644 index 00000000..f5d48670 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/metal_3_0/bf16.h @@ -0,0 +1,314 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat numeric limits +///////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable + +namespace metal { + +template <> +struct _numeric_limits_impl : _fp_numeric_limits_impl_base { + static constexpr constant int digits = 8; + static constexpr constant int digits10 = 2; + static constexpr constant int max_digits10 = 4; + static constexpr constant int radix = 2; + static constexpr constant int min_exponent = -125; + static constexpr constant int min_exponent10 = -37; + static constexpr constant int max_exponent = 128; + static constexpr constant int max_exponent10 = 38; + + static constexpr bfloat16_t min() { + return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t lowest() { + return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t max() { + return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t epsilon() { + return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t round_error() { + return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t infinity() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t quiet_NaN() { + return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t signaling_NaN() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t denorm_min() { + return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat()); + } +}; + +METAL_FUNC bool isnan(_MLX_BFloat16 x) { + return x != x; +} + +} // namespace metal + +#pragma METAL internals : disable +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return x.bits_; +} + +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()); +} diff --git a/Source/Cmlx/mlx-generated/metal/metal_3_1/bf16.h b/Source/Cmlx/mlx-generated/metal/metal_3_1/bf16.h new file mode 100644 index 00000000..aa3c3c78 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/metal_3_1/bf16.h @@ -0,0 +1,16 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +typedef bfloat bfloat16_t; +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return as_type(x); +} + +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return as_type(x); +} diff --git a/Source/Cmlx/mlx-generated/metal/quantized.h b/Source/Cmlx/mlx-generated/metal/quantized.h index 4f388b9f..ad53f082 100644 --- a/Source/Cmlx/mlx-generated/metal/quantized.h +++ b/Source/Cmlx/mlx-generated/metal/quantized.h @@ -8,12 +8,13 @@ using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U sum = 0; @@ -27,6 +28,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -37,6 +53,16 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { sum += x[i]; @@ -50,8 +76,8 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U sum = 0; @@ -63,8 +89,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; } } @@ -76,8 +115,15 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; } } @@ -86,9 +132,10 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { sum += x[i]; x_thread[i] = x[i]; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; } return sum; @@ -102,8 +149,8 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U accum = 0; @@ -117,6 +164,26 @@ inline U qdot( } } + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -128,6 +195,23 @@ inline U qdot( } } + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * w[i]; @@ -146,8 +230,8 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U accum = 0; @@ -161,6 +245,26 @@ inline U qdot_safe( } } + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { @@ -172,6 +276,23 @@ inline U qdot_safe( } } + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + else if (bits == 8) { for (int i = 0; i < N; i++) { accum += x_thread[i] * w[i]; @@ -185,8 +306,8 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -198,12 +319,45 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { } } + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + else if (bits == 4) { U s[2] = {scale, scale / 16.0f}; for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + + } else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } } else if (bits == 8) { @@ -217,8 +371,8 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); if (bits == 2) { U s[4] = { @@ -234,6 +388,22 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + else if (bits == 4) { U s[2] = {scale, scale / static_cast(16.0f)}; for (int i = 0; i < (N / 2); i++) { @@ -242,6 +412,18 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + else if (bits == 8) { for (int i = 0; i < N; i++) { w_local[i] = scale * w[i] + bias; @@ -266,10 +448,11 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); - MLX_MTL_CONST short pack_factor = 32 / bits; + MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -285,12 +468,12 @@ struct QuantizedBlockLoader { const short bj; threadgroup T* dst; - const device uint32_t* src; + const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( - const device uint32_t* src_, + const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, @@ -299,14 +482,16 @@ struct QuantizedBlockLoader { ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( - reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld / pack_factor + bj), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), biases(biases_ + bi * src_ld / group_size) {} @@ -319,7 +504,7 @@ struct QuantizedBlockLoader { T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } @@ -346,7 +531,10 @@ struct QuantizedBlockLoader { T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); } } @@ -371,6 +559,63 @@ struct QuantizedBlockLoader { } }; +template +METAL_FUNC void qmv_quad_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.y * in_vec_size + quad_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device T* sl = scales + row * in_vec_size_g * quads_per_simd; + const device T* bl = biases + row * in_vec_size_g * quads_per_simd; + + U s = sl[0]; + U b = bl[0]; + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + template METAL_FUNC void qmv_fast_impl( const device uint32_t* w, @@ -383,25 +628,30 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int packs_per_thread = bits > 2 ? 2 : 1; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; + const device uint8_t* ws = (const device uint8_t*)w; + typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -411,8 +661,7 @@ METAL_FUNC void qmv_fast_impl( U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -421,7 +670,7 @@ METAL_FUNC void qmv_fast_impl( result[row] += qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -447,21 +696,25 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; + const device uint8_t* ws = (const device uint8_t*)w; + typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; @@ -474,7 +727,8 @@ METAL_FUNC void qmv_impl( // In this case we need to properly guard all our reads because there isn't // even 1 tile in the matrix if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -485,8 +739,7 @@ METAL_FUNC void qmv_impl( U sum = load_vector(x, x_thread); for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -496,7 +749,7 @@ METAL_FUNC void qmv_impl( qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -505,18 +758,20 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); - for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } } for (int row = 0; out_row + row < out_vec_size; row++) { @@ -529,7 +784,8 @@ METAL_FUNC void qmv_impl( // In this case the last tile is moved back to redo some output values else { - w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -540,8 +796,7 @@ METAL_FUNC void qmv_impl( U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -551,7 +806,7 @@ METAL_FUNC void qmv_impl( qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -560,21 +815,21 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); - for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot_safe( - wl, x_thread, s, b, sum, remaining); + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } } - for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { @@ -586,24 +841,28 @@ METAL_FUNC void qmv_impl( template METAL_FUNC void qvm_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, + const int in_vec_size, + const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int tn = 32 / pack_factor; - constexpr int blocksize = SIMD_SIZE; + constexpr int block_size = SIMD_SIZE; + + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; typedef struct { - uint32_t wi[tn]; + uint8_t wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; @@ -613,11 +872,10 @@ METAL_FUNC void qvm_impl( thread U x_local = 0; // Adjust positions - const int out_vec_size_w = out_vec_size / pack_factor; + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col = - tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; - w += out_col / pack_factor + simd_lid * out_vec_size_w; + int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g; x += tid.y * in_vec_size + simd_lid; @@ -627,43 +885,42 @@ METAL_FUNC void qvm_impl( return; } - // Loop over in_vec in blocks of blocksize - int remaining = in_vec_size % blocksize; + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; if (remaining == 0) { - for (int i = 0; i < in_vec_size; i += blocksize) { + for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); - + w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; } } else { - for (int i = blocksize; i < in_vec_size; i += blocksize) { + for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; } if (static_cast(simd_lid) < remaining) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; @@ -697,16 +954,16 @@ template < const int BK = 32, const int BN = 32> METAL_FUNC void qmm_t_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, - const constant int& M, - const constant int& N, const constant int& K, + const constant int& N, + const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -718,8 +975,9 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -737,13 +995,15 @@ METAL_FUNC void qmm_t_impl( bits>; // Set the block - const int K_w = K / pack_factor; + const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; + auto wl = (const device uint8_t*)w; + x += y_row * K; - w += y_col * K_w; + wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * N + y_col; @@ -752,7 +1012,7 @@ METAL_FUNC void qmm_t_impl( const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { @@ -794,6 +1054,7 @@ METAL_FUNC void qmm_t_impl( loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); @@ -818,16 +1079,16 @@ template < const int BK = 32, const int BN = 32> METAL_FUNC void qmm_n_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, - const constant int& M, - const constant int& N, const constant int& K, + const constant int& N, + const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -839,9 +1100,11 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -858,11 +1121,13 @@ METAL_FUNC void qmm_n_impl( group_size, bits>; + auto wl = (const device uint8_t*)w; + // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * K; - w += y_col / pack_factor; + wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; y += y_row * N + y_col; @@ -870,7 +1135,7 @@ METAL_FUNC void qmm_n_impl( // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { @@ -942,6 +1207,45 @@ METAL_FUNC void qmm_n_impl( } } +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant size_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant size_t* w_strides, + const constant size_t* s_strides, + const constant size_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + template METAL_FUNC void adjust_matrix_offsets( const device T*& x, @@ -996,7 +1300,58 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template +[[kernel]] void qmv_quad( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_quad_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid); +} + +template [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1005,9 +1360,35 @@ template device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmv_fast_impl( w, scales, @@ -1021,7 +1402,7 @@ template simd_lid); } -template +template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1030,9 +1411,35 @@ template device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmv_impl( w, scales, @@ -1046,23 +1453,49 @@ template simd_lid); } -template +template [[kernel]] void qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qvm_impl( - x, w, scales, biases, + x, y, in_vec_size, out_vec_size, @@ -1071,23 +1504,87 @@ template simd_lid); } +template +[[kernel]] void qvm_split_k( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& final_block_size [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + template < typename T, const int group_size, const int bits, const bool aligned_N, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], + const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1099,26 +1596,53 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], + const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1131,8 +1655,27 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template @@ -1141,23 +1684,23 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1202,23 +1745,23 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1259,27 +1802,27 @@ template template [[kernel]] void bs_qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1306,10 +1849,10 @@ template b_strides, tid); qvm_impl( - x, w, scales, biases, + x, y, in_vec_size, out_vec_size, @@ -1327,28 +1870,28 @@ template < const int BK = 32, const int BN = 32> [[kernel]] void bs_qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], + const constant int& batch_ndims [[buffer(16)]], + const constant int* batch_shape [[buffer(17)]], + const device uint32_t* lhs_indices [[buffer(18)]], + const device uint32_t* rhs_indices [[buffer(19)]], + const constant size_t* lhs_strides [[buffer(20)]], + const constant size_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1383,7 +1926,7 @@ template < b_strides, tid); qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < @@ -1394,28 +1937,28 @@ template < const int BK = 32, const int BN = 32> [[kernel]] void bs_qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], + const constant int& batch_ndims [[buffer(16)]], + const constant int* batch_shape [[buffer(17)]], + const device uint32_t* lhs_indices [[buffer(18)]], + const device uint32_t* rhs_indices [[buffer(19)]], + const constant size_t* lhs_strides [[buffer(20)]], + const constant size_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1451,7 +1994,7 @@ template < b_strides, tid); qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template @@ -1464,13 +2007,14 @@ template uint2 grid_dim [[threads_per_grid]]) { constexpr T eps = T(1e-7); constexpr int simd_size = 32; - constexpr int uint8_bits = 8; constexpr T n_bins = (1 << bits) - 1; - constexpr int packs_per_int = uint8_bits / bits; + constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = packs_per_int / values_per_reduce; constexpr int writes_per_pack = writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -1478,7 +2022,9 @@ template size_t offset = index.x + grid_dim.x * size_t(index.y); size_t in_index = offset * values_per_reduce; - size_t out_index = offset * writes_per_pack; + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; T w_thread[values_per_reduce]; T w_min = Limits::max; @@ -1511,7 +2057,9 @@ template biases[gindex] = bias; } - uint8_t output = 0; + // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t + uint32_t output = 0; + #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); @@ -1527,47 +2075,23 @@ template output = 0; } else { #pragma clang loop unroll(full) - for (int j = 0; j < writes_per_reduce - 1; j++) { - uint8_t sval = simd_shuffle_down(val, j + 1); - output += sval << (bits * (values_per_reduce + j + i)); + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = simd_shuffle_down(val, j); + output += sval << (bits * (j * values_per_reduce + i)); } } } - if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { - out[out_index / writes_per_reduce] = output; - } -} - -template -[[kernel]] void affine_quantize_scales_biases( - const device T* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - device uint8_t* out [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr int uint8_bits = 8; - constexpr int packs_per_int = uint8_bits / bits; - constexpr T n_bins = (1 << bits) - 1; - - size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t in_index = offset * packs_per_int; - size_t gindex = in_index / group_size; - - T scale = scales[gindex]; - T bias = biases[gindex]; - - uint8_t output = 0; -#pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { - uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins); - if (bits == 8) { - output = val; - } else { - output += val << (bits * i); + if (bits == 3 || bits == 6) { + if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else { + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; } } - out[offset] = output; } template @@ -1578,26 +2102,48 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int uint8_bits = 8; - constexpr int packs_per_int = uint8_bits / bits; + constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; size_t offset = index.x + grid_dim.x * size_t(index.y); size_t oindex = offset * packs_per_int; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; - uint val = w[offset]; + out += oindex; + + if (bits == 3) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x7) * scale + bias; + out[1] = ((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + out[3] = ((w[1] & 0xe) >> 1) * scale + bias; + out[4] = ((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + + } else if (bits == 6) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x3f) * scale + bias; + out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { - uint8_t d; - if (bits == 2) { - d = (val >> (bits * i)) & 0x03; - } else if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; + for (int i = 0; i < packs_per_int; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * d + bias; } - out[oindex + i] = scale * d + bias; } } diff --git a/Source/Cmlx/mlx-generated/metal/random.metal b/Source/Cmlx/mlx-generated/metal/random.metal index 5a1704d2..17adffd3 100644 --- a/Source/Cmlx/mlx-generated/metal/random.metal +++ b/Source/Cmlx/mlx-generated/metal/random.metal @@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { [[kernel]] void rbitsc( device const uint32_t* keys, device char* out, - device const bool& odd, - device const uint& bytes_per_key, + constant const bool& odd, + constant const uint& bytes_per_key, uint2 grid_dim [[threads_per_grid]], uint2 index [[thread_position_in_grid]]) { auto kidx = 2 * index.x; @@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { [[kernel]] void rbits( device const uint32_t* keys, device char* out, - device const bool& odd, - device const uint& bytes_per_key, + constant const bool& odd, + constant const uint& bytes_per_key, constant const int& ndim, constant const int* key_shape, constant const size_t* key_strides, diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h index 381d5e20..e0d08392 100644 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h @@ -1,6 +1,11 @@ // Copyright © 2023-2024 Apple Inc. -template +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS> [[kernel]] void all_reduce( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -16,10 +21,10 @@ template threadgroup U shared_vals[simd_size]; U total = Op::init; - int64_t start_idx = gid.y * row_size; - int64_t actual_row = + IdxT start_idx = gid.y * IdxT(row_size); + IdxT actual_row = (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; - int64_t blocks = actual_row / (lsize.x * N_READS); + IdxT blocks = actual_row / (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS); extra -= lid.x * N_READS; start_idx += lid.x * N_READS; @@ -30,7 +35,7 @@ template extra = 0; } - for (int64_t b = 0; b < blocks; b++) { + for (IdxT b = 0; b < blocks; b++) { for (int i = 0; i < N_READS; i++) { total = op(static_cast(in[i]), total); } diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h index 52e763dd..2fa5132d 100644 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h @@ -1,11 +1,6 @@ // Copyright © 2023-2024 Apple Inc. -template < - typename T, - typename U, - typename Op, - int NDIMS, - int N_READS = REDUCE_N_READS> +template [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -20,170 +15,129 @@ template < const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[thread_position_in_grid]], - uint3 tsize [[threads_per_grid]]) { + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + constexpr int n_reads = 4; Op op; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - // Case 1: Small row small column - if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { - U totals[31]; - for (int i = 0; i < 31; i++) { - totals[i] = Op::init; - } - - short stride = reduction_stride; - short size = reduction_size; - short blocks = stride / N_READS; - short extra = stride - blocks * N_READS; + U totals[n_reads]; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } - size_t out_idx = tid.x + tsize.y * size_t(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; + if (column >= reduction_stride) { + return; + } + bool safe = column + n_reads <= reduction_stride; - for (uint r = 0; r < non_col_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; - for (short i = 0; i < size; i++) { - for (short j = 0; j < blocks; j++) { - for (short k = 0; k < N_READS; k++) { - totals[j * N_READS + k] = - op(totals[j * N_READS + k], - static_cast(row[i * stride + j * N_READS + k])); - } - } - for (short k = 0; k < extra; k++) { - totals[blocks * N_READS + k] = - op(totals[blocks * N_READS + k], - static_cast(row[i * stride + blocks * N_READS + k])); - } + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(lid.y, reduce_shape, reduce_strides); + for (IdxT r = lid.y; r < total_rows; r += lsize.y) { + row = in + loop.location(); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); } - - loop.next(reduce_shape, reduce_strides); - } - out += out_idx * reduction_stride; - for (short j = 0; j < stride; j++) { - out[j] = totals[j]; } + loop.next(lsize.y, reduce_shape, reduce_strides); } - // Case 2: Long row small column - else if (reduction_size * non_col_reductions < 32) { - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = Op::init; + if (lsize.y > 1) { + // lsize.y should be <= 8 + threadgroup U shared_vals[32 * 8 * n_reads]; + for (int i = 0; i < n_reads; i++) { + shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; } - - short size = reduction_size; - size_t offset = size_t(tid.x) * N_READS; - bool safe = offset + N_READS <= reduction_stride; - short extra = reduction_stride - offset; - - size_t out_idx = tid.y + tsize.z * size_t(tid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + offset; - - for (uint r = 0; r < non_col_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - - if (safe) { - for (short i = 0; i < size; i++) { - for (short j = 0; j < N_READS; j++) { - totals[j] = - op(static_cast(row[i * reduction_stride + j]), totals[j]); - } - } - } else { - for (short i = 0; i < size; i++) { - for (short j = 0; j < extra; j++) { - totals[j] = - op(static_cast(row[i * reduction_stride + j]), totals[j]); - } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (int i = 0; i < n_reads; i++) { + totals[i] = shared_vals[lid.x * n_reads + i]; + } + for (uint j = 1; j < lsize.y; j++) { + for (int i = 0; i < n_reads; i++) { + totals[i] = + op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], + totals[i]); } } - - loop.next(reduce_shape, reduce_strides); } - out += out_idx * reduction_stride + offset; + } + + if (lid.y == 0) { + out += out_idx * IdxT(reduction_stride) + column; if (safe) { - for (short i = 0; i < N_READS; i++) { + for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; } } else { - for (short i = 0; i < extra; i++) { + for (int i = 0; column + i < reduction_stride; i++) { out[i] = totals[i]; } } } +} - // Case 3: Long row medium column - else { - threadgroup U shared_vals[1024]; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = Op::init; - } - - short stride = reduction_stride; - short lid = simd_group_id * simd_size + simd_lane_id; - short2 tile((stride + N_READS - 1) / N_READS, 32); - short2 offset((lid % tile.x) * N_READS, lid / tile.x); - short sm_stride = tile.x * N_READS; - bool safe = offset.x + N_READS <= stride; - - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; - - // Read cooperatively and contiguously and aggregate the partial results. - size_t total = non_col_reductions * reduction_size; - loop.next(offset.y, reduce_shape, reduce_strides); - for (size_t r = offset.y; r < total; r += simd_size) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - - if (safe) { - for (int i = 0; i < N_READS; i++) { - totals[i] = op(static_cast(row[i]), totals[i]); - } - } else { - U vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = (offset.x + i < stride) ? static_cast(row[i]) : op.init; - } - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); - } - } - - loop.next(simd_size, reduce_shape, reduce_strides); - } +template +[[kernel]] void col_reduce_longcolumn( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + Op op; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; - // Each thread holds N_READS partial results but the simdgroups are not - // aligned to do the reduction across the simdgroup so we write our results - // in the shared memory and read them back according to the simdgroup. - for (int i = 0; i < N_READS; i++) { - shared_vals[offset.y * sm_stride + offset.x + i] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_READS; i++) { - totals[i] = op.simd_reduce( - shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]); - } + IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + lid.x; + + U total = Op::init; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); + for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; + r += lsize.y * gsize.z) { + row = in + loop.location(); + total = op(static_cast(*row), total); + loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); + } - // Write the output. - if (simd_lane_id == 0) { - short column = simd_group_id * N_READS; - out += out_idx * reduction_stride + column; - if (column + N_READS <= stride) { - for (int i = 0; i < N_READS; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; column + i < stride; i++) { - out[i] = totals[i]; - } - } + threadgroup U shared_vals[32 * 32]; + shared_vals[lid.y * lsize.x + lid.x] = total; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (uint i = 1; i < lsize.y; i++) { + total = op(total, shared_vals[i * lsize.x + lid.x]); } + out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = + total; } } @@ -198,7 +152,14 @@ template < * totals with a loop. * 7. Write them to the output */ -template +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> [[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -216,14 +177,14 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - constexpr int n_simdgroups = 4; + constexpr int n_simdgroups = 8; constexpr short tgp_size = n_simdgroups * simd_size; constexpr short n_reads = (BM * BN) / tgp_size; constexpr short n_read_blocks = BN / n_reads; threadgroup U shared_vals[BN * BM]; U totals[n_reads]; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; for (int i = 0; i < n_reads; i++) { @@ -232,17 +193,17 @@ template short lid = simd_group_id * simd_size + simd_lane_id; short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); - size_t column = BN * gid.x + offset.x; + IdxT column = BN * gid.x + offset.x; bool safe = column + n_reads <= reduction_stride; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; - size_t total = non_col_reductions * reduction_size; + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(offset.y, reduce_shape, reduce_strides); - for (size_t r = offset.y; r < total; r += BM) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT r = offset.y; r < total; r += BM) { + row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { @@ -282,8 +243,8 @@ template // Write the output. if (simd_lane_id == 0) { - size_t out_column = BN * gid.x + out_offset.x; - out += out_idx * reduction_stride + out_column; + IdxT out_column = BN * gid.x + out_offset.x; + out += out_idx * IdxT(reduction_stride) + out_column; if (out_column + n_outputs <= reduction_stride) { for (int i = 0; i < n_outputs; i++) { out[i] = totals[i]; @@ -316,7 +277,7 @@ template // Write the output. if (offset.y == 0) { - out += out_idx * reduction_stride + column; + out += out_idx * IdxT(reduction_stride) + column; if (safe) { for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; @@ -329,3 +290,109 @@ template } } } + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> +[[kernel]] void col_reduce_2pass( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + constexpr int n_outputs = BN / n_simdgroups; + constexpr short outer_blocks = 32; + static_assert(BM == 32, "BM should be equal to 32"); + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + IdxT column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT block_idx = full_idx / IdxT(out_size); + IdxT out_idx = full_idx % IdxT(out_size); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); + for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { + row = in + loop.location(); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(outer_blocks * BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + IdxT out_column = BN * gid.x + out_offset.x; + out += full_idx * IdxT(reduction_stride) + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h index af8a01da..74636125 100644 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h @@ -193,6 +193,7 @@ template < typename T, typename U, typename Op, + typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_small( @@ -214,20 +215,20 @@ template < Op op; U total_val = Op::init; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); // Precompute some row reduction numbers const device T* row; - int blocks = row_size / N_READS; - int extra = row_size % N_READS; + int blocks = IdxT(row_size) / N_READS; + int extra = IdxT(row_size) % N_READS; if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - size_t out_idx = tid.x + tsize.y * size_t(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(reduce_shape, reduce_strides); } @@ -236,13 +237,13 @@ template < } else { // Collaboratively reduce over non_row_reductions in the simdgroup. Each // thread reduces every 32nd row and then a simple simd reduce. - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); loop.next(simd_lane_id, reduce_shape, reduce_strides); for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(simd_size, reduce_shape, reduce_strides); } @@ -259,6 +260,7 @@ template < typename T, typename U, typename Op, + typename IdxT = size_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> [[kernel]] void row_reduce_simple( @@ -277,15 +279,15 @@ template < U totals[N_WRITES]; // Move to the row - size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); + IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); if (out_idx + N_WRITES > out_size) { out_idx = out_size - N_WRITES; } - in += out_idx * reduction_size; + in += out_idx * IdxT(reduction_size); out += out_idx; // Each thread reduces across the row - int blocks = reduction_size / (lsize.x * N_READS); + int blocks = IdxT(reduction_size) / (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS); per_thread_row_reduce( totals, in, reduction_size, blocks, extra, lsize.x, lid.x); @@ -306,6 +308,7 @@ template < typename T, typename U, typename Op, + typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_looped( @@ -330,19 +333,20 @@ template < threadgroup U shared_vals[simd_size]; U total = Op::init; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // needs a small refactor. - in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + in += elem_to_loc(out_idx, shape, strides, ndim) + + lid.x * N_READS; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - int blocks = row_size / (lsize.x * N_READS); + int blocks = IdxT(row_size) / (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS); - for (size_t i = 0; i < non_row_reductions; i++) { - row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT i = 0; i < non_row_reductions; i++) { + row = in + loop.location(); // Each thread reduces across the row U row_total; diff --git a/Source/Cmlx/mlx-generated/metal/rms_norm.metal b/Source/Cmlx/mlx-generated/metal/rms_norm.metal index 9b52a986..14c04946 100644 --- a/Source/Cmlx/mlx-generated/metal/rms_norm.metal +++ b/Source/Cmlx/mlx-generated/metal/rms_norm.metal @@ -3,8 +3,6 @@ #include #include -#include "bf16.h" -#include "defines.h" #include "utils.h" using namespace metal; @@ -17,12 +15,15 @@ template constant float& eps, constant uint& axis_size, constant uint& w_stride, - threadgroup float* local_inv_mean [[threadgroup(0)]], - threadgroup float* local_sums [[threadgroup(1)]], uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + + threadgroup float local_inv_mean[1]; + threadgroup float local_sums[SIMD_SIZE]; + float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; @@ -84,13 +85,15 @@ template constant float& eps, constant uint& axis_size, constant uint& w_stride, - threadgroup float* local_inv_mean [[threadgroup(0)]], - threadgroup float* local_sums [[threadgroup(1)]], uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + threadgroup float local_inv_mean[1]; + threadgroup float local_sums[SIMD_SIZE]; + float acc = 0; x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; @@ -376,8 +379,6 @@ template constant float& eps, \ constant uint& axis_size, \ constant uint& w_stride, \ - threadgroup float* local_inv_mean [[threadgroup(0)]], \ - threadgroup float* local_sums [[threadgroup(1)]], \ uint gid [[thread_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ @@ -407,8 +408,6 @@ template constant float& eps, \ constant uint& axis_size, \ constant uint& w_stride, \ - threadgroup float* local_inv_mean [[threadgroup(0)]], \ - threadgroup float* local_sums [[threadgroup(1)]], \ uint gid [[thread_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \ diff --git a/Source/Cmlx/mlx-generated/metal/rope.metal b/Source/Cmlx/mlx-generated/metal/rope.metal index cc9a4648..106d23f6 100644 --- a/Source/Cmlx/mlx-generated/metal/rope.metal +++ b/Source/Cmlx/mlx-generated/metal/rope.metal @@ -2,7 +2,6 @@ #include -#include "bf16.h" #include "utils.h" template void rope_single_impl( diff --git a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal index 8ec0a579..04db37d3 100644 --- a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal +++ b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal @@ -1,1469 +1,23 @@ -#include #include -#include "steel/defines.h" -#include "steel/gemm/transforms.h" -#include "steel/utils.h" +#include "sdpa_vector.h" +#include "utils.h" -#include "scaled_dot_product_attention_params.h" using namespace metal; -using namespace mlx::steel; - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short alignment = 1, - short n_reads = (BCOLS * BROWS) / (tgp_size), - short TCOLS = BCOLS / n_reads, - short TROWS = tgp_size / TCOLS> -struct BlockLoaderFA { - STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; - STEEL_CONST short vec_size = n_reads; - - // Leading dimension for src - const int src_ld; - const int tile_stride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - struct alignas(alignment * sizeof(T)) ReadVector { - uint8_t v[sizeof(T) * vec_size]; - }; - - /* Constructor */ - METAL_FUNC BlockLoaderFA( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - *((threadgroup ReadVector*)(&dst[i * dst_ld])) = - *((const device ReadVector*)(&src[i * src_ld])); - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = src_tile_dim - short2(bj, bi); - - // Skip loading if thread has no valid reads - if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - return; - } - - // Use fast thread memory for bound checks - bool tmp_idx[vec_size]; - T tmp_val[vec_size]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - // Make sure tmp_idx only contains valid indices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); - } - - // Read valid indices into tmp_val - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; - } - - // Zero out uneeded values - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = tmp_val[j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tile_stride; - } - METAL_FUNC void next(short n) { - src += n * tile_stride; - } -}; - -template -struct LoopAlignment {}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMAFA { - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; - - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - // Offsets within threadgroup - const short tm; - const short tn; - - short sm; - short sn; - - ushort sid; - ushort slid; - - short As_offset; - short Bs_offset; - - /* Constructor */ - METAL_FUNC BlockMMAFA( - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - slid = simd_lane_id; - sid = simd_group_id; - - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; - - // Iterate over BK in blocks of 8 - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup A as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Multiply and accumulate into result simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; - } - } - - METAL_FUNC void rescale_output(const threadgroup float* Corrections) { - // Loop over all simdgroup tiles - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - short row = sm + tm + i * TM_stride; - float scale_value = Corrections[row]; - - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - // int offset = (i * TM_stride) * ldc + (j * TN_stride); - accum[0] *= scale_value; - accum[1] *= scale_value; - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* C, const int ldc) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + tn + sn; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; - - // Write out C - C[offset] = outs[0]; - C[offset + 1] = outs[1]; - } - } - } - - METAL_FUNC void store_result_to_tgp_memory( - threadgroup U* C, - const int ldc, - short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } - - METAL_FUNC void - store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; - } - } - } - - METAL_FUNC void store_result_safe( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); - } - } - } - } - } - - METAL_FUNC void clear_results() { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - results[i * TN + j] = simdgroup_matrix(0); - } - } - } -}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct FastAttentionKernel { - STEEL_CONST short tgp_padding = 16 / sizeof(T); - STEEL_CONST short float_padding = 16 / sizeof(float); - STEEL_CONST short tgp_mem_size_q = - transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_k = - transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_v = - transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); - - // maxes, rowsums, rescale - STEEL_CONST short tgp_mem_size_corrections = - 4 * (BM * sizeof(float) + float_padding); - - STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; - - STEEL_CONST short tgp_mem_size = share_kv_smem - ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections - : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections + tgp_mem_size_v; - - STEEL_CONST short tgp_size = WM * WN * 32; - - static_assert(transpose_q == false, "Expected Q not transposed."); - static_assert(transpose_k == true, "Expected K transposed."); - static_assert(transpose_v == false, "Expected V not transposed."); - static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); - - using loader_q_t = BlockLoaderFA< - T, - transpose_q ? BK : BM, - transpose_q ? BM : BK, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - !transpose_q, - tgp_size>; - - using loader_k_t = BlockLoaderFA< - T, - transpose_k ? BN : BK, - transpose_k ? BK : BN, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - transpose_k, - tgp_size>; - - using loader_v_t = BlockLoaderFA< - T, - transpose_v ? BK : BN, - transpose_v ? BN : BK, - transpose_v ? BN + tgp_padding : BK + tgp_padding, - transpose_v, - tgp_size>; - - using mma_qk_t = BlockMMAFA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - AccumType, - Epilogue>; - - using mma_sv_t = BlockMMAFA< - T, - U, - BM, - BK, - BN, - WM, - WN, - false, - transpose_v, - BN + tgp_padding, - BK + tgp_padding, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_k_t& loader_b, - thread mma_qk_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - (void)tgp_bm; - - short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - // not valid for gemm_k_iterations > 1 (so, BK == d_k) - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - } - - static METAL_FUNC void initialize_corrections( - threadgroup float* C, - uint simd_lane_id, - uint simd_group_id) { - if (simd_group_id == 0) { - threadgroup float* maxes = C; - threadgroup float* sums = C + (BM + float_padding); - threadgroup float* o_rescale = sums + (BM + float_padding); - threadgroup float* output_rescale = o_rescale + (BM + float_padding); - - if (simd_lane_id < BM) { - maxes[simd_lane_id] = -INFINITY; // m_i - sums[simd_lane_id] = 0.f; // l_i - o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) - output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i - } - } - } - - static METAL_FUNC void rescale_ss( - threadgroup T* Ss, - threadgroup float* Corrections, - uint simd_group_id, - uint simd_lane_id, - short2 local_blocks, - float alpha) { - if (simd_group_id == 0) { - short row_offset = BM + float_padding; - threadgroup float* maxes = Corrections; - threadgroup float* sums = Corrections + row_offset; - threadgroup float* o_rescale = sums + row_offset; - threadgroup float* output_scales = o_rescale + row_offset; - - if (simd_lane_id < uint(local_blocks.y)) { - float m_i_old = maxes[simd_lane_id]; - float l_i_old = sums[simd_lane_id]; - - float m_i_new = m_i_old; - float l_i_new = l_i_old; - - short offset = simd_lane_id * (BN + tgp_padding); - - float m_ij = -INFINITY; - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - m_ij = max(m_ij, val); - } - - m_i_new = max(m_ij, m_i_new); - - float rowsum = 0.f; // lij - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - float P_i_j = exp(val - m_ij); - rowsum += P_i_j; - P_i_j = P_i_j * exp(m_ij - m_i_new); - Ss[offset + j] = T(P_i_j); - } - - l_i_new = - exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; - maxes[simd_lane_id] = m_i_new; - sums[simd_lane_id] = l_i_new; - float rescale = l_i_old * exp(m_i_old - m_i_new); - o_rescale[simd_lane_id] = rescale; - output_scales[simd_lane_id] = 1.0 / l_i_new; - } - } - } - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device U* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - threadgroup T* Qs [[threadgroup(0)]], - threadgroup T* Ks [[threadgroup(1)]], - threadgroup T* Ss [[threadgroup(2)]], - threadgroup T* Vs [[threadgroup(3)]], - threadgroup float* Corrections [[threadgroup(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in Q, O; and head in K, V. - const int c_row = tid_y * BM; - - Q += transpose_q ? c_row : c_row * params->ldq; - thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); - - short tgp_bm = min(BM, params->M - c_row); - short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - loader_q.load_safe(tile_dims_Q); - - initialize_corrections(Corrections, simd_lane_id, simd_group_id); - - O += c_row * params->ldo; - - // Prepare threadgroup mma operation - thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); - thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); - thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); - thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); - - for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; - n_block++) { - short c_col = BN; - - // Prepare threadgroup loading operations - short gemm_k_iterations = params->gemm_k_iterations_aligned; - short tgp_bn_qk = min(BN, params->N - c_col * n_block); - threadgroup_barrier(mem_flags::mem_none); - - /////////////////////////////////////////////////////////////////////////////// - { // Loop over K - unaligned case - - if (tgp_bm == BM && tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } else if (tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else if (tgp_bm == BM) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } - } - - mma_qk_op.store_result_to_tgp_memory( - Ss, BN + tgp_padding, short2(BN, BM)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - rescale_ss( - Ss, - Corrections, - simd_group_id, - simd_lane_id, - short2(tgp_bn_qk, tgp_bm), - params->alpha); - - loader_v.load_safe(short2(BK, tgp_bn_qk)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); - mma_softmax_sv_op.rescale_output(o_scales); - - mma_softmax_sv_op.mma(Ss, Vs); - - threadgroup float* final_output_scales = - Corrections + 3 * (BM + float_padding); - - mma_softmax_sv_op.rescale_output(final_output_scales); - - loader_v.next(); - loader_k.next(BN); - - mma_qk_op.clear_results(); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); - } -}; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using attention_kernel = FastAttentionKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_v, - MN_aligned, - K_aligned>; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* Q_bstrides = batch_strides; - const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); - - Q += batch_offsets.x; - K += batch_offsets.y; - V += batch_offsets.y; - - } else { - Q += params->batch_stride_q * tid.z; - K += params->batch_stride_k * tid.z; - V += params->batch_stride_v * tid.z; - } - - // same shape as input - O += params->batch_stride_o * tid.z; - threadgroup T Qs[attention_kernel::tgp_mem_size_q]; - threadgroup T Ss[attention_kernel::tgp_mem_size_s]; - threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; - - if (attention_kernel::share_kv_smem) { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } else { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T Vs[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } -} - -#define instantiate_fast_inference_self_attention_kernel( \ - itype, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ - "_itype_" #itype)]] [[kernel]] void \ - attention( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - device otype* O [[buffer(3)]], \ - const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 64, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 128, - 2, - 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); - -template < - typename T, - typename T2, - typename T4, - uint16_t TILE_SIZE_CONST, - uint16_t NSIMDGROUPS> -[[kernel]] void fast_inference_sdpa_compute_partials_template( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - const device uint64_t& L [[buffer(3)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], - device float* O_partials [[buffer(5)]], - device float* p_lse [[buffer(6)]], - device float* p_maxes [[buffer(7)]], - threadgroup T* threadgroup_block [[threadgroup(0)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - constexpr const size_t DK = 128; - constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8; - constexpr const size_t THREADS_PER_SIMDGROUP = 32; - constexpr const uint iter_offset = NSIMDGROUPS * 4; - const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS; - uint kv_head_offset_factor = tid.x; - if (is_gqa) { - int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS; - kv_head_offset_factor = tid.x / q_kv_head_ratio; - } - constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4; - constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = - TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS); - constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR; - constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * - SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * - NSIMDGROUPS; - - threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block; -#pragma clang loop unroll(full) - for (uint i = 0; i < 8; i++) { - smemFlush - [simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + - i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - // TODO: multiple query sequence length for speculative decoding - const uint tgroup_query_head_offset = - tid.x * DK + tid.z * (params.N_Q_HEADS * DK); - - const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L; - const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK; - const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK; - - const device T* baseK = - K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset; - const device T* baseQ = Q + tgroup_query_head_offset; - - device T4* simdgroupQueryData = (device T4*)baseQ; - - constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS; - float threadAccum[ACCUM_PER_GROUP]; - -#pragma clang loop unroll(full) - for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; - threadAccumIndex++) { - threadAccum[threadAccumIndex] = -INFINITY; - } - - uint KROW_ACCUM_INDEX = 0; - - const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST; - const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L; - const bool LAST_TILE_ALIGNED = - (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST)); - - T4 thread_data_x4; - T4 thread_data_y4; - if (!LAST_TILE || LAST_TILE_ALIGNED) { - thread_data_x4 = *(simdgroupQueryData + simd_lane_id); -#pragma clang loop unroll(full) - for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; - KROW += NSIMDGROUPS) { - const uint KROW_OFFSET = KROW * DK; - const device T* baseKRow = baseK + KROW_OFFSET; - device T4* keysData = (device T4*)baseKRow; - thread_data_y4 = *(keysData + simd_lane_id); - T kq_scalar = dot(thread_data_x4, thread_data_y4); - threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); - KROW_ACCUM_INDEX++; - } - } else { - thread_data_x4 = *(simdgroupQueryData + simd_lane_id); - const uint START_ROW = tid.y * TILE_SIZE_CONST; - const device T* baseKThisHead = - K + tgroup_k_batch_offset + tgroup_k_head_offset; - - for (size_t KROW = START_ROW + simd_group_id; KROW < L; - KROW += NSIMDGROUPS) { - const uint KROW_OFFSET = KROW * DK; - const device T* baseKRow = baseKThisHead + KROW_OFFSET; - device T4* keysData = (device T4*)baseKRow; - thread_data_y4 = *(keysData + simd_lane_id); - T kq_scalar = dot(thread_data_x4, thread_data_y4); - threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar); - KROW_ACCUM_INDEX++; - } - } - threadgroup float* smemP = (threadgroup float*)threadgroup_block; - -#pragma clang loop unroll(full) - for (size_t i = 0; i < P_VEC4; i++) { - thread_data_x4 = - T4(threadAccum[4 * i], - threadAccum[4 * i + 1], - threadAccum[4 * i + 2], - threadAccum[4 * i + 3]); - simdgroup_barrier(mem_flags::mem_none); - thread_data_y4 = simd_sum(thread_data_x4); - if (simd_lane_id == 0) { - const uint base_smem_p_offset = i * iter_offset + simd_group_id; - smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x); - smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y); - smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z); - smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - float groupMax; - float lse = 0.f; - - constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32; - constexpr const size_t ACCUM_ARRAY_LENGTH = - TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1; - float4 pvals[ACCUM_ARRAY_LENGTH]; - -#pragma clang loop unroll(full) - for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; - accum_array_iter++) { - pvals[accum_array_iter] = float4(-INFINITY); - } - - if (TILE_SIZE_CONST == 64) { - threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block; - float2 vals = smemPtrFlt2[simd_lane_id]; - vals *= params.INV_ALPHA; - float maxval = max(vals.x, vals.y); - simdgroup_barrier(mem_flags::mem_none); - groupMax = simd_max(maxval); - - float2 expf_shifted = exp(vals - groupMax); - float sumExpLocal = expf_shifted.x + expf_shifted.y; - simdgroup_barrier(mem_flags::mem_none); - float tgroupExpSum = simd_sum(sumExpLocal); - - lse = log(tgroupExpSum); - float2 local_p_hat = expf_shifted / tgroupExpSum; - pvals[0].x = local_p_hat.x; - pvals[0].y = local_p_hat.y; - smemPtrFlt2[simd_lane_id] = float2(0.f); - } - constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64; - constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128; - - if (TILE_SIZE_LARGER_THAN_64) { - float maxval = -INFINITY; - threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block; -#pragma clang loop unroll(full) - for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { - float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP]; - vals *= params.INV_ALPHA; - pvals[i] = vals; - maxval = fmax3(vals.x, vals.y, maxval); - maxval = fmax3(vals.z, vals.w, maxval); - } - simdgroup_barrier(mem_flags::mem_none); - groupMax = simd_max(maxval); - - float sumExpLocal = 0.f; -#pragma clang loop unroll(full) - for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { - pvals[i] = exp(pvals[i] - groupMax); - sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w; - } - simdgroup_barrier(mem_flags::mem_none); - float tgroupExpSum = simd_sum(sumExpLocal); - lse = log(tgroupExpSum); -#pragma clang loop unroll(full) - for (int i = 0; i < TILE_SIZE_ITERS_128; i++) { - pvals[i] = pvals[i] / tgroupExpSum; - smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f); - } - } - - threadgroup T* smemV = (threadgroup T*)threadgroup_block; - - const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK; - const size_t v_head_offset = kv_head_offset_factor * L * DK; - - const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK; - const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset; - device T* baseV = (device T*)V + v_offset; - - threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV); - - if (!LAST_TILE || LAST_TILE_ALIGNED) { -#pragma clang loop unroll(full) - for (size_t col = 0; col < MATRIX_COLS; col++) { - uint matrix_load_loop_iter = 0; - constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8; - - for (size_t tile_start = simd_group_id; - tile_start < TILE_SIZE_CONST_DIV_8; - tile_start += NSIMDGROUPS) { - simdgroup_matrix tmp; - ulong simdgroup_matrix_offset = - matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + - simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - ulong2 matrixOrigin = - ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset); - simdgroup_load(tmp, baseV, DK, matrixOrigin, true); - const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0); - const ulong elemsPerRowSmem = TILE_SIZE_CONST; - simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false); - matrix_load_loop_iter++; - }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (TILE_SIZE_CONST == 64) { - T2 local_p_hat = T2(pvals[0].x, pvals[0].y); - uint loop_iter = 0; - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - -#pragma clang loop unroll(full) - for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; - row += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); - threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; - T2 v_local = *(smemV2 + simd_lane_id); - - T val = dot(local_p_hat, v_local); - simdgroup_barrier(mem_flags::mem_none); - - T row_sum = simd_sum(val); - oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = - float(row_sum); - loop_iter++; - } - } - - if (TILE_SIZE_CONST > 64) { - constexpr const size_t TILE_SIZE_CONST_DIV_128 = - (TILE_SIZE_CONST + 1) / 128; - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - uint loop_iter = 0; - for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; - row += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row); - - T row_sum = 0.f; - for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) { - threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; - T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP); - T4 p_local = T4(pvals[i]); - T val = dot(p_local, v_local); - row_sum += val; - } - simdgroup_barrier(mem_flags::mem_none); - row_sum = simd_sum(row_sum); - oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = - float(row_sum); - loop_iter++; - } - } - } - } else { - const int32_t START_ROW = tid.y * TILE_SIZE_CONST; - const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1; - const device T* baseVThisHead = V + v_batch_offset + v_head_offset; - constexpr const int ROWS_PER_ITER = 8; -#pragma clang loop unroll(full) - for (size_t col = 0; col < MATRIX_COLS; col++) { - uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - int32_t tile_start; - for (tile_start = - START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; - tile_start < MAX_START_ROW; - tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) { - simdgroup_matrix tmp; - ulong2 matrixOrigin = - ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start); - simdgroup_load( - tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true); - const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0); - constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST; - simdgroup_store( - tmp, - smemV, - elemsPerRowSmem, - matrixOriginSmem, - /* transpose */ false); - smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR; - }; - - tile_start = - ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR); - - const int32_t INT_L = int32_t(L); - for (int row_index = tile_start + simd_group_id; row_index < INT_L; - row_index += NSIMDGROUPS) { - if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) { - const uint elems_per_row_gmem = DK; - const uint col_index_v_gmem = - col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id; - const uint row_index_v_gmem = row_index; - - const uint elems_per_row_smem = TILE_SIZE_CONST; - const uint col_index_v_smem = row_index % TILE_SIZE_CONST; - const uint row_index_v_smem = simd_lane_id; - - const uint scalar_offset_gmem = - row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem; - const uint scalar_offset_smem = - row_index_v_smem * elems_per_row_smem + col_index_v_smem; - T vdata = T(*(baseVThisHead + scalar_offset_gmem)); - smemV[scalar_offset_smem] = vdata; - smem_col_index += NSIMDGROUPS; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (TILE_SIZE_CONST == 64) { - T2 local_p_hat = T2(pvals[0].x, pvals[0].y); - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - for (size_t smem_row_index = simd_group_id; - smem_row_index < ROWS_PER_ITER; - smem_row_index += NSIMDGROUPS) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index); - threadgroup T2* smemV2 = (threadgroup T2*)smemV_row; - T2 v_local = *(smemV2 + simd_lane_id); - T val = dot(local_p_hat, v_local); - simdgroup_barrier(mem_flags::mem_none); - T row_sum = simd_sum(val); - oPartialSmem[smem_row_index] = float(row_sum); - } - } - - if (TILE_SIZE_CONST > 64) { - threadgroup float* oPartialSmem = - smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col; - uint loop_count = 0; - for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER; - row_index += NSIMDGROUPS) { - T row_sum = 0.f; - for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; - tile_iters++) { - threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index); - threadgroup T4* smemV2 = (threadgroup T4*)smemV_row; - T4 v_local = - *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP); - T4 p_local = T4(pvals[tile_iters]); - row_sum += dot(p_local, v_local); - } - simdgroup_barrier(mem_flags::mem_none); - row_sum = simd_sum(row_sum); - oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = - float(row_sum); - loop_count++; - } - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (simd_group_id == 0) { - threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial; - float4 vals = *(oPartialVec4 + simd_lane_id); - device float* oPartialGmem = - O_partials + tid.x * DK * params.KV_TILES + tid.y * DK; - device float4* oPartialGmemVec4 = (device float4*)oPartialGmem; - oPartialGmemVec4[simd_lane_id] = vals; - } - - if (simd_group_id == 0 && simd_lane_id == 0) { - const uint tileIndex = tid.y; - const uint gmem_partial_scalar_offset = - tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + - tileIndex; - p_lse[gmem_partial_scalar_offset] = lse; - p_maxes[gmem_partial_scalar_offset] = groupMax; - } -} - -#define instantiate_fast_inference_sdpa_to_partials_kernel( \ - itype, itype2, itype4, tile_size, nsimdgroups) \ - template [[host_name("fast_inference_sdpa_compute_partials_" #itype \ - "_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \ - fast_inference_sdpa_compute_partials_template< \ - itype, \ - itype2, \ - itype4, \ - tile_size, \ - nsimdgroups>( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - const device uint64_t& L [[buffer(3)]], \ - const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \ - device float* O_partials [[buffer(5)]], \ - device float* p_lse [[buffer(6)]], \ - device float* p_maxes [[buffer(7)]], \ - threadgroup itype* threadgroup_block [[threadgroup(0)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]]); - // clang-format off -#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \ - itype, itype2, itype4, tile_size) \ - instantiate_fast_inference_sdpa_to_partials_kernel( \ - itype, itype2, itype4, tile_size, 4) \ - instantiate_fast_inference_sdpa_to_partials_kernel( \ - itype, itype2, itype4, tile_size, 8) // clang-format on - -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 64); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 128); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 256); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - float, - float2, - float4, - 512); - -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 64); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 128); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 256); -instantiate_fast_inference_sdpa_to_partials_shapes_helper( - half, - half2, - half4, - 512); - -template -void fast_inference_sdpa_reduce_tiles_template( - const device float* O_partials [[buffer(0)]], - const device float* p_lse [[buffer(1)]], - const device float* p_maxes [[buffer(2)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], - device T* O [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - constexpr const int DK = 128; - const ulong offset_rows = - tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES; - const device float* p_lse_row = p_lse + offset_rows; - const device float* p_rowmax_row = p_maxes + offset_rows; - // reserve some number of registers. this constitutes an assumption on max - // value of KV TILES. - constexpr const uint8_t reserve = 128; - float p_lse_regs[reserve]; - float p_rowmax_regs[reserve]; - float weights[reserve]; - - float true_max = -INFINITY; - for (size_t i = 0; i < params.KV_TILES; i++) { - p_lse_regs[i] = float(*(p_lse_row + i)); - p_rowmax_regs[i] = float(*(p_rowmax_row + i)); - true_max = fmax(p_rowmax_regs[i], true_max); - weights[i] = exp(p_lse_regs[i]); - } - - float denom = 0.f; - for (size_t i = 0; i < params.KV_TILES; i++) { - weights[i] *= exp(p_rowmax_regs[i] - true_max); - denom += weights[i]; - } - - const device float* O_partials_with_offset = O_partials + - tid.z * params.N_Q_HEADS * DK * params.KV_TILES + - tid.x * DK * params.KV_TILES; - - float o_value = 0.f; - for (size_t i = 0; i < params.KV_TILES; i++) { - float val = *(O_partials_with_offset + i * DK + lid.x); - o_value += val * weights[i] / denom; - } - device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK; - O_gmem[lid.x] = T(o_value); - return; -} - -kernel void fast_inference_sdpa_reduce_tiles_float( - const device float* O_partials [[buffer(0)]], - const device float* p_lse [[buffer(1)]], - const device float* p_maxes [[buffer(2)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], - device float* O [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - fast_inference_sdpa_reduce_tiles_template( - O_partials, p_lse, p_maxes, params, O, tid, lid); -} - -kernel void fast_inference_sdpa_reduce_tiles_half( - const device float* O_partials [[buffer(0)]], - const device float* p_lse [[buffer(1)]], - const device float* p_maxes [[buffer(2)]], - const device MLXScaledDotProductAttentionParams& params [[buffer(3)]], - device half* O [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - fast_inference_sdpa_reduce_tiles_template( - O_partials, p_lse, p_maxes, params, O, tid, lid); -} +// SDPA vector instantiations +#define instantiate_sdpa_vector(type, head_dim) \ + instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ + instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \ + instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim) + +#define instantiate_sdpa_vector_heads(type) \ + instantiate_sdpa_vector(type, 64) \ + instantiate_sdpa_vector(type, 96) \ + instantiate_sdpa_vector(type, 128) + +instantiate_sdpa_vector_heads(float) +instantiate_sdpa_vector_heads(bfloat16_t) +instantiate_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention_params.h b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention_params.h deleted file mode 100644 index a77dad26..00000000 --- a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention_params.h +++ /dev/null @@ -1,42 +0,0 @@ -// -// scaled_dot_product_attention_params.h -// mlx - -#pragma once - -struct MLXFastAttentionParams { - const int M; - const int N; - const int K; - - const int ldq; // ldq == ldo - const int ldk; - const int ldv; - const int lds; - const int ldo; - - const int tiles_n; - const int tiles_m; - - const int batch_stride_q; - const int batch_stride_k; - const int batch_stride_v; - const int batch_stride_o; - - const int swizzle_log; - const int gemm_n_iterations_aligned; - const int gemm_k_iterations_aligned; - const int gemm_sv_m_block_iterations; - - const int batch_ndim; - const float alpha; -}; - -struct MLXScaledDotProductAttentionParams { - // Associated dimensions & transposition information - const uint QUERY_SEQUENCE_LENGTH = 1; - const uint N_Q_HEADS = 32; - const uint N_KV_HEADS = 32; - const uint KV_TILES = 1; - const float INV_ALPHA = 0.08838834764831843f; -}; diff --git a/Source/Cmlx/mlx-generated/metal/scan.h b/Source/Cmlx/mlx-generated/metal/scan.h index 67b27ba8..cfa84c04 100644 --- a/Source/Cmlx/mlx-generated/metal/scan.h +++ b/Source/Cmlx/mlx-generated/metal/scan.h @@ -1,7 +1,38 @@ // Copyright © 2023-2024 Apple Inc. +#pragma once + +#define DEFINE_SIMD_SCAN() \ + template = true> \ + T simd_scan(T val) { \ + return simd_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_scan(T val) { \ + for (int i = 1; i <= 16; i *= 2) { \ + val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ + } \ + return val; \ + } + +#define DEFINE_SIMD_EXCLUSIVE_SCAN() \ + template = true> \ + T simd_exclusive_scan(T val) { \ + return simd_exclusive_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_exclusive_scan(T val) { \ + val = simd_scan(val); \ + return simd_shuffle_and_fill_up(val, init, 1); \ + } + template struct CumSum { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + static constexpr constant U init = static_cast(0); template @@ -9,17 +40,20 @@ struct CumSum { return a + b; } - U simd_scan(U x) { + U simd_scan_impl(U x) { return simd_prefix_inclusive_sum(x); } - U simd_exclusive_scan(U x) { + U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_sum(x); } }; template struct CumProd { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + static constexpr constant U init = static_cast(1.0f); template @@ -27,11 +61,11 @@ struct CumProd { return a * b; } - U simd_scan(U x) { + U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } - U simd_exclusive_scan(U x) { + U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_product(x); } }; @@ -47,7 +81,7 @@ struct CumProd { bool simd_scan(bool x) { for (int i = 1; i <= 16; i *= 2) { - bool other = simd_shuffle_up(x, i); + bool other = simd_shuffle_and_fill_up(x, init, i); x &= other; } return x; @@ -70,7 +104,7 @@ struct CumMax { U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); + U other = simd_shuffle_and_fill_up(x, init, i); x = (x >= other) ? x : other; } return x; @@ -93,7 +127,7 @@ struct CumMin { U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); + U other = simd_shuffle_and_fill_up(x, init, i); x = (x <= other) ? x : other; } return x; @@ -178,20 +212,22 @@ template < const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; Op op; // Position the pointers - in += (gid / lsize) * axis_size; - out += (gid / lsize) * axis_size; + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out += offset; // Compute the number of simd_groups - uint simd_groups = lsize / simd_size; + uint simd_groups = lsize.x / simd_size; // Allocate memory U prefix = Op::init; @@ -210,9 +246,9 @@ template < // value // Write block - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Compute the block offset - uint offset = r * lsize * N_READS + lid * N_READS; + uint offset = r * lsize.x * N_READS + lid.x * N_READS; // Read the values if (reverse) { @@ -275,7 +311,7 @@ template < values, out + axis_size - offset - N_READS, offset, axis_size); } } else { - if (lid == 0 && offset == 0) { + if (lid.x == 0 && offset == 0) { out[axis_size - 1] = Op::init; } if ((offset + N_READS + 1) < axis_size) { @@ -298,7 +334,7 @@ template < values, out + offset, offset, axis_size); } } else { - if (lid == 0 && offset == 0) { + if (lid.x == 0 && offset == 0) { out[0] = Op::init; } if ((offset + N_READS + 1) < axis_size) { @@ -332,86 +368,98 @@ template < device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], const constant size_t& stride [[buffer(3)]], - uint2 gid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]]) { + const constant size_t& stride_blocks [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(U); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; Op op; - // Allocate memory - threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32]; - U values[N_READS]; - U prefix[N_READS]; - for (int i = 0; i < N_READS; i++) { + threadgroup U read_buffer[BM * BN_pad]; + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { prefix[i] = Op::init; } // Compute offsets - int offset = gid.y * axis_size * stride; - int global_index_x = gid.x * lsize.y * N_READS; - - for (uint j = 0; j < axis_size; j += simd_size) { + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + threadgroup U* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup U* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { // Calculate the indices for the current thread - uint index_y = j + lid.y; + uint index_y = j + read_offset_y; uint check_index_y = index_y; - uint index_x = global_index_x + lid.x * N_READS; if (reverse) { index_y = axis_size - 1 - index_y; } // Read in SM - if (check_index_y < axis_size && (index_x + N_READS) < stride) { + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; + read_into[i] = in[index_y * stride + i]; } } else { for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; } else { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - Op::init; + read_into[i] = Op::init; } } } threadgroup_barrier(mem_flags::mem_threadgroup); // Read strided into registers - for (int i = 0; i < N_READS; i++) { - values[i] = - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; } - // Do we need the following barrier? Shouldn't all simd threads execute - // simultaneously? simdgroup_barrier(mem_flags::mem_threadgroup); // Perform the scan - for (int i = 0; i < N_READS; i++) { + for (int i = 0; i < n_scans; i++) { values[i] = op.simd_scan(values[i]); values[i] = op(values[i], prefix[i]); prefix[i] = simd_shuffle(values[i], simd_size - 1); } // Write to SM - for (int i = 0; i < N_READS; i++) { - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = - values[i]; + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); // Write to device memory if (!inclusive) { if (check_index_y == 0) { - if ((index_x + N_READS) < stride) { + if ((read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = Op::init; + out[index_y * stride + i] = Op::init; } } else { for (int i = 0; i < N_READS; i++) { - if ((index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = Op::init; + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = Op::init; } } } @@ -424,16 +472,14 @@ template < check_index_y += 1; } } - if (check_index_y < axis_size && (index_x + N_READS) < stride) { + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + out[index_y * stride + i] = read_into[i]; } } else { for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; } } } diff --git a/Source/Cmlx/mlx-generated/metal/scatter.h b/Source/Cmlx/mlx-generated/metal/scatter.h index 6c9f84e8..fe10b705 100644 --- a/Source/Cmlx/mlx-generated/metal/scatter.h +++ b/Source/Cmlx/mlx-generated/metal/scatter.h @@ -4,73 +4,57 @@ #include "indexing.h" -template -METAL_FUNC void scatter_1d_index_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* out_shape [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], - const constant size_t& out_ndim [[buffer(5)]], - const constant int* upd_shape [[buffer(6)]], - const constant size_t& upd_ndim [[buffer(7)]], - const constant size_t& upd_size [[buffer(8)]], - const thread array& idx_buffers, - uint2 gid [[thread_position_in_grid]]) { - Op op; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; i++) { - auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); - out_idx += idx_val * out_strides[i]; - } - - if (upd_ndim > 1) { - auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim); - out_idx += out_offset; - } else { - out_idx += gid.x; - } - - op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx); -} - -template +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + bool UPD_ROW_CONTIG, + int NWORK, + typename LocT> METAL_FUNC void scatter_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* upd_shape [[buffer(3)]], - const constant size_t* upd_strides [[buffer(4)]], - const constant size_t& upd_ndim [[buffer(5)]], - const constant size_t& upd_size [[buffer(6)]], - const constant int* out_shape [[buffer(7)]], - const constant size_t* out_strides [[buffer(8)]], - const constant size_t& out_ndim [[buffer(9)]], - const constant int* axes [[buffer(10)]], + const device T* updates, + device mlx_atomic* out, + const constant int* upd_shape, + const constant size_t* upd_strides, + const constant size_t& upd_ndim, + const constant size_t& upd_size, + const constant int* out_shape, + const constant size_t* out_strides, + const constant size_t& out_ndim, + const constant int* axes, + const constant size_t& idx_size, const thread Indices& indices, uint2 gid [[thread_position_in_grid]]) { Op op; - auto ind_idx = gid.y; - auto ind_offset = gid.x; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } + auto ind_idx = gid.y * NWORK; + LocT out_offset = 0; if (upd_size > 1) { - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - out_idx += out_offset; + out_offset = elem_to_loc( + gid.x, upd_shape + indices.ndim, out_strides, out_ndim); } - auto upd_idx = - elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx); + for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { + LocT out_idx = out_offset; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = indices.row_contiguous[i] + ? ind_idx + : elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += + static_cast(idx_val) * static_cast(out_strides[ax]); + } + auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; + if constexpr (!UPD_ROW_CONTIG) { + upd_idx = + elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + } + op.atomic_update(out, updates[upd_idx], out_idx); + } } diff --git a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h new file mode 100644 index 00000000..8b6af638 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h @@ -0,0 +1,292 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = simd_gid; i < N; i += BN) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + + // Move the pointers to the next kv + keys += stride; + values += stride; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +template +[[kernel]] void sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device float* out [[buffer(3)]], + device float* sums [[buffer(4)]], + device float* maxs [[buffer(5)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant size_t& v_stride, + const constant float& scale, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 8; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int stride = BN * D; + constexpr int blocks = 32; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int block_idx = tid.z; + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D + + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D + + simd_lid * elem_per_thread; + out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread; + sums += head_idx * blocks + block_idx; + maxs += head_idx * blocks + block_idx; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -1e9; + U sum_exp_score = 0; + + // For each key + for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + + // Move the pointers to the next kv + keys += blocks * stride; + values += blocks * stride; + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + +template +[[kernel]] void sdpa_vector_2pass_2( + const device float* partials [[buffer(0)]], + const device float* sums [[buffer(1)]], + const device float* maxs [[buffer(2)]], + device T* out [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int blocks = 32; + + typedef float U; + + thread U o[elem_per_thread]; + threadgroup U outputs[BN * BD]; + + // Adjust positions + const int head_idx = tid.y; + partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += head_idx * blocks; + maxs += head_idx * blocks; + out += head_idx * D + simd_gid * elem_per_thread; + + // First everybody reads the max and sum_exp + U max_score = maxs[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + U sum_exp_score = simd_sum(sums[simd_lid] * factor); + + // Now read the block into registers and then use shared memory to transpose + // it + for (int i = 0; i < elem_per_thread; i++) { + o[i] = partials[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h b/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h new file mode 100644 index 00000000..8851df68 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h @@ -0,0 +1,296 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/attn/loader.h" +#include "../../steel/attn/mma.h" +#include "../../steel/attn/params.h" +#include "../../steel/attn/transforms.h" +#include "../../steel/gemm/params.h" +#include "../../steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h new file mode 100644 index 00000000..c5c69c30 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h @@ -0,0 +1,349 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Seqeunce + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Seqeunce + + // Prepare threadgroup memory + constexpr short padQ = 0; // 16 / sizeof(T); + constexpr short padK = 0; // 16 / sizeof(T); + constexpr short padV = 0; // 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + threadgroup T Qs[BQ * (BD + padQ)]; + threadgroup T Ks[(BK + padK) * BD]; + threadgroup T Vs[BK * (BD + padV)]; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale)); + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks apply scale + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ)); + } else { + loader_q.load_unsafe(); + } + loader_q.apply_inplace_op(ts); + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::min; + } + + // Loop over KV seq length + for (int kb = 0; kb < params->NK; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + } else { + loader_k.load_unsafe(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do S = Q @ K.T + Stile.clear(); + + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Mask out of length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + const short lim = params->kL - params->NK_aligned * BK; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= lim) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + } else { + loader_v.load_unsafe(); + } + + // Do softmax + + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Stile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp(max_score[i] - new_max[i]); + } + + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); + + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } + + // Update O + Otile.template row_bin_op(factor); + + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); + Vtile.template load(&Vs[Vs_offset]); + + simdgroup_barrier(mem_flags::mem_none); + + // Do O = S @ V + tile_matmad(Otile, Stile, Vtile, Otile); + + // Prepare for next iteration + loader_k.next(); + loader_v.next(); + } + + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = + short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); + } else { + Otile.template store(O, params->O_strides[2]); + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h new file mode 100644 index 00000000..75d695e6 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h @@ -0,0 +1,264 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; + +template < + typename T, + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h b/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h new file mode 100644 index 00000000..621a0e1e --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h @@ -0,0 +1,726 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "../../steel/attn/transforms.h" +#include "../../steel/defines.h" +#include "../../steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/params.h b/Source/Cmlx/mlx-generated/metal/steel/attn/params.h new file mode 100644 index 00000000..a9d7c7b4 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/params.h @@ -0,0 +1,36 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Attn param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + size_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + size_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + size_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +} // namespace steel +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h b/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h new file mode 100644 index 00000000..3d8ca054 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "../../steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h index d52b2654..9261b871 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h @@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general( // Store results to device memory { // Adjust for simdgroup and thread locatio - int offset_m = c_row + mma_op.sm + mma_op.tm; - int offset_n = c_col + mma_op.sn + mma_op.tn; + int offset_m = c_row + mma_op.sm; + int offset_n = c_col + mma_op.sn; C += offset_n; if (offset_n >= gemm_params->N) @@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general( STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = - mma_op.results[i * mma_t::TN + j].thread_elements(); + thread const auto& accum = mma_op.Ctile.frag_at(i, j); int offset = offset_cm + (j * mma_t::TN_stride); - // Apply epilogue and output C - if (j * mma_t::TN_stride < diff) { - C[offset] = Epilogue::apply(accum[0]); - } + constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; - if (j * mma_t::TN_stride + 1 < diff) { - C[offset + 1] = Epilogue::apply(accum[1]); + // Apply epilogue and output C + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * mma_t::TN_stride + k) < diff) { + C[offset + k] = Epilogue::apply(accum[k]); + } } } } diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h index 948bbc61..3852fc87 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h @@ -8,6 +8,7 @@ #include "../../steel/defines.h" #include "../../steel/gemm/transforms.h" +#include "../../steel/utils/integral_constant.h" using namespace metal; @@ -18,6 +19,347 @@ using namespace metal; namespace mlx { namespace steel { +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + template < typename T, typename U, @@ -33,39 +375,38 @@ template < typename AccumType = float, typename Epilogue = TransformNone> struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; + STEEL_CONST short TM_stride = kFragSize * WM; // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; + STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; + STEEL_CONST short TM = BM / (kFragSize * WM); // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; + STEEL_CONST short TN = BN / (kFragSize * WN); - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; + MMATile Atile; + MMATile Btile; + MMATile Ctile; // Offsets within threadgroup - const short tm; - const short tn; - short sm; short sn; @@ -75,18 +416,21 @@ struct BlockMMA { /* Constructor */ METAL_FUNC BlockMMA( ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + ushort simd_lane_id [[thread_index_in_simdgroup]]) { // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; } /* (BM, BK) X (BK, BN) multiply accumulate function */ @@ -95,47 +439,20 @@ struct BlockMMA { As += As_offset; Bs += Bs_offset; - // Iterate over BK in blocks of 8 + // Iterate over BK in blocks of kFragSize STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { + for (short kk = 0; kk < BK; kk += kFragSize) { simdgroup_barrier(mem_flags::mem_none); - // Load elements from threadgroup A as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } + Atile.template load(As); simdgroup_barrier(mem_flags::mem_none); - // Load elements from threadgroup B as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } + Btile.template load(Bs); simdgroup_barrier(mem_flags::mem_none); - // Multiply and accumulate into result simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } + tile_matmad(Ctile, Atile, Btile, Ctile); // Progress to next simdgroup tile As += tile_stride_a; @@ -144,58 +461,35 @@ struct BlockMMA { } /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* D, const int ldd) const { - // Adjust for simdgroup and thread location - D += (sm + tm) * ldd + tn + sn; - - // Loop over all simdgroup tiles + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + // Adjust for simdgroup and thread location + D += sm * ldd + sn; - // Write out D - D[offset] = outs[0]; - D[offset + 1] = outs[1]; - } - } + Ctile.template store(D, ldd); } METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + // Adjust for simdgroup and thread location - D += (sm + tm) * ldd + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } + Ctile.template store_safe(D, ldd, dst_tile_dims); } /* Apply epilogue */ @@ -203,16 +497,8 @@ struct BlockMMA { METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - - // Apply epilogue - accum[0] = epilogue_op.apply(accum[0]); - accum[1] = epilogue_op.apply(accum[1]); - } + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); } } @@ -224,7 +510,7 @@ struct BlockMMA { const int fdc, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; + C += (sm)*ldc + (sn)*fdc; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -232,12 +518,14 @@ struct BlockMMA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; // Apply epilogue - accum[0] = epilogue_op.apply(accum[0], C[offset_c]); - accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -251,8 +539,8 @@ struct BlockMMA { short2 dst_tile_dims, thread const BinaryEpilogue& epilogue_op) { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; @@ -263,22 +551,26 @@ struct BlockMMA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); + thread auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + // Read C - U c_elems[2] = {0}; + U c_elems[kelems] = {0}; - if ((j * TN_stride + 1) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; - c_elems[1] = C[offset_c + fdc]; - } else if ((j * TN_stride) < dst_tile_dims.x) { - c_elems[0] = C[offset_c]; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } } // Apply epilogue - accum[0] = epilogue_op.apply(accum[0], c_elems[0]); - accum[1] = epilogue_op.apply(accum[1], c_elems[1]); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } } } } @@ -292,8 +584,10 @@ struct BlockMMA { const int fdc, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; // Loop over all simdgroup tiles STEEL_PRAGMA_UNROLL @@ -301,18 +595,15 @@ struct BlockMMA { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } @@ -326,30 +617,32 @@ struct BlockMMA { short2 dst_tile_dims, thread const Epilogue& epilogue_op) const { // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + STEEL_PRAGMA_UNROLL for (int i = 0; i < TM; i++) { if (i * TM_stride < dst_tile_dims.y) { STEEL_PRAGMA_UNROLL for (int j = 0; j < TN; j++) { // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); + thread const auto& accum = Ctile.frag_at(i, j); int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } } } } diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h b/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h new file mode 100644 index 00000000..0d55c399 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h @@ -0,0 +1,96 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include "../../steel/utils/type_traits.h" + +#pragma METAL internals : enable + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +struct integral_constant { + static constexpr constant T value = v; + using value_type = T; + using type = integral_constant; + + METAL_FUNC constexpr operator value_type() const noexcept { + return value; + } + + // METAL_FUNC constexpr value_type operator()() const noexcept { + // return value; + // } +}; + +template +using bool_constant = integral_constant; +using true_type = bool_constant; +using false_type = bool_constant; + +template +struct is_integral : bool_constant::value> {}; + +template +struct is_integral> + : bool_constant::value> {}; + +template +constexpr constant bool is_integral_v = is_integral::value; + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +} // namespace steel +} // namespace mlx + +#pragma METAL internals : disable \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h b/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h new file mode 100644 index 00000000..f004dc83 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h @@ -0,0 +1,55 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#pragma METAL internals : enable + +namespace metal { + +template +struct is_empty : metal::bool_constant<__is_empty(T)> {}; + +#ifdef __cpp_variable_templates +template +constexpr constant bool is_empty_v = is_empty::value; +#endif + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; + +template +struct is_static : metal::bool_constant>::value> {}; + +template +struct pointer_element {}; + +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; + +template +using pointer_element_t = typename pointer_element>::type; + +} // namespace metal + +#pragma METAL internals : disable \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/ternary.h b/Source/Cmlx/mlx-generated/metal/ternary.h index 2bd1242c..e19ea23d 100644 --- a/Source/Cmlx/mlx-generated/metal/ternary.h +++ b/Source/Cmlx/mlx-generated/metal/ternary.h @@ -32,13 +32,13 @@ template constant const size_t& b_strides, constant const size_t& c_strides, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd2( device const bool* a, device const T* b, @@ -49,14 +49,14 @@ template constant const size_t c_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - auto c_idx = elem_to_loc_2(index, c_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd3( device const bool* a, device const T* b, @@ -67,15 +67,14 @@ template constant const size_t c_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - auto c_idx = elem_to_loc_3(index, c_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g( device const bool* a, device const T* b, @@ -88,7 +87,7 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_3_nd( + auto idx = elem_to_loc_3_nd( {N * index.x, index.y, index.z}, shape, a_strides, @@ -96,11 +95,10 @@ template c_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; - auto c_xstride = c_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + IdxT c_xstride = c_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); idx.x += a_xstride; diff --git a/Source/Cmlx/mlx-generated/metal/unary.h b/Source/Cmlx/mlx-generated/metal/unary.h index 8d404ae2..acfe176e 100644 --- a/Source/Cmlx/mlx-generated/metal/unary.h +++ b/Source/Cmlx/mlx-generated/metal/unary.h @@ -1,38 +1,42 @@ // Copyright © 2024 Apple Inc. -template +template [[kernel]] void unary_v( device const T* in, - device T* out, + device U* out, uint index [[thread_position_in_grid]]) { out[index] = Op()(in[index]); } -template +template [[kernel]] void unary_v2( device const T* in, - device T* out, + device U* out, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { size_t offset = index.x + grid_dim.x * size_t(index.y); out[offset] = Op()(in[offset]); } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void unary_g( device const T* in, - device T* out, + device U* out, constant const int* in_shape, constant const size_t* in_strides, device const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto idx = elem_to_loc( + {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); auto xshape = in_shape[ndim - 1]; - auto xstride = in_strides[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + IdxT xstride = in_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { out[out_idx++] = Op()(in[idx]); idx += xstride; diff --git a/Source/Cmlx/mlx-generated/metal/unary_ops.h b/Source/Cmlx/mlx-generated/metal/unary_ops.h index b7346405..87df55a9 100644 --- a/Source/Cmlx/mlx-generated/metal/unary_ops.h +++ b/Source/Cmlx/mlx-generated/metal/unary_ops.h @@ -238,6 +238,13 @@ struct Floor { }; }; +struct Imag { + template + T operator()(T x) { + return x.imag; + }; +}; + struct Log { template T operator()(T x) { @@ -280,6 +287,13 @@ struct Negative { }; }; +struct Real { + template + T operator()(T x) { + return x.real; + }; +}; + struct Round { template T operator()(T x) { diff --git a/Source/Cmlx/mlx-generated/metal/utils.h b/Source/Cmlx/mlx-generated/metal/utils.h index e94901c9..30e155b2 100644 --- a/Source/Cmlx/mlx-generated/metal/utils.h +++ b/Source/Cmlx/mlx-generated/metal/utils.h @@ -3,7 +3,13 @@ #pragma once #include + +// The correct bf16.h is included based on the metal version +// by giving the correct path to -I during compilation +// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0 #include "bf16.h" + +#include "bf16_math.h" #include "complex.h" #include "defines.h" @@ -83,44 +89,45 @@ struct Limits { /////////////////////////////////////////////////////////////////////////////// // Single Array with generic dims -template -METAL_FUNC stride_t elem_to_loc( +template +METAL_FUNC IdxT elem_to_loc( uint elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = 0; + IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; + loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } -template -METAL_FUNC stride_t elem_to_loc( - stride_t elem, +template +METAL_FUNC IdxT elem_to_loc( + StrideT elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = 0; + IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; + loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } // Non templated version to handle arbitrary dims -template -METAL_FUNC stride_t elem_to_loc( +template +METAL_FUNC IdxT elem_to_loc( uint3 elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + IdxT loc = + elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); for (int d = ndim - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; + loc += (elem.z % shape[d]) * IdxT(strides[d]); elem.z /= shape[d]; } return loc; @@ -129,61 +136,65 @@ METAL_FUNC stride_t elem_to_loc( /////////////////////////////////////////////////////////////////////////////// // Single Array with fixed N dims -template -METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) { - return elem * stride; +template +METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) { + return elem * IdxT(stride); } -template -METAL_FUNC stride_t -elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) { - return elem.x * strides[1] + elem.y * strides[0]; +template +METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) { + return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); } -template -METAL_FUNC stride_t -elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { - return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +template +METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) { + return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + + elem.z * IdxT(strides[0]); } /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims -template -METAL_FUNC ulong2 elem_to_loc_2_nd( +template +METAL_FUNC vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, - constant const stride_t* a_strides, - constant const stride_t* b_strides, + constant const StrideT* a_strides, + constant const StrideT* b_strides, int ndim) { - ulong2 loc = { - ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + vec loc = { + IdxT( + elem.x * IdxT(a_strides[ndim - 1]) + + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), + IdxT( + elem.x * IdxT(b_strides[ndim - 1]) + + elem.y * IdxT(b_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); elem.z /= shape[d]; } return loc; } -METAL_FUNC ulong3 elem_to_loc_3_nd( +template +METAL_FUNC vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, constant const size_t* b_strides, constant const size_t* c_strides, int ndim) { - ulong3 loc = { - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2], - elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]}; + vec loc = { + elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]), + elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]), + elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - loc.z += l * c_strides[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + loc.z += l * IdxT(c_strides[d]); elem.z /= shape[d]; } return loc; @@ -193,16 +204,21 @@ METAL_FUNC ulong3 elem_to_loc_3_nd( // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// -template -struct looped_elem_to_loc { - looped_elem_to_loc inner_looper; - offset_t offset{0}; +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; int index{0}; + LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + void next(const constant int* shape, const constant size_t* strides) { + if (dim == 0) { + return; + } index++; - offset += strides[dim - 1]; - + offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); @@ -211,13 +227,21 @@ struct looped_elem_to_loc { } void next(int n, const constant int* shape, const constant size_t* strides) { + if (dim == 0) { + return; + } index += n; - offset += n * strides[dim - 1]; + offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } index = 0; - inner_looper.next(shape, strides); offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); @@ -225,41 +249,58 @@ struct looped_elem_to_loc { } } - offset_t - location(offset_t, const constant int*, const constant size_t*, int) { + OffsetT location() { return offset; } }; -template -struct looped_elem_to_loc<1, offset_t> { - offset_t offset{0}; +template +struct LoopedElemToLoc<1, OffsetT, true> { + int dim; + OffsetT offset{0}; + uint index{0}; - void next(const constant int*, const constant size_t* strides) { - offset += strides[0]; + LoopedElemToLoc(int dim) : dim(dim) {} + + void next(const constant int* shape, const constant size_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } } - void next(int n, const constant int*, const constant size_t* strides) { - offset += n * strides[0]; + void next(int n, const constant int* shape, const constant size_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } } - offset_t - location(offset_t, const constant int*, const constant size_t*, int) { + OffsetT location() { return offset; } }; -template -struct looped_elem_to_loc<0, offset_t> { - void next(const constant int*, const constant size_t*) {} - void next(int, const constant int*, const constant size_t*) {} - - offset_t location( - offset_t idx, - const constant int* shape, - const constant size_t* strides, - int ndim) { - return elem_to_loc(idx, shape, strides, ndim); +template +struct LoopedElemToLoc<1, OffsetT, false> { + OffsetT offset{0}; + + LoopedElemToLoc(int) {} + + void next(const constant int*, const constant size_t* strides) { + offset += OffsetT(strides[0]); + } + + void next(int n, const constant int*, const constant size_t* strides) { + offset += n * OffsetT(strides[0]); + } + + OffsetT location() { + return offset; } }; @@ -320,3 +361,63 @@ inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { return complex64_t( simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); } + +inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline bool simd_shuffle_up(bool data, uint16_t delta) { + return simd_shuffle_up(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); +} + +inline uint64_t +simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline int64_t +simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { + return simd_shuffle_and_fill_up( + static_cast(data), static_cast(filling), delta); +} + +inline complex64_t simd_shuffle_and_fill_up( + complex64_t data, + complex64_t filling, + uint16_t delta) { + return complex64_t( + simd_shuffle_and_fill_up(data.real, filling.real, delta), + simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); +} + +inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline int64_t simd_shuffle(int64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline bool simd_shuffle(bool data, uint16_t lane) { + return simd_shuffle(static_cast(data), lane); +} + +inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { + return complex64_t( + simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); +} diff --git a/Source/Cmlx/mlx-generated/quantized.cpp b/Source/Cmlx/mlx-generated/quantized.cpp index b4de253d..70fec3a0 100644 --- a/Source/Cmlx/mlx-generated/quantized.cpp +++ b/Source/Cmlx/mlx-generated/quantized.cpp @@ -4,11 +4,12 @@ const char* quantized() { return R"preamble( using namespace metal; static constant constexpr const int SIMD_SIZE = 32; +static constant constexpr const int QUAD_SIZE = 4; template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { @@ -19,6 +20,20 @@ inline U load_vector(const device T* x, thread U* x_thread) { x_thread[i + 3] = x[i + 3] / 64.0f; } } + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -28,6 +43,15 @@ inline U load_vector(const device T* x, thread U* x_thread) { x_thread[i + 3] = x[i + 3] / 4096.0f; } } + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { sum += x[i]; @@ -39,8 +63,8 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { @@ -50,8 +74,19 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; + } + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; } } else if (bits == 4) { @@ -62,8 +97,14 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; + } + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; } } else if (bits == 8) { @@ -71,9 +112,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { sum += x[i]; x_thread[i] = x[i]; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; } return sum; } @@ -85,8 +126,8 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { @@ -97,6 +138,22 @@ inline U qdot( x_thread[4 * i + 3] * (w[i] & 0xc0)); } } + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -107,6 +164,18 @@ inline U qdot( x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + accum += (w[0] & 0x3f) * x_thread[0]; + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + accum += (w[2] & 0xfc) * x_thread[3]; + } + } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * w[i]; @@ -123,8 +192,8 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { @@ -135,6 +204,22 @@ inline U qdot_safe( x_thread[4 * i + 3] * (w[i] & 0xc0)); } } + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { @@ -145,6 +230,18 @@ inline U qdot_safe( x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + accum += (w[0] & 0x3f) * x_thread[0]; + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + accum += (w[2] & 0xfc) * x_thread[3]; + } + } else if (bits == 8) { for (int i = 0; i < N; i++) { accum += x_thread[i] * w[i]; @@ -156,8 +253,8 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -167,12 +264,41 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); } } + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } else if (bits == 4) { U s[2] = {scale, scale / 16.0f}; for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + } else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } } else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { @@ -184,8 +310,8 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); if (bits == 2) { U s[4] = { scale, @@ -199,6 +325,20 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; } } + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } else if (bits == 4) { U s[2] = {scale, scale / static_cast(16.0f)}; for (int i = 0; i < (N / 2); i++) { @@ -206,6 +346,16 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } else if (bits == 8) { for (int i = 0; i < N; i++) { w_local[i] = scale * w[i] + bias; @@ -229,9 +379,10 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - static constant constexpr const short pack_factor = 32 / bits; + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); + static constant constexpr const short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + static constant constexpr const short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; static constant constexpr const short BCOLS_PACKED = BCOLS / pack_factor; static constant constexpr const short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -244,11 +395,11 @@ struct QuantizedBlockLoader { const short bi; const short bj; threadgroup T* dst; - const device uint32_t* src; + const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( - const device uint32_t* src_, + const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, @@ -257,14 +408,16 @@ struct QuantizedBlockLoader { ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( - reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld / pack_factor + bj), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), biases(biases_ + bi * src_ld / group_size) {} void load_unsafe() const { @@ -275,7 +428,7 @@ struct QuantizedBlockLoader { T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { @@ -298,7 +451,10 @@ struct QuantizedBlockLoader { T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); } } void next() { @@ -321,6 +477,53 @@ struct QuantizedBlockLoader { } } }; +template +METAL_FUNC void qmv_quad_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid; + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.y * in_vec_size + quad_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + U sum = load_vector(x, x_thread); + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device T* sl = scales + row * in_vec_size_g * quads_per_simd; + const device T* bl = biases + row * in_vec_size_g * quads_per_simd; + U s = sl[0]; + U b = bl[0]; + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} template METAL_FUNC void qmv_fast_impl( const device uint32_t* w, @@ -333,21 +536,24 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int packs_per_thread = bits > 2 ? 2 : 1; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -355,15 +561,14 @@ METAL_FUNC void qmv_fast_impl( for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; U b = bl[0]; result[row] += qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -387,17 +592,20 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; - const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; @@ -406,7 +614,8 @@ METAL_FUNC void qmv_impl( return; } if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -415,8 +624,7 @@ METAL_FUNC void qmv_impl( for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; @@ -424,7 +632,7 @@ METAL_FUNC void qmv_impl( result[row] += qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -433,16 +641,18 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); - for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } } for (int row = 0; out_row + row < out_vec_size; row++) { result[row] = simd_sum(result[row]); @@ -452,7 +662,8 @@ METAL_FUNC void qmv_impl( } } else { - w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -461,8 +672,7 @@ METAL_FUNC void qmv_impl( for (; k < in_vec_size - block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; @@ -470,7 +680,7 @@ METAL_FUNC void qmv_impl( result[row] += qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -479,17 +689,18 @@ METAL_FUNC void qmv_impl( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); - U sum = - load_vector_safe(x, x_thread, remaining); - for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - U s = sl[0]; - U b = bl[0]; - result[row] += qdot_safe( - wl, x_thread, s, b, sum, remaining); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); @@ -501,34 +712,36 @@ METAL_FUNC void qmv_impl( } template METAL_FUNC void qvm_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, + const int in_vec_size, + const int out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int tn = 32 / pack_factor; - constexpr int blocksize = SIMD_SIZE; + constexpr int block_size = SIMD_SIZE; + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; typedef struct { - uint32_t wi[tn]; + uint8_t wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; thread U result[tn * pack_factor] = {0}; thread U scale = 1; thread U bias = 0; thread U x_local = 0; - const int out_vec_size_w = out_vec_size / pack_factor; + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col = - tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; - w += out_col / pack_factor + simd_lid * out_vec_size_w; + int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g; x += tid.y * in_vec_size + simd_lid; @@ -536,38 +749,38 @@ METAL_FUNC void qvm_impl( if (out_col >= out_vec_size) { return; } - int remaining = in_vec_size % blocksize; + int remaining = in_vec_size % block_size; if (remaining == 0) { - for (int i = 0; i < in_vec_size; i += blocksize) { + for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; } } else { - for (int i = blocksize; i < in_vec_size; i += blocksize) { + for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; } if (static_cast(simd_lid) < remaining) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; @@ -596,16 +809,16 @@ template < const int BK = 32, const int BN = 32> METAL_FUNC void qmm_t_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, - const constant int& M, - const constant int& N, const constant int& K, + const constant int& N, + const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -615,8 +828,9 @@ METAL_FUNC void qmm_t_impl( (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; using mma_t = mlx::steel:: BlockMMA; using loader_x_t = @@ -630,19 +844,20 @@ METAL_FUNC void qmm_t_impl( WM * WN * SIMD_SIZE, group_size, bits>; - const int K_w = K / pack_factor; + const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; + auto wl = (const device uint8_t*)w; x += y_row * K; - w += y_col * K_w; + wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * N + y_col; const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if (!aligned_N && num_outs < BN) { @@ -704,16 +919,16 @@ template < const int BK = 32, const int BN = 32> METAL_FUNC void qmm_n_impl( - const device T* x, const device uint32_t* w, const device T* scales, const device T* biases, + const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, - const constant int& M, - const constant int& N, const constant int& K, + const constant int& N, + const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -723,9 +938,11 @@ METAL_FUNC void qmm_n_impl( (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel:: @@ -739,16 +956,17 @@ METAL_FUNC void qmm_n_impl( WM * WN * SIMD_SIZE, group_size, bits>; + auto wl = (const device uint8_t*)w; const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * K; - w += y_col / pack_factor; + wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; y += y_row * N + y_col; const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if ((K % BK) != 0) { @@ -817,6 +1035,43 @@ METAL_FUNC void qmm_n_impl( } } template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant size_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant size_t* w_strides, + const constant size_t* s_strides, + const constant size_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} +template METAL_FUNC void adjust_matrix_offsets( const device T*& x, const device uint32_t*& w, @@ -868,7 +1123,57 @@ METAL_FUNC void adjust_matrix_offsets( } y += tid.z * output_stride; } -template +template +[[kernel]] void qmv_quad( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_quad_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid); +} +template [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -877,9 +1182,35 @@ template device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmv_fast_impl( w, scales, @@ -892,7 +1223,7 @@ template simd_gid, simd_lid); } -template +template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -901,9 +1232,35 @@ template device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmv_impl( w, scales, @@ -916,23 +1273,49 @@ template simd_gid, simd_lid); } -template +template [[kernel]] void qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], const constant int& in_vec_size [[buffer(5)]], const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qvm_impl( - x, w, scales, biases, + x, y, in_vec_size, out_vec_size, @@ -940,23 +1323,83 @@ template simd_gid, simd_lid); } +template +[[kernel]] void qvm_split_k( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& final_block_size [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid); +} template < typename T, const int group_size, const int bits, const bool aligned_N, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], + const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -965,25 +1408,52 @@ template < constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, const int group_size, const int bits, + const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> [[kernel]] void qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], + const constant int& K [[buffer(5)]], const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -993,8 +1463,26 @@ template < constexpr int BN_padded = (BN + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template [[kernel]] void bs_qmv_fast( @@ -1002,23 +1490,23 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1062,23 +1550,23 @@ template const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1118,27 +1606,27 @@ template } template [[kernel]] void bs_qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& batch_ndims [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant size_t* lhs_strides [[buffer(11)]], - const constant size_t* rhs_strides [[buffer(12)]], - const constant int& x_batch_ndims [[buffer(13)]], - const constant int* x_shape [[buffer(14)]], - const constant size_t* x_strides [[buffer(15)]], - const constant int& w_batch_ndims [[buffer(16)]], - const constant int* w_shape [[buffer(17)]], - const constant size_t* w_strides [[buffer(18)]], - const constant size_t* s_strides [[buffer(19)]], - const constant size_t* b_strides [[buffer(20)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant size_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant size_t* w_strides [[buffer(12)]], + const constant size_t* s_strides [[buffer(13)]], + const constant size_t* b_strides [[buffer(14)]], + const constant int& batch_ndims [[buffer(15)]], + const constant int* batch_shape [[buffer(16)]], + const device uint32_t* lhs_indices [[buffer(17)]], + const device uint32_t* rhs_indices [[buffer(18)]], + const constant size_t* lhs_strides [[buffer(19)]], + const constant size_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1165,10 +1653,10 @@ template b_strides, tid); qvm_impl( - x, w, scales, biases, + x, y, in_vec_size, out_vec_size, @@ -1185,28 +1673,28 @@ template < const int BK = 32, const int BN = 32> [[kernel]] void bs_qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], + const constant int& batch_ndims [[buffer(16)]], + const constant int* batch_shape [[buffer(17)]], + const device uint32_t* lhs_indices [[buffer(18)]], + const device uint32_t* rhs_indices [[buffer(19)]], + const constant size_t* lhs_strides [[buffer(20)]], + const constant size_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1238,7 +1726,7 @@ template < b_strides, tid); qmm_t_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template < typename T, @@ -1248,28 +1736,28 @@ template < const int BK = 32, const int BN = 32> [[kernel]] void bs_qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], const constant int& M [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& K [[buffer(9)]], - const constant int& batch_ndims [[buffer(10)]], - const constant int* batch_shape [[buffer(11)]], - const constant size_t* lhs_strides [[buffer(12)]], - const constant size_t* rhs_strides [[buffer(13)]], - const constant int& x_batch_ndims [[buffer(14)]], - const constant int* x_shape [[buffer(15)]], - const constant size_t* x_strides [[buffer(16)]], - const constant int& w_batch_ndims [[buffer(17)]], - const constant int* w_shape [[buffer(18)]], - const constant size_t* w_strides [[buffer(19)]], - const constant size_t* s_strides [[buffer(20)]], - const constant size_t* b_strides [[buffer(21)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant size_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant size_t* w_strides [[buffer(13)]], + const constant size_t* s_strides [[buffer(14)]], + const constant size_t* b_strides [[buffer(15)]], + const constant int& batch_ndims [[buffer(16)]], + const constant int* batch_shape [[buffer(17)]], + const device uint32_t* lhs_indices [[buffer(18)]], + const device uint32_t* rhs_indices [[buffer(19)]], + const constant size_t* lhs_strides [[buffer(20)]], + const constant size_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1302,7 +1790,7 @@ template < b_strides, tid); qmm_n_impl( - x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template [[kernel]] void affine_quantize( @@ -1314,19 +1802,22 @@ template uint2 grid_dim [[threads_per_grid]]) { constexpr T eps = T(1e-7); constexpr int simd_size = 32; - constexpr int uint8_bits = 8; constexpr T n_bins = (1 << bits) - 1; - constexpr int packs_per_int = uint8_bits / bits; + constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = packs_per_int / values_per_reduce; constexpr int writes_per_pack = writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, "Group size must be divisible by simd size."); size_t offset = index.x + grid_dim.x * size_t(index.y); size_t in_index = offset * values_per_reduce; - size_t out_index = offset * writes_per_pack; + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; T w_thread[values_per_reduce]; T w_min = Limits::max; T w_max = 0; @@ -1352,7 +1843,7 @@ template scales[gindex] = scale; biases[gindex] = bias; } - uint8_t output = 0; + uint32_t output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); @@ -1367,43 +1858,23 @@ template output = 0; } else { #pragma clang loop unroll(full) - for (int j = 0; j < writes_per_reduce - 1; j++) { - uint8_t sval = simd_shuffle_down(val, j + 1); - output += sval << (bits * (values_per_reduce + j + i)); + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = simd_shuffle_down(val, j); + output += sval << (bits * (j * values_per_reduce + i)); } } } - if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { - out[out_index / writes_per_reduce] = output; - } -} -template -[[kernel]] void affine_quantize_scales_biases( - const device T* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - device uint8_t* out [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr int uint8_bits = 8; - constexpr int packs_per_int = uint8_bits / bits; - constexpr T n_bins = (1 << bits) - 1; - size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t in_index = offset * packs_per_int; - size_t gindex = in_index / group_size; - T scale = scales[gindex]; - T bias = biases[gindex]; - uint8_t output = 0; -#pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { - uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins); - if (bits == 8) { - output = val; - } else { - output += val << (bits * i); + if (bits == 3 || bits == 6) { + if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else { + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; } } - out[offset] = output; } template [[kernel]] void affine_dequantize( @@ -1413,25 +1884,45 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int uint8_bits = 8; - constexpr int packs_per_int = uint8_bits / bits; + constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; size_t offset = index.x + grid_dim.x * size_t(index.y); size_t oindex = offset * packs_per_int; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; - uint val = w[offset]; + out += oindex; + if (bits == 3) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x7) * scale + bias; + out[1] = ((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + out[3] = ((w[1] & 0xe) >> 1) * scale + bias; + out[4] = ((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } else if (bits == 6) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x3f) * scale + bias; + out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { - uint8_t d; - if (bits == 2) { - d = (val >> (bits * i)) & 0x03; - } else if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; + for (int i = 0; i < packs_per_int; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * d + bias; } - out[oindex + i] = scale * d + bias; } } )preamble"; diff --git a/Source/Cmlx/mlx-generated/reduce.cpp b/Source/Cmlx/mlx-generated/reduce.cpp index 33eaa583..896d19db 100644 --- a/Source/Cmlx/mlx-generated/reduce.cpp +++ b/Source/Cmlx/mlx-generated/reduce.cpp @@ -3,7 +3,12 @@ namespace mlx::core::metal { const char* reduce() { return R"preamble( -template +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS> [[kernel]] void all_reduce( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -18,10 +23,10 @@ template Op op; threadgroup U shared_vals[simd_size]; U total = Op::init; - int64_t start_idx = gid.y * row_size; - int64_t actual_row = + IdxT start_idx = gid.y * IdxT(row_size); + IdxT actual_row = (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; - int64_t blocks = actual_row / (lsize.x * N_READS); + IdxT blocks = actual_row / (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS); extra -= lid.x * N_READS; start_idx += lid.x * N_READS; @@ -30,7 +35,7 @@ template blocks++; extra = 0; } - for (int64_t b = 0; b < blocks; b++) { + for (IdxT b = 0; b < blocks; b++) { for (int i = 0; i < N_READS; i++) { total = op(static_cast(in[i]), total); } @@ -54,12 +59,7 @@ template out[gid.y] = total; } } -template < - typename T, - typename U, - typename Op, - int NDIMS, - int N_READS = REDUCE_N_READS> +template [[kernel]] void col_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -74,145 +74,128 @@ template < const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[thread_position_in_grid]], - uint3 tsize [[threads_per_grid]]) { + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + constexpr int n_reads = 4; Op op; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) { - U totals[31]; - for (int i = 0; i < 31; i++) { - totals[i] = Op::init; - } - short stride = reduction_stride; - short size = reduction_size; - short blocks = stride / N_READS; - short extra = stride - blocks * N_READS; - size_t out_idx = tid.x + tsize.y * size_t(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); - for (uint r = 0; r < non_col_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - for (short i = 0; i < size; i++) { - for (short j = 0; j < blocks; j++) { - for (short k = 0; k < N_READS; k++) { - totals[j * N_READS + k] = - op(totals[j * N_READS + k], - static_cast(row[i * stride + j * N_READS + k])); - } - } - for (short k = 0; k < extra; k++) { - totals[blocks * N_READS + k] = - op(totals[blocks * N_READS + k], - static_cast(row[i * stride + blocks * N_READS + k])); - } + U totals[n_reads]; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; + if (column >= reduction_stride) { + return; + } + bool safe = column + n_reads <= reduction_stride; + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(lid.y, reduce_shape, reduce_strides); + for (IdxT r = lid.y; r < total_rows; r += lsize.y) { + row = in + loop.location(); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); } - loop.next(reduce_shape, reduce_strides); - } - out += out_idx * reduction_stride; - for (short j = 0; j < stride; j++) { - out[j] = totals[j]; } + loop.next(lsize.y, reduce_shape, reduce_strides); } - else if (reduction_size * non_col_reductions < 32) { - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = Op::init; + if (lsize.y > 1) { + threadgroup U shared_vals[32 * 8 * n_reads]; + for (int i = 0; i < n_reads; i++) { + shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; } - short size = reduction_size; - size_t offset = size_t(tid.x) * N_READS; - bool safe = offset + N_READS <= reduction_stride; - short extra = reduction_stride - offset; - size_t out_idx = tid.y + tsize.z * size_t(tid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + offset; - for (uint r = 0; r < non_col_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - if (safe) { - for (short i = 0; i < size; i++) { - for (short j = 0; j < N_READS; j++) { - totals[j] = - op(static_cast(row[i * reduction_stride + j]), totals[j]); - } - } - } else { - for (short i = 0; i < size; i++) { - for (short j = 0; j < extra; j++) { - totals[j] = - op(static_cast(row[i * reduction_stride + j]), totals[j]); - } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (int i = 0; i < n_reads; i++) { + totals[i] = shared_vals[lid.x * n_reads + i]; + } + for (uint j = 1; j < lsize.y; j++) { + for (int i = 0; i < n_reads; i++) { + totals[i] = + op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], + totals[i]); } } - loop.next(reduce_shape, reduce_strides); } - out += out_idx * reduction_stride + offset; + } + if (lid.y == 0) { + out += out_idx * IdxT(reduction_stride) + column; if (safe) { - for (short i = 0; i < N_READS; i++) { + for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; } } else { - for (short i = 0; i < extra; i++) { + for (int i = 0; column + i < reduction_stride; i++) { out[i] = totals[i]; } } } - else { - threadgroup U shared_vals[1024]; - U totals[N_READS]; - for (int i = 0; i < N_READS; i++) { - totals[i] = Op::init; - } - short stride = reduction_stride; - short lid = simd_group_id * simd_size + simd_lane_id; - short2 tile((stride + N_READS - 1) / N_READS, 32); - short2 offset((lid % tile.x) * N_READS, lid / tile.x); - short sm_stride = tile.x * N_READS; - bool safe = offset.x + N_READS <= stride; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x; - size_t total = non_col_reductions * reduction_size; - loop.next(offset.y, reduce_shape, reduce_strides); - for (size_t r = offset.y; r < total; r += simd_size) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); - if (safe) { - for (int i = 0; i < N_READS; i++) { - totals[i] = op(static_cast(row[i]), totals[i]); - } - } else { - U vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = (offset.x + i < stride) ? static_cast(row[i]) : op.init; - } - for (int i = 0; i < N_READS; i++) { - totals[i] = op(vals[i], totals[i]); - } - } - loop.next(simd_size, reduce_shape, reduce_strides); - } - for (int i = 0; i < N_READS; i++) { - shared_vals[offset.y * sm_stride + offset.x + i] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_READS; i++) { - totals[i] = op.simd_reduce( - shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]); - } - if (simd_lane_id == 0) { - short column = simd_group_id * N_READS; - out += out_idx * reduction_stride + column; - if (column + N_READS <= stride) { - for (int i = 0; i < N_READS; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; column + i < stride; i++) { - out[i] = totals[i]; - } - } - } +} +template +[[kernel]] void col_reduce_longcolumn( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + Op op; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + lid.x; + U total = Op::init; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); + for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; + r += lsize.y * gsize.z) { + row = in + loop.location(); + total = op(static_cast(*row), total); + loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); + } + threadgroup U shared_vals[32 * 32]; + shared_vals[lid.y * lsize.x + lid.x] = total; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (uint i = 1; i < lsize.y; i++) { + total = op(total, shared_vals[i * lsize.x + lid.x]); + } + out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = + total; } } -template +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> [[kernel]] void col_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], @@ -230,28 +213,28 @@ template uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - constexpr int n_simdgroups = 4; + constexpr int n_simdgroups = 8; constexpr short tgp_size = n_simdgroups * simd_size; constexpr short n_reads = (BM * BN) / tgp_size; constexpr short n_read_blocks = BN / n_reads; threadgroup U shared_vals[BN * BM]; U totals[n_reads]; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; for (int i = 0; i < n_reads; i++) { totals[i] = Op::init; } short lid = simd_group_id * simd_size + simd_lane_id; short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); - size_t column = BN * gid.x + offset.x; + IdxT column = BN * gid.x + offset.x; bool safe = column + n_reads <= reduction_stride; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; - size_t total = non_col_reductions * reduction_size; + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); loop.next(offset.y, reduce_shape, reduce_strides); - for (size_t r = offset.y; r < total; r += BM) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT r = offset.y; r < total; r += BM) { + row = in + loop.location(); if (safe) { for (int i = 0; i < n_reads; i++) { totals[i] = op(static_cast(row[i]), totals[i]); @@ -283,8 +266,8 @@ template op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); } if (simd_lane_id == 0) { - size_t out_column = BN * gid.x + out_offset.x; - out += out_idx * reduction_stride + out_column; + IdxT out_column = BN * gid.x + out_offset.x; + out += out_idx * IdxT(reduction_stride) + out_column; if (out_column + n_outputs <= reduction_stride) { for (int i = 0; i < n_outputs; i++) { out[i] = totals[i]; @@ -311,7 +294,7 @@ template } } if (offset.y == 0) { - out += out_idx * reduction_stride + column; + out += out_idx * IdxT(reduction_stride) + column; if (safe) { for (int i = 0; i < n_reads; i++) { out[i] = totals[i]; @@ -324,6 +307,98 @@ template } } } +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> +[[kernel]] void col_reduce_2pass( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant size_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + constexpr int n_outputs = BN / n_simdgroups; + constexpr short outer_blocks = 32; + static_assert(BM == 32, "BM should be equal to 32"); + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + IdxT column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT block_idx = full_idx / IdxT(out_size); + IdxT out_idx = full_idx % IdxT(out_size); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); + for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { + row = in + loop.location(); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + loop.next(outer_blocks * BM, reduce_shape, reduce_strides); + } + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + if (simd_lane_id == 0) { + IdxT out_column = BN * gid.x + out_offset.x; + out += full_idx * IdxT(reduction_stride) + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} template [[kernel]] void init_reduce( device T* out [[buffer(0)]], @@ -473,6 +548,7 @@ template < typename T, typename U, typename Op, + typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_small( @@ -493,25 +569,25 @@ template < uint3 tsize [[threads_per_grid]]) { Op op; U total_val = Op::init; - looped_elem_to_loc loop; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - int blocks = row_size / N_READS; - int extra = row_size % N_READS; + int blocks = IdxT(row_size) / N_READS; + int extra = IdxT(row_size) % N_READS; if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { - size_t out_idx = tid.x + tsize.y * size_t(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(reduce_shape, reduce_strides); } out[out_idx] = total_val; } else { - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim); + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); loop.next(simd_lane_id, reduce_shape, reduce_strides); for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { - row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); + row = in + loop.location(); thread_reduce(total_val, row, blocks, extra); loop.next(simd_size, reduce_shape, reduce_strides); } @@ -525,6 +601,7 @@ template < typename T, typename U, typename Op, + typename IdxT = size_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> [[kernel]] void row_reduce_simple( @@ -541,13 +618,13 @@ template < uint simd_group_id [[simdgroup_index_in_threadgroup]]) { threadgroup U shared_vals[simd_size * N_WRITES]; U totals[N_WRITES]; - size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); + IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); if (out_idx + N_WRITES > out_size) { out_idx = out_size - N_WRITES; } - in += out_idx * reduction_size; + in += out_idx * IdxT(reduction_size); out += out_idx; - int blocks = reduction_size / (lsize.x * N_READS); + int blocks = IdxT(reduction_size) / (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS); per_thread_row_reduce( totals, in, reduction_size, blocks, extra, lsize.x, lid.x); @@ -563,6 +640,7 @@ template < typename T, typename U, typename Op, + typename IdxT, int NDIMS, int N_READS = REDUCE_N_READS> [[kernel]] void row_reduce_looped( @@ -586,14 +664,15 @@ template < Op op; threadgroup U shared_vals[simd_size]; U total = Op::init; - size_t out_idx = gid.y + gsize.y * size_t(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; - looped_elem_to_loc loop; + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim) + + lid.x * N_READS; + LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; - int blocks = row_size / (lsize.x * N_READS); + int blocks = IdxT(row_size) / (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS); - for (size_t i = 0; i < non_row_reductions; i++) { - row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); + for (IdxT i = 0; i < non_row_reductions; i++) { + row = in + loop.location(); U row_total; per_thread_row_reduce( &row_total, &row, blocks, extra, lsize.x, lid.x); diff --git a/Source/Cmlx/mlx-generated/scan.cpp b/Source/Cmlx/mlx-generated/scan.cpp index ff8d806c..c62fd78b 100644 --- a/Source/Cmlx/mlx-generated/scan.cpp +++ b/Source/Cmlx/mlx-generated/scan.cpp @@ -4,29 +4,33 @@ const char* scan() { return R"preamble( template struct CumSum { + template = true> T simd_scan(T val) { return simd_scan_impl(val); } template = true> T simd_scan(T val) { for (int i = 1; i <= 16; i *= 2) { val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); } return val; } + template = true> T simd_exclusive_scan(T val) { return simd_exclusive_scan_impl(val); } template = true> T simd_exclusive_scan(T val) { val = simd_scan(val); return simd_shuffle_and_fill_up(val, init, 1); } static constexpr constant U init = static_cast(0); template U operator()(U a, T b) { return a + b; } - U simd_scan(U x) { + U simd_scan_impl(U x) { return simd_prefix_inclusive_sum(x); } - U simd_exclusive_scan(U x) { + U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_sum(x); } }; template struct CumProd { + template = true> T simd_scan(T val) { return simd_scan_impl(val); } template = true> T simd_scan(T val) { for (int i = 1; i <= 16; i *= 2) { val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); } return val; } + template = true> T simd_exclusive_scan(T val) { return simd_exclusive_scan_impl(val); } template = true> T simd_exclusive_scan(T val) { val = simd_scan(val); return simd_shuffle_and_fill_up(val, init, 1); } static constexpr constant U init = static_cast(1.0f); template U operator()(U a, T b) { return a * b; } - U simd_scan(U x) { + U simd_scan_impl(U x) { return simd_prefix_inclusive_product(x); } - U simd_exclusive_scan(U x) { + U simd_exclusive_scan_impl(U x) { return simd_prefix_exclusive_product(x); } }; @@ -39,7 +43,7 @@ struct CumProd { } bool simd_scan(bool x) { for (int i = 1; i <= 16; i *= 2) { - bool other = simd_shuffle_up(x, i); + bool other = simd_shuffle_and_fill_up(x, init, i); x &= other; } return x; @@ -58,7 +62,7 @@ struct CumMax { } U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); + U other = simd_shuffle_and_fill_up(x, init, i); x = (x >= other) ? x : other; } return x; @@ -77,7 +81,7 @@ struct CumMin { } U simd_scan(U x) { for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_up(x, i); + U other = simd_shuffle_and_fill_up(x, init, i); x = (x <= other) ? x : other; } return x; @@ -156,21 +160,23 @@ template < const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; Op op; - in += (gid / lsize) * axis_size; - out += (gid / lsize) * axis_size; - uint simd_groups = lsize / simd_size; + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out += offset; + uint simd_groups = lsize.x / simd_size; U prefix = Op::init; U values[N_READS]; threadgroup U simdgroup_sums[32]; - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { - uint offset = r * lsize * N_READS + lid * N_READS; + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { + uint offset = r * lsize.x * N_READS + lid.x * N_READS; if (reverse) { if ((offset + N_READS) < axis_size) { load_unsafe( @@ -219,7 +225,7 @@ template < values, out + axis_size - offset - N_READS, offset, axis_size); } } else { - if (lid == 0 && offset == 0) { + if (lid.x == 0 && offset == 0) { out[axis_size - 1] = Op::init; } if ((offset + N_READS + 1) < axis_size) { @@ -242,7 +248,7 @@ template < values, out + offset, offset, axis_size); } } else { - if (lid == 0 && offset == 0) { + if (lid.x == 0 && offset == 0) { out[0] = Op::init; } if ((offset + N_READS + 1) < axis_size) { @@ -273,68 +279,82 @@ template < device U* out [[buffer(1)]], const constant size_t& axis_size [[buffer(2)]], const constant size_t& stride [[buffer(3)]], - uint2 gid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]]) { + const constant size_t& stride_blocks [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(U); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; Op op; - threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32]; - U values[N_READS]; - U prefix[N_READS]; - for (int i = 0; i < N_READS; i++) { + threadgroup U read_buffer[BM * BN_pad]; + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { prefix[i] = Op::init; } - int offset = gid.y * axis_size * stride; - int global_index_x = gid.x * lsize.y * N_READS; - for (uint j = 0; j < axis_size; j += simd_size) { - uint index_y = j + lid.y; + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + threadgroup U* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup U* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + for (uint j = 0; j < axis_size; j += BM) { + uint index_y = j + read_offset_y; uint check_index_y = index_y; - uint index_x = global_index_x + lid.x * N_READS; if (reverse) { index_y = axis_size - 1 - index_y; } - if (check_index_y < axis_size && (index_x + N_READS) < stride) { + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; + read_into[i] = in[index_y * stride + i]; } } else { for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - in[offset + index_y * stride + index_x + i]; + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; } else { - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = - Op::init; + read_into[i] = Op::init; } } } threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_READS; i++) { - values[i] = - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i]; + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; } simdgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_READS; i++) { + for (int i = 0; i < n_scans; i++) { values[i] = op.simd_scan(values[i]); values[i] = op(values[i], prefix[i]); prefix[i] = simd_shuffle(values[i], simd_size - 1); } - for (int i = 0; i < N_READS; i++) { - read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = - values[i]; + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (!inclusive) { if (check_index_y == 0) { - if ((index_x + N_READS) < stride) { + if ((read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = Op::init; + out[index_y * stride + i] = Op::init; } } else { for (int i = 0; i < N_READS; i++) { - if ((index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = Op::init; + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = Op::init; } } } @@ -347,16 +367,14 @@ template < check_index_y += 1; } } - if (check_index_y < axis_size && (index_x + N_READS) < stride) { + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { for (int i = 0; i < N_READS; i++) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + out[index_y * stride + i] = read_into[i]; } } else { for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (index_x + i) < stride) { - out[offset + index_y * stride + index_x + i] = - read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i]; + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; } } } diff --git a/Source/Cmlx/mlx-generated/scatter.cpp b/Source/Cmlx/mlx-generated/scatter.cpp index 4bbad074..fa709676 100644 --- a/Source/Cmlx/mlx-generated/scatter.cpp +++ b/Source/Cmlx/mlx-generated/scatter.cpp @@ -7,10 +7,11 @@ struct Indices { const array buffers; const constant int* shapes; const constant size_t* strides; + const constant bool* row_contiguous; const int ndim; }; template -METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { if (is_unsigned_v) { return idx; } else { @@ -18,68 +19,57 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { } } -template -METAL_FUNC void scatter_1d_index_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* out_shape [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], - const constant size_t& out_ndim [[buffer(5)]], - const constant int* upd_shape [[buffer(6)]], - const constant size_t& upd_ndim [[buffer(7)]], - const constant size_t& upd_size [[buffer(8)]], - const thread array& idx_buffers, - uint2 gid [[thread_position_in_grid]]) { - Op op; - size_t out_idx = 0; - for (int i = 0; i < NIDX; i++) { - auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]); - out_idx += idx_val * out_strides[i]; - } - if (upd_ndim > 1) { - auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim); - out_idx += out_offset; - } else { - out_idx += gid.x; - } - op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx); -} -template +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + bool UPD_ROW_CONTIG, + int NWORK, + typename LocT> METAL_FUNC void scatter_impl( - const device T* updates [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* upd_shape [[buffer(3)]], - const constant size_t* upd_strides [[buffer(4)]], - const constant size_t& upd_ndim [[buffer(5)]], - const constant size_t& upd_size [[buffer(6)]], - const constant int* out_shape [[buffer(7)]], - const constant size_t* out_strides [[buffer(8)]], - const constant size_t& out_ndim [[buffer(9)]], - const constant int* axes [[buffer(10)]], + const device T* updates, + device mlx_atomic* out, + const constant int* upd_shape, + const constant size_t* upd_strides, + const constant size_t& upd_ndim, + const constant size_t& upd_size, + const constant int* out_shape, + const constant size_t* out_strides, + const constant size_t& out_ndim, + const constant int* axes, + const constant size_t& idx_size, const thread Indices& indices, uint2 gid [[thread_position_in_grid]]) { Op op; - auto ind_idx = gid.y; - auto ind_offset = gid.x; - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } + auto ind_idx = gid.y * NWORK; + LocT out_offset = 0; if (upd_size > 1) { - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - out_idx += out_offset; + out_offset = elem_to_loc( + gid.x, upd_shape + indices.ndim, out_strides, out_ndim); + } + for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { + LocT out_idx = out_offset; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = indices.row_contiguous[i] + ? ind_idx + : elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += + static_cast(idx_val) * static_cast(out_strides[ax]); + } + auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; + if constexpr (!UPD_ROW_CONTIG) { + upd_idx = + elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + } + op.atomic_update(out, updates[upd_idx], out_idx); } - auto upd_idx = - elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx); } )preamble"; } diff --git a/Source/Cmlx/mlx-generated/steel_conv_general.cpp b/Source/Cmlx/mlx-generated/steel_conv_general.cpp index 286b22e1..98e34d93 100644 --- a/Source/Cmlx/mlx-generated/steel_conv_general.cpp +++ b/Source/Cmlx/mlx-generated/steel_conv_general.cpp @@ -315,8 +315,8 @@ implicit_gemm_conv_2d_general( } threadgroup_barrier(mem_flags::mem_none); { - int offset_m = c_row + mma_op.sm + mma_op.tm; - int offset_n = c_col + mma_op.sn + mma_op.tn; + int offset_m = c_row + mma_op.sm; + int offset_n = c_col + mma_op.sn; C += offset_n; if (offset_n >= gemm_params->N) return; @@ -335,14 +335,14 @@ implicit_gemm_conv_2d_general( oh * params->out_strides[1] + ow * params->out_strides[2]; #pragma clang loop unroll(full) for (int j = 0; j < mma_t::TN; j++) { - thread const auto& accum = - mma_op.results[i * mma_t::TN + j].thread_elements(); + thread const auto& accum = mma_op.Ctile.frag_at(i, j); int offset = offset_cm + (j * mma_t::TN_stride); - if (j * mma_t::TN_stride < diff) { - C[offset] = Epilogue::apply(accum[0]); - } - if (j * mma_t::TN_stride + 1 < diff) { - C[offset + 1] = Epilogue::apply(accum[1]); + constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; +#pragma clang loop unroll(full) + for (short k = 0; k < kelems; k++) { + if ((j * mma_t::TN_stride + k) < diff) { + C[offset + k] = Epilogue::apply(accum[k]); + } } } } diff --git a/Source/Cmlx/mlx-generated/ternary.cpp b/Source/Cmlx/mlx-generated/ternary.cpp index ce4a2c8a..8ecd64a8 100644 --- a/Source/Cmlx/mlx-generated/ternary.cpp +++ b/Source/Cmlx/mlx-generated/ternary.cpp @@ -32,12 +32,12 @@ template constant const size_t& b_strides, constant const size_t& c_strides, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd2( device const bool* a, device const T* b, @@ -48,13 +48,13 @@ template constant const size_t c_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - auto c_idx = elem_to_loc_2(index, c_strides); - size_t out_idx = index.x + size_t(grid_dim.x) * index.y; + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd3( device const bool* a, device const T* b, @@ -65,14 +65,13 @@ template constant const size_t c_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - auto c_idx = elem_to_loc_3(index, c_strides); - size_t out_idx = - index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g( device const bool* a, device const T* b, @@ -85,7 +84,7 @@ template constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_3_nd( + auto idx = elem_to_loc_3_nd( {N * index.x, index.y, index.z}, shape, a_strides, @@ -93,11 +92,10 @@ template c_strides, ndim); auto xshape = shape[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); - auto a_xstride = a_strides[ndim - 1]; - auto b_xstride = b_strides[ndim - 1]; - auto c_xstride = c_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + IdxT c_xstride = c_strides[ndim - 1]; for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); idx.x += a_xstride; diff --git a/Source/Cmlx/mlx-generated/unary.cpp b/Source/Cmlx/mlx-generated/unary.cpp index 00d67fb7..cdffc5f3 100644 --- a/Source/Cmlx/mlx-generated/unary.cpp +++ b/Source/Cmlx/mlx-generated/unary.cpp @@ -2,37 +2,41 @@ namespace mlx::core::metal { const char* unary() { return R"preamble( -template +template [[kernel]] void unary_v( device const T* in, - device T* out, + device U* out, uint index [[thread_position_in_grid]]) { out[index] = Op()(in[index]); } -template +template [[kernel]] void unary_v2( device const T* in, - device T* out, + device U* out, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { size_t offset = index.x + grid_dim.x * size_t(index.y); out[offset] = Op()(in[offset]); } -template +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = size_t> [[kernel]] void unary_g( device const T* in, - device T* out, + device U* out, constant const int* in_shape, constant const size_t* in_strides, device const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = - elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto idx = elem_to_loc( + {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); auto xshape = in_shape[ndim - 1]; - auto xstride = in_strides[ndim - 1]; - size_t out_idx = - N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); + IdxT xstride = in_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { out[out_idx++] = Op()(in[idx]); idx += xstride; diff --git a/Source/Cmlx/mlx-generated/unary_ops.cpp b/Source/Cmlx/mlx-generated/unary_ops.cpp index 6b1be9ea..0d30a145 100644 --- a/Source/Cmlx/mlx-generated/unary_ops.cpp +++ b/Source/Cmlx/mlx-generated/unary_ops.cpp @@ -308,6 +308,12 @@ struct Floor { return x; }; }; +struct Imag { + template + T operator()(T x) { + return x.imag; + }; +}; struct Log { template T operator()(T x) { @@ -344,6 +350,12 @@ struct Negative { return -x; }; }; +struct Real { + template + T operator()(T x) { + return x.real; + }; +}; struct Round { template T operator()(T x) { diff --git a/Source/Cmlx/mlx-generated/utils.cpp b/Source/Cmlx/mlx-generated/utils.cpp index dccadb30..8bc287f7 100644 --- a/Source/Cmlx/mlx-generated/utils.cpp +++ b/Source/Cmlx/mlx-generated/utils.cpp @@ -2,6 +2,17 @@ namespace mlx::core::metal { const char* utils() { return R"preamble( +#if (__METAL_VERSION__ >= 310) +using namespace metal; +typedef bfloat bfloat16_t; +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return as_type(x); +} +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return as_type(x); +} + + #else using namespace metal; constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > @@ -150,6 +161,15 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) { } } #pragma METAL internals : disable +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return x.bits_; +} +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat()); +} + + #endif + namespace metal { METAL_FUNC bfloat16_t abs(bfloat16_t x) { return static_cast(__metal_fabs(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t acos(bfloat16_t x) { return static_cast(__metal_acos(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t acosh(bfloat16_t x) { return static_cast(__metal_acosh(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t asin(bfloat16_t x) { return static_cast(__metal_asin(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t asinh(bfloat16_t x) { return static_cast(__metal_asinh(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t atan(bfloat16_t y_over_x) { return static_cast( __metal_atan(static_cast(y_over_x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t atan2(bfloat16_t y, bfloat16_t x) { return static_cast( __metal_atan2(static_cast(y), static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t atanh(bfloat16_t x) { return static_cast(__metal_atanh(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t ceil(bfloat16_t x) { return static_cast(__metal_ceil(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t cos(bfloat16_t x) { return static_cast(__metal_cos(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t cosh(bfloat16_t x) { return static_cast(__metal_cosh(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t cospi(bfloat16_t x) { return static_cast(__metal_cospi(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t divide(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_divide(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t exp(bfloat16_t x) { return static_cast(__metal_exp(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t exp10(bfloat16_t x) { return static_cast(__metal_exp10(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t exp2(bfloat16_t x) { return static_cast(__metal_exp2(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fabs(bfloat16_t x) { return static_cast(__metal_fabs(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fdim(bfloat16_t x, bfloat16_t y) { float t = static_cast(x - y); return static_cast(select(t, float(0), t < float(0) || x == y)); } METAL_FUNC bfloat16_t floor(bfloat16_t x) { return static_cast(__metal_floor(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fma(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fma( static_cast(x), static_cast(y), static_cast(z))); } METAL_FUNC bfloat16_t fmax(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_fmax(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fmax3(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fmax3( static_cast(x), static_cast(y), static_cast(z), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fmedian3(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fmedian3( static_cast(x), static_cast(y), static_cast(z), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fmin(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_fmin(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fmin3(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fmin3( static_cast(x), static_cast(y), static_cast(z), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fmod(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_fmod(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t fract(bfloat16_t x) { return static_cast(__metal_fract(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t frexp(bfloat16_t x, thread int& exp) { return static_cast(__metal_frexp(static_cast(x), &exp)); } METAL_FUNC bfloat16_t ldexp(bfloat16_t x, int k) { return static_cast(__metal_ldexp(static_cast(x), k, __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t log(bfloat16_t x) { return static_cast(__metal_log(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t log10(bfloat16_t x) { return static_cast(__metal_log10(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t log2(bfloat16_t x) { return static_cast(__metal_log2(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_fmax(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t max3(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fmax3( static_cast(x), static_cast(y), static_cast(z), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t median3(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fmedian3( static_cast(x), static_cast(y), static_cast(z), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_fmin(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t min3(bfloat16_t x, bfloat16_t y, bfloat16_t z) { return static_cast(__metal_fmin3( static_cast(x), static_cast(y), static_cast(z), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t nextafter(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_nextafter(static_cast(x), static_cast(y))); } METAL_FUNC bfloat16_t pow(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_pow(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t powr(bfloat16_t x, bfloat16_t y) { return static_cast( __metal_powr(static_cast(x), static_cast(y), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t rint(bfloat16_t x) { return static_cast(__metal_rint(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t round(bfloat16_t x) { return static_cast(__metal_round(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t rsqrt(bfloat16_t x) { return static_cast(__metal_rsqrt(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t sin(bfloat16_t x) { return static_cast(__metal_sin(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t sinh(bfloat16_t x) { return static_cast(__metal_sinh(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t sinpi(bfloat16_t x) { return static_cast(__metal_sinpi(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t sqrt(bfloat16_t x) { return static_cast(__metal_sqrt(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t tan(bfloat16_t x) { return static_cast(__metal_tan(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t tanh(bfloat16_t x) { return static_cast(__metal_tanh(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t tanpi(bfloat16_t x) { return static_cast(__metal_tanpi(static_cast(x), __METAL_MAYBE_FAST_MATH__)); } METAL_FUNC bfloat16_t trunc(bfloat16_t x) { return static_cast(__metal_trunc(static_cast(x), __METAL_MAYBE_FAST_MATH__)); }; namespace fast { @@ -160,7 +180,7 @@ METAL_FUNC bfloat16_t abs(bfloat16_t x) { return static_cast(__metal } } namespace metal { -METAL_FUNC bfloat16_t simd_broadcast(bfloat16_t data, ushort broadcast_lane_id) { return _MLX_BFloat16(__metal_simd_broadcast(data.bits_, broadcast_lane_id), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle(bfloat16_t data, ushort simd_lane_id) { return _MLX_BFloat16(__metal_simd_shuffle(data.bits_, simd_lane_id), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_down( bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo) { return _MLX_BFloat16(__metal_simd_shuffle_and_fill_down( data.bits_, filling_data.bits_, delta, modulo), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_down( bfloat16_t data, bfloat16_t filling_data, ushort delta) { return _MLX_BFloat16(__metal_simd_shuffle_and_fill_down( data.bits_, filling_data.bits_, delta, __metal_get_simdgroup_size(ushort())), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_up( bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo) { return _MLX_BFloat16(__metal_simd_shuffle_and_fill_up( data.bits_, filling_data.bits_, delta, modulo), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_up( bfloat16_t data, bfloat16_t filling_data, ushort delta) { return _MLX_BFloat16(__metal_simd_shuffle_and_fill_up( data.bits_, filling_data.bits_, delta, __metal_get_simdgroup_size(ushort())), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta) { return _MLX_BFloat16(__metal_simd_shuffle_down(data.bits_, delta), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_rotate_down(bfloat16_t data, ushort delta) { return _MLX_BFloat16(__metal_simd_shuffle_rotate_down(data.bits_, delta), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_rotate_up(bfloat16_t data, ushort delta) { return _MLX_BFloat16(__metal_simd_shuffle_rotate_up(data.bits_, delta), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_up(bfloat16_t data, ushort delta) { return _MLX_BFloat16(__metal_simd_shuffle_up(data.bits_, delta), _MLX_BFloat16::bits_to_bfloat()); } METAL_FUNC bfloat16_t simd_shuffle_xor(bfloat16_t data, ushort mask) { return _MLX_BFloat16(__metal_simd_shuffle_xor(data.bits_, mask), _MLX_BFloat16::bits_to_bfloat()); }; +METAL_FUNC bfloat16_t simd_broadcast(bfloat16_t data, ushort broadcast_lane_id) { return uint16_to_bfloat16( __metal_simd_broadcast(bfloat16_to_uint16(data), broadcast_lane_id)); } METAL_FUNC bfloat16_t simd_shuffle(bfloat16_t data, ushort simd_lane_id) { return uint16_to_bfloat16( __metal_simd_shuffle(bfloat16_to_uint16(data), simd_lane_id)); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_down( bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo) { return uint16_to_bfloat16(__metal_simd_shuffle_and_fill_down( bfloat16_to_uint16(data), bfloat16_to_uint16(filling_data), delta, modulo)); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_down( bfloat16_t data, bfloat16_t filling_data, ushort delta) { return uint16_to_bfloat16(__metal_simd_shuffle_and_fill_down( bfloat16_to_uint16(data), bfloat16_to_uint16(filling_data), delta, __metal_get_simdgroup_size(ushort()))); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_up( bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo) { return uint16_to_bfloat16(__metal_simd_shuffle_and_fill_up( bfloat16_to_uint16(data), bfloat16_to_uint16(filling_data), delta, modulo)); } METAL_FUNC bfloat16_t simd_shuffle_and_fill_up( bfloat16_t data, bfloat16_t filling_data, ushort delta) { return uint16_to_bfloat16(__metal_simd_shuffle_and_fill_up( bfloat16_to_uint16(data), bfloat16_to_uint16(filling_data), delta, __metal_get_simdgroup_size(ushort()))); } METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta) { return uint16_to_bfloat16( __metal_simd_shuffle_down(bfloat16_to_uint16(data), delta)); } METAL_FUNC bfloat16_t simd_shuffle_rotate_down(bfloat16_t data, ushort delta) { return uint16_to_bfloat16( __metal_simd_shuffle_rotate_down(bfloat16_to_uint16(data), delta)); } METAL_FUNC bfloat16_t simd_shuffle_rotate_up(bfloat16_t data, ushort delta) { return uint16_to_bfloat16( __metal_simd_shuffle_rotate_up(bfloat16_to_uint16(data), delta)); } METAL_FUNC bfloat16_t simd_shuffle_up(bfloat16_t data, ushort delta) { return uint16_to_bfloat16( __metal_simd_shuffle_up(bfloat16_to_uint16(data), delta)); } METAL_FUNC bfloat16_t simd_shuffle_xor(bfloat16_t data, ushort mask) { return uint16_to_bfloat16( __metal_simd_shuffle_xor(bfloat16_to_uint16(data), mask)); }; METAL_FUNC bfloat16_t simd_max(bfloat16_t data) { return static_cast(__metal_simd_max(static_cast(data))); } METAL_FUNC bfloat16_t simd_min(bfloat16_t data) { return static_cast(__metal_simd_min(static_cast(data))); } METAL_FUNC bfloat16_t simd_prefix_exclusive_product(bfloat16_t data) { return static_cast( __metal_simd_prefix_exclusive_product(static_cast(data))); } METAL_FUNC bfloat16_t simd_prefix_exclusive_sum(bfloat16_t data) { return static_cast( __metal_simd_prefix_exclusive_sum(static_cast(data))); } METAL_FUNC bfloat16_t simd_prefix_inclusive_product(bfloat16_t data) { return static_cast( __metal_simd_prefix_inclusive_product(static_cast(data))); } METAL_FUNC bfloat16_t simd_prefix_inclusive_sum(bfloat16_t data) { return static_cast( __metal_simd_prefix_inclusive_sum(static_cast(data))); } METAL_FUNC bfloat16_t simd_product(bfloat16_t data) { return static_cast(__metal_simd_product(static_cast(data))); } METAL_FUNC bfloat16_t simd_sum(bfloat16_t data) { return static_cast(__metal_simd_sum(static_cast(data))); } METAL_FUNC bfloat16_t simd_xor(bfloat16_t data) { return static_cast(__metal_simd_xor(static_cast(data))); }; } using namespace metal; @@ -303,105 +323,115 @@ struct Limits { -metal::numeric_limits::infinity(), -metal::numeric_limits::infinity()); }; -template -METAL_FUNC stride_t elem_to_loc( +template +METAL_FUNC IdxT elem_to_loc( uint elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = 0; + IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; + loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } -template -METAL_FUNC stride_t elem_to_loc( - stride_t elem, +template +METAL_FUNC IdxT elem_to_loc( + StrideT elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = 0; + IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * strides[i]; + loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } -template -METAL_FUNC stride_t elem_to_loc( +template +METAL_FUNC IdxT elem_to_loc( uint3 elem, constant const int* shape, - constant const stride_t* strides, + constant const StrideT* strides, int ndim) { - stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + IdxT loc = + elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); for (int d = ndim - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * strides[d]; + loc += (elem.z % shape[d]) * IdxT(strides[d]); elem.z /= shape[d]; } return loc; } -template -METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) { - return elem * stride; +template +METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) { + return elem * IdxT(stride); } -template -METAL_FUNC stride_t -elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) { - return elem.x * strides[1] + elem.y * strides[0]; +template +METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) { + return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); } -template -METAL_FUNC stride_t -elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { - return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +template +METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) { + return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + + elem.z * IdxT(strides[0]); } -template -METAL_FUNC ulong2 elem_to_loc_2_nd( +template +METAL_FUNC vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, - constant const stride_t* a_strides, - constant const stride_t* b_strides, + constant const StrideT* a_strides, + constant const StrideT* b_strides, int ndim) { - ulong2 loc = { - ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), - ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + vec loc = { + IdxT( + elem.x * IdxT(a_strides[ndim - 1]) + + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), + IdxT( + elem.x * IdxT(b_strides[ndim - 1]) + + elem.y * IdxT(b_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); elem.z /= shape[d]; } return loc; } -METAL_FUNC ulong3 elem_to_loc_3_nd( +template +METAL_FUNC vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, constant const size_t* b_strides, constant const size_t* c_strides, int ndim) { - ulong3 loc = { - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2], - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2], - elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]}; + vec loc = { + elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]), + elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]), + elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - loc.z += l * c_strides[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + loc.z += l * IdxT(c_strides[d]); elem.z /= shape[d]; } return loc; } -template -struct looped_elem_to_loc { - looped_elem_to_loc inner_looper; - offset_t offset{0}; +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; int index{0}; + LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} void next(const constant int* shape, const constant size_t* strides) { + if (dim == 0) { + return; + } index++; - offset += strides[dim - 1]; + offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); @@ -409,47 +439,68 @@ struct looped_elem_to_loc { } } void next(int n, const constant int* shape, const constant size_t* strides) { + if (dim == 0) { + return; + } index += n; - offset += n * strides[dim - 1]; + offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } index = 0; - inner_looper.next(shape, strides); offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); } } } - offset_t - location(offset_t, const constant int*, const constant size_t*, int) { + OffsetT location() { return offset; } }; -template -struct looped_elem_to_loc<1, offset_t> { - offset_t offset{0}; - void next(const constant int*, const constant size_t* strides) { - offset += strides[0]; +template +struct LoopedElemToLoc<1, OffsetT, true> { + int dim; + OffsetT offset{0}; + uint index{0}; + LoopedElemToLoc(int dim) : dim(dim) {} + void next(const constant int* shape, const constant size_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } } - void next(int n, const constant int*, const constant size_t* strides) { - offset += n * strides[0]; + void next(int n, const constant int* shape, const constant size_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } } - offset_t - location(offset_t, const constant int*, const constant size_t*, int) { + OffsetT location() { return offset; } }; -template -struct looped_elem_to_loc<0, offset_t> { - void next(const constant int*, const constant size_t*) {} - void next(int, const constant int*, const constant size_t*) {} - offset_t location( - offset_t idx, - const constant int* shape, - const constant size_t* strides, - int ndim) { - return elem_to_loc(idx, shape, strides, ndim); +template +struct LoopedElemToLoc<1, OffsetT, false> { + OffsetT offset{0}; + LoopedElemToLoc(int) {} + void next(const constant int*, const constant size_t* strides) { + offset += OffsetT(strides[0]); + } + void next(int n, const constant int*, const constant size_t* strides) { + offset += n * OffsetT(strides[0]); + } + OffsetT location() { + return offset; } }; template @@ -491,6 +542,54 @@ inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { return complex64_t( simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); } +inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} +inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} +inline bool simd_shuffle_up(bool data, uint16_t delta) { + return simd_shuffle_up(static_cast(data), delta); +} +inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); +} +inline uint64_t +simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} +inline int64_t +simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} +inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { + return simd_shuffle_and_fill_up( + static_cast(data), static_cast(filling), delta); +} +inline complex64_t simd_shuffle_and_fill_up( + complex64_t data, + complex64_t filling, + uint16_t delta) { + return complex64_t( + simd_shuffle_and_fill_up(data.real, filling.real, delta), + simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); +} +inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} +inline int64_t simd_shuffle(int64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} +inline bool simd_shuffle(bool data, uint16_t lane) { + return simd_shuffle(static_cast(data), lane); +} +inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { + return complex64_t( + simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); +} )preamble"; } diff --git a/Source/MLX/Cmlx+Util.swift b/Source/MLX/Cmlx+Util.swift index f91ec2ee..3d7d83a5 100644 --- a/Source/MLX/Cmlx+Util.swift +++ b/Source/MLX/Cmlx+Util.swift @@ -3,34 +3,17 @@ import Cmlx import Foundation -@inline(__always) -func mlx_free(_ ptr: OpaquePointer) { - mlx_free(UnsafeMutableRawPointer(ptr)) -} - -@inline(__always) -func mlx_retain(_ ptr: OpaquePointer) { - mlx_retain(UnsafeMutableRawPointer(ptr)) -} - -func mlx_describe(_ ptr: OpaquePointer) -> String? { - let description = mlx_tostring(UnsafeMutableRawPointer(ptr))! - defer { mlx_free(description) } - return String(cString: mlx_string_data(description)) -} - // return a +1 mlx_vector_array containing the given arrays func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { - let result = mlx_vector_array_new()! - mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) - return result + mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count) } func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { (0 ..< mlx_vector_array_size(vector_array)) .map { index in // ctx is a +1 object, the array takes ownership - let ctx = mlx_vector_array_get(vector_array, index)! + var ctx = mlx_array_new() + mlx_vector_array_get(&ctx, vector_array, index) return MLXArray(ctx) } } @@ -38,21 +21,22 @@ func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { func mlx_map_array_values(_ mlx_map: mlx_map_string_to_array) -> [String: MLXArray] { var result = [String: MLXArray]() - let iterator = mlx_map_string_to_array_iterate(mlx_map)! - defer { mlx_free(iterator) } + let iterator = mlx_map_string_to_array_iterator_new(mlx_map) + defer { mlx_map_string_to_array_iterator_free(iterator) } - while !mlx_map_string_to_array_iterator_end(iterator) { - let mlx_key = mlx_map_string_to_array_iterator_key(iterator)! - defer { mlx_free(mlx_key) } - let key = String(cString: mlx_string_data(mlx_key)) + var mlx_key: UnsafePointer? + var mlx_value = mlx_array_new() + defer { mlx_array_free(mlx_value) } - // note: transfer ownership - let mlx_array_ctx = mlx_map_string_to_array_iterator_value(iterator)! - let array = MLXArray(mlx_array_ctx) + while mlx_map_string_to_array_iterator_next(&mlx_key, &mlx_value, iterator) == 0 { + guard let mlx_key else { continue } - result[key] = array + let key = String(cString: mlx_key) + var new = mlx_array_new() + mlx_array_set(&new, mlx_value) + let array = MLXArray(new) - mlx_map_string_to_array_iterator_next(iterator) + result[key] = array } return result @@ -61,51 +45,40 @@ func mlx_map_array_values(_ mlx_map: mlx_map_string_to_array) -> [String: MLXArr func mlx_map_string_values(_ mlx_map: mlx_map_string_to_string) -> [String: String] { var result = [String: String]() - let iterator = mlx_map_string_to_string_iterate(mlx_map)! - defer { mlx_free(iterator) } + let iterator = mlx_map_string_to_string_iterator_new(mlx_map) + defer { mlx_map_string_to_string_iterator_free(iterator) } - while !mlx_map_string_to_string_iterator_end(iterator) { - let mlx_key = mlx_map_string_to_string_iterator_key(iterator)! - defer { mlx_free(mlx_key) } - let key = String(cString: mlx_string_data(mlx_key)) + var mlx_key: UnsafePointer? + var mlx_value: UnsafePointer? - // note: transfer ownership - let mlx_value = mlx_map_string_to_string_iterator_value(iterator)! - defer { mlx_free(mlx_value) } - let value = String(cString: mlx_string_data(mlx_value)) + while mlx_map_string_to_string_iterator_next(&mlx_key, &mlx_value, iterator) == 0 { + guard let mlx_key, let mlx_value else { continue } - result[key] = value + let key = String(cString: mlx_key) + let value = String(cString: mlx_value) - mlx_map_string_to_string_iterator_next(iterator) + result[key] = value } return result } func new_mlx_array_map(_ dictionary: [String: MLXArray]) -> mlx_map_string_to_array { - let mlx_map = mlx_map_string_to_array_new()! + let mlx_map = mlx_map_string_to_array_new() for (key, array) in dictionary { - let mlx_key = mlx_string_new(key.cString(using: .utf8))! - defer { mlx_free(mlx_key) } - - mlx_map_string_to_array_insert(mlx_map, mlx_key, array.ctx) + mlx_map_string_to_array_insert(mlx_map, key.cString(using: .utf8), array.ctx) } return mlx_map } func new_mlx_string_map(_ dictionary: [String: String]) -> mlx_map_string_to_string { - let mlx_map = mlx_map_string_to_string_new()! + let mlx_map = mlx_map_string_to_string_new() for (key, value) in dictionary { - let mlx_key = mlx_string_new(key.cString(using: .utf8))! - defer { mlx_free(mlx_key) } - - let mlx_value = mlx_string_new(value.cString(using: .utf8))! - defer { mlx_free(mlx_value) } - - mlx_map_string_to_string_insert(mlx_map, mlx_key, mlx_value) + mlx_map_string_to_string_insert( + mlx_map, key.cString(using: .utf8), value.cString(using: .utf8)) } return mlx_map @@ -130,36 +103,25 @@ func new_mlx_closure(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> mlx_closure { // the C function that the mlx_closure will call -- this will convert // arguments & results and call the captured `f()` - func trampoline(vector_array: mlx_vector_array?, payload: UnsafeMutableRawPointer?) - -> mlx_vector_array? + func trampoline( + resultOut: UnsafeMutablePointer?, vector_array: mlx_vector_array, + payload: UnsafeMutableRawPointer? + ) + -> Int32 { let state = Unmanaged.fromOpaque(payload!).takeUnretainedValue() - let arrays = mlx_vector_array_values(vector_array!) + let arrays = mlx_vector_array_values(vector_array) let result = state.f(arrays) - return new_mlx_vector_array(result) - } - - return mlx_closure_new_with_payload(trampoline, payload, free)! -} -func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { - let a = mlx_tuple_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_get_1(tuple)! - return (MLXArray(a), MLXArray(b)) -} + if let resultOut { + resultOut.pointee = new_mlx_vector_array(result) + } else { + fatalError("no resultOut pointer") + } -func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { - let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! - defer { mlx_free(a) } - let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! - defer { mlx_free(b) } - return (mlx_vector_array_values(a), mlx_vector_array_values(b)) -} + return 0 + } -func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { - let a = mlx_tuple_array_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_array_get_1(tuple)! - let c = mlx_tuple_array_array_array_get_2(tuple)! - return (MLXArray(a), MLXArray(b), MLXArray(c)) + return mlx_closure_new_func_payload(trampoline, payload, free) } diff --git a/Source/MLX/DType.swift b/Source/MLX/DType.swift index 740c87a6..013b94e0 100644 --- a/Source/MLX/DType.swift +++ b/Source/MLX/DType.swift @@ -39,7 +39,7 @@ public enum DType: Hashable, Sendable { case bfloat16 case complex64 - init(_ cmlxDtype: mlx_array_dtype) { + init(_ cmlxDtype: mlx_dtype) { switch cmlxDtype { case MLX_BOOL: self = .bool case MLX_UINT8: self = .uint8 @@ -59,7 +59,7 @@ public enum DType: Hashable, Sendable { } } - public var cmlxDtype: mlx_array_dtype { + public var cmlxDtype: mlx_dtype { switch self { case .bool: MLX_BOOL case .uint8: MLX_UINT8 diff --git a/Source/MLX/Device.swift b/Source/MLX/Device.swift index 8d1d504a..8d451c85 100644 --- a/Source/MLX/Device.swift +++ b/Source/MLX/Device.swift @@ -41,24 +41,22 @@ public final class Device: @unchecked Sendable, Equatable { case DeviceType.gpu: cDeviceType = MLX_GPU } - ctx = mlx_device_new(cDeviceType, index) + self.ctx = mlx_device_new_type(cDeviceType, index) } public init() { - ctx = mlx_default_device() + var ctx = mlx_device_new() + mlx_get_default_device(&ctx) + self.ctx = ctx } deinit { - mlx_free(ctx) + mlx_device_free(ctx) } static public let cpu: Device = Device(.cpu) static public let gpu: Device = Device(.gpu) - static public func defaultDevice() -> Device { - return Device() - } - public var deviceType: DeviceType? { switch mlx_device_get_type(ctx) { case MLX_CPU: .cpu @@ -67,6 +65,27 @@ public final class Device: @unchecked Sendable, Equatable { } } + static let _lock = NSLock() + #if swift(>=5.10) + nonisolated(unsafe) static var _defaultDevice = gpu + nonisolated(unsafe) static var _defaultStream = Stream(gpu) + #else + static var _defaultDevice = gpu + static var _defaultStream = Stream(gpu) + #endif + + static public func defaultDevice() -> Device { + _lock.withLock { + _defaultDevice + } + } + + static func defaultStream() -> Stream { + _lock.withLock { + _defaultStream + } + } + /// Set the default device. /// /// For example: @@ -80,7 +99,11 @@ public final class Device: @unchecked Sendable, Equatable { /// ### See Also /// - ``StreamOrDevice/default`` static public func setDefault(device: Device) { - mlx_set_default_device(device.ctx) + _lock.withLock { + mlx_set_default_device(device.ctx) + _defaultDevice = device + _defaultStream = Stream(device) + } } /// Compare two ``Device`` for equality -- this does not compare the index, just the device type. @@ -91,7 +114,10 @@ public final class Device: @unchecked Sendable, Equatable { extension Device: CustomStringConvertible { public var description: String { - mlx_describe(ctx) ?? String(describing: type(of: self)) + var s = mlx_string_new() + mlx_device_tostring(&s, ctx) + defer { mlx_string_free(s) } + return String(cString: mlx_string_data(s), encoding: .utf8)! } } diff --git a/Source/MLX/Factory.swift b/Source/MLX/Factory.swift index 94b141d0..7d171236 100644 --- a/Source/MLX/Factory.swift +++ b/Source/MLX/Factory.swift @@ -25,7 +25,9 @@ extension MLXArray { static public func zeros( _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_zeros(shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_zeros(&result, shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array of zeros with a given ``DType`` @@ -48,7 +50,9 @@ extension MLXArray { static public func zeros( _ shape: [Int], dtype: DType = .float32, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_zeros(shape.map { Int32($0) }, shape.count, dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_zeros(&result, shape.map { Int32($0) }, shape.count, dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array of zeros. @@ -69,7 +73,9 @@ extension MLXArray { /// - ``zeros(_:type:stream:)`` /// - ``ones(_:type:stream:)`` static public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_zeros_like(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_zeros_like(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Construct an array of ones. @@ -92,7 +98,9 @@ extension MLXArray { static public func ones( _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_ones(shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_ones(&result, shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array of ones with a given ``DType`` @@ -115,7 +123,9 @@ extension MLXArray { static public func ones( _ shape: [Int], dtype: DType = .float32, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_ones(shape.map { Int32($0) }, shape.count, dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_ones(&result, shape.map { Int32($0) }, shape.count, dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array of ones. @@ -136,7 +146,9 @@ extension MLXArray { /// - ``ones(_:type:stream:)`` /// - ``zeros(_:type:stream:)`` static public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_ones_like(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_ones_like(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Create an identity matrix or a general diagonal matrix. @@ -162,7 +174,9 @@ extension MLXArray { _ n: Int, m: Int? = nil, k: Int = 0, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_eye(n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_eye(&result, n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array with the given value. @@ -190,7 +204,9 @@ extension MLXArray { static public func full( _ shape: [Int], values: MLXArray, type: T.Type, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_full(shape.asInt32, shape.count, values.ctx, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_full(&result, shape.asInt32, shape.count, values.ctx, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array with the given value. @@ -217,8 +233,11 @@ extension MLXArray { static public func full(_ shape: [Int], values: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray( - mlx_full(shape.asInt32, shape.count, values.ctx, values.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + + mlx_full( + &result, shape.asInt32, shape.count, values.ctx, values.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Create a square identity matrix. @@ -241,7 +260,9 @@ extension MLXArray { static public func identity( _ n: Int, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_identity(n.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_identity(&result, n.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Generate `num` evenly spaced numbers over interval `[start, stop]` for `BinaryInteger`. @@ -265,8 +286,11 @@ extension MLXArray { static public func linspace( _ start: T, _ stop: T, count: Int = 50, stream: StreamOrDevice = .default ) -> MLXArray where T: BinaryInteger { - MLXArray( - mlx_linspace(Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + + mlx_linspace( + &result, Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Generate `num` evenly spaced numbers over interval `[start, stop]` for `BinaryFloatingPoint`. @@ -290,8 +314,11 @@ extension MLXArray { static public func linspace( _ start: T, _ stop: T, count: Int = 50, stream: StreamOrDevice = .default ) -> MLXArray where T: BinaryFloatingPoint { - MLXArray( - mlx_linspace(Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + + mlx_linspace( + &result, Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Repeat an array along a specified axis. @@ -306,7 +333,9 @@ extension MLXArray { static public func `repeat`( _ array: MLXArray, count: Int, axis: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_repeat(array.ctx, count.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat(&result, array.ctx, count.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Repeat a flattened array along axis 0. @@ -321,7 +350,9 @@ extension MLXArray { static public func `repeat`(_ array: MLXArray, count: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_repeat_all(array.ctx, count.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat_all(&result, array.ctx, count.int32, stream.ctx) + return MLXArray(result) } /// Repeat an array along a specified axis. @@ -346,7 +377,9 @@ extension MLXArray { static public func repeated( _ array: MLXArray, count: Int, axis: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_repeat(array.ctx, count.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat(&result, array.ctx, count.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Repeat a flattened array along axis 0. @@ -370,7 +403,9 @@ extension MLXArray { static public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_repeat_all(array.ctx, count.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat_all(&result, array.ctx, count.int32, stream.ctx) + return MLXArray(result) } /// An array with ones at and below the given diagonal and zeros elsewhere. @@ -395,7 +430,9 @@ extension MLXArray { _ n: Int, m: Int? = nil, k: Int = 0, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_tri(n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_tri(&result, n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } } @@ -420,7 +457,9 @@ extension MLXArray { public func zeros( _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_zeros(shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_zeros(&result, shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array of zeros. @@ -441,7 +480,9 @@ public func zeros( /// - ``zeros(_:type:stream:)`` /// - ``ones(_:type:stream:)`` public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_zeros_like(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_zeros_like(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Construct an array of ones. @@ -464,7 +505,9 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML public func ones( _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_ones(shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_ones(&result, shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array of ones. @@ -485,7 +528,9 @@ public func ones( /// - ``ones(_:type:stream:)`` /// - ``zeros(_:type:stream:)`` public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_ones_like(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_ones_like(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Create an identity matrix or a general diagonal matrix. @@ -511,7 +556,9 @@ public func eye( _ n: Int, m: Int? = nil, k: Int = 0, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_eye(n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_eye(&result, n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array with the given value. @@ -539,8 +586,10 @@ public func eye( public func full( _ shape: [Int], values: ScalarOrArray, type: T.Type, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() let values = values.asMLXArray(dtype: nil) - return MLXArray(mlx_full(shape.asInt32, shape.count, values.ctx, T.dtype.cmlxDtype, stream.ctx)) + mlx_full(&result, shape.asInt32, shape.count, values.ctx, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Construct an array with the given value. @@ -567,9 +616,10 @@ public func full( public func full(_ shape: [Int], values: ScalarOrArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() let values = values.asMLXArray(dtype: nil) - return MLXArray( - mlx_full(shape.asInt32, shape.count, values.ctx, values.dtype.cmlxDtype, stream.ctx)) + mlx_full(&result, shape.asInt32, shape.count, values.ctx, values.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Create a square identity matrix. @@ -592,7 +642,9 @@ public func full(_ shape: [Int], values: ScalarOrArray, stream: StreamOrDevice = public func identity( _ n: Int, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_identity(n.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_identity(&result, n.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Generate `num` evenly spaced numbers over interval `[start, stop]`. @@ -616,7 +668,9 @@ public func identity( public func linspace( _ start: T, _ stop: T, count: Int = 50, stream: StreamOrDevice = .default ) -> MLXArray where T: BinaryInteger { - MLXArray(mlx_linspace(Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_linspace(&result, Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Generate `num` evenly spaced numbers over interval `[start, stop]`. @@ -640,7 +694,9 @@ public func linspace( public func linspace( _ start: T, _ stop: T, count: Int = 50, stream: StreamOrDevice = .default ) -> MLXArray where T: BinaryFloatingPoint { - MLXArray(mlx_linspace(Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_linspace(&result, Double(start), Double(stop), count.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Repeat an array along a specified axis. @@ -655,7 +711,9 @@ public func linspace( public func `repeat`(_ array: MLXArray, count: Int, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_repeat(array.ctx, count.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat(&result, array.ctx, count.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Repeat a flattened array along axis 0. @@ -668,7 +726,9 @@ public func `repeat`(_ array: MLXArray, count: Int, axis: Int, stream: StreamOrD /// - ``full(_:values:stream:)`` @available(*, deprecated, renamed: "repeated(_:count:stream:)") public func `repeat`(_ array: MLXArray, count: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_repeat_all(array.ctx, count.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat_all(&result, array.ctx, count.int32, stream.ctx) + return MLXArray(result) } /// Repeat an array along a specified axis. @@ -693,7 +753,9 @@ public func `repeat`(_ array: MLXArray, count: Int, stream: StreamOrDevice = .de public func repeated(_ array: MLXArray, count: Int, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_repeat(array.ctx, count.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat(&result, array.ctx, count.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Repeat a flattened array along axis 0. @@ -715,7 +777,9 @@ public func repeated(_ array: MLXArray, count: Int, axis: Int, stream: StreamOrD /// - ``repeated(_:count:axis:stream:)`` /// - ``full(_:values:stream:)`` public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_repeat_all(array.ctx, count.int32, stream.ctx)) + var result = mlx_array_new() + mlx_repeat_all(&result, array.ctx, count.int32, stream.ctx) + return MLXArray(result) } /// An array with ones at and below the given diagonal and zeros elsewhere. @@ -740,5 +804,7 @@ public func tri( _ n: Int, m: Int? = nil, k: Int = 0, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_tri(n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_tri(&result, n.int32, (m ?? n).int32, k.int32, T.dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } diff --git a/Source/MLX/GPU.swift b/Source/MLX/GPU.swift index 50d6147d..17c08246 100644 --- a/Source/MLX/GPU.swift +++ b/Source/MLX/GPU.swift @@ -133,7 +133,9 @@ public enum GPU { /// Note, this will not always match memory use reported by the system because /// it does not include cached memory buffers. public static var activeMemory: Int { - mlx_metal_get_active_memory() + var result: size_t = 0 + mlx_metal_get_active_memory(&result) + return result } /// Get the cache size in bytes. @@ -141,7 +143,9 @@ public enum GPU { /// The cache includes memory not currently used that has not been returned /// to the system allocator. public static var cacheMemory: Int { - mlx_metal_get_cache_memory() + var result: size_t = 0 + mlx_metal_get_cache_memory(&result) + return result } /// Get the peak amount of active memory in bytes. @@ -149,7 +153,9 @@ public enum GPU { /// The maximum memory used is recorded from the beginning of the program /// execution. public static var peakMemory: Int { - mlx_metal_get_peak_memory() + var result: size_t = 0 + mlx_metal_get_peak_memory(&result) + return result } /// Return a snapshot of memory stats -- see ``Snapshot`` for more details. @@ -180,8 +186,9 @@ public enum GPU { // set it to a reasonable value in order to read it, then set it back // to current - let current = mlx_metal_set_cache_limit(cacheMemory) - mlx_metal_set_cache_limit(current) + var current: size_t = 0 + mlx_metal_set_cache_limit(¤t, cacheMemory) + mlx_metal_set_cache_limit(¤t, current) _cacheLimit = current return current } @@ -199,7 +206,8 @@ public enum GPU { public static func set(cacheLimit: Int) { queue.sync { _cacheLimit = cacheLimit - mlx_metal_set_cache_limit(cacheLimit) + var current: size_t = 0 + mlx_metal_set_cache_limit(¤t, cacheLimit) } } @@ -217,8 +225,10 @@ public enum GPU { return memoryLimit } - let current = mlx_metal_set_memory_limit(activeMemory, _relaxedMemoryLimit) - mlx_metal_set_memory_limit(current, _relaxedMemoryLimit) + var current: size_t = 0 + var discard: size_t = 0 + mlx_metal_set_memory_limit(¤t, activeMemory, _relaxedMemoryLimit) + mlx_metal_set_memory_limit(&discard, current, _relaxedMemoryLimit) return current } } @@ -236,7 +246,8 @@ public enum GPU { queue.sync { _relaxedMemoryLimit = relaxed _memoryLimit = memoryLimit - mlx_metal_set_memory_limit(memoryLimit, relaxed) + var current: size_t = 0 + mlx_metal_set_memory_limit(¤t, memoryLimit, relaxed) } } @@ -257,9 +268,7 @@ public enum GPU { /// See [the documentation](https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html) /// for more information. public static func startCapture(url: URL) { - let path = mlx_string_new(url.path().cString(using: .utf8))! - defer { mlx_free(path) } - mlx_metal_start_capture(path) + mlx_metal_start_capture(url.path().cString(using: .utf8)) } /// Stop the metal capture. diff --git a/Source/MLX/MLXArray+Indexing.swift b/Source/MLX/MLXArray+Indexing.swift index f3116fc0..f5016670 100644 --- a/Source/MLX/MLXArray+Indexing.swift +++ b/Source/MLX/MLXArray+Indexing.swift @@ -63,11 +63,8 @@ extension MLXArray { /// Replace the interior ctx (`mlx_array` pointer) with a new value by transferring ownership @inline(__always) - func update(ctx: OpaquePointer) { - if ctx != self.ctx { - mlx_free(self.ctx) - self.ctx = ctx - } + func update(_ ctx: mlx_array) { + mlx_array_set(&self.ctx, ctx) } /// allow addressing as a positive index or negative (from end) using given axis @@ -213,10 +210,12 @@ extension MLXArray { stops[axis] = upper strides[axis] = 1 - return MLXArray( - mlx_slice( - ctx, starts, starts.count, stops, stops.count, strides, strides.count, - stream.ctx)) + var result = mlx_array_new() + mlx_slice( + &result, + ctx, starts, starts.count, stops, stops.count, strides, strides.count, + stream.ctx) + return MLXArray(result) } set { // this is [0 ..., 0 ..., range] where the number of full range leading expressions @@ -286,10 +285,12 @@ extension MLXArray { } strides[axis] = stride.int32 - return MLXArray( - mlx_slice( - ctx, starts, starts.count, stops, stops.count, strides, strides.count, - stream.ctx)) + var result = mlx_array_new() + mlx_slice( + &result, + ctx, starts, starts.count, stops, stops.count, strides, strides.count, + stream.ctx) + return MLXArray(result) } set { // see mlx_set_item_nd @@ -387,11 +388,13 @@ extension MLXArray { src: self, operations: operations, update: newValue, stream: stream) if !indices.isEmpty { let indices_vector = new_mlx_vector_array(indices) - defer { mlx_free(indices_vector) } + defer { mlx_vector_array_free(indices_vector) } - let result = MLXArray( - mlx_scatter(self.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_scatter( + &result, self.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx) self.update(result) + mlx_array_free(result) return } else { self.update(update) @@ -552,10 +555,12 @@ func getItem(src: MLXArray, operation: MLXArrayIndexOperation, stream: StreamOrD ends[0] = slice.end(size) strides[0] = slice.stride - return MLXArray( - mlx_slice( - src.ctx, starts, starts.count, ends, ends.count, strides, strides.count, - stream.ctx)) + var result = mlx_array_new() + mlx_slice( + &result, + src.ctx, starts, starts.count, ends, ends.count, strides, strides.count, + stream.ctx) + return MLXArray(result) case .array(let indices): return src.take(indices, axis: 0, stream: stream) @@ -699,9 +704,11 @@ func getItemND( axis += 1 } - src = MLXArray( - mlx_slice( - src.ctx, starts, starts.count, ends, ends.count, strides, strides.count, stream.ctx)) + var result = mlx_array_new() + mlx_slice( + &result, + src.ctx, starts, starts.count, ends, ends.count, strides, strides.count, stream.ctx) + src = MLXArray(result) // Unsqueeze handling if remainingIndices.count > src.ndim || squeezeNeeded { @@ -804,15 +811,16 @@ func gatherND( // Do the gather let indices = new_mlx_vector_array(gatherIndices) - defer { mlx_free(indices) } + defer { mlx_vector_array_free(indices) } let axes = Array(0 ..< operations.count.int32) var sliceSizes = shape32 for i in 0 ..< operations.count { sliceSizes[i] = 1 } - let gathered = MLXArray( - mlx_gather(src.ctx, indices, axes, axes.count, sliceSizes, sliceSizes.count, stream.ctx)) + var tmp = mlx_array_new() + mlx_gather(&tmp, src.ctx, indices, axes, axes.count, sliceSizes, sliceSizes.count, stream.ctx) + let gathered = MLXArray(tmp) let gatheredShape = gathered.shape // Squeeze the dims @@ -856,11 +864,12 @@ func updateSlice( ends[0] = slice.end(size) strides[0] = slice.stride - let result = MLXArray( - mlx_slice_update( - src.ctx, update.ctx, starts, starts.count, ends, ends.count, strides, strides.count, - stream.ctx)) - return result + var result = mlx_array_new() + mlx_slice_update( + &result, + src.ctx, update.ctx, starts, starts.count, ends, ends.count, strides, strides.count, + stream.ctx) + return MLXArray(result) } // Expand ellipses into a series of ':' (full slice) slices @@ -923,11 +932,12 @@ func updateSlice( update = reshaped(update, updateReshape) - let result = MLXArray( - mlx_slice_update( - src.ctx, update.ctx, starts, starts.count, ends, ends.count, strides, strides.count, - stream.ctx)) - return result + var result = mlx_array_new() + mlx_slice_update( + &result, + src.ctx, update.ctx, starts, starts.count, ends, ends.count, strides, strides.count, + stream.ctx) + return MLXArray(result) } // MARK: - index set (scatter) diff --git a/Source/MLX/MLXArray+Init.swift b/Source/MLX/MLXArray+Init.swift index a3d7a84a..260598df 100644 --- a/Source/MLX/MLXArray+Init.swift +++ b/Source/MLX/MLXArray+Init.swift @@ -22,7 +22,7 @@ extension MLXArray { /// ### See Also /// - public convenience init(_ value: Int32) { - self.init(mlx_array_from_int(value)) + self.init(mlx_array_new_int(value)) } /// Initalizer allowing creation of scalar (0-dimension) `MLXArray` from an `Int` as @@ -43,7 +43,7 @@ extension MLXArray { (Int(Int32.min) ... Int(Int32.max)).contains(value), "\(value) is out of range for Int32 -- please use MLXArray(int64: Int) if you need 64 bits." ) - self.init(mlx_array_from_int(Int32(value))) + self.init(mlx_array_new_int(Int32(value))) } /// Initalizer allowing creation of scalar (0-dimension) `MLXArray` from an `Int` as @@ -60,7 +60,7 @@ extension MLXArray { public convenience init(int64 value: Int) { self.init( withUnsafePointer(to: value) { ptr in - mlx_array_from_data(ptr, [], 0, Int.dtype.cmlxDtype) + mlx_array_new_data(ptr, [], 0, Int.dtype.cmlxDtype) }) } @@ -73,7 +73,7 @@ extension MLXArray { /// ### See Also /// - public convenience init(_ value: Bool) { - self.init(mlx_array_from_bool(value)) + self.init(mlx_array_new_bool(value)) } /// Initalizer allowing creation of scalar (0-dimension) `MLXArray` from a `Float`. @@ -85,7 +85,7 @@ extension MLXArray { /// ### See Also /// - public convenience init(_ value: Float) { - self.init(mlx_array_from_float(value)) + self.init(mlx_array_new_float(value)) } /// Initalizer allowing creation of scalar (0-dimension) `MLXArray` from a `HasDType` value. @@ -99,7 +99,7 @@ extension MLXArray { public convenience init(_ value: T) { self.init( withUnsafePointer(to: value) { ptr in - mlx_array_from_data(ptr, [], 0, T.dtype.cmlxDtype) + mlx_array_new_data(ptr, [], 0, T.dtype.cmlxDtype) }) } @@ -114,9 +114,10 @@ extension MLXArray { /// - public convenience init(bfloat16 value: Float32) { let stream = StreamOrDevice.default - let v_mlx = mlx_array_from_float(Float32(value))! - defer { mlx_free(v_mlx) } - let v_bfloat = mlx_astype(v_mlx, DType.bfloat16.cmlxDtype, stream.ctx)! + let v_mlx = mlx_array_new_float(Float32(value)) + defer { mlx_array_free(v_mlx) } + var v_bfloat = mlx_array_new() + mlx_astype(&v_bfloat, v_mlx, DType.bfloat16.cmlxDtype, stream.ctx) self.init(v_bfloat) } @@ -143,7 +144,7 @@ extension MLXArray { default: self.init( withUnsafePointer(to: value) { ptr in - mlx_array_from_data(ptr, [], 0, T.dtype.cmlxDtype) + mlx_array_new_data(ptr, [], 0, T.dtype.cmlxDtype) }) } } else { @@ -241,7 +242,7 @@ extension MLXArray { self.init( value.withUnsafeBufferPointer { ptr in let shape = shape ?? [value.count] - return mlx_array_from_data( + return mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, T.dtype.cmlxDtype) }) } @@ -271,7 +272,7 @@ extension MLXArray { .map { Int32($0) } .withUnsafeBufferPointer { ptr in let shape = shape ?? [value.count] - return mlx_array_from_data( + return mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, Int32.dtype.cmlxDtype) }) } @@ -294,7 +295,7 @@ extension MLXArray { value .withUnsafeBufferPointer { ptr in let shape = shape ?? [value.count] - return mlx_array_from_data( + return mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, Int.dtype.cmlxDtype) }) } @@ -316,7 +317,7 @@ extension MLXArray { self.init( floats.withUnsafeBufferPointer { ptr in let shape = shape ?? [floats.count] - return mlx_array_from_data( + return mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, Float.dtype.cmlxDtype) }) } @@ -372,7 +373,7 @@ extension MLXArray { self.init( value.withUnsafeBufferPointer { ptr in let shape = shape ?? [value.count] - return mlx_array_from_data( + return mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, Int.dtype.cmlxDtype) }) } @@ -393,7 +394,7 @@ extension MLXArray { shapePrecondition(shape: shape, count: ptr.count) let shape = shape ?? [ptr.count] self.init( - mlx_array_from_data( + mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, T.dtype.cmlxDtype)) } @@ -430,7 +431,7 @@ extension MLXArray { let buffer = ptr.assumingMemoryBound(to: type) shapePrecondition(shape: shape, count: buffer.count) let shape = shape ?? [buffer.count] - return mlx_array_from_data( + return mlx_array_new_data( ptr.baseAddress!, shape.asInt32, shape.count.int32, T.dtype.cmlxDtype) }) } @@ -440,7 +441,7 @@ extension MLXArray { /// - real: real part /// - imaginary: imaginary part public convenience init(real: Float, imaginary: Float) { - self.init(mlx_array_from_data([real, imaginary], [], 0, DType.complex64.cmlxDtype)) + self.init(mlx_array_new_data([real, imaginary], [], 0, DType.complex64.cmlxDtype)) } /// Create a ``DType/complex64`` scalar from `Complex`. @@ -478,8 +479,8 @@ extension MLXArray: ExpressibleByArrayLiteral { public convenience init(arrayLiteral elements: Int32...) { let ctx = elements.withUnsafeBufferPointer { ptr in let shape = [Int32(elements.count)] - return mlx_array_from_data( - ptr.baseAddress!, shape, Int32(shape.count), Int32.dtype.cmlxDtype)! + return mlx_array_new_data( + ptr.baseAddress!, shape, Int32(shape.count), Int32.dtype.cmlxDtype) } self.init(ctx) } diff --git a/Source/MLX/MLXArray+Ops.swift b/Source/MLX/MLXArray+Ops.swift index ac1f9fca..72ba8da2 100644 --- a/Source/MLX/MLXArray+Ops.swift +++ b/Source/MLX/MLXArray+Ops.swift @@ -41,7 +41,9 @@ extension MLXArray { /// - ``add(_:_:stream:)`` public static func + (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_add(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_add(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise addition. @@ -109,7 +111,9 @@ extension MLXArray { /// - ``subtract(_:_:stream:)`` public static func - (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_subtract(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_subtract(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise subtraction. @@ -171,7 +175,9 @@ extension MLXArray { /// - ``negative(_:stream:)`` public static prefix func - (lhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_negative(lhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_negative(&result, lhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise multiplication. @@ -194,7 +200,9 @@ extension MLXArray { /// - ``matmul(_:_:stream:)`` public static func * (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_multiply(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_multiply(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise multiplication. @@ -267,7 +275,9 @@ extension MLXArray { /// - ``pow(_:_:stream:)-49xi0`` public static func ** (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_power(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_power(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise power with a ``ScalarOrArray`` (scalar) argument. @@ -307,7 +317,9 @@ extension MLXArray { /// - ``floorDivide(_:_:stream:)`` public static func / (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_divide(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_divide(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise division. @@ -372,7 +384,9 @@ extension MLXArray { /// - ``remainder(_:_:stream:)`` public static func % (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_remainder(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_remainder(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise remainder with a ``ScalarOrArray`` (scalar) argument. @@ -410,7 +424,9 @@ extension MLXArray { /// - ``logicalNot(_:stream:)`` public static prefix func .! (lhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_logical_not(lhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_logical_not(&result, lhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise equality. @@ -435,7 +451,9 @@ extension MLXArray { /// - ``allClose(_:_:rtol:atol:equalNaN:stream:)`` public static func .== (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_equal(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_equal(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise equality with a ``ScalarOrArray`` (scalar) argument. @@ -467,7 +485,9 @@ extension MLXArray { /// - ``lessEqual(_:_:stream:)`` public static func .<= (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_less_equal(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_less_equal(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise less than or equal with a ``ScalarOrArray`` (scalar) argument. @@ -499,7 +519,9 @@ extension MLXArray { /// - ``greaterEqual(_:_:stream:)`` public static func .>= (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_greater_equal(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_greater_equal(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise greater than or equal with a ``ScalarOrArray`` (scalar) argument. @@ -531,7 +553,9 @@ extension MLXArray { /// - ``notEqual(_:_:stream:)`` public static func .!= (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_not_equal(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_not_equal(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise not equal with a ``ScalarOrArray`` (scalar) argument. @@ -563,7 +587,9 @@ extension MLXArray { /// - ``less(_:_:stream:)`` public static func .< (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_less(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_less(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise less than with a ``ScalarOrArray`` (scalar) argument. @@ -595,7 +621,9 @@ extension MLXArray { /// - ``greater(_:_:stream:)`` public static func .> (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_greater(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_greater(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise greater than with a ``ScalarOrArray`` (scalar) argument. @@ -625,7 +653,9 @@ extension MLXArray { /// - ``logicalAnd(_:_:stream:)`` public static func .&& (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_logical_and(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_logical_and(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise logical or. @@ -646,7 +676,9 @@ extension MLXArray { /// - ``logicalOr(_:_:stream:)`` public static func .|| (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_logical_or(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_logical_or(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise bitwise and. @@ -659,7 +691,9 @@ extension MLXArray { /// - ``bitwiseAnd(_:_:stream:)`` public static func & (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_bitwise_and(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_bitwise_and(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise bitwise and. @@ -696,7 +730,9 @@ extension MLXArray { /// - ``bitwiseOr(_:_:stream:)`` public static func | (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_bitwise_or(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_bitwise_or(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise bitwise or. @@ -733,7 +769,9 @@ extension MLXArray { /// - ``bitwiseXOr(_:_:stream:)`` public static func ^ (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_bitwise_xor(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_bitwise_xor(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise bitwise xor. @@ -771,7 +809,9 @@ extension MLXArray { /// - ``leftShift(_:_:stream:)`` public static func << (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_left_shift(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_left_shift(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise left shift. @@ -811,7 +851,9 @@ extension MLXArray { /// - ``rightShift(_:_:stream:)`` public static func >> (lhs: MLXArray, rhs: MLXArray) -> MLXArray { let s = StreamOrDevice.default - return MLXArray(mlx_right_shift(lhs.ctx, rhs.ctx, s.ctx)) + var result = mlx_array_new() + mlx_right_shift(&result, lhs.ctx, rhs.ctx, s.ctx) + return MLXArray(result) } /// Element-wise right shift. @@ -1086,21 +1128,27 @@ extension MLXArray { extension MLXArray { func broadcast(to shape: [Int32], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_broadcast_to(ctx, shape, shape.count, stream.ctx)) + var result = mlx_array_new() + mlx_broadcast_to(&result, ctx, shape, shape.count, stream.ctx) + return MLXArray(result) } func scattered( indices: [MLXArray], updates: MLXArray, axes: [Int32], stream: StreamOrDevice = .default ) -> MLXArray { let vector_array = new_mlx_vector_array(indices) - defer { mlx_free(vector_array) } + defer { mlx_vector_array_free(vector_array) } - return MLXArray(mlx_scatter(ctx, vector_array, updates.ctx, axes, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_scatter(&result, ctx, vector_array, updates.ctx, axes, axes.count, stream.ctx) + return MLXArray(result) } // varaiant with [Int32] argument func reshaped(_ newShape: [Int32], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reshape(ctx, newShape, newShape.count, stream.ctx)) + var result = mlx_array_new() + mlx_reshape(&result, ctx, newShape, newShape.count, stream.ctx) + return MLXArray(result) } } @@ -1117,7 +1165,9 @@ extension MLXArray { /// ### See Also /// - public func abs(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_abs(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_abs(&result, ctx, stream.ctx) + return MLXArray(result) } /// An `and` reduction over the given axes. @@ -1148,7 +1198,9 @@ extension MLXArray { public func all(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_all_axes(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_all_axes(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// An `and` reduction over the given axes. @@ -1166,7 +1218,9 @@ extension MLXArray { public func all(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_all_axis(ctx, axis.int32, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_all_axis(&result, ctx, axis.int32, keepDims, stream.ctx) + return MLXArray(result) } /// An `and` reduction over the given axes. @@ -1181,7 +1235,9 @@ extension MLXArray { /// - ``all(axis:keepDims:stream:)`` /// - ``all(_:axes:keepDims:stream:)`` public func all(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_all_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_all_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Approximate comparison of two arrays. @@ -1221,7 +1277,9 @@ extension MLXArray { stream: StreamOrDevice = .default ) -> MLXArray { let other = other.asMLXArray(dtype: self.dtype) - return MLXArray(mlx_allclose(ctx, other.ctx, rtol, atol, equalNaN, stream.ctx)) + var result = mlx_array_new() + mlx_allclose(&result, ctx, other.ctx, rtol, atol, equalNaN, stream.ctx) + return MLXArray(result) } /// An `or` reduction over the given axes. @@ -1252,7 +1310,9 @@ extension MLXArray { public func any(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_any(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_any(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// An `or` reduction over the given axes. @@ -1270,7 +1330,9 @@ extension MLXArray { public func any(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_any(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_any(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// An `or` reduction over the given axes. @@ -1285,7 +1347,9 @@ extension MLXArray { /// - ``any(axis:keepDims:stream:)`` /// - ``any(_:axes:keepDims:stream:)`` public func any(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_any_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_any_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the maximum values along the axis. @@ -1310,7 +1374,9 @@ extension MLXArray { public func argMax(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argmax(ctx, axis.int32, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmax(&result, ctx, axis.int32, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the maximum value over the entire array. @@ -1332,7 +1398,9 @@ extension MLXArray { /// - ``argMin(axis:keepDims:stream:)`` /// - ``argMax(_:axis:keepDims:stream:)`` public func argMax(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argmax_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmax_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the minimum values along the axis. @@ -1357,7 +1425,9 @@ extension MLXArray { public func argMin(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argmin(ctx, axis.int32, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmin(&result, ctx, axis.int32, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the minimum value over the entire array. @@ -1379,7 +1449,9 @@ extension MLXArray { /// - ``argMax(axis:keepDims:stream:)`` /// - ``argMin(_:axis:keepDims:stream:)`` public func argMin(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argmin_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmin_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Array equality check. @@ -1405,7 +1477,9 @@ extension MLXArray { _ other: T, equalNAN: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { let other = other.asMLXArray(dtype: self.dtype) - return MLXArray(mlx_array_equal(ctx, other.ctx, equalNAN, stream.ctx)) + var result = mlx_array_new() + mlx_array_equal(&result, ctx, other.ctx, equalNAN, stream.ctx) + return MLXArray(result) } /// Element-wise complex conjugate of the input. @@ -1414,7 +1488,9 @@ extension MLXArray { /// - /// - ``conjugate(_:stream:)`` public func conjugate(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_conjugate(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_conjugate(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise cosine. @@ -1423,7 +1499,9 @@ extension MLXArray { /// - /// - ``cos(_:stream:)`` public func cos(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_cos(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_cos(&result, ctx, stream.ctx) + return MLXArray(result) } /// Return the cumulative maximum of the elements along the given axis. @@ -1442,7 +1520,9 @@ extension MLXArray { public func cummax( axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cummax(ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cummax(&result, ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative maximum of the elements over the flattened array. @@ -1461,9 +1541,10 @@ extension MLXArray { public func cummax( reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cummax(flat, 0, reverse, inclusive, stream.ctx)) + let flat = self.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cummax(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative minimum of the elements along the given axis. @@ -1482,7 +1563,9 @@ extension MLXArray { public func cummin( axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cummin(ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cummin(&result, ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative minimum of the elements over the flattened array. @@ -1501,9 +1584,10 @@ extension MLXArray { public func cummin( reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cummin(flat, 0, reverse, inclusive, stream.ctx)) + let flat = self.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cummin(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative product of the elements along the given axis. @@ -1522,7 +1606,9 @@ extension MLXArray { public func cumprod( axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cumprod(ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cumprod(&result, ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative product of the elements over the flattened array. @@ -1541,9 +1627,10 @@ extension MLXArray { public func cumprod( reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cumprod(flat, 0, reverse, inclusive, stream.ctx)) + let flat = self.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cumprod(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative sum of the elements along the given axis. @@ -1562,7 +1649,9 @@ extension MLXArray { public func cumsum( axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cumsum(ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cumsum(&result, ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative sum of the elements over the flattened array. @@ -1581,9 +1670,10 @@ extension MLXArray { public func cumsum( reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cumsum(flat, 0, reverse, inclusive, stream.ctx)) + let flat = self.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cumsum(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Extract a diagonal or construct a diagonal matrix. @@ -1599,7 +1689,9 @@ extension MLXArray { /// ### See Also /// - ``diagonal(offset:axis1:axis2:stream:)`` public func diag(k: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_diag(ctx, k.int32, stream.ctx)) + var result = mlx_array_new() + mlx_diag(&result, ctx, k.int32, stream.ctx) + return MLXArray(result) } /// Return specified diagonals. @@ -1623,7 +1715,9 @@ extension MLXArray { public func diagonal( offset: Int = 0, axis1: Int = 0, axis2: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_diagonal(ctx, offset.int32, axis1.int32, axis2.int32, stream.ctx)) + var result = mlx_array_new() + mlx_diagonal(&result, ctx, offset.int32, axis1.int32, axis2.int32, stream.ctx) + return MLXArray(result) } /// Element-wise exponential. @@ -1632,7 +1726,9 @@ extension MLXArray { /// - /// - ``exp(_:stream:)`` public func exp(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_exp(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_exp(&result, ctx, stream.ctx) + return MLXArray(result) } /// Add a size one dimension at the given axis. @@ -1648,7 +1744,9 @@ extension MLXArray { public func expandedDimensions(axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_expand_dims(self.ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_expand_dims(&result, self.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } /// Add a size one dimension at the given axis. @@ -1664,7 +1762,9 @@ extension MLXArray { public func expandedDimensions(axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_expand_dims(self.ctx, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_expand_dims(&result, self.ctx, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Flatten an array. @@ -1696,7 +1796,9 @@ extension MLXArray { public func flattened(start: Int = 0, end: Int = -1, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_flatten(ctx, start.int32, end.int32, stream.ctx)) + var result = mlx_array_new() + mlx_flatten(&result, ctx, start.int32, end.int32, stream.ctx) + return MLXArray(result) } /// Element-wise floor. @@ -1707,7 +1809,9 @@ extension MLXArray { /// - ``floorDivide(_:stream:)`` /// - ``floor(_:stream:)`` public func floor(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_floor(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_floor(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise integer division.. @@ -1733,7 +1837,9 @@ extension MLXArray { -> MLXArray { let other = other.asMLXArray(dtype: self.dtype) - return MLXArray(mlx_floor_divide(ctx, other.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_floor_divide(&result, ctx, other.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise natural logarithm. @@ -1742,7 +1848,9 @@ extension MLXArray { /// - /// - ``log(_:stream:)`` public func log(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise base-2 logarithm. @@ -1754,7 +1862,9 @@ extension MLXArray { /// - ``log1p(stream:)`` /// - ``log2(_:stream:)`` public func log2(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log2(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log2(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise base-10 logarithm. @@ -1766,7 +1876,9 @@ extension MLXArray { /// - ``log1p(stream:)`` /// - ``log10(_:stream:)`` public func log10(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log10(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log10(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise natural log of one plus the array. @@ -1778,7 +1890,9 @@ extension MLXArray { /// - ``log10(stream:)`` /// - ``log1p(_:stream:)`` public func log1p(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log1p(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log1p(&result, ctx, stream.ctx) + return MLXArray(result) } /// A `log-sum-exp` reduction over the given axes. @@ -1802,7 +1916,9 @@ extension MLXArray { public func logSumExp(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_logsumexp(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_logsumexp(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `log-sum-exp` reduction over the given axis. @@ -1826,7 +1942,9 @@ extension MLXArray { public func logSumExp(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_logsumexp(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_logsumexp(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `log-sum-exp` reduction over the entire array. @@ -1847,7 +1965,9 @@ extension MLXArray { /// - ``logSumExp(axis:keepDims:stream:)`` /// - ``logSumExp(_:axes:keepDims:stream:)`` public func logSumExp(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_logsumexp_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_logsumexp_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Matrix multiplication. @@ -1879,7 +1999,9 @@ extension MLXArray { /// - /// - ``matmul(_:_:stream:)`` public func matmul(_ other: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_matmul(ctx, other.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_matmul(&result, ctx, other.ctx, stream.ctx) + return MLXArray(result) } /// A `max` reduction over the given axes. @@ -1904,7 +2026,9 @@ extension MLXArray { public func max(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_max(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_max(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `max` reduction over the given axis. @@ -1929,7 +2053,9 @@ extension MLXArray { public func max(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_max(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_max(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `max` reduction over the entire array. @@ -1951,7 +2077,9 @@ extension MLXArray { /// - ``max(axis:keepDims:stream:)`` /// - ``max(_:axes:keepDims:stream:)`` public func max(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_max_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_max_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// A `mean` reduction over the given axes. @@ -1976,7 +2104,9 @@ extension MLXArray { public func mean(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_mean(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_mean(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `mean` reduction over the given axis. @@ -2001,7 +2131,9 @@ extension MLXArray { public func mean(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_mean(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_mean(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `mean` reduction over the entire array. @@ -2023,7 +2155,9 @@ extension MLXArray { /// - ``mean(axis:keepDims:stream:)`` /// - ``mean(_:axes:keepDims:stream:)`` public func mean(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_mean_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_mean_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// A `min` reduction over the given axes. @@ -2048,7 +2182,9 @@ extension MLXArray { public func min(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_min(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_min(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `min` reduction over the given axis. @@ -2073,7 +2209,9 @@ extension MLXArray { public func min(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_min(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_min(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `min` reduction over the entire array. @@ -2095,7 +2233,9 @@ extension MLXArray { /// - ``min(axis:keepDims:stream:)`` /// - ``min(_:axes:keepDims:stream:)`` public func min(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_min_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_min_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Move an axis to a new position. @@ -2131,7 +2271,9 @@ extension MLXArray { public func movedAxis(source: Int, destination: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_moveaxis(ctx, source.int32, destination.int32, stream.ctx)) + var result = mlx_array_new() + mlx_moveaxis(&result, ctx, source.int32, destination.int32, stream.ctx) + return MLXArray(result) } /// Element-wise power operation. @@ -2155,7 +2297,9 @@ extension MLXArray { /// - ``pow(_:_:stream:)-8ie9c`` public func pow(_ other: T, stream: StreamOrDevice = .default) -> MLXArray { let other = other.asMLXArray(dtype: self.dtype) - return MLXArray(mlx_power(ctx, other.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_power(&result, ctx, other.ctx, stream.ctx) + return MLXArray(result) } /// A `product` reduction over the given axes. @@ -2180,7 +2324,9 @@ extension MLXArray { public func product(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_prod(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_prod(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `product` reduction over the given axis. @@ -2205,7 +2351,9 @@ extension MLXArray { public func product(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_prod(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_prod(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `product` reduction over the entire array. @@ -2227,7 +2375,9 @@ extension MLXArray { /// - ``product(axis:keepDims:stream:)`` /// - ``product(_:axes:keepDims:stream:)`` public func product(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_prod_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_prod_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Element-wise reciprocal. @@ -2236,7 +2386,9 @@ extension MLXArray { /// - /// - ``reciprocal(_:stream:)`` public func reciprocal(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reciprocal(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_reciprocal(&result, ctx, stream.ctx) + return MLXArray(result) } /// Reshape an array while preserving the size. @@ -2252,7 +2404,9 @@ extension MLXArray { /// - ``reshaped(_:stream:)`` /// - ``reshaped(_:_:stream:)-96lgr`` public func reshaped(_ newShape: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reshape(ctx, newShape.asInt32, newShape.count, stream.ctx)) + var result = mlx_array_new() + mlx_reshape(&result, ctx, newShape.asInt32, newShape.count, stream.ctx) + return MLXArray(result) } /// Reshape an array while preserving the size. @@ -2268,7 +2422,9 @@ extension MLXArray { /// - ``reshaped(_:_:stream:)-5x3y0`` /// - ``reshaped(_:stream:)`` public func reshaped(_ newShape: Int..., stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reshape(ctx, newShape.asInt32, newShape.count, stream.ctx)) + var result = mlx_array_new() + mlx_reshape(&result, ctx, newShape.asInt32, newShape.count, stream.ctx) + return MLXArray(result) } /// Round to the given number of decimals. @@ -2287,7 +2443,9 @@ extension MLXArray { /// - ``floor(stream:)`` /// - ``round(_:decimals:stream:)`` public func round(decimals: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_round(ctx, decimals.int32, stream.ctx)) + var result = mlx_array_new() + mlx_round(&result, ctx, decimals.int32, stream.ctx) + return MLXArray(result) } /// Element-wise reciprocal and square root. @@ -2296,7 +2454,9 @@ extension MLXArray { /// - /// - ``sqrt(_:stream:)`` public func rsqrt(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_rsqrt(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_rsqrt(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise sine. @@ -2305,7 +2465,9 @@ extension MLXArray { /// - /// - ``sin(_:stream:)`` public func sin(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sin(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sin(&result, ctx, stream.ctx) + return MLXArray(result) } /// Split an array into equal size pieces along a given axis. @@ -2334,8 +2496,9 @@ extension MLXArray { /// - ``split(indices:axis:stream:)`` /// - ``split(_:parts:axis:stream:)`` public func split(parts: Int, axis: Int = 0, stream: StreamOrDevice = .default) -> [MLXArray] { - let vec = mlx_split_equal_parts(ctx, parts.int32, axis.int32, stream.ctx)! - defer { mlx_free(vec) } + var vec = mlx_vector_array_new() + mlx_split_equal_parts(&vec, ctx, parts.int32, axis.int32, stream.ctx) + defer { mlx_vector_array_free(vec) } return mlx_vector_array_values(vec) } @@ -2358,8 +2521,9 @@ extension MLXArray { /// - ``split(indices:axis:stream:)`` /// - ``split(_:parts:axis:stream:)`` public func split(axis: Int = 0, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { - let vec = mlx_split_equal_parts(ctx, 2, axis.int32, stream.ctx)! - defer { mlx_free(vec) } + var vec = mlx_vector_array_new() + mlx_split_equal_parts(&vec, ctx, 2, axis.int32, stream.ctx) + defer { mlx_vector_array_free(vec) } let pieces = mlx_vector_array_values(vec) return (pieces[0], pieces[1]) } @@ -2377,8 +2541,9 @@ extension MLXArray { public func split(indices: [Int], axis: Int = 0, stream: StreamOrDevice = .default) -> [MLXArray] { - let vec = mlx_split(ctx, indices.asInt32, indices.count, axis.int32, stream.ctx)! - defer { mlx_free(vec) } + var vec = mlx_vector_array_new() + mlx_split(&vec, ctx, indices.asInt32, indices.count, axis.int32, stream.ctx) + defer { mlx_vector_array_free(vec) } return mlx_vector_array_values(vec) } @@ -2388,7 +2553,9 @@ extension MLXArray { /// - /// - ``sqrt(_:stream:)`` public func sqrt(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sqrt(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sqrt(&result, ctx, stream.ctx) + return MLXArray(result) } /// Element-wise square. @@ -2397,7 +2564,9 @@ extension MLXArray { /// - /// - ``square(_:stream:)`` public func square(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_square(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_square(&result, ctx, stream.ctx) + return MLXArray(result) } /// Remove length one axes from an array. @@ -2411,7 +2580,9 @@ extension MLXArray { /// - ``squeezed(stream:)`` /// - ``squeezed(_:axes:stream:)`` public func squeezed(axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_squeeze(ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_squeeze(&result, ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } /// Remove length one axes from an array. @@ -2425,7 +2596,9 @@ extension MLXArray { /// - ``squeezed(stream:)`` /// - ``squeezed(_:axes:stream:)`` public func squeezed(axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_squeeze(ctx, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_squeeze(&result, ctx, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Remove all length one axes from an array. @@ -2436,7 +2609,9 @@ extension MLXArray { /// - ``squeezed(axis:stream:)`` /// - ``squeezed(_:axes:stream:)`` public func squeezed(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_squeeze_all(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_squeeze_all(&result, ctx, stream.ctx) + return MLXArray(result) } /// Sum reduce the array over the given axes. @@ -2454,7 +2629,9 @@ extension MLXArray { public func sum(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sum(ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_sum(&result, ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// Sum reduce the array over the given axis. @@ -2472,7 +2649,9 @@ extension MLXArray { public func sum(axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sum(ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_sum(&result, ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// Sum reduce the array over all axes. @@ -2487,7 +2666,9 @@ extension MLXArray { /// - ``sum(axis:keepDims:stream:)`` /// - ``sum(_:axes:keepDims:stream:)`` public func sum(keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sum_all(ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_sum_all(&result, ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Swap two axes of an array. @@ -2522,7 +2703,9 @@ extension MLXArray { public func swappedAxes(_ axis1: Int, _ axis2: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_swapaxes(ctx, axis1.int32, axis2.int32, stream.ctx)) + var result = mlx_array_new() + mlx_swapaxes(&result, ctx, axis1.int32, axis2.int32, stream.ctx) + return MLXArray(result) } /// Take elements along an axis. @@ -2541,7 +2724,9 @@ extension MLXArray { /// - ``take(_:_:axis:stream:)`` public func take(_ indices: MLXArray, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_take(ctx, indices.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_take(&result, ctx, indices.ctx, axis.int32, stream.ctx) + return MLXArray(result) } /// Take elements from flattened 1-D array. @@ -2551,7 +2736,9 @@ extension MLXArray { /// - ``take(_:_:axis:stream:)`` public func take(_ indices: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { let input = self.reshaped([-1], stream: stream) - return MLXArray(mlx_take(input.ctx, indices.ctx, 0, stream.ctx)) + var result = mlx_array_new() + mlx_take(&result, input.ctx, indices.ctx, 0, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -2565,11 +2752,15 @@ extension MLXArray { /// - ``transposed(stream:)`` /// - ``transposed(_:axes:stream:)`` public func transposed(axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose(ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_transpose(&result, ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } public func transposed(_ axes: Int..., stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose(ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_transpose(&result, ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -2582,7 +2773,9 @@ extension MLXArray { /// - ``transposed(stream:)`` /// - ``transposed(_:axes:stream:)`` public func transposed(axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose(ctx, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_transpose(&result, ctx, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -2595,7 +2788,9 @@ extension MLXArray { /// - ``transposed(axis:stream:)`` /// - ``transposed(_:axes:stream:)`` public func transposed(stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose_all(ctx, stream.ctx)) + var result = mlx_array_new() + mlx_transpose_all(&result, ctx, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -2620,7 +2815,9 @@ extension MLXArray { public func variance( axes: [Int], keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_var(ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_var(&result, ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Compute the variance(s) over the given axes @@ -2638,7 +2835,9 @@ extension MLXArray { public func variance( axis: Int, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_var(ctx, [axis.int32], 1, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_var(&result, ctx, [axis.int32], 1, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Compute the variance(s) over the given axes @@ -2656,7 +2855,9 @@ extension MLXArray { public func variance(keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_var_all(ctx, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_var_all(&result, ctx, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// View the array as a different type. @@ -2673,6 +2874,8 @@ extension MLXArray { /// - stream: stream or device to evaluate on /// - Returns: array with the new type public func view(dtype: DType, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_view(ctx, dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_view(&result, ctx, dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } } diff --git a/Source/MLX/MLXArray.swift b/Source/MLX/MLXArray.swift index 9c42e6b9..b2da7a7a 100644 --- a/Source/MLX/MLXArray.swift +++ b/Source/MLX/MLXArray.swift @@ -13,12 +13,25 @@ public final class MLXArray { /// Initialize with the given +1 context (transfer ownership). /// /// This initializer is for `Cmlx` interoperation. - public init(_ ctx: mlx_array) { + public init(_ ctx: consuming mlx_array) { self.ctx = ctx } + /// return the equivalent of a `.none` MLXArray (for the C API). + /// + /// Not called `.none` to avoid abiguity with `Optional`. This can be used + /// to pass an optional ``MLXArray`` as a non-optional (but possibly empty/null) + /// `mlx_array`: + /// + /// ```swift + /// mlx_func((freqs ?? .mlxNone).ctx) + /// ``` + public static var mlxNone: MLXArray { + .init(mlx_array_new()) + } + deinit { - mlx_free(ctx) + mlx_array_free(ctx) } /// Number of bytes per element @@ -68,7 +81,7 @@ public final class MLXArray { /// print(array.dtype) /// // .int64 (aka Int.dtype) /// ``` - public var dtype: DType { DType(mlx_array_get_dtype(ctx)) } + public var dtype: DType { DType(mlx_array_dtype(ctx)) } /// Dimensions of the array. /// @@ -159,31 +172,88 @@ public final class MLXArray { /// specialized conversion between integer types -- see ``item(_:)`` private func itemInt() -> Int { switch self.dtype { - case .bool: mlx_array_item_bool(self.ctx) ? 1 : 0 - case .uint8: Int(mlx_array_item_uint8(self.ctx)) - case .uint16: Int(mlx_array_item_uint16(self.ctx)) - case .uint32: Int(mlx_array_item_uint32(self.ctx)) - case .uint64: Int(mlx_array_item_uint64(self.ctx)) - case .int8: Int(mlx_array_item_int8(self.ctx)) - case .int16: Int(mlx_array_item_int16(self.ctx)) - case .int32: Int(mlx_array_item_int32(self.ctx)) - case .int64: Int(mlx_array_item_int64(self.ctx)) - default: fatalError("itemInt expected an integer dtype: \(self.dtype)") + case .bool: + var r = false + mlx_array_item_bool(&r, self.ctx) + return r ? 1 : 0 + case .uint8: + var r: UInt8 = 0 + mlx_array_item_uint8(&r, self.ctx) + return Int(r) + case .uint16: + var r: UInt16 = 0 + mlx_array_item_uint16(&r, self.ctx) + return Int(r) + case .uint32: + var r: UInt32 = 0 + mlx_array_item_uint32(&r, self.ctx) + return Int(r) + case .uint64: + var r: UInt64 = 0 + mlx_array_item_uint64(&r, self.ctx) + return Int(r) + case .int8: + var r: Int8 = 0 + mlx_array_item_int8(&r, self.ctx) + return Int(r) + case .int16: + var r: Int16 = 0 + mlx_array_item_int16(&r, self.ctx) + return Int(r) + case .int32: + var r: Int32 = 0 + mlx_array_item_int32(&r, self.ctx) + return Int(r) + case .int64: + var r: Int64 = 0 + mlx_array_item_int64(&r, self.ctx) + return Int(r) + + default: + fatalError("itemInt expected an integer dtype: \(self.dtype)") } } /// specialized conversion between integer types -- see ``item(_:)`` private func itemUInt() -> UInt { switch self.dtype { - case .bool: mlx_array_item_bool(self.ctx) ? 1 : 0 - case .uint8: UInt(mlx_array_item_uint8(self.ctx)) - case .uint16: UInt(mlx_array_item_uint16(self.ctx)) - case .uint32: UInt(mlx_array_item_uint32(self.ctx)) - case .uint64: UInt(mlx_array_item_uint64(self.ctx)) - case .int8: UInt(mlx_array_item_int8(self.ctx)) - case .int16: UInt(mlx_array_item_int16(self.ctx)) - case .int32: UInt(mlx_array_item_int32(self.ctx)) - case .int64: UInt(mlx_array_item_int64(self.ctx)) + case .bool: + var r = false + mlx_array_item_bool(&r, self.ctx) + return r ? 1 : 0 + case .uint8: + var r: UInt8 = 0 + mlx_array_item_uint8(&r, self.ctx) + return UInt(r) + case .uint16: + var r: UInt16 = 0 + mlx_array_item_uint16(&r, self.ctx) + return UInt(r) + case .uint32: + var r: UInt32 = 0 + mlx_array_item_uint32(&r, self.ctx) + return UInt(r) + case .uint64: + var r: UInt64 = 0 + mlx_array_item_uint64(&r, self.ctx) + return UInt(r) + case .int8: + var r: Int8 = 0 + mlx_array_item_int8(&r, self.ctx) + return UInt(r) + case .int16: + var r: Int16 = 0 + mlx_array_item_int16(&r, self.ctx) + return UInt(r) + case .int32: + var r: Int32 = 0 + mlx_array_item_int32(&r, self.ctx) + return UInt(r) + case .int64: + var r: Int64 = 0 + mlx_array_item_int64(&r, self.ctx) + return UInt(r) + default: fatalError("itemUInt expected an integer dtype: \(self.dtype)") } } @@ -192,9 +262,16 @@ public final class MLXArray { private func itemFloat() -> Float { switch self.dtype { #if !arch(x86_64) - case .float16: Float(mlx_array_item_float16(self.ctx)) + case .float16: + var r: Float16 = 0 + mlx_array_item_float16(&r, self.ctx) + return Float(r) #endif - case .float32: Float(mlx_array_item_float32(self.ctx)) + case .float32: + var r: Float32 = 0 + mlx_array_item_float32(&r, self.ctx) + return Float(r) + default: fatalError("itemFloat expected a floating point dtype: \(self.dtype)") } } @@ -270,21 +347,60 @@ public final class MLXArray { } switch type { - case is Bool.Type: return mlx_array_item_bool(self.ctx) as! T - case is UInt8.Type: return mlx_array_item_uint8(self.ctx) as! T - case is UInt16.Type: return mlx_array_item_uint16(self.ctx) as! T - case is UInt32.Type: return mlx_array_item_uint32(self.ctx) as! T - case is UInt64.Type: return mlx_array_item_uint64(self.ctx) as! T - case is Int8.Type: return mlx_array_item_int8(self.ctx) as! T - case is Int16.Type: return mlx_array_item_int16(self.ctx) as! T - case is Int32.Type: return mlx_array_item_int32(self.ctx) as! T - case is Int64.Type: return mlx_array_item_int64(self.ctx) as! T - case is Int.Type: return Int(mlx_array_item_int64(self.ctx)) as! T + case is Bool.Type: + var r: Bool = false + mlx_array_item_bool(&r, self.ctx) + return r as! T + case is UInt8.Type: + var r: UInt8 = 0 + mlx_array_item_uint8(&r, self.ctx) + return r as! T + case is UInt16.Type: + var r: UInt16 = 0 + mlx_array_item_uint16(&r, self.ctx) + return r as! T + case is UInt32.Type: + var r: UInt32 = 0 + mlx_array_item_uint32(&r, self.ctx) + return r as! T + case is UInt64.Type: + var r: UInt64 = 0 + mlx_array_item_uint64(&r, self.ctx) + return r as! T + case is Int8.Type: + var r: Int8 = 0 + mlx_array_item_int8(&r, self.ctx) + return r as! T + case is Int16.Type: + var r: Int16 = 0 + mlx_array_item_int16(&r, self.ctx) + return r as! T + case is Int32.Type: + var r: Int32 = 0 + mlx_array_item_int32(&r, self.ctx) + return r as! T + case is Int64.Type: + var r: Int64 = 0 + mlx_array_item_int64(&r, self.ctx) + return r as! T + case is Int.Type: + var r: Int64 = 0 + mlx_array_item_int64(&r, self.ctx) + return Int(r) as! T #if !arch(x86_64) - case is Float16.Type: return mlx_array_item_float16(self.ctx) as! T + case is Float16.Type: + var r: Float16 = 0 + mlx_array_item_float16(&r, self.ctx) + return r as! T #endif - case is Float32.Type: return mlx_array_item_float32(self.ctx) as! T - case is Float.Type: return mlx_array_item_float32(self.ctx) as! T + case is Float32.Type: + var r: Float32 = 0 + mlx_array_item_float32(&r, self.ctx) + return r as! T + case is Float.Type: + var r: Float = 0 + mlx_array_item_float32(&r, self.ctx) + return r as! T case is Complex.Type: // mlx_array_item_complex64() isn't visible in swift so read the array // contents. call self.eval() as this doesn't end up in item() @@ -330,7 +446,9 @@ public final class MLXArray { /// - public func asType(_ type: DType, stream: StreamOrDevice = .default) -> MLXArray { guard type != self.dtype else { return self } - return MLXArray(mlx_astype(ctx, type.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_astype(&result, ctx, type.cmlxDtype, stream.ctx) + return MLXArray(result) } /// Create a new `MLXArray` with the contents converted to the given type, e.g. `Float.self`. @@ -391,17 +509,14 @@ public final class MLXArray { /// Replace the contents with a reference to a new array. public func update(_ array: MLXArray) { - if array.ctx != self.ctx { - mlx_retain(array.ctx) - mlx_free(ctx) - self.ctx = array.ctx - } + mlx_array_set(&self.ctx, array.ctx) } /// Internal function for copying the backing `mlx::core::array` context. func copyContext() -> MLXArray { - mlx_retain(ctx) - return MLXArray(ctx) + var new = mlx_array_new() + mlx_array_set(&new, self.ctx) + return MLXArray(new) } } @@ -413,6 +528,9 @@ extension MLXArray: Updatable, Evaluatable { extension MLXArray: CustomStringConvertible { public var description: String { - mlx_describe(ctx) ?? String(describing: type(of: self)) + var s = mlx_string_new() + mlx_array_tostring(&s, ctx) + defer { mlx_string_free(s) } + return String(cString: mlx_string_data(s), encoding: .utf8)! } } diff --git a/Source/MLX/Ops+Array.swift b/Source/MLX/Ops+Array.swift index c4f8cfad..e3ca95f2 100644 --- a/Source/MLX/Ops+Array.swift +++ b/Source/MLX/Ops+Array.swift @@ -17,7 +17,9 @@ import Foundation /// - /// - ``MLXArray/abs(stream:)`` public func abs(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_abs(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_abs(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// An `and` reduction over the given axes. @@ -49,7 +51,9 @@ public func abs(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra public func all( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_all_axes(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_all_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// An `and` reduction over the given axes. @@ -68,7 +72,9 @@ public func all( public func all( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_all_axis(array.ctx, axis.int32, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_all_axis(&result, array.ctx, axis.int32, keepDims, stream.ctx) + return MLXArray(result) } /// An `and` reduction over the given axes. @@ -86,7 +92,9 @@ public func all( public func all(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_all_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_all_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Approximate comparison of two arrays. @@ -131,7 +139,9 @@ public func allClose( stream: StreamOrDevice = .default ) -> MLXArray { let other = other.asMLXArray(dtype: array.dtype) - return MLXArray(mlx_allclose(array.ctx, other.ctx, rtol, atol, equalNaN, stream.ctx)) + var result = mlx_array_new() + mlx_allclose(&result, array.ctx, other.ctx, rtol, atol, equalNaN, stream.ctx) + return MLXArray(result) } /// An `or` reduction over the given axes. @@ -163,7 +173,9 @@ public func allClose( public func any( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_any(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_any(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// An `or` reduction over the given axes. @@ -182,7 +194,9 @@ public func any( public func any( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_any(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_any(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// An `or` reduction over the given axes. @@ -200,7 +214,9 @@ public func any( public func any(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_any_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_any_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the maximum values along the axis. @@ -226,7 +242,9 @@ public func any(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevic public func argMax( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_argmax(array.ctx, axis.int32, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmax(&result, array.ctx, axis.int32, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the maximum value over the entire array. @@ -251,7 +269,9 @@ public func argMax( public func argMax(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argmax_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmax_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the minimum values along the axis. @@ -277,7 +297,9 @@ public func argMax(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDe public func argMin( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_argmin(array.ctx, axis.int32, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmin(&result, array.ctx, axis.int32, keepDims, stream.ctx) + return MLXArray(result) } /// Indices of the minimum value over the entire array. @@ -302,7 +324,9 @@ public func argMin( public func argMin(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argmin_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_argmin_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Array equality check. @@ -329,7 +353,9 @@ public func arrayEqual( _ array: MLXArray, _ other: T, equalNAN: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { let other = other.asMLXArray(dtype: array.dtype) - return MLXArray(mlx_array_equal(array.ctx, other.ctx, equalNAN, stream.ctx)) + var result = mlx_array_new() + mlx_array_equal(&result, array.ctx, other.ctx, equalNAN, stream.ctx) + return MLXArray(result) } /// Element-wise bitwise and. @@ -343,7 +369,9 @@ public func bitwiseAnd( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_bitwise_and(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_bitwise_and(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise bitwise or. @@ -357,7 +385,9 @@ public func bitwiseOr( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_bitwise_or(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_bitwise_or(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise bitwise xor. @@ -371,7 +401,9 @@ public func bitwiseXOr( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_bitwise_xor(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_bitwise_xor(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise cosine. @@ -380,7 +412,9 @@ public func bitwiseXOr( /// - /// - ``MLXArray/cos(stream:)`` public func cos(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_cos(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_cos(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise complex conjugate of the input. @@ -388,7 +422,9 @@ public func cos(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra /// ### See Also /// - public func conjugate(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_conjugate(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_conjugate(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Return the cumulative maximum of the elements along the given axis. @@ -408,7 +444,9 @@ public func cummax( _ array: MLXArray, axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cummax(array.ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cummax(&result, array.ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative maximum of the elements over the flattened array. @@ -428,9 +466,10 @@ public func cummax( _ array: MLXArray, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(array.ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cummax(flat, 0, reverse, inclusive, stream.ctx)) + let flat = array.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cummax(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative minimum of the elements along the given axis. @@ -450,7 +489,9 @@ public func cummin( _ array: MLXArray, axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cummin(array.ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cummin(&result, array.ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative minimum of the elements over the flattened array. @@ -470,9 +511,10 @@ public func cummin( _ array: MLXArray, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(array.ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cummin(flat, 0, reverse, inclusive, stream.ctx)) + let flat = array.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cummin(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative product of the elements along the given axis. @@ -492,7 +534,9 @@ public func cumprod( _ array: MLXArray, axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cumprod(array.ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cumprod(&result, array.ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative product of the elements over the flattened array. @@ -512,9 +556,10 @@ public func cumprod( _ array: MLXArray, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(array.ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cumprod(flat, 0, reverse, inclusive, stream.ctx)) + let flat = array.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cumprod(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative sum of the elements along the given axis. @@ -534,7 +579,9 @@ public func cumsum( _ array: MLXArray, axis: Int, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_cumsum(array.ctx, axis.int32, reverse, inclusive, stream.ctx)) + var result = mlx_array_new() + mlx_cumsum(&result, array.ctx, axis.int32, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Return the cumulative sum of the elements over the flattened array. @@ -554,9 +601,10 @@ public func cumsum( _ array: MLXArray, reverse: Bool = false, inclusive: Bool = true, stream: StreamOrDevice = .default ) -> MLXArray { - let flat = mlx_reshape(array.ctx, [-1], 1, stream.ctx)! - defer { mlx_free(flat) } - return MLXArray(mlx_cumsum(flat, 0, reverse, inclusive, stream.ctx)) + let flat = array.reshaped([-1], stream: stream) + var result = mlx_array_new() + mlx_cumsum(&result, flat.ctx, 0, reverse, inclusive, stream.ctx) + return MLXArray(result) } /// Extract a diagonal or construct a diagonal matrix. @@ -573,7 +621,9 @@ public func cumsum( /// ### See Also /// - ``diagonal(_:offset:axis1:axis2:stream:)`` public func diag(_ array: MLXArray, k: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_diag(array.ctx, k.int32, stream.ctx)) + var result = mlx_array_new() + mlx_diag(&result, array.ctx, k.int32, stream.ctx) + return MLXArray(result) } /// Return specified diagonals. @@ -599,7 +649,9 @@ public func diagonal( _ array: MLXArray, offset: Int = 0, axis1: Int = 0, axis2: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_diagonal(array.ctx, offset.int32, axis1.int32, axis2.int32, stream.ctx)) + var result = mlx_array_new() + mlx_diagonal(&result, array.ctx, offset.int32, axis1.int32, axis2.int32, stream.ctx) + return MLXArray(result) } /// Element-wise exponential. @@ -609,7 +661,9 @@ public func diagonal( /// - ``MLXArray/exp(stream:)`` /// - ``expm1(_:stream:)`` public func exp(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_exp(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_exp(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Flatten an array. @@ -643,7 +697,9 @@ public func exp(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra public func flattened( _ array: MLXArray, start: Int = 0, end: Int = -1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_flatten(array.ctx, start.int32, end.int32, stream.ctx)) + var result = mlx_array_new() + mlx_flatten(&result, array.ctx, start.int32, end.int32, stream.ctx) + return MLXArray(result) } /// Element-wise floor. @@ -655,7 +711,9 @@ public func flattened( /// - ``ceil(_:stream:)`` /// - ``MLXArray/floor(stream:)`` public func floor(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_floor(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_floor(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise integer division.. @@ -680,7 +738,9 @@ public func floorDivide( _ array: MLXArray, _ other: T, stream: StreamOrDevice = .default ) -> MLXArray { let other = other.asMLXArray(dtype: array.dtype) - return MLXArray(mlx_floor_divide(array.ctx, other.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_floor_divide(&result, array.ctx, other.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise left shift. @@ -695,7 +755,9 @@ public func leftShift( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_left_shift(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_left_shift(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise natural logarithm. @@ -704,7 +766,9 @@ public func leftShift( /// - /// - ``MLXArray/log(stream:)`` public func log(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise base-2 logarithm. @@ -716,7 +780,9 @@ public func log(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra /// - ``log1p(_:stream:)`` /// - ``MLXArray/log2(stream:)`` public func log2(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log2(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log2(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise base-10 logarithm. @@ -728,7 +794,9 @@ public func log2(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - ``log1p(_:stream:)`` /// - ``MLXArray/log10(stream:)`` public func log10(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log10(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log10(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise natural log of one plus the array. @@ -740,7 +808,9 @@ public func log10(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// - ``log10(_:stream:)`` /// - ``MLXArray/log1p(stream:)`` public func log1p(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_log1p(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_log1p(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// A `log-sum-exp` reduction over the given axes. @@ -765,7 +835,9 @@ public func log1p(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr public func logSumExp( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_logsumexp(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_logsumexp(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `log-sum-exp` reduction over the given axis. @@ -790,7 +862,9 @@ public func logSumExp( public func logSumExp( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_logsumexp(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_logsumexp(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `log-sum-exp` reduction over the entire array. @@ -814,7 +888,9 @@ public func logSumExp( public func logSumExp(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_logsumexp_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_logsumexp_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Matrix multiplication. @@ -854,7 +930,9 @@ public func logSumExp(_ array: MLXArray, keepDims: Bool = false, stream: StreamO /// - ``blockMaskedMM(_:_:blockSize:maskOut:maskLHS:maskRHS:stream:)`` /// - ``MLXArray/matmul(_:stream:)`` public func matmul(_ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_matmul(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_matmul(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// A `max` reduction over the given axes. @@ -880,7 +958,9 @@ public func matmul(_ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .defau public func max( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_max(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_max(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `max` reduction over the given axis. @@ -906,7 +986,9 @@ public func max( public func max( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_max(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_max(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `max` reduction over the entire array. @@ -931,7 +1013,9 @@ public func max( public func max(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_max_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_max_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// A `mean` reduction over the given axes. @@ -957,7 +1041,9 @@ public func max(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevic public func mean( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_mean(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_mean(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `mean` reduction over the given axis. @@ -983,7 +1069,9 @@ public func mean( public func mean( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_mean(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_mean(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `mean` reduction over the entire array. @@ -1008,7 +1096,9 @@ public func mean( public func mean(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_mean_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_mean_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// A `min` reduction over the given axes. @@ -1034,7 +1124,9 @@ public func mean(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevi public func min( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_min(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_min(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `min` reduction over the given axis. @@ -1060,7 +1152,9 @@ public func min( public func min( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_min(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_min(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `min` reduction over the entire array. @@ -1085,7 +1179,9 @@ public func min( public func min(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_min_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_min_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Move an axis to a new position. @@ -1121,7 +1217,9 @@ public func min(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevic public func movedAxis( _ array: MLXArray, source: Int, destination: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_moveaxis(array.ctx, source.int32, destination.int32, stream.ctx)) + var result = mlx_array_new() + mlx_moveaxis(&result, array.ctx, source.int32, destination.int32, stream.ctx) + return MLXArray(result) } /// Element-wise power operation. @@ -1144,7 +1242,9 @@ public func movedAxis( /// - ``MLXArray/pow(_:stream:)`` public func pow(_ array: MLXArray, _ other: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_power(array.ctx, other.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_power(&result, array.ctx, other.ctx, stream.ctx) + return MLXArray(result) } public func pow(_ array: MLXArray, _ other: T, stream: StreamOrDevice = .default) @@ -1184,7 +1284,9 @@ public func pow(_ array: T, _ other: MLXArray, stream: StreamO public func product( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_prod(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_prod(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// A `product` reduction over the given axis. @@ -1210,7 +1312,9 @@ public func product( public func product( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_prod(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_prod(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// A `product` reduction over the entire array. @@ -1235,7 +1339,9 @@ public func product( public func product(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_prod_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_prod_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Element-wise reciprocal. @@ -1244,7 +1350,9 @@ public func product(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrD /// - /// - ``MLXArray/reciprocal(stream:)`` public func reciprocal(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reciprocal(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_reciprocal(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Reshape an array while preserving the size. @@ -1262,7 +1370,9 @@ public func reciprocal(_ array: MLXArray, stream: StreamOrDevice = .default) -> public func reshaped(_ array: MLXArray, _ newShape: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reshape(array.ctx, newShape.asInt32, newShape.count, stream.ctx)) + var result = mlx_array_new() + mlx_reshape(&result, array.ctx, newShape.asInt32, newShape.count, stream.ctx) + return MLXArray(result) } /// Reshape an array while preserving the size. @@ -1280,7 +1390,9 @@ public func reshaped(_ array: MLXArray, _ newShape: [Int], stream: StreamOrDevic public func reshaped(_ array: MLXArray, _ newShape: Int..., stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_reshape(array.ctx, newShape.asInt32, newShape.count, stream.ctx)) + var result = mlx_array_new() + mlx_reshape(&result, array.ctx, newShape.asInt32, newShape.count, stream.ctx) + return MLXArray(result) } /// Element-wise right shift. @@ -1295,7 +1407,9 @@ public func rightShift( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_right_shift(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_right_shift(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Round to the given number of decimals. @@ -1316,7 +1430,9 @@ public func rightShift( public func round(_ array: MLXArray, decimals: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_round(array.ctx, decimals.int32, stream.ctx)) + var result = mlx_array_new() + mlx_round(&result, array.ctx, decimals.int32, stream.ctx) + return MLXArray(result) } /// Element-wise reciprocal and square root. @@ -1325,7 +1441,9 @@ public func round(_ array: MLXArray, decimals: Int = 0, stream: StreamOrDevice = /// - /// - ``MLXArray/rsqrt(stream:)`` public func rsqrt(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_rsqrt(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_rsqrt(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise sine. @@ -1334,7 +1452,9 @@ public func rsqrt(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// - /// - ``MLXArray/sin(stream:)`` public func sin(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sin(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sin(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Split an array into equal size pieces along a given axis. @@ -1366,8 +1486,9 @@ public func sin(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra public func split(_ array: MLXArray, parts: Int, axis: Int = 0, stream: StreamOrDevice = .default) -> [MLXArray] { - let vec = mlx_split_equal_parts(array.ctx, parts.int32, axis.int32, stream.ctx)! - defer { mlx_free(vec) } + var vec = mlx_vector_array_new() + mlx_split_equal_parts(&vec, array.ctx, parts.int32, axis.int32, stream.ctx) + defer { mlx_vector_array_free(vec) } return mlx_vector_array_values(vec) } @@ -1394,8 +1515,9 @@ public func split(_ array: MLXArray, parts: Int, axis: Int = 0, stream: StreamOr public func split(_ array: MLXArray, axis: Int = 0, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { - let vec = mlx_split_equal_parts(array.ctx, 2, axis.int32, stream.ctx)! - defer { mlx_free(vec) } + var vec = mlx_vector_array_new() + mlx_split_equal_parts(&vec, array.ctx, 2, axis.int32, stream.ctx) + defer { mlx_vector_array_free(vec) } let pieces = mlx_vector_array_values(vec) return (pieces[0], pieces[1]) } @@ -1414,8 +1536,9 @@ public func split(_ array: MLXArray, axis: Int = 0, stream: StreamOrDevice = .de public func split( _ array: MLXArray, indices: [Int], axis: Int = 0, stream: StreamOrDevice = .default ) -> [MLXArray] { - let vec = mlx_split(array.ctx, indices.asInt32, indices.count, axis.int32, stream.ctx)! - defer { mlx_free(vec) } + var vec = mlx_vector_array_new() + mlx_split(&vec, array.ctx, indices.asInt32, indices.count, axis.int32, stream.ctx) + defer { mlx_vector_array_free(vec) } return mlx_vector_array_values(vec) } @@ -1425,7 +1548,9 @@ public func split( /// - /// - ``MLXArray/sqrt(stream:)`` public func sqrt(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sqrt(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sqrt(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise square. @@ -1434,7 +1559,9 @@ public func sqrt(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``MLXArray/square(stream:)`` public func square(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_square(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_square(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Remove length one axes from an array. @@ -1450,7 +1577,9 @@ public func square(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXA /// - ``MLXArray/squeezed(axes:stream:)`` public func squeezed(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_squeeze(array.ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_squeeze(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } /// Remove length one axes from an array. @@ -1465,7 +1594,9 @@ public func squeezed(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .d /// - ``squeezed(_:stream:)`` /// - ``MLXArray/squeezed(axes:stream:)`` public func squeezed(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_squeeze(array.ctx, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_squeeze(&result, array.ctx, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Remove all length one axes from an array. @@ -1476,7 +1607,9 @@ public func squeezed(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .def /// - ``squeezed(_:axis:stream:)`` /// - ``MLXArray/squeezed(axes:stream:)`` public func squeezed(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_squeeze_all(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_squeeze_all(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Sum reduce the array over the given axes. @@ -1495,7 +1628,9 @@ public func squeezed(_ array: MLXArray, stream: StreamOrDevice = .default) -> ML public func sum( _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_sum(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_sum(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// Sum reduce the array over the given axis. @@ -1514,7 +1649,9 @@ public func sum( public func sum( _ array: MLXArray, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_sum(array.ctx, [axis.int32], 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_sum(&result, array.ctx, [axis.int32], 1, keepDims, stream.ctx) + return MLXArray(result) } /// Sum reduce the array over all axes. @@ -1532,7 +1669,9 @@ public func sum( public func sum(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sum_all(array.ctx, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_sum_all(&result, array.ctx, keepDims, stream.ctx) + return MLXArray(result) } /// Swap two axes of an array. @@ -1567,7 +1706,9 @@ public func sum(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevic public func swappedAxes( _ array: MLXArray, _ axis1: Int, _ axis2: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_swapaxes(array.ctx, axis1.int32, axis2.int32, stream.ctx)) + var result = mlx_array_new() + mlx_swapaxes(&result, array.ctx, axis1.int32, axis2.int32, stream.ctx) + return MLXArray(result) } /// Take elements along an axis. @@ -1587,7 +1728,9 @@ public func swappedAxes( public func take( _ array: MLXArray, _ indices: MLXArray, axis: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_take(array.ctx, indices.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_take(&result, array.ctx, indices.ctx, axis.int32, stream.ctx) + return MLXArray(result) } /// Take elements from flattened 1-D array. @@ -1599,7 +1742,9 @@ public func take(_ array: MLXArray, _ indices: MLXArray, stream: StreamOrDevice -> MLXArray { let input = array.reshaped([-1], stream: stream) - return MLXArray(mlx_take(input.ctx, indices.ctx, 0, stream.ctx)) + var result = mlx_array_new() + mlx_take(&result, input.ctx, indices.ctx, 0, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -1616,13 +1761,17 @@ public func take(_ array: MLXArray, _ indices: MLXArray, stream: StreamOrDevice public func transposed(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose(array.ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_transpose(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } public func transposed(_ array: MLXArray, _ axes: Int..., stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose(array.ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_transpose(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -1636,7 +1785,9 @@ public func transposed(_ array: MLXArray, _ axes: Int..., stream: StreamOrDevice /// - ``MLXArray/transposed(axes:stream:)`` public func transposed(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose(array.ctx, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_transpose(&result, array.ctx, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -1649,7 +1800,9 @@ public func transposed(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .d /// - ``transposed(_:axis:stream:)`` /// - ``MLXArray/transposed(axes:stream:)`` public func transposed(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_transpose_all(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_transpose_all(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Transpose the dimensions of the array. @@ -1676,7 +1829,9 @@ public func variance( _ array: MLXArray, axes: [Int], keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_var(array.ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_var(&result, array.ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Compute the variance(s) over the given axes @@ -1696,7 +1851,9 @@ public func variance( _ array: MLXArray, axis: Int, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_var(array.ctx, [axis.int32], 1, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_var(&result, array.ctx, [axis.int32], 1, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Compute the variance(s) over the given axes @@ -1714,7 +1871,9 @@ public func variance( public func variance( _ array: MLXArray, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_var_all(array.ctx, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_var_all(&result, array.ctx, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// View the array as a different type. @@ -1733,5 +1892,7 @@ public func variance( /// ### See Also ///- ``MLXArray/view(dtype:stream:)`` public func view(_ array: MLXArray, dtype: DType, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_view(array.ctx, dtype.cmlxDtype, stream.ctx)) + var result = mlx_array_new() + mlx_view(&result, array.ctx, dtype.cmlxDtype, stream.ctx) + return MLXArray(result) } diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index 83920d2f..646e0584 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -8,10 +8,11 @@ import Foundation /// Broadcast a vector of arrays against one another. func broadcast(arrays: [MLXArray], stream: StreamOrDevice = .default) -> [MLXArray] { let vector_array = new_mlx_vector_array(arrays) - defer { mlx_free(vector_array) } + defer { mlx_vector_array_free(vector_array) } - let result = mlx_broadcast_arrays(vector_array, stream.ctx)! - defer { mlx_free(result) } + var result = mlx_vector_array_new() + mlx_broadcast_arrays(&result, vector_array, stream.ctx) + defer { mlx_vector_array_free(result) } return mlx_vector_array_values(result) } @@ -42,7 +43,9 @@ public func add( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_add(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_add(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } @available(*, deprecated, renamed: "addMM(_:_:_:alpha:beta:stream:)") @@ -82,7 +85,9 @@ public func addMM( ) -> MLXArray { let (a, b) = toArrays(a, b) let (_, c) = toArrays(a, c) - return MLXArray(mlx_addmm(c.ctx, a.ctx, b.ctx, alpha, beta, stream.ctx)) + var result = mlx_array_new() + mlx_addmm(&result, c.ctx, a.ctx, b.ctx, alpha, beta, stream.ctx) + return MLXArray(result) } /// Element-wise inverse cosine. @@ -95,7 +100,9 @@ public func addMM( /// - /// - ``cos(_:stream:)`` public func acos(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arccos(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arccos(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse hyperbolic cosine. @@ -108,7 +115,9 @@ public func acos(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``cosh(_:stream:)`` public func acosh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arccosh(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arccosh(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse sine. @@ -121,7 +130,9 @@ public func acosh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// - /// - ``sin(_:stream:)`` public func asin(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arcsin(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arcsin(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse hyperbolic sine. @@ -134,7 +145,9 @@ public func asin(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``sinh(_:stream:)`` public func asinh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arcsinh(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arcsinh(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse tangent. @@ -147,7 +160,9 @@ public func asinh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// - /// - ``tan(_:stream:)`` public func atan(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arctan(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arctan(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse tangent of the ratio of two arrays. @@ -156,7 +171,9 @@ public func atan(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``atan(_:stream:)`` public func atan2(_ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arctan2(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arctan2(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse hyperbolic tangent. @@ -169,7 +186,9 @@ public func atan2(_ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .defaul /// - /// - ``tanh(_:stream:)`` public func atanh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_arctanh(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_arctanh(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Convert array to have at least 1 dimension. @@ -177,7 +196,9 @@ public func atanh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// ### See Also /// - public func atLeast1D(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_atleast_1d(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_atleast_1d(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Convert array to have at least 2 dimensions. @@ -185,7 +206,9 @@ public func atLeast1D(_ array: MLXArray, stream: StreamOrDevice = .default) -> M /// ### See Also /// - public func atLeast2D(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_atleast_2d(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_atleast_2d(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Convert array to have at least 3 dimensions. @@ -193,7 +216,9 @@ public func atLeast2D(_ array: MLXArray, stream: StreamOrDevice = .default) -> M /// ### See Also /// - public func atLeast3D(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_atleast_3d(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_atleast_3d(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Returns the indices that partition the array. @@ -226,7 +251,9 @@ public func atLeast3D(_ array: MLXArray, stream: StreamOrDevice = .default) -> M public func argPartition(_ array: MLXArray, kth: Int, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argpartition(array.ctx, kth.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_argpartition(&result, array.ctx, kth.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Returns the indices that partition the flattened array. @@ -244,7 +271,9 @@ public func argPartition(_ array: MLXArray, kth: Int, axis: Int, stream: StreamO /// - ``partitioned(_:kth:axis:stream:)`` public func argPartition(_ array: MLXArray, kth: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argpartition_all(array.ctx, kth.int32, stream.ctx)) + var result = mlx_array_new() + mlx_argpartition_all(&result, array.ctx, kth.int32, stream.ctx) + return MLXArray(result) } /// Returns the indices that sort the array. @@ -268,7 +297,9 @@ public func argPartition(_ array: MLXArray, kth: Int, stream: StreamOrDevice = . /// - /// - ``argSort(_:stream:)`` public func argSort(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argsort(array.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_argsort(&result, array.ctx, axis.int32, stream.ctx) + return MLXArray(result) } /// Returns the indices that sort the array. @@ -281,7 +312,9 @@ public func argSort(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .defa /// - /// - ``argSort(_:axis:stream:)`` public func argSort(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_argsort_all(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_argsort_all(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Create a view into the array with the given shape and strides. @@ -339,10 +372,12 @@ public func asStrided( resolvedStrides = result.reversed() } - return MLXArray( - mlx_as_strided( - array.ctx, shape.asInt32, shape.count, resolvedStrides, resolvedStrides.count, offset, - stream.ctx)) + var result = mlx_array_new() + mlx_as_strided( + &result, + array.ctx, shape.asInt32, shape.count, resolvedStrides, resolvedStrides.count, offset, + stream.ctx) + return MLXArray(result) } /// Matrix multiplication with block masking. @@ -379,9 +414,14 @@ public func blockMaskedMM( _ a: MLXArray, _ b: MLXArray, blockSize: Int = 64, maskOut: MLXArray? = nil, maskLHS: MLXArray? = nil, maskRHS: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_block_masked_mm( - a.ctx, b.ctx, blockSize.int32, maskOut?.ctx, maskLHS?.ctx, maskRHS?.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_block_masked_mm( + &result, + a.ctx, b.ctx, blockSize.int32, (maskOut ?? .mlxNone).ctx, (maskLHS ?? .mlxNone).ctx, + (maskRHS ?? .mlxNone).ctx, stream.ctx) + + return MLXArray(result) } /// Broadcast an array to the given shape. @@ -395,7 +435,9 @@ public func blockMaskedMM( public func broadcast(_ array: MLXArray, to shape: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_broadcast_to(array.ctx, shape.asInt32, shape.count, stream.ctx)) + var result = mlx_array_new() + mlx_broadcast_to(&result, array.ctx, shape.asInt32, shape.count, stream.ctx) + return MLXArray(result) } /// Element-wise ceil. @@ -408,7 +450,9 @@ public func broadcast(_ array: MLXArray, to shape: [Int], stream: StreamOrDevice /// - /// - ``floor(_:stream:)`` public func ceil(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_ceil(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_ceil(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Clip the values of the array between the given minimum and maximum. @@ -426,7 +470,11 @@ public func clip( _ array: MLXArray, min: A, stream: StreamOrDevice = .default ) -> MLXArray { let (array, min) = toArrays(array, min) - return MLXArray(mlx_clip(array.ctx, min.ctx, nil, stream.ctx)) + var result = mlx_array_new() + let max = mlx_array_new() + defer { mlx_array_free(max) } + mlx_clip(&result, array.ctx, min.ctx, max, stream.ctx) + return MLXArray(result) } /// Clip the values of the array between the given minimum and maximum. @@ -445,7 +493,9 @@ public func clip( ) -> MLXArray { let (array, min) = toArrays(array, min) let (_, max) = toArrays(array, max) - return MLXArray(mlx_clip(array.ctx, min.ctx, max.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_clip(&result, array.ctx, min.ctx, max.ctx, stream.ctx) + return MLXArray(result) } /// Clip the values of the array up to the given maximum. @@ -463,7 +513,11 @@ public func clip(_ array: MLXArray, max: A, stream: StreamOrDe -> MLXArray { let (array, max) = toArrays(array, max) - return MLXArray(mlx_clip(array.ctx, nil, max.ctx, stream.ctx)) + var result = mlx_array_new() + let min = mlx_array_new() + defer { mlx_array_free(min) } + mlx_clip(&result, array.ctx, min, max.ctx, stream.ctx) + return MLXArray(result) } /// Concatenate the arrays along the given axis. @@ -479,9 +533,11 @@ public func concatenated(_ arrays: [MLXArray], axis: Int = 0, stream: StreamOrDe -> MLXArray { let vector_array = new_mlx_vector_array(arrays) - defer { mlx_free(vector_array) } + defer { mlx_vector_array_free(vector_array) } - return MLXArray(mlx_concatenate(vector_array, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_concatenate(&result, vector_array, axis.int32, stream.ctx) + return MLXArray(result) } /// 1D convolution over an input with several channels. @@ -506,10 +562,12 @@ public func conv1d( _ array: MLXArray, _ weight: MLXArray, stride: Int = 1, padding: Int = 0, dilation: Int = 1, groups: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv1d( - array.ctx, weight.ctx, stride.int32, padding.int32, dilation.int32, groups.int32, - stream.ctx)) + var result = mlx_array_new() + mlx_conv1d( + &result, + array.ctx, weight.ctx, stride.int32, padding.int32, dilation.int32, groups.int32, + stream.ctx) + return MLXArray(result) } /// 2D convolution over an input with several channels. @@ -550,11 +608,13 @@ public func conv2d( _ array: MLXArray, _ weight: MLXArray, stride: IntOrPair = 1, padding: IntOrPair = 0, dilation: IntOrPair = 1, groups: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv2d( - array.ctx, weight.ctx, stride.first.int32, stride.second.int32, padding.first.int32, - padding.second.int32, dilation.first.int32, dilation.second.int32, groups.int32, - stream.ctx)) + var result = mlx_array_new() + mlx_conv2d( + &result, + array.ctx, weight.ctx, stride.first.int32, stride.second.int32, padding.first.int32, + padding.second.int32, dilation.first.int32, dilation.second.int32, groups.int32, + stream.ctx) + return MLXArray(result) } /// 3D convolution over an input with several channels. @@ -595,13 +655,15 @@ public func conv3d( _ array: MLXArray, _ weight: MLXArray, stride: IntOrTriple = 1, padding: IntOrTriple = 0, dilation: IntOrTriple = 1, groups: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv3d( - array.ctx, weight.ctx, - stride.first.int32, stride.second.int32, stride.third.int32, - padding.first.int32, padding.second.int32, padding.third.int32, - dilation.first.int32, dilation.second.int32, dilation.third.int32, - groups.int32, stream.ctx)) + var result = mlx_array_new() + mlx_conv3d( + &result, + array.ctx, weight.ctx, + stride.first.int32, stride.second.int32, stride.third.int32, + padding.first.int32, padding.second.int32, padding.third.int32, + dilation.first.int32, dilation.second.int32, dilation.third.int32, + groups.int32, stream.ctx) + return MLXArray(result) } /// General convolution over an input with several channels. @@ -638,15 +700,17 @@ public func convGeneral( flip: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv_general( - array.ctx, weight.ctx, - strides.asInt32Array, strides.count, - padding.asInt32Array, padding.count, - padding.asInt32Array, padding.count, - kernelDilation.asInt32Array, kernelDilation.count, - inputDilation.asInt32Array, inputDilation.count, - groups.int32, flip, stream.ctx)) + var result = mlx_array_new() + mlx_conv_general( + &result, + array.ctx, weight.ctx, + strides.asInt32Array, strides.count, + padding.asInt32Array, padding.count, + padding.asInt32Array, padding.count, + kernelDilation.asInt32Array, kernelDilation.count, + inputDilation.asInt32Array, inputDilation.count, + groups.int32, flip, stream.ctx) + return MLXArray(result) } /// General convolution over an input with several channels with a padding pair. @@ -682,15 +746,17 @@ public func convGeneral( flip: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv_general( - array.ctx, weight.ctx, - strides.asInt32Array, strides.count, - [padding.0.int32], 1, - [padding.1.int32], 1, - kernelDilation.asInt32Array, kernelDilation.count, - inputDilation.asInt32Array, inputDilation.count, - groups.int32, flip, stream.ctx)) + var result = mlx_array_new() + mlx_conv_general( + &result, + array.ctx, weight.ctx, + strides.asInt32Array, strides.count, + [padding.0.int32], 1, + [padding.1.int32], 1, + kernelDilation.asInt32Array, kernelDilation.count, + inputDilation.asInt32Array, inputDilation.count, + groups.int32, flip, stream.ctx) + return MLXArray(result) } /// 1D transposed convolution over an input with several channels. @@ -716,10 +782,12 @@ public func convTransposed1d( _ array: MLXArray, _ weight: MLXArray, stride: Int = 1, padding: Int = 0, dilation: Int = 1, groups: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv_transpose1d( - array.ctx, weight.ctx, stride.int32, padding.int32, dilation.int32, groups.int32, - stream.ctx)) + var result = mlx_array_new() + mlx_conv_transpose1d( + &result, + array.ctx, weight.ctx, stride.int32, padding.int32, dilation.int32, groups.int32, + stream.ctx) + return MLXArray(result) } /// 2D transposed convolution over an input with several channels. @@ -761,11 +829,13 @@ public func convTransposed2d( _ array: MLXArray, _ weight: MLXArray, stride: IntOrPair = 1, padding: IntOrPair = 0, dilation: IntOrPair = 1, groups: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv_transpose2d( - array.ctx, weight.ctx, stride.first.int32, stride.second.int32, padding.first.int32, - padding.second.int32, dilation.first.int32, dilation.second.int32, groups.int32, - stream.ctx)) + var result = mlx_array_new() + mlx_conv_transpose2d( + &result, + array.ctx, weight.ctx, stride.first.int32, stride.second.int32, padding.first.int32, + padding.second.int32, dilation.first.int32, dilation.second.int32, groups.int32, + stream.ctx) + return MLXArray(result) } /// 3D transposed convolution over an input with several channels. @@ -807,13 +877,15 @@ public func convTransposed3d( _ array: MLXArray, _ weight: MLXArray, stride: IntOrTriple = 1, padding: IntOrTriple = 0, dilation: IntOrTriple = 1, groups: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_conv_transpose3d( - array.ctx, weight.ctx, - stride.first.int32, stride.second.int32, stride.third.int32, - padding.first.int32, padding.second.int32, padding.third.int32, - dilation.first.int32, dilation.second.int32, dilation.third.int32, - groups.int32, stream.ctx)) + var result = mlx_array_new() + mlx_conv_transpose3d( + &result, + array.ctx, weight.ctx, + stride.first.int32, stride.second.int32, stride.third.int32, + padding.first.int32, padding.second.int32, padding.third.int32, + dilation.first.int32, dilation.second.int32, dilation.third.int32, + groups.int32, stream.ctx) + return MLXArray(result) } /// Mode for ``convolve(_:_:mode:stream:)`` @@ -845,10 +917,12 @@ public func convolve( var (input, weight) = a.size < b.size ? (b, a) : (a, b) - weight = MLXArray( - mlx_slice( - weight.ctx, [weight.dim(0) - 1].asInt32, 1, [-weight.dim(0) - 1].asInt32, 1, [-1], 1, - stream.ctx)) + var slice = mlx_array_new() + mlx_slice( + &slice, + weight.ctx, [weight.dim(0) - 1].asInt32, 1, [-weight.dim(0) - 1].asInt32, 1, [-1], 1, + stream.ctx) + weight = MLXArray(slice) weight = weight.reshaped([1, -1, 1], stream: stream) input = input.reshaped([1, -1, 1], stream: stream) @@ -872,8 +946,9 @@ public func convolve( } } - return MLXArray(mlx_conv1d(input.ctx, weight.ctx, 1, padding.int32, 1, 1, stream.ctx)).reshaped( - -1, stream: stream) + var result = mlx_array_new() + mlx_conv1d(&result, input.ctx, weight.ctx, 1, padding.int32, 1, 1, stream.ctx) + return MLXArray(result).reshaped(-1, stream: stream) } /// Element-wise hyperbolic cosine. @@ -886,7 +961,9 @@ public func convolve( /// - /// - ``cos(_:stream:)`` public func cosh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_cosh(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_cosh(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Convert angles from radians to degrees. @@ -899,7 +976,9 @@ public func cosh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``radians(_:stream:)`` public func degrees(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_degrees(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_degrees(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Dequantize the matrix `w` using the provided `scales` and @@ -915,7 +994,9 @@ public func dequantized( _ w: MLXArray, scales: MLXArray, biases: MLXArray, groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_dequantize(w.ctx, scales.ctx, biases.ctx, groupSize.int32, bits.int32, stream.ctx)) + var result = mlx_array_new() + mlx_dequantize(&result, w.ctx, scales.ctx, biases.ctx, groupSize.int32, bits.int32, stream.ctx) + return MLXArray(result) } /// Element-wise division. @@ -943,7 +1024,9 @@ public func divide( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_divide(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_divide(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise quotient and remainder. @@ -966,9 +1049,10 @@ public func divmod( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> (MLXArray, MLXArray) { let (a, b) = toArrays(a, b) - let arrays = mlx_divmod(a.ctx, b.ctx, stream.ctx)! - defer { mlx_free(arrays) } - let result = mlx_vector_array_values(arrays) + var vec = mlx_vector_array_new() + mlx_divmod(&vec, a.ctx, b.ctx, stream.ctx) + defer { mlx_vector_array_free(vec) } + let result = mlx_vector_array_values(vec) return (result[0], result[1]) } @@ -993,13 +1077,12 @@ public func einsum(_ subscripts: String, _ operands: MLXArray..., stream: Stream public func einsum(_ subscripts: String, operands: [MLXArray], stream: StreamOrDevice = .default) -> MLXArray { - let subscripts = mlx_string_new(subscripts.cString(using: .utf8))! - defer { mlx_free(subscripts) } - let operands = new_mlx_vector_array(operands) - defer { mlx_free(operands) } + defer { mlx_vector_array_free(operands) } - return MLXArray(mlx_einsum(subscripts, operands, stream.ctx)) + var result = mlx_array_new() + mlx_einsum(&result, subscripts, operands, stream.ctx) + return MLXArray(result) } /// Element-wise equality. @@ -1028,7 +1111,9 @@ public func equal( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_equal(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_equal(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise error function. @@ -1044,7 +1129,9 @@ public func equal( /// - /// - ``erfInverse(_:stream:)`` public func erf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_erf(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_erf(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise inverse of ``erf(_:stream:)``. @@ -1060,7 +1147,9 @@ public func erf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra /// - /// - ``erf(_:stream:)`` public func erfInverse(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_erfinv(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_erfinv(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Add a size one dimension at the given axis. @@ -1076,7 +1165,9 @@ public func erfInverse(_ array: MLXArray, stream: StreamOrDevice = .default) -> public func expandedDimensions(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_expand_dims(array.ctx, axes.asInt32, axes.count, stream.ctx)) + var result = mlx_array_new() + mlx_expand_dims(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } /// Add a size one dimension at the given axis. @@ -1092,7 +1183,9 @@ public func expandedDimensions(_ array: MLXArray, axes: [Int], stream: StreamOrD public func expandedDimensions(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_expand_dims(array.ctx, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_expand_dims(&result, array.ctx, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Element-wise exponential minus 1. @@ -1107,7 +1200,9 @@ public func expandedDimensions(_ array: MLXArray, axis: Int, stream: StreamOrDev /// - /// - ``exp(_:stream:)`` public func expm1(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_expm1(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_expm1(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Matrix multiplication with matrix-level gather. @@ -1134,7 +1229,13 @@ public func gatherMatmul( _ a: MLXArray, _ b: MLXArray, lhsIndices: MLXArray? = nil, rhsIndices: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_gather_mm(a.ctx, b.ctx, lhsIndices?.ctx, rhsIndices?.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_gather_mm( + &result, a.ctx, b.ctx, (lhsIndices ?? .mlxNone).ctx, (rhsIndices ?? .mlxNone).ctx, + stream.ctx) + + return MLXArray(result) } /// Perform quantized matrix multiplication with matrix-level gather. @@ -1153,10 +1254,15 @@ public func gatherQuantizedMatmul( transpose: Bool = true, groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_gather_qmm( - x.ctx, w.ctx, scales.ctx, biases.ctx, lhsIndices?.ctx, rhsIndices?.ctx, transpose, - groupSize.int32, bits.int32, stream.ctx)) + var result = mlx_array_new() + + mlx_gather_qmm( + &result, + x.ctx, w.ctx, scales.ctx, biases.ctx, (lhsIndices ?? .mlxNone).ctx, + (rhsIndices ?? .mlxNone).ctx, transpose, + groupSize.int32, bits.int32, stream.ctx) + + return MLXArray(result) } /// Element-wise greater than. @@ -1185,7 +1291,9 @@ public func greater( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_greater(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_greater(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise less greater than or equal. @@ -1214,7 +1322,9 @@ public func greaterEqual( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_greater_equal(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_greater_equal(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Perform the Walsh-Hadamard transform along the final axis. @@ -1230,7 +1340,9 @@ public func hadamardTransform( _ array: MLXArray, scale: Float? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let scale = mlx_optional_float(value: scale ?? 0, has_value: scale != nil) - return MLXArray(mlx_hadamard_transform(array.ctx, scale, stream.ctx)) + var result = mlx_array_new() + mlx_hadamard_transform(&result, array.ctx, scale, stream.ctx) + return MLXArray(result) } /// Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. @@ -1246,7 +1358,9 @@ public func hadamardTransform( public func inner( _ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_inner(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_inner(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Returns a boolean array where two arrays are element-wise equal within a tolerance. @@ -1277,7 +1391,9 @@ public func isClose( _ a: MLXArray, _ b: MLXArray, rtol: Double = 1e-5, atol: Double = 1e-8, equalNaN: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_isclose(a.ctx, b.ctx, rtol, atol, equalNaN, stream.ctx)) + var result = mlx_array_new() + mlx_isclose(&result, a.ctx, b.ctx, rtol, atol, equalNaN, stream.ctx) + return MLXArray(result) } /// Return a boolean array indicating which elements are NaN. @@ -1290,7 +1406,9 @@ public func isClose( /// ### See Also /// - public func isNaN(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_isnan(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_isnan(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Return a boolean array indicating which elements are infinity. @@ -1303,7 +1421,9 @@ public func isNaN(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// ### See Also /// - public func isInf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_isinf(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_isinf(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Return a boolean array indicating which elements are finite. @@ -1316,7 +1436,9 @@ public func isInf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// ### See Also /// - public func isFinite(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_isfinite(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_isfinite(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Return a boolean array indicating which elements are negative infinity. @@ -1329,7 +1451,9 @@ public func isFinite(_ array: MLXArray, stream: StreamOrDevice = .default) -> ML /// ### See Also /// - public func isNegInf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_isneginf(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_isneginf(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Return a boolean array indicating which elements are positive infinity. @@ -1342,7 +1466,9 @@ public func isNegInf(_ array: MLXArray, stream: StreamOrDevice = .default) -> ML /// ### See Also /// - public func isPosInf(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_isposinf(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_isposinf(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise less than. @@ -1371,7 +1497,9 @@ public func less( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_less(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_less(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise less than or equal. @@ -1400,7 +1528,9 @@ public func lessEqual( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_less_equal(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_less_equal(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } enum LoadSaveError: Error { @@ -1432,12 +1562,12 @@ extension LoadSaveError: LocalizedError { public func loadArray(url: URL, stream: StreamOrDevice = .default) throws -> MLXArray { precondition(url.isFileURL) let path = url.path(percentEncoded: false) - let filename = mlx_string_new(path.cString(using: .utf8))! - defer { mlx_free(filename) } switch url.pathExtension { case "npy": - return MLXArray(mlx_load(filename, stream.ctx)) + var result = mlx_array_new() + mlx_load(&result, path.cString(using: .utf8), stream.ctx) + return MLXArray(result) default: throw LoadSaveError.unknownExtension(url.pathExtension) @@ -1458,18 +1588,17 @@ public func loadArray(url: URL, stream: StreamOrDevice = .default) throws -> MLX public func loadArrays(url: URL, stream: StreamOrDevice = .default) throws -> [String: MLXArray] { precondition(url.isFileURL) let path = url.path(percentEncoded: false) - let filename = mlx_string_new(path.cString(using: .utf8))! - defer { mlx_free(filename) } switch url.pathExtension { case "safetensors": - let mlx_safetensors = mlx_load_safetensors(filename, stream.ctx)! - defer { mlx_free(mlx_safetensors) } + var r0 = mlx_map_string_to_array_new() + var r1 = mlx_map_string_to_string_new() - let mlx_arrays = mlx_safetensors_data(mlx_safetensors)! - defer { mlx_free(mlx_arrays) } + mlx_load_safetensors(&r0, &r1, path.cString(using: .utf8), stream.ctx) + defer { mlx_map_string_to_array_free(r0) } + defer { mlx_map_string_to_string_free(r1) } - return mlx_map_array_values(mlx_arrays) + return mlx_map_array_values(r0) default: throw LoadSaveError.unknownExtension(url.pathExtension) } @@ -1489,21 +1618,17 @@ public func loadArraysAndMetadata(url: URL, stream: StreamOrDevice = .default) t ) { precondition(url.isFileURL) let path = url.path(percentEncoded: false) - let filename = mlx_string_new(path.cString(using: .utf8))! - defer { mlx_free(filename) } switch url.pathExtension { case "safetensors": - let mlx_safetensors = mlx_load_safetensors(filename, stream.ctx)! - defer { mlx_free(mlx_safetensors) } - - let mlx_arrays = mlx_safetensors_data(mlx_safetensors)! - defer { mlx_free(mlx_arrays) } + var r0 = mlx_map_string_to_array_new() + var r1 = mlx_map_string_to_string_new() - let mlx_metadata = mlx_safetensors_metadata(mlx_safetensors)! - defer { mlx_free(mlx_metadata) } + mlx_load_safetensors(&r0, &r1, path.cString(using: .utf8), stream.ctx) + defer { mlx_map_string_to_array_free(r0) } + defer { mlx_map_string_to_string_free(r1) } - return (mlx_map_array_values(mlx_arrays), mlx_map_string_values(mlx_metadata)) + return (mlx_map_array_values(r0), mlx_map_string_values(r1)) default: throw LoadSaveError.unknownExtension(url.pathExtension) } @@ -1527,7 +1652,9 @@ public func logAddExp( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_logaddexp(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_logaddexp(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise logical and. @@ -1558,7 +1685,9 @@ public func logicalAnd( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_logical_and(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_logical_and(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise logical not. @@ -1579,7 +1708,9 @@ public func logicalAnd( /// - /// - public func logicalNot(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_logical_not(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_logical_not(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise logical or. @@ -1610,7 +1741,9 @@ public func logicalOr( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_logical_or(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_logical_or(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Indexing mode for ``meshGrid(_:sparse:indexing:stream:)``. @@ -1635,15 +1768,14 @@ public func meshGrid( stream: StreamOrDevice = .default ) -> [MLXArray] { let mlxArrays = new_mlx_vector_array(arrays) - defer { mlx_free(mlxArrays) } + defer { mlx_vector_array_free(mlxArrays) } - let indexingString = mlx_string_new(indexing.rawValue.cString(using: .utf8))! - defer { mlx_free(indexingString) } + var vec = mlx_vector_array_new() - let result = mlx_meshgrid(mlxArrays, sparse, indexingString, stream.ctx)! - defer { mlx_free(result) } + mlx_meshgrid(&vec, mlxArrays, sparse, indexing.rawValue.cString(using: .utf8), stream.ctx) + defer { mlx_vector_array_free(vec) } - return mlx_vector_array_values(result) + return mlx_vector_array_values(vec) } /// Element-wise maximum. @@ -1663,7 +1795,9 @@ public func maximum( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_maximum(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_maximum(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise minimum. @@ -1683,7 +1817,9 @@ public func minimum( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_minimum(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_minimum(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise multiplication. @@ -1710,7 +1846,9 @@ public func multiply( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_multiply(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_multiply(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Replace NaN and Inf values with finite numbers. @@ -1733,7 +1871,9 @@ public func nanToNum( ) -> MLXArray { let posInf = mlx_optional_float(value: posInf ?? 0, has_value: posInf != nil) let negInf = mlx_optional_float(value: negInf ?? 0, has_value: negInf != nil) - return MLXArray(mlx_nan_to_num(array.ctx, nan, posInf, negInf, stream.ctx)) + var result = mlx_array_new() + mlx_nan_to_num(&result, array.ctx, nan, posInf, negInf, stream.ctx) + return MLXArray(result) } /// Element-wise negation. @@ -1754,7 +1894,9 @@ public func nanToNum( /// ### See Also /// - public func negative(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_negative(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_negative(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise not equal. @@ -1784,7 +1926,9 @@ public func notEqual( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_not_equal(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_not_equal(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand. @@ -1800,7 +1944,9 @@ public func notEqual( public func outer( _ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_outer(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_outer(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Mode for ``padded(_:width:value:stream:)`` @@ -1834,12 +1980,12 @@ public func padded( let highPads = (0 ..< ndim).map { _ in width.second.int32 } let value = value ?? MLXArray(0, dtype: array.dtype) - let mlx_mode = mlx_string_new(mode.rawValue.cString(using: .utf8))! - defer { mlx_free(mlx_mode) } - - return MLXArray( - mlx_pad( - array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, mlx_mode, stream.ctx)) + var result = mlx_array_new() + mlx_pad( + &result, + array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, + mode.rawValue.cString(using: .utf8), stream.ctx) + return MLXArray(result) } /// Pad an array with a constant value. @@ -1864,12 +2010,12 @@ public func padded( let highPads = widths.map { $0.second.int32 } let value = value ?? MLXArray(0, dtype: array.dtype) - let mlx_mode = mlx_string_new(mode.rawValue.cString(using: .utf8))! - defer { mlx_free(mlx_mode) } - - return MLXArray( - mlx_pad( - array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, mlx_mode, stream.ctx)) + var result = mlx_array_new() + mlx_pad( + &result, + array.ctx, axes, ndim, lowPads, ndim, highPads, ndim, value.ctx, + mode.rawValue.cString(using: .utf8), stream.ctx) + return MLXArray(result) } /// Returns a partitioned copy of the array such that the smaller `kth` @@ -1893,7 +2039,9 @@ public func padded( public func partitioned(_ array: MLXArray, kth: Int, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_partition(array.ctx, kth.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_partition(&result, array.ctx, kth.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// @@ -1916,7 +2064,9 @@ public func partitioned(_ array: MLXArray, kth: Int, axis: Int, stream: StreamOr /// - ``argPartition(_:kth:axis:stream:)`` public func partitioned(_ array: MLXArray, kth: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_partition_all(array.ctx, kth.int32, stream.ctx)) + var result = mlx_array_new() + mlx_partition_all(&result, array.ctx, kth.int32, stream.ctx) + return MLXArray(result) } /// Put values along an axis at the specified indices. @@ -1935,7 +2085,9 @@ public func putAlong( _ array: MLXArray, _ indices: MLXArray, values: MLXArray, axis: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_put_along_axis(array.ctx, indices.ctx, values.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_put_along_axis(&result, array.ctx, indices.ctx, values.ctx, axis.int32, stream.ctx) + return MLXArray(result) } /// Put values along an axis at the specified indices in a flattened array. @@ -1955,8 +2107,9 @@ public func putAlong( -> MLXArray { let input = array.reshaped([-1], stream: stream) - let result = MLXArray(mlx_put_along_axis(input.ctx, indices.ctx, values.ctx, 0, stream.ctx)) - return result.reshaped(array.shape, stream: stream) + var result = mlx_array_new() + mlx_put_along_axis(&result, input.ctx, indices.ctx, values.ctx, 0, stream.ctx) + return MLXArray(result).reshaped(array.shape, stream: stream) } /// Quantize the matrix `w` using `bits` bits per element. @@ -1977,10 +2130,12 @@ public func putAlong( public func quantized( _ w: MLXArray, groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default ) -> (wq: MLXArray, scales: MLXArray, biases: MLXArray) { - let result_tuple = mlx_quantize(w.ctx, groupSize.int32, bits.int32, stream.ctx)! - defer { mlx_free(result_tuple) } + var r1 = mlx_array_new() + var r2 = mlx_array_new() + var r3 = mlx_array_new() + mlx_quantize(&r1, &r2, &r3, w.ctx, groupSize.int32, bits.int32, stream.ctx) - return mlx_tuple_values(result_tuple) + return (MLXArray(r1), MLXArray(r2), MLXArray(r3)) } /// Perform the matrix multiplication with the quantized matrix `w`. The @@ -1995,10 +2150,12 @@ public func quantizedMatmul( _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray, transpose: Bool = true, groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_quantized_matmul( - x.ctx, w.ctx, scales.ctx, biases.ctx, transpose, groupSize.int32, bits.int32, stream.ctx - )) + var result = mlx_array_new() + mlx_quantized_matmul( + &result, + x.ctx, w.ctx, scales.ctx, biases.ctx, transpose, groupSize.int32, bits.int32, stream.ctx + ) + return MLXArray(result) } /// Convert angles from degrees to radians. @@ -2011,7 +2168,9 @@ public func quantizedMatmul( /// - /// - ``degrees(_:stream:)`` public func radians(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_radians(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_radians(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise remainder of division. @@ -2037,7 +2196,9 @@ public func remainder( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_remainder(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_remainder(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Save array to a binary file in `.npy`format. @@ -2090,10 +2251,10 @@ public func save( let path = url.path(percentEncoded: false) let mlx_arrays = new_mlx_array_map(arrays) - defer { mlx_free(mlx_arrays) } + defer { mlx_map_string_to_array_free(mlx_arrays) } let mlx_metadata = new_mlx_string_map(metadata) - defer { mlx_free(mlx_metadata) } + defer { mlx_map_string_to_string_free(mlx_metadata) } switch url.pathExtension { case "safetensors": @@ -2124,7 +2285,9 @@ public func save( /// ### See Also /// - public func sigmoid(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sigmoid(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sigmoid(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise sign. @@ -2136,7 +2299,9 @@ public func sigmoid(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLX /// ### See Also /// - public func sign(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sign(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sign(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise hyperbolic sine. @@ -2149,7 +2314,9 @@ public func sign(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr /// - /// - ``sin(_:stream:)`` public func sinh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sinh(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sinh(&result, array.ctx, stream.ctx) + return MLXArray(result) } @available(*, deprecated, renamed: "softmax(_:axes:precise:stream:)") @@ -2179,7 +2346,9 @@ public func softMax( public func softmax( _ array: MLXArray, axes: [Int], precise: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_softmax(array.ctx, axes.asInt32, axes.count, precise, stream.ctx)) + var result = mlx_array_new() + mlx_softmax(&result, array.ctx, axes.asInt32, axes.count, precise, stream.ctx) + return MLXArray(result) } @available(*, deprecated, renamed: "softmax(_:axis:precise:stream:)") @@ -2209,7 +2378,9 @@ public func softMax( public func softmax( _ array: MLXArray, axis: Int, precise: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_softmax(array.ctx, [axis.int32], 1, precise, stream.ctx)) + var result = mlx_array_new() + mlx_softmax(&result, array.ctx, [axis.int32], 1, precise, stream.ctx) + return MLXArray(result) } @available(*, deprecated, renamed: "softmax(_:axis:precise:stream:)") @@ -2238,7 +2409,9 @@ public func softMax(_ array: MLXArray, precise: Bool = false, stream: StreamOrDe public func softmax(_ array: MLXArray, precise: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_softmax_all(array.ctx, precise, stream.ctx)) + var result = mlx_array_new() + mlx_softmax_all(&result, array.ctx, precise, stream.ctx) + return MLXArray(result) } /// Returns a sorted copy of the array. @@ -2253,7 +2426,9 @@ public func softmax(_ array: MLXArray, precise: Bool = false, stream: StreamOrDe /// - ``sorted(_:stream:)`` /// - ``argSort(_:axis:stream:)`` public func sorted(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sort(array.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_sort(&result, array.ctx, axis.int32, stream.ctx) + return MLXArray(result) } /// Returns a sorted copy of the flattened array. @@ -2267,7 +2442,9 @@ public func sorted(_ array: MLXArray, axis: Int, stream: StreamOrDevice = .defau /// - ``sorted(_:axis:stream:)`` /// - ``argSort(_:axis:stream:)`` public func sorted(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_sort_all(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_sort_all(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Compute the standard deviation(s) over the given axes. @@ -2287,7 +2464,9 @@ public func std( _ array: MLXArray, axes: [Int], keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_std(array.ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_std(&result, array.ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Compute the standard deviation over the given axis. @@ -2307,7 +2486,9 @@ public func std( _ array: MLXArray, axis: Int, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_std(array.ctx, [axis.int32], 1, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_std(&result, array.ctx, [axis.int32], 1, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Compute the standard deviations over all axes. @@ -2325,7 +2506,9 @@ public func std( public func std( _ array: MLXArray, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_std_all(array.ctx, keepDims, ddof.int32, stream.ctx)) + var result = mlx_array_new() + mlx_std_all(&result, array.ctx, keepDims, ddof.int32, stream.ctx) + return MLXArray(result) } /// Stacks the arrays along a new axis. @@ -2336,8 +2519,10 @@ public func stacked(_ arrays: [MLXArray], axis: Int = 0, stream: StreamOrDevice -> MLXArray { let vector_array = new_mlx_vector_array(arrays) - defer { mlx_free(vector_array) } - return MLXArray(mlx_stack(vector_array, axis.int32, stream.ctx)) + defer { mlx_vector_array_free(vector_array) } + var result = mlx_array_new() + mlx_stack(&result, vector_array, axis.int32, stream.ctx) + return MLXArray(result) } /// Stop gradients from being computed. @@ -2349,7 +2534,9 @@ public func stacked(_ arrays: [MLXArray], axis: Int = 0, stream: StreamOrDevice /// - array: input array /// - stream: stream or device to evaluate on public func stopGradient(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_stop_gradient(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_stop_gradient(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise subtraction. @@ -2376,7 +2563,9 @@ public func subtract( _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_subtract(a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_subtract(&result, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Take values along an axis at the specified indices. @@ -2393,7 +2582,9 @@ public func subtract( public func takeAlong( _ array: MLXArray, _ indices: MLXArray, axis: Int, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_take_along_axis(array.ctx, indices.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_take_along_axis(&result, array.ctx, indices.ctx, axis.int32, stream.ctx) + return MLXArray(result) } /// Take values along an axis at the specified indices from a flattened array. @@ -2410,7 +2601,9 @@ public func takeAlong(_ array: MLXArray, _ indices: MLXArray, stream: StreamOrDe -> MLXArray { let array = array.reshaped([-1], stream: stream) - return MLXArray(mlx_take_along_axis(array.ctx, indices.ctx, 0, stream.ctx)) + var result = mlx_array_new() + mlx_take_along_axis(&result, array.ctx, indices.ctx, 0, stream.ctx) + return MLXArray(result) } /// Element-wise tangent. @@ -2422,7 +2615,9 @@ public func takeAlong(_ array: MLXArray, _ indices: MLXArray, stream: StreamOrDe /// ### See Also /// - public func tan(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_tan(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_tan(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Element-wise hyperbolic tangent. @@ -2434,7 +2629,9 @@ public func tan(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra /// ### See Also /// - public func tanh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_tanh(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_tanh(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Compute tensor dot product. @@ -2452,7 +2649,9 @@ public func tanh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr public func tensordot( _ a: MLXArray, _ b: MLXArray, axes: Int = 1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_tensordot_along_axis(a.ctx, b.ctx, axes.int32, stream.ctx)) + var result = mlx_array_new() + mlx_tensordot_along_axis(&result, a.ctx, b.ctx, axes.int32, stream.ctx) + return MLXArray(result) } /// Compute tensor dot product. @@ -2470,10 +2669,12 @@ public func tensordot( public func tensordot( _ a: MLXArray, _ b: MLXArray, axes: ((Int, Int), (Int, Int)), stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_tensordot( - a.ctx, b.ctx, [axes.0.0, axes.0.1].asInt32, 2, [axes.1.0, axes.1.1].asInt32, 2, - stream.ctx)) + var result = mlx_array_new() + mlx_tensordot( + &result, + a.ctx, b.ctx, [axes.0.0, axes.0.1].asInt32, 2, [axes.1.0, axes.1.1].asInt32, 2, + stream.ctx) + return MLXArray(result) } /// Compute tensor dot product. @@ -2491,10 +2692,12 @@ public func tensordot( public func tensordot( _ a: MLXArray, _ b: MLXArray, axes: ([Int], [Int]), stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_tensordot( - a.ctx, b.ctx, axes.0.asInt32, axes.0.count, axes.1.asInt32, axes.1.count, - stream.ctx)) + var result = mlx_array_new() + mlx_tensordot( + &result, + a.ctx, b.ctx, axes.0.asInt32, axes.0.count, axes.1.asInt32, axes.1.count, + stream.ctx) + return MLXArray(result) } /// Construct array by repeating given array the number of times given by `repetitions`. @@ -2511,7 +2714,9 @@ public func tensordot( public func tiled(_ array: MLXArray, repetitions: [Int], stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_tile(array.ctx, repetitions.asInt32, repetitions.count, stream.ctx)) + var result = mlx_array_new() + mlx_tile(&result, array.ctx, repetitions.asInt32, repetitions.count, stream.ctx) + return MLXArray(result) } /// Construct array by repeating given array the number of times given by `repetitions`. @@ -2528,7 +2733,9 @@ public func tiled(_ array: MLXArray, repetitions: [Int], stream: StreamOrDevice public func tiled(_ array: MLXArray, repetitions: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_tile(array.ctx, [repetitions.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_tile(&result, array.ctx, [repetitions.int32], 1, stream.ctx) + return MLXArray(result) } /// Returns the `k` largest elements from the input along a given axis. @@ -2547,7 +2754,9 @@ public func tiled(_ array: MLXArray, repetitions: Int, stream: StreamOrDevice = public func top(_ array: MLXArray, k: Int, axis: Int = -1, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_topk(array.ctx, k.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_topk(&result, array.ctx, k.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Returns the `k` largest elements from the flattened input along a given axis. @@ -2563,7 +2772,9 @@ public func top(_ array: MLXArray, k: Int, axis: Int = -1, stream: StreamOrDevic /// - /// - ``top(_:k:axis:stream:)`` public func top(_ array: MLXArray, k: Int, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_topk_all(array.ctx, k.int32, stream.ctx)) + var result = mlx_array_new() + mlx_topk_all(&result, array.ctx, k.int32, stream.ctx) + return MLXArray(result) } /// Return the sum along a specified diagonal in the given array. @@ -2583,10 +2794,12 @@ public func trace( _ array: MLXArray, offset: Int = 0, axis1: Int = 0, axis2: Int = 1, dtype: DType? = nil, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_trace( - array.ctx, offset.int32, axis1.int32, axis2.int32, (dtype ?? array.dtype).cmlxDtype, - stream.ctx)) + var result = mlx_array_new() + mlx_trace( + &result, + array.ctx, offset.int32, axis1.int32, axis2.int32, (dtype ?? array.dtype).cmlxDtype, + stream.ctx) + return MLXArray(result) } /// Zeros the array above the given diagonal. @@ -2599,7 +2812,9 @@ public func trace( /// ### See Also /// - ``triu(_:k:stream:)`` public func tril(_ array: MLXArray, k: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_tril(array.ctx, k.int32, stream.ctx)) + var result = mlx_array_new() + mlx_tril(&result, array.ctx, k.int32, stream.ctx) + return MLXArray(result) } /// Zeros the array below the given diagonal. @@ -2612,7 +2827,9 @@ public func tril(_ array: MLXArray, k: Int = 0, stream: StreamOrDevice = .defaul /// ### See Also /// - ``tril(_:k:stream:)`` public func triu(_ array: MLXArray, k: Int = 0, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_triu(array.ctx, k.int32, stream.ctx)) + var result = mlx_array_new() + mlx_triu(&result, array.ctx, k.int32, stream.ctx) + return MLXArray(result) } /// Select from `x` or `y` according to `condition`. @@ -2635,7 +2852,9 @@ public func `where`( _ condition: MLXArray, _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_where(condition.ctx, a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_where(&result, condition.ctx, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } /// Alias for ``where(_:_:_:stream:)`` -- select from `x` or `y` according to `condition`. @@ -2656,5 +2875,7 @@ public func which( _ condition: MLXArray, _ a: A, _ b: B, stream: StreamOrDevice = .default ) -> MLXArray { let (a, b) = toArrays(a, b) - return MLXArray(mlx_where(condition.ctx, a.ctx, b.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_where(&result, condition.ctx, a.ctx, b.ctx, stream.ctx) + return MLXArray(result) } diff --git a/Source/MLX/Stream.swift b/Source/MLX/Stream.swift index 093e7b77..3ea4986d 100644 --- a/Source/MLX/Stream.swift +++ b/Source/MLX/Stream.swift @@ -34,7 +34,7 @@ public struct StreamOrDevice: Sendable, CustomStringConvertible, Equatable { /// This will be ``Device/gpu`` unless ``Device/setDefault(device:)`` /// sets it otherwise. public static var `default`: StreamOrDevice { - StreamOrDevice(Stream()) + StreamOrDevice(Device.defaultStream()) } public static func device(_ device: Device) -> StreamOrDevice { @@ -51,11 +51,11 @@ public struct StreamOrDevice: Sendable, CustomStringConvertible, Equatable { public static let gpu = device(.gpu) public static func stream(_ stream: Stream) -> StreamOrDevice { - StreamOrDevice(stream) + StreamOrDevice(Device.defaultStream()) } /// Internal context -- used with Cmlx calls. - public var ctx: OpaquePointer { + public var ctx: mlx_stream { stream.ctx } @@ -82,22 +82,35 @@ public final class Stream: @unchecked Sendable, Equatable { let ctx: mlx_stream + public static let gpu = Stream(.gpu) + public static let cpu = Stream(.cpu) + init(_ ctx: mlx_stream) { self.ctx = ctx } public init() { - let dDev = mlx_default_device()! - ctx = mlx_default_stream(dDev) - mlx_free(dDev) + let device = Device.defaultDevice() + var ctx = mlx_stream_new() + mlx_get_default_stream(&ctx, device.ctx) + self.ctx = ctx } + @available(*, deprecated, message: "use init(Device) -- index not supported") public init(index: Int32, _ device: Device) { - ctx = mlx_stream_new(index, device.ctx) + var ctx = mlx_stream_new() + mlx_get_default_stream(&ctx, device.ctx) + self.ctx = ctx + } + + public init(_ device: Device) { + var ctx = mlx_stream_new() + mlx_get_default_stream(&ctx, device.ctx) + self.ctx = ctx } deinit { - mlx_free(ctx) + mlx_stream_free(ctx) } /// Synchronize with the given stream @@ -106,7 +119,11 @@ public final class Stream: @unchecked Sendable, Equatable { } static public func defaultStream(_ device: Device) -> Stream { - return Stream(mlx_default_stream(device.ctx)) + switch device.deviceType { + case .cpu: .cpu + case .gpu: .gpu + default: fatalError("Unexpected device type: \(device)") + } } public static func == (lhs: Stream, rhs: Stream) -> Bool { @@ -116,6 +133,9 @@ public final class Stream: @unchecked Sendable, Equatable { extension Stream: CustomStringConvertible { public var description: String { - mlx_describe(ctx) ?? String(describing: type(of: self)) + var s = mlx_string_new() + mlx_stream_tostring(&s, ctx) + defer { mlx_string_free(s) } + return String(cString: mlx_string_data(s), encoding: .utf8)! } } diff --git a/Source/MLX/Transforms+Compile.swift b/Source/MLX/Transforms+Compile.swift index 99395971..cef7dc5b 100644 --- a/Source/MLX/Transforms+Compile.swift +++ b/Source/MLX/Transforms+Compile.swift @@ -74,21 +74,23 @@ final class CompiledFunction: @unchecked (Sendable) { } let innerClosure = new_mlx_closure(inner(tracers:)) - defer { mlx_free(innerClosure) } + defer { mlx_closure_free(innerClosure) } // note: this will use the cached compile (via the id) // but will be able to re-evaluate with fresh state if needed - let compiled = mlx_detail_compile(innerClosure, id, shapeless, [], 0)! - defer { mlx_free(compiled) } + var compiled = mlx_closure_new() + mlx_detail_compile(&compiled, innerClosure, id, shapeless, [], 0) + defer { mlx_closure_free(compiled) } let innerInputs = arguments + stateInputs let innerInputsVector = new_mlx_vector_array(innerInputs) - defer { mlx_free(innerInputsVector) } + defer { mlx_vector_array_free(innerInputsVector) } // will compile the function (if needed) and evaluate the // compiled graph - let resultVector = mlx_closure_apply(compiled, innerInputsVector)! - defer { mlx_free(resultVector) } + var resultVector = mlx_vector_array_new() + mlx_closure_apply(&resultVector, compiled, innerInputsVector) + defer { mlx_vector_array_free(resultVector) } let resultsPlusStateOutput = mlx_vector_array_values(resultVector) diff --git a/Source/MLX/Transforms+Eval.swift b/Source/MLX/Transforms+Eval.swift index baa67bdf..9b73babb 100644 --- a/Source/MLX/Transforms+Eval.swift +++ b/Source/MLX/Transforms+Eval.swift @@ -10,7 +10,7 @@ import Foundation public func eval(_ arrays: MLXArray...) { let vector_array = new_mlx_vector_array(arrays) mlx_eval(vector_array) - mlx_free(vector_array) + mlx_vector_array_free(vector_array) } /// Evaluate one or more `MLXArray` @@ -20,7 +20,7 @@ public func eval(_ arrays: MLXArray...) { public func eval(_ arrays: [MLXArray]) { let vector_array = new_mlx_vector_array(arrays) mlx_eval(vector_array) - mlx_free(vector_array) + mlx_vector_array_free(vector_array) } /// Evaluate one or more `MLXArray` asynchronously. @@ -31,7 +31,7 @@ public func eval(_ arrays: [MLXArray]) { public func asyncEval(_ arrays: [MLXArray]) { let vector_array = new_mlx_vector_array(arrays) mlx_async_eval(vector_array) - mlx_free(vector_array) + mlx_vector_array_free(vector_array) } /// Evaluate one or more `MLXArray`. diff --git a/Source/MLX/Transforms+Internal.swift b/Source/MLX/Transforms+Internal.swift index 57ca2e6f..7a05fb0d 100644 --- a/Source/MLX/Transforms+Internal.swift +++ b/Source/MLX/Transforms+Internal.swift @@ -9,26 +9,31 @@ private func valueAndGradient(apply valueAndGrad: mlx_closure_value_and_grad, ar -> ([MLXArray], [MLXArray]) { let input_vector = new_mlx_vector_array(arrays) - defer { mlx_free(input_vector) } + defer { mlx_vector_array_free(input_vector) } - let vector_pair = mlx_closure_value_and_grad_apply(valueAndGrad, input_vector)! - defer { mlx_free(vector_pair) } + var r0 = mlx_vector_array_new() + var r1 = mlx_vector_array_new() + mlx_closure_value_and_grad_apply(&r0, &r1, valueAndGrad, input_vector) - let (values, gradient) = mlx_tuple_vectors(vector_pair) - return (values, gradient) + defer { mlx_vector_array_free(r0) } + defer { mlx_vector_array_free(r1) } + + return (mlx_vector_array_values(r0), mlx_vector_array_values(r1)) } func buildGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [Int]) -> ( [MLXArray] ) -> [MLXArray] { { (arrays: [MLXArray]) in + var vag = mlx_closure_value_and_grad_new() + let closure = new_mlx_closure(f) - let valueAndGrad = mlx_value_and_grad( - closure, argumentNumbers.asInt32, argumentNumbers.count)! - defer { mlx_free(valueAndGrad) } - mlx_free(closure) + mlx_value_and_grad(&vag, closure, argumentNumbers.asInt32, argumentNumbers.count) + mlx_closure_free(closure) - return valueAndGradient(apply: valueAndGrad, arrays: arrays).1 + defer { mlx_closure_value_and_grad_free(vag) } + + return valueAndGradient(apply: vag, arrays: arrays).1 } } @@ -36,13 +41,15 @@ func buildValueAndGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNu [MLXArray] ) -> ([MLXArray], [MLXArray]) { { (arrays: [MLXArray]) in + var vag = mlx_closure_value_and_grad_new() + let closure = new_mlx_closure(f) - let valueAndGrad = mlx_value_and_grad( - closure, argumentNumbers.asInt32, argumentNumbers.count)! - defer { mlx_free(valueAndGrad) } - mlx_free(closure) + mlx_value_and_grad(&vag, closure, argumentNumbers.asInt32, argumentNumbers.count) + mlx_closure_free(closure) - return valueAndGradient(apply: valueAndGrad, arrays: arrays) + defer { mlx_closure_value_and_grad_free(vag) } + + return valueAndGradient(apply: vag, arrays: arrays) } } @@ -81,14 +88,16 @@ func buildValueAndGradient( return f(parameters, arrays) } + var vag = mlx_closure_value_and_grad_new() + let closure = new_mlx_closure(inner) - let valueAndGrad = mlx_value_and_grad( - closure, Array(Int32(0) ..< Int32(flattenedArrays.count)), - flattenedArrays.count)! - defer { mlx_free(valueAndGrad) } - mlx_free(closure) + mlx_value_and_grad( + &vag, closure, Array(Int32(0) ..< Int32(flattenedArrays.count)), flattenedArrays.count) + mlx_closure_free(closure) + + defer { mlx_closure_value_and_grad_free(vag) } - let (values, flatGradients) = valueAndGradient(apply: valueAndGrad, arrays: flattenedArrays) + let (values, flatGradients) = valueAndGradient(apply: vag, arrays: flattenedArrays) let gradients = unflattened(flatGradients) return (values, gradients) diff --git a/Source/MLX/Transforms.swift b/Source/MLX/Transforms.swift index 5218d877..32d4fa7c 100644 --- a/Source/MLX/Transforms.swift +++ b/Source/MLX/Transforms.swift @@ -17,19 +17,22 @@ import Foundation public func jvp( _ f: @escaping ([MLXArray]) -> [MLXArray], primals: [MLXArray], tangents: [MLXArray] ) -> ([MLXArray], [MLXArray]) { - - let closure = new_mlx_closure(f) let primals_mlx = new_mlx_vector_array(primals) - defer { mlx_free(primals_mlx) } + defer { mlx_vector_array_free(primals_mlx) } let tangents_mlx = new_mlx_vector_array(tangents) - defer { mlx_free(tangents_mlx) } + defer { mlx_vector_array_free(tangents_mlx) } - let vector_pair = mlx_jvp(closure, primals_mlx, tangents_mlx)! - defer { mlx_free(vector_pair) } + var r0 = mlx_vector_array_new() + var r1 = mlx_vector_array_new() + + let closure = new_mlx_closure(f) + mlx_jvp(&r0, &r1, closure, primals_mlx, tangents_mlx) + mlx_closure_free(closure) - mlx_free(closure) + defer { mlx_vector_array_free(r0) } + defer { mlx_vector_array_free(r1) } - return mlx_tuple_vectors(vector_pair) + return (mlx_vector_array_values(r0), mlx_vector_array_values(r1)) } /// Compute the vector-Jacobian product. @@ -46,19 +49,22 @@ public func jvp( public func vjp( _ f: @escaping ([MLXArray]) -> [MLXArray], primals: [MLXArray], cotangents: [MLXArray] ) -> ([MLXArray], [MLXArray]) { - - let closure = new_mlx_closure(f) let primals_mlx = new_mlx_vector_array(primals) - defer { mlx_free(primals_mlx) } + defer { mlx_vector_array_free(primals_mlx) } let cotangents_mlx = new_mlx_vector_array(cotangents) - defer { mlx_free(cotangents_mlx) } + defer { mlx_vector_array_free(cotangents_mlx) } - let vector_pair = mlx_vjp(closure, primals_mlx, cotangents_mlx)! - defer { mlx_free(vector_pair) } + var r0 = mlx_vector_array_new() + var r1 = mlx_vector_array_new() + + let closure = new_mlx_closure(f) + mlx_vjp(&r0, &r1, closure, primals_mlx, cotangents_mlx) + mlx_closure_free(closure) - mlx_free(closure) + defer { mlx_vector_array_free(r0) } + defer { mlx_vector_array_free(r1) } - return mlx_tuple_vectors(vector_pair) + return (mlx_vector_array_values(r0), mlx_vector_array_values(r1)) } /// Returns a function that computes the gradient and result of `f`, computing the gradient with respect to the ``NestedDictionary``. diff --git a/Source/MLXFFT/FFT.swift b/Source/MLXFFT/FFT.swift index 854f6e05..aeafa0e7 100644 --- a/Source/MLXFFT/FFT.swift +++ b/Source/MLXFFT/FFT.swift @@ -19,8 +19,10 @@ import MLX public func fft(_ array: MLXArray, n: Int? = nil, axis: Int = -1, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray( - mlx_fft_fftn(array.ctx, [(n ?? array.dim(axis)).int32], 1, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + + mlx_fft_fftn(&result, array.ctx, [(n ?? array.dim(axis)).int32], 1, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// One dimensional inverse discrete Fourier Transform. @@ -38,7 +40,9 @@ public func fft(_ array: MLXArray, n: Int? = nil, axis: Int = -1, stream: Stream public func ifft( _ array: MLXArray, n: Int? = nil, axis: Int = -1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_fft_ifft(array.ctx, (n ?? array.dim(axis)).int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_fft_ifft(&result, array.ctx, (n ?? array.dim(axis)).int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Two dimensional discrete Fourier Transform. @@ -92,25 +96,30 @@ public func ifft2( public func fftn( _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let s, let axes { // both supplied - return MLXArray( - mlx_fft_fft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_fft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let axes { // no n, compute from dim() let n = axes.map { array.dim($0) } - return MLXArray( - mlx_fft_fft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_fft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let s { // axes are the rightmost dimensions matching the number of dimensions of n let axes = Array(-s.count ..< 0) - return MLXArray( - mlx_fft_fft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_fft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else { let axes = Array(0 ..< array.ndim) let n = axes.map { array.dim($0) } - return MLXArray( - mlx_fft_fft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_fft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } } @@ -129,25 +138,30 @@ public func fftn( public func ifftn( _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let s, let axes { // both supplied - return MLXArray( - mlx_fft_ifft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_ifft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let axes { // no n, compute from dim() let n = axes.map { array.dim($0) } - return MLXArray( - mlx_fft_ifft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_ifft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let s { // axes are the rightmost dimensions matching the number of dimensions of n let axes = Array(-s.count ..< 0) - return MLXArray( - mlx_fft_ifft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_ifft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else { let axes = Array(0 ..< array.ndim) let n = axes.map { array.dim($0) } - return MLXArray( - mlx_fft_ifft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_ifft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } } @@ -169,8 +183,10 @@ public func ifftn( public func rfft( _ array: MLXArray, n: Int? = nil, axis: Int = -1, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_fft_rfftn(array.ctx, [(n ?? array.dim(axis)).int32], 1, [axis.int32], 1, stream.ctx)) + var result = mlx_array_new() + mlx_fft_rfftn( + &result, array.ctx, [(n ?? array.dim(axis)).int32], 1, [axis.int32], 1, stream.ctx) + return MLXArray(result) } /// Inverse one dimensional discrete Fourier Transform on a real input. @@ -192,7 +208,9 @@ public func irfft( _ array: MLXArray, n: Int? = nil, axis: Int = -1, stream: StreamOrDevice = .default ) -> MLXArray { let n = n ?? (array.dim(axis) - 1) * 2 - return MLXArray(mlx_fft_irfft(array.ctx, n.int32, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_fft_irfft(&result, array.ctx, n.int32, axis.int32, stream.ctx) + return MLXArray(result) } /// Two dimensional real discrete Fourier Transform. @@ -259,25 +277,30 @@ public func irfft2( public func rfftn( _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let s, let axes { // both supplied - return MLXArray( - mlx_fft_rfft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_rfft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let axes { // no n, compute from dim() let n = axes.map { array.dim($0) } - return MLXArray( - mlx_fft_rfft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_rfft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let s { // axes are the rightmost dimensions matching the number of dimensions of n let axes = Array(-s.count ..< 0) - return MLXArray( - mlx_fft_rfft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_rfft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else { let axes = Array(0 ..< array.ndim) let n = axes.map { array.dim($0) } - return MLXArray( - mlx_fft_rfft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_rfft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } } @@ -301,26 +324,31 @@ public func rfftn( public func irfftn( _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let s, let axes { // both supplied - return MLXArray( - mlx_fft_irfft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_irfft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let axes { // no n, compute from dim() var n = axes.map { array.dim($0) } n[n.count - 1] = (n[n.count - 1] - 1) * 2 - return MLXArray( - mlx_fft_irfft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_irfft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else if let s { // axes are the rightmost dimensions matching the number of dimensions of n let axes = Array(-s.count ..< 0) - return MLXArray( - mlx_fft_irfft2(array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_irfft2(&result, array.ctx, s.asInt32, s.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } else { let axes = Array(0 ..< array.ndim) var n = axes.map { array.dim($0) } n[n.count - 1] = (n[n.count - 1] - 1) * 2 - return MLXArray( - mlx_fft_irfft2(array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx)) + + mlx_fft_irfft2(&result, array.ctx, n.asInt32, n.count, axes.asInt32, axes.count, stream.ctx) + return MLXArray(result) } } diff --git a/Source/MLXFast/Cmlx+Util.swift b/Source/MLXFast/Cmlx+Util.swift index a36e123b..2688582a 100644 --- a/Source/MLXFast/Cmlx+Util.swift +++ b/Source/MLXFast/Cmlx+Util.swift @@ -4,71 +4,17 @@ import Cmlx import Foundation import MLX -@inline(__always) -func mlx_free(_ ptr: OpaquePointer) { - mlx_free(UnsafeMutableRawPointer(ptr)) -} - // return a +1 mlx_vector_array containing the given arrays func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { - let result = mlx_vector_array_new()! - mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) - return result + mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count) } func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { (0 ..< mlx_vector_array_size(vector_array)) .map { index in // ctx is a +1 object, the array takes ownership - let ctx = mlx_vector_array_get(vector_array, index)! + var ctx = mlx_array_new() + mlx_vector_array_get(&ctx, vector_array, index) return MLXArray(ctx) } } - -func new_mlx_vector_string(_ values: [String]) -> mlx_vector_string { - let result = mlx_vector_string_new()! - for value in values { - let mlxString = mlx_string_new(value.cString(using: .utf8))! - mlx_vector_string_add_value(result, mlxString) - mlx_free(mlxString) - } - return result -} - -func new_mlx_vector_vector_int(_ values: [[Int]]) -> mlx_vector_vector_int { - let result = mlx_vector_vector_int_new()! - for value in values { - let vector = mlx_vector_int_from_data(value.map { Int32($0) }, value.count)! - mlx_vector_vector_int_add_value(result, vector) - mlx_free(vector) - } - return result -} - -func new_mlx_vector_array_dtype(_ values: [DType]) -> mlx_vector_array_dtype { - mlx_vector_array_dtype_from_data( - values.map { $0.cmlxDtype }, - values.count - ) -} - -func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { - let a = mlx_tuple_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_get_1(tuple)! - return (MLXArray(a), MLXArray(b)) -} - -func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { - let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! - defer { mlx_free(a) } - let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! - defer { mlx_free(b) } - return (mlx_vector_array_values(a), mlx_vector_array_values(b)) -} - -func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { - let a = mlx_tuple_array_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_array_get_1(tuple)! - let c = mlx_tuple_array_array_array_get_2(tuple)! - return (MLXArray(a), MLXArray(b), MLXArray(c)) -} diff --git a/Source/MLXFast/MLXFast.swift b/Source/MLXFast/MLXFast.swift index 79b2cf96..feab3282 100644 --- a/Source/MLXFast/MLXFast.swift +++ b/Source/MLXFast/MLXFast.swift @@ -26,11 +26,13 @@ public func RoPE( _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, offset: Int, freqs: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) - return MLXArray( - mlx_fast_rope( - array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), - freqs?.ctx, stream.ctx)) + mlx_fast_rope( + &result, + array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), + (freqs ?? .mlxNone).ctx, stream.ctx) + return MLXArray(result) } /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` @@ -59,12 +61,14 @@ public func scaledDotProductAttention( queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() let memoryEfficientThreshold = mlx_optional_int( value: Int32(memoryEfficientThreshold ?? 0), has_value: memoryEfficientThreshold != nil) - return MLXArray( - mlx_fast_scaled_dot_product_attention( - queries.ctx, keys.ctx, values.ctx, scale, mask?.ctx, - memoryEfficientThreshold, stream.ctx)) + mlx_fast_scaled_dot_product_attention( + &result, + queries.ctx, keys.ctx, values.ctx, scale, (mask ?? .mlxNone).ctx, + memoryEfficientThreshold, stream.ctx) + return MLXArray(result) } /// Root Mean Square normalization (RMS norm). @@ -80,7 +84,9 @@ public func scaledDotProductAttention( public func rmsNorm(_ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_fast_rms_norm(x.ctx, weight.ctx, eps, stream.ctx)) + var result = mlx_array_new() + mlx_fast_rms_norm(&result, x.ctx, weight.ctx, eps, stream.ctx) + return MLXArray(result) } /// Layer normalization. @@ -99,28 +105,8 @@ public func layerNorm( _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_fast_layer_norm(x.ctx, weight?.ctx, bias?.ctx, eps, stream.ctx)) -} - -/// Quantize the matrix `w` using the provided `scales` and -/// `biases` and the `groupSize` and `bits` configuration. - -/// For details, please see -/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.affine_quantize.html) -/// -/// - Parameters: -/// - w: Matrix to be quantized -/// - scales: The scales to use per `groupSize` elements of `w` -/// - biases: The biases to use per `groupSize` elements of `w` -/// - groupSize: The size of the group in `w` that shares a scale and bias. -/// - bits: The number of bits occupied by each element in `w`. -/// - stream: stream or device to evaluate on -/// - Returns: quantized version of `w` -public func affineQuantized( - _ w: MLXArray, scales: MLXArray, biases: MLXArray, groupSize: Int = 64, bits: Int = 4, - stream: StreamOrDevice = .default -) -> MLXArray { - MLXArray( - mlx_fast_affine_quantize( - w.ctx, scales.ctx, biases.ctx, Int32(groupSize), Int32(bits), stream.ctx)) + var result = mlx_array_new() + mlx_fast_layer_norm( + &result, x.ctx, (weight ?? .mlxNone).ctx, (bias ?? .mlxNone).ctx, eps, stream.ctx) + return MLXArray(result) } diff --git a/Source/MLXFast/MLXFastKernel.swift b/Source/MLXFast/MLXFastKernel.swift index 8425a40d..18587512 100644 --- a/Source/MLXFast/MLXFastKernel.swift +++ b/Source/MLXFast/MLXFastKernel.swift @@ -17,43 +17,10 @@ extension Bool: KernelTemplateArg {} extension Int: KernelTemplateArg {} extension DType: KernelTemplateArg {} -/// Add a ``KernelTemplateArg`` to the tuple of vectors -private func add( - name: String, - value: any KernelTemplateArg, - to vector: mlx_vector_tuple_string_variant_int_bool_array_dtype -) { - let name = mlx_string_new(name.cString(using: .utf8))! - defer { mlx_free(name) } - - let value = - switch value { - case let value as Bool: - mlx_variant_int_bool_array_dtype_new_with_bool(value)! - - case let value as Int: - mlx_variant_int_bool_array_dtype_new_with_int(Int32(value))! - - case let value as DType: - mlx_variant_int_bool_array_dtype_new_with_array_dtype(value.cmlxDtype)! - - default: - fatalError("Unable to handle KernelTemplateArg with type: \(type(of: value)).") - } - - defer { mlx_free(value) } - - let tuple = mlx_tuple_string_variant_int_bool_array_dtype_new(name, value)! - defer { mlx_free(tuple) } - - mlx_vector_tuple_string_variant_int_bool_array_dtype_add_value(vector, tuple) -} - /// Container for a kernel created by -/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:)``. +/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)`` /// -/// The ``callAsFunction(inputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:stream:)`` -/// can be used to evaluate the kernel with inputs: +/// The ``callAsFunction(_:stream:)`` can be used to evaluate the kernel with inputs: /// /// ```swift /// let a = normal([2, 2]) @@ -64,47 +31,85 @@ private func add( /// source: """ /// uint elem = thread_position_in_grid.x; /// out1[elem] = a[elem]; -/// """) -/// -/// let out = kernel( -/// inputs: [a], +/// """, /// grid: (4, 1, 1), /// threadGroup: (2, 1, 1), /// outputShapes: [[2, 2]], /// outputDTypes: [.float32]) +/// +/// let out = kernel([a]) /// ``` open class MLXFastKernel { - let kernel: mlx_closure_metal_kernel_function + let kernel: mlx_fast_metal_kernel public let outputNames: [String] init( name: String, inputNames: [String], outputNames: [String], source: String, header: String = "", ensureRowContiguous: Bool = true, - atomicOutputs: Bool = false + atomicOutputs: Bool = false, + template: [(String, KernelTemplateArg)]? = nil, + grid: (Int, Int, Int), + threadGroup: (Int, Int, Int), + outputShapes: [[Int]], + outputDTypes: [DType], + initValue: Float? = nil, + verbose: Bool = false ) { self.outputNames = outputNames - let mlxName = mlx_string_new(name.cString(using: .utf8))! - defer { mlx_free(mlxName) } + self.kernel = mlx_fast_metal_kernel_new( + name.cString(using: .utf8), + source.cString(using: .utf8), + header.cString(using: .utf8)) + + for name in inputNames { + mlx_fast_metal_kernel_add_input_name(kernel, name) + } + for name in outputNames { + mlx_fast_metal_kernel_add_output_name(kernel, name) + } + + mlx_fast_metal_kernel_set_contiguous_rows(kernel, ensureRowContiguous) + mlx_fast_metal_kernel_set_atomic_outputs(kernel, atomicOutputs) + + if let template { + for (name, arg) in template { + switch arg { + case let value as Bool: + mlx_fast_metal_kernel_add_template_arg_bool(kernel, name, value) + + case let value as Int: + mlx_fast_metal_kernel_add_template_arg_int(kernel, name, Int32(value)) - let mlxInputNames = new_mlx_vector_string(inputNames) - defer { mlx_free(mlxInputNames) } - let mlxOutputNames = new_mlx_vector_string(outputNames) - defer { mlx_free(mlxOutputNames) } + case let value as DType: + mlx_fast_metal_kernel_add_template_arg_dtype(kernel, name, value.cmlxDtype) - let mlxSource = mlx_string_new(source.cString(using: .utf8))! - defer { mlx_free(mlxSource) } - let mlxHeader = mlx_string_new(header.cString(using: .utf8))! - defer { mlx_free(mlxHeader) } + default: + fatalError( + "Unable to handle KernelTemplateArg \(name) with type: \(type(of: arg)).") + } + } + } + + mlx_fast_metal_kernel_set_grid(kernel, Int32(grid.0), Int32(grid.1), Int32(grid.2)) + mlx_fast_metal_kernel_set_thread_group( + kernel, Int32(threadGroup.0), Int32(threadGroup.1), Int32(threadGroup.2)) - self.kernel = mlx_fast_metal_kernel( - mlxName, mlxInputNames, mlxOutputNames, mlxSource, mlxHeader, ensureRowContiguous, - atomicOutputs) + for (shape, dtype) in zip(outputShapes, outputDTypes) { + mlx_fast_metal_kernel_add_output_arg( + kernel, shape.map { Int32($0) }, shape.count, dtype.cmlxDtype) + } + + if let initValue { + mlx_fast_metal_kernel_set_init_value(kernel, initValue) + } + + mlx_fast_metal_kernel_set_verbose(kernel, verbose) } deinit { - mlx_free(kernel) + mlx_fast_metal_kernel_free(kernel) } /// Call the prepared metal kernel. @@ -125,55 +130,15 @@ open class MLXFastKernel { /// - stream: stream to run on /// - Returns: array of `MLXArray` public func callAsFunction( - inputs: [ScalarOrArray], - template: [(String, KernelTemplateArg)]? = nil, - grid: (Int, Int, Int), - threadGroup: (Int, Int, Int), - outputShapes: [[Int]], - outputDTypes: [DType], - initValue: Float? = nil, - verbose: Bool = false, + _ inputs: [ScalarOrArray], stream: StreamOrDevice = .default ) -> [MLXArray] { - // convert all the inputs into the mlx-c types let inputs = new_mlx_vector_array(inputs.map { $0.asMLXArray(dtype: nil) }) - defer { mlx_free(inputs) } - - let outputShapes = new_mlx_vector_vector_int(outputShapes) - defer { mlx_free(outputShapes) } - - let outputDTypes = new_mlx_vector_array_dtype(outputDTypes) - defer { mlx_free(outputDTypes) } - - let grid = mlx_tuple_int_int_int_new(Int32(grid.0), Int32(grid.1), Int32(grid.2))! - defer { mlx_free(grid) } - - let threadGroup = mlx_tuple_int_int_int_new( - Int32(threadGroup.0), Int32(threadGroup.1), Int32(threadGroup.2))! - defer { mlx_free(threadGroup) } - - let templateVector = mlx_vector_tuple_string_variant_int_bool_array_dtype_new()! - defer { mlx_free(templateVector) } - if let template { - for (name, value) in template { - add(name: name, value: value, to: templateVector) - } - } + defer { mlx_vector_array_free(inputs) } - let initValue = mlx_optional_float(value: initValue ?? 0, has_value: initValue != nil) - - let result = mlx_closure_metal_kernel_function_apply( - kernel, - inputs, - outputShapes, - outputDTypes, - grid, - threadGroup, - templateVector, - initValue, - verbose, - stream.ctx)! - defer { mlx_free(result) } + var result = mlx_vector_array_new() + mlx_fast_metal_kernel_apply(&result, kernel, inputs, stream.ctx) + defer { mlx_vector_array_free(result) } return mlx_vector_array_values(result) } @@ -198,10 +163,21 @@ public func metalKernel( name: String, inputNames: [String], outputNames: [String], source: String, header: String = "", ensureRowContiguous: Bool = true, - atomicOutputs: Bool = false + atomicOutputs: Bool = false, + template: [(String, KernelTemplateArg)]? = nil, + grid: (Int, Int, Int), + threadGroup: (Int, Int, Int), + outputShapes: [[Int]], + outputDTypes: [DType], + initValue: Float? = nil, + verbose: Bool = false ) -> MLXFastKernel { MLXFastKernel( name: name, inputNames: inputNames, outputNames: outputNames, source: source, header: header, - ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs) + ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs, + template: template, grid: grid, threadGroup: threadGroup, + outputShapes: outputShapes, outputDTypes: outputDTypes, + initValue: initValue, verbose: verbose + ) } diff --git a/Source/MLXLinalg/Cmlx+Util.swift b/Source/MLXLinalg/Cmlx+Util.swift index 06d5024d..2688582a 100644 --- a/Source/MLXLinalg/Cmlx+Util.swift +++ b/Source/MLXLinalg/Cmlx+Util.swift @@ -4,44 +4,17 @@ import Cmlx import Foundation import MLX -@inline(__always) -func mlx_free(_ ptr: OpaquePointer) { - mlx_free(UnsafeMutableRawPointer(ptr)) -} - // return a +1 mlx_vector_array containing the given arrays func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { - let result = mlx_vector_array_new()! - mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) - return result + mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count) } func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { (0 ..< mlx_vector_array_size(vector_array)) .map { index in // ctx is a +1 object, the array takes ownership - let ctx = mlx_vector_array_get(vector_array, index)! + var ctx = mlx_array_new() + mlx_vector_array_get(&ctx, vector_array, index) return MLXArray(ctx) } } - -func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { - let a = mlx_tuple_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_get_1(tuple)! - return (MLXArray(a), MLXArray(b)) -} - -func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { - let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! - defer { mlx_free(a) } - let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! - defer { mlx_free(b) } - return (mlx_vector_array_values(a), mlx_vector_array_values(b)) -} - -func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { - let a = mlx_tuple_array_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_array_get_1(tuple)! - let c = mlx_tuple_array_array_array_get_2(tuple)! - return (MLXArray(a), MLXArray(b), MLXArray(c)) -} diff --git a/Source/MLXLinalg/Linalg.swift b/Source/MLXLinalg/Linalg.swift index 0f9dd58f..93fcf849 100644 --- a/Source/MLXLinalg/Linalg.swift +++ b/Source/MLXLinalg/Linalg.swift @@ -60,14 +60,14 @@ public func norm( _ array: MLXArray, ord: NormKind? = nil, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let ord { - let ord_str = mlx_string_new(ord.rawValue.cString(using: .utf8))! - defer { mlx_free(ord_str) } - return MLXArray( - mlx_linalg_norm_ord(array.ctx, ord_str, axes.asInt32, axes.count, keepDims, stream.ctx)) + mlx_linalg_norm_ord( + &result, array.ctx, ord.rawValue, axes.asInt32, axes.count, keepDims, stream.ctx) } else { - return MLXArray(mlx_linalg_norm(array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx)) + mlx_linalg_norm(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) } + return MLXArray(result) } /// Matrix or vector norm. @@ -115,7 +115,9 @@ public func norm( _ array: MLXArray, ord: Double, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_linalg_norm_p(array.ctx, ord, axes.asInt32, axes.count, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_norm_p(&result, array.ctx, ord, axes.asInt32, axes.count, keepDims, stream.ctx) + return MLXArray(result) } /// Matrix or vector norm. @@ -125,13 +127,14 @@ public func norm( _ array: MLXArray, ord: NormKind? = nil, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let ord { - let ord_str = mlx_string_new(ord.rawValue.cString(using: .utf8))! - defer { mlx_free(ord_str) } - return MLXArray( - mlx_linalg_norm_ord(array.ctx, ord_str, [axis].asInt32, 1, keepDims, stream.ctx)) + mlx_linalg_norm_ord( + &result, array.ctx, ord.rawValue, [axis].asInt32, 1, keepDims, stream.ctx) + return MLXArray(result) } else { - return MLXArray(mlx_linalg_norm(array.ctx, [axis].asInt32, 1, keepDims, stream.ctx)) + mlx_linalg_norm(&result, array.ctx, [axis].asInt32, 1, keepDims, stream.ctx) + return MLXArray(result) } } @@ -142,7 +145,9 @@ public func norm( _ array: MLXArray, ord: Double, axis: Int, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_linalg_norm_p(array.ctx, ord, [axis].asInt32, 1, keepDims, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_norm_p(&result, array.ctx, ord, [axis].asInt32, 1, keepDims, stream.ctx) + return MLXArray(result) } /// Matrix or vector norm. @@ -152,15 +157,16 @@ public func norm( _ array: MLXArray, ord: NormKind? = nil, axis: IntOrArray? = nil, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { + var result = mlx_array_new() if let ord { - let ord_str = mlx_string_new(ord.rawValue.cString(using: .utf8))! - defer { mlx_free(ord_str) } - return MLXArray( - mlx_linalg_norm_ord( - array.ctx, ord_str, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx)) + mlx_linalg_norm_ord( + &result, array.ctx, ord.rawValue, axis?.asInt32Array, axis?.count ?? 0, keepDims, + stream.ctx) + return MLXArray(result) } else { - return MLXArray( - mlx_linalg_norm(array.ctx, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx)) + mlx_linalg_norm( + &result, array.ctx, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx) + return MLXArray(result) } } @@ -171,9 +177,11 @@ public func norm( _ array: MLXArray, ord: Double, axis: IntOrArray? = nil, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray( - mlx_linalg_norm_p( - array.ctx, ord, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx)) + var result = mlx_array_new() + + mlx_linalg_norm_p( + &result, array.ctx, ord, axis?.asInt32Array, axis?.count ?? 0, keepDims, stream.ctx) + return MLXArray(result) } /// The QR factorization of the input matrix. @@ -184,10 +192,12 @@ public func norm( /// /// - Returns: the `Q` and `R` matrices public func qr(_ array: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { - let result_tuple = mlx_linalg_qr(array.ctx, stream.ctx)! - defer { mlx_free(result_tuple) } + var r0 = mlx_array_new() + var r1 = mlx_array_new() + + mlx_linalg_qr(&r0, &r1, array.ctx, stream.ctx) - return mlx_tuple_values(result_tuple) + return (MLXArray(r0), MLXArray(r1)) } /// The Singular Value Decomposition (SVD) of the input matrix. @@ -203,10 +213,11 @@ public func qr(_ array: MLXArray, stream: StreamOrDevice = .default) -> (MLXArra public func svd(_ array: MLXArray, stream: StreamOrDevice = .default) -> ( MLXArray, MLXArray, MLXArray ) { - let mlx_arrays = mlx_linalg_svd(array.ctx, stream.ctx)! - defer { mlx_free(mlx_arrays) } + var vec = mlx_vector_array_new() + mlx_linalg_svd(&vec, array.ctx, stream.ctx) + defer { mlx_vector_array_free(vec) } - let arrays = mlx_vector_array_values(mlx_arrays) + let arrays = mlx_vector_array_values(vec) return (arrays[0], arrays[1], arrays[2]) } @@ -221,7 +232,9 @@ public func svd(_ array: MLXArray, stream: StreamOrDevice = .default) -> ( /// - stream: stream or device to evaluate on /// - Returns: `ainv` such that `dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])` public func inv(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_linalg_inv(array.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_inv(&result, array.ctx, stream.ctx) + return MLXArray(result) } /// Compute the inverse of a triangular square matrix. @@ -239,7 +252,9 @@ public func triInv( _ array: MLXArray, upper: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { - MLXArray(mlx_linalg_tri_inv(array.ctx, upper, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_tri_inv(&result, array.ctx, upper, stream.ctx) + return MLXArray(result) } /// Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. @@ -258,7 +273,9 @@ public func triInv( public func cholesky(_ array: MLXArray, upper: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_linalg_cholesky(array.ctx, upper, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_cholesky(&result, array.ctx, upper, stream.ctx) + return MLXArray(result) } /// Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition. @@ -277,7 +294,9 @@ public func cholesky(_ array: MLXArray, upper: Bool = false, stream: StreamOrDev public func choleskyInv(_ array: MLXArray, upper: Bool = false, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_linalg_cholesky_inv(array.ctx, upper, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_cholesky_inv(&result, array.ctx, upper, stream.ctx) + return MLXArray(result) } /// Compute the cross product of two arrays along a specified axis. @@ -294,5 +313,7 @@ public func choleskyInv(_ array: MLXArray, upper: Bool = false, stream: StreamOr public func cross(_ a: MLXArray, _ b: MLXArray, axis: Int = -1, stream: StreamOrDevice = .default) -> MLXArray { - MLXArray(mlx_linalg_cross(a.ctx, b.ctx, axis.int32, stream.ctx)) + var result = mlx_array_new() + mlx_linalg_cross(&result, a.ctx, b.ctx, axis.int32, stream.ctx) + return MLXArray(result) } diff --git a/Source/MLXRandom/Cmlx+Util.swift b/Source/MLXRandom/Cmlx+Util.swift index 06d5024d..efc391ba 100644 --- a/Source/MLXRandom/Cmlx+Util.swift +++ b/Source/MLXRandom/Cmlx+Util.swift @@ -4,44 +4,16 @@ import Cmlx import Foundation import MLX -@inline(__always) -func mlx_free(_ ptr: OpaquePointer) { - mlx_free(UnsafeMutableRawPointer(ptr)) -} - -// return a +1 mlx_vector_array containing the given arrays func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { - let result = mlx_vector_array_new()! - mlx_vector_array_add_data(result, arrays.map { $0.ctx }, arrays.count) - return result + mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count) } func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] { (0 ..< mlx_vector_array_size(vector_array)) .map { index in // ctx is a +1 object, the array takes ownership - let ctx = mlx_vector_array_get(vector_array, index)! + var ctx = mlx_array_new() + mlx_vector_array_get(&ctx, vector_array, index) return MLXArray(ctx) } } - -func mlx_tuple_values(_ tuple: mlx_tuple_array_array) -> (MLXArray, MLXArray) { - let a = mlx_tuple_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_get_1(tuple)! - return (MLXArray(a), MLXArray(b)) -} - -func mlx_tuple_vectors(_ tuple: mlx_tuple_vector_array_vector_array) -> ([MLXArray], [MLXArray]) { - let a = mlx_tuple_vector_array_vector_array_get_0(tuple)! - defer { mlx_free(a) } - let b = mlx_tuple_vector_array_vector_array_get_1(tuple)! - defer { mlx_free(b) } - return (mlx_vector_array_values(a), mlx_vector_array_values(b)) -} - -func mlx_tuple_values(_ tuple: mlx_tuple_array_array_array) -> (MLXArray, MLXArray, MLXArray) { - let a = mlx_tuple_array_array_array_get_0(tuple)! - let b = mlx_tuple_array_array_array_get_1(tuple)! - let c = mlx_tuple_array_array_array_get_2(tuple)! - return (MLXArray(a), MLXArray(b), MLXArray(c)) -} diff --git a/Source/MLXRandom/Random.swift b/Source/MLXRandom/Random.swift index a49c782b..08807855 100644 --- a/Source/MLXRandom/Random.swift +++ b/Source/MLXRandom/Random.swift @@ -20,7 +20,9 @@ public func seed(_ seed: UInt64) { /// functions take an optional key -- this will let you control the /// random number generation. public func key(_ seed: UInt64) -> MLXArray { - MLXArray(mlx_random_key(seed)) + var result = mlx_array_new() + mlx_random_key(&result, seed) + return MLXArray(result) } /// Split a PRNG key into sub keys. @@ -28,8 +30,10 @@ public func key(_ seed: UInt64) -> MLXArray { /// ### See Also /// - ``split(key:stream:)`` public func split(key: MLXArray, into num: Int, stream: StreamOrDevice = .default) -> [MLXArray] { - let keys = MLXArray(mlx_random_split_equal_parts(key.ctx, num.int32, stream.ctx)) - return keys.map { $0 } + var keys = mlx_array_new() + mlx_random_split_equal_parts(&keys, key.ctx, num.int32, stream.ctx) + + return MLXArray(keys).map { $0 } } /// Split a PRNG key into two keys and return a tuple. @@ -37,9 +41,10 @@ public func split(key: MLXArray, into num: Int, stream: StreamOrDevice = .defaul /// ### See Also /// - ``split(key:into:stream:)`` public func split(key: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray, MLXArray) { - let keys = mlx_random_split(key.ctx, stream.ctx)! - defer { mlx_free(keys) } - return mlx_tuple_values(keys) + var r0 = mlx_array_new() + var r1 = mlx_array_new() + mlx_random_split(&r0, &r1, key.ctx, stream.ctx) + return (MLXArray(r0), MLXArray(r1)) } /// Generate uniformly distributed random numbers with a `RangeExpression`. @@ -63,9 +68,12 @@ public func uniform( let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) let key = key ?? globalState.next() - return MLXArray( - mlx_random_uniform( - lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_uniform( + &result, lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate uniformly distributed random numbers with a `RangeExpression` (specialization). @@ -83,9 +91,12 @@ public func uniform( let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) let key = key ?? globalState.next() - return MLXArray( - mlx_random_uniform( - lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_uniform( + &result, lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate uniformly distributed random numbers between `low` and `high`. @@ -109,9 +120,13 @@ public func uniform( let (low, high) = toArrays(low, high) let shape = shape ?? low.shape let key = key ?? globalState.next() - return MLXArray( - mlx_random_uniform( - low.ctx, high.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_uniform( + &result, low.ctx, high.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, + stream.ctx) + + return MLXArray(result) } /// Generate uniformly distributed random numbers between `low` and `high` with a given `DType`. @@ -135,9 +150,13 @@ public func uniform( let (low, high) = toArrays(low, high) let shape = shape ?? low.shape let key = key ?? globalState.next() - return MLXArray( - mlx_random_uniform( - low.ctx, high.ctx, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_uniform( + &result, low.ctx, high.ctx, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx + ) + + return MLXArray(result) } /// Generate normally distributed random numbers. @@ -167,9 +186,12 @@ public func normal( stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let key = key ?? globalState.next() - return MLXArray( - mlx_random_normal( - shape.asInt32, shape.count, T.dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_normal( + &result, shape.asInt32, shape.count, T.dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate normally distributed random numbers. @@ -199,9 +221,12 @@ public func normal( stream: StreamOrDevice = .default ) -> MLXArray { let key = key ?? globalState.next() - return MLXArray( - mlx_random_normal( - shape.asInt32, shape.count, dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_normal( + &result, shape.asInt32, shape.count, dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate jointly-normal random samples given a mean and covariance. @@ -225,10 +250,13 @@ public func multivariateNormal( key: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let key = key ?? globalState.next() - return MLXArray( - mlx_random_multivariate_normal( - mean.ctx, covariance.ctx, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, - stream.ctx)) + var result = mlx_array_new() + + mlx_random_multivariate_normal( + &result, mean.ctx, covariance.ctx, shape.asInt32, shape.count, + dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate random integers from the given interval using a `RangeExpression`. @@ -252,9 +280,12 @@ public func randInt( let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) let key = key ?? globalState.next() - return MLXArray( - mlx_random_randint( - lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_randint( + &result, lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate random integers from the given interval (`low:` and `high:`). @@ -277,10 +308,13 @@ public func randInt( let (low, high) = toArrays(low, high) let shape = shape ?? low.shape let key = key ?? globalState.next() - return MLXArray( - mlx_random_randint( - low.ctx, high.ctx, shape.asInt32, shape.count, low.dtype.cmlxDtype, key.ctx, stream.ctx - )) + var result = mlx_array_new() + + mlx_random_randint( + &result, low.ctx, high.ctx, shape.asInt32, shape.count, low.dtype.cmlxDtype, key.ctx, + stream.ctx + ) + return MLXArray(result) } /// Generate random integers from the given interval (`low:` and `high:`) with a given type, e.g. `Int8.self`. @@ -304,9 +338,13 @@ public func randInt( let (low, high) = toArrays(low, high) let shape = shape ?? low.shape let key = key ?? globalState.next() - return MLXArray( - mlx_random_randint( - low.ctx, high.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_randint( + &result, low.ctx, high.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, + stream.ctx) + + return MLXArray(result) } /// Generate Bernoulli random values with a `p` value of 0.5. @@ -328,7 +366,10 @@ public func bernoulli(_ shape: [Int] = [], key: MLXArray? = nil, stream: StreamO { let p = MLXArray(0.5) let key = key ?? globalState.next() - return MLXArray(mlx_random_bernoulli(p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_random_bernoulli(&result, p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate Bernoulli random values with a given `p` value. @@ -356,7 +397,10 @@ public func bernoulli( let p = p.asMLXArray(dtype: .float32) let shape = shape ?? p.shape let key = key ?? globalState.next() - return MLXArray(mlx_random_bernoulli(p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_random_bernoulli(&result, p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate values from a truncated normal distribution. @@ -384,9 +428,12 @@ public func truncatedNormal( let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) let key = key ?? globalState.next() - return MLXArray( - mlx_random_truncated_normal( - lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_truncated_normal( + &result, lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate values from a truncated normal distribution in a given `RangeExpression`. @@ -404,9 +451,12 @@ public func truncatedNormal( let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) let key = key ?? globalState.next() - return MLXArray( - mlx_random_truncated_normal( - lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_truncated_normal( + &result, lb.ctx, ub.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Generate values from a truncated normal distribution between `low` and `high`. @@ -429,9 +479,13 @@ public func truncatedNormal( let (low, high) = toArrays(low, high) let shape = shape ?? low.shape let key = key ?? globalState.next() - return MLXArray( - mlx_random_truncated_normal( - low.ctx, high.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_truncated_normal( + &result, low.ctx, high.ctx, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, + stream.ctx) + + return MLXArray(result) } /// Generate values from a truncated normal distribution between `low` and `high` with a given `DType`. @@ -454,9 +508,13 @@ public func truncatedNormal( let (low, high) = toArrays(low, high) let shape = shape ?? low.shape let key = key ?? globalState.next() - return MLXArray( - mlx_random_truncated_normal( - low.ctx, high.ctx, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_truncated_normal( + &result, low.ctx, high.ctx, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx + ) + + return MLXArray(result) } /// Sample from the standard Gumbel distribution. @@ -478,8 +536,11 @@ public func gumbel( stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let key = key ?? globalState.next() - return MLXArray( - mlx_random_gumbel(shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_gumbel(&result, shape.asInt32, shape.count, T.dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Sample from the standard Gumbel distribution with a given `DType`. @@ -501,8 +562,11 @@ public func gumbel( stream: StreamOrDevice = .default ) -> MLXArray { let key = key ?? globalState.next() - return MLXArray( - mlx_random_gumbel(shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_gumbel(&result, shape.asInt32, shape.count, dtype.cmlxDtype, key.ctx, stream.ctx) + + return MLXArray(result) } /// Sample from a categorical distribution. @@ -529,11 +593,17 @@ public func categorical( ) -> MLXArray { let key = key ?? globalState.next() if let shape { - return MLXArray( - mlx_random_categorical_shape( - logits.ctx, axis.int32, shape.asInt32, shape.count, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_categorical_shape( + &result, logits.ctx, axis.int32, shape.asInt32, shape.count, key.ctx, stream.ctx) + + return MLXArray(result) } else { - return MLXArray(mlx_random_categorical(logits.ctx, axis.int32, key.ctx, stream.ctx)) + var result = mlx_array_new() + mlx_random_categorical(&result, logits.ctx, axis.int32, key.ctx, stream.ctx) + + return MLXArray(result) } } @@ -558,9 +628,12 @@ public func categorical( stream: StreamOrDevice = .default ) -> MLXArray { let key = key ?? globalState.next() - return MLXArray( - mlx_random_categorical_num_samples( - logits.ctx, axis.int32, count.int32, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_categorical_num_samples( + &result, logits.ctx, axis.int32, count.int32, key.ctx, stream.ctx) + + return MLXArray(result) } /// Sample numbers from a Laplace distribution. @@ -575,7 +648,10 @@ public func laplace( key: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let key = key ?? globalState.next() - return MLXArray( - mlx_random_laplace( - shape.asInt32, shape.count, dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx)) + var result = mlx_array_new() + + mlx_random_laplace( + &result, shape.asInt32, shape.count, dtype.cmlxDtype, loc, scale, key.ctx, stream.ctx) + + return MLXArray(result) } diff --git a/Source/Tools/GenerateGrad.swift b/Source/Tools/GenerateGrad.swift deleted file mode 100644 index 7c925a8c..00000000 --- a/Source/Tools/GenerateGrad.swift +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation - -@main -struct GenerateGrad { - - /// up to how many MLXArray tuples should we generate, e.g. 3 == `MLXArray, MLXArray, MLXArray` - static let inputTupleCount = 1 - static let outputTupleCount = 1 - - static func indentLines(_ text: String, lead: String) -> String { - lead - + text - .split(separator: "\n", omittingEmptySubsequences: false) - .joined(separator: "\n\(lead)") - } - - struct MethodInfo { - let methodName: String - let methodDescription: String - let internalDocumentation: String - let seeAlso: String - let arguments: (String, String) -> String - let returnValue: (String, String) -> String - let body: (String) -> String - } - - static let methodInfo = [ - "grad": MethodInfo( - methodName: "grad", - methodDescription: "Returns a function which computes the gradient of `f`.", - internalDocumentation: - """ - Converts the given function `f()` into canonical types, e.g. - (MLXArray) -> MLXArray into the canonical form ([MLXArray]) -> [MLXArray]. - - First use the wrapArguments() and wrapResult() function to transform - it into that form. Then call buildValueAndGradient() to produce a new - function with the same canonical form. - - Finally use unwrapArguments() and unwrapResult() to transform the function - back into the original signature. - - Note: this particular form of the function is already in the canonical - form and the wrap/unwrap calls are identity functions. - """, - seeAlso: "See ``grad(_:)-r8dv``", - arguments: { input, returnValue in - if input == "MLXArray" { - return "(_ f: @escaping (\(input)) -> \(returnValue))" - } else { - return - "(_ f: @escaping (\(input)) -> \(returnValue), argumentNumbers: [Int] = [0])" - } - }, - returnValue: { input, returnValue in - "(\(input)) -> \(returnValue)" - }, - body: { input in - let argumentNumbersUse: String - if input == "MLXArray" { - argumentNumbersUse = "[0]" - } else { - argumentNumbersUse = "argumentNumbers" - } - return - """ - let wrappedFunction = wrapResult(wrapArguments(f)) - let gradientFunction = buildGradient(wrappedFunction, argumentNumbers: \(argumentNumbersUse)) - let uag: (\(input)) -> [MLXArray] = unwrapArguments(gradientFunction) - return unwrapResult(uag) - """ - } - ), - "valueAndGrad": MethodInfo( - methodName: "valueAndGrad", - methodDescription: "Returns a function which computes the value and gradient of `f`.", - internalDocumentation: "", - seeAlso: "See ``valueAndGrad(_:)``", - arguments: { input, returnValue in - "(_ f: @escaping (\(input)) -> \(returnValue), argumentNumbers: [Int] = [0])" - }, - returnValue: { input, returnValue in - "(\(input)) -> (\(returnValue), \(returnValue))" - }, - body: { input in - """ - return buildValueAndGradient(f, argumentNumbers: argumentNumbers) - """ - } - ), - ] - - static func emitFunction(name: String, input: String, output: String) -> String { - var result = "" - - let info = methodInfo[name]! - - let firstMethod = input == "[MLXArray]" && output == "[MLXArray]" - let documentationText = firstMethod ? info.methodDescription : info.seeAlso - - result += indentLines(documentationText, lead: "// ") - result += "\n" - - let returnValue: String - if output.contains(",") { - returnValue = "(\(output))" - } else { - returnValue = output - } - - result += - """ - public func \(name)\(info.arguments(input, returnValue)) -> \(info.returnValue(input, returnValue)) { - - """ - - if firstMethod { - result += indentLines(info.internalDocumentation, lead: " // ") - result += "\n" - } - - result += indentLines(info.body(input), lead: " ") - result += "\n}\n" - - return result - } - - /// Tool to generate `Transforms+Variants.swift`. - /// - /// Either: - /// - run this and paste the output into `Transforms+Grad.swift` - /// - or `swift run GenerateGrad > Sources/MLX/Transforms+Grad.swift` - static func main() { - print( - """ - import Foundation - import Cmlx - - // This file is generated by GenerateGrad. - - """ - ) - - // emit the `grad()` variants -- these are the public functions that can be called. - // we emit a variant for each combination of inputs and outputs below - - let baseTypes = [ - "[MLXArray]", - "MLXArray", - ] - var inputs = baseTypes - var outputs = baseTypes - - for i in 2 ..< (inputTupleCount + 1) { - inputs.append(Array(repeating: "MLXArray", count: i).joined(separator: ", ")) - } - for i in 2 ..< (outputTupleCount + 1) { - outputs.append(Array(repeating: "MLXArray", count: i).joined(separator: ", ")) - } - - for input in inputs { - for output in outputs { - print(emitFunction(name: "grad", input: input, output: output)) - } - } - - print(emitFunction(name: "valueAndGrad", input: "[MLXArray]", output: "[MLXArray]")) - - // functions for converting to and from canonical types. For example this function: - // - // public func grad(_ f: @escaping (MLXArray, MLXArray) -> [MLXArray]) -> (MLXArray, MLXArray) -> [MLXArray] { - // - // takes a (MLXArray, MLXArray) -> [MLXArray]. We need to convert that to ([MLXArray]) -> [MLXArray] - // and can compute the gradient: - // - // let gradientFunction = buildValueAndGradient(wrapResult(wrapArguments(f))) - // - // the result is ([MLXArray]) -> [MLXArray] and we need to convert that back to - // (MLXArray, MLXArray) -> [MLXArray]: - // - // let uag: (MLXArray, MLXArray) -> [MLXArray] = unwrapArguments(gradientFunction) - // return unwrapResult(uag) - // - // These are all the different wrap/unwrap functions. - - // these are the special cases (NOP and single element tuple) - print( - """ - - // MARK: - Functions to wrap and unwrap types in closures - - @inline(__always) - private func wrapArguments(_ f: @escaping ([MLXArray]) -> Result) -> ([MLXArray]) -> Result { - f - } - - @inline(__always) - private func wrapResult(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> ([MLXArray]) -> [MLXArray] { - f - } - - @inline(__always) - private func wrapResult(_ f: @escaping ([MLXArray]) -> MLXArray) -> ([MLXArray]) -> [MLXArray] { - { (arrays: [MLXArray]) in - [f(arrays)] - } - } - - @inline(__always) - private func unwrapArguments(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> ([MLXArray]) -> [MLXArray] { - f - } - - @inline(__always) - private func unwrapResult(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> ([MLXArray]) -> [MLXArray] { - f - } - - """ - ) - - for c in 1 ..< (inputTupleCount + 1) { - let args = Array(repeating: "MLXArray", count: c).joined(separator: ", ") - let wrapArguments = (0 ..< c).map { "arrays[\($0)]" }.joined(separator: ", ") - - print( - """ - @inline(__always) - private func wrapArguments(_ f: @escaping (\(args)) -> Result) -> ([MLXArray]) -> Result { - { (arrays: [MLXArray]) in - f(\(wrapArguments)) - } - } - - """ - ) - } - - // note: from 2 since we have the 1 special case above - for c in 2 ..< (inputTupleCount + 1) { - let args = Array(repeating: "MLXArray", count: c).joined(separator: ", ") - let wrapResult = (0 ..< c).map { "v.\($0)" }.joined(separator: ", ") - - print( - """ - @inline(__always) - private func wrapResult(_ f: @escaping ([MLXArray]) -> (\(args))) -> ([MLXArray]) -> [MLXArray] { - { (arrays: [MLXArray]) in - let v = f(arrays) - return [\(wrapResult)] - } - } - - """ - ) - } - - for c in 1 ..< (inputTupleCount + 1) { - let args = Array(repeating: "MLXArray", count: c).joined(separator: ", ") - let unrwapArguments1 = (0 ..< c).map { "a\($0): MLXArray" }.joined(separator: ", ") - let unrwapArguments2 = (0 ..< c).map { "a\($0)" }.joined(separator: ", ") - print( - """ - @inline(__always) - private func unwrapArguments(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> (\(args)) -> [MLXArray] { - { (\(unrwapArguments1)) in - f([\(unrwapArguments2)]) - } - } - - """ - ) - } - - // unwrapResult is a little more complicated because we have to handle all the - // input/output pairs. - - // [MLXArray] -> (MLXArray...) - for c in 1 ..< (inputTupleCount + 1) { - let args = Array(repeating: "MLXArray", count: c).joined(separator: ", ") - let unwrapResult = (0 ..< c).map { "v[\($0)]" }.joined(separator: ", ") - - print( - """ - @inline(__always) - private func unwrapResult(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> ([MLXArray]) -> (\(args)) { - { (a0: [MLXArray]) in - let v = f(a0) - return (\(unwrapResult)) - } - } - - """ - ) - } - - // (MLXArray...) -> ([MLXArray]) - for c in 1 ..< (inputTupleCount + 1) { - let args = Array(repeating: "MLXArray", count: c).joined(separator: ", ") - let unwrapInputs1 = (0 ..< c).map { "a\($0): MLXArray" }.joined(separator: ", ") - let unwrapInputs2 = (0 ..< c).map { "a\($0)" }.joined(separator: ", ") - - print( - """ - @inline(__always) - private func unwrapResult(_ f: @escaping (\(args)) -> [MLXArray]) -> (\(args)) -> [MLXArray] { - { (\(unwrapInputs1)) in - f(\(unwrapInputs2)) - } - } - - """ - ) - } - - // (MLXArray...) -> (MLXArray...) - for c in 1 ..< (inputTupleCount + 1) { - let output = Array(repeating: "MLXArray", count: c).joined(separator: ", ") - let unwrapResult = (0 ..< c).map { "v[\($0)]" }.joined(separator: ", ") - - for argc in 1 ..< (inputTupleCount + 1) { - let inputArgs = Array(repeating: "MLXArray", count: argc).joined(separator: ", ") - let unwrapInputs1 = (0 ..< argc).map { "a\($0): MLXArray" }.joined(separator: ", ") - let unwrapInputs2 = (0 ..< argc).map { "a\($0)" }.joined(separator: ", ") - - print( - """ - @inline(__always) - private func unwrapResult(_ f: @escaping (\(inputArgs)) -> [MLXArray]) -> (\(inputArgs)) -> (\(output)) { - { (\(unwrapInputs1)) in - let v = f(\(unwrapInputs2)) - return (\(unwrapResult)) - } - } - - """ - ) - } - } - } - -} diff --git a/Tests/CmlxTests/CmlxTests.swift b/Tests/CmlxTests/CmlxTests.swift index 55c59c1c..cd49cf76 100644 --- a/Tests/CmlxTests/CmlxTests.swift +++ b/Tests/CmlxTests/CmlxTests.swift @@ -20,15 +20,15 @@ class CmlxTests: XCTestCase { var data: [Float] = [1, 2, 3, 4, 5, 6] var shape: [Int32] = [2, 3] - let arr = mlx_array_from_data(&data, &shape, 2, MLX_FLOAT32)! + let arr = mlx_array_new_data(&data, &shape, 2, MLX_FLOAT32) + defer { mlx_array_free(arr) } - let str = mlx_tostring(UnsafeMutableRawPointer(arr))! - defer { mlx_free(UnsafeMutableRawPointer(str)) } + var str = mlx_string_new() + mlx_array_tostring(&str, arr) + defer { mlx_string_free(str) } let description = String(cString: mlx_string_data(str)) print(description) - - mlx_free(UnsafeMutableRawPointer(arr)) } } diff --git a/Tests/MLXTests/MLXArrayTests.swift b/Tests/MLXTests/MLXArrayTests.swift index 9bc5b18f..dc1fa8b5 100644 --- a/Tests/MLXTests/MLXArrayTests.swift +++ b/Tests/MLXTests/MLXArrayTests.swift @@ -79,20 +79,6 @@ class MLXArrayTests: XCTestCase { XCTAssertEqual(s_arr, expected) } - func testAsArrayNonContiguous3() { - // reversed via strides -- note that the base pointer for the - // storage has an offset applied to it - let a = MLXArray(0 ..< 9, [3, 3]) - - let s = asStrided(a, [3, 3], strides: [-3, -1], offset: 8) - - let expected: [Int32] = [8, 7, 6, 5, 4, 3, 2, 1, 0] - assertEqual(s, MLXArray(expected, [3, 3])) - - let s_arr = s.asArray(Int32.self) - XCTAssertEqual(s_arr, expected) - } - func testAsArrayNonContiguous4() { // buffer with holes (last dimension has stride of 2 and // thus larger storage than it physically needs) diff --git a/Tests/MLXTests/MLXFastKernelTests.swift b/Tests/MLXTests/MLXFastKernelTests.swift index 0e0d22bc..635e3246 100644 --- a/Tests/MLXTests/MLXFastKernelTests.swift +++ b/Tests/MLXTests/MLXFastKernelTests.swift @@ -19,15 +19,14 @@ class MLXFastKernelTests: XCTestCase { source: """ uint elem = thread_position_in_grid.x; out1[elem] = a[elem]; - """) - - let out = kernel( - inputs: [a], + """, grid: (4, 1, 1), threadGroup: (2, 1, 1), outputShapes: [[2, 2]], outputDTypes: [.float32]) + let out = kernel([a]) + XCTAssertTrue(allClose(out[0], a).all().item()) } @@ -50,15 +49,7 @@ class MLXFastKernelTests: XCTestCase { out1[elem] = 1; } out2[elem] = a[1] + b[2] + c[1] - d; - """) - - let out = kernel( - inputs: [ - a, - MLXArray([3, 4, 5]), - c, - 7.3, - ], + """, template: [ ("e", true), ("f", 3), @@ -69,6 +60,13 @@ class MLXFastKernelTests: XCTestCase { outputShapes: [[2, 2], [3, 2]], outputDTypes: [.float32, .int32]) + let out = kernel([ + a, + MLXArray([3, 4, 5]), + c, + 7.3, + ]) + XCTAssertTrue(allClose(out[0], full([2, 2], values: 14.0484)).all().item()) XCTAssertTrue(allClose(out[1], full([3, 2], values: -2)).all().item()) } diff --git a/Tests/MLXTests/OpsTests.swift b/Tests/MLXTests/OpsTests.swift index 7dd3381e..b1013549 100644 --- a/Tests/MLXTests/OpsTests.swift +++ b/Tests/MLXTests/OpsTests.swift @@ -38,14 +38,6 @@ class OpsTests: XCTestCase { assertEqual(b, MLXArray(1 ..< 13, [3, 4])) } - func testAsStridedReverse() { - let a = MLXArray(0 ..< 16, [4, 4]) - let expected = MLXArray((0 ..< 16).reversed(), [4, 4]) - - let b = asStrided(a, [4, 4], strides: [-4, -1], offset: 15) - assertEqual(b, expected) - } - func testTensordot() { let a = MLXArray(0 ..< 60, [3, 4, 5]).asType(.float32) let b = MLXArray(0 ..< 24, [4, 3, 2]).asType(.float32) diff --git a/Tests/MLXTests/Utils.swift b/Tests/MLXTests/Utils.swift index 13c90ec6..73e58e23 100644 --- a/Tests/MLXTests/Utils.swift +++ b/Tests/MLXTests/Utils.swift @@ -11,7 +11,7 @@ func assertEqual( XCTAssertEqual(array1.shape, array2.shape, "shapes differ: \(array1.shape) != \(array2.shape)") XCTAssertTrue( array1.allClose(array2, rtol: rtol, atol: atol).item(Bool.self), - "contents differ:\n\(array1)\n\(array2))))") + "contents differ:\n\(array1)\n\(array2)") } func assertEqual( diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index b24c9c55..be472d99 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -16,11 +16,6 @@ mkdir build cd build cmake ../Source/Cmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0 -# NOTE: -# until mlx supports overriding the METAL_VERSION you will need to edit -# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION -# to "3.0" - # run the cmake build to generate the source files cd mlx/backend/metal make \ @@ -73,4 +68,4 @@ done; rm Source/Cmlx/mlx-generated/*.tmp # Update the headers -./tools/fix-metal-includes.sh \ No newline at end of file +./tools/fix-metal-includes.sh diff --git a/vendor/README.md b/vendor/README.md index 5d4697a2..2703cdcf 100644 --- a/vendor/README.md +++ b/vendor/README.md @@ -14,7 +14,7 @@ This is https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz This comes from https://developer.apple.com/metal/cpp/ specifically: -- https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip +- https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip Note that `metal-cpp.patch` has been applied to the contents of that zip. diff --git a/vendor/metal-cpp.patch b/vendor/metal-cpp.patch index e65d74ee..e13b756d 100644 --- a/vendor/metal-cpp.patch +++ b/vendor/metal-cpp.patch @@ -59,3 +59,16 @@ diff --color -u -r Foundation/NSBundle.hpp metal-cpp/Foundation/NSBundle.hpp _MTL_PRIVATE_DEF_SEL(setSize_, "setSize:"); _MTL_PRIVATE_DEF_SEL(setSlice_, +diff --git a/vendor/metal-cpp/Metal/MTLDevice.hpp b/vendor/metal-cpp/Metal/MTLDevice.hpp +index 17efb88..2b6bc3b 100644 +--- a/vendor/metal-cpp/Metal/MTLDevice.hpp ++++ b/vendor/metal-cpp/Metal/MTLDevice.hpp +@@ -636,7 +636,7 @@ _NS_EXPORT MTL::Device* MTL::CreateSystemDefaultDevice() + + _NS_EXPORT NS::Array* MTL::CopyAllDevices() + { +-#if (defined __IPHONE_18) || (defined __MAC_10_11) ++#if (__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000) || (__MAC_OS_X_VERSION_MIN_REQUIRED >= 101100) + return ::MTLCopyAllDevices(); + #else + return nullptr; diff --git a/vendor/metal-cpp/Foundation/Foundation.hpp b/vendor/metal-cpp/Foundation/Foundation.hpp index 8b64277f..31e8fb3c 100644 --- a/vendor/metal-cpp/Foundation/Foundation.hpp +++ b/vendor/metal-cpp/Foundation/Foundation.hpp @@ -2,7 +2,7 @@ // // Foundation/Foundation.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSArray.hpp b/vendor/metal-cpp/Foundation/NSArray.hpp index 7ccdb804..455f083b 100644 --- a/vendor/metal-cpp/Foundation/NSArray.hpp +++ b/vendor/metal-cpp/Foundation/NSArray.hpp @@ -2,7 +2,7 @@ // // Foundation/NSArray.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSAutoreleasePool.hpp b/vendor/metal-cpp/Foundation/NSAutoreleasePool.hpp index 3008590d..6d01a465 100644 --- a/vendor/metal-cpp/Foundation/NSAutoreleasePool.hpp +++ b/vendor/metal-cpp/Foundation/NSAutoreleasePool.hpp @@ -2,7 +2,7 @@ // // Foundation/NSAutoreleasePool.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSBundle.hpp b/vendor/metal-cpp/Foundation/NSBundle.hpp index af40c111..bc2de8bc 100644 --- a/vendor/metal-cpp/Foundation/NSBundle.hpp +++ b/vendor/metal-cpp/Foundation/NSBundle.hpp @@ -2,7 +2,7 @@ // // Foundation/NSBundle.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSData.hpp b/vendor/metal-cpp/Foundation/NSData.hpp index ddfa6dd0..3ad36060 100644 --- a/vendor/metal-cpp/Foundation/NSData.hpp +++ b/vendor/metal-cpp/Foundation/NSData.hpp @@ -2,7 +2,7 @@ // // Foundation/NSData.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSDate.hpp b/vendor/metal-cpp/Foundation/NSDate.hpp index 61f10a95..0a5ec7dd 100644 --- a/vendor/metal-cpp/Foundation/NSDate.hpp +++ b/vendor/metal-cpp/Foundation/NSDate.hpp @@ -2,7 +2,7 @@ // // Foundation/NSDate.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSDefines.hpp b/vendor/metal-cpp/Foundation/NSDefines.hpp index a042be63..38bbb56b 100644 --- a/vendor/metal-cpp/Foundation/NSDefines.hpp +++ b/vendor/metal-cpp/Foundation/NSDefines.hpp @@ -2,7 +2,7 @@ // // Foundation/NSDefines.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSDictionary.hpp b/vendor/metal-cpp/Foundation/NSDictionary.hpp index 078cd5c8..d4a1519d 100644 --- a/vendor/metal-cpp/Foundation/NSDictionary.hpp +++ b/vendor/metal-cpp/Foundation/NSDictionary.hpp @@ -2,7 +2,7 @@ // // Foundation/NSDictionary.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSEnumerator.hpp b/vendor/metal-cpp/Foundation/NSEnumerator.hpp index eed19dba..5a2500c1 100644 --- a/vendor/metal-cpp/Foundation/NSEnumerator.hpp +++ b/vendor/metal-cpp/Foundation/NSEnumerator.hpp @@ -2,7 +2,7 @@ // // Foundation/NSEnumerator.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSError.hpp b/vendor/metal-cpp/Foundation/NSError.hpp index f19ff861..ea331d46 100644 --- a/vendor/metal-cpp/Foundation/NSError.hpp +++ b/vendor/metal-cpp/Foundation/NSError.hpp @@ -2,7 +2,7 @@ // // Foundation/NSError.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSLock.hpp b/vendor/metal-cpp/Foundation/NSLock.hpp index ca371fba..01df2194 100644 --- a/vendor/metal-cpp/Foundation/NSLock.hpp +++ b/vendor/metal-cpp/Foundation/NSLock.hpp @@ -2,7 +2,7 @@ // // Foundation/NSLock.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSNotification.hpp b/vendor/metal-cpp/Foundation/NSNotification.hpp index 49cf2d4a..6b5be121 100644 --- a/vendor/metal-cpp/Foundation/NSNotification.hpp +++ b/vendor/metal-cpp/Foundation/NSNotification.hpp @@ -2,7 +2,7 @@ // // Foundation/NSNotification.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSNumber.hpp b/vendor/metal-cpp/Foundation/NSNumber.hpp index 13c78024..eec7ceac 100644 --- a/vendor/metal-cpp/Foundation/NSNumber.hpp +++ b/vendor/metal-cpp/Foundation/NSNumber.hpp @@ -2,7 +2,7 @@ // // Foundation/NSNumber.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSObjCRuntime.hpp b/vendor/metal-cpp/Foundation/NSObjCRuntime.hpp index a3860e94..9a5364c2 100644 --- a/vendor/metal-cpp/Foundation/NSObjCRuntime.hpp +++ b/vendor/metal-cpp/Foundation/NSObjCRuntime.hpp @@ -2,7 +2,7 @@ // // Foundation/NSObjCRuntime.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSObject.hpp b/vendor/metal-cpp/Foundation/NSObject.hpp index 489fd36f..907064f0 100644 --- a/vendor/metal-cpp/Foundation/NSObject.hpp +++ b/vendor/metal-cpp/Foundation/NSObject.hpp @@ -2,7 +2,7 @@ // // Foundation/NSObject.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSPrivate.hpp b/vendor/metal-cpp/Foundation/NSPrivate.hpp index af5ffb14..81981eb6 100644 --- a/vendor/metal-cpp/Foundation/NSPrivate.hpp +++ b/vendor/metal-cpp/Foundation/NSPrivate.hpp @@ -2,7 +2,7 @@ // // Foundation/NSPrivate.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,6 +33,19 @@ #if defined(NS_PRIVATE_IMPLEMENTATION) +#include + +namespace NS::Private +{ + template + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : _Type(); + } +} // NS::Private + #ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN #define _NS_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) #else @@ -52,9 +65,16 @@ #define _NS_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _NS_PRIVATE_VISIBILITY = _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) #define _NS_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _NS_PRIVATE_VISIBILITY = _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) #define _NS_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _NS_PRIVATE_VISIBILITY = sel_registerName(symbol) + +#if defined(__MAC_15_0) || defined(__IPHONE_18_0) || defined(__TVOS_18_0) #define _NS_PRIVATE_DEF_CONST(type, symbol) \ _NS_EXTERN type const NS##symbol _NS_PRIVATE_IMPORT; \ - type const NS::symbol = (nullptr != &NS##symbol) ? NS##symbol : nullptr + type const NS::symbol = (nullptr != &NS##symbol) ? NS##symbol : type() +#else +#define _NS_PRIVATE_DEF_CONST(type, symbol) \ + _NS_EXTERN type const MTL##symbol _NS_PRIVATE_IMPORT; \ + type const NS::symbol = Private::LoadSymbol("NS" #symbol) +#endif #else @@ -236,6 +256,8 @@ namespace Private "globallyUniqueString"); _NS_PRIVATE_DEF_SEL(hash, "hash"); + _NS_PRIVATE_DEF_SEL(hasPerformanceProfile_, + "hasPerformanceProfile:"); _NS_PRIVATE_DEF_SEL(hostName, "hostName"); _NS_PRIVATE_DEF_SEL(infoDictionary, @@ -294,6 +316,8 @@ namespace Private "integerValue"); _NS_PRIVATE_DEF_SEL(intValue, "intValue"); + _NS_PRIVATE_DEF_SEL(isDeviceCertified_, + "isDeviceCertifiedFor:"); _NS_PRIVATE_DEF_SEL(isEqual_, "isEqual:"); _NS_PRIVATE_DEF_SEL(isEqualToNumber_, diff --git a/vendor/metal-cpp/Foundation/NSProcessInfo.hpp b/vendor/metal-cpp/Foundation/NSProcessInfo.hpp index 565b5993..09c212d5 100644 --- a/vendor/metal-cpp/Foundation/NSProcessInfo.hpp +++ b/vendor/metal-cpp/Foundation/NSProcessInfo.hpp @@ -2,7 +2,7 @@ // // Foundation/NSProcessInfo.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ namespace NS { _NS_CONST(NotificationName, ProcessInfoThermalStateDidChangeNotification); _NS_CONST(NotificationName, ProcessInfoPowerStateDidChangeNotification); +_NS_CONST(NotificationName, ProcessInfoPerformanceProfileDidChangeNotification); _NS_ENUM(NS::Integer, ProcessInfoThermalState) { ProcessInfoThermalStateNominal = 0, @@ -55,6 +56,13 @@ _NS_OPTIONS(std::uint64_t, ActivityOptions) { ActivityLatencyCritical = 0xFF00000000ULL, }; +typedef NS::Integer DeviceCertification; +_NS_CONST(DeviceCertification, DeviceCertificationiPhonePerformanceGaming); + +typedef NS::Integer ProcessPerformanceProfile; +_NS_CONST(ProcessPerformanceProfile, ProcessPerformanceProfileDefault); +_NS_CONST(ProcessPerformanceProfile, ProcessPerformanceProfileSustained); + class ProcessInfo : public Referencing { public: @@ -101,6 +109,10 @@ class ProcessInfo : public Referencing bool isiOSAppOnMac() const; bool isMacCatalystApp() const; + + bool isDeviceCertified(DeviceCertification performanceTier) const; + bool hasPerformanceProfile(ProcessPerformanceProfile performanceProfile) const; + }; } @@ -109,6 +121,12 @@ class ProcessInfo : public Referencing _NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoThermalStateDidChangeNotification); _NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoPowerStateDidChangeNotification); +// The linker searches for these symbols in the Metal framework, be sure to link it in as well: +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoPerformanceProfileDidChangeNotification); +_NS_PRIVATE_DEF_CONST(NS::DeviceCertification, DeviceCertificationiPhonePerformanceGaming); +_NS_PRIVATE_DEF_CONST(NS::ProcessPerformanceProfile, ProcessPerformanceProfileDefault); +_NS_PRIVATE_DEF_CONST(NS::ProcessPerformanceProfile, ProcessPerformanceProfileSustained); + //------------------------------------------------------------------------------------------------------------------------------------------------------------- _NS_INLINE NS::ProcessInfo* NS::ProcessInfo::processInfo() @@ -352,3 +370,17 @@ _NS_INLINE bool NS::ProcessInfo::isMacCatalystApp() const } //------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isDeviceCertified(DeviceCertification performanceTier) const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isDeviceCertified_), performanceTier); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::hasPerformanceProfile(ProcessPerformanceProfile performanceProfile) const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(hasPerformanceProfile_), performanceProfile); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/vendor/metal-cpp/Foundation/NSRange.hpp b/vendor/metal-cpp/Foundation/NSRange.hpp index 2c5beb5b..8500271d 100644 --- a/vendor/metal-cpp/Foundation/NSRange.hpp +++ b/vendor/metal-cpp/Foundation/NSRange.hpp @@ -2,7 +2,7 @@ // // Foundation/NSRange.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSSet.hpp b/vendor/metal-cpp/Foundation/NSSet.hpp index a4eb0d64..382b6714 100644 --- a/vendor/metal-cpp/Foundation/NSSet.hpp +++ b/vendor/metal-cpp/Foundation/NSSet.hpp @@ -2,7 +2,7 @@ // // Foundation/NSSet.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSSharedPtr.hpp b/vendor/metal-cpp/Foundation/NSSharedPtr.hpp index 761ce2db..ff367a9f 100644 --- a/vendor/metal-cpp/Foundation/NSSharedPtr.hpp +++ b/vendor/metal-cpp/Foundation/NSSharedPtr.hpp @@ -2,7 +2,7 @@ // // Foundation/NSSharedPtr.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSString.hpp b/vendor/metal-cpp/Foundation/NSString.hpp index c601fc01..c48e0689 100644 --- a/vendor/metal-cpp/Foundation/NSString.hpp +++ b/vendor/metal-cpp/Foundation/NSString.hpp @@ -2,7 +2,7 @@ // // Foundation/NSString.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSTypes.hpp b/vendor/metal-cpp/Foundation/NSTypes.hpp index 5f098f67..e6b723e5 100644 --- a/vendor/metal-cpp/Foundation/NSTypes.hpp +++ b/vendor/metal-cpp/Foundation/NSTypes.hpp @@ -2,7 +2,7 @@ // // Foundation/NSTypes.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Foundation/NSURL.hpp b/vendor/metal-cpp/Foundation/NSURL.hpp index a7bc3e6e..d90e5d70 100644 --- a/vendor/metal-cpp/Foundation/NSURL.hpp +++ b/vendor/metal-cpp/Foundation/NSURL.hpp @@ -2,7 +2,7 @@ // // Foundation/NSURL.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/LICENSE.txt b/vendor/metal-cpp/LICENSE.txt index 6c877ff9..d07f885e 100644 --- a/vendor/metal-cpp/LICENSE.txt +++ b/vendor/metal-cpp/LICENSE.txt @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright © 2023 Apple Inc. + Copyright © 2024 Apple Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLAccelerationStructure.hpp b/vendor/metal-cpp/Metal/MTLAccelerationStructure.hpp index cb30db71..d4193f5b 100644 --- a/vendor/metal-cpp/Metal/MTLAccelerationStructure.hpp +++ b/vendor/metal-cpp/Metal/MTLAccelerationStructure.hpp @@ -2,7 +2,7 @@ // // Metal/MTLAccelerationStructure.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -49,6 +49,11 @@ _MTL_OPTIONS(uint32_t, AccelerationStructureInstanceOptions) { AccelerationStructureInstanceOptionNonOpaque = 8, }; +_MTL_ENUM(NS::Integer, MatrixLayout) { + MatrixLayoutColumnMajor = 0, + MatrixLayoutRowMajor = 1, +}; + class AccelerationStructureDescriptor : public NS::Copying { public: @@ -162,6 +167,9 @@ class AccelerationStructureTriangleGeometryDescriptor : public NS::Copying { public: @@ -494,6 +510,15 @@ class InstanceAccelerationStructureDescriptor : public NS::Copying(this, _MTL_PRIVATE_SEL(setTransformationMatrixBufferOffset_), transformationMatrixBufferOffset); } +// property: transformationMatrixLayout +_MTL_INLINE MTL::MatrixLayout MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + // static method: descriptor _MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::descriptor() { @@ -1101,6 +1146,17 @@ _MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::set Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBufferOffset_), transformationMatrixBufferOffset); } +// property: transformationMatrixLayout +_MTL_INLINE MTL::MatrixLayout MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + // static method: descriptor _MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::descriptor() { @@ -1657,6 +1713,39 @@ _MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransfor Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCount_), motionTransformCount); } +// property: instanceTransformationMatrixLayout +_MTL_INLINE MTL::MatrixLayout MTL::InstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +// property: motionTransformType +_MTL_INLINE MTL::TransformType MTL::InstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +// property: motionTransformStride +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + // static method: descriptor _MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::descriptor() { @@ -1807,6 +1896,39 @@ _MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotion Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBufferOffset_), motionTransformCountBufferOffset); } +// property: instanceTransformationMatrixLayout +_MTL_INLINE MTL::MatrixLayout MTL::IndirectInstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +// property: motionTransformType +_MTL_INLINE MTL::TransformType MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +// property: motionTransformStride +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + // static method: descriptor _MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::descriptor() { diff --git a/vendor/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp index e0b4ccd9..2ac79f81 100644 --- a/vendor/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLAccelerationStructureCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp b/vendor/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp index 146ffc20..13e3ffb9 100644 --- a/vendor/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp +++ b/vendor/metal-cpp/Metal/MTLAccelerationStructureTypes.hpp @@ -2,7 +2,7 @@ // // Metal/MTLAccelerationStructureTypes.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -74,6 +74,39 @@ struct AxisAlignedBoundingBox PackedFloat3 min; PackedFloat3 max; } _MTL_PACKED; + +struct PackedFloatQuaternion +{ + PackedFloatQuaternion(); + PackedFloatQuaternion(float x, float y, float z, float w); + + float& operator[](int idx); + const float& operator[](int idx) const; + + union + { + struct + { + float x; + float y; + float z; + float w; + }; + + float elements[4]; + }; + +} _MTL_PACKED; + +struct ComponentTransform +{ + PackedFloat3 scale; + PackedFloat3 shear; + PackedFloat3 pivot; + PackedFloatQuaternion rotation; + PackedFloat3 translation; +} _MTL_PACKED; + } //------------------------------------------------------------------------------------------------------------------------------------------------------------- @@ -167,3 +200,37 @@ _MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox(PackedFloat3 _mi } //------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloatQuaternion::PackedFloatQuaternion() + : x(0.0f) + , y(0.0f) + , z(0.0f) + , w(0.0f) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloatQuaternion::PackedFloatQuaternion(float x, float y, float z, float w) + : x(x) + , y(y) + , z(z) + , w(w) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float& MTL::PackedFloatQuaternion::operator[](int idx) +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE const float& MTL::PackedFloatQuaternion::operator[](int idx) const +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/vendor/metal-cpp/Metal/MTLAllocation.hpp b/vendor/metal-cpp/Metal/MTLAllocation.hpp new file mode 100644 index 00000000..a1ec3ca1 --- /dev/null +++ b/vendor/metal-cpp/Metal/MTLAllocation.hpp @@ -0,0 +1,43 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAllocation.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +#include + +namespace MTL +{ +class Allocation : public NS::Referencing +{ +public: + NS::UInteger allocatedSize() const; +}; + +} + +// property: allocatedSize +_MTL_INLINE NS::UInteger MTL::Allocation::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} diff --git a/vendor/metal-cpp/Metal/MTLArgument.hpp b/vendor/metal-cpp/Metal/MTLArgument.hpp index 796fa332..e9c97a9d 100644 --- a/vendor/metal-cpp/Metal/MTLArgument.hpp +++ b/vendor/metal-cpp/Metal/MTLArgument.hpp @@ -2,7 +2,7 @@ // // Metal/MTLArgument.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLArgumentEncoder.hpp b/vendor/metal-cpp/Metal/MTLArgumentEncoder.hpp index a81859c7..e23c9132 100644 --- a/vendor/metal-cpp/Metal/MTLArgumentEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLArgumentEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLArgumentEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLBinaryArchive.hpp b/vendor/metal-cpp/Metal/MTLBinaryArchive.hpp index 1c77c078..3626e068 100644 --- a/vendor/metal-cpp/Metal/MTLBinaryArchive.hpp +++ b/vendor/metal-cpp/Metal/MTLBinaryArchive.hpp @@ -2,7 +2,7 @@ // // Metal/MTLBinaryArchive.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,6 +61,10 @@ class BinaryArchive : public NS::Referencing bool addTileRenderPipelineFunctions(const class TileRenderPipelineDescriptor* descriptor, NS::Error** error); + bool addMeshRenderPipelineFunctions(const class MeshRenderPipelineDescriptor* descriptor, NS::Error** error); + + bool addLibrary(const class StitchedLibraryDescriptor* descriptor, NS::Error** error); + bool serializeToURL(const NS::URL* url, NS::Error** error); bool addFunction(const class FunctionDescriptor* descriptor, const class Library* library, NS::Error** error); @@ -126,6 +130,18 @@ _MTL_INLINE bool MTL::BinaryArchive::addTileRenderPipelineFunctions(const MTL::T return Object::sendMessage(this, _MTL_PRIVATE_SEL(addTileRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); } +// method: addMeshRenderPipelineFunctionsWithDescriptor:error: +_MTL_INLINE bool MTL::BinaryArchive::addMeshRenderPipelineFunctions(const MTL::MeshRenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addMeshRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +// method: addLibraryWithDescriptor:error: +_MTL_INLINE bool MTL::BinaryArchive::addLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addLibraryWithDescriptor_error_), descriptor, error); +} + // method: serializeToURL:error: _MTL_INLINE bool MTL::BinaryArchive::serializeToURL(const NS::URL* url, NS::Error** error) { diff --git a/vendor/metal-cpp/Metal/MTLBlitCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLBlitCommandEncoder.hpp index 8d4845a8..2701c1d6 100644 --- a/vendor/metal-cpp/Metal/MTLBlitCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLBlitCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLBlitCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLBlitPass.hpp b/vendor/metal-cpp/Metal/MTLBlitPass.hpp index 84987525..135b918c 100644 --- a/vendor/metal-cpp/Metal/MTLBlitPass.hpp +++ b/vendor/metal-cpp/Metal/MTLBlitPass.hpp @@ -2,7 +2,7 @@ // // Metal/MTLBlitPass.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLBuffer.hpp b/vendor/metal-cpp/Metal/MTLBuffer.hpp index 684ee8e2..72b209a7 100644 --- a/vendor/metal-cpp/Metal/MTLBuffer.hpp +++ b/vendor/metal-cpp/Metal/MTLBuffer.hpp @@ -2,7 +2,7 @@ // // Metal/MTLBuffer.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLCaptureManager.hpp b/vendor/metal-cpp/Metal/MTLCaptureManager.hpp index ebe7ddd2..52309a2e 100644 --- a/vendor/metal-cpp/Metal/MTLCaptureManager.hpp +++ b/vendor/metal-cpp/Metal/MTLCaptureManager.hpp @@ -2,7 +2,7 @@ // // Metal/MTLCaptureManager.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLCaptureScope.hpp b/vendor/metal-cpp/Metal/MTLCaptureScope.hpp index 6d5d1d6f..5e22a79a 100644 --- a/vendor/metal-cpp/Metal/MTLCaptureScope.hpp +++ b/vendor/metal-cpp/Metal/MTLCaptureScope.hpp @@ -2,7 +2,7 @@ // // Metal/MTLCaptureScope.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLCommandBuffer.hpp b/vendor/metal-cpp/Metal/MTLCommandBuffer.hpp index 64bdf35e..889f4baf 100644 --- a/vendor/metal-cpp/Metal/MTLCommandBuffer.hpp +++ b/vendor/metal-cpp/Metal/MTLCommandBuffer.hpp @@ -2,7 +2,7 @@ // // Metal/MTLCommandBuffer.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -80,6 +80,9 @@ class CommandBufferDescriptor : public NS::Copying MTL::CommandBufferErrorOption errorOptions() const; void setErrorOptions(MTL::CommandBufferErrorOption errorOptions); + + class LogState* logState() const; + void setLogState(const class LogState* logState); }; class CommandBufferEncoderInfo : public NS::Referencing @@ -182,6 +185,10 @@ class CommandBuffer : public NS::Referencing void pushDebugGroup(const NS::String* string); void popDebugGroup(); + + void useResidencySet(const class ResidencySet* residencySet); + + void useResidencySets(const class ResidencySet* const residencySets[], NS::UInteger count); }; } @@ -220,6 +227,17 @@ _MTL_INLINE void MTL::CommandBufferDescriptor::setErrorOptions(MTL::CommandBuffe Object::sendMessage(this, _MTL_PRIVATE_SEL(setErrorOptions_), errorOptions); } +// property: logState +_MTL_INLINE MTL::LogState* MTL::CommandBufferDescriptor::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + // property: label _MTL_INLINE NS::String* MTL::CommandBufferEncoderInfo::label() const { @@ -472,3 +490,15 @@ _MTL_INLINE void MTL::CommandBuffer::popDebugGroup() { Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); } + +// method: useResidencySet: +_MTL_INLINE void MTL::CommandBuffer::useResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySet_), residencySet); +} + +// method: useResidencySets:count: +_MTL_INLINE void MTL::CommandBuffer::useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySets_count_), residencySets, count); +} diff --git a/vendor/metal-cpp/Metal/MTLCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLCommandEncoder.hpp index 9a4d97c7..0ea3cde5 100644 --- a/vendor/metal-cpp/Metal/MTLCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLCommandQueue.hpp b/vendor/metal-cpp/Metal/MTLCommandQueue.hpp index 07b29844..6ce9fbf6 100644 --- a/vendor/metal-cpp/Metal/MTLCommandQueue.hpp +++ b/vendor/metal-cpp/Metal/MTLCommandQueue.hpp @@ -2,7 +2,7 @@ // // Metal/MTLCommandQueue.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -43,6 +43,28 @@ class CommandQueue : public NS::Referencing class CommandBuffer* commandBufferWithUnretainedReferences(); void insertDebugCaptureBoundary(); + + void addResidencySet(const class ResidencySet* residencySet); + + void addResidencySets(const class ResidencySet* const residencySets[], NS::UInteger count); + + void removeResidencySet(const class ResidencySet* residencySet); + + void removeResidencySets(const class ResidencySet* const residencySets[], NS::UInteger count); +}; + +class CommandQueueDescriptor : public NS::Copying +{ +public: + static class CommandQueueDescriptor* alloc(); + + class CommandQueueDescriptor* init(); + + NS::UInteger maxCommandBufferCount() const; + void setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount); + + class LogState* logState() const; + void setLogState(const class LogState* logState); }; } @@ -87,3 +109,61 @@ _MTL_INLINE void MTL::CommandQueue::insertDebugCaptureBoundary() { Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugCaptureBoundary)); } + +// method: addResidencySet: +_MTL_INLINE void MTL::CommandQueue::addResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySet_), residencySet); +} + +// method: addResidencySets:count: +_MTL_INLINE void MTL::CommandQueue::addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySets_count_), residencySets, count); +} + +// method: removeResidencySet: +_MTL_INLINE void MTL::CommandQueue::removeResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySet_), residencySet); +} + +// method: removeResidencySets:count: +_MTL_INLINE void MTL::CommandQueue::removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySets_count_), residencySets, count); +} + +// static method: alloc +_MTL_INLINE MTL::CommandQueueDescriptor* MTL::CommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCommandQueueDescriptor)); +} + +// method: init +_MTL_INLINE MTL::CommandQueueDescriptor* MTL::CommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +// property: maxCommandBufferCount +_MTL_INLINE NS::UInteger MTL::CommandQueueDescriptor::maxCommandBufferCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandBufferCount)); +} + +_MTL_INLINE void MTL::CommandQueueDescriptor::setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandBufferCount_), maxCommandBufferCount); +} + +// property: logState +_MTL_INLINE MTL::LogState* MTL::CommandQueueDescriptor::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE void MTL::CommandQueueDescriptor::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} diff --git a/vendor/metal-cpp/Metal/MTLComputeCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLComputeCommandEncoder.hpp index 50a3b241..43f2a5de 100644 --- a/vendor/metal-cpp/Metal/MTLComputeCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLComputeCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLComputeCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLComputePass.hpp b/vendor/metal-cpp/Metal/MTLComputePass.hpp index bbc71841..8afdeedf 100644 --- a/vendor/metal-cpp/Metal/MTLComputePass.hpp +++ b/vendor/metal-cpp/Metal/MTLComputePass.hpp @@ -2,7 +2,7 @@ // // Metal/MTLComputePass.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLComputePipeline.hpp b/vendor/metal-cpp/Metal/MTLComputePipeline.hpp index 3065b8f2..ae377a62 100644 --- a/vendor/metal-cpp/Metal/MTLComputePipeline.hpp +++ b/vendor/metal-cpp/Metal/MTLComputePipeline.hpp @@ -2,7 +2,7 @@ // // Metal/MTLComputePipeline.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ #include +#include "MTLPipeline.hpp" #include "MTLTypes.hpp" namespace MTL @@ -88,6 +89,9 @@ class ComputePipelineDescriptor : public NS::Copying NS::UInteger maxCallStackDepth() const; void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + MTL::ShaderValidation shaderValidation() const; + void setShaderValidation(MTL::ShaderValidation shaderValidation); }; class ComputePipelineState : public NS::Referencing @@ -116,6 +120,8 @@ class ComputePipelineState : public NS::Referencing class VisibleFunctionTable* newVisibleFunctionTable(const class VisibleFunctionTableDescriptor* descriptor); class IntersectionFunctionTable* newIntersectionFunctionTable(const class IntersectionFunctionTableDescriptor* descriptor); + + MTL::ShaderValidation shaderValidation() const; }; } @@ -300,6 +306,17 @@ _MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxCallStackDepth(NS::UInteg Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); } +// property: shaderValidation +_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + // property: label _MTL_INLINE NS::String* MTL::ComputePipelineState::label() const { @@ -371,3 +388,9 @@ _MTL_INLINE MTL::IntersectionFunctionTable* MTL::ComputePipelineState::newInters { return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionTableWithDescriptor_), descriptor); } + +// property: shaderValidation +_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineState::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} diff --git a/vendor/metal-cpp/Metal/MTLCounters.hpp b/vendor/metal-cpp/Metal/MTLCounters.hpp index c552f7ed..14febd41 100644 --- a/vendor/metal-cpp/Metal/MTLCounters.hpp +++ b/vendor/metal-cpp/Metal/MTLCounters.hpp @@ -2,7 +2,7 @@ // // Metal/MTLCounters.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLDefines.hpp b/vendor/metal-cpp/Metal/MTLDefines.hpp index 7dd8ff95..4260a2b1 100644 --- a/vendor/metal-cpp/Metal/MTLDefines.hpp +++ b/vendor/metal-cpp/Metal/MTLDefines.hpp @@ -2,7 +2,7 @@ // // Metal/MTLDefines.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLDepthStencil.hpp b/vendor/metal-cpp/Metal/MTLDepthStencil.hpp index ba8bd8b0..09cd404b 100644 --- a/vendor/metal-cpp/Metal/MTLDepthStencil.hpp +++ b/vendor/metal-cpp/Metal/MTLDepthStencil.hpp @@ -2,7 +2,7 @@ // // Metal/MTLDepthStencil.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLDevice.hpp b/vendor/metal-cpp/Metal/MTLDevice.hpp index 6843131d..2b6bc3b4 100644 --- a/vendor/metal-cpp/Metal/MTLDevice.hpp +++ b/vendor/metal-cpp/Metal/MTLDevice.hpp @@ -2,7 +2,7 @@ // // Metal/MTLDevice.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -115,6 +115,7 @@ _MTL_ENUM(NS::UInteger, DeviceLocation) { _MTL_OPTIONS(NS::UInteger, PipelineOption) { PipelineOptionNone = 0, PipelineOptionArgumentInfo = 1, + PipelineOptionBindingInfo = 1, PipelineOptionBufferTypeInfo = 2, PipelineOptionFailOnBinaryArchiveMiss = 4, }; @@ -315,10 +316,14 @@ class Device : public NS::Referencing NS::UInteger currentAllocatedSize() const; + class LogState* newLogState(const class LogStateDescriptor* descriptor, NS::Error** error); + class CommandQueue* newCommandQueue(); class CommandQueue* newCommandQueue(NS::UInteger maxCommandBufferCount); + class CommandQueue* newCommandQueue(const class CommandQueueDescriptor* descriptor); + MTL::SizeAndAlign heapTextureSizeAndAlign(const class TextureDescriptor* desc); MTL::SizeAndAlign heapBufferSizeAndAlign(NS::UInteger length, MTL::ResourceOptions options); @@ -499,6 +504,8 @@ class Device : public NS::Referencing void setShouldMaximizeConcurrentCompilation(bool shouldMaximizeConcurrentCompilation); NS::UInteger maximumConcurrentCompilationTaskCount() const; + + class ResidencySet* newResidencySet(const class ResidencySetDescriptor* desc, NS::Error** error); }; } @@ -629,11 +636,11 @@ _NS_EXPORT MTL::Device* MTL::CreateSystemDefaultDevice() _NS_EXPORT NS::Array* MTL::CopyAllDevices() { -#if TARGET_OS_OSX +#if (__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000) || (__MAC_OS_X_VERSION_MIN_REQUIRED >= 101100) return ::MTLCopyAllDevices(); #else return nullptr; -#endif // TARGET_OS_OSX +#endif // __IPHONE_18 } _NS_EXPORT NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object** pOutObserver, DeviceNotificationHandlerBlock handler) @@ -657,6 +664,7 @@ _NS_EXPORT NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object** pOutObserver, _NS_EXPORT void MTL::RemoveDeviceObserver(const NS::Object* pObserver) { + (void)pObserver; #if TARGET_OS_OSX ::MTLRemoveDeviceObserver(pObserver); #endif // TARGET_OS_OSX @@ -869,6 +877,12 @@ _MTL_INLINE NS::UInteger MTL::Device::currentAllocatedSize() const return Object::sendMessage(this, _MTL_PRIVATE_SEL(currentAllocatedSize)); } +// method: newLogStateWithDescriptor:error: +_MTL_INLINE MTL::LogState* MTL::Device::newLogState(const MTL::LogStateDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLogStateWithDescriptor_error_), descriptor, error); +} + // method: newCommandQueue _MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue() { @@ -881,6 +895,12 @@ _MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue(NS::UInteger maxComm return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueueWithMaxCommandBufferCount_), maxCommandBufferCount); } +// method: newCommandQueueWithDescriptor: +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue(const MTL::CommandQueueDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueueWithDescriptor_), descriptor); +} + // method: heapTextureSizeAndAlignWithDescriptor: _MTL_INLINE MTL::SizeAndAlign MTL::Device::heapTextureSizeAndAlign(const MTL::TextureDescriptor* desc) { @@ -1425,3 +1445,9 @@ _MTL_INLINE NS::UInteger MTL::Device::maximumConcurrentCompilationTaskCount() co { return Object::sendMessage(this, _MTL_PRIVATE_SEL(maximumConcurrentCompilationTaskCount)); } + +// method: newResidencySetWithDescriptor:error: +_MTL_INLINE MTL::ResidencySet* MTL::Device::newResidencySet(const MTL::ResidencySetDescriptor* desc, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newResidencySetWithDescriptor_error_), desc, error); +} diff --git a/vendor/metal-cpp/Metal/MTLDrawable.hpp b/vendor/metal-cpp/Metal/MTLDrawable.hpp index 58945d03..40b0fa3e 100644 --- a/vendor/metal-cpp/Metal/MTLDrawable.hpp +++ b/vendor/metal-cpp/Metal/MTLDrawable.hpp @@ -2,7 +2,7 @@ // // Metal/MTLDrawable.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLDynamicLibrary.hpp b/vendor/metal-cpp/Metal/MTLDynamicLibrary.hpp index f951d125..d8620451 100644 --- a/vendor/metal-cpp/Metal/MTLDynamicLibrary.hpp +++ b/vendor/metal-cpp/Metal/MTLDynamicLibrary.hpp @@ -2,7 +2,7 @@ // // Metal/MTLDynamicLibrary.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLEvent.hpp b/vendor/metal-cpp/Metal/MTLEvent.hpp index 4ccd42d4..b0f78cd2 100644 --- a/vendor/metal-cpp/Metal/MTLEvent.hpp +++ b/vendor/metal-cpp/Metal/MTLEvent.hpp @@ -2,7 +2,7 @@ // // Metal/MTLEvent.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,9 +60,10 @@ class SharedEvent : public NS::Referencing class SharedEventHandle* newSharedEventHandle(); + bool waitUntilSignaledValue(uint64_t value, uint64_t milliseconds); + uint64_t signaledValue() const; void setSignaledValue(uint64_t signaledValue); - bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS); }; class SharedEventHandle : public NS::SecureCoding @@ -130,6 +131,12 @@ _MTL_INLINE MTL::SharedEventHandle* MTL::SharedEvent::newSharedEventHandle() return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEventHandle)); } +// method: waitUntilSignaledValue:timeoutMS: +_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t value, uint64_t milliseconds) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), value, milliseconds); +} + // property: signaledValue _MTL_INLINE uint64_t MTL::SharedEvent::signaledValue() const { @@ -141,11 +148,6 @@ _MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); } -// method: waitUntilSignaledValue -_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) { - return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS); -} - // static method: alloc _MTL_INLINE MTL::SharedEventHandle* MTL::SharedEventHandle::alloc() { diff --git a/vendor/metal-cpp/Metal/MTLFence.hpp b/vendor/metal-cpp/Metal/MTLFence.hpp index c8ef24ca..8d057aeb 100644 --- a/vendor/metal-cpp/Metal/MTLFence.hpp +++ b/vendor/metal-cpp/Metal/MTLFence.hpp @@ -2,7 +2,7 @@ // // Metal/MTLFence.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLFunctionConstantValues.hpp b/vendor/metal-cpp/Metal/MTLFunctionConstantValues.hpp index d23d98fc..b0de8e13 100644 --- a/vendor/metal-cpp/Metal/MTLFunctionConstantValues.hpp +++ b/vendor/metal-cpp/Metal/MTLFunctionConstantValues.hpp @@ -2,7 +2,7 @@ // // Metal/MTLFunctionConstantValues.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLFunctionDescriptor.hpp b/vendor/metal-cpp/Metal/MTLFunctionDescriptor.hpp index ec82d981..189c7b34 100644 --- a/vendor/metal-cpp/Metal/MTLFunctionDescriptor.hpp +++ b/vendor/metal-cpp/Metal/MTLFunctionDescriptor.hpp @@ -2,7 +2,7 @@ // // Metal/MTLFunctionDescriptor.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ _MTL_OPTIONS(NS::UInteger, FunctionOptions) { FunctionOptionNone = 0, FunctionOptionCompileToBinary = 1, FunctionOptionStoreFunctionInMetalScript = 2, + FunctionOptionFailOnBinaryArchiveMiss = 4, }; class FunctionDescriptor : public NS::Copying diff --git a/vendor/metal-cpp/Metal/MTLFunctionHandle.hpp b/vendor/metal-cpp/Metal/MTLFunctionHandle.hpp index 30f71f2e..43e1c844 100644 --- a/vendor/metal-cpp/Metal/MTLFunctionHandle.hpp +++ b/vendor/metal-cpp/Metal/MTLFunctionHandle.hpp @@ -2,7 +2,7 @@ // // Metal/MTLFunctionHandle.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLFunctionLog.hpp b/vendor/metal-cpp/Metal/MTLFunctionLog.hpp index 2a899525..2dc0cec4 100644 --- a/vendor/metal-cpp/Metal/MTLFunctionLog.hpp +++ b/vendor/metal-cpp/Metal/MTLFunctionLog.hpp @@ -2,7 +2,7 @@ // // Metal/MTLFunctionLog.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLFunctionStitching.hpp b/vendor/metal-cpp/Metal/MTLFunctionStitching.hpp index 0ae7d6a4..06458293 100644 --- a/vendor/metal-cpp/Metal/MTLFunctionStitching.hpp +++ b/vendor/metal-cpp/Metal/MTLFunctionStitching.hpp @@ -2,7 +2,7 @@ // // Metal/MTLFunctionStitching.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,6 +30,12 @@ namespace MTL { +_MTL_OPTIONS(NS::UInteger, StitchedLibraryOptions) { + StitchedLibraryOptionNone = 0, + StitchedLibraryOptionFailOnBinaryArchiveMiss = 1, + StitchedLibraryOptionStoreLibraryInMetalScript = 2, +}; + class FunctionStitchingAttribute : public NS::Referencing { public: @@ -114,6 +120,12 @@ class StitchedLibraryDescriptor : public NS::Copying NS::Array* functions() const; void setFunctions(const NS::Array* functions); + + NS::Array* binaryArchives() const; + void setBinaryArchives(const NS::Array* binaryArchives); + + MTL::StitchedLibraryOptions options() const; + void setOptions(MTL::StitchedLibraryOptions options); }; } @@ -305,3 +317,25 @@ _MTL_INLINE void MTL::StitchedLibraryDescriptor::setFunctions(const NS::Array* f { Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_), functions); } + +// property: binaryArchives +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +// property: options +_MTL_INLINE MTL::StitchedLibraryOptions MTL::StitchedLibraryDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setOptions(MTL::StitchedLibraryOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/vendor/metal-cpp/Metal/MTLHeaderBridge.hpp b/vendor/metal-cpp/Metal/MTLHeaderBridge.hpp index 91408838..6455013f 100644 --- a/vendor/metal-cpp/Metal/MTLHeaderBridge.hpp +++ b/vendor/metal-cpp/Metal/MTLHeaderBridge.hpp @@ -2,7 +2,7 @@ // // Metal/MTLHeaderBridge.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -51,6 +51,7 @@ _MTL_PRIVATE_DEF_CLS(MTLBufferLayoutDescriptorArray); _MTL_PRIVATE_DEF_CLS(MTLCaptureDescriptor); _MTL_PRIVATE_DEF_CLS(MTLCaptureManager); _MTL_PRIVATE_DEF_CLS(MTLCommandBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCommandQueueDescriptor); _MTL_PRIVATE_DEF_CLS(MTLCompileOptions); _MTL_PRIVATE_DEF_CLS(MTLComputePassDescriptor); _MTL_PRIVATE_DEF_CLS(MTLComputePassSampleBufferAttachmentDescriptor); @@ -74,6 +75,7 @@ _MTL_PRIVATE_DEF_CLS(MTLInstanceAccelerationStructureDescriptor); _MTL_PRIVATE_DEF_CLS(MTLIntersectionFunctionDescriptor); _MTL_PRIVATE_DEF_CLS(MTLIntersectionFunctionTableDescriptor); _MTL_PRIVATE_DEF_CLS(MTLLinkedFunctions); +_MTL_PRIVATE_DEF_CLS(MTLLogStateDescriptor); _MTL_PRIVATE_DEF_CLS(MTLMeshRenderPipelineDescriptor); _MTL_PRIVATE_DEF_CLS(MTLMotionKeyframeData); _MTL_PRIVATE_DEF_CLS(MTLPipelineBufferDescriptor); @@ -97,6 +99,7 @@ _MTL_PRIVATE_DEF_CLS(MTLRenderPipelineColorAttachmentDescriptorArray); _MTL_PRIVATE_DEF_CLS(MTLRenderPipelineDescriptor); _MTL_PRIVATE_DEF_CLS(MTLRenderPipelineFunctionsDescriptor); _MTL_PRIVATE_DEF_CLS(MTLRenderPipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTLResidencySetDescriptor); _MTL_PRIVATE_DEF_CLS(MTLResourceStatePassDescriptor); _MTL_PRIVATE_DEF_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptor); _MTL_PRIVATE_DEF_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptorArray); @@ -130,6 +133,7 @@ namespace MTL::Private::Protocol _MTL_PRIVATE_DEF_PRO(MTLAccelerationStructure); _MTL_PRIVATE_DEF_PRO(MTLAccelerationStructureCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLAllocation); _MTL_PRIVATE_DEF_PRO(MTLArgumentEncoder); _MTL_PRIVATE_DEF_PRO(MTLBinaryArchive); _MTL_PRIVATE_DEF_PRO(MTLBinding); @@ -169,11 +173,13 @@ _MTL_PRIVATE_DEF_PRO(MTLIndirectRenderCommand); _MTL_PRIVATE_DEF_PRO(MTLIntersectionFunctionTable); _MTL_PRIVATE_DEF_PRO(MTLLibrary); _MTL_PRIVATE_DEF_PRO(MTLLogContainer); +_MTL_PRIVATE_DEF_PRO(MTLLogState); _MTL_PRIVATE_DEF_PRO(MTLObjectPayloadBinding); _MTL_PRIVATE_DEF_PRO(MTLParallelRenderCommandEncoder); _MTL_PRIVATE_DEF_PRO(MTLRasterizationRateMap); _MTL_PRIVATE_DEF_PRO(MTLRenderCommandEncoder); _MTL_PRIVATE_DEF_PRO(MTLRenderPipelineState); +_MTL_PRIVATE_DEF_PRO(MTLResidencySet); _MTL_PRIVATE_DEF_PRO(MTLResource); _MTL_PRIVATE_DEF_PRO(MTLResourceStateCommandEncoder); _MTL_PRIVATE_DEF_PRO(MTLSamplerState); @@ -204,6 +210,10 @@ _MTL_PRIVATE_DEF_SEL(accelerationStructureSizesWithDescriptor_, "accelerationStructureSizesWithDescriptor:"); _MTL_PRIVATE_DEF_SEL(access, "access"); +_MTL_PRIVATE_DEF_SEL(addAllocation_, + "addAllocation:"); +_MTL_PRIVATE_DEF_SEL(addAllocations_count_, + "addAllocations:count:"); _MTL_PRIVATE_DEF_SEL(addBarrier, "addBarrier"); _MTL_PRIVATE_DEF_SEL(addCompletedHandler_, @@ -214,18 +224,32 @@ _MTL_PRIVATE_DEF_SEL(addDebugMarker_range_, "addDebugMarker:range:"); _MTL_PRIVATE_DEF_SEL(addFunctionWithDescriptor_library_error_, "addFunctionWithDescriptor:library:error:"); +_MTL_PRIVATE_DEF_SEL(addLibraryWithDescriptor_error_, + "addLibraryWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addLogHandler_, + "addLogHandler:"); +_MTL_PRIVATE_DEF_SEL(addMeshRenderPipelineFunctionsWithDescriptor_error_, + "addMeshRenderPipelineFunctionsWithDescriptor:error:"); _MTL_PRIVATE_DEF_SEL(addPresentedHandler_, "addPresentedHandler:"); _MTL_PRIVATE_DEF_SEL(addRenderPipelineFunctionsWithDescriptor_error_, "addRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addResidencySet_, + "addResidencySet:"); +_MTL_PRIVATE_DEF_SEL(addResidencySets_count_, + "addResidencySets:count:"); _MTL_PRIVATE_DEF_SEL(addScheduledHandler_, "addScheduledHandler:"); _MTL_PRIVATE_DEF_SEL(addTileRenderPipelineFunctionsWithDescriptor_error_, "addTileRenderPipelineFunctionsWithDescriptor:error:"); _MTL_PRIVATE_DEF_SEL(alignment, "alignment"); +_MTL_PRIVATE_DEF_SEL(allAllocations, + "allAllocations"); _MTL_PRIVATE_DEF_SEL(allocatedSize, "allocatedSize"); +_MTL_PRIVATE_DEF_SEL(allocationCount, + "allocationCount"); _MTL_PRIVATE_DEF_SEL(allowDuplicateIntersectionFunctionInvocation, "allowDuplicateIntersectionFunctionInvocation"); _MTL_PRIVATE_DEF_SEL(allowGPUOptimizedContents, @@ -304,6 +328,8 @@ _MTL_PRIVATE_DEF_SEL(bufferOffset, "bufferOffset"); _MTL_PRIVATE_DEF_SEL(bufferPointerType, "bufferPointerType"); +_MTL_PRIVATE_DEF_SEL(bufferSize, + "bufferSize"); _MTL_PRIVATE_DEF_SEL(bufferStructType, "bufferStructType"); _MTL_PRIVATE_DEF_SEL(buffers, @@ -362,6 +388,8 @@ _MTL_PRIVATE_DEF_SEL(constantDataAtIndex_, "constantDataAtIndex:"); _MTL_PRIVATE_DEF_SEL(constantValues, "constantValues"); +_MTL_PRIVATE_DEF_SEL(containsAllocation_, + "containsAllocation:"); _MTL_PRIVATE_DEF_SEL(contents, "contents"); _MTL_PRIVATE_DEF_SEL(controlDependencies, @@ -528,6 +556,8 @@ _MTL_PRIVATE_DEF_SEL(elementTextureReferenceType, "elementTextureReferenceType"); _MTL_PRIVATE_DEF_SEL(elementType, "elementType"); +_MTL_PRIVATE_DEF_SEL(enableLogging, + "enableLogging"); _MTL_PRIVATE_DEF_SEL(encodeSignalEvent_value_, "encodeSignalEvent:value:"); _MTL_PRIVATE_DEF_SEL(encodeWaitForEvent_value_, @@ -544,6 +574,8 @@ _MTL_PRIVATE_DEF_SEL(endOfFragmentSampleIndex, "endOfFragmentSampleIndex"); _MTL_PRIVATE_DEF_SEL(endOfVertexSampleIndex, "endOfVertexSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endResidency, + "endResidency"); _MTL_PRIVATE_DEF_SEL(enqueue, "enqueue"); _MTL_PRIVATE_DEF_SEL(enqueueBarrier, @@ -682,6 +714,8 @@ _MTL_PRIVATE_DEF_SEL(initWithSampleCount_, "initWithSampleCount:"); _MTL_PRIVATE_DEF_SEL(initWithSampleCount_horizontal_vertical_, "initWithSampleCount:horizontal:vertical:"); +_MTL_PRIVATE_DEF_SEL(initialCapacity, + "initialCapacity"); _MTL_PRIVATE_DEF_SEL(inputPrimitiveTopology, "inputPrimitiveTopology"); _MTL_PRIVATE_DEF_SEL(insertDebugCaptureBoundary, @@ -706,6 +740,8 @@ _MTL_PRIVATE_DEF_SEL(instanceDescriptorStride, "instanceDescriptorStride"); _MTL_PRIVATE_DEF_SEL(instanceDescriptorType, "instanceDescriptorType"); +_MTL_PRIVATE_DEF_SEL(instanceTransformationMatrixLayout, + "instanceTransformationMatrixLayout"); _MTL_PRIVATE_DEF_SEL(instancedAccelerationStructures, "instancedAccelerationStructures"); _MTL_PRIVATE_DEF_SEL(intersectionFunctionTableDescriptor, @@ -804,6 +840,8 @@ _MTL_PRIVATE_DEF_SEL(lodMaxClamp, "lodMaxClamp"); _MTL_PRIVATE_DEF_SEL(lodMinClamp, "lodMinClamp"); +_MTL_PRIVATE_DEF_SEL(logState, + "logState"); _MTL_PRIVATE_DEF_SEL(logs, "logs"); _MTL_PRIVATE_DEF_SEL(magFilter, @@ -814,6 +852,10 @@ _MTL_PRIVATE_DEF_SEL(mapPhysicalToScreenCoordinates_forLayer_, "mapPhysicalToScreenCoordinates:forLayer:"); _MTL_PRIVATE_DEF_SEL(mapScreenToPhysicalCoordinates_forLayer_, "mapScreenToPhysicalCoordinates:forLayer:"); +_MTL_PRIVATE_DEF_SEL(mathFloatingPointFunctions, + "mathFloatingPointFunctions"); +_MTL_PRIVATE_DEF_SEL(mathMode, + "mathMode"); _MTL_PRIVATE_DEF_SEL(maxAnisotropy, "maxAnisotropy"); _MTL_PRIVATE_DEF_SEL(maxArgumentBufferSamplerCount, @@ -926,6 +968,10 @@ _MTL_PRIVATE_DEF_SEL(motionTransformCountBuffer, "motionTransformCountBuffer"); _MTL_PRIVATE_DEF_SEL(motionTransformCountBufferOffset, "motionTransformCountBufferOffset"); +_MTL_PRIVATE_DEF_SEL(motionTransformStride, + "motionTransformStride"); +_MTL_PRIVATE_DEF_SEL(motionTransformType, + "motionTransformType"); _MTL_PRIVATE_DEF_SEL(moveTextureMappingsFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, "moveTextureMappingsFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); _MTL_PRIVATE_DEF_SEL(mutability, @@ -966,6 +1012,8 @@ _MTL_PRIVATE_DEF_SEL(newCaptureScopeWithDevice_, "newCaptureScopeWithDevice:"); _MTL_PRIVATE_DEF_SEL(newCommandQueue, "newCommandQueue"); +_MTL_PRIVATE_DEF_SEL(newCommandQueueWithDescriptor_, + "newCommandQueueWithDescriptor:"); _MTL_PRIVATE_DEF_SEL(newCommandQueueWithMaxCommandBufferCount_, "newCommandQueueWithMaxCommandBufferCount:"); _MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithAdditionalBinaryFunctions_error_, @@ -1044,6 +1092,8 @@ _MTL_PRIVATE_DEF_SEL(newLibraryWithStitchedDescriptor_error_, "newLibraryWithStitchedDescriptor:error:"); _MTL_PRIVATE_DEF_SEL(newLibraryWithURL_error_, "newLibraryWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newLogStateWithDescriptor_error_, + "newLogStateWithDescriptor:error:"); _MTL_PRIVATE_DEF_SEL(newRasterizationRateMapWithDescriptor_, "newRasterizationRateMapWithDescriptor:"); _MTL_PRIVATE_DEF_SEL(newRemoteBufferViewForDevice_, @@ -1068,6 +1118,8 @@ _MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithTileDescriptor_options_completion "newRenderPipelineStateWithTileDescriptor:options:completionHandler:"); _MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithTileDescriptor_options_reflection_error_, "newRenderPipelineStateWithTileDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newResidencySetWithDescriptor_error_, + "newResidencySetWithDescriptor:error:"); _MTL_PRIVATE_DEF_SEL(newSamplerStateWithDescriptor_, "newSamplerStateWithDescriptor:"); _MTL_PRIVATE_DEF_SEL(newScratchBufferWithMinimumSize_, @@ -1252,8 +1304,18 @@ _MTL_PRIVATE_DEF_SEL(remoteStorageBuffer, "remoteStorageBuffer"); _MTL_PRIVATE_DEF_SEL(remoteStorageTexture, "remoteStorageTexture"); +_MTL_PRIVATE_DEF_SEL(removeAllAllocations, + "removeAllAllocations"); _MTL_PRIVATE_DEF_SEL(removeAllDebugMarkers, "removeAllDebugMarkers"); +_MTL_PRIVATE_DEF_SEL(removeAllocation_, + "removeAllocation:"); +_MTL_PRIVATE_DEF_SEL(removeAllocations_count_, + "removeAllocations:count:"); +_MTL_PRIVATE_DEF_SEL(removeResidencySet_, + "removeResidencySet:"); +_MTL_PRIVATE_DEF_SEL(removeResidencySets_count_, + "removeResidencySets:count:"); _MTL_PRIVATE_DEF_SEL(renderCommandEncoder, "renderCommandEncoder"); _MTL_PRIVATE_DEF_SEL(renderCommandEncoderWithDescriptor_, @@ -1270,6 +1332,8 @@ _MTL_PRIVATE_DEF_SEL(replaceRegion_mipmapLevel_slice_withBytes_bytesPerRow_bytes "replaceRegion:mipmapLevel:slice:withBytes:bytesPerRow:bytesPerImage:"); _MTL_PRIVATE_DEF_SEL(replaceRegion_mipmapLevel_withBytes_bytesPerRow_, "replaceRegion:mipmapLevel:withBytes:bytesPerRow:"); +_MTL_PRIVATE_DEF_SEL(requestResidency, + "requestResidency"); _MTL_PRIVATE_DEF_SEL(required, "required"); _MTL_PRIVATE_DEF_SEL(reset, @@ -1394,6 +1458,8 @@ _MTL_PRIVATE_DEF_SEL(setBufferOffset_atIndex_, "setBufferOffset:atIndex:"); _MTL_PRIVATE_DEF_SEL(setBufferOffset_attributeStride_atIndex_, "setBufferOffset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferSize_, + "setBufferSize:"); _MTL_PRIVATE_DEF_SEL(setBuffers_offsets_attributeStrides_withRange_, "setBuffers:offsets:attributeStrides:withRange:"); _MTL_PRIVATE_DEF_SEL(setBuffers_offsets_withRange_, @@ -1508,6 +1574,8 @@ _MTL_PRIVATE_DEF_SEL(setDestinationRGBBlendFactor_, "setDestinationRGBBlendFactor:"); _MTL_PRIVATE_DEF_SEL(setDispatchType_, "setDispatchType:"); +_MTL_PRIVATE_DEF_SEL(setEnableLogging_, + "setEnableLogging:"); _MTL_PRIVATE_DEF_SEL(setEndOfEncoderSampleIndex_, "setEndOfEncoderSampleIndex:"); _MTL_PRIVATE_DEF_SEL(setEndOfFragmentSampleIndex_, @@ -1604,6 +1672,8 @@ _MTL_PRIVATE_DEF_SEL(setInheritBuffers_, "setInheritBuffers:"); _MTL_PRIVATE_DEF_SEL(setInheritPipelineState_, "setInheritPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setInitialCapacity_, + "setInitialCapacity:"); _MTL_PRIVATE_DEF_SEL(setInputPrimitiveTopology_, "setInputPrimitiveTopology:"); _MTL_PRIVATE_DEF_SEL(setInsertLibraries_, @@ -1624,6 +1694,8 @@ _MTL_PRIVATE_DEF_SEL(setInstanceDescriptorStride_, "setInstanceDescriptorStride:"); _MTL_PRIVATE_DEF_SEL(setInstanceDescriptorType_, "setInstanceDescriptorType:"); +_MTL_PRIVATE_DEF_SEL(setInstanceTransformationMatrixLayout_, + "setInstanceTransformationMatrixLayout:"); _MTL_PRIVATE_DEF_SEL(setInstancedAccelerationStructures_, "setInstancedAccelerationStructures:"); _MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTable_atBufferIndex_, @@ -1662,8 +1734,14 @@ _MTL_PRIVATE_DEF_SEL(setLodMaxClamp_, "setLodMaxClamp:"); _MTL_PRIVATE_DEF_SEL(setLodMinClamp_, "setLodMinClamp:"); +_MTL_PRIVATE_DEF_SEL(setLogState_, + "setLogState:"); _MTL_PRIVATE_DEF_SEL(setMagFilter_, "setMagFilter:"); +_MTL_PRIVATE_DEF_SEL(setMathFloatingPointFunctions_, + "setMathFloatingPointFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMathMode_, + "setMathMode:"); _MTL_PRIVATE_DEF_SEL(setMaxAnisotropy_, "setMaxAnisotropy:"); _MTL_PRIVATE_DEF_SEL(setMaxCallStackDepth_, @@ -1758,6 +1836,10 @@ _MTL_PRIVATE_DEF_SEL(setMotionTransformCountBuffer_, "setMotionTransformCountBuffer:"); _MTL_PRIVATE_DEF_SEL(setMotionTransformCountBufferOffset_, "setMotionTransformCountBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformStride_, + "setMotionTransformStride:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformType_, + "setMotionTransformType:"); _MTL_PRIVATE_DEF_SEL(setMutability_, "setMutability:"); _MTL_PRIVATE_DEF_SEL(setName_, @@ -1816,6 +1898,8 @@ _MTL_PRIVATE_DEF_SEL(setOutputNode_, "setOutputNode:"); _MTL_PRIVATE_DEF_SEL(setOutputURL_, "setOutputURL:"); +_MTL_PRIVATE_DEF_SEL(setOwnerWithIdentity_, + "setOwnerWithIdentity:"); _MTL_PRIVATE_DEF_SEL(setPayloadMemoryLength_, "setPayloadMemoryLength:"); _MTL_PRIVATE_DEF_SEL(setPixelFormat_, @@ -1914,13 +1998,12 @@ _MTL_PRIVATE_DEF_SEL(setSegmentControlPointCount_, "setSegmentControlPointCount:"); _MTL_PRIVATE_DEF_SEL(setSegmentCount_, "setSegmentCount:"); +_MTL_PRIVATE_DEF_SEL(setShaderValidation_, + "setShaderValidation:"); _MTL_PRIVATE_DEF_SEL(setShouldMaximizeConcurrentCompilation_, "setShouldMaximizeConcurrentCompilation:"); _MTL_PRIVATE_DEF_SEL(setSignaledValue_, "setSignaledValue:"); -_MTL_PRIVATE_DEF_SEL( - waitUntilSignaledValue_timeoutMS_, - "waitUntilSignaledValue:timeoutMS:"); _MTL_PRIVATE_DEF_SEL(setSize_, "setSize:"); _MTL_PRIVATE_DEF_SEL(setSlice_, @@ -2069,6 +2152,8 @@ _MTL_PRIVATE_DEF_SEL(setTransformationMatrixBuffer_, "setTransformationMatrixBuffer:"); _MTL_PRIVATE_DEF_SEL(setTransformationMatrixBufferOffset_, "setTransformationMatrixBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixLayout_, + "setTransformationMatrixLayout:"); _MTL_PRIVATE_DEF_SEL(setTriangleCount_, "setTriangleCount:"); _MTL_PRIVATE_DEF_SEL(setTriangleFillMode_, @@ -2159,6 +2244,8 @@ _MTL_PRIVATE_DEF_SEL(setWidth_, "setWidth:"); _MTL_PRIVATE_DEF_SEL(setWriteMask_, "setWriteMask:"); +_MTL_PRIVATE_DEF_SEL(shaderValidation, + "shaderValidation"); _MTL_PRIVATE_DEF_SEL(sharedCaptureManager, "sharedCaptureManager"); _MTL_PRIVATE_DEF_SEL(shouldMaximizeConcurrentCompilation, @@ -2357,6 +2444,8 @@ _MTL_PRIVATE_DEF_SEL(transformationMatrixBuffer, "transformationMatrixBuffer"); _MTL_PRIVATE_DEF_SEL(transformationMatrixBufferOffset, "transformationMatrixBufferOffset"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixLayout, + "transformationMatrixLayout"); _MTL_PRIVATE_DEF_SEL(triangleCount, "triangleCount"); _MTL_PRIVATE_DEF_SEL(tryCancel, @@ -2385,6 +2474,10 @@ _MTL_PRIVATE_DEF_SEL(useHeaps_count_, "useHeaps:count:"); _MTL_PRIVATE_DEF_SEL(useHeaps_count_stages_, "useHeaps:count:stages:"); +_MTL_PRIVATE_DEF_SEL(useResidencySet_, + "useResidencySet:"); +_MTL_PRIVATE_DEF_SEL(useResidencySets_count_, + "useResidencySets:count:"); _MTL_PRIVATE_DEF_SEL(useResource_usage_, "useResource:usage:"); _MTL_PRIVATE_DEF_SEL(useResource_usage_stages_, @@ -2439,6 +2532,8 @@ _MTL_PRIVATE_DEF_SEL(waitUntilCompleted, "waitUntilCompleted"); _MTL_PRIVATE_DEF_SEL(waitUntilScheduled, "waitUntilScheduled"); +_MTL_PRIVATE_DEF_SEL(waitUntilSignaledValue_timeoutMS_, + "waitUntilSignaledValue:timeoutMS:"); _MTL_PRIVATE_DEF_SEL(width, "width"); _MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_, diff --git a/vendor/metal-cpp/Metal/MTLHeap.hpp b/vendor/metal-cpp/Metal/MTLHeap.hpp index bb5b8361..dc3bfb58 100644 --- a/vendor/metal-cpp/Metal/MTLHeap.hpp +++ b/vendor/metal-cpp/Metal/MTLHeap.hpp @@ -2,7 +2,7 @@ // // Metal/MTLHeap.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ #include +#include "MTLAllocation.hpp" #include "MTLDevice.hpp" #include "MTLHeap.hpp" #include "MTLResource.hpp" @@ -67,7 +68,7 @@ class HeapDescriptor : public NS::Copying void setType(MTL::HeapType type); }; -class Heap : public NS::Referencing +class Heap : public NS::Referencing { public: NS::String* label() const; diff --git a/vendor/metal-cpp/Metal/MTLIOCommandBuffer.hpp b/vendor/metal-cpp/Metal/MTLIOCommandBuffer.hpp index fb28f3e7..5ad86c73 100644 --- a/vendor/metal-cpp/Metal/MTLIOCommandBuffer.hpp +++ b/vendor/metal-cpp/Metal/MTLIOCommandBuffer.hpp @@ -2,7 +2,7 @@ // // Metal/MTLIOCommandBuffer.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLIOCommandQueue.hpp b/vendor/metal-cpp/Metal/MTLIOCommandQueue.hpp index b5f5d1c2..e76212f5 100644 --- a/vendor/metal-cpp/Metal/MTLIOCommandQueue.hpp +++ b/vendor/metal-cpp/Metal/MTLIOCommandQueue.hpp @@ -2,7 +2,7 @@ // // Metal/MTLIOCommandQueue.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -105,6 +105,8 @@ class IOFileHandle : public NS::Referencing } +_MTL_PRIVATE_DEF_WEAK_CONST(NS::ErrorDomain, IOErrorDomain); + // method: enqueueBarrier _MTL_INLINE void MTL::IOCommandQueue::enqueueBarrier() { diff --git a/vendor/metal-cpp/Metal/MTLIOCompressor.hpp b/vendor/metal-cpp/Metal/MTLIOCompressor.hpp index ad9e92d1..71be4403 100644 --- a/vendor/metal-cpp/Metal/MTLIOCompressor.hpp +++ b/vendor/metal-cpp/Metal/MTLIOCompressor.hpp @@ -2,7 +2,7 @@ // // Metal/MTLIOCompressor.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp b/vendor/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp index ce7ba5d1..03686644 100644 --- a/vendor/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp +++ b/vendor/metal-cpp/Metal/MTLIndirectCommandBuffer.hpp @@ -2,7 +2,7 @@ // // Metal/MTLIndirectCommandBuffer.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp index ff973ef0..3f921ef5 100644 --- a/vendor/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLIndirectCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLIndirectCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp b/vendor/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp index e1c0911e..b62b4930 100644 --- a/vendor/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp +++ b/vendor/metal-cpp/Metal/MTLIntersectionFunctionTable.hpp @@ -2,7 +2,7 @@ // // Metal/MTLIntersectionFunctionTable.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLLibrary.hpp b/vendor/metal-cpp/Metal/MTLLibrary.hpp index 1933ff01..129c73f5 100644 --- a/vendor/metal-cpp/Metal/MTLLibrary.hpp +++ b/vendor/metal-cpp/Metal/MTLLibrary.hpp @@ -2,7 +2,7 @@ // // Metal/MTLLibrary.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -147,6 +147,7 @@ _MTL_ENUM(NS::UInteger, LanguageVersion) { LanguageVersion2_4 = 131076, LanguageVersion3_0 = 196608, LanguageVersion3_1 = 196609, + LanguageVersion3_2 = 196610, }; _MTL_ENUM(NS::Integer, LibraryType) { @@ -164,45 +165,65 @@ _MTL_ENUM(NS::Integer, CompileSymbolVisibility) { CompileSymbolVisibilityHidden = 1, }; +_MTL_ENUM(NS::Integer, MathMode) { + MathModeSafe = 0, + MathModeRelaxed = 1, + MathModeFast = 2, +}; + +_MTL_ENUM(NS::Integer, MathFloatingPointFunctions) { + MathFloatingPointFunctionsFast = 0, + MathFloatingPointFunctionsPrecise = 1, +}; + class CompileOptions : public NS::Copying { public: - static class CompileOptions* alloc(); + static class CompileOptions* alloc(); + + class CompileOptions* init(); + + NS::Dictionary* preprocessorMacros() const; + void setPreprocessorMacros(const NS::Dictionary* preprocessorMacros); + + bool fastMathEnabled() const; + void setFastMathEnabled(bool fastMathEnabled); - class CompileOptions* init(); + MTL::MathMode mathMode() const; + void setMathMode(MTL::MathMode mathMode); - NS::Dictionary* preprocessorMacros() const; - void setPreprocessorMacros(const NS::Dictionary* preprocessorMacros); + MTL::MathFloatingPointFunctions mathFloatingPointFunctions() const; + void setMathFloatingPointFunctions(MTL::MathFloatingPointFunctions mathFloatingPointFunctions); - bool fastMathEnabled() const; - void setFastMathEnabled(bool fastMathEnabled); + MTL::LanguageVersion languageVersion() const; + void setLanguageVersion(MTL::LanguageVersion languageVersion); - MTL::LanguageVersion languageVersion() const; - void setLanguageVersion(MTL::LanguageVersion languageVersion); + MTL::LibraryType libraryType() const; + void setLibraryType(MTL::LibraryType libraryType); - MTL::LibraryType libraryType() const; - void setLibraryType(MTL::LibraryType libraryType); + NS::String* installName() const; + void setInstallName(const NS::String* installName); - NS::String* installName() const; - void setInstallName(const NS::String* installName); + NS::Array* libraries() const; + void setLibraries(const NS::Array* libraries); - NS::Array* libraries() const; - void setLibraries(const NS::Array* libraries); + bool preserveInvariance() const; + void setPreserveInvariance(bool preserveInvariance); - bool preserveInvariance() const; - void setPreserveInvariance(bool preserveInvariance); + MTL::LibraryOptimizationLevel optimizationLevel() const; + void setOptimizationLevel(MTL::LibraryOptimizationLevel optimizationLevel); - MTL::LibraryOptimizationLevel optimizationLevel() const; - void setOptimizationLevel(MTL::LibraryOptimizationLevel optimizationLevel); + MTL::CompileSymbolVisibility compileSymbolVisibility() const; + void setCompileSymbolVisibility(MTL::CompileSymbolVisibility compileSymbolVisibility); - MTL::CompileSymbolVisibility compileSymbolVisibility() const; - void setCompileSymbolVisibility(MTL::CompileSymbolVisibility compileSymbolVisibility); + bool allowReferencingUndefinedSymbols() const; + void setAllowReferencingUndefinedSymbols(bool allowReferencingUndefinedSymbols); - bool allowReferencingUndefinedSymbols() const; - void setAllowReferencingUndefinedSymbols(bool allowReferencingUndefinedSymbols); + NS::UInteger maxTotalThreadsPerThreadgroup() const; + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); - NS::UInteger maxTotalThreadsPerThreadgroup() const; - void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + bool enableLogging() const; + void setEnableLogging(bool enableLogging); }; _MTL_ENUM(NS::UInteger, LibraryError) { @@ -494,6 +515,28 @@ _MTL_INLINE void MTL::CompileOptions::setFastMathEnabled(bool fastMathEnabled) Object::sendMessage(this, _MTL_PRIVATE_SEL(setFastMathEnabled_), fastMathEnabled); } +// property: mathMode +_MTL_INLINE MTL::MathMode MTL::CompileOptions::mathMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mathMode)); +} + +_MTL_INLINE void MTL::CompileOptions::setMathMode(MTL::MathMode mathMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMathMode_), mathMode); +} + +// property: mathFloatingPointFunctions +_MTL_INLINE MTL::MathFloatingPointFunctions MTL::CompileOptions::mathFloatingPointFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mathFloatingPointFunctions)); +} + +_MTL_INLINE void MTL::CompileOptions::setMathFloatingPointFunctions(MTL::MathFloatingPointFunctions mathFloatingPointFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMathFloatingPointFunctions_), mathFloatingPointFunctions); +} + // property: languageVersion _MTL_INLINE MTL::LanguageVersion MTL::CompileOptions::languageVersion() const { @@ -593,6 +636,17 @@ _MTL_INLINE void MTL::CompileOptions::setMaxTotalThreadsPerThreadgroup(NS::UInte Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); } +// property: enableLogging +_MTL_INLINE bool MTL::CompileOptions::enableLogging() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(enableLogging)); +} + +_MTL_INLINE void MTL::CompileOptions::setEnableLogging(bool enableLogging) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEnableLogging_), enableLogging); +} + _MTL_INLINE void MTL::Library::newFunction(const NS::String* pFunctionName, const FunctionConstantValues* pConstantValues, const std::function& completionHandler) { __block std::function blockCompletionHandler = completionHandler; diff --git a/vendor/metal-cpp/Metal/MTLLinkedFunctions.hpp b/vendor/metal-cpp/Metal/MTLLinkedFunctions.hpp index 89ee8759..c5cf556d 100644 --- a/vendor/metal-cpp/Metal/MTLLinkedFunctions.hpp +++ b/vendor/metal-cpp/Metal/MTLLinkedFunctions.hpp @@ -2,7 +2,7 @@ // // Metal/MTLLinkedFunctions.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLLogState.hpp b/vendor/metal-cpp/Metal/MTLLogState.hpp new file mode 100644 index 00000000..ede37e86 --- /dev/null +++ b/vendor/metal-cpp/Metal/MTLLogState.hpp @@ -0,0 +1,111 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLogState.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +#include + +#include "MTLLogState.hpp" + +namespace MTL +{ +_MTL_ENUM(NS::Integer, LogLevel) { + LogLevelUndefined = 0, + LogLevelDebug = 1, + LogLevelInfo = 2, + LogLevelNotice = 3, + LogLevelError = 4, + LogLevelFault = 5, +}; + +class LogState : public NS::Referencing +{ +public: + void addLogHandler(void (^block)(NS::String*, NS::String*, MTL::LogLevel, NS::String*)); +}; + +class LogStateDescriptor : public NS::Copying +{ +public: + static class LogStateDescriptor* alloc(); + + class LogStateDescriptor* init(); + + MTL::LogLevel level() const; + void setLevel(MTL::LogLevel level); + + NS::Integer bufferSize() const; + void setBufferSize(NS::Integer bufferSize); +}; + +_MTL_CONST(NS::ErrorDomain, LogStateErrorDomain); + +_MTL_ENUM(NS::UInteger, LogStateError) { + LogStateErrorInvalidSize = 1, + LogStateErrorInvalid = 2, +}; + +} + +_MTL_PRIVATE_DEF_WEAK_CONST(NS::ErrorDomain, LogStateErrorDomain); + +// method: addLogHandler: +_MTL_INLINE void MTL::LogState::addLogHandler(void (^block)(NS::String*, NS::String*, MTL::LogLevel, NS::String*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addLogHandler_), block); +} + +// static method: alloc +_MTL_INLINE MTL::LogStateDescriptor* MTL::LogStateDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLogStateDescriptor)); +} + +// method: init +_MTL_INLINE MTL::LogStateDescriptor* MTL::LogStateDescriptor::init() +{ + return NS::Object::init(); +} + +// property: level +_MTL_INLINE MTL::LogLevel MTL::LogStateDescriptor::level() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(level)); +} + +_MTL_INLINE void MTL::LogStateDescriptor::setLevel(MTL::LogLevel level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevel_), level); +} + +// property: bufferSize +_MTL_INLINE NS::Integer MTL::LogStateDescriptor::bufferSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferSize)); +} + +_MTL_INLINE void MTL::LogStateDescriptor::setBufferSize(NS::Integer bufferSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferSize_), bufferSize); +} diff --git a/vendor/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp index d6d6ad8c..4214be75 100644 --- a/vendor/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLParallelRenderCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLParallelRenderCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLPipeline.hpp b/vendor/metal-cpp/Metal/MTLPipeline.hpp index 1b6859ca..3971fb3c 100644 --- a/vendor/metal-cpp/Metal/MTLPipeline.hpp +++ b/vendor/metal-cpp/Metal/MTLPipeline.hpp @@ -2,7 +2,7 @@ // // Metal/MTLPipeline.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,6 +36,12 @@ _MTL_ENUM(NS::UInteger, Mutability) { MutabilityImmutable = 2, }; +_MTL_ENUM(NS::Integer, ShaderValidation) { + ShaderValidationDefault = 0, + ShaderValidationEnabled = 1, + ShaderValidationDisabled = 2, +}; + class PipelineBufferDescriptor : public NS::Copying { public: diff --git a/vendor/metal-cpp/Metal/MTLPixelFormat.hpp b/vendor/metal-cpp/Metal/MTLPixelFormat.hpp index 8e28a6be..d8cf3a9a 100644 --- a/vendor/metal-cpp/Metal/MTLPixelFormat.hpp +++ b/vendor/metal-cpp/Metal/MTLPixelFormat.hpp @@ -2,7 +2,7 @@ // // Metal/MTLPixelFormat.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLPrivate.hpp b/vendor/metal-cpp/Metal/MTLPrivate.hpp index e72e6354..34a65915 100644 --- a/vendor/metal-cpp/Metal/MTLPrivate.hpp +++ b/vendor/metal-cpp/Metal/MTLPrivate.hpp @@ -2,7 +2,7 @@ // // Metal/MTLPrivate.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ namespace MTL::Private } } // MTL::Private -#if defined(__MAC_10_16) || defined(__MAC_11_0) || defined(__MAC_12_0) || defined(__MAC_13_0) || defined(__MAC_14_0) || defined(__IPHONE_14_0) || defined(__IPHONE_15_0) || defined(__IPHONE_16_0) || defined(__IPHONE_17_0) || defined(__TVOS_14_0) || defined(__TVOS_15_0) || defined(__TVOS_16_0) || defined(__TVOS_17_0) +#if defined(__MAC_15_0) || defined(__IPHONE_18_0) || defined(__TVOS_18_0) #define _MTL_PRIVATE_DEF_STR(type, symbol) \ _MTL_EXTERN type const MTL##symbol _MTL_PRIVATE_IMPORT; \ @@ -97,7 +97,7 @@ namespace MTL::Private #define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) _MTL_PRIVATE_DEF_CONST(type, symbol) -#endif // defined(__MAC_10_16) || defined(__MAC_11_0) || defined(__MAC_12_0) || defined(__MAC_13_0) || defined(__MAC_14_0) || defined(__IPHONE_14_0) || defined(__IPHONE_15_0) || defined(__IPHONE_16_0) || defined(__IPHONE_17_0) || defined(__TVOS_14_0) || defined(__TVOS_15_0) || defined(__TVOS_16_0) || defined(__TVOS_17_0) +#endif // defined(__MAC_15_0) || defined(__IPHONE_18_0) || defined(__TVOS_18_0) #else diff --git a/vendor/metal-cpp/Metal/MTLRasterizationRate.hpp b/vendor/metal-cpp/Metal/MTLRasterizationRate.hpp index cac23392..5bfe3803 100644 --- a/vendor/metal-cpp/Metal/MTLRasterizationRate.hpp +++ b/vendor/metal-cpp/Metal/MTLRasterizationRate.hpp @@ -2,7 +2,7 @@ // // Metal/MTLRasterizationRate.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLRenderCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLRenderCommandEncoder.hpp index a3b0c420..35ae6dd1 100644 --- a/vendor/metal-cpp/Metal/MTLRenderCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLRenderCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLRenderCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLRenderPass.hpp b/vendor/metal-cpp/Metal/MTLRenderPass.hpp index cb47ceab..a5dce59d 100644 --- a/vendor/metal-cpp/Metal/MTLRenderPass.hpp +++ b/vendor/metal-cpp/Metal/MTLRenderPass.hpp @@ -2,7 +2,7 @@ // // Metal/MTLRenderPass.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLRenderPipeline.hpp b/vendor/metal-cpp/Metal/MTLRenderPipeline.hpp index c4e501d2..44f9a24e 100644 --- a/vendor/metal-cpp/Metal/MTLRenderPipeline.hpp +++ b/vendor/metal-cpp/Metal/MTLRenderPipeline.hpp @@ -2,7 +2,7 @@ // // Metal/MTLRenderPipeline.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ #include +#include "MTLPipeline.hpp" #include "MTLPixelFormat.hpp" #include "MTLRenderCommandEncoder.hpp" #include "MTLRenderPipeline.hpp" @@ -266,6 +267,9 @@ class RenderPipelineDescriptor : public NS::Copying void setMaxFragmentCallStackDepth(NS::UInteger maxFragmentCallStackDepth); void reset(); + + MTL::ShaderValidation shaderValidation() const; + void setShaderValidation(MTL::ShaderValidation shaderValidation); }; class RenderPipelineFunctionsDescriptor : public NS::Copying @@ -321,6 +325,8 @@ class RenderPipelineState : public NS::Referencing class IntersectionFunctionTable* newIntersectionFunctionTable(const class IntersectionFunctionTableDescriptor* descriptor, MTL::RenderStages stage); class RenderPipelineState* newRenderPipelineState(const class RenderPipelineFunctionsDescriptor* additionalBinaryFunctions, NS::Error** error); + + MTL::ShaderValidation shaderValidation() const; }; class RenderPipelineColorAttachmentDescriptorArray : public NS::Referencing @@ -400,6 +406,9 @@ class TileRenderPipelineDescriptor : public NS::Copying @@ -471,6 +480,9 @@ class MeshRenderPipelineDescriptor : public NS::Copying(this, _MTL_PRIVATE_SEL(reset)); } +// property: shaderValidation +_MTL_INLINE MTL::ShaderValidation MTL::RenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + // static method: alloc _MTL_INLINE MTL::RenderPipelineFunctionsDescriptor* MTL::RenderPipelineFunctionsDescriptor::alloc() { @@ -1169,6 +1195,12 @@ _MTL_INLINE MTL::RenderPipelineState* MTL::RenderPipelineState::newRenderPipelin return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithAdditionalBinaryFunctions_error_), additionalBinaryFunctions, error); } +// property: shaderValidation +_MTL_INLINE MTL::ShaderValidation MTL::RenderPipelineState::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + // static method: alloc _MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineColorAttachmentDescriptorArray::alloc() { @@ -1380,6 +1412,17 @@ _MTL_INLINE void MTL::TileRenderPipelineDescriptor::reset() Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); } +// property: shaderValidation +_MTL_INLINE MTL::ShaderValidation MTL::TileRenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + // static method: alloc _MTL_INLINE MTL::MeshRenderPipelineDescriptor* MTL::MeshRenderPipelineDescriptor::alloc() { @@ -1614,6 +1657,17 @@ _MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setSupportIndirectCommandBuf Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); } +// property: binaryArchives +_MTL_INLINE NS::Array* MTL::MeshRenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + // property: objectLinkedFunctions _MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::objectLinkedFunctions() const { @@ -1652,3 +1706,14 @@ _MTL_INLINE void MTL::MeshRenderPipelineDescriptor::reset() { Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); } + +// property: shaderValidation +_MTL_INLINE MTL::ShaderValidation MTL::MeshRenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} diff --git a/vendor/metal-cpp/Metal/MTLResidencySet.hpp b/vendor/metal-cpp/Metal/MTLResidencySet.hpp new file mode 100644 index 00000000..d7f84f90 --- /dev/null +++ b/vendor/metal-cpp/Metal/MTLResidencySet.hpp @@ -0,0 +1,195 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResidencySet.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +#include + +namespace MTL +{ +class ResidencySetDescriptor : public NS::Copying +{ +public: + static class ResidencySetDescriptor* alloc(); + + class ResidencySetDescriptor* init(); + + NS::String* label() const; + void setLabel(const NS::String* label); + + NS::UInteger initialCapacity() const; + void setInitialCapacity(NS::UInteger initialCapacity); +}; + +class ResidencySet : public NS::Referencing +{ +public: + class Device* device() const; + + NS::String* label() const; + + uint64_t allocatedSize() const; + + void requestResidency(); + + void endResidency(); + + void addAllocation(const class Allocation* allocation); + + void addAllocations(const class Allocation* const allocations[], NS::UInteger count); + + void removeAllocation(const class Allocation* allocation); + + void removeAllocations(const class Allocation* const allocations[], NS::UInteger count); + + void removeAllAllocations(); + + bool containsAllocation(const class Allocation* anAllocation); + + NS::Array* allAllocations() const; + + NS::UInteger allocationCount() const; + + void commit(); +}; + +} + +// static method: alloc +_MTL_INLINE MTL::ResidencySetDescriptor* MTL::ResidencySetDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResidencySetDescriptor)); +} + +// method: init +_MTL_INLINE MTL::ResidencySetDescriptor* MTL::ResidencySetDescriptor::init() +{ + return NS::Object::init(); +} + +// property: label +_MTL_INLINE NS::String* MTL::ResidencySetDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::ResidencySetDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +// property: initialCapacity +_MTL_INLINE NS::UInteger MTL::ResidencySetDescriptor::initialCapacity() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initialCapacity)); +} + +_MTL_INLINE void MTL::ResidencySetDescriptor::setInitialCapacity(NS::UInteger initialCapacity) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInitialCapacity_), initialCapacity); +} + +// property: device +_MTL_INLINE MTL::Device* MTL::ResidencySet::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +// property: label +_MTL_INLINE NS::String* MTL::ResidencySet::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +// property: allocatedSize +_MTL_INLINE uint64_t MTL::ResidencySet::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +// method: requestResidency +_MTL_INLINE void MTL::ResidencySet::requestResidency() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(requestResidency)); +} + +// method: endResidency +_MTL_INLINE void MTL::ResidencySet::endResidency() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endResidency)); +} + +// method: addAllocation: +_MTL_INLINE void MTL::ResidencySet::addAllocation(const MTL::Allocation* allocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addAllocation_), allocation); +} + +// method: addAllocations:count: +_MTL_INLINE void MTL::ResidencySet::addAllocations(const MTL::Allocation* const allocations[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addAllocations_count_), allocations, count); +} + +// method: removeAllocation: +_MTL_INLINE void MTL::ResidencySet::removeAllocation(const MTL::Allocation* allocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllocation_), allocation); +} + +// method: removeAllocations:count: +_MTL_INLINE void MTL::ResidencySet::removeAllocations(const MTL::Allocation* const allocations[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllocations_count_), allocations, count); +} + +// method: removeAllAllocations +_MTL_INLINE void MTL::ResidencySet::removeAllAllocations() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllAllocations)); +} + +// method: containsAllocation: +_MTL_INLINE bool MTL::ResidencySet::containsAllocation(const MTL::Allocation* anAllocation) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(containsAllocation_), anAllocation); +} + +// property: allAllocations +_MTL_INLINE NS::Array* MTL::ResidencySet::allAllocations() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allAllocations)); +} + +// property: allocationCount +_MTL_INLINE NS::UInteger MTL::ResidencySet::allocationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocationCount)); +} + +// method: commit +_MTL_INLINE void MTL::ResidencySet::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} diff --git a/vendor/metal-cpp/Metal/MTLResource.hpp b/vendor/metal-cpp/Metal/MTLResource.hpp index b39caa13..0548bc19 100644 --- a/vendor/metal-cpp/Metal/MTLResource.hpp +++ b/vendor/metal-cpp/Metal/MTLResource.hpp @@ -2,7 +2,7 @@ // // Metal/MTLResource.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,12 +20,15 @@ #pragma once +#include + #include "MTLDefines.hpp" #include "MTLHeaderBridge.hpp" #include "MTLPrivate.hpp" #include +#include "MTLAllocation.hpp" #include "MTLResource.hpp" namespace MTL @@ -69,7 +72,7 @@ _MTL_OPTIONS(NS::UInteger, ResourceOptions) { ResourceOptionCPUCacheModeWriteCombined = 1, }; -class Resource : public NS::Referencing +class Resource : public NS::Referencing { public: NS::String* label() const; @@ -96,6 +99,8 @@ class Resource : public NS::Referencing void makeAliasable(); bool isAliasable(); + + kern_return_t setOwner(task_id_token_t task_id_token); }; } @@ -176,3 +181,9 @@ _MTL_INLINE bool MTL::Resource::isAliasable() { return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAliasable)); } + +// method: setOwnerWithIdentity: +_MTL_INLINE kern_return_t MTL::Resource::setOwner(task_id_token_t task_id_token) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setOwnerWithIdentity_), task_id_token); +} diff --git a/vendor/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp b/vendor/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp index cf54f387..ba08d4f0 100644 --- a/vendor/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp +++ b/vendor/metal-cpp/Metal/MTLResourceStateCommandEncoder.hpp @@ -2,7 +2,7 @@ // // Metal/MTLResourceStateCommandEncoder.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLResourceStatePass.hpp b/vendor/metal-cpp/Metal/MTLResourceStatePass.hpp index 3c759b1a..86adfc82 100644 --- a/vendor/metal-cpp/Metal/MTLResourceStatePass.hpp +++ b/vendor/metal-cpp/Metal/MTLResourceStatePass.hpp @@ -2,7 +2,7 @@ // // Metal/MTLResourceStatePass.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLSampler.hpp b/vendor/metal-cpp/Metal/MTLSampler.hpp index b086744a..bd1aeb32 100644 --- a/vendor/metal-cpp/Metal/MTLSampler.hpp +++ b/vendor/metal-cpp/Metal/MTLSampler.hpp @@ -2,7 +2,7 @@ // // Metal/MTLSampler.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp b/vendor/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp index dbc37467..88eaff22 100644 --- a/vendor/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp +++ b/vendor/metal-cpp/Metal/MTLStageInputOutputDescriptor.hpp @@ -2,7 +2,7 @@ // // Metal/MTLStageInputOutputDescriptor.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLTexture.hpp b/vendor/metal-cpp/Metal/MTLTexture.hpp index 15898f40..2d0b74c7 100644 --- a/vendor/metal-cpp/Metal/MTLTexture.hpp +++ b/vendor/metal-cpp/Metal/MTLTexture.hpp @@ -2,7 +2,7 @@ // // Metal/MTLTexture.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLTypes.hpp b/vendor/metal-cpp/Metal/MTLTypes.hpp index 09a81840..9c1f89f0 100644 --- a/vendor/metal-cpp/Metal/MTLTypes.hpp +++ b/vendor/metal-cpp/Metal/MTLTypes.hpp @@ -2,7 +2,7 @@ // // Metal/MTLTypes.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLVersion.hpp b/vendor/metal-cpp/Metal/MTLVersion.hpp index f811e555..3f97aedc 100644 --- a/vendor/metal-cpp/Metal/MTLVersion.hpp +++ b/vendor/metal-cpp/Metal/MTLVersion.hpp @@ -2,7 +2,7 @@ // // Metal/MTLVersion.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ //------------------------------------------------------------------------------------------------------------------------------------------------------------- -#define METALCPP_VERSION_MAJOR 354 -#define METALCPP_VERSION_MINOR 0 +#define METALCPP_VERSION_MAJOR 366 +#define METALCPP_VERSION_MINOR 11 #define METALCPP_VERSION_PATCH 0 #define METALCPP_SUPPORTS_VERSION(major, minor, patch) \ diff --git a/vendor/metal-cpp/Metal/MTLVertexDescriptor.hpp b/vendor/metal-cpp/Metal/MTLVertexDescriptor.hpp index 7e92b551..eca7e155 100644 --- a/vendor/metal-cpp/Metal/MTLVertexDescriptor.hpp +++ b/vendor/metal-cpp/Metal/MTLVertexDescriptor.hpp @@ -2,7 +2,7 @@ // // Metal/MTLVertexDescriptor.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/MTLVisibleFunctionTable.hpp b/vendor/metal-cpp/Metal/MTLVisibleFunctionTable.hpp index 88e8fb1e..9ebf6ea7 100644 --- a/vendor/metal-cpp/Metal/MTLVisibleFunctionTable.hpp +++ b/vendor/metal-cpp/Metal/MTLVisibleFunctionTable.hpp @@ -2,7 +2,7 @@ // // Metal/MTLVisibleFunctionTable.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/Metal/Metal.hpp b/vendor/metal-cpp/Metal/Metal.hpp index 6ea7adf3..ca587c35 100644 --- a/vendor/metal-cpp/Metal/Metal.hpp +++ b/vendor/metal-cpp/Metal/Metal.hpp @@ -2,7 +2,7 @@ // // Metal/Metal.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ #include "MTLAccelerationStructure.hpp" #include "MTLAccelerationStructureCommandEncoder.hpp" #include "MTLAccelerationStructureTypes.hpp" +#include "MTLAllocation.hpp" #include "MTLArgument.hpp" #include "MTLArgumentEncoder.hpp" #include "MTLBinaryArchive.hpp" @@ -62,6 +63,7 @@ #include "MTLIOCompressor.hpp" #include "MTLLibrary.hpp" #include "MTLLinkedFunctions.hpp" +#include "MTLLogState.hpp" #include "MTLParallelRenderCommandEncoder.hpp" #include "MTLPipeline.hpp" #include "MTLPixelFormat.hpp" @@ -70,6 +72,7 @@ #include "MTLRenderCommandEncoder.hpp" #include "MTLRenderPass.hpp" #include "MTLRenderPipeline.hpp" +#include "MTLResidencySet.hpp" #include "MTLResource.hpp" #include "MTLResourceStateCommandEncoder.hpp" #include "MTLResourceStatePass.hpp" diff --git a/vendor/metal-cpp/MetalFX/MTLFXDefines.hpp b/vendor/metal-cpp/MetalFX/MTLFXDefines.hpp index 8b452168..320e0aa8 100644 --- a/vendor/metal-cpp/MetalFX/MTLFXDefines.hpp +++ b/vendor/metal-cpp/MetalFX/MTLFXDefines.hpp @@ -2,7 +2,7 @@ // // MetalFX/MTLFXDefines.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/MetalFX/MTLFXPrivate.hpp b/vendor/metal-cpp/MetalFX/MTLFXPrivate.hpp index 104678ad..c0ce716d 100644 --- a/vendor/metal-cpp/MetalFX/MTLFXPrivate.hpp +++ b/vendor/metal-cpp/MetalFX/MTLFXPrivate.hpp @@ -2,7 +2,7 @@ // // MetalFX/MTLFXPrivate.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ namespace MTLFX::Private } } // MTLFX::Private -#if defined( __MAC_13_0 ) || defined( __MAC_14_0 ) || defined( __IPHONE_16_0 ) || defined( __IPHONE_17_0 ) || defined( __TVOS_16_0 ) || defined( __TVOS_17_0 ) +#if defined(__MAC_15_0) || defined(__IPHONE_18_0) || defined(__TVOS_18_0) #define _MTLFX_PRIVATE_DEF_STR( type, symbol ) \ _MTLFX_EXTERN type const MTLFX##symbol _MTLFX_PRIVATE_IMPORT; \ @@ -97,7 +97,7 @@ namespace MTLFX::Private #define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) _MTLFX_PRIVATE_DEF_CONST( type, symbol ) -#endif // defined( __MAC_13_0 ) || defined( __MAC_14_0 ) || defined( __IPHONE_16_0 ) || defined( __IPHONE_17_0 ) || defined( __TVOS_16_0 ) || defined( __TVOS_17_0 ) +#endif // defined(__MAC_15_0) || defined(__IPHONE_18_0) || defined(__TVOS_18_0) #else @@ -184,6 +184,8 @@ namespace MTLFX "isDepthReversed" ); _MTLFX_PRIVATE_DEF_SEL( isInputContentPropertiesEnabled, "isInputContentPropertiesEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isReactiveMaskTextureEnabled, + "isReactiveMaskTextureEnabled" ); _MTLFX_PRIVATE_DEF_SEL( jitterOffsetX, "jitterOffsetX" ); _MTLFX_PRIVATE_DEF_SEL( jitterOffsetY, @@ -214,8 +216,16 @@ namespace MTLFX "outputWidth" ); _MTLFX_PRIVATE_DEF_SEL( preExposure, "preExposure" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveMaskTextureFormat, + "reactiveMaskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveTextureUsage, + "reactiveTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveMaskTexture, + "reactiveMaskTexture" ); _MTLFX_PRIVATE_DEF_SEL( reset, "reset" ); + _MTLFX_PRIVATE_DEF_SEL( requiresSynchronousInitialization, + "requiresSynchronousInitialization" ); _MTLFX_PRIVATE_DEF_SEL( setAutoExposureEnabled_, "setAutoExposureEnabled:" ); _MTLFX_PRIVATE_DEF_SEL( setColorProcessingMode_, @@ -270,6 +280,14 @@ namespace MTLFX "setOutputWidth:" ); _MTLFX_PRIVATE_DEF_SEL( setPreExposure_, "setPreExposure:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTexture_, + "setReactiveMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTextureEnabled_, + "setReactiveMaskTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTextureFormat_, + "setReactiveMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setRequiresSynchronousInitialization_, + "setRequiresSynchronousInitialization:" ); _MTLFX_PRIVATE_DEF_SEL( setReset_, "setReset:" ); _MTLFX_PRIVATE_DEF_SEL( supportedInputContentMaxScaleForDevice_, diff --git a/vendor/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp b/vendor/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp index 841898d4..4a79676c 100644 --- a/vendor/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp +++ b/vendor/metal-cpp/MetalFX/MTLFXSpatialScaler.hpp @@ -2,7 +2,7 @@ // // MetalFX/MTLFXSpatialScaler.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp b/vendor/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp index e4782973..6305e832 100644 --- a/vendor/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp +++ b/vendor/metal-cpp/MetalFX/MTLFXTemporalScaler.hpp @@ -2,7 +2,7 @@ // // MetalFX/MTLFXTemporalScaler.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -67,6 +67,15 @@ namespace MTLFX bool isInputContentPropertiesEnabled() const; void setInputContentPropertiesEnabled( bool enabled ); + bool requiresSynchronousInitialization() const; + void setRequiresSynchronousInitialization(bool requiresSynchronousInitialization); + + bool isReactiveMaskTextureEnabled() const; + void setReactiveMaskTextureEnabled( bool enabled ); + + MTL::PixelFormat reactiveMaskTextureFormat() const; + void setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ); + float inputContentMinScale() const; void setInputContentMinScale( float scale ); @@ -125,6 +134,11 @@ namespace MTLFX float motionVectorScaleY() const; void setMotionVectorScaleY( float scale ); + MTL::Texture* reactiveMaskTexture() const; + void setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ); + + MTL::TextureUsage reactiveTextureUsage() const; + bool reset() const; void setReset( bool reset ); @@ -303,6 +317,49 @@ _MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentPropertiesEna Object::sendMessage< void >( this, _MTL_PRIVATE_SEL( setInputContentPropertiesEnabled_ ), enabled ); } + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::requiresSynchronousInitialization() const +{ + return Object::sendMessage< bool >( this, _MTL_PRIVATE_SEL( requiresSynchronousInitialization ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setRequiresSynchronousInitialization(bool requiresSynchronousInitialization) +{ + Object::sendMessage< void >( this, _MTL_PRIVATE_SEL( setRequiresSynchronousInitialization_ ), requiresSynchronousInitialization ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isReactiveMaskTextureEnabled() const +{ + return Object::sendMessage< bool >( this, _MTL_PRIVATE_SEL( isReactiveMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setReactiveMaskTextureEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTL_PRIVATE_SEL( setReactiveMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::reactiveMaskTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTL_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + Object::sendMessage< void >( this, _MTL_PRIVATE_SEL( setReactiveMaskTextureFormat_ ), pixelFormat ); +} + //------------------------------------------------------------------------------------------------------------------------------------------------------------- _MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::inputContentMinScale() const @@ -575,6 +632,27 @@ _MTLFX_INLINE void MTLFX::TemporalScaler::setMotionVectorScaleY( float scale ) //------------------------------------------------------------------------------------------------------------------------------------------------------------- +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScaler::reactiveMaskTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTL_PRIVATE_SEL( reactiveMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScaler::setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ) +{ + Object::sendMessage< void >( this, _MTL_PRIVATE_SEL( setReactiveMaskTexture_ ), reactiveMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScaler::reactiveTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTL_PRIVATE_SEL( reactiveTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + _MTLFX_INLINE bool MTLFX::TemporalScaler::reset() const { return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( reset ) ); diff --git a/vendor/metal-cpp/MetalFX/MetalFX.hpp b/vendor/metal-cpp/MetalFX/MetalFX.hpp index 40405cd3..10f6600d 100644 --- a/vendor/metal-cpp/MetalFX/MetalFX.hpp +++ b/vendor/metal-cpp/MetalFX/MetalFX.hpp @@ -2,7 +2,7 @@ // // MetalFX/MetalFX.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/QuartzCore/CADefines.hpp b/vendor/metal-cpp/QuartzCore/CADefines.hpp index 83f3d8fc..b0641de0 100644 --- a/vendor/metal-cpp/QuartzCore/CADefines.hpp +++ b/vendor/metal-cpp/QuartzCore/CADefines.hpp @@ -2,7 +2,7 @@ // // QuartzCore/CADefines.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/QuartzCore/CAMetalDrawable.hpp b/vendor/metal-cpp/QuartzCore/CAMetalDrawable.hpp index 99a3872a..0057773a 100644 --- a/vendor/metal-cpp/QuartzCore/CAMetalDrawable.hpp +++ b/vendor/metal-cpp/QuartzCore/CAMetalDrawable.hpp @@ -2,7 +2,7 @@ // // QuartzCore/CAMetalDrawable.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/QuartzCore/CAMetalLayer.hpp b/vendor/metal-cpp/QuartzCore/CAMetalLayer.hpp index 904e2188..2edc683d 100644 --- a/vendor/metal-cpp/QuartzCore/CAMetalLayer.hpp +++ b/vendor/metal-cpp/QuartzCore/CAMetalLayer.hpp @@ -2,7 +2,7 @@ // // QuartzCore/CAMetalDrawable.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/QuartzCore/CAPrivate.hpp b/vendor/metal-cpp/QuartzCore/CAPrivate.hpp index 701ee082..113e7216 100644 --- a/vendor/metal-cpp/QuartzCore/CAPrivate.hpp +++ b/vendor/metal-cpp/QuartzCore/CAPrivate.hpp @@ -2,7 +2,7 @@ // // QuartzCore/CAPrivate.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/QuartzCore/QuartzCore.hpp b/vendor/metal-cpp/QuartzCore/QuartzCore.hpp index beb57b61..681003ad 100644 --- a/vendor/metal-cpp/QuartzCore/QuartzCore.hpp +++ b/vendor/metal-cpp/QuartzCore/QuartzCore.hpp @@ -2,7 +2,7 @@ // // QuartzCore/QuartzCore.hpp // -// Copyright 2020-2023 Apple Inc. +// Copyright 2020-2024 Apple Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/vendor/metal-cpp/README.md b/vendor/metal-cpp/README.md index 52e1938c..963dad43 100644 --- a/vendor/metal-cpp/README.md +++ b/vendor/metal-cpp/README.md @@ -18,6 +18,7 @@ | Version | Changes | |-|-| +| macOS 15, iOS 18 | Add all the Metal APIs in macOS 15 and iOS 18. | | macOS 14, iOS 17 | Add support for the **MetalFX** framework.
Add all the APIs in macOS 14 and iOS 17. | | macOS 13.3, iOS 16.4 | Add all the APIs in macOS 13.3 and iOS 16.4. | | macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.
New optional `NS::SharedPtr` type to assist with memory management.
New convenience function to create a `CA::MetalLayer`.
New `MTLSTR(str)` macro allows faster string creation from literals.
Fix a problem with the signature of functions that take an array of pointers as input.
Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.| diff --git a/vendor/metal-cpp/SingleHeader/MakeSingleHeader.py b/vendor/metal-cpp/SingleHeader/MakeSingleHeader.py index 520cb889..c8d3715f 100755 --- a/vendor/metal-cpp/SingleHeader/MakeSingleHeader.py +++ b/vendor/metal-cpp/SingleHeader/MakeSingleHeader.py @@ -4,7 +4,7 @@ # # SingleHeader/MakeSingleHeader.py # -# Copyright 2020-2023 Apple Inc. +# Copyright 2020-2024 Apple Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ class HeaderPrefix( object ): '//\n' '// {meta_data}\n' '//\n' - '// Copyright 2020-2023 Apple Inc.\n' + '// Copyright 2020-2024 Apple Inc.\n' '//\n' '// Licensed under the Apache License, Version 2.0 (the "License");\n' '// you may not use this file except in compliance with the License.\n'