From 7f02cd874e380027540f305d9a8b36a40ef712b8 Mon Sep 17 00:00:00 2001 From: Valentin Roussellet Date: Wed, 4 Dec 2024 14:28:54 -0800 Subject: [PATCH] MLXArray (Data, Dtype) constructor , DType size. (#158) * Array init, Dtype codable and tests Signed-off-by: Valentin Roussellet --- Source/MLX/DType.swift | 19 ++++++++++++++++- Source/MLX/MLXArray+Init.swift | 27 +++++++++++++++++++++++++ Tests/MLXTests/MLXArray+InitTests.swift | 26 +++++++++++++++++++++++- Tests/MLXTests/MLXArrayTests.swift | 7 +++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/Source/MLX/DType.swift b/Source/MLX/DType.swift index 013b94e0..2f0e3569 100644 --- a/Source/MLX/DType.swift +++ b/Source/MLX/DType.swift @@ -24,7 +24,7 @@ import Numerics /// - ``MLXArray/asType(_:stream:)-6d44y`` /// - ``MLXArray/asType(_:stream:)-4eqoc`` /// - ``MLXArray/init(_:dtype:)`` -public enum DType: Hashable, Sendable { +public enum DType: Hashable, Sendable, CaseIterable { case bool case uint8 case uint16 @@ -105,6 +105,23 @@ public enum DType: Hashable, Sendable { default: false } } + + public var size: Int { + mlx_dtype_size(cmlxDtype) + } +} + +extension DType: Encodable { + public func encode(to encoder: any Encoder) throws { + try self.cmlxDtype.rawValue.encode(to: encoder) + } +} + +extension DType: Decodable { + public init(from decoder: any Decoder) throws { + let rawValue = try UInt32(from: decoder) + self.init(mlx_dtype(rawValue: rawValue)) + } } /// Protocol for types that can provide a ``DType`` diff --git a/Source/MLX/MLXArray+Init.swift b/Source/MLX/MLXArray+Init.swift index 260598df..3bab563c 100644 --- a/Source/MLX/MLXArray+Init.swift +++ b/Source/MLX/MLXArray+Init.swift @@ -11,6 +11,13 @@ private func shapePrecondition(shape: [Int]?, count: Int) { } } +private func shapePrecondition(shape: [Int]?, byteCount: Int, type: DType) { + if let shape { + let total = shape.reduce(1, *) * type.size + precondition(total == byteCount, "shape \(shape) total \(total)B != \(byteCount)B (actual)") + } +} + extension MLXArray { /// Initalizer allowing creation of scalar (0-dimension) `MLXArray` from an `Int32`. @@ -436,6 +443,25 @@ extension MLXArray { }) } + /// Initalizer allowing creation of `MLXArray` from a `Data` buffer values with + /// an optional shape and an explicit DType. + /// ### See Also + /// - + public convenience init(_ data: Data, _ shape: [Int]? = nil, dtype: DType) { + self.init( + data.withUnsafeBytes { ptr in + shapePrecondition(shape: shape, byteCount: data.count, type: dtype) + precondition(data.count % dtype.size == 0) + let shape = shape ?? [data.count / dtype.size] + return mlx_array_new_data( + ptr.baseAddress!, shape.asInt32, shape.count.int32, dtype.cmlxDtype) + }) + } + + public convenience init(data: MLXArrayData) { + self.init(data.data, data.shape, dtype: data.dType) + } + /// Create a ``DType/complex64`` scalar. /// - Parameters: /// - real: real part @@ -448,6 +474,7 @@ extension MLXArray { public convenience init(_ value: Complex) { self.init(real: value.real, imaginary: value.imaginary) } + } // MARK: - Expressible by literals diff --git a/Tests/MLXTests/MLXArray+InitTests.swift b/Tests/MLXTests/MLXArray+InitTests.swift index 99130be1..5347eacd 100644 --- a/Tests/MLXTests/MLXArray+InitTests.swift +++ b/Tests/MLXTests/MLXArray+InitTests.swift @@ -12,8 +12,30 @@ class MLXArrayInitTests: XCTestCase { setDefaultDevice() } - // MARK: - Creation + // MARK: - Dtype + func testDtypeSize() { + // Checking that the size of the dtype matches the array's itemsize + for dtype in DType.allCases { + XCTAssertEqual(MLXArray(Data(), dtype: dtype).itemSize, dtype.size) + } + } + func testDtypeCodable() { + let encoder = JSONEncoder() + let decoder = JSONDecoder() + // Test encoding / decoding round trip + for dtype in DType.allCases { + do { + let json: Data = try encoder.encode(dtype) + let decoded = try decoder.decode(DType.self, from: json) + XCTAssertEqual(decoded, dtype) + } catch { + XCTFail("Encoding / decoding failed") + } + } + } + + // MARK: - Creation func testInt() { // array creation with Int -- we want it to produce .int32 let a1 = MLXArray(500) @@ -87,8 +109,10 @@ class MLXArrayInitTests: XCTestCase { func testData() { let data = Data([1, 2, 3, 4]) let a = MLXArray(data, [2, 2], type: UInt8.self) + let b = MLXArray(data, [2, 2], dtype: DType.uint8) let expected = MLXArray(UInt8(1) ... 4, [2, 2]) assertEqual(a, expected) + assertEqual(b, expected) } func testUnsafeRawPointer() { diff --git a/Tests/MLXTests/MLXArrayTests.swift b/Tests/MLXTests/MLXArrayTests.swift index dc1fa8b5..715a7759 100644 --- a/Tests/MLXTests/MLXArrayTests.swift +++ b/Tests/MLXTests/MLXArrayTests.swift @@ -129,6 +129,13 @@ class MLXArrayTests: XCTestCase { } } + func testAsDataRoundTrip() { + let a = MLXArray(0 ..< 16, [4, 4]) + let arrayData = a.asData(access: .copy) + let result = MLXArray(arrayData.data, arrayData.shape, dtype: arrayData.dType) + assertEqual(a, result) + } + func testAsDataNonContiguous() { // buffer with holes (last dimension has stride of 2 and // thus larger storage than it physically needs)