Skip to content

Commit

Permalink
Implement Request.signal to detect client disconnects
Browse files Browse the repository at this point in the history
  • Loading branch information
npaun committed Feb 12, 2025
1 parent e5b3760 commit 2d76a37
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/workerd/api/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,9 @@ wd_test(
args = ["--experimental"],
data = ["tests/fetch-test.js"],
)

wd_test(
src = "tests/request-client-disconnect.wd-test",
args = ["--experimental"],
data = ["tests/request-client-disconnect.js"],
)
12 changes: 12 additions & 0 deletions src/workerd/api/basics.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,21 @@ class AbortSignal final: public EventTarget {
tracker.trackField("reason", reason);
}

bool isIgnoredForSubrequests() {
return ignoreForSubrequests;
}

void setIgnoredForSubrequests() {
ignoreForSubrequests = true;
}

private:
IoOwn<RefcountedCanceler> canceler;
Flag flag;

// If set, this AbortSignal will not be used when constructing a Request from an existing Request
bool ignoreForSubrequests = false;

kj::Maybe<jsg::JsRef<jsg::JsValue>> reason;
kj::Maybe<jsg::JsRef<jsg::JsValue>> onAbortHandler;

Expand Down
15 changes: 13 additions & 2 deletions src/workerd/api/global-scope.c++
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ kj::Promise<DeferredProxy<void>> ServiceWorkerGlobalScope::request(kj::HttpMetho
kj::HttpService::Response& response,
kj::Maybe<kj::StringPtr> cfBlobJson,
Worker::Lock& lock,
kj::Maybe<ExportedHandler&> exportedHandler) {
kj::Maybe<ExportedHandler&> exportedHandler,
kj::Maybe<jsg::Ref<AbortSignal>> abortSignal) {
TRACE_EVENT("workerd", "ServiceWorkerGlobalScope::request()");
// To construct a ReadableStream object, we're supposed to pass in an Own<AsyncInputStream>, so
// that it can drop the reference whenever it gets GC'ed. But in this case the stream's lifetime
Expand Down Expand Up @@ -190,7 +191,17 @@ kj::Promise<DeferredProxy<void>> ServiceWorkerGlobalScope::request(kj::HttpMetho

auto jsRequest = jsg::alloc<Request>(method, url, Request::Redirect::MANUAL, kj::mv(jsHeaders),
jsg::alloc<Fetcher>(IoContext::NEXT_CLIENT_CHANNEL, Fetcher::RequiresHostAndProtocol::YES),
kj::none /** AbortSignal **/, kj::mv(cf), kj::mv(body));
/* signal */ kj::none, kj::mv(cf), kj::mv(body),
/* thisSignal */ kj::mv(abortSignal), Request::CacheMode::NONE);

// signal vs thisSignal
// --------------------
// The fetch spec definition of Request has a distinction between the
// "signal" (which is an optional AbortSignal passed in with the options), and "this' signal",
// which is an AbortSignal that is always available via the request.signal accessor.
//
// redirect
// --------
// I set the redirect mode to manual here, so that by default scripts that just pass requests
// through to a fetch() call will behave the same as scripts which don't call .respondWith(): if
// the request results in a redirect, the visitor will see that redirect.
Expand Down
3 changes: 2 additions & 1 deletion src/workerd/api/global-scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ class ServiceWorkerGlobalScope: public WorkerGlobalScope {
kj::HttpService::Response& response,
kj::Maybe<kj::StringPtr> cfBlobJson,
Worker::Lock& lock,
kj::Maybe<ExportedHandler&> exportedHandler);
kj::Maybe<ExportedHandler&> exportedHandler,
kj::Maybe<jsg::Ref<AbortSignal>> abortSignal);
// TODO(cleanup): Factor out the shared code used between old-style event listeners vs. module
// exports and move that code somewhere more appropriate.

Expand Down
31 changes: 25 additions & 6 deletions src/workerd/api/http.c++
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ jsg::Ref<Request> Request::constructor(
cacheMode = oldRequest->getCacheMode();
redirect = oldRequest->getRedirectEnum();
fetcher = oldRequest->getFetcher();
signal = oldRequest->getSignal();
signal = oldRequest->getThisSignal();
}
}

Expand Down Expand Up @@ -1093,7 +1093,7 @@ jsg::Ref<Request> Request::constructor(
redirect = otherRequest->redirect;
cacheMode = otherRequest->cacheMode;
fetcher = otherRequest->getFetcher();
signal = otherRequest->getSignal();
signal = otherRequest->getThisSignal();
headers = jsg::alloc<Headers>(*otherRequest->headers);
cf = otherRequest->cf.deepClone(js);
KJ_IF_SOME(b, otherRequest->getBody()) {
Expand All @@ -1112,7 +1112,8 @@ jsg::Ref<Request> Request::constructor(

// TODO(conform): If `init` has a keepalive flag, pass it to the Body constructor.
return jsg::alloc<Request>(method, url, redirect, KJ_ASSERT_NONNULL(kj::mv(headers)),
kj::mv(fetcher), kj::mv(signal), kj::mv(cf), kj::mv(body), cacheMode);
kj::mv(fetcher), kj::mv(signal), kj::mv(cf), kj::mv(body), /* thisSignal */ kj::none,
cacheMode);
}

jsg::Ref<Request> Request::clone(jsg::Lock& js) {
Expand All @@ -1121,8 +1122,13 @@ jsg::Ref<Request> Request::clone(jsg::Lock& js) {
auto cfClone = cf.deepClone(js);
auto bodyClone = Body::clone(js);

return jsg::alloc<Request>(method, url, redirect, kj::mv(headersClone), getFetcher(), getSignal(),
kj::mv(cfClone), kj::mv(bodyClone));
return jsg::alloc<Request>(method, url, redirect, kj::mv(headersClone), getFetcher(),
/* signal */ getThisSignal(), kj::mv(cfClone), kj::mv(bodyClone), /* thisSignal */ kj::none);

// signal
//-------
// The fetch spec states: "Let clonedSignal be the result of creating a dependent abort signal
// from « this’s signal », using AbortSignal and this’s relevant realm."
}

kj::StringPtr Request::getMethod() {
Expand Down Expand Up @@ -1166,7 +1172,7 @@ jsg::Optional<jsg::JsObject> Request::getCf(jsg::Lock& js) {
// that's a bit silly and unnecessary.
// The name "thisSignal" is derived from the fetch spec, which draws a
// distinction between the "signal" and "this' signal".
jsg::Ref<AbortSignal> Request::getThisSignal(jsg::Lock& js) {
jsg::Ref<AbortSignal> Request::getThisSignal() {
KJ_IF_SOME(s, signal) {
return s.addRef();
}
Expand All @@ -1178,6 +1184,14 @@ jsg::Ref<AbortSignal> Request::getThisSignal(jsg::Lock& js) {
return newSignal;
}

void Request::clearSignalIfIgnoredForSubrequest() {
KJ_IF_SOME(s, signal) {
if (s->isIgnoredForSubrequests()) {
signal = kj::none;
}
}
}

kj::Maybe<Request::Redirect> Request::tryParseRedirect(kj::StringPtr redirect) {
if (strcasecmp(redirect.cStr(), "follow") == 0) {
return Redirect::FOLLOW;
Expand Down Expand Up @@ -2206,6 +2220,11 @@ jsg::Promise<jsg::Ref<Response>> fetchImplNoOutputLock(jsg::Lock& js,
// front is robust, and won't add significant overhead compared to the rest of fetch().
auto jsRequest = Request::constructor(js, kj::mv(requestOrUrl), kj::mv(requestInit));

// Clear the request's signal if the 'ignoreForSubrequests' flag is set. This happens when
// a request from an incoming fetch is passed-through to another fetch. We want to avoid
// aborting the subrequest in that case.
jsRequest->clearSignalIfIgnoredForSubrequest();

// This URL list keeps track of redirections and becomes a source for Response's URL list. The
// first URL in the list is the Request's URL (visible to JS via Request::getUrl()). The last URL
// in the list is the Request's "current" URL (eventually visible to JS via Response::getUrl()).
Expand Down
14 changes: 11 additions & 3 deletions src/workerd/api/http.h
Original file line number Diff line number Diff line change
Expand Up @@ -805,12 +805,13 @@ class Request final: public Body {
Request(kj::HttpMethod method, kj::StringPtr url, Redirect redirect,
jsg::Ref<Headers> headers, kj::Maybe<jsg::Ref<Fetcher>> fetcher,
kj::Maybe<jsg::Ref<AbortSignal>> signal, CfProperty&& cf,
kj::Maybe<Body::ExtractedBody> body, CacheMode cacheMode = CacheMode::NONE)
kj::Maybe<Body::ExtractedBody> body, kj::Maybe<jsg::Ref<AbortSignal>> thisSignal,
CacheMode cacheMode = CacheMode::NONE)
: Body(kj::mv(body), *headers), method(method), url(kj::str(url)),
redirect(redirect), headers(kj::mv(headers)), fetcher(kj::mv(fetcher)),
cacheMode(cacheMode), cf(kj::mv(cf)) {
KJ_IF_SOME(s, signal) {
// If the AbortSignal will never abort, assigning it to thisSignal instead ensures
// If the AbortSignal will never abort, assigning it to thisSignal instead ensures
// that the cancel machinery is not used but the request.signal accessor will still
// do the right thing.
if (s->getNeverAborts()) {
Expand All @@ -819,6 +820,10 @@ class Request final: public Body {
this->signal = kj::mv(s);
}
}

KJ_IF_SOME(s, thisSignal) {
this->thisSignal = kj::mv(s);
}
}
// TODO(conform): Technically, the request's URL should be parsed immediately upon Request
// construction, and any errors encountered should be thrown. Instead, we defer parsing until
Expand Down Expand Up @@ -871,7 +876,10 @@ class Request final: public Body {
// used on this request.
kj::Maybe<jsg::Ref<AbortSignal>> getSignal();

jsg::Ref<AbortSignal> getThisSignal(jsg::Lock& js);
jsg::Ref<AbortSignal> getThisSignal();

void clearSignalIfIgnoredForSubrequest();


// Returns the `cf` field containing Cloudflare feature flags.
jsg::Optional<jsg::JsObject> getCf(jsg::Lock& js);
Expand Down
188 changes: 188 additions & 0 deletions src/workerd/api/tests/request-client-disconnect.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import { DurableObject, WorkerEntrypoint } from 'cloudflare:workers';
import assert from 'node:assert';
import { scheduler } from 'node:timers/promises';

export class AbortTracker extends DurableObject {
async getAborted(key) {
return this.ctx.storage.get(key);
}
async setAborted(key, value) {
await this.ctx.storage.put(key, value);
}
}
export class OtherServer extends WorkerEntrypoint {
async fetch() {
await scheduler.wait(300);
return new Response('completed');
}
}

export class Server extends WorkerEntrypoint {
async fetch(req) {
const key = new URL(req.url).pathname.slice(1);
let abortTracker = this.env.AbortTracker.get(
this.env.AbortTracker.idFromName('AbortTracker')
);
await abortTracker.setAborted(key, false);

req.signal.onabort = () => {
this.ctx.waitUntil(abortTracker.setAborted(key, true));
};

return this[key](req);
}

async valid() {
return new Response('hello world');
}

async error() {
throw new Error('boom');
}

async hang() {
for (;;) {
await scheduler.wait(86400);
}
}

async hangAfterSendingSomeData() {
const { readable, writable } = new TransformStream();
this.ctx.waitUntil(this.sendSomeData(writable));

return new Response(readable);
}

async sendSomeData(writable) {
const writer = writable.getWriter();
const enc = new TextEncoder();
await writer.write(enc.encode('hello world'));
await this.hang();
}

async triggerSubrequest(req) {
this.ctx.waitUntil(this.callOtherServer(req));
await this.hang();
}

async callOtherServer(req) {
const key = 'subrequest';

let abortTracker = this.env.AbortTracker.get(
this.env.AbortTracker.idFromName('AbortTracker')
);

const passedThroughReq = new Request(req);
passedThroughReq.onabort = () => {
this.ctx.waitUntil(abortTracker.setAborted(key, true));
};

const res = await this.env.OtherServer.fetch(passedThroughReq);
const text = await res.text();

if (text == 'completed') {
await abortTracker.setAborted(key, false);
}
}
}

export const noAbortOnSimpleResponse = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

const req = env.Server.fetch('http://example.com/valid');

const res = await req;
assert.strictEqual(await res.text(), 'hello world');
assert.strictEqual(await abortTracker.getAborted('valid'), false);
},
};

export const noAbortIfServerThrows = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

const req = env.Server.fetch('http://example.com/error');

await assert.rejects(() => req, { name: 'Error', message: 'boom' });
assert.strictEqual(await abortTracker.getAborted('error'), false);
},
};

export const abortIfClientAbandonsRequest = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

// This endpoint never generates a response, so we can timeout after an arbitrary time.
const req = env.Server.fetch('http://example.com/hang', {
signal: AbortSignal.timeout(500),
});

await assert.rejects(() => req, {
name: 'TimeoutError',
message: 'The operation was aborted due to timeout',
});
assert.strictEqual(await abortTracker.getAborted('hang'), true);
},
};

export const abortIfClientCancelsReadingResponse = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

// This endpoint begins generating a response but then hangs
const req = env.Server.fetch('http://example.com/hangAfterSendingSomeData');
const res = await req;
const reader = res.body.getReader();

const { value, done } = await reader.read();
assert.strictEqual(new TextDecoder().decode(value), 'hello world');
assert.ok(!done);

// Give up reading
await reader.cancel();

// Waste a bit of time so the server cleans up
await scheduler.wait(0);

assert.strictEqual(
await abortTracker.getAborted('hangAfterSendingSomeData'),
true
);
},
};

export const abortedRequestDoesNotAbortSubrequest = {
async test(ctrl, env, ctx) {
let abortTracker = env.AbortTracker.get(
env.AbortTracker.idFromName('AbortTracker')
);

// This endpoint calls another endpoint that eventually completes after wasting 300 ms
// So, we abort the initial request quickly...
const req = env.Server.fetch('http://example.com/triggerSubrequest', {
signal: AbortSignal.timeout(100),
});

await assert.rejects(() => req, {
name: 'TimeoutError',
message: 'The operation was aborted due to timeout',
});
assert.strictEqual(
await abortTracker.getAborted('triggerSubrequest'),
true
);

// Then make sure that the subrequest wasn't also aborted
await scheduler.wait(500);
assert.strictEqual(await abortTracker.getAborted('subrequest'), false);
},
};
Loading

0 comments on commit 2d76a37

Please sign in to comment.