Skip to content

Commit

Permalink
HTTPDecoder: Don't invoke left overs strategy on EOF (#1061)
Browse files Browse the repository at this point in the history
Motivation:

Do as we documented and invoke the leftovers strategy only when the
handler is removed.

Modifications:

Don't invoke leftovers on EOF (which isn't something the user can
control).

Result:

- Fewer crashes
- should fix Kitura/Kitura-WebSocket-NIO#35
  • Loading branch information
weissi authored and Lukasa committed Jul 10, 2019
1 parent ed3bca3 commit 7a250a2
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 7 deletions.
20 changes: 13 additions & 7 deletions Sources/NIOHTTP1/HTTPDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,9 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
/// Creates a new instance of `HTTPDecoder`.
///
/// - parameters:
/// - leftOverBytesStrategy: the strategy to use when removing the decoder from the pipeline and an upgrade was detected
/// - leftOverBytesStrategy: The strategy to use when removing the decoder from the pipeline and an upgrade was,
/// detected. Note that this does not affect what happens on EOF (in which case an
/// `ByteToMessageDecoderError.leftoverDataWhenDone` error is fired.)
public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) {
self.headers.reserveCapacity(16)
if In.self == HTTPServerRequestPart.self {
Expand Down Expand Up @@ -619,13 +621,17 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
}
}
if buffer.readableBytes > 0 {
switch self.leftOverBytesStrategy {
case .dropBytes:
()
case .fireError:
if seenEOF {
context.fireErrorCaught(ByteToMessageDecoderError.leftoverDataWhenDone(buffer))
case .forwardBytes:
context.fireChannelRead(NIOAny(buffer))
} else {
switch self.leftOverBytesStrategy {
case .dropBytes:
()
case .fireError:
context.fireErrorCaught(ByteToMessageDecoderError.leftoverDataWhenDone(buffer))
case .forwardBytes:
context.fireChannelRead(NIOAny(buffer))
}
}
}
return .needMoreData
Expand Down
4 changes: 4 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ extension HTTPDecoderTest {
("testDoesNotDeliverLeftoversUnnecessarily", testDoesNotDeliverLeftoversUnnecessarily),
("testHTTPResponseWithoutHeaders", testHTTPResponseWithoutHeaders),
("testBasicVerifications", testBasicVerifications),
("testErrorFiredOnEOFForLeftOversInAllLeftOversModes", testErrorFiredOnEOFForLeftOversInAllLeftOversModes),
("testBytesCanBeForwardedWhenHandlerRemoved", testBytesCanBeForwardedWhenHandlerRemoved),
("testBytesCanBeFiredAsErrorWhenHandlerRemoved", testBytesCanBeFiredAsErrorWhenHandlerRemoved),
("testBytesCanBeDroppedWhenHandlerRemoved", testBytesCanBeDroppedWhenHandlerRemoved),
]
}
}
Expand Down
172 changes: 172 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPDecoderTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -593,4 +593,176 @@ class HTTPDecoderTest: XCTestCase {
XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: expectedInOutsBB,
decoderFactory: { HTTPRequestDecoder() }))
}

func testErrorFiredOnEOFForLeftOversInAllLeftOversModes() throws {
class Receiver: ChannelInboundHandler {
typealias InboundIn = HTTPServerRequestPart

private let errorReceivedPromise: EventLoopPromise<ByteToMessageDecoderError>
private var numberOfErrors = 0

init(errorReceivedPromise: EventLoopPromise<ByteToMessageDecoderError>) {
self.errorReceivedPromise = errorReceivedPromise
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
self.numberOfErrors += 1
if self.numberOfErrors == 1, let error = error as? ByteToMessageDecoderError {
self.errorReceivedPromise.succeed(error)
} else {
XCTFail("illegal: number of errors: \(self.numberOfErrors), error: \(error)")
}
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
switch part {
case .head(let head):
XCTAssertEqual(.OPTIONS, head.method)
case .body:
XCTFail("unexpected .body part")
case .end:
()
}
}
}

for leftOverBytesStrategy in [RemoveAfterUpgradeStrategy.dropBytes, .fireError, .forwardBytes] {
let channel = EmbeddedChannel()
let errorReceivedPromise: EventLoopPromise<ByteToMessageDecoderError> = channel.eventLoop.makePromise()
var buffer = channel.allocator.buffer(capacity: 64)
buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: L\r\nUpgrade: P\r\nConnection: upgrade\r\n\r\nXXXX")

let decoder = HTTPRequestDecoder(leftOverBytesStrategy: leftOverBytesStrategy)
XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(decoder)).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(Receiver(errorReceivedPromise: errorReceivedPromise)).wait())
XCTAssertNoThrow(try channel.writeInbound(buffer))
XCTAssertNoThrow(XCTAssert(try channel.finish().isClean))

