Skip to content

Commit

Permalink
MLXArray (Data, Dtype) constructor , DType size. (#158)
Browse files Browse the repository at this point in the history
* Array init, Dtype codable and tests

Signed-off-by: Valentin Roussellet <[email protected]>
  • Loading branch information
louen authored Dec 4, 2024
1 parent bafcb33 commit 7f02cd8
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
19 changes: 18 additions & 1 deletion Source/MLX/DType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Numerics
/// - ``MLXArray/asType(_:stream:)-6d44y``
/// - ``MLXArray/asType(_:stream:)-4eqoc``
/// - ``MLXArray/init(_:dtype:)``
public enum DType: Hashable, Sendable {
public enum DType: Hashable, Sendable, CaseIterable {
case bool
case uint8
case uint16
Expand Down Expand Up @@ -105,6 +105,23 @@ public enum DType: Hashable, Sendable {
default: false
}
}

public var size: Int {
mlx_dtype_size(cmlxDtype)
}
}

extension DType: Encodable {
public func encode(to encoder: any Encoder) throws {
try self.cmlxDtype.rawValue.encode(to: encoder)
}
}

extension DType: Decodable {
public init(from decoder: any Decoder) throws {
let rawValue = try UInt32(from: decoder)
self.init(mlx_dtype(rawValue: rawValue))
}
}

/// Protocol for types that can provide a ``DType``
Expand Down
27 changes: 27 additions & 0 deletions Source/MLX/MLXArray+Init.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ private func shapePrecondition(shape: [Int]?, count: Int) {
}
}

private func shapePrecondition(shape: [Int]?, byteCount: Int, type: DType) {
if let shape {
let total = shape.reduce(1, *) * type.size
precondition(total == byteCount, "shape \(shape) total \(total)B != \(byteCount)B (actual)")
}
}

extension MLXArray {

/// Initalizer allowing creation of scalar (0-dimension) `MLXArray` from an `Int32`.
Expand Down Expand Up @@ -436,6 +443,25 @@ extension MLXArray {
})
}

/// Initalizer allowing creation of `MLXArray` from a `Data` buffer values with
/// an optional shape and an explicit DType.
/// ### See Also
/// - <doc:initialization>
public convenience init(_ data: Data, _ shape: [Int]? = nil, dtype: DType) {
self.init(
data.withUnsafeBytes { ptr in
shapePrecondition(shape: shape, byteCount: data.count, type: dtype)
precondition(data.count % dtype.size == 0)
let shape = shape ?? [data.count / dtype.size]
return mlx_array_new_data(
ptr.baseAddress!, shape.asInt32, shape.count.int32, dtype.cmlxDtype)
})
}

public convenience init(data: MLXArrayData) {
self.init(data.data, data.shape, dtype: data.dType)
}

/// Create a ``DType/complex64`` scalar.
/// - Parameters:
/// - real: real part
Expand All @@ -448,6 +474,7 @@ extension MLXArray {
public convenience init(_ value: Complex<Float>) {
self.init(real: value.real, imaginary: value.imaginary)
}

}

// MARK: - Expressible by literals
Expand Down
26 changes: 25 additions & 1 deletion Tests/MLXTests/MLXArray+InitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,30 @@ class MLXArrayInitTests: XCTestCase {
setDefaultDevice()
}

// MARK: - Creation
// MARK: - Dtype
func testDtypeSize() {
// Checking that the size of the dtype matches the array's itemsize
for dtype in DType.allCases {
XCTAssertEqual(MLXArray(Data(), dtype: dtype).itemSize, dtype.size)
}
}

func testDtypeCodable() {
let encoder = JSONEncoder()
let decoder = JSONDecoder()
// Test encoding / decoding round trip
for dtype in DType.allCases {
do {
let json: Data = try encoder.encode(dtype)
let decoded = try decoder.decode(DType.self, from: json)
XCTAssertEqual(decoded, dtype)
} catch {
XCTFail("Encoding / decoding failed")
}
}
}

// MARK: - Creation
func testInt() {
// array creation with Int -- we want it to produce .int32
let a1 = MLXArray(500)
Expand Down Expand Up @@ -87,8 +109,10 @@ class MLXArrayInitTests: XCTestCase {
func testData() {
let data = Data([1, 2, 3, 4])
let a = MLXArray(data, [2, 2], type: UInt8.self)
let b = MLXArray(data, [2, 2], dtype: DType.uint8)
let expected = MLXArray(UInt8(1) ... 4, [2, 2])
assertEqual(a, expected)
assertEqual(b, expected)
}

func testUnsafeRawPointer() {
Expand Down
7 changes: 7 additions & 0 deletions Tests/MLXTests/MLXArrayTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ class MLXArrayTests: XCTestCase {
}
}

func testAsDataRoundTrip() {
let a = MLXArray(0 ..< 16, [4, 4])
let arrayData = a.asData(access: .copy)
let result = MLXArray(arrayData.data, arrayData.shape, dtype: arrayData.dType)
assertEqual(a, result)
}

func testAsDataNonContiguous() {
// buffer with holes (last dimension has stride of 2 and
// thus larger storage than it physically needs)
Expand Down

0 comments on commit 7f02cd8

Please sign in to comment.