diff --git a/Sources/NIO/ByteBuffer-core.swift b/Sources/NIO/ByteBuffer-core.swift index e69eead454..b34b312c7b 100644 --- a/Sources/NIO/ByteBuffer-core.swift +++ b/Sources/NIO/ByteBuffer-core.swift @@ -492,14 +492,32 @@ public struct ByteBuffer { return try body(.init(rebasing: self._slicedStorageBuffer.dropFirst(self.writerIndex))) } + /// This vends a pointer of the `ByteBuffer` at the `writerIndex` after ensuring that the buffer has at least `minimumWritableBytes` of writable bytes available. + /// + /// - warning: Do not escape the pointer from the closure for later use. + /// + /// - parameters: + /// - minimumWritableBytes: The number of writable bytes to reserve capacity for before vending the `ByteBuffer` pointer to `body`. + /// - body: The closure that will accept the yielded bytes and return the number of bytes written. + /// - returns: The number of bytes written. @discardableResult @inlinable - public mutating func writeWithUnsafeMutableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> Int) rethrows -> Int { - let bytesWritten = try withUnsafeMutableWritableBytes(body) + public mutating func writeWithUnsafeMutableBytes(minimumWritableBytes: Int, _ body: (UnsafeMutableRawBufferPointer) throws -> Int) rethrows -> Int { + if minimumWritableBytes > 0 { + self.reserveCapacity(self.writerIndex + minimumWritableBytes) + } + let bytesWritten = try self.withUnsafeMutableWritableBytes(body) self._moveWriterIndex(to: self._writerIndex + _toIndex(bytesWritten)) return bytesWritten } + @available(*, deprecated, message: "please use writeWithUnsafeMutableBytes(minimumWritableBytes:_:) instead to ensure sufficient write capacity.") + @discardableResult + @inlinable + public mutating func writeWithUnsafeMutableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> Int) rethrows -> Int { + return try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0, body) + } + /// This vends a pointer to the storage of the `ByteBuffer`. It's marked as _very unsafe_ because it might contain /// uninitialised memory and it's undefined behaviour to read it. In most cases you should use `withUnsafeReadableBytes`. /// diff --git a/Sources/NIO/NonBlockingFileIO.swift b/Sources/NIO/NonBlockingFileIO.swift index 319c6ac641..76bd54c60c 100644 --- a/Sources/NIO/NonBlockingFileIO.swift +++ b/Sources/NIO/NonBlockingFileIO.swift @@ -188,7 +188,7 @@ public struct NonBlockingFileIO { return self.threadPool.runIfActive(eventLoop: eventLoop) { () -> ByteBuffer in var bytesRead = 0 while bytesRead < byteCount { - let n = try buf.writeWithUnsafeMutableBytes { ptr in + let n = try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: byteCount - bytesRead) { ptr in let res = try fileHandle.withUnsafeFileDescriptor { descriptor in try Posix.read(descriptor: descriptor, pointer: ptr.baseAddress!, diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index d5f6e9de96..bb8ce344c0 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -15,7 +15,7 @@ extension ByteBuffer { mutating func withMutableWritePointer(body: (UnsafeMutableRawBufferPointer) throws -> IOResult) rethrows -> IOResult { var singleResult: IOResult! - _ = try self.writeWithUnsafeMutableBytes { ptr in + _ = try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { ptr in let localWriteResult = try body(ptr) singleResult = localWriteResult switch localWriteResult { diff --git a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift index e711c694f1..0e6f15ad87 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift @@ -191,7 +191,7 @@ class HTTPServerClientTest : XCTestCase { case "/massive-response": var buf = context.channel.allocator.buffer(capacity: HTTPServerClientTest.massiveResponseLength) - buf.writeWithUnsafeMutableBytes { targetPtr in + buf.writeWithUnsafeMutableBytes(minimumWritableBytes: HTTPServerClientTest.massiveResponseLength) { targetPtr in return HTTPServerClientTest.massiveResponseBytes.withUnsafeBytes { srcPtr in precondition(targetPtr.count >= srcPtr.count) targetPtr.copyMemory(from: srcPtr) diff --git a/Tests/NIOTests/ByteBufferTest+XCTest.swift b/Tests/NIOTests/ByteBufferTest+XCTest.swift index 68509fb25e..a35d15913e 100644 --- a/Tests/NIOTests/ByteBufferTest+XCTest.swift +++ b/Tests/NIOTests/ByteBufferTest+XCTest.swift @@ -41,6 +41,8 @@ extension ByteBufferTest { ("testCoWWorks", testCoWWorks), ("testWithMutableReadPointerMovesReaderIndexAndReturnsNumBytesConsumed", testWithMutableReadPointerMovesReaderIndexAndReturnsNumBytesConsumed), ("testWithMutableWritePointerMovesWriterIndexAndReturnsNumBytesWritten", testWithMutableWritePointerMovesWriterIndexAndReturnsNumBytesWritten), + ("testWithMutableWritePointerWithMinimumSpecifiedAdjustsCapacity", testWithMutableWritePointerWithMinimumSpecifiedAdjustsCapacity), + ("testWithMutableWritePointerWithMinimumSpecifiedWhileAtMaxCapacity", testWithMutableWritePointerWithMinimumSpecifiedWhileAtMaxCapacity), ("testSetGetInt8", testSetGetInt8), ("testSetGetInt16", testSetGetInt16), ("testSetGetInt32", testSetGetInt32), diff --git a/Tests/NIOTests/ByteBufferTest.swift b/Tests/NIOTests/ByteBufferTest.swift index feeb9326c9..c0ca0fd9c3 100644 --- a/Tests/NIOTests/ByteBufferTest.swift +++ b/Tests/NIOTests/ByteBufferTest.swift @@ -228,11 +228,58 @@ class ByteBufferTest: XCTestCase { func testWithMutableWritePointerMovesWriterIndexAndReturnsNumBytesWritten() { XCTAssertEqual(0, buf.writerIndex) - let bytesWritten = buf.writeWithUnsafeMutableBytes { (_: UnsafeMutableRawBufferPointer) in return 5 } + let bytesWritten = buf.writeWithUnsafeMutableBytes(minimumWritableBytes: 5) { + XCTAssertTrue($0.count >= 5) + return 5 + } XCTAssertEqual(5, bytesWritten) XCTAssertEqual(5, buf.writerIndex) } + func testWithMutableWritePointerWithMinimumSpecifiedAdjustsCapacity() { + XCTAssertEqual(0, buf.writerIndex) + XCTAssertEqual(512, buf.capacity) + + var bytesWritten = buf.writeWithUnsafeMutableBytes(minimumWritableBytes: 256) { + XCTAssertTrue($0.count >= 256) + return 256 + } + XCTAssertEqual(256, bytesWritten) + XCTAssertEqual(256, buf.writerIndex) + XCTAssertEqual(512, buf.capacity) + + bytesWritten += buf.writeWithUnsafeMutableBytes(minimumWritableBytes: 1024) { + XCTAssertTrue($0.count >= 1024) + return 1024 + } + let expectedBytesWritten = 256 + 1024 + XCTAssertEqual(expectedBytesWritten, bytesWritten) + XCTAssertEqual(expectedBytesWritten, buf.writerIndex) + XCTAssertTrue(buf.capacity >= expectedBytesWritten) + } + + func testWithMutableWritePointerWithMinimumSpecifiedWhileAtMaxCapacity() { + XCTAssertEqual(0, buf.writerIndex) + XCTAssertEqual(512, buf.capacity) + + var bytesWritten = buf.writeWithUnsafeMutableBytes(minimumWritableBytes: 512) { + XCTAssertTrue($0.count >= 512) + return 512 + } + XCTAssertEqual(512, bytesWritten) + XCTAssertEqual(512, buf.writerIndex) + XCTAssertEqual(512, buf.capacity) + + bytesWritten += buf.writeWithUnsafeMutableBytes(minimumWritableBytes: 1) { + XCTAssertTrue($0.count >= 1) + return 1 + } + let expectedBytesWritten = 512 + 1 + XCTAssertEqual(expectedBytesWritten, bytesWritten) + XCTAssertEqual(expectedBytesWritten, buf.writerIndex) + XCTAssertTrue(buf.capacity >= expectedBytesWritten) + } + func testSetGetInt8() throws { try setGetInt(index: 0, v: Int8.max) } @@ -598,7 +645,7 @@ class ByteBufferTest: XCTestCase { let cap = buf.capacity var otherBuf = buf XCTAssertEqual(otherBuf, buf) - otherBuf?.writeWithUnsafeMutableBytes { ptr in + otherBuf?.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { ptr in XCTAssertEqual(cap, ptr.count) let intPtr = ptr.baseAddress!.bindMemory(to: UInt8.self, capacity: ptr.count) for i in 0..