Skip to content

Commit

Permalink
[WIP] adopt new mlx-c api (#150)
Browse files Browse the repository at this point in the history
- adopt new mlx-c api - ml-explore/mlx-c#38
- mlx v0.21.0
  • Loading branch information
davidkoski authored Dec 4, 2024
1 parent d5fefbf commit bafcb33
Show file tree
Hide file tree
Showing 190 changed files with 10,184 additions and 5,126 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "v0.0.10")
GIT_TAG "v0.1.0")
FetchContent_MakeAvailable(mlx-c)

# swift-numerics
Expand Down
9 changes: 0 additions & 9 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,6 @@ let package = Package(
sources: ["Tutorial.swift"]
),

// ------
// Internal Tools

.executableTarget(
name: "GenerateGrad",
path: "Source/Tools",
sources: ["GenerateGrad.swift"]
),

],
cxxLanguageStandard: .gnucxx17
)
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
Submodule mlx updated 208 files
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx-c
Submodule mlx-c updated 89 files
+17 −13 CMakeLists.txt
+3 −4 README.md
+1 −1 docs/src/conf.py
+6 −10 docs/src/index.rst
+0 −5 docs/src/ioutils.rst
+0 −5 docs/src/object.rst
+70 −35 docs/src/overview.rst
+0 −5 docs/src/tuple.rst
+0 −5 docs/src/variant.rst
+7 −0 examples/CMakeLists.txt
+110 −0 examples/example-closure.c
+79 −77 examples/example-grad.c
+62 −0 examples/example-metal-kernel.c
+19 −21 examples/example-safe-tensors.c
+21 −29 examples/example.c
+472 −141 mlx/c/array.cpp
+107 −52 mlx/c/array.h
+607 −397 mlx/c/closure.cpp
+102 −88 mlx/c/closure.h
+63 −36 mlx/c/compile.cpp
+14 −5 mlx/c/compile.h
+64 −27 mlx/c/device.cpp
+25 −4 mlx/c/device.h
+98 −53 mlx/c/distributed.cpp
+28 −20 mlx/c/distributed.h
+33 −13 mlx/c/distributed_group.cpp
+3 −1 mlx/c/distributed_group.h
+5 −1 mlx/c/error.cpp
+1 −1 mlx/c/error.h
+404 −92 mlx/c/fast.cpp
+106 −36 mlx/c/fast.h
+215 −100 mlx/c/fft.cpp
+74 −45 mlx/c/fft.h
+102 −43 mlx/c/io.cpp
+23 −14 mlx/c/io.h
+0 −29 mlx/c/ioutils.cpp
+0 −43 mlx/c/ioutils.h
+209 −57 mlx/c/linalg.cpp
+64 −25 mlx/c/linalg.h
+179 −203 mlx/c/map.cpp
+62 −115 mlx/c/map.h
+107 −34 mlx/c/metal.cpp
+22 −12 mlx/c/metal.h
+0 −4 mlx/c/mlx.h
+0 −19 mlx/c/object.cpp
+0 −37 mlx/c/object.h
+3,131 −1,019 mlx/c/ops.cpp
+838 −406 mlx/c/ops.h
+46 −7 mlx/c/private/array.h
+392 −84 mlx/c/private/closure.h
+46 −7 mlx/c/private/device.h
+54 −8 mlx/c/private/distributed_group.h
+74 −0 mlx/c/private/enums.h
+82 −22 mlx/c/private/io.h
+204 −69 mlx/c/private/map.h
+13 −0 mlx/c/private/mlx.h
+0 −17 mlx/c/private/object.h
+47 −7 mlx/c/private/stream.h
+48 −9 mlx/c/private/string.h
+0 −84 mlx/c/private/tuple.h
+0 −200 mlx/c/private/utils.h
+0 −37 mlx/c/private/variant.h
+197 −137 mlx/c/private/vector.h
+350 −175 mlx/c/random.cpp
+103 −68 mlx/c/random.h
+90 −25 mlx/c/stream.cpp
+25 −10 mlx/c/stream.h
+37 −10 mlx/c/string.cpp
+19 −2 mlx/c/string.h
+122 −49 mlx/c/transforms.cpp
+28 −18 mlx/c/transforms.h
+40 −46 mlx/c/transforms_impl.cpp
+11 −15 mlx/c/transforms_impl.h
+0 −153 mlx/c/tuple.cpp
+0 −111 mlx/c/tuple.h
+0 −93 mlx/c/variant.cpp
+0 −66 mlx/c/variant.h
+424 −267 mlx/c/vector.cpp
+67 −110 mlx/c/vector.h
+61 −308 python/c.py
+200 −135 python/closure_generator.py
+46 −31 python/generator.py
+212 −165 python/map_generator.py
+374 −0 python/mlxhooks.py
+490 −20 python/mlxtypes.py
+0 −246 python/tuple_generator.py
+78 −0 python/type_private_generator.py
+0 −241 python/variant_generator.py
+171 −105 python/vector_generator.py
37 changes: 20 additions & 17 deletions Source/Cmlx/mlx-generated/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
Expand All @@ -85,12 +85,12 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
Expand All @@ -99,13 +99,17 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
Expand All @@ -116,13 +120,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride;
Expand Down
37 changes: 20 additions & 17 deletions Source/Cmlx/mlx-generated/binary_two.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ template <typename T, typename U, typename Op>
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto a_idx = elem_to_loc_1<size_t, uint>(index, a_stride);
auto b_idx = elem_to_loc_1<size_t, uint>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
Expand All @@ -110,14 +110,14 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
auto a_idx = elem_to_loc_2<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, typename IdxT = size_t>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
Expand All @@ -127,15 +127,19 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto a_idx = elem_to_loc_3<size_t, IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3<size_t, IdxT>(index, b_strides);
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int N = 1>
template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = size_t>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
Expand All @@ -147,13 +151,12 @@ template <typename T, typename U, typename Op, int N = 1>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(
auto idx = elem_to_loc_2_nd<size_t, IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1];
size_t out_idx =
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
auto a_xstride = a_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
IdxT a_xstride = a_strides[ndim - 1];
IdxT b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
Expand Down
14 changes: 14 additions & 0 deletions Source/Cmlx/mlx-generated/compiled_preamble.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,13 @@ struct Floor {
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
Expand Down Expand Up @@ -635,6 +642,13 @@ struct Negative {
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {
Expand Down
Loading

0 comments on commit bafcb33

Please sign in to comment.