From f27763bef455d76f9455e9dfc6704a6b2859fa26 Mon Sep 17 00:00:00 2001 From: David Koski <46639364+davidkoski@users.noreply.github.com> Date: Tue, 10 Sep 2024 08:00:44 -0700 Subject: [PATCH] replacement for asData() to handle strided backing. (#130) * replacement for asData() to handle strided backing. - fixes #117 - asData() now returns a struct with stride information - takes a policy enum to indicate whether the consumer wants contiguous data --- Source/MLX/MLXArray+Bytes.swift | 314 +++++++++++++++++++++++++++++ Source/MLX/MLXArray.swift | 205 ++----------------- Tests/MLXTests/MLXArrayTests.swift | 75 ++++++- 3 files changed, 400 insertions(+), 194 deletions(-) create mode 100644 Source/MLX/MLXArray+Bytes.swift diff --git a/Source/MLX/MLXArray+Bytes.swift b/Source/MLX/MLXArray+Bytes.swift new file mode 100644 index 00000000..99816402 --- /dev/null +++ b/Source/MLX/MLXArray+Bytes.swift @@ -0,0 +1,314 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx +import Foundation +import Metal + +// MARK: - Backing / Bytes + +extension MLXArray { + + /// Return the dimension where the storage is contiguous. + /// + /// If this returns 0 then the whole storage is contiguous. If it returns ndmin + 1 then none of it is contiguous. + func contiguousToDimension() -> Int { + let shape = self.shape + let strides = self.internalStrides + + var expectedStride = 1 + + for (dimension, (shape, stride)) in zip(shape, strides).enumerated().reversed() { + // as long as the actual strides match the expected (contiguous) strides + // the backing is contiguous in these dimensions + if stride != expectedStride { + return dimension + 1 + } + expectedStride *= shape + } + + return 0 + } + + /// Return the physical size of the backing (assuming it is evaluated) in elements. This should + /// only be used when accessing the backing directly, e.g. via `mlx_array_data_uint8()` + var physicalSize: Int { + // nbytes is the logical size of the input, not the physical size + return zip(self.shape, self.internalStrides) + .map { Swift.abs($0.0 * $0.1) } + .max() + ?? self.size + } + + func copy(from: UnsafeRawBufferPointer, toContiguous output: UnsafeMutableRawBufferPointer) { + let contiguousDimension = self.contiguousToDimension() + let shape = self.shape + let strides = self.internalStrides + + if contiguousDimension == 0 { + // entire backing is contiguous + from.copyBytes(to: output) + + } else { + // only part of the backing is contiguous (possibly a single element) + // iterate the non-contiguous parts and copy the contiguous chunks into + // the output. + + // these are the parts to iterate + let ndim = contiguousDimension + let itemSize = self.itemSize + + // the size of each chunk that we copy. this computes the stride of + // (contiguousDimension - 1) if it were contiguous + let destItemSize: Int + if contiguousDimension == self.ndim { + // nothing contiguous + destItemSize = itemSize + } else { + destItemSize = strides[contiguousDimension] * shape[contiguousDimension] * itemSize + } + + // the index of the current source item + var index = Array.init(repeating: 0, count: ndim) + + // output pointer + var dest = output.baseAddress! + + while true { + // compute the source index by multiplying the index by the + // stride for each dimension + + // note: in the case where the array has negative strides / offset + // the base pointer we have will have the offset already applied, + // e.g. asStrided(a, [3, 3], strides: [-3, -1], offset: 8) + + let sourceIndex = zip(index, strides).reduce(0) { $0 + ($1.0 * $1.1) } + + // convert to byte pointer + let src = from.baseAddress! + sourceIndex * itemSize + dest.copyMemory(from: src, byteCount: destItemSize) + + // next output address + dest += destItemSize + + // increment the index + for dimension in Swift.stride(from: ndim - 1, through: 0, by: -1) { + // do we need to "carry" into the next dimension? + if index[dimension] == (shape[dimension] - 1) { + if dimension == 0 { + // all done + return + } + + index[dimension] = 0 + } else { + // just increment the dimension and we are done + index[dimension] += 1 + break + } + } + } + + } + } + + /// Return the contents as a single contiguous 1d `Swift.Array`. + /// + /// Note: because the number of dimensions is dynamic, this cannot produce a multi-dimensional + /// array. + /// + /// ### See Also + /// - + /// - ``asData(noCopy:)`` + /// - ``asMTLBuffer(device:noCopy:)`` + public func asArray(_ type: T.Type) -> [T] { + if type.dtype != self.dtype { + return self.asType(type).asArray(type) + } + + self.eval() + + return [T](unsafeUninitializedCapacity: self.size) { destination, initializedCount in + let source = UnsafeRawBufferPointer( + start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) + copy(from: source, toContiguous: UnsafeMutableRawBufferPointer(destination)) + initializedCount = self.size + } + } + + /// How to access backing data with ``asData(access:)`` -- this controls how + /// ``MLXArrayData`` is produced. + public enum AccessMethod { + /// Create a contiguous copy of the data backing the ``MLXArray``. The lifetime of this data + /// independent of the lifetime of the ``MLXArray``. + case copy + + /// Return contiguous data from the backing of the ``MLXArray``, avoiding a copy if possible. + /// The lifetime of the result is valid only while keeping the ``MLXArray`` alive. This is best + /// for temporary copies, e.g. creating and writing an image to disk. + case noCopyIfContiguous + + /// Return a wrapper around the backing of the ``MLXArray``. This might not be + /// contiguous -- the caller must examine ``MLXArrayData/strides`` + /// The lifetime of the result is valid only while keeping the ``MLXArray`` alive. This is best + /// for temporary copies, e.g. creating and writing an image to disk. + case noCopy + } + + /// Container for ``Data`` backing of ``MLXArray``. + /// + /// ### See Also + /// - ``MLXArray/asData(access:)`` + public struct MLXArrayData { + /// The bytes backing the ``MLXArray``. + /// + /// The ``MLXArray/AccessMethod`` passed to ``MLXArray/asData(access:)`` + /// controls the lifetime and potential layout of this data (e.g. contiguous vs non-contiguous). + public let data: Data + + /// Dimensions of the data when viewed as ``dType`` + public let shape: [Int] + + /// Strides of the data in terms of the ``dType`` + public let strides: [Int] + + /// The layout of a single items in the ``data`` + public let dType: DType + } + + /// return a copy of the backing in contiguous layout + private func asDataCopy() -> MLXArrayData { + // point into the possibly non-contiguous backing + let source = UnsafeRawBufferPointer( + start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) + + var data = Data(count: self.nbytes) + data.withUnsafeMutableBytes { destination in + copy(from: source, toContiguous: destination) + } + return MLXArrayData( + data: data, + shape: self.shape, strides: contiguousStrides(shape: self.shape), + dType: self.dtype) + } + + /// Return the contents as ``Data`` bytes in the native ``dtype``. + /// + /// > If you use ``AccessMethod/noCopy`` or ``AccessMethod/noCopyIfContiguous`` you + /// must guarantee that the lifetime of the ``MLXArray`` exceeds the lifetime of the result. + /// + /// Callers can specify an ``AccessMethod`` to cause the data to be in _contiguous_ memory + /// vs. whatever layout the backing actually has. Callers can use ``MLXArrayData/strides`` + /// if the data is not contiguous. + /// + /// By default it will use ``AccessMethod/copy`` and will return a copy of the data in + /// contiguous memory. + /// + /// ### See Also + /// - + /// - ``asArray(_:)`` + /// - ``asMTLBuffer(device:noCopy:)`` + public func asData(access: AccessMethod = .copy) -> MLXArrayData { + self.eval() + + switch access { + case .copy: + return asDataCopy() + + case .noCopyIfContiguous: + if self.contiguousToDimension() == 0 { + // the backing is contiguous, we can provide a wrapper + // for the contents without a copy + let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))! + let data = Data( + bytesNoCopy: source, count: size * itemSize, + deallocator: .none) + + return MLXArrayData( + data: data, + shape: self.shape, strides: contiguousStrides(shape: self.shape), + dType: self.dtype) + + } else { + // not contiguous + return asDataCopy() + } + + case .noCopy: + let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))! + let data = Data( + bytesNoCopy: source, count: nbytes, + deallocator: .none) + + let strides: [Int] + if ndim == 0 { + strides = [] + } else { + let internalStrides = mlx_array_strides(ctx)! + strides = (0 ..< ndim).map { Int(internalStrides[$0]) } + } + + return MLXArrayData( + data: data, + shape: self.shape, strides: strides, + dType: self.dtype) + } + } + + /// Return the contents as contiguous bytes in the native ``dtype``. + /// + /// > If you can guarantee the lifetime of the ``MLXArray`` will exceed the Data and that + /// the array will not be mutated (e.g. using indexing or other means) it is possible to pass `noCopy: true` + /// to reference the backing bytes. + /// + /// **Replaced with** ``asData(access:)`` + /// + /// ### See Also + /// - + /// - ``asArray(_:)`` + /// - ``asMTLBuffer(device:noCopy:)`` + @available(*, deprecated, message: "use asData(acccess: .copy)") + public func asData(noCopy: Bool = false, disambiguate: Bool = false) -> Data { + self.eval() + + return asData(access: noCopy ? .noCopyIfContiguous : .copy) + .data + } + + /// Return the contents as a Metal buffer in the native ``dtype``. + /// + /// > If you can guarantee the lifetime of the ``MLXArray`` will exceed the MTLBuffer and that + /// the array will not be mutated (e.g. using indexing or other means) it is possible to pass `noCopy: true` + /// to reference the backing bytes. + /// + /// ### See Also + /// - + /// - ``asArray(_:)`` + /// - ``asData(noCopy:)`` + public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? { + self.eval() + + if noCopy && self.contiguousToDimension() == 0 { + // the backing is contiguous, we can provide a wrapper + // for the contents without a copy (if requested) + let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))! + return device.makeBuffer(bytesNoCopy: source, length: self.nbytes) + } else { + let source = UnsafeRawBufferPointer( + start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) + return device.makeBuffer(bytes: source.baseAddress!, length: self.nbytes) + } + } + +} + +/// Return the strides for contiguous memory +func contiguousStrides(shape: [Int]) -> [Int] { + var result = [Int]() + var current = 1 + for d in shape.reversed() { + result.append(current) + current *= d + } + result.reverse() + return result +} diff --git a/Source/MLX/MLXArray.swift b/Source/MLX/MLXArray.swift index 4caab774..9c42e6b9 100644 --- a/Source/MLX/MLXArray.swift +++ b/Source/MLX/MLXArray.swift @@ -120,13 +120,10 @@ public final class MLXArray { return (Int(cShape[0]), Int(cShape[1]), Int(cShape[2]), Int(cShape[3])) } - /// Strides of the array. - /// - /// ```swift - /// let array = MLXArray(0 ..< 12, [3, 4]) - /// print(array.strides) - /// // [4, 1] - /// ``` + /// Strides of the array. Note: do not use this as it changes + /// before and after evaluation. See also ``asData(access:)`` + /// and ``MLXArray/MLXArrayData/strides``. + @available(*, deprecated, message: "Do not use -- see asData(access:)") public var strides: [Int] { let ndim = mlx_array_ndim(ctx) guard ndim > 0 else { return [] } @@ -134,6 +131,16 @@ public final class MLXArray { return (0 ..< ndim).map { Int(strides[$0]) } } + /// Strides of the array backing. + /// + /// Note: this is only stable once the array is evaluated. + var internalStrides: [Int] { + let ndim = mlx_array_ndim(ctx) + guard ndim > 0 else { return [] } + let strides = mlx_array_strides(ctx)! + return (0 ..< ndim).map { Int(strides[$0]) } + } + /// Return the scalar value of the array. /// /// It is a contract violation to call this on an array with more than one element @@ -336,190 +343,6 @@ public final class MLXArray { asType(T.dtype, stream: stream) } - /// Return the dimension where the storage is contiguous. - /// - /// If this returns 0 then the whole storage is contiguous. If it returns ndmin + 1 then none of it is contiguous. - func contiguousToDimension() -> Int { - let shape = self.shape - let strides = self.strides - - var expectedStride = 1 - - for (dimension, (shape, stride)) in zip(shape, strides).enumerated().reversed() { - // as long as the actual strides match the expected (contiguous) strides - // the backing is contiguous in these dimensions - if stride != expectedStride { - return dimension + 1 - } - expectedStride *= shape - } - - return 0 - } - - /// Return the physical size of the backing (assuming it is evaluated) in elements - var physicalSize: Int { - // nbytes is the logical size of the input, not the physical size - return zip(self.shape, self.strides) - .map { Swift.abs($0.0 * $0.1) } - .max() - ?? self.size - } - - func copy(from: UnsafeRawBufferPointer, to output: UnsafeMutableRawBufferPointer) { - let contiguousDimension = self.contiguousToDimension() - - if contiguousDimension == 0 { - // entire backing is contiguous - from.copyBytes(to: output) - - } else { - // only part of the backing is contiguous (possibly a single element) - // iterate the non-contiguous parts and copy the contiguous chunks into - // the output. - - // these are the parts to iterate - let shape = self.shape.prefix(upTo: contiguousDimension) - let strides = self.strides.prefix(upTo: contiguousDimension) - let ndim = contiguousDimension - let itemSize = self.itemSize - - // the size of each chunk that we copy. this computes the stride of - // (contiguousDimension - 1) if it were contiguous - let destItemSize: Int - if contiguousDimension == self.ndim { - // nothing contiguous - destItemSize = itemSize - } else { - destItemSize = - self.strides[contiguousDimension] * self.shape[contiguousDimension] * itemSize - } - - // the index of the current source item - var index = Array.init(repeating: 0, count: ndim) - - // output pointer - var dest = output.baseAddress! - - while true { - // compute the source index by multiplying the index by the - // stride for each dimension - - // note: in the case where the array has negative strides / offset - // the base pointer we have will have the offset already applied, - // e.g. asStrided(a, [3, 3], strides: [-3, -1], offset: 8) - - let sourceIndex = zip(index, strides).reduce(0) { $0 + ($1.0 * $1.1) } - - // convert to byte pointer - let src = from.baseAddress! + sourceIndex * itemSize - dest.copyMemory(from: src, byteCount: destItemSize) - - // next output address - dest += destItemSize - - // increment the index - for dimension in Swift.stride(from: ndim - 1, through: 0, by: -1) { - // do we need to "carry" into the next dimension? - if index[dimension] == (shape[dimension] - 1) { - if dimension == 0 { - // all done - return - } - - index[dimension] = 0 - } else { - // just increment the dimension and we are done - index[dimension] += 1 - break - } - } - } - - } - } - - /// Return the contents as a single contiguous 1d `Swift.Array`. - /// - /// Note: because the number of dimensions is dynamic, this cannot produce a multi-dimensional - /// array. - /// - /// ### See Also - /// - - /// - ``asData(noCopy:)`` - /// - ``asMTLBuffer(device:noCopy:)`` - public func asArray(_ type: T.Type) -> [T] { - if type.dtype != self.dtype { - return self.asType(type).asArray(type) - } - - self.eval() - - return [T](unsafeUninitializedCapacity: self.size) { destination, initializedCount in - let source = UnsafeRawBufferPointer( - start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) - copy(from: source, to: UnsafeMutableRawBufferPointer(destination)) - initializedCount = self.size - } - } - - /// Return the contents as contiguous bytes in the native ``dtype``. - /// - /// > If you can guarantee the lifetime of the ``MLXArray`` will exceed the Data and that - /// the array will not be mutated (e.g. using indexing or other means) it is possible to pass `noCopy: true` - /// to reference the backing bytes. - /// - /// ### See Also - /// - - /// - ``asArray(_:)`` - /// - ``asMTLBuffer(device:noCopy:)`` - public func asData(noCopy: Bool = false) -> Data { - self.eval() - - if noCopy && self.contiguousToDimension() == 0 { - // the backing is contiguous, we can provide a wrapper - // for the contents without a copy (if requested) - let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))! - return Data( - bytesNoCopy: source, count: self.nbytes, - deallocator: .none) - } else { - let source = UnsafeRawBufferPointer( - start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) - - var data = Data(count: self.nbytes) - data.withUnsafeMutableBytes { destination in - copy(from: source, to: destination) - } - return data - } - } - - /// Return the contents as a Metal buffer in the native ``dtype``. - /// - /// > If you can guarantee the lifetime of the ``MLXArray`` will exceed the MTLBuffer and that - /// the array will not be mutated (e.g. using indexing or other means) it is possible to pass `noCopy: true` - /// to reference the backing bytes. - /// - /// ### See Also - /// - - /// - ``asArray(_:)`` - /// - ``asData(noCopy:)`` - public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? { - self.eval() - - if noCopy && self.contiguousToDimension() == 0 { - // the backing is contiguous, we can provide a wrapper - // for the contents without a copy (if requested) - let source = UnsafeMutableRawPointer(mutating: mlx_array_data_uint8(self.ctx))! - return device.makeBuffer(bytesNoCopy: source, length: self.nbytes) - } else { - let source = UnsafeRawBufferPointer( - start: mlx_array_data_uint8(self.ctx), count: physicalSize * itemSize) - return device.makeBuffer(bytes: source.baseAddress!, length: self.nbytes) - } - } - /// Convert the real array into a ``DType/complex64`` imaginary part. /// /// This is equivalent to the following in python: diff --git a/Tests/MLXTests/MLXArrayTests.swift b/Tests/MLXTests/MLXArrayTests.swift index e2844ce8..9bc5b18f 100644 --- a/Tests/MLXTests/MLXArrayTests.swift +++ b/Tests/MLXTests/MLXArrayTests.swift @@ -56,9 +56,9 @@ class MLXArrayTests: XCTestCase { XCTAssertEqual(s.physicalSize, 3 * 2) // evaluating s (the comparison above) will realize the strides. - // if we eamine these before they might be [2, 1] which are the + // if we examine these before they might be [2, 1] which are the // "logical" strides - XCTAssertEqual(s.strides, [3, 1]) + XCTAssertEqual(s.internalStrides, [3, 1]) let s_arr = s.asArray(Int32.self) XCTAssertEqual(s_arr, [1, 2, 4, 5]) @@ -102,10 +102,79 @@ class MLXArrayTests: XCTestCase { let expected: [Int32] = [0, 2, 4, 6, 8, 10, 12, 14] assertEqual(s, MLXArray(expected, [4, 2])) - XCTAssertEqual(s.strides, [4, 2]) + XCTAssertEqual(s.internalStrides, [4, 2]) let s_arr = s.asArray(Int32.self) XCTAssertEqual(s_arr, expected) } + func testContiguousStrides() { + XCTAssertEqual(contiguousStrides(shape: [1, 1, 1]), [1, 1, 1]) + XCTAssertEqual(contiguousStrides(shape: [4, 4]), [4, 1]) + XCTAssertEqual(contiguousStrides(shape: [4, 2]), [2, 1]) + XCTAssertEqual(contiguousStrides(shape: [3, 2, 1, 5]), [10, 5, 5, 1]) + } + + func testAsDataContiguous() { + // contiguous source + let a = MLXArray(0 ..< 16, [4, 4]) + + let result = a.asData() + XCTAssertEqual(result.shape, [4, 4]) + XCTAssertEqual(result.strides, [4, 1]) + XCTAssertEqual(result.dType, .int32) + } + + func testAsDataContiguousNoCopy() { + // contiguous source + let a = MLXArray(0 ..< 16, [4, 4]) + + do { + let result = a.asData(access: .noCopy) + XCTAssertEqual(result.shape, [4, 4]) + XCTAssertEqual(result.strides, [4, 1]) + XCTAssertEqual(result.dType, .int32) + } + do { + let result = a.asData(access: .noCopyIfContiguous) + XCTAssertEqual(result.shape, [4, 4]) + XCTAssertEqual(result.strides, [4, 1]) + XCTAssertEqual(result.dType, .int32) + } + } + + func testAsDataNonContiguous() { + // buffer with holes (last dimension has stride of 2 and + // thus larger storage than it physically needs) + let a = MLXArray(0 ..< 16, [4, 4]) + let s = a[0..., .stride(by: 2)] + + let result = s.asData() + XCTAssertEqual(result.shape, [4, 2]) + XCTAssertEqual(result.strides, [2, 1]) + XCTAssertEqual(result.dType, .int32) + } + + func testAsDataNonContiguousNoCopy() { + // buffer with holes (last dimension has stride of 2 and + // thus larger storage than it physically needs) + let a = MLXArray(0 ..< 16, [4, 4]) + let s = a[0..., .stride(by: 2)] + + do { + // the strides will match the strided array + let result = s.asData(access: .noCopy) + XCTAssertEqual(result.shape, [4, 2]) + XCTAssertEqual(result.strides, [4, 2]) + XCTAssertEqual(result.dType, .int32) + } + do { + // it isn't contiguous so we will get a contiguous copy + let result = s.asData(access: .noCopyIfContiguous) + XCTAssertEqual(result.shape, [4, 2]) + XCTAssertEqual(result.strides, [2, 1]) + XCTAssertEqual(result.dType, .int32) + } + } + }