Skip to content

Commit

Permalink
Implement Request.signal to detect client disconnects
Browse files Browse the repository at this point in the history
  • Loading branch information
npaun committed Feb 10, 2025
1 parent e5b3760 commit d43950e
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 9 deletions.
6 changes: 6 additions & 0 deletions src/workerd/api/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,9 @@ wd_test(
args = ["--experimental"],
data = ["tests/fetch-test.js"],
)

wd_test(
src = "tests/request-client-disconnect.wd-test",
args = ["--experimental"],
data = ["tests/request-client-disconnect.js"],
)
15 changes: 13 additions & 2 deletions src/workerd/api/global-scope.c++
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ kj::Promise<DeferredProxy<void>> ServiceWorkerGlobalScope::request(kj::HttpMetho
kj::HttpService::Response& response,
kj::Maybe<kj::StringPtr> cfBlobJson,
Worker::Lock& lock,
kj::Maybe<ExportedHandler&> exportedHandler) {
kj::Maybe<ExportedHandler&> exportedHandler,
kj::Maybe<jsg::Ref<AbortSignal>> abortSignal) {
TRACE_EVENT("workerd", "ServiceWorkerGlobalScope::request()");
// To construct a ReadableStream object, we're supposed to pass in an Own<AsyncInputStream>, so
// that it can drop the reference whenever it gets GC'ed. But in this case the stream's lifetime
Expand Down Expand Up @@ -190,7 +191,17 @@ kj::Promise<DeferredProxy<void>> ServiceWorkerGlobalScope::request(kj::HttpMetho

auto jsRequest = jsg::alloc<Request>(method, url, Request::Redirect::MANUAL, kj::mv(jsHeaders),
jsg::alloc<Fetcher>(IoContext::NEXT_CLIENT_CHANNEL, Fetcher::RequiresHostAndProtocol::YES),
kj::none /** AbortSignal **/, kj::mv(cf), kj::mv(body));
/* signal */ kj::none, kj::mv(cf), kj::mv(body),
/* thisSignal */ kj::mv(abortSignal), Request::CacheMode::NONE);

// signal vs thisSignal
// --------------------
// The fetch spec definition of Request has a distinction between the
// "signal" (which is an optional AbortSignal passed in with the options), and "this' signal",
// which is an AbortSignal that is always available via the request.signal accessor.
//
// redirect
// --------
// I set the redirect mode to manual here, so that by default scripts that just pass requests
// through to a fetch() call will behave the same as scripts which don't call .respondWith(): if
// the request results in a redirect, the visitor will see that redirect.
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/global-scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ class ServiceWorkerGlobalScope: public WorkerGlobalScope {
kj::HttpService::Response& response,
kj::Maybe<kj::StringPtr> cfBlobJson,
Worker::Lock& lock,
kj::Maybe<ExportedHandler&> exportedHandler);
kj::Maybe<ExportedHandler&> exportedHandler,
kj::Maybe<jsg::Ref<AbortSignal>> abortSignal);
// TODO(cleanup): Factor out the shared code used between old-style event listeners vs. module
// exports and move that code somewhere more appropriate.

Expand Down
5 changes: 3 additions & 2 deletions src/workerd/api/http.c++
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,8 @@ jsg::Ref<Request> Request::constructor(

// TODO(conform): If `init` has a keepalive flag, pass it to the Body constructor.
return jsg::alloc<Request>(method, url, redirect, KJ_ASSERT_NONNULL(kj::mv(headers)),
kj::mv(fetcher), kj::mv(signal), kj::mv(cf), kj::mv(body), cacheMode);
kj::mv(fetcher), kj::mv(signal), kj::mv(cf), kj::mv(body), /* thisSignal */ kj::none,
cacheMode);
}

jsg::Ref<Request> Request::clone(jsg::Lock& js) {
Expand All @@ -1122,7 +1123,7 @@ jsg::Ref<Request> Request::clone(jsg::Lock& js) {
auto bodyClone = Body::clone(js);

return jsg::alloc<Request>(method, url, redirect, kj::mv(headersClone), getFetcher(), getSignal(),
kj::mv(cfClone), kj::mv(bodyClone));
kj::mv(cfClone), kj::mv(bodyClone), /* thisSignal */ kj::none);
}

