Skip to content

Commit

Permalink
bit of cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
robherley committed Dec 29, 2024
1 parent 081ac6d commit 8ba6c14
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 86 deletions.
91 changes: 38 additions & 53 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
use actix_web::{
get, http, middleware::Logger, post, web, App, HttpResponse, HttpServer, Responder,
http, middleware::Logger, post, rt::spawn, web, App, HttpResponse, HttpServer, Responder,
ResponseError, Result,
};
use func_gg::{
runtime::Sandbox,
streams::{ReceiverStream, SenderStream},
streams::{InputStream, OutputStream},
};
use futures::StreamExt;
use log::{info, warn};
use tokio::sync::mpsc::channel;
use log::error;

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("runtime: {0}")]
Runtime(#[from] func_gg::runtime::Error),
#[error("payload: {0}")]
Payload(#[from] actix_web::error::PayloadError),
#[error("send: {0}")]
Send(String),
}

impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error {
fn from(err: tokio::sync::mpsc::error::SendError<T>) -> Self {
Self::Send(err.to_string())
}
}

impl ResponseError for Error {
Expand All @@ -22,60 +31,41 @@ impl ResponseError for Error {
}
}

// tokio_util::sync::CancellationToken
// https://tokio.rs/tokio/topics/shutdown
#[post("/")] // note: default payload limit is 256kB from actix-web, but is configurable with PayloadConfig
async fn handle(mut body: web::Payload) -> Result<impl Responder, Error> {
let binary = include_bytes!("/Users/robherley/dev/webfunc-handler/dist/main.wasm");
let mut sandbox = Sandbox::new(binary.to_vec())?;

let (stdin, req_tx) = ReceiverStream::new();
let (stdin, input_tx) = InputStream::new();

actix_web::rt::spawn(async move {
// collect input from request body
spawn(async move {
while let Some(item) = body.next().await {
match item {
Ok(chunk) => {
if let Err(e) = req_tx.send(chunk).await {
warn!("unable to send chunk: {:?}", e);
break;
}
}
Err(e) => {
warn!("payload error: {:?}", e);
break;
}
}
input_tx.send(item?).await?;
}
Ok::<(), Error>(())
});

let (stdout, mut res_rx) = SenderStream::new();

let (body_tx, body_rx) = channel::<Result<actix_web::web::Bytes, actix_web::Error>>(1);
let stream = tokio_stream::wrappers::ReceiverStream::new(body_rx);

actix_web::rt::spawn(async move {
while let Some(item) = res_rx.recv().await {
let chunk = actix_web::web::Bytes::from(item);
info!("sending chunk: {:?}", chunk);
if let Err(e) = body_tx.send(Ok(chunk)).await {
warn!("unable to send chunk: {:?}", e);
break;
}
}
});
let (stdout, output_rx, mut first_write_rx) = OutputStream::new();

actix_web::rt::spawn(async move {
if let Err(e) = sandbox.handler(stdin, stdout).await {
warn!("handler error: {:?}", e);
}
// invoke the function
spawn(async move {
sandbox.call(stdin, stdout).await?;
Ok::<(), Error>(())
});

// TODO(robherley): join handles? and proper error handling

Ok(HttpResponse::Ok().streaming(stream))
}
let content_type = match first_write_rx.recv().await {
Some(b'{') => "application/json",
Some(b'<') => "text/html",
_ => "text/plain",
};

#[get("/")]
async fn hello() -> impl Responder {
"Hello world!"
Ok(HttpResponse::Ok().content_type(content_type).streaming(
tokio_stream::wrappers::ReceiverStream::new(output_rx)
.map(|item| Ok::<_, Error>(actix_web::web::Bytes::from(item))),
))
}

#[actix_web::main]
Expand All @@ -88,13 +78,8 @@ async fn main() -> std::io::Result<()> {
std::env::var("PORT").unwrap_or("8080".into()),
);

HttpServer::new(|| {
App::new()
.wrap(Logger::new("%r %s %Dms"))
.service(hello)
.service(handle)
})
.bind(addr)?
.run()
.await
HttpServer::new(|| App::new().wrap(Logger::new("%r %s %Dms")).service(handle))
.bind(addr)?
.run()
.await
}
31 changes: 21 additions & 10 deletions src/runtime/sandbox.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::sync::Arc;

use log::info;
use tokio::spawn;
use tokio::sync::Mutex;
use wasmtime::*;
use wasmtime_wasi::preview1::{self, WasiP1Ctx};
use wasmtime_wasi::{AsyncStdinStream, WasiCtxBuilder};

use crate::streams::{ReceiverStream, SenderStream};
use crate::streams::{InputStream, OutputStream};

#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down Expand Up @@ -37,6 +41,7 @@ impl Sandbox {
let mut config = wasmtime::Config::default();
config.debug_info(true);
config.async_support(true);
config.epoch_interruption(true);

let engine = Engine::new(&config)?;

Expand All @@ -62,19 +67,17 @@ impl Sandbox {
})
}

pub async fn handler(
&mut self,
stdin: ReceiverStream,
stdout: SenderStream,
) -> Result<(), Error> {
pub async fn call(&mut self, stdin: InputStream, stdout: OutputStream) -> Result<(), Error> {
let wasi_ctx = WasiCtxBuilder::new()
.env("FUNC_GG", "1")
.stdin(AsyncStdinStream::from(stdin))
.stdout(stdout)
.inherit_stderr() // TODO(robherley): pipe stderr to a log stream
.build_p1();

// NOTE: if store changes, we need to recompile the module
let mut store = Store::new(&self.engine, wasi_ctx);
store.set_epoch_deadline(1);

let func = self
.linker
Expand All @@ -83,7 +86,17 @@ impl Sandbox {
.get_default(&mut store, "")?
.typed::<(), ()>(&store)?;

let result = func.call_async(&mut store, ()).await.or_else(|err| {
let engine = Arc::new(Mutex::new(self.engine.clone()));
spawn({
let engine = Arc::clone(&engine);
async move {
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
info!("cancelling request");
engine.lock().await.increment_epoch();
}
});

func.call_async(&mut store, ()).await.or_else(|err| {
match err.downcast_ref::<wasmtime_wasi::I32Exit>() {
Some(e) => {
if e.0 != 0 {
Expand All @@ -94,8 +107,6 @@ impl Sandbox {
}
_ => Err(err.into()),
}
})?;

Ok(result)
})
}
}
16 changes: 8 additions & 8 deletions src/streams/receiver.rs → src/streams/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ use tokio::{
};
use wasmtime_wasi::{pipe::AsyncReadStream, AsyncStdinStream};

pub struct ReceiverStream {
pub struct InputStream {
rx: Receiver<Bytes>,
xtra: Option<Vec<u8>>,
}

impl ReceiverStream {
impl InputStream {
pub fn new() -> (Self, Sender<Bytes>) {
let (tx, rx) = channel::<Bytes>(1);
(Self { rx, xtra: None }, tx)
}
}

impl AsyncRead for ReceiverStream {
impl AsyncRead for InputStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
Expand Down Expand Up @@ -48,8 +48,8 @@ impl AsyncRead for ReceiverStream {
}
}

impl From<ReceiverStream> for AsyncStdinStream {
fn from(stream: ReceiverStream) -> Self {
impl From<InputStream> for AsyncStdinStream {
fn from(stream: InputStream) -> Self {
let rs = AsyncReadStream::new(stream);
AsyncStdinStream::new(rs)
}
Expand All @@ -61,7 +61,7 @@ mod tests {

#[tokio::test]
async fn test_read() {
let (mut stream, tx) = ReceiverStream::new();
let (mut stream, tx) = InputStream::new();

tx.send(Bytes::from("hello world")).await.unwrap();

Expand All @@ -77,7 +77,7 @@ mod tests {

#[tokio::test]
async fn test_into_leftover() {
let (mut stream, tx) = ReceiverStream::new();
let (mut stream, tx) = InputStream::new();

tx.send(Bytes::from("hello world")).await.unwrap();

Expand All @@ -94,7 +94,7 @@ mod tests {

#[tokio::test]
async fn test_from_leftover() {
let (mut stream, tx) = ReceiverStream::new();
let (mut stream, tx) = InputStream::new();
stream.xtra = Some(Bytes::from("hello ").to_vec());

tx.send(Bytes::from("world")).await.unwrap();
Expand Down
8 changes: 4 additions & 4 deletions src/streams/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod receiver;
mod sender;
mod input;
pub use input::InputStream;

pub use receiver::ReceiverStream;
pub use sender::SenderStream;
mod output;
pub use output::OutputStream;
41 changes: 30 additions & 11 deletions src/streams/sender.rs → src/streams/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,44 @@ use wasmtime_wasi::{
};

#[derive(Clone)]
pub struct SenderStream {
pub struct OutputStream {
tx: Sender<Bytes>,
first_tx: Option<Sender<u8>>,
}

impl SenderStream {
pub fn new() -> (Self, Receiver<Bytes>) {
impl OutputStream {
pub fn new() -> (Self, Receiver<Bytes>, Receiver<u8>) {
let (tx, rx) = channel::<Bytes>(1);
(Self { tx }, rx)
let (first_tx, first_rx) = channel::<u8>(1);
(
Self {
tx,
first_tx: Some(first_tx),
},
rx,
first_rx,
)
}
}

#[async_trait]
impl Subscribe for SenderStream {
impl Subscribe for OutputStream {
async fn ready(&mut self) {}
}

#[async_trait]
impl HostOutputStream for SenderStream {
impl HostOutputStream for OutputStream {
fn write(&mut self, buf: Bytes) -> StreamResult<()> {
if buf.is_empty() {
return Ok(());
}

if let Some(first_tx) = self.first_tx.take() {
if let Err(err) = first_tx.try_send(buf[0]) {
return Err(StreamError::LastOperationFailed(err.into()));
}
}

match self.tx.try_send(Bytes::from(buf)) {
Ok(()) => Ok(()),
Err(err) => Err(StreamError::LastOperationFailed(err.into())),
Expand All @@ -39,7 +58,7 @@ impl HostOutputStream for SenderStream {
}
}

impl StdoutStream for SenderStream {
impl StdoutStream for OutputStream {
fn stream(&self) -> Box<dyn HostOutputStream> {
Box::new(self.clone())
}
Expand All @@ -55,7 +74,7 @@ mod tests {

#[tokio::test]
async fn test_write() {
let (mut stream, mut rx) = SenderStream::new();
let (mut stream, mut rx, _) = OutputStream::new();

let data = Bytes::from("hello");
stream.write(data.clone()).unwrap();
Expand All @@ -66,19 +85,19 @@ mod tests {

#[tokio::test]
async fn test_flush() {
let (mut stream, _rx) = SenderStream::new();
let (mut stream, _, _) = OutputStream::new();
assert!(stream.flush().is_ok());
}

#[tokio::test]
async fn test_check_write() {
let (mut stream, _rx) = SenderStream::new();
let (mut stream, _, _) = OutputStream::new();
assert_eq!(stream.check_write().unwrap(), usize::MAX);
}

#[tokio::test]
async fn test_isatty() {
let (stream, _rx) = SenderStream::new();
let (stream, _, _) = OutputStream::new();
assert!(!stream.isatty());
}
}

0 comments on commit 8ba6c14

Please sign in to comment.