From 36d63a1fc386a551df14f5b67df1756dc17d2ebc Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Tue, 4 Jun 2024 21:24:44 -0700 Subject: [PATCH] handle non-contiguous backing when reading out MLXArray (#96) * handle non-contiguous backing when reading out MLXArray - fixes #83 - mlx::core::array can have non-contiguous backing - handle those cases and simplify the readout --- Source/MLX/MLXArray.swift | 232 ++++++++++++++++++----------- Tests/MLXTests/MLXArrayTests.swift | 78 ++++++++++ 2 files changed, 219 insertions(+), 91 deletions(-) diff --git a/Source/MLX/MLXArray.swift b/Source/MLX/MLXArray.swift index a7374abe..4a326d04 100644 --- a/Source/MLX/MLXArray.swift +++ b/Source/MLX/MLXArray.swift @@ -160,35 +160,30 @@ public final class MLXArray { /// let value = array[1].item(Float.self) /// ``` public func item(_ type: T.Type) -> T { - self.eval() + precondition(self.size == 1) - var array_ctx = self.ctx - var free = false if type.dtype != self.dtype { - array_ctx = mlx_astype(self.ctx, type.dtype.cmlxDtype, StreamOrDevice.default.ctx) - mlx_array_eval(array_ctx) - free = true + return self.asType(type).item(type) } - // can't do it inside the else as it will free at the end of the block - defer { if free { mlx_free(array_ctx) } } + self.eval() switch type { - case is Bool.Type: return mlx_array_item_bool(array_ctx) as! T - case is UInt8.Type: return mlx_array_item_uint8(array_ctx) as! T - case is UInt16.Type: return mlx_array_item_uint16(array_ctx) as! T - case is UInt32.Type: return mlx_array_item_uint32(array_ctx) as! T - case is UInt64.Type: return mlx_array_item_uint64(array_ctx) as! T - case is Int8.Type: return mlx_array_item_int8(array_ctx) as! T - case is Int16.Type: return mlx_array_item_int16(array_ctx) as! T - case is Int32.Type: return mlx_array_item_int32(array_ctx) as! T - case is Int64.Type: return mlx_array_item_int64(array_ctx) as! T - case is Int.Type: return Int(mlx_array_item_int64(array_ctx)) as! T + case is Bool.Type: return mlx_array_item_bool(self.ctx) as! T + case is UInt8.Type: return mlx_array_item_uint8(self.ctx) as! T + case is UInt16.Type: return mlx_array_item_uint16(self.ctx) as! T + case is UInt32.Type: return mlx_array_item_uint32(self.ctx) as! T + case is UInt64.Type: return mlx_array_item_uint64(self.ctx) as! T + case is Int8.Type: return mlx_array_item_int8(self.ctx) as! T + case is Int16.Type: return mlx_array_item_int16(self.ctx) as! T + case is Int32.Type: return mlx_array_item_int32(self.ctx) as! T + case is Int64.Type: return mlx_array_item_int64(self.ctx) as! T + case is Int.Type: return Int(mlx_array_item_int64(self.ctx)) as! T #if !arch(x86_64) - case is Float16.Type: return mlx_array_item_float16(array_ctx) as! T + case is Float16.Type: return mlx_array_item_float16(self.ctx) as! T #endif - case is Float32.Type: return mlx_array_item_float32(array_ctx) as! T - case is Float.Type: return mlx_array_item_float32(array_ctx) as! T + case is Float32.Type: return mlx_array_item_float32(self.ctx) as! T + case is Float.Type: return mlx_array_item_float32(self.ctx) as! T case is Complex.Type: // mlx_array_item_complex64() isn't visible in swift so read the array // contents @@ -246,6 +241,109 @@ public final class MLXArray { asType(T.dtype, stream: stream) } + /// Return the dimension where the storage is contiguous. + /// + /// If this returns 0 then the whole storage is contiguous. If it returns ndmin + 1 then none of it is contiguous. + func contiguousToDimension() -> Int { + let shape = self.shape + let strides = self.strides + + var expectedStride = 1 + + for (dimension, (shape, stride)) in zip(shape, strides).enumerated().reversed() { + // as long as the actual strides match the expected (contiguous) strides + // the backing is contiguous in these dimensions + if stride != expectedStride { + return dimension + 1 + } + expectedStride *= shape + } + + return 0 + } + + /// Return the physical size of the backing (assuming it is evaluated) in elements + var physicalSize: Int { + // nbytes is the logical size of the input, not the physical size + return zip(self.shape, self.strides) + .map { Swift.abs($0.0 * $0.1) } + .max() + ?? self.size + } + + func copy(from: UnsafeRawBufferPointer, to output: UnsafeMutableRawBufferPointer) { + let contiguousDimension = self.contiguousToDimension() + + if contiguousDimension == 0 { + // entire backing is contiguous + from.copyBytes(to: output) + + } else { + // only part of the backing is contiguous (possibly a single element) + // iterate the non-contiguous parts and copy the contiguous chunks into + // the output. + + // these are the parts to iterate + let shape = self.shape.prefix(upTo: contiguousDimension) + let strides = self.strides.prefix(upTo: contiguousDimension) + let ndim = contiguousDimension + let itemSize = self.itemSize + + // the size of each chunk that we copy. this computes the stride of + // (contiguousDimension - 1) if it were contiguous + let destItemSize: Int + if contiguousDimension == self.ndim { + // nothing contiguous + destItemSize = itemSize + } else { + destItemSize = + self.strides[contiguousDimension] * self.shape[contiguousDimension] * itemSize + } + + // the index of the current source item + var index = Array.init(repeating: 0, count: ndim) + + // output pointer + var dest = output.baseAddress! + + while true { + // compute the source index by multiplying the index by the + // stride for each dimension + + // note: in the case where the array has negative strides / offset + // the base pointer we have will have the offset already applied, + // e.g. asStrided(a, [3, 3], strides: [-3, -1], offset: 8) + + let sourceIndex = zip(index, strides).reduce(0) { $0 + ($1.0 * $1.1) } + + // convert to byte pointer + let src = from.baseAddress! + sourceIndex * itemSize + dest.copyMemory(from: src, byteCount: destItemSize) + + // next output address + dest += destItemSize + + // increment the index + for dimension in Swift.stride(from: ndim - 1, through: 0, by: -1) { + // do we need to "carry" into the next dimension? + if index[dimension] == (shape[dimension] - 1) { + if dimension == 0 { + // all done + return + } + + index[dimension] = 0 + } else { + // just increment the dimension and we are done + index[dimension] += 1 + break + } + } + } + + } + } + /// Return the contents as a single contiguous 1d `Swift.Array`. /// /// Note: because the number of dimensions is dynamic, this cannot produce a multi-dimensional @@ -255,53 +353,17 @@ public final class MLXArray { /// - /// - ``asData(noCopy:)`` public func asArray(_ type: T.Type) -> [T] { - self.eval() - - var array_ctx = self.ctx - var free = false if type.dtype != self.dtype { - array_ctx = mlx_astype(self.ctx, type.dtype.cmlxDtype, StreamOrDevice.default.ctx) - mlx_array_eval(array_ctx) - free = true + return self.asType(type).asArray(type) } - // can't do it inside the else as it will free at the end of the block - defer { if free { mlx_free(array_ctx) } } + self.eval() - func convert(_ ptr: UnsafePointer) -> [T] { - Array(UnsafeBufferPointer(start: ptr, count: self.size)) - } - - switch type { - case is Bool.Type: return convert(mlx_array_data_bool(array_ctx) as! UnsafePointer) - case is UInt8.Type: return convert(mlx_array_data_uint8(array_ctx) as! UnsafePointer) - case is UInt16.Type: return convert(mlx_array_data_uint16(array_ctx) as! UnsafePointer) - case is UInt32.Type: return convert(mlx_array_data_uint32(array_ctx) as! UnsafePointer) - case is UInt64.Type: return convert(mlx_array_data_uint64(array_ctx) as! UnsafePointer) - case is Int8.Type: return convert(mlx_array_data_int8(array_ctx) as! UnsafePointer) - case is Int16.Type: return convert(mlx_array_data_int16(array_ctx) as! UnsafePointer) - case is Int32.Type: return convert(mlx_array_data_int32(array_ctx) as! UnsafePointer) - case is Int64.Type: return convert(mlx_array_data_int64(array_ctx) as! UnsafePointer) - case is Int.Type: - // Int and Int64 are the same bits but distinct types. coerce pointers as needed - let pointer = mlx_array_data_int64(array_ctx) - let bufferPointer = UnsafeBufferPointer(start: pointer, count: self.size) - return bufferPointer.withMemoryRebound(to: Int.self) { buffer in - Array(buffer) as! [T] - } - #if !arch(x86_64) - case is Float16.Type: - return convert(mlx_array_data_float16(array_ctx) as! UnsafePointer) - #endif - case is Float32.Type: return convert(mlx_array_data_float32(array_ctx) as! UnsafePointer) - case is Float.Type: return convert(mlx_array_data_float32(array_ctx) as! UnsafePointer) - case is Complex.Type: - let ptr = UnsafeBufferPointer( - start: UnsafePointer>(mlx_array_data_complex64(ctx)), - count: self.size) - return Array(ptr) as! [T] - default: - fatalError("Unable to get item() as \(type)") + return [T](unsafeUninitializedCapacity: self.size) { destination, initializedCount in + let source = UnsafeRawBufferPointer( + start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) + copy(from: source, to: UnsafeMutableRawBufferPointer(destination)) + initializedCount = self.size } } @@ -317,34 +379,22 @@ public final class MLXArray { public func asData(noCopy: Bool = false) -> Data { self.eval() - func convert(_ ptr: UnsafePointer) -> Data { - if noCopy { - Data( - bytesNoCopy: UnsafeMutableRawPointer(mutating: ptr), count: self.nbytes, - deallocator: .none) - } else { - Data(buffer: UnsafeBufferPointer(start: ptr, count: self.size)) + if noCopy && self.contiguousToDimension() == 0 { + // the backing is contiguous, we can provide a wrapper + // for the contents without a copy (if requested) + let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))! + return Data( + bytesNoCopy: source, count: self.nbytes, + deallocator: .none) + } else { + let source = UnsafeRawBufferPointer( + start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) + + var data = Data(count: self.nbytes) + data.withUnsafeMutableBytes { destination in + copy(from: source, to: destination) } - } - - switch self.dtype { - case .bool: return convert(mlx_array_data_bool(ctx)) - case .uint8: return convert(mlx_array_data_uint8(ctx)) - case .uint16: return convert(mlx_array_data_uint16(ctx)) - case .uint32: return convert(mlx_array_data_uint32(ctx)) - case .uint64: return convert(mlx_array_data_uint64(ctx)) - case .int8: return convert(mlx_array_data_int8(ctx)) - case .int16: return convert(mlx_array_data_int16(ctx)) - case .int32: return convert(mlx_array_data_int32(ctx)) - case .int64: return convert(mlx_array_data_int64(ctx)) - #if !arch(x86_64) - case .float16: return convert(mlx_array_data_float16(ctx)) - #endif - case .float32: return convert(mlx_array_data_float32(ctx)) - case .complex64: - return convert(UnsafePointer>(mlx_array_data_complex64(ctx))) - default: - fatalError("Unable to get asData() for \(self.dtype)") + return data } } diff --git a/Tests/MLXTests/MLXArrayTests.swift b/Tests/MLXTests/MLXArrayTests.swift index e8bc962a..e2844ce8 100644 --- a/Tests/MLXTests/MLXArrayTests.swift +++ b/Tests/MLXTests/MLXArrayTests.swift @@ -30,4 +30,82 @@ class MLXArrayTests: XCTestCase { XCTAssertEqual(a[1][2].item(Int.self), 5) } + func testAsArrayContiguous() { + // read array from contiguous memory + let a = MLXArray(0 ..< 12, [4, 3]) + let b = a.asArray(Int.self) + XCTAssertEqual(b, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + } + + func testAsArrayNonContiguous1() { + // skipping elements via slicing + let a = MLXArray(0 ..< 9, [3, 3]) + + let s = a[0 ..< 2, 1 ..< 3] + assertEqual(s, MLXArray([1, 2, 4, 5], [2, 2])) + + XCTAssertEqual(s.shape, [2, 2]) + + // size and nbytes are the logical size + XCTAssertEqual(s.size, 2 * 2) + XCTAssertEqual(s.nbytes, 2 * 2 * s.itemSize) + + // internal property for counting the physical size of the backing. + // note that the physical size doesn't include the row that is + // sliced out + XCTAssertEqual(s.physicalSize, 3 * 2) + + // evaluating s (the comparison above) will realize the strides. + // if we eamine these before they might be [2, 1] which are the + // "logical" strides + XCTAssertEqual(s.strides, [3, 1]) + + let s_arr = s.asArray(Int32.self) + XCTAssertEqual(s_arr, [1, 2, 4, 5]) + } + + func testAsArrayNonContiguous2() { + // a transpose via strides + let a = MLXArray(0 ..< 12, [4, 3]) + + let s = asStrided(a, [3, 4], strides: [1, 3]) + + let expected: [Int32] = [0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11] + assertEqual(s, MLXArray(expected, [3, 4])) + + // Note: be careful to use the matching type -- if we transcode + // to a different type it will be converted to contiguous + let s_arr = s.asArray(Int32.self) + XCTAssertEqual(s_arr, expected) + } + + func testAsArrayNonContiguous3() { + // reversed via strides -- note that the base pointer for the + // storage has an offset applied to it + let a = MLXArray(0 ..< 9, [3, 3]) + + let s = asStrided(a, [3, 3], strides: [-3, -1], offset: 8) + + let expected: [Int32] = [8, 7, 6, 5, 4, 3, 2, 1, 0] + assertEqual(s, MLXArray(expected, [3, 3])) + + let s_arr = s.asArray(Int32.self) + XCTAssertEqual(s_arr, expected) + } + + func testAsArrayNonContiguous4() { + // buffer with holes (last dimension has stride of 2 and + // thus larger storage than it physically needs) + let a = MLXArray(0 ..< 16, [4, 4]) + let s = a[0..., .stride(by: 2)] + + let expected: [Int32] = [0, 2, 4, 6, 8, 10, 12, 14] + assertEqual(s, MLXArray(expected, [4, 2])) + + XCTAssertEqual(s.strides, [4, 2]) + + let s_arr = s.asArray(Int32.self) + XCTAssertEqual(s_arr, expected) + } + }