kj::StringPtr Request::getMethod() {
Expand Down
9 changes: 7 additions & 2 deletions src/workerd/api/http.h
Original file line number Diff line number Diff line change
Expand Up @@ -805,12 +805,13 @@ class Request final: public Body {
Request(kj::HttpMethod method, kj::StringPtr url, Redirect redirect,
jsg::Ref<Headers> headers, kj::Maybe<jsg::Ref<Fetcher>> fetcher,
kj::Maybe<jsg::Ref<AbortSignal>> signal, CfProperty&& cf,
kj::Maybe<Body::ExtractedBody> body, CacheMode cacheMode = CacheMode::NONE)
kj::Maybe<Body::ExtractedBody> body, kj::Maybe<jsg::Ref<AbortSignal>> thisSignal,
CacheMode cacheMode = CacheMode::NONE)
: Body(kj::mv(body), *headers), method(method), url(kj::str(url)),
redirect(redirect), headers(kj::mv(headers)), fetcher(kj::mv(fetcher)),
cacheMode(cacheMode), cf(kj::mv(cf)) {
KJ_IF_SOME(s, signal) {
// If the AbortSignal will never abort, assigning it to thisSignal instead ensures
// If the AbortSignal will never abort, assigning it to thisSignal instead ensures
// that the cancel machinery is not used but the request.signal accessor will still
// do the right thing.
if (s->getNeverAborts()) {
Expand All @@ -819,6 +820,10 @@ class Request final: public Body {
this->signal = kj::mv(s);
}
}

KJ_IF_SOME(s, thisSignal) {
this->thisSignal = kj::mv(s);
}
}
// TODO(conform): Technically, the request's URL should be parsed immediately upon Request
// construction, and any errors encountered should be thrown. Instead, we defer parsing until
Expand Down
128 changes: 128 additions & 0 deletions src/workerd/api/tests/request-client-disconnect.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import { DurableObject, WorkerEntrypoint } from 'cloudflare:workers';
import assert from 'node:assert';
import { scheduler } from 'node:timers/promises';

export class AbortTracker extends DurableObject {
async getAborted(key) {
return this.ctx.storage.get(key);
}
async setAborted(key, value) {
await this.ctx.storage.put(key, value);
}
}

export class Server extends WorkerEntrypoint {
async fetch(req) {
const key = new URL(req.url).pathname.slice(1);
let abortTracker = this.env.AbortTracker.get(
this.env.AbortTracker.idFromName('AbortTracker')
);
await abortTracker.setAborted(key, false);
req.signal.onabort = () => {
this.ctx.waitUntil(abortTracker.setAborted(key, true));
};
return this[key]();
}

async valid() {
return new Response('hello world');
}

async error() {
throw new Error('boom');
}

async hang() {
for (;;) {
await scheduler.wait(86400);
}
}

async hangAfterSendingSomeData() {
const { readable, writable } = new TransformStream();
this.ctx.waitUntil(this.sendSomeData(writable));

return new Response(readable);
}

async sendSomeData(writable) {
const writer = writable.getWriter();
const enc = new TextEncoder();
await writer.write(enc.encode('hello world'));
await this.hang();
}
}

export const noAbortOnSimpleResponse = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

const req = env.Server.fetch('http://example.com/valid');

const res = await req;
assert.strictEqual(await res.text(), 'hello world');
assert.strictEqual(await abortTracker.getAborted('valid'), false);
},
};

export const noAbortIfServerThrows = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

const req = env.Server.fetch('http://example.com/error');

await assert.rejects(() => req, { name: 'Error', message: 'boom' });
assert.strictEqual(await abortTracker.getAborted('error'), false);
},
};

export const abortIfClientAbandonsRequest = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

// This endpoint never generates a response, so we can timeout after an arbitrary time.
const req = env.Server.fetch('http://example.com/hang', {
signal: AbortSignal.timeout(500),
});

await assert.rejects(() => req, {
name: 'TimeoutError',
message: 'The operation was aborted due to timeout',
});
assert.strictEqual(await abortTracker.getAborted('hang'), true);
},
};

export const abortIfClientCancelsReadingResponse = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

// This endpoint begins generating a response but then hangs
const req = env.Server.fetch('http://example.com/hangAfterSendingSomeData');
const res = await req;
const reader = res.body.getReader();

const { value, done } = await reader.read();
assert.strictEqual(new TextDecoder().decode(value), 'hello world');
assert.ok(!done);

// Give up reading
await reader.cancel();

// Waste a bit of time so the server cleans up
await scheduler.wait(0);

