Skip to content

Commit

Permalink
ClientBootstrap: allow binding sockets (#1490)
Browse files Browse the repository at this point in the history
  • Loading branch information
weissi authored Jun 8, 2020
1 parent 5fc487b commit 120acb1
Show file tree
Hide file tree
Showing 11 changed files with 314 additions and 48 deletions.
4 changes: 2 additions & 2 deletions Sources/NIO/BaseSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class BaseSocket: BaseSocketProtocol {
/// - name: The name of the option to set.
/// - value: The value for the option.
/// - throws: An `IOError` if the operation failed.
final func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
func setOption<T>(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: T) throws {
if level == .tcp && name == .tcp_nodelay && (try? self.localAddress().protocol) == Optional<NIOBSDSocket.ProtocolFamily>.some(.unix) {
// setting TCP_NODELAY on UNIX domain sockets will fail. Previously we had a bug where we would ignore
// most socket options settings so for the time being we'll just ignore this. Let's revisit for NIO 2.0.
Expand Down Expand Up @@ -387,7 +387,7 @@ class BaseSocket: BaseSocketProtocol {
/// - parameters:
/// - address: The `SocketAddress` to which the socket should be bound.
/// - throws: An `IOError` if the operation failed.
final func bind(to address: SocketAddress) throws {
func bind(to address: SocketAddress) throws {
try self.withUnsafeHandle { fd in
func doBind(ptr: UnsafePointer<sockaddr>, bytes: Int) throws {
try Posix.bind(descriptor: fd, ptr: ptr, bytes: bytes)
Expand Down
4 changes: 0 additions & 4 deletions Sources/NIO/BaseSocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -645,10 +645,6 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
promise?.fail(ChannelError.ioOnClosedChannel)
return
}
guard self.lifecycleManager.isPreRegistered else {
promise?.fail(ChannelError.inappropriateOperationForState)
return
}

executeAndComplete(promise) {
try socket.bind(to: address)
Expand Down
100 changes: 75 additions & 25 deletions Sources/NIO/Bootstrap.swift
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
internal var _channelOptions: ChannelOptions.Storage
private var connectTimeout: TimeAmount = TimeAmount.seconds(10)
private var resolver: Optional<Resolver>
private var bindTarget: Optional<SocketAddress>

/// Create a `ClientBootstrap` on the `EventLoopGroup` `group`.
///
Expand Down Expand Up @@ -458,6 +459,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
self._channelInitializer = { channel in channel.eventLoop.makeSucceededFuture(()) }
self.protocolHandlers = nil
self.resolver = nil
self.bindTarget = nil
}

/// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add
Expand Down Expand Up @@ -523,6 +525,24 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
return self
}

/// Bind the `SocketChannel` to `address`.
///
/// Using `bind` is not necessary unless you need the local address to be bound to a specific address.
///
/// - note: Using `bind` will disable Happy Eyeballs on this `Channel`.
///
/// - parameters:
/// - address: The `SocketAddress` to bind on.
public func bind(to address: SocketAddress) -> ClientBootstrap {
self.bindTarget = address
return self
}

func makeSocketChannel(eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily) throws -> SocketChannel {
return try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily)
}

/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
///
/// - parameters:
Expand All @@ -531,34 +551,51 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
/// - returns: An `EventLoopFuture<Channel>` to deliver the `Channel` when connected.
public func connect(host: String, port: Int) -> EventLoopFuture<Channel> {
let loop = self.group.next()
let connector = HappyEyeballsConnector(resolver: resolver ?? GetaddrinfoResolver(loop: loop, aiSocktype: .stream, aiProtocol: CInt(IPPROTO_TCP)),
let resolver = self.resolver ?? GetaddrinfoResolver(loop: loop,
aiSocktype: .stream,
aiProtocol: CInt(IPPROTO_TCP))
let connector = HappyEyeballsConnector(resolver: resolver,
loop: loop,
host: host,
port: port,
connectTimeout: self.connectTimeout) { eventLoop, protocolFamily in
return self.execute(eventLoop: eventLoop, protocolFamily: protocolFamily) { $0.eventLoop.makeSucceededFuture(()) }
return self.initializeAndRegisterNewChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) {
$0.eventLoop.makeSucceededFuture(())
}
}
return connector.resolveAndConnect()
}

private func connect(freshChannel channel: Channel, address: SocketAddress) -> EventLoopFuture<Void> {
let connectPromise = channel.eventLoop.makePromise(of: Void.self)
channel.connect(to: address, promise: connectPromise)
let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
channel.close(promise: nil)
}

connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
cancelTask.cancel()
}
return connectPromise.futureResult
}

internal func testOnly_connect(injectedChannel: SocketChannel,
to address: SocketAddress) -> EventLoopFuture<Channel> {
return self.initializeAndRegisterChannel(injectedChannel) { channel in
return self.connect(freshChannel: channel, address: address)
}
}

/// Specify the `address` to connect to for the TCP `Channel` that will be established.
///
/// - parameters:
/// - address: The address to connect to.
/// - returns: An `EventLoopFuture<Channel>` to deliver the `Channel` when connected.
public func connect(to address: SocketAddress) -> EventLoopFuture<Channel> {
return execute(eventLoop: group.next(), protocolFamily: address.protocol) { channel in
let connectPromise = channel.eventLoop.makePromise(of: Void.self)
channel.connect(to: address, promise: connectPromise)
let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
channel.close(promise: nil)
}

connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
cancelTask.cancel()
}
return connectPromise.futureResult
return self.initializeAndRegisterNewChannel(eventLoop: self.group.next(),
protocolFamily: address.protocol) { channel in
return self.connect(freshChannel: channel, address: address)
}
}

Expand All @@ -570,9 +607,9 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
public func connect(unixDomainSocketPath: String) -> EventLoopFuture<Channel> {
do {
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
return connect(to: address)
return self.connect(to: address)
} catch {
return group.next().makeFailedFuture(error)
return self.group.next().makeFailedFuture(error)
}
}

Expand Down Expand Up @@ -615,24 +652,35 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
}
}

private func execute(eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily,
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let channelInitializer = self.channelInitializer
let channelOptions = self._channelOptions

private func initializeAndRegisterNewChannel(eventLoop: EventLoop,
protocolFamily: NIOBSDSocket.ProtocolFamily,
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let channel: SocketChannel
do {
channel = try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily)
channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily)
} catch {
return eventLoop.makeFailedFuture(error)
}
return self.initializeAndRegisterChannel(channel, body)
}