switch Result(catching: { try errorReceivedPromise.futureResult.wait() }) {
case .success(ByteToMessageDecoderError.leftoverDataWhenDone(let buffer)):
XCTAssertEqual("XXXX", String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self))
case .failure(let error):
XCTFail("unexpected error: \(error)")
case .success(let error):
XCTFail("unexpected error: \(error)")
}
}
}

func testBytesCanBeForwardedWhenHandlerRemoved() throws {
class Receiver: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
switch part {
case .head(let head):
XCTAssertEqual(.OPTIONS, head.method)
case .body:
XCTFail("unexpected .body part")
case .end:
()
}
}
}

let channel = EmbeddedChannel()
var buffer = channel.allocator.buffer(capacity: 64)
buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: L\r\nUpgrade: P\r\nConnection: upgrade\r\n\r\nXXXX")

let receiver = Receiver()
let decoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes))
XCTAssertNoThrow(try channel.pipeline.addHandler(decoder).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(receiver).wait())
XCTAssertNoThrow(try channel.writeInbound(buffer))
let removalFutures = [ channel.pipeline.removeHandler(receiver), channel.pipeline.removeHandler(decoder) ]
channel.embeddedEventLoop.run()
try removalFutures.forEach {
XCTAssertNoThrow(try $0.wait())
}
XCTAssertNoThrow(XCTAssertEqual("XXXX", try channel.readInbound(as: ByteBuffer.self).map {
String(decoding: $0.readableBytesView, as: Unicode.UTF8.self)
}))
XCTAssertNoThrow(XCTAssert(try channel.finish().isClean))
}

func testBytesCanBeFiredAsErrorWhenHandlerRemoved() throws {
class Receiver: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
switch part {
case .head(let head):
XCTAssertEqual(.OPTIONS, head.method)
case .body:
XCTFail("unexpected .body part")
case .end:
()
}
}
}

let channel = EmbeddedChannel()
var buffer = channel.allocator.buffer(capacity: 64)
buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: L\r\nUpgrade: P\r\nConnection: upgrade\r\n\r\nXXXX")

let receiver = Receiver()
let decoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .fireError))
XCTAssertNoThrow(try channel.pipeline.addHandler(decoder).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(receiver).wait())
XCTAssertNoThrow(try channel.writeInbound(buffer))
let removalFutures = [ channel.pipeline.removeHandler(receiver), channel.pipeline.removeHandler(decoder) ]
channel.embeddedEventLoop.run()
try removalFutures.forEach {
XCTAssertNoThrow(try $0.wait())
}
XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in
switch error as? ByteToMessageDecoderError {
case .some(ByteToMessageDecoderError.leftoverDataWhenDone(let buffer)):
XCTAssertEqual("XXXX", String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self))
case .some(let error):
XCTFail("unexpected error: \(error)")
case .none:
XCTFail("unexpected error")
}
}
XCTAssertNoThrow(XCTAssert(try channel.finish().isClean))
}

func testBytesCanBeDroppedWhenHandlerRemoved() throws {
class Receiver: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
switch part {
case .head(let head):
XCTAssertEqual(.OPTIONS, head.method)
case .body:
XCTFail("unexpected .body part")
case .end:
()
}
}
}

let channel = EmbeddedChannel()
var buffer = channel.allocator.buffer(capacity: 64)
buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: L\r\nUpgrade: P\r\nConnection: upgrade\r\n\r\nXXXX")

let receiver = Receiver()
let decoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .dropBytes))
XCTAssertNoThrow(try channel.pipeline.addHandler(decoder).wait())
XCTAssertNoThrow(try channel.pipeline.addHandler(receiver).wait())
XCTAssertNoThrow(try channel.writeInbound(buffer))
let removalFutures = [ channel.pipeline.removeHandler(receiver), channel.pipeline.removeHandler(decoder) ]
channel.embeddedEventLoop.run()
try removalFutures.forEach {
XCTAssertNoThrow(try $0.wait())
}
XCTAssertNoThrow(XCTAssert(try channel.finish().isClean))
}

}

0 comments on commit 7a250a2

Please sign in to comment.