assert.strictEqual(
await abortTracker.getAborted('hangAfterSendingSomeData'),
true
);
},
};
25 changes: 25 additions & 0 deletions src/workerd/api/tests/request-client-disconnect.wd-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Workerd = import "/workerd/workerd.capnp";

const unitTests :Workerd.Config = (
services = [
( name = "request-client-disconnect",
worker = (
modules = [
(name = "worker", esModule = embed "request-client-disconnect.js" )
],
compatibilityDate = "2025-01-01",
compatibilityFlags = ["nodejs_compat", "experimental"],
durableObjectNamespaces = [
(className = "AbortTracker", uniqueKey = "badbeef"),
],
durableObjectStorage = (inMemory = void),
bindings = [
(name = "AbortTracker", durableObjectNamespace = "AbortTracker"),
(name = "Server", service = (name = "request-client-disconnect", entrypoint = "Server")),
(name = "defaultExport", service = "request-client-disconnect"),
]
)
)
]
);

19 changes: 18 additions & 1 deletion src/workerd/io/worker-entrypoint.c++
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "worker-entrypoint.h"

#include <workerd/api/basics.h>
#include <workerd/api/global-scope.h>
#include <workerd/api/util.h>
#include <workerd/io/io-context.h>
Expand Down Expand Up @@ -92,6 +93,7 @@ class WorkerEntrypoint final: public WorkerInterface {
kj::Maybe<kj::Promise<void>> proxyTask;
kj::Maybe<kj::Own<WorkerInterface>> failOpenService;
bool loggedExceptionEarlier = false;
kj::Maybe<jsg::Ref<api::AbortController>> abortController;

void init(kj::Own<const Worker> worker,
kj::Maybe<kj::Own<Worker::Actor>> actor,
Expand Down Expand Up @@ -297,10 +299,12 @@ kj::Promise<void> WorkerEntrypoint::request(kj::HttpMethod method,
TRACE_EVENT_END("workerd", PERFETTO_TRACK_FROM_POINTER(&context));
TRACE_EVENT("workerd", "WorkerEntrypoint::request() run", PERFETTO_FLOW_FROM_POINTER(this));
jsg::AsyncContextFrame::StorageScope traceScope = context.makeAsyncTraceScope(lock);
jsg::Ref<api::AbortSignal> signal =
abortController.emplace(jsg::alloc<api::AbortController>())->getSignal();

return lock.getGlobalScope().request(method, url, headers, requestBody, wrappedResponse,
cfBlobJson, lock,
lock.getExportedHandler(entrypointName, kj::mv(props), context.getActor()));
lock.getExportedHandler(entrypointName, kj::mv(props), context.getActor()), kj::mv(signal));
})
.then([this](api::DeferredProxy<void> deferredProxy) {
TRACE_EVENT("workerd", "WorkerEntrypoint::request() deferred proxy step",
Expand Down Expand Up @@ -332,6 +336,19 @@ kj::Promise<void> WorkerEntrypoint::request(kj::HttpMethod method,
failOpenService = context.getSubrequestChannelNoChecks(
IoContext::NEXT_CLIENT_CHANNEL, false, kj::mv(cfBlobJson));
}

if (proxyTask == kj::none && !loggedExceptionEarlier) {
// When the client disconnects, trigger an abort on request.signal, unless the request has
// already completed normally, or failed with an exception.

// TODO(perf): Don't add a task to trigger the abort unless we know it has at least one
// listener.
KJ_IF_SOME(ctrl, abortController) {
context.addWaitUntil(context.run(
[ctrl = kj::mv(ctrl)](Worker::Lock& lock) mutable { ctrl->abort(lock, kj::none); }));
}
}

auto promise = incomingRequest->drain().attach(kj::mv(incomingRequest));
waitUntilTasks.add(maybeAddGcPassForTest(context, kj::mv(promise)));
}))
Expand Down
2 changes: 1 addition & 1 deletion src/workerd/tests/test-fixture.c++
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ TestFixture::Response TestFixture::runRequest(
runInIoContext([&](const TestFixture::Environment& env) {
auto& globalScope = env.lock.getGlobalScope();
return globalScope.request(method, url, requestHeaders, *requestBody, response, "{}"_kj,
env.lock, env.lock.getExportedHandler(kj::none, {}, kj::none));
env.lock, env.lock.getExportedHandler(kj::none, {}, kj::none), /* abortSignal */ kj::none);
});

return {.statusCode = response.statusCode, .body = response.body->str()};
Expand Down

0 comments on commit d43950e

Please sign in to comment.