private func initializeAndRegisterChannel(_ channel: SocketChannel,
_ body: @escaping (Channel) -> EventLoopFuture<Void>) -> EventLoopFuture<Channel> {
let channelInitializer = self.channelInitializer
let channelOptions = self._channelOptions
let eventLoop = channel.eventLoop

@inline(__always)
func setupChannel() -> EventLoopFuture<Channel> {
eventLoop.assertInEventLoop()
return channelOptions.applyAllChannelOptions(to: channel).flatMap {
channelInitializer(channel)
if let bindTarget = self.bindTarget {
return channel.bind(to: bindTarget).flatMap {
channelInitializer(channel)
}
} else {
return channelInitializer(channel)
}
}.flatMap {
eventLoop.assertInEventLoop()
return channel.registerAndDoSynchronously(body)
Expand All @@ -647,7 +695,9 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol {
if eventLoop.inEventLoop {
return setupChannel()
} else {
return eventLoop.submit(setupChannel).flatMap { $0 }
return eventLoop.flatSubmit {
setupChannel()
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions Tests/NIOTests/BootstrapTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ extension BootstrapTest {
("testDatagramBootstrapRejectsNotWorkingELGsCorrectly", testDatagramBootstrapRejectsNotWorkingELGsCorrectly),
("testNIOPipeBootstrapValidatesWorkingELGsCorrectly", testNIOPipeBootstrapValidatesWorkingELGsCorrectly),
("testNIOPipeBootstrapRejectsNotWorkingELGsCorrectly", testNIOPipeBootstrapRejectsNotWorkingELGsCorrectly),
("testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only", testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only),
]
}
}
Expand Down
82 changes: 82 additions & 0 deletions Tests/NIOTests/BootstrapTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,88 @@ class BootstrapTest: XCTestCase {
XCTAssertNil(NIOPipeBootstrap(validatingGroup: elg))
XCTAssertNil(NIOPipeBootstrap(validatingGroup: el))
}

func testClientBindWorksOnSocketsBoundToEitherIPv4OrIPv6Only() {
for isIPv4 in [true, false] {
guard System.supportsIPv6 || isIPv4 else {
continue // need to skip IPv6 tests if we don't support it.
}
let localIP = isIPv4 ? "127.0.0.1" : "::1"
guard let serverLocalAddressChoice = try? SocketAddress(ipAddress: localIP, port: 0),
let clientLocalAddressWholeInterface = try? SocketAddress(ipAddress: localIP, port: 0),
let server1 = (try? ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
.bind(to: serverLocalAddressChoice)
.wait()),
let server2 = (try? ServerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.serverChannelOption(ChannelOptions.maxMessagesPerRead, value: 1)
.bind(to: serverLocalAddressChoice)
.wait()),
let server1LocalAddress = server1.localAddress,
let server2LocalAddress = server2.localAddress else {
XCTFail("can't boot servers even")
return
}
defer {
XCTAssertNoThrow(try server1.close().wait())
XCTAssertNoThrow(try server2.close().wait())
}

// Try 1: Directly connect to 127.0.0.1, this won't do Happy Eyeballs.
XCTAssertNoThrow(try ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.bind(to: clientLocalAddressWholeInterface)
.connect(to: server1LocalAddress)
.wait()
.close()
.wait())

var maybeChannel1: Channel? = nil
// Try 2: Connect to "localhost", this will do Happy Eyeballs.
XCTAssertNoThrow(maybeChannel1 = try ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.bind(to: clientLocalAddressWholeInterface)
.connect(host: "localhost", port: server1LocalAddress.port!)
.wait())
guard let myChannel1 = maybeChannel1, let myChannel1Address = myChannel1.localAddress else {
XCTFail("can't connect channel 1")
return
}
XCTAssertEqual(localIP, maybeChannel1?.localAddress?.ipAddress)
// Try 3: Bind the client to the same address/port as in try 2 but to server 2.
XCTAssertNoThrow(try ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.connectTimeout(.hours(2))
.bind(to: myChannel1Address)
.connect(to: server2LocalAddress)
.map { channel -> Channel in
XCTAssertEqual(myChannel1Address, channel.localAddress)
return channel
}
.wait()
.close()
.wait())
}
}
}

private final class WriteStringOnChannelActive: ChannelInboundHandler {
typealias InboundIn = Never
typealias OutboundOut = ByteBuffer

let string: String

init(_ string: String) {
self.string = string
}

func channelActive(context: ChannelHandlerContext) {
var buffer = context.channel.allocator.buffer(capacity: self.string.utf8.count)
buffer.writeString(string)
context.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
}

private final class MakeSureAutoReadIsOffInChannelInitializer: ChannelInboundHandler {
Expand Down
2 changes: 1 addition & 1 deletion Tests/NIOTests/ChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2012,7 +2012,7 @@ public final class ChannelTests: XCTestCase {
XCTAssertFalse(channel.isWritable)
}

withChannel { channel in
withChannel(skipStream: true) { channel in
checkThatItThrowsInappropriateOperationForState {
XCTAssertEqual(0, channel.localAddress?.port ?? 0xffff)
XCTAssertNil(channel.remoteAddress)
Expand Down
13 changes: 2 additions & 11 deletions Tests/NIOTests/MulticastTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ final class MulticastTest: XCTestCase {

struct ReceivedDatagramError: Error { }

private var supportsIPv6: Bool {
do {
let ipv6Loopback = try SocketAddress(ipAddress: "::1", port: 0)
return try System.enumerateInterfaces().contains(where: { $0.address == ipv6Loopback })
} catch {
return false
}
}

private func interfaceForAddress(address: String) throws -> NIONetworkInterface {
let targetAddress = try SocketAddress(ipAddress: address, port: 0)
guard let interface = try System.enumerateInterfaces().lazy.filter({ $0.address == targetAddress }).first else {
Expand Down Expand Up @@ -220,7 +211,7 @@ final class MulticastTest: XCTestCase {
}

func testCanJoinBasicMulticastGroupIPv6() throws {
guard self.supportsIPv6 else {
guard System.supportsIPv6 else {
// Skip on non-IPv6 systems
return
}
Expand Down Expand Up @@ -317,7 +308,7 @@ final class MulticastTest: XCTestCase {
}

func testCanLeaveAnIPv6MulticastGroup() throws {
guard self.supportsIPv6 else {
guard System.supportsIPv6 else {
// Skip on non-IPv6 systems
return
}
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOTests/SALChannelTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ extension SALChannelTest {
("testWritesFromWritabilityNotificationsDoNotGetLostIfWePreviouslyWroteEverything", testWritesFromWritabilityNotificationsDoNotGetLostIfWePreviouslyWroteEverything),
("testWeSurviveIfIgnoringSIGPIPEFails", testWeSurviveIfIgnoringSIGPIPEFails),
("testBasicRead", testBasicRead),
("testBasicConnectWithClientBootstrap", testBasicConnectWithClientBootstrap),
("testClientBootstrapBindIsDoneAfterSocketOptions", testClientBootstrapBindIsDoneAfterSocketOptions),
]
}
}
Expand Down
Loading

0 comments on commit 120acb1

Please sign in to comment.