diff --git a/packages/preview2-shim/lib/io/worker-socket-udp.js b/packages/preview2-shim/lib/io/worker-socket-udp.js index 706590f89..e71f31cad 100644 --- a/packages/preview2-shim/lib/io/worker-socket-udp.js +++ b/packages/preview2-shim/lib/io/worker-socket-udp.js @@ -2,10 +2,62 @@ * @typedef {import("../../types/interfaces/wasi-sockets-network").IpAddressFamily} IpAddressFamily */ import { createSocket } from "node:dgram"; -import { openedSockets } from "./worker-thread.js"; + +const symbolSocketUdpIpUnspecified = + Symbol.symbolSocketUdpIpUnspecified ?? + Symbol.for("symbolSocketUdpIpUnspecified"); + +/** @type {Map} */ +export const openedSockets = new Map(); + +/** @type {Map>} */ +const queuedReceivedSocketDatagrams = new Map(); let socketCnt = 0; +export function getSocketOrThrow(socketId) { + const socket = openedSockets.get(socketId); + if (!socket) throw "invalid-state"; + return socket; +} + +export function getSocketByPort(port) { + return Array.from(openedSockets.values()).find( + (socket) => socket.address().port === port + ); +} + +export function getBoundSockets(socketId) { + return Array.from(openedSockets.entries()) + .filter(([id, _socket]) => id !== socketId) // exclude source socket + .map(([_id, socket]) => socket.address()); +} + +export function dequeueReceivedSocketDatagram(socketInfo, maxResults) { + const key = `PORT:${socketInfo.port}`; + const dgrams = queuedReceivedSocketDatagrams + .get(key) + .splice(0, Number(maxResults)); + return dgrams; +} +export function enqueueReceivedSocketDatagram(socketInfo, { data, rinfo }) { + const key = `PORT:${socketInfo.port}`; + const chunk = { + data, + rinfo, // sender/remote socket info (source) + socketInfo, // receiver socket info (targeted socket) + }; + + // create new queue if not exists + if (!queuedReceivedSocketDatagrams.has(key)) { + queuedReceivedSocketDatagrams.set(key, []); + } + + // append to queue + const queue = queuedReceivedSocketDatagrams.get(key); + queue.push(chunk); +} + /** * @param {IpAddressFamily} addressFamily * @returns {NodeJS.Socket} @@ -28,3 +80,138 @@ export function createUdpSocket(addressFamily, reuseAddr) { } }); } + +export function socketUdpBind(id, payload) { + const { localAddress, localPort } = payload; + const socket = getSocketOrThrow(id); + + // Note: even if the client has bound to IPV4_UNSPECIFIED/IPV6_UNSPECIFIED (0.0.0.0 // ::), + // rinfo.address is resolved to IPV4_LOOPBACK/IPV6_LOOPBACK. + // We need to cache the original bound IP type and fix rinfo.address when receiving datagrams (see below) + // See https://github.com/WebAssembly/wasi-sockets/issues/86 + socket[symbolSocketUdpIpUnspecified] = { + isUnspecified: + localAddress === "0.0.0.0" || localAddress === "0:0:0:0:0:0:0:0", + localAddress, + }; + + return new Promise((resolve) => { + socket.bind( + { + address: localAddress, + port: localPort, + }, + () => { + openedSockets.set(id, socket); + resolve(0); + } + ); + + socket.on("message", (data, rinfo) => { + const remoteSocket = getSocketByPort(rinfo.port); + let { address, port } = socket.address(); + + if (remoteSocket[symbolSocketUdpIpUnspecified].isUnspecified) { + // cache original bound address + rinfo._address = + remoteSocket[symbolSocketUdpIpUnspecified].localAddress; + } + + const receiverSocket = { + address, + port, + id, + }; + + enqueueReceivedSocketDatagram(receiverSocket, { data, rinfo }); + }); + + // catch all errors + socket.once("error", (err) => { + resolve(err.errno); + }); + }); +} + +export function socketUdpCheckSend(id) { + const socket = getSocketOrThrow(id); + try { + return socket.getSendBufferSize() - socket.getSendQueueSize(); + } catch (err) { + return err.errno; + } +} + +export function socketUdpSend(id, payload) { + let { remoteHost, remotePort, data } = payload; + const socket = getSocketOrThrow(id); + + return new Promise((resolve) => { + const _callback = (err, _byteLength) => { + if (err) return resolve(err.errno); + resolve(0); // success + }; + + // Note: when remoteHost/remotePort is None, we broadcast to all bound sockets + // except the source socket + if (remotePort === undefined || remoteHost === undefined) { + getBoundSockets(id).forEach((adr) => { + socket.send(data, adr.port, adr.address, _callback); + }); + } else { + socket.send(data, remotePort, remoteHost, _callback); + } + + socket.once("error", (err) => { + resolve(err.errno); + }); + }); +} + +export function SocketUdpReceive(id, payload) { + const { maxResults } = payload; + const socket = getSocketOrThrow(id); + const { address, port } = socket.address(); + + // set target socket info + // we use this to filter out datagrams that are were sent to this socket + const targetSocket = { + address, + port, + }; + + const dgrams = dequeueReceivedSocketDatagram(targetSocket, maxResults); + return Promise.resolve(dgrams); +} + +export function socketUdpConnect(id, payload) { + const socket = getSocketOrThrow(id); + const { remoteAddress, remotePort } = payload; + return new Promise((resolve) => { + socket.connect(remotePort, remoteAddress, () => { + openedSockets.set(id, socket); + resolve(0); + }); + socket.once("error", (err) => { + resolve(err.errno); + }); + }); +} + +export function socketUdpDisconnect(id) { + const socket = getSocketOrThrow(id); + return new Promise((resolve) => { + socket.disconnect(); + resolve(0); + }); +} + +export function socketUdpDispose(id) { + const socket = getSocketOrThrow(id); + return new Promise((resolve) => { + socket.close(() => { + openedSockets.delete(id); + resolve(0); + }); + }); +} diff --git a/packages/preview2-shim/lib/io/worker-thread.js b/packages/preview2-shim/lib/io/worker-thread.js index 24cb3eca5..a0c52e522 100644 --- a/packages/preview2-shim/lib/io/worker-thread.js +++ b/packages/preview2-shim/lib/io/worker-thread.js @@ -73,11 +73,7 @@ import { HTTP_SERVER_SET_OUTGOING_RESPONSE, HTTP_SERVER_CLEAR_OUTGOING_RESPONSE, } from "./calls.js"; -import { createUdpSocket } from "./worker-socket-udp.js"; - -const symbolSocketUdpIpUnspecified = - Symbol.symbolSocketUdpIpUnspecified ?? - Symbol.for("symbolSocketUdpIpUnspecified"); +import { SocketUdpReceive, createUdpSocket, getSocketOrThrow, socketUdpBind, socketUdpCheckSend, socketUdpConnect, socketUdpDisconnect, socketUdpDispose, socketUdpSend } from "./worker-socket-udp.js"; let streamCnt = 0, pollCnt = 0; @@ -88,12 +84,6 @@ export const unfinishedPolls = new Map(); /** @type {Map | null, stream: NodeJS.ReadableStream | NodeJS.WritableStream }>} */ export const unfinishedStreams = new Map(); -/** @type {Map} */ -export const openedSockets = new Map(); - -/** @type {Map>} */ -const queuedReceivedSocketDatagrams = new Map(); - /** @type {Map} */ export const unfinishedFutures = new Map(); @@ -142,48 +132,6 @@ export function getStreamOrThrow(streamId) { return stream; } -export function getSocketOrThrow(socketId) { - const socket = openedSockets.get(socketId); - if (!socket) throw "invalid-state"; - return socket; -} - -export function getSocketByPort(port) { - return Array.from(openedSockets.values()).find( - (socket) => socket.address().port === port - ); -} - -export function getBoundSockets(socketId) { - return Array.from(openedSockets.entries()) - .filter(([id, _socket]) => id !== socketId) // exclude source socket - .map(([_id, socket]) => socket.address()); -} - -export function dequeueReceivedSocketDatagram(socketInfo, maxResults) { - const dgrams = queuedReceivedSocketDatagrams - .get(`PORT:${socketInfo.port}`) - .splice(0, Number(maxResults)); - return dgrams; -} -export function enqueueReceivedSocketDatagram(socketInfo, { data, rinfo }) { - const key = `PORT:${socketInfo.port}`; - const chunk = { - data, - rinfo, // sender/remote socket info (source) - socketInfo, // receiver socket info (targeted socket) - }; - - // create new queue if not exists - if (!queuedReceivedSocketDatagrams.has(key)) { - queuedReceivedSocketDatagrams.set(key, []); - } - - // append to queue - const queue = queuedReceivedSocketDatagrams.get(key); - queue.push(chunk); -} - function subscribeInstant(instant) { const duration = instant - hrtime.bigint(); if (duration <= 0) return Promise.resolve(); @@ -306,130 +254,23 @@ function handle(call, id, payload) { return createFuture(createUdpSocket(addressFamily, reuseAddr)); } - case SOCKET_UDP_BIND: { - const { localAddress, localPort } = payload; - const socket = getSocketOrThrow(id); + case SOCKET_UDP_BIND: + return socketUdpBind(id, payload); - // Note: even if the client has bound to IPV4_UNSPECIFIED/IPV6_UNSPECIFIED (0.0.0.0 // ::), - // rinfo.address is resolved to IPV4_LOOPBACK/IPV6_LOOPBACK. - // We need to cache the original bound IP type and fix rinfo.address when receiving datagrams (see below) - // See https://github.com/WebAssembly/wasi-sockets/issues/86 - socket[symbolSocketUdpIpUnspecified] = { - isUnspecified: - localAddress === "0.0.0.0" || localAddress === "0:0:0:0:0:0:0:0", - localAddress, - }; - - return new Promise((resolve) => { - socket.bind( - { - address: localAddress, - port: localPort, - }, - () => { - openedSockets.set(id, socket); - resolve(0); - } - ); + case SOCKET_UDP_CHECK_SEND: + return socketUdpCheckSend(id); - socket.on("message", (data, rinfo) => { - const remoteSocket = getSocketByPort(rinfo.port); - let { address, port } = socket.address(); + case SOCKET_UDP_SEND: + return socketUdpSend(id, payload); - if (remoteSocket[symbolSocketUdpIpUnspecified].isUnspecified) { - // cache original bound address - rinfo._address = - remoteSocket[symbolSocketUdpIpUnspecified].localAddress; - } - - const receiverSocket = { - address, - port, - id, - }; + case SOCKET_UDP_RECEIVE: + return SocketUdpReceive(id, payload); - enqueueReceivedSocketDatagram(receiverSocket, { data, rinfo }); - }); + case SOCKET_UDP_CONNECT: + return socketUdpConnect(id, payload); - // catch all errors - socket.once("error", (err) => { - resolve(err.errno); - }); - }); - } - - case SOCKET_UDP_CHECK_SEND: { - const socket = getSocketOrThrow(id); - try { - return socket.getSendBufferSize() - socket.getSendQueueSize(); - } catch (err) { - return err.errno; - } - } - - case SOCKET_UDP_SEND: { - let { remoteHost, remotePort, data } = payload; - const socket = getSocketOrThrow(id); - - return new Promise((resolve) => { - const _callback = (err, _byteLength) => { - if (err) return resolve(err.errno); - resolve(0); // success - }; - - // Note: when remoteHost/remotePort is None, we broadcast to all bound sockets - // except the source socket - if (remotePort === undefined || remoteHost === undefined) { - getBoundSockets(id).forEach((adr) => { - socket.send(data, adr.port, adr.address, _callback); - }); - } else { - socket.send(data, remotePort, remoteHost, _callback); - } - - socket.once("error", (err) => { - resolve(err.errno); - }); - }); - } - - case SOCKET_UDP_RECEIVE: { - const { maxResults } = payload; - const socket = getSocketOrThrow(id); - const { address, port } = socket.address(); - - // set target socket info - // we use this to filter out datagrams that are were sent to this socket - const targetSocket = { - address, - port, - }; - - const dgrams = dequeueReceivedSocketDatagram(targetSocket, maxResults); - return Promise.resolve(dgrams); - } - - case SOCKET_UDP_CONNECT: { - const socket = getSocketOrThrow(id); - const { remoteAddress, remotePort } = payload; - return new Promise((resolve) => { - socket.connect(remotePort, remoteAddress, () => { - openedSockets.set(id, socket); - resolve(0); - }); - socket.once("error", (err) => { - resolve(err.errno); - }); - }); - } - - case SOCKET_UDP_DISCONNECT: { - const socket = getSocketOrThrow(id); - return new Promise((resolve) => { - socket.disconnect(); - resolve(0); - }); - } + case SOCKET_UDP_DISCONNECT: + return socketUdpDisconnect(id); case SOCKET_UDP_GET_LOCAL_ADDRESS: { const socket = getSocketOrThrow(id); @@ -465,7 +306,6 @@ function handle(call, id, payload) { try { return socket.getRecvBufferSize(); } catch (err) { - // we are only interested in the errno return err.errno; } } @@ -473,7 +313,7 @@ function handle(call, id, payload) { case SOCKET_UDP_SET_RECEIVE_BUFFER_SIZE: { const socket = getSocketOrThrow(id); try { - return socket.setRecvBufferSize(payload.value); + return socket.setRecvBufferSize(65537); } catch (err) { return err.errno; } @@ -515,15 +355,8 @@ function handle(call, id, payload) { break; } - case SOCKET_UDP_DISPOSE: { - const socket = getSocketOrThrow(id); - return new Promise((resolve) => { - socket.close(() => { - openedSockets.delete(id); - resolve(0); - }); - }); - } + case SOCKET_UDP_DISPOSE: + return socketUdpDispose(id); // Stdio case OUTPUT_STREAM_BLOCKING_FLUSH | STDOUT: diff --git a/packages/preview2-shim/lib/nodejs/sockets/udp-socket-impl.js b/packages/preview2-shim/lib/nodejs/sockets/udp-socket-impl.js index 157531746..35cb61f5b 100644 --- a/packages/preview2-shim/lib/nodejs/sockets/udp-socket-impl.js +++ b/packages/preview2-shim/lib/nodejs/sockets/udp-socket-impl.js @@ -26,11 +26,19 @@ import { SOCKET_UDP_SET_UNICAST_HOP_LIMIT, } from "../../io/calls.js"; import { ioCall, pollableCreate } from "../../io/worker-io.js"; -import { deserializeIpAddress, isIPv4MappedAddress, isWildcardAddress, serializeIpAddress } from "./socket-common.js"; +import { + deserializeIpAddress, + findUnsuedLocalAddress, + isIPv4MappedAddress, + isWildcardAddress, + serializeIpAddress, +} from "./socket-common.js"; const symbolDispose = Symbol.dispose || Symbol.for("dispose"); -const symbolSocketState = Symbol.SocketInternalState || Symbol.for("SocketInternalState"); -const symbolOperations = Symbol.SocketOperationsState || Symbol.for("SocketOperationsState"); +const symbolSocketState = + Symbol.SocketInternalState || Symbol.for("SocketInternalState"); +const symbolOperations = + Symbol.SocketOperationsState || Symbol.for("SocketOperationsState"); // TODO: move to a common const SocketConnectionState = { @@ -261,6 +269,17 @@ export class UdpSocket { return socket; } + #autoBind(network, ipFamily) { + const localAddress = findUnsuedLocalAddress(ipFamily); + this.#socketOptions.localAddress = serializeIpAddress( + localAddress, + this.#socketOptions.family + ); + this.#socketOptions.localPort = localAddress.val.port; + this.startBind(network, localAddress); + this.finishBind(); + } + #cacheBoundAddress() { let { localIpSocketAddress: boundAddress, localPort } = this.#socketOptions; // when port is 0, the OS will assign an ephemeral port @@ -280,20 +299,29 @@ export class UdpSocket { * @throws {invalid-state} The socket is already bound. (EINVAL) */ startBind(network, localAddress) { - if (!this.allowed()) - throw 'access-denied'; + if (!this.allowed()) throw "access-denied"; try { - assert(this[symbolSocketState].isBound, "invalid-state", "The socket is already bound"); + assert( + this[symbolSocketState].isBound, + "invalid-state", + "The socket is already bound" + ); const address = serializeIpAddress(localAddress); const ipFamily = `ipv${isIP(address)}`; assert( - this.#socketOptions.family.toLocaleLowerCase() !== ipFamily.toLocaleLowerCase(), + this.#socketOptions.family.toLocaleLowerCase() !== + ipFamily.toLocaleLowerCase(), "invalid-argument", "The `local-address` has the wrong address family" ); + assert( + isIPv4MappedAddress(localAddress) && this.ipv6Only(), + "invalid-argument" + ); + const { port } = localAddress.val; this.#socketOptions.localIpSocketAddress = localAddress; this.#socketOptions.localAddress = address; @@ -320,9 +348,15 @@ export class UdpSocket { try { assert(this[symbolOperations].bind === 0, "not-in-progress"); - const { localAddress, localIpSocketAddress, localPort } = this.#socketOptions; + const { localAddress, localIpSocketAddress, localPort } = + this.#socketOptions; assert(isIP(localAddress) === 0, "address-not-bindable"); - assert(globalBoundAddresses.has(serializeIpAddress(localIpSocketAddress, true)), "address-in-use"); + assert( + globalBoundAddresses.has( + serializeIpAddress(localIpSocketAddress, true) + ), + "address-in-use" + ); const err = ioCall(SOCKET_UDP_BIND, this.id, { localAddress, @@ -376,7 +410,11 @@ export class UdpSocket { #startConnect(network, remoteAddress = undefined) { this[symbolOperations].connect++; - if (remoteAddress === undefined || this[symbolSocketState].connectionState === SocketConnectionState.Connected) { + if ( + remoteAddress === undefined || + this[symbolSocketState].connectionState === + SocketConnectionState.Connected + ) { this.#socketOptions.remoteAddress = undefined; this.#socketOptions.remotePort = 0; return; @@ -387,14 +425,25 @@ export class UdpSocket { "invalid-argument", "The IP address in `remote-address` is set to INADDR_ANY (`0.0.0.0` / `::`)" ); - assert(isIPv4MappedAddress(remoteAddress) && this.ipv6Only(), "invalid-argument"); - assert(remoteAddress.val.port === 0, "invalid-argument", "The port in `remote-address` is set to 0"); + assert( + isIPv4MappedAddress(remoteAddress) && this.ipv6Only(), + "invalid-argument" + ); + assert( + remoteAddress.val.port === 0, + "invalid-argument", + "The port in `remote-address` is set to 0" + ); const host = serializeIpAddress(remoteAddress); const ipFamily = `ipv${isIP(host)}`; assert(ipFamily.toLocaleLowerCase() === "ipv0", "invalid-argument"); - assert(this.#socketOptions.family.toLocaleLowerCase() !== ipFamily.toLocaleLowerCase(), "invalid-argument"); + assert( + this.#socketOptions.family.toLocaleLowerCase() !== + ipFamily.toLocaleLowerCase(), + "invalid-argument" + ); const { port } = remoteAddress.val; this.#socketOptions.remoteAddress = host; // can be undefined @@ -414,7 +463,11 @@ export class UdpSocket { const { remoteAddress, remotePort } = this.#socketOptions; this[symbolSocketState].connectionState = SocketConnectionState.Connecting; - if (remoteAddress === undefined || this[symbolSocketState].connectionState === SocketConnectionState.Connected) { + if ( + remoteAddress === undefined || + this[symbolSocketState].connectionState === + SocketConnectionState.Connected + ) { return; } @@ -471,15 +524,24 @@ export class UdpSocket { * @throws {connection-refused} The connection was refused. (ECONNREFUSED) */ stream(remoteAddress = undefined) { + assert(this[symbolSocketState].lastErrorState !== null, "invalid-state"); + // Note: to comply with test programs, we cannot throw if the socket is not bound (as required by the spec - see udp.wit) // assert(this[symbolSocketState].isBound === false, "invalid-state"); - if (this[symbolSocketState].connectionState === SocketConnectionState.Connected) { + if ( + this[symbolSocketState].connectionState === + SocketConnectionState.Connected + ) { // stream() can be called multiple times, so we need to disconnect first if we are already connected // Note: disconnect() will also reset the connection state but does not close the socket handle! this.#disconnect(); } + if (remoteAddress) { + this.#connect(this.network, remoteAddress); + } + // reconfigure remote host and port. // Note: remoteAddress can be undefined const host = serializeIpAddress(remoteAddress); @@ -487,9 +549,10 @@ export class UdpSocket { this.#socketOptions.remoteAddress = host; // host can be undefined this.#socketOptions.remotePort = port; - // reconnect to the remote host - // this.#connect(this.network, remoteAddress); - return [incomingDatagramStreamCreate(this.id), outgoingDatagramStreamCreate(this.id)]; + return [ + incomingDatagramStreamCreate(this.id), + outgoingDatagramStreamCreate(this.id), + ]; } /** @@ -523,14 +586,19 @@ export class UdpSocket { */ remoteAddress() { assert( - this[symbolSocketState].connectionState !== SocketConnectionState.Connected, + this[symbolSocketState].connectionState !== + SocketConnectionState.Connected, "invalid-state", "The socket is not streaming to a specific remote address" ); const out = ioCall(SOCKET_UDP_GET_REMOTE_ADDRESS, this.id); - assert(out.address === undefined, "invalid-state", "The socket is not streaming to a specific remote address"); + assert( + out.address === undefined, + "invalid-state", + "The socket is not streaming to a specific remote address" + ); const { address, port, family } = out; this.#socketOptions.remoteAddress = address; @@ -560,7 +628,11 @@ export class UdpSocket { * @throws {not-supported} (get/set) `this` socket is an IPv4 socket. */ ipv6Only() { - assert(this.#socketOptions.family.toLocaleLowerCase() === "ipv4", "not-supported", "Socket is an IPv4 socket."); + assert( + this.#socketOptions.family.toLocaleLowerCase() === "ipv4", + "not-supported", + "Socket is an IPv4 socket." + ); return this[symbolSocketState].ipv6Only; } @@ -575,11 +647,16 @@ export class UdpSocket { */ setIpv6Only(value) { assert( - value === true && this.#socketOptions.family.toLocaleLowerCase() === "ipv4", + value === true && + this.#socketOptions.family.toLocaleLowerCase() === "ipv4", "not-supported", "Socket is an IPv4 socket." ); - assert(this[symbolSocketState].isBound, "invalid-state", "The socket is already bound"); + assert( + this[symbolSocketState].isBound, + "invalid-state", + "The socket is already bound" + ); this[symbolSocketState].ipv6Only = value; } @@ -610,16 +687,19 @@ export class UdpSocket { * @returns {bigint} */ receiveBufferSize() { + // `receiveBufferSize()` would throws EBADF if called on an unbound socket. // TODO: should we throw if the socket is not bound? // assert(this[symbolSocketState].isBound === false, "invalid-state"); + // or we can auto-bind the socket if it's not bound? + if (this[symbolSocketState].isBound === false) { + this.#autoBind(this.network, this.#socketOptions.family); + } + + // Note: on WSL, this may report a different value than the one set! const ret = ioCall(SOCKET_UDP_GET_RECEIVE_BUFFER_SIZE, this.id); - // if (ret === -9) { - // // TODO: handle the case where bad file descriptor (EBADF) is returned - // // This happens when the socket is not bound - // return this[symbolSocketState].receiveBufferSize; - // } + assert(ret === -9, "invalid-state"); // EBADF return ret; } @@ -633,7 +713,7 @@ export class UdpSocket { setReceiveBufferSize(value) { assert(value === 0n, "invalid-argument", "The provided value was 0"); - // value = cappedUint32(value); + value = Number(value); ioCall(SOCKET_UDP_SET_RECEIVE_BUFFER_SIZE, this.id, { value }); } @@ -646,7 +726,6 @@ export class UdpSocket { // assert(this[symbolSocketState].isBound === false, "invalid-state"); const ret = ioCall(SOCKET_UDP_GET_SEND_BUFFER_SIZE, this.id); - // if (ret === -9) { // // TODO: handle the case where bad file descriptor (EBADF) is returned // // This happens when the socket is not bound