Skip to content

Commit

Permalink
fix to_array behavior with bool -- was not promoting types correctly (#…
Browse files Browse the repository at this point in the history
…127)

* fix to_array behavior with bool -- was not promoting types correctly
* treat Int as Int32 to match python
* add specialized handling for item() to avoid conversions
  • Loading branch information
davidkoski authored Aug 27, 2024
1 parent f6b9bdc commit 86ad75a
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 11 deletions.
64 changes: 56 additions & 8 deletions Source/MLX/DType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,32 +131,86 @@ extension Bool: HasDType {

extension Int: HasDType {
static public var dtype: DType { .int64 }

public func asMLXArray(dtype: DType?) -> MLXArray {
// callers can use Int64() to get explicit .int64 behavior
let dtype = dtype ?? .int32
return MLXArray(self, dtype: dtype == .bool ? .int32 : dtype)
}
}

extension Int8: HasDType {
static public var dtype: DType { .int8 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension Int16: HasDType {
static public var dtype: DType { .int16 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension Int32: HasDType {
static public var dtype: DType { .int32 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension Int64: HasDType {
static public var dtype: DType { .int64 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}

extension UInt8: HasDType {
static public var dtype: DType { .uint8 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension UInt16: HasDType {
static public var dtype: DType { .uint16 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension UInt32: HasDType {
static public var dtype: DType { .uint32 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension UInt64: HasDType {
static public var dtype: DType { .uint64 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}
extension UInt: HasDType {
static public var dtype: DType { .uint64 }

public func asMLXArray(dtype: DType?) -> MLXArray {
let dtype = dtype ?? Self.dtype
return MLXArray(self, dtype: dtype == .bool ? Self.dtype : dtype)
}
}

#if !arch(x86_64)
Expand Down Expand Up @@ -201,16 +255,10 @@ public protocol ScalarOrArray {
func asMLXArray(dtype: DType?) -> MLXArray
}

extension Int: ScalarOrArray {
public func asMLXArray(dtype: DType?) -> MLXArray {
// callers can use Int64() to get explicit .int64 behavior
MLXArray(Int32(self), dtype: dtype ?? .int32)
}
}

extension Double: ScalarOrArray {
public func asMLXArray(dtype: DType?) -> MLXArray {
MLXArray(Float(self), dtype: dtype ?? .float32)
let dtype = dtype ?? .float32
return MLXArray(Float(self), dtype: dtype.isFloatingPoint ? dtype : .float32)
}
}

Expand Down
100 changes: 97 additions & 3 deletions Source/MLX/MLXArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,49 @@ public final class MLXArray {
item(T.self)
}

/// specialized conversion between integer types -- see ``item(_:)``
private func itemInt() -> Int {
switch self.dtype {
case .bool: mlx_array_item_bool(self.ctx) ? 1 : 0
case .uint8: Int(mlx_array_item_uint8(self.ctx))
case .uint16: Int(mlx_array_item_uint16(self.ctx))
case .uint32: Int(mlx_array_item_uint32(self.ctx))
case .uint64: Int(mlx_array_item_uint64(self.ctx))
case .int8: Int(mlx_array_item_int8(self.ctx))
case .int16: Int(mlx_array_item_int16(self.ctx))
case .int32: Int(mlx_array_item_int32(self.ctx))
case .int64: Int(mlx_array_item_int64(self.ctx))
default: fatalError("itemInt expected an integer dtype: \(self.dtype)")
}
}

/// specialized conversion between integer types -- see ``item(_:)``
private func itemUInt() -> UInt {
switch self.dtype {
case .bool: mlx_array_item_bool(self.ctx) ? 1 : 0
case .uint8: UInt(mlx_array_item_uint8(self.ctx))
case .uint16: UInt(mlx_array_item_uint16(self.ctx))
case .uint32: UInt(mlx_array_item_uint32(self.ctx))
case .uint64: UInt(mlx_array_item_uint64(self.ctx))
case .int8: UInt(mlx_array_item_int8(self.ctx))
case .int16: UInt(mlx_array_item_int16(self.ctx))
case .int32: UInt(mlx_array_item_int32(self.ctx))
case .int64: UInt(mlx_array_item_int64(self.ctx))
default: fatalError("itemUInt expected an integer dtype: \(self.dtype)")
}
}

/// specialized conversion between float types -- see ``item(_:)``
private func itemFloat() -> Float {
switch self.dtype {
#if !arch(x86_64)
case .float16: Float(mlx_array_item_float16(self.ctx))
#endif
case .float32: Float(mlx_array_item_float32(self.ctx))
default: fatalError("itemFloat expected a floating point dtype: \(self.dtype)")
}
}

/// Return the scalar value of the array.
///
/// It is a contract violation to call this on an array with more than one element.
Expand All @@ -163,12 +206,62 @@ public final class MLXArray {
public func item<T: HasDType>(_ type: T.Type) -> T {
precondition(self.size == 1)

// special cases for reading integers and floats from (roughly)
// same typed arrays -- this avoids doing a conversion which
// might end up as an unexpected operation that would mess up
// async evaluation
switch type {
case is Int.Type, is Int8.Type, is Int16.Type, is Int32.Type, is Int64.Type:
if self.dtype.isInteger {
switch type {
case is Int.Type: return Int(itemInt()) as! T
case is Int8.Type: return Int8(itemInt()) as! T
case is Int16.Type: return Int16(itemInt()) as! T
case is Int32.Type: return Int32(itemInt()) as! T
case is Int64.Type: return Int64(itemInt()) as! T
default:
// fall through to default handling
break
}
}
case is UInt8.Type, is UInt16.Type, is UInt32.Type, is UInt64.Type, is UInt.Type:
if self.dtype.isInteger {
switch type {
case is UInt8.Type: return UInt8(itemUInt()) as! T
case is UInt16.Type: return UInt16(itemUInt()) as! T
case is UInt32.Type: return UInt32(itemUInt()) as! T
case is UInt64.Type: return UInt64(itemUInt()) as! T
case is UInt.Type: return UInt(itemUInt()) as! T
default:
// fall through to default handling
break
}
}
#if !arch(x86_64)
case is Float.Type, is Float32.Type, is Float16.Type:
switch self.dtype {
case .float16, .float32:
switch type {
case is Float.Type: return Float(itemFloat()) as! T
case is Float32.Type: return Float32(itemFloat()) as! T
case is Float16.Type: return Float16(itemFloat()) as! T
default:
// fall through to default handling
break
}
default:
break
}
#endif
default:
break
}

// default handling -- convert the type if needed
if type.dtype != self.dtype {
return self.asType(type).item(type)
}

self.eval()

switch type {
case is Bool.Type: return mlx_array_item_bool(self.ctx) as! T
case is UInt8.Type: return mlx_array_item_uint8(self.ctx) as! T
Expand All @@ -187,7 +280,8 @@ public final class MLXArray {
case is Float.Type: return mlx_array_item_float32(self.ctx) as! T
case is Complex<Float32>.Type:
// mlx_array_item_complex64() isn't visible in swift so read the array
// contents
// contents. call self.eval() as this doesn't end up in item()
self.eval()
let ptr = UnsafePointer<Complex<Float32>>(mlx_array_data_complex64(ctx))!
return ptr.pointee as! T
default:
Expand Down
32 changes: 32 additions & 0 deletions Tests/MLXTests/OpsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,36 @@ class OpsTests: XCTestCase {
assertEqual(b, expected)
}

func testConvertScalarInt() {
let a = MLXArray(0 ..< 10)
let b = a .< (a + 1)
let c = b * 25
XCTAssertEqual(b.dtype, .bool)
XCTAssertEqual(c.dtype, .int32)
}

func testConvertScalarFloat16() {
let a = MLXArray(0 ..< 10)
let b = a .< (a + 1)
let c = b * Float16(2.5)
XCTAssertEqual(b.dtype, .bool)
XCTAssertEqual(c.dtype, .float16)
}

func testConvertScalarFloat() {
let a = MLXArray(0 ..< 10)
let b = a .< (a + 1)
let c = b * Float(2.5)
XCTAssertEqual(b.dtype, .bool)
XCTAssertEqual(c.dtype, .float32)
}

func testConvertScalarDouble() {
let a = MLXArray(0 ..< 10)
let b = a .< (a + 1)
let c = b * 2.5
XCTAssertEqual(b.dtype, .bool)
XCTAssertEqual(c.dtype, .float32)
}

}

0 comments on commit 86ad75a

Please sign in to comment.