diff --git a/src/main.rs b/src/main.rs index d8d826d..96d0786 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,28 @@ use actix_web::{ - get, http, middleware::Logger, post, web, App, HttpResponse, HttpServer, Responder, + http, middleware::Logger, post, rt::spawn, web, App, HttpResponse, HttpServer, Responder, ResponseError, Result, }; use func_gg::{ runtime::Sandbox, - streams::{ReceiverStream, SenderStream}, + streams::{InputStream, OutputStream}, }; use futures::StreamExt; -use log::{info, warn}; -use tokio::sync::mpsc::channel; +use log::error; #[derive(thiserror::Error, Debug)] pub enum Error { #[error("runtime: {0}")] Runtime(#[from] func_gg::runtime::Error), + #[error("payload: {0}")] + Payload(#[from] actix_web::error::PayloadError), + #[error("send: {0}")] + Send(String), +} + +impl From> for Error { + fn from(err: tokio::sync::mpsc::error::SendError) -> Self { + Self::Send(err.to_string()) + } } impl ResponseError for Error { @@ -22,60 +31,41 @@ impl ResponseError for Error { } } +// tokio_util::sync::CancellationToken +// https://tokio.rs/tokio/topics/shutdown #[post("/")] // note: default payload limit is 256kB from actix-web, but is configurable with PayloadConfig async fn handle(mut body: web::Payload) -> Result { let binary = include_bytes!("/Users/robherley/dev/webfunc-handler/dist/main.wasm"); let mut sandbox = Sandbox::new(binary.to_vec())?; - let (stdin, req_tx) = ReceiverStream::new(); + let (stdin, input_tx) = InputStream::new(); - actix_web::rt::spawn(async move { + // collect input from request body + spawn(async move { while let Some(item) = body.next().await { - match item { - Ok(chunk) => { - if let Err(e) = req_tx.send(chunk).await { - warn!("unable to send chunk: {:?}", e); - break; - } - } - Err(e) => { - warn!("payload error: {:?}", e); - break; - } - } + input_tx.send(item?).await?; } + Ok::<(), Error>(()) }); - let (stdout, mut res_rx) = SenderStream::new(); - - let (body_tx, body_rx) = channel::>(1); - let stream = tokio_stream::wrappers::ReceiverStream::new(body_rx); - - actix_web::rt::spawn(async move { - while let Some(item) = res_rx.recv().await { - let chunk = actix_web::web::Bytes::from(item); - info!("sending chunk: {:?}", chunk); - if let Err(e) = body_tx.send(Ok(chunk)).await { - warn!("unable to send chunk: {:?}", e); - break; - } - } - }); + let (stdout, output_rx, mut first_write_rx) = OutputStream::new(); - actix_web::rt::spawn(async move { - if let Err(e) = sandbox.handler(stdin, stdout).await { - warn!("handler error: {:?}", e); - } + // invoke the function + spawn(async move { + sandbox.call(stdin, stdout).await?; + Ok::<(), Error>(()) }); - // TODO(robherley): join handles? and proper error handling - - Ok(HttpResponse::Ok().streaming(stream)) -} + let content_type = match first_write_rx.recv().await { + Some(b'{') => "application/json", + Some(b'<') => "text/html", + _ => "text/plain", + }; -#[get("/")] -async fn hello() -> impl Responder { - "Hello world!" + Ok(HttpResponse::Ok().content_type(content_type).streaming( + tokio_stream::wrappers::ReceiverStream::new(output_rx) + .map(|item| Ok::<_, Error>(actix_web::web::Bytes::from(item))), + )) } #[actix_web::main] @@ -88,13 +78,8 @@ async fn main() -> std::io::Result<()> { std::env::var("PORT").unwrap_or("8080".into()), ); - HttpServer::new(|| { - App::new() - .wrap(Logger::new("%r %s %Dms")) - .service(hello) - .service(handle) - }) - .bind(addr)? - .run() - .await + HttpServer::new(|| App::new().wrap(Logger::new("%r %s %Dms")).service(handle)) + .bind(addr)? + .run() + .await } diff --git a/src/runtime/sandbox.rs b/src/runtime/sandbox.rs index e7bd986..f039753 100644 --- a/src/runtime/sandbox.rs +++ b/src/runtime/sandbox.rs @@ -1,9 +1,13 @@ +use std::sync::Arc; + use log::info; +use tokio::spawn; +use tokio::sync::Mutex; use wasmtime::*; use wasmtime_wasi::preview1::{self, WasiP1Ctx}; use wasmtime_wasi::{AsyncStdinStream, WasiCtxBuilder}; -use crate::streams::{ReceiverStream, SenderStream}; +use crate::streams::{InputStream, OutputStream}; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -37,6 +41,7 @@ impl Sandbox { let mut config = wasmtime::Config::default(); config.debug_info(true); config.async_support(true); + config.epoch_interruption(true); let engine = Engine::new(&config)?; @@ -62,11 +67,7 @@ impl Sandbox { }) } - pub async fn handler( - &mut self, - stdin: ReceiverStream, - stdout: SenderStream, - ) -> Result<(), Error> { + pub async fn call(&mut self, stdin: InputStream, stdout: OutputStream) -> Result<(), Error> { let wasi_ctx = WasiCtxBuilder::new() .env("FUNC_GG", "1") .stdin(AsyncStdinStream::from(stdin)) @@ -74,7 +75,9 @@ impl Sandbox { .inherit_stderr() // TODO(robherley): pipe stderr to a log stream .build_p1(); + // NOTE: if store changes, we need to recompile the module let mut store = Store::new(&self.engine, wasi_ctx); + store.set_epoch_deadline(1); let func = self .linker @@ -83,7 +86,17 @@ impl Sandbox { .get_default(&mut store, "")? .typed::<(), ()>(&store)?; - let result = func.call_async(&mut store, ()).await.or_else(|err| { + let engine = Arc::new(Mutex::new(self.engine.clone())); + spawn({ + let engine = Arc::clone(&engine); + async move { + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + info!("cancelling request"); + engine.lock().await.increment_epoch(); + } + }); + + func.call_async(&mut store, ()).await.or_else(|err| { match err.downcast_ref::() { Some(e) => { if e.0 != 0 { @@ -94,8 +107,6 @@ impl Sandbox { } _ => Err(err.into()), } - })?; - - Ok(result) + }) } } diff --git a/src/streams/receiver.rs b/src/streams/input.rs similarity index 89% rename from src/streams/receiver.rs rename to src/streams/input.rs index 9a69f37..ef771df 100644 --- a/src/streams/receiver.rs +++ b/src/streams/input.rs @@ -6,19 +6,19 @@ use tokio::{ }; use wasmtime_wasi::{pipe::AsyncReadStream, AsyncStdinStream}; -pub struct ReceiverStream { +pub struct InputStream { rx: Receiver, xtra: Option>, } -impl ReceiverStream { +impl InputStream { pub fn new() -> (Self, Sender) { let (tx, rx) = channel::(1); (Self { rx, xtra: None }, tx) } } -impl AsyncRead for ReceiverStream { +impl AsyncRead for InputStream { fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -48,8 +48,8 @@ impl AsyncRead for ReceiverStream { } } -impl From for AsyncStdinStream { - fn from(stream: ReceiverStream) -> Self { +impl From for AsyncStdinStream { + fn from(stream: InputStream) -> Self { let rs = AsyncReadStream::new(stream); AsyncStdinStream::new(rs) } @@ -61,7 +61,7 @@ mod tests { #[tokio::test] async fn test_read() { - let (mut stream, tx) = ReceiverStream::new(); + let (mut stream, tx) = InputStream::new(); tx.send(Bytes::from("hello world")).await.unwrap(); @@ -77,7 +77,7 @@ mod tests { #[tokio::test] async fn test_into_leftover() { - let (mut stream, tx) = ReceiverStream::new(); + let (mut stream, tx) = InputStream::new(); tx.send(Bytes::from("hello world")).await.unwrap(); @@ -94,7 +94,7 @@ mod tests { #[tokio::test] async fn test_from_leftover() { - let (mut stream, tx) = ReceiverStream::new(); + let (mut stream, tx) = InputStream::new(); stream.xtra = Some(Bytes::from("hello ").to_vec()); tx.send(Bytes::from("world")).await.unwrap(); diff --git a/src/streams/mod.rs b/src/streams/mod.rs index 217c18c..31e653a 100644 --- a/src/streams/mod.rs +++ b/src/streams/mod.rs @@ -1,5 +1,5 @@ -mod receiver; -mod sender; +mod input; +pub use input::InputStream; -pub use receiver::ReceiverStream; -pub use sender::SenderStream; +mod output; +pub use output::OutputStream; diff --git a/src/streams/sender.rs b/src/streams/output.rs similarity index 60% rename from src/streams/sender.rs rename to src/streams/output.rs index 8d1711d..0e8f966 100644 --- a/src/streams/sender.rs +++ b/src/streams/output.rs @@ -5,25 +5,44 @@ use wasmtime_wasi::{ }; #[derive(Clone)] -pub struct SenderStream { +pub struct OutputStream { tx: Sender, + first_tx: Option>, } -impl SenderStream { - pub fn new() -> (Self, Receiver) { +impl OutputStream { + pub fn new() -> (Self, Receiver, Receiver) { let (tx, rx) = channel::(1); - (Self { tx }, rx) + let (first_tx, first_rx) = channel::(1); + ( + Self { + tx, + first_tx: Some(first_tx), + }, + rx, + first_rx, + ) } } #[async_trait] -impl Subscribe for SenderStream { +impl Subscribe for OutputStream { async fn ready(&mut self) {} } #[async_trait] -impl HostOutputStream for SenderStream { +impl HostOutputStream for OutputStream { fn write(&mut self, buf: Bytes) -> StreamResult<()> { + if buf.is_empty() { + return Ok(()); + } + + if let Some(first_tx) = self.first_tx.take() { + if let Err(err) = first_tx.try_send(buf[0]) { + return Err(StreamError::LastOperationFailed(err.into())); + } + } + match self.tx.try_send(Bytes::from(buf)) { Ok(()) => Ok(()), Err(err) => Err(StreamError::LastOperationFailed(err.into())), @@ -39,7 +58,7 @@ impl HostOutputStream for SenderStream { } } -impl StdoutStream for SenderStream { +impl StdoutStream for OutputStream { fn stream(&self) -> Box { Box::new(self.clone()) } @@ -55,7 +74,7 @@ mod tests { #[tokio::test] async fn test_write() { - let (mut stream, mut rx) = SenderStream::new(); + let (mut stream, mut rx, _) = OutputStream::new(); let data = Bytes::from("hello"); stream.write(data.clone()).unwrap(); @@ -66,19 +85,19 @@ mod tests { #[tokio::test] async fn test_flush() { - let (mut stream, _rx) = SenderStream::new(); + let (mut stream, _, _) = OutputStream::new(); assert!(stream.flush().is_ok()); } #[tokio::test] async fn test_check_write() { - let (mut stream, _rx) = SenderStream::new(); + let (mut stream, _, _) = OutputStream::new(); assert_eq!(stream.check_write().unwrap(), usize::MAX); } #[tokio::test] async fn test_isatty() { - let (stream, _rx) = SenderStream::new(); + let (stream, _, _) = OutputStream::new(); assert!(!stream.isatty()); } }