From 00e95cd8489498c978647c52329280ad6dc236c6 Mon Sep 17 00:00:00 2001 From: Jonathan Haines Date: Mon, 17 Feb 2025 11:30:48 +1100 Subject: [PATCH] fix(proxy): include original request signal in proxy request --- src/helper/proxy/index.test.ts | 26 ++++++++++++++++++++++++++ src/helper/proxy/index.ts | 1 + 2 files changed, 27 insertions(+) diff --git a/src/helper/proxy/index.test.ts b/src/helper/proxy/index.test.ts index 9c291ff43..9a092c0df 100644 --- a/src/helper/proxy/index.test.ts +++ b/src/helper/proxy/index.test.ts @@ -7,6 +7,18 @@ describe('Proxy Middleware', () => { global.fetch = vi.fn().mockImplementation(async (req) => { if (req.url === 'https://example.com/ok') { return Promise.resolve(new Response('ok')) + } else if (req.url === 'https://example.com/disconnect') { + const reader = req.body.getReader() + let response + + req.signal.addEventListener('abort', () => { + response = req.signal.reason + reader.cancel() + }) + + await reader.read() + + return Promise.resolve(new Response(response)) } else if (req.url === 'https://example.com/compressed') { return Promise.resolve( new Response('ok', { @@ -200,5 +212,19 @@ describe('Proxy Middleware', () => { const req = (global.fetch as ReturnType).mock.calls[0][0] expect(req.headers.get('Authorization')).toBeNull() }) + + it('client disconnect', async () => { + const app = new Hono() + const controller = new AbortController() + app.post('/proxy/:path', (c) => proxy(`https://example.com/${c.req.param('path')}`, c.req)) + const resPromise = app.request('/proxy/disconnect', { + method: 'POST', + body: 'test', + signal: controller.signal, + }) + controller.abort('client disconnect') + const res = await resPromise + expect(await res.text()).toBe('client disconnect') + }) }) }) diff --git a/src/helper/proxy/index.ts b/src/helper/proxy/index.ts index 96d19f8c2..ead7f8198 100644 --- a/src/helper/proxy/index.ts +++ b/src/helper/proxy/index.ts @@ -46,6 +46,7 @@ const buildRequestInitFromRequest = ( body: request.body, duplex: request.body ? 'half' : undefined, headers, + signal: request.signal, } }