diff --git a/crates/c-api/include/wasmtime/component.h b/crates/c-api/include/wasmtime/component.h index 99caaaed18ab..c074da38a3e9 100644 --- a/crates/c-api/include/wasmtime/component.h +++ b/crates/c-api/include/wasmtime/component.h @@ -53,6 +53,8 @@ typedef uint8_t wasmtime_component_kind_t; typedef struct wasmtime_component_val_t wasmtime_component_val_t; typedef struct wasmtime_component_val_record_field_t wasmtime_component_val_record_field_t; +typedef struct wasmtime_component_type_t wasmtime_component_type_t; +typedef struct wasmtime_component_type_field_t wasmtime_component_type_field_t; #define WASMTIME_COMPONENT_DECLARE_VEC(name, element) \ typedef struct wasmtime_component_##name##_t { \ @@ -85,6 +87,16 @@ WASMTIME_COMPONENT_DECLARE_VEC(val_record, wasmtime_component_val_record_field_t WASMTIME_COMPONENT_DECLARE_VEC(val_flags, uint32_t); WASMTIME_COMPONENT_DECLARE_VEC_NEW(val_flags, uint32_t); +/// \brief A vector of types +WASMTIME_COMPONENT_DECLARE_VEC(type_vec, wasmtime_component_type_t); + +/// \brief A vector of field types +WASMTIME_COMPONENT_DECLARE_VEC(type_field_vec, wasmtime_component_type_field_t); + +/// \brief A vector of strings +WASMTIME_COMPONENT_DECLARE_VEC(string_vec, wasm_name_t); +WASMTIME_COMPONENT_DECLARE_VEC_NEW(string_vec, wasm_name_t); + #undef WASMTIME_COMPONENT_DECLARE_VEC // A variant contains the discriminant index and an optional value that is held. @@ -157,6 +169,43 @@ wasmtime_component_val_t* wasmtime_component_val_new(); void wasmtime_component_val_delete(wasmtime_component_val_t* val); +typedef struct wasmtime_component_type_field_t { + wasm_name_t name; + wasmtime_component_type_t* ty; +} wasmtime_component_type_field_t; + +WASMTIME_COMPONENT_DECLARE_VEC_NEW(type_field_vec, wasmtime_component_type_field_t); + +typedef struct wasmtime_component_type_result_t { + wasmtime_component_type_t* ok_ty; + wasmtime_component_type_t* err_ty; +} wasmtime_component_type_result_t; + +typedef union wasmtime_component_type_payload_t +{ + wasmtime_component_type_t* list; + wasmtime_component_type_field_vec_t record; + wasmtime_component_type_vec_t tuple; + wasmtime_component_type_field_vec_t variant; + wasmtime_component_string_vec_t enumeration; + wasmtime_component_type_t* option; + wasmtime_component_type_result_t result; + wasmtime_component_string_vec_t flags; +} wasmtime_component_type_payload_t; + +typedef struct wasmtime_component_type_t { + wasmtime_component_kind_t kind; + wasmtime_component_type_payload_t payload; +} wasmtime_component_type_t; + +WASMTIME_COMPONENT_DECLARE_VEC_NEW(type_vec, wasmtime_component_type_t); + +#undef WASMTIME_COMPONENT_DECLARE_VEC_NEW + +wasmtime_component_type_t* wasmtime_component_type_new(); + +void wasmtime_component_type_delete(wasmtime_component_type_t* ty); + typedef struct wasmtime_component_t wasmtime_component_t; wasmtime_error_t * @@ -176,6 +225,19 @@ typedef struct wasmtime_component_instance_t wasmtime_component_instance_t; // declaration from store.h typedef struct wasmtime_context wasmtime_context_t; +typedef wasm_trap_t *(*wasmtime_component_func_callback_t)( + void *env, wasmtime_context_t *context, const wasmtime_component_val_t *args, + size_t nargs, wasmtime_component_val_t *results, size_t nresults); + +wasmtime_error_t *wasmtime_component_linker_define_func( + wasmtime_component_linker_t *linker, const char *path, size_t path_len, + const char *name, size_t name_len, + wasmtime_component_type_t* params_types_buf, size_t params_types_len, + wasmtime_component_type_t* outputs_types_buf, size_t outputs_types_len, + wasmtime_component_func_callback_t cb, void *data, void (*finalizer)(void *)); + +wasmtime_error_t *wasmtime_component_linker_build(wasmtime_component_linker_t *linker); + wasmtime_error_t *wasmtime_component_linker_instantiate( const wasmtime_component_linker_t *linker, wasmtime_context_t *context, const wasmtime_component_t *component, wasmtime_component_instance_t **instance_out); diff --git a/crates/c-api/src/component.rs b/crates/c-api/src/component.rs index 85c2f1e9d174..eb723e2d4517 100644 --- a/crates/c-api/src/component.rs +++ b/crates/c-api/src/component.rs @@ -1,13 +1,16 @@ -use anyhow::{bail, ensure, Context, Result}; -use wasmtime::component::{Component, Func, Instance, Linker, Type, Val}; +use anyhow::{anyhow, bail, ensure, Context, Result}; +use wasmtime::component::{Component, Func, Instance, Linker, LinkerInstance, Type, Val}; use wasmtime::{AsContext, AsContextMut}; use crate::{ - declare_vecs, handle_call_error, handle_result, wasm_byte_vec_t, wasm_config_t, wasm_engine_t, - wasm_name_t, wasm_trap_t, wasmtime_error_t, WasmtimeStoreContextMut, WasmtimeStoreData, + bad_utf8, declare_vecs, handle_call_error, handle_result, to_str, wasm_byte_vec_t, + wasm_config_t, wasm_engine_t, wasm_name_t, wasm_trap_t, wasmtime_error_t, + WasmtimeStoreContextMut, WasmtimeStoreData, }; +use core::ffi::c_void; use std::collections::HashMap; -use std::{mem, mem::MaybeUninit, ptr, slice}; +use std::ops::Deref; +use std::{mem, mem::MaybeUninit, ptr, slice, str}; #[no_mangle] pub extern "C" fn wasmtime_config_component_model_set(c: &mut wasm_config_t, enable: bool) { @@ -346,6 +349,430 @@ impl TryFrom> for Val { } } +// a Val and its associated wasmtime_component_type_t (from the c_api) +struct CTypedVal<'a>(&'a Val, &'a CType); + +impl TryFrom> for wasmtime_component_val_t { + type Error = anyhow::Error; + fn try_from(value: CTypedVal) -> Result { + let (value, ty) = (value.0, value.1); + Ok(match value { + Val::Bool(v) => { + ensure_type!(ty, CType::Bool); + wasmtime_component_val_t::Bool(*v) + } + Val::S8(v) => { + ensure_type!(ty, CType::S8); + wasmtime_component_val_t::S8(*v) + } + Val::U8(v) => { + ensure_type!(ty, CType::U8); + wasmtime_component_val_t::U8(*v) + } + Val::S16(v) => { + ensure_type!(ty, CType::S16); + wasmtime_component_val_t::S16(*v) + } + Val::U16(v) => { + ensure_type!(ty, CType::U16); + wasmtime_component_val_t::U16(*v) + } + Val::S32(v) => { + ensure_type!(ty, CType::S32); + wasmtime_component_val_t::S32(*v) + } + Val::U32(v) => { + ensure_type!(ty, CType::U32); + wasmtime_component_val_t::U32(*v) + } + Val::S64(v) => { + ensure_type!(ty, CType::S64); + wasmtime_component_val_t::S64(*v) + } + Val::U64(v) => { + ensure_type!(ty, CType::U64); + wasmtime_component_val_t::U64(*v) + } + Val::Float32(v) => { + ensure_type!(ty, CType::F32); + wasmtime_component_val_t::F32(*v) + } + Val::Float64(v) => { + ensure_type!(ty, CType::F64); + wasmtime_component_val_t::F64(*v) + } + Val::Char(v) => { + ensure_type!(ty, CType::Char); + wasmtime_component_val_t::Char(*v) + } + Val::String(v) => { + ensure_type!(ty, CType::String); + wasmtime_component_val_t::String(v.clone().into_bytes().into()) + } + Val::List(vec) => { + if let CType::List(ty) = ty { + wasmtime_component_val_t::List( + vec.iter() + .map(|v| CTypedVal(v, ty.deref()).try_into()) + .collect::>>()? + .into(), + ) + } else { + bail!("attempted to create a List for a {}", ty.desc()); + } + } + Val::Record(vec) => { + if let CType::Record(ty) = ty { + let mut field_vals: HashMap<&str, &Val> = + HashMap::from_iter(vec.iter().map(|f| (f.0.as_str(), &f.1))); + + wasmtime_component_val_t::Record( + ty.iter() + .map(|(field_name, field_type)| { + match field_vals.remove(field_name.as_str()) { + Some(v) => Ok(wasmtime_component_val_record_field_t { + name: field_name.clone().into_bytes().into(), + val: CTypedVal(v, field_type).try_into()?, + }), + None => bail!("missing field {} in record", field_name), + } + }) + .collect::>>()? + .into(), + ) + } else { + bail!("attempted to create a Record for a {}", ty.desc()); + } + } + Val::Tuple(vec) => { + if let CType::Tuple(ty) = ty { + wasmtime_component_val_t::Tuple( + vec.iter() + .zip(ty.iter()) + .map(|(v, ty)| CTypedVal(v, ty).try_into()) + .collect::>>()? + .into(), + ) + } else { + bail!("attempted to create a Tuple for a {}", ty.desc()); + } + } + Val::Variant(case, val) => { + if let CType::Variant(ty) = ty { + let index = ty + .iter() + .position(|c| &c.0 == case) + .context(format!("case {case} not found in type"))?; + ensure!( + val.is_some() == ty[index].1.is_some(), + "mismatched variant case {} : value is {}, but type is {}", + case, + if val.is_some() { "some" } else { "none" }, + ty[index].1.as_ref().map(|ty| ty.desc()).unwrap_or("none") + ); + wasmtime_component_val_t::Variant(wasmtime_component_val_variant_t { + discriminant: index as u32, + val: match val { + Some(val) => Some(Box::new( + CTypedVal(val.deref(), ty[index].1.as_ref().unwrap()).try_into()?, + )), + None => None, + }, + }) + } else { + bail!("attempted to create a Variant for a {}", ty.desc()); + } + } + Val::Enum(v) => { + if let CType::Enum(ty) = ty { + let index = ty + .iter() + .position(|s| s == v) + .context(format!("enum value {v} not found in type"))?; + wasmtime_component_val_t::Enum(wasmtime_component_val_enum_t { + discriminant: index as u32, + }) + } else { + bail!("attempted to create a Enum for a {}", ty.desc()); + } + } + Val::Option(val) => { + if let CType::Option(ty) = ty { + wasmtime_component_val_t::Option(match val { + Some(val) => Some(Box::new(CTypedVal(val.deref(), ty.deref()).try_into()?)), + None => None, + }) + } else { + bail!("attempted to create a Option for a {}", ty.desc()); + } + } + Val::Result(val) => { + if let CType::Result(ok_type, err_type) = ty { + wasmtime_component_val_t::Result(match val { + Ok(Some(ok)) => { + let ok_type = ok_type + .as_ref() + .context("some ok result found instead of none")?; + wasmtime_component_val_result_t { + value: Some(Box::new( + CTypedVal(ok.deref(), ok_type.deref()).try_into()?, + )), + error: false, + } + } + Ok(None) => { + ensure!( + ok_type.is_none(), + "none ok result found instead of {}", + ok_type.as_ref().unwrap().desc() + ); + wasmtime_component_val_result_t { + value: None, + error: false, + } + } + Err(Some(err)) => { + let err_type = err_type + .as_ref() + .context("some err result found instead of none")?; + wasmtime_component_val_result_t { + value: Some(Box::new( + CTypedVal(err.deref(), err_type.deref()).try_into()?, + )), + error: true, + } + } + Err(None) => { + ensure!( + err_type.is_none(), + "none err result found instead of {}", + err_type.as_ref().unwrap().desc() + ); + wasmtime_component_val_result_t { + value: None, + error: true, + } + } + }) + } else { + bail!("attempted to create a Result for a {}", ty.desc()); + } + } + Val::Flags(vec) => { + if let CType::Flags(ty) = ty { + let mapping: HashMap<_, _> = ty.iter().zip(0u32..).collect(); + let mut flags: wasmtime_component_val_flags_t = Vec::new().into(); + for name in vec.iter() { + let idx = mapping.get(name).context("expected valid name")?; + wasmtime_component_val_flags_set(&mut flags, *idx, true); + } + wasmtime_component_val_t::Flags(flags) + } else { + bail!("attempted to create a Flags for a {}", ty.desc()); + } + } + Val::Resource(_) => bail!("resources not supported"), + }) + } +} + +// a wasmtime_component_val_t and its associated wasmtime_component_type_t +struct CTypedCVal<'a>(&'a wasmtime_component_val_t, &'a CType); + +impl TryFrom> for Val { + type Error = anyhow::Error; + fn try_from(value: CTypedCVal) -> Result { + let (value, ty) = (value.0, value.1); + Ok(match value { + &wasmtime_component_val_t::Bool(b) => { + ensure_type!(ty, CType::Bool); + Val::Bool(b) + } + &wasmtime_component_val_t::S8(v) => { + ensure_type!(ty, CType::S8); + Val::S8(v) + } + &wasmtime_component_val_t::U8(v) => { + ensure_type!(ty, CType::U8); + Val::U8(v) + } + &wasmtime_component_val_t::S16(v) => { + ensure_type!(ty, CType::S16); + Val::S16(v) + } + &wasmtime_component_val_t::U16(v) => { + ensure_type!(ty, CType::U16); + Val::U16(v) + } + &wasmtime_component_val_t::S32(v) => { + ensure_type!(ty, CType::S32); + Val::S32(v) + } + &wasmtime_component_val_t::U32(v) => { + ensure_type!(ty, CType::U32); + Val::U32(v) + } + &wasmtime_component_val_t::S64(v) => { + ensure_type!(ty, CType::S64); + Val::S64(v) + } + &wasmtime_component_val_t::U64(v) => { + ensure_type!(ty, CType::U64); + Val::U64(v) + } + &wasmtime_component_val_t::F32(v) => { + ensure_type!(ty, CType::F32); + Val::Float32(v) + } + &wasmtime_component_val_t::F64(v) => { + ensure_type!(ty, CType::F64); + Val::Float64(v) + } + &wasmtime_component_val_t::Char(v) => { + ensure_type!(ty, CType::Char); + Val::Char(v) + } + wasmtime_component_val_t::String(v) => { + ensure_type!(ty, CType::String); + Val::String(String::from_utf8(v.as_slice().to_vec())?) + } + wasmtime_component_val_t::List(v) => { + if let CType::List(ty) = ty { + Val::List( + v.as_slice() + .iter() + .map(|v| CTypedCVal(v, ty.deref()).try_into()) + .collect::>>()?, + ) + } else { + bail!("attempted to create a list for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Record(v) => { + if let CType::Record(ty) = ty { + let mut field_vals: HashMap<&[u8], &wasmtime_component_val_t> = + HashMap::from_iter( + v.as_slice().iter().map(|f| (f.name.as_slice(), &f.val)), + ); + Val::Record( + ty.iter() + .map(|tyf| { + if let Some(v) = field_vals.remove(tyf.0.as_bytes()) { + Ok((tyf.0.clone(), CTypedCVal(v, &tyf.1).try_into()?)) + } else { + bail!("record missing field: {}", tyf.0); + } + }) + .collect::>>()?, + ) + } else { + bail!("attempted to create a record for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Tuple(v) => { + if let CType::Tuple(ty) = ty { + Val::Tuple( + ty.iter() + .zip(v.as_slice().iter()) + .map(|(ty, v)| CTypedCVal(v, ty).try_into()) + .collect::>>()?, + ) + } else { + bail!("attempted to create a tuple for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Variant(v) => { + if let CType::Variant(ty) = ty { + let index = v.discriminant as usize; + ensure!(index < ty.len(), "variant index outside range"); + let case = &ty[index]; + let case_name = case.0.clone(); + ensure!( + case.1.is_some() == v.val.is_some(), + "variant type mismatch for case {}: {} instead of {}", + case_name, + if v.val.is_some() { "some" } else { "none" }, + case.1.as_ref().map(|ty| ty.desc()).unwrap_or("none") + ); + if let (Some(t), Some(v)) = (&case.1, &v.val) { + let v = CTypedCVal(v.as_ref(), t.deref()).try_into()?; + Val::Variant(case_name, Some(Box::new(v))) + } else { + Val::Variant(case_name, None) + } + } else { + bail!("attempted to create a variant for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Enum(v) => { + if let CType::Enum(ty) = ty { + let index = v.discriminant as usize; + ensure!(index < ty.as_slice().len(), "variant index outside range"); + Val::Enum(ty[index].clone()) + } else { + bail!("attempted to create an enum for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Option(v) => { + if let CType::Option(ty) = ty { + Val::Option(match v { + Some(v) => Some(Box::new(CTypedCVal(v.as_ref(), ty.deref()).try_into()?)), + None => None, + }) + } else { + bail!("attempted to create an option for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Result(v) => { + if let CType::Result(ok_ty, err_ty) = ty { + if v.error { + match &v.value { + Some(v) => { + let ty = err_ty.as_deref().context("expected err type")?; + Val::Result(Err(Some(Box::new( + CTypedCVal(v.as_ref(), ty).try_into()?, + )))) + } + None => { + ensure!(err_ty.is_none(), "expected no err type"); + Val::Result(Err(None)) + } + } + } else { + match &v.value { + Some(v) => { + let ty = ok_ty.as_deref().context("expected ok type")?; + Val::Result(Ok(Some(Box::new( + CTypedCVal(v.as_ref(), ty).try_into()?, + )))) + } + None => { + ensure!(ok_ty.is_none(), "expected no ok type"); + Val::Result(Ok(None)) + } + } + } + } else { + bail!("attempted to create a result for a {}", ty.desc()); + } + } + wasmtime_component_val_t::Flags(flags) => { + if let CType::Flags(ty) = ty { + let mut set = Vec::new(); + for (idx, name) in ty.iter().enumerate() { + if wasmtime_component_val_flags_test(&flags, idx as u32) { + set.push(name.clone()); + } + } + Val::Flags(set) + } else { + bail!("attempted to create a flags for a {}", ty.desc()); + } + } + }) + } +} + impl TryFrom<(&Val, &Type)> for wasmtime_component_val_t { type Error = anyhow::Error; @@ -565,6 +992,303 @@ pub const WASMTIME_COMPONENT_KIND_OPTION: wasmtime_component_kind_t = 18; pub const WASMTIME_COMPONENT_KIND_RESULT: wasmtime_component_kind_t = 19; pub const WASMTIME_COMPONENT_KIND_FLAGS: wasmtime_component_kind_t = 20; +#[repr(C)] +#[derive(Clone)] +pub struct wasmtime_component_type_field_t { + pub name: wasm_name_t, + pub ty: Option, +} + +impl Default for wasmtime_component_type_field_t { + fn default() -> Self { + Self { + name: Vec::new().into(), + ty: Default::default(), + } + } +} + +#[repr(C)] +#[derive(Clone)] +pub struct wasmtime_component_type_result_t { + pub ok_ty: Option>, + pub err_ty: Option>, +} + +declare_vecs! { + ( + name: wasmtime_component_type_vec_t, + ty: wasmtime_component_type_t, + new: wasmtime_component_type_vec_new, + empty: wasmtime_component_type_vec_new_empty, + uninit: wasmtime_component_type_vec_new_uninitialized, + copy: wasmtime_component_type_vec_copy, + delete: wasmtime_component_type_vec_delete, + ) + ( + name: wasmtime_component_type_field_vec_t, + ty: wasmtime_component_type_field_t, + new: wasmtime_component_type_field_vec_new, + empty: wasmtime_component_type_field_vec_new_empty, + uninit: wasmtime_component_type_field_vec_new_uninitialized, + copy: wasmtime_component_type_field_vec_copy, + delete: wasmtime_component_type_field_vec_delete, + ) + ( + name: wasmtime_component_string_vec_t, + ty: wasmtime_component_string_t, + new: wasmtime_component_string_vec_new, + empty: wasmtime_component_string_vec_new_empty, + uninit: wasmtime_component_string_vec_new_uninitialized, + copy: wasmtime_component_string_vec_copy, + delete: wasmtime_component_string_vec_delete, + ) +} + +#[repr(C, u8)] +#[derive(Clone)] +pub enum wasmtime_component_type_t { + Bool, + S8, + U8, + S16, + U16, + S32, + U32, + S64, + U64, + F32, + F64, + Char, + String, + List(Box), + Record(wasmtime_component_type_field_vec_t), + Tuple(wasmtime_component_type_vec_t), + Variant(wasmtime_component_type_field_vec_t), + Enum(wasmtime_component_string_vec_t), + Option(Box), + Result(wasmtime_component_type_result_t), + Flags(wasmtime_component_string_vec_t), +} + +impl Default for wasmtime_component_type_t { + fn default() -> Self { + Self::Bool + } +} + +#[no_mangle] +pub extern "C" fn wasmtime_component_type_new() -> Box { + Box::new(wasmtime_component_type_t::Bool) +} + +#[no_mangle] +pub unsafe extern "C" fn wasmtime_component_type_delete(_: Box) {} + +#[derive(Clone)] +pub enum CType { + Bool, + S8, + U8, + S16, + U16, + S32, + U32, + S64, + U64, + F32, + F64, + Char, + String, + List(Box), + Record(Vec<(String, CType)>), + Tuple(Vec), + Variant(Vec<(String, Option>)>), + Enum(Vec), + Option(Box), + Result(Option>, Option>), + Flags(Vec), +} + +impl TryFrom<&wasmtime_component_type_t> for CType { + type Error = anyhow::Error; + + fn try_from(ty: &wasmtime_component_type_t) -> Result { + Ok(match ty { + wasmtime_component_type_t::Bool => CType::Bool, + wasmtime_component_type_t::S8 => CType::S8, + wasmtime_component_type_t::U8 => CType::U8, + wasmtime_component_type_t::S16 => CType::S16, + wasmtime_component_type_t::U16 => CType::U16, + wasmtime_component_type_t::S32 => CType::S32, + wasmtime_component_type_t::U32 => CType::U32, + wasmtime_component_type_t::S64 => CType::S64, + wasmtime_component_type_t::U64 => CType::U64, + wasmtime_component_type_t::F32 => CType::F32, + wasmtime_component_type_t::F64 => CType::F64, + wasmtime_component_type_t::Char => CType::Char, + wasmtime_component_type_t::String => CType::String, + wasmtime_component_type_t::List(ty) => CType::List(Box::new(ty.as_ref().try_into()?)), + wasmtime_component_type_t::Record(fields) => CType::Record( + fields + .as_slice() + .iter() + .map(|field| { + let field_name = String::from_utf8(field.name.as_slice().to_vec())?; + let field_type = match &field.ty { + Some(ty) => ty.try_into()?, + None => bail!("missing type of field {} in record", field_name), + }; + Ok((field_name, field_type)) + }) + .collect::>>()?, + ), + wasmtime_component_type_t::Tuple(types) => CType::Tuple( + types + .as_slice() + .iter() + .map(|ty| ty.try_into()) + .collect::>>()?, + ), + wasmtime_component_type_t::Variant(cases) => CType::Variant( + cases + .as_slice() + .iter() + .map(|case| { + let case_name = String::from_utf8(case.name.as_slice().to_vec())?; + let case_type = match &case.ty { + Some(ty) => Some(Box::new(ty.try_into()?)), + None => None, + }; + Ok((case_name, case_type)) + }) + .collect::>>()?, + ), + wasmtime_component_type_t::Enum(enums) => CType::Enum( + enums + .as_slice() + .iter() + .map(|s| Ok(String::from_utf8(s.as_slice().to_vec())?)) + .collect::>>()?, + ), + wasmtime_component_type_t::Option(ty) => { + CType::Option(Box::new(ty.as_ref().try_into()?)) + } + wasmtime_component_type_t::Result(wasmtime_component_type_result_t { + ok_ty, + err_ty, + }) => CType::Result( + match ok_ty { + Some(ty) => Some(Box::new(ty.as_ref().try_into()?)), + None => None, + }, + match err_ty { + Some(ty) => Some(Box::new(ty.as_ref().try_into()?)), + None => None, + }, + ), + wasmtime_component_type_t::Flags(flags) => CType::Flags( + flags + .as_slice() + .iter() + .map(|s| Ok(String::from_utf8(s.as_slice().to_vec())?)) + .collect::>>()?, + ), + }) + } +} + +impl CType { + /// Return a string description of this type + fn desc(&self) -> &'static str { + match self { + CType::Bool => "bool", + CType::S8 => "s8", + CType::U8 => "u8", + CType::S16 => "s16", + CType::U16 => "u16", + CType::S32 => "s32", + CType::U32 => "u32", + CType::S64 => "s64", + CType::U64 => "u64", + CType::F32 => "f32", + CType::F64 => "f64", + CType::Char => "char", + CType::String => "string", + CType::List(_) => "list", + CType::Record(_) => "record", + CType::Tuple(_) => "tuple", + CType::Variant(_) => "variant", + CType::Enum(_) => "enum", + CType::Option(_) => "option", + CType::Result(_, _) => "result", + CType::Flags(_) => "flags", + } + } + + fn default_cval(&self) -> wasmtime_component_val_t { + match self { + CType::Bool => wasmtime_component_val_t::Bool(false), + CType::S8 => wasmtime_component_val_t::S8(0), + CType::U8 => wasmtime_component_val_t::U8(0), + CType::S16 => wasmtime_component_val_t::S16(0), + CType::U16 => wasmtime_component_val_t::U16(0), + CType::S32 => wasmtime_component_val_t::S32(0), + CType::U32 => wasmtime_component_val_t::U32(0), + CType::S64 => wasmtime_component_val_t::S64(0), + CType::U64 => wasmtime_component_val_t::U64(0), + CType::F32 => wasmtime_component_val_t::F32(0.0), + CType::F64 => wasmtime_component_val_t::F64(0.0), + CType::Char => wasmtime_component_val_t::Char('\0'), + CType::String => { + wasmtime_component_val_t::String(wasmtime_component_string_t::default()) + } + CType::List(_) => { + wasmtime_component_val_t::List(wasmtime_component_val_vec_t::default()) + } + CType::Record(fields) => wasmtime_component_val_t::Record( + fields + .iter() + .map(|(name, ty)| wasmtime_component_val_record_field_t { + name: name.clone().into_bytes().into(), + val: ty.default_cval(), + }) + .collect::>() + .into(), + ), + CType::Tuple(tuple) => wasmtime_component_val_t::Tuple( + tuple + .iter() + .map(|ty| ty.default_cval()) + .collect::>() + .into(), + ), + CType::Variant(cases) => { + wasmtime_component_val_t::Variant(wasmtime_component_val_variant_t { + discriminant: 0, + val: match &cases[0].1 { + Some(ty) => Some(Box::new(ty.default_cval())), + None => None, + }, + }) + } + CType::Enum(_) => { + wasmtime_component_val_t::Enum(wasmtime_component_val_enum_t { discriminant: 0 }) + } + CType::Option(_) => wasmtime_component_val_t::Option(None), + CType::Result(_, _) => { + wasmtime_component_val_t::Result(wasmtime_component_val_result_t { + value: None, + error: false, + }) + } + CType::Flags(_) => { + wasmtime_component_val_t::Flags(wasmtime_component_val_flags_t::default()) + } + } + } +} + #[repr(transparent)] pub struct wasmtime_component_t { component: Component, @@ -586,9 +1310,30 @@ pub unsafe extern "C" fn wasmtime_component_from_binary( #[no_mangle] pub unsafe extern "C" fn wasmtime_component_delete(_: Box) {} +pub type wasmtime_component_func_callback_t = extern "C" fn( + *mut c_void, + WasmtimeStoreContextMut<'_>, + *const wasmtime_component_val_t, + usize, + *mut wasmtime_component_val_t, + usize, +) -> Option>; + +struct HostFuncDefinition { + path: Vec, + name: String, + params_types: Vec, + outputs_types: Vec, + callback: wasmtime_component_func_callback_t, + data: *mut c_void, + finalizer: Option, +} + #[repr(C)] pub struct wasmtime_component_linker_t { linker: Linker, + is_built: bool, + functions: Vec, } #[no_mangle] @@ -597,12 +1342,165 @@ pub extern "C" fn wasmtime_component_linker_new( ) -> Box { Box::new(wasmtime_component_linker_t { linker: Linker::new(&engine.engine), + is_built: false, + functions: Vec::new(), }) } #[no_mangle] pub unsafe extern "C" fn wasmtime_component_linker_delete(_: Box) {} +fn to_ctype_vec(buf: *mut wasmtime_component_type_t, len: usize) -> Result> { + if len == 0 { + return Ok(Vec::new()); + } + let v = unsafe { crate::slice_from_raw_parts(buf, len) }; + v.iter().map(|t| t.try_into()).collect::>>() +} + +#[no_mangle] +pub unsafe extern "C" fn wasmtime_component_linker_define_func( + linker: &mut wasmtime_component_linker_t, + path_buf: *const u8, + path_len: usize, + name_buf: *const u8, + name_len: usize, + params_types_buf: *mut wasmtime_component_type_t, + params_types_len: usize, + outputs_types_buf: *mut wasmtime_component_type_t, + outputs_types_len: usize, + callback: wasmtime_component_func_callback_t, + data: *mut c_void, + finalizer: Option, +) -> Option> { + let path = to_str!(path_buf, path_len) + .split('.') + .filter(|s| s.len() > 0) + .map(|s| s.to_string()) + .collect::>(); + let name = to_str!(name_buf, name_len).to_string(); + let params_types = match to_ctype_vec(params_types_buf, params_types_len) { + Err(err) => return Some(Box::new(wasmtime_error_t::from(err))), + Ok(p) => p, + }; + let outputs_types = match to_ctype_vec(outputs_types_buf, outputs_types_len) { + Err(err) => return Some(Box::new(wasmtime_error_t::from(err))), + Ok(p) => p, + }; + + linker.functions.push(HostFuncDefinition { + path, + name, + params_types, + outputs_types, + callback, + data, + finalizer, + }); + None +} + +fn build_closure( + function: &HostFuncDefinition, +) -> impl Fn(WasmtimeStoreContextMut<'_>, &[Val], &mut [Val]) -> Result<()> { + let func = function.callback; + let params_types = function.params_types.clone(); + let outputs_types = function.outputs_types.clone(); + let foreign = crate::ForeignData { + data: function.data, + finalizer: function.finalizer, + }; + move |context, parameters, outputs| { + let _ = &foreign; + let _ = ¶ms_types; + let _ = &outputs_types; + let mut params = Vec::new(); + for param in parameters.iter().zip(params_types.iter()) { + params.push(CTypedVal(param.0, param.1).try_into()?); + } + let mut outs = Vec::new(); + for output_type in outputs_types.iter() { + outs.push(output_type.default_cval()); + } + let res = func( + foreign.data, + context, + params.as_ptr(), + params.len(), + outs.as_mut_ptr(), + outs.len(), + ); + match res { + None => { + for (i, (output, output_type)) in outs.iter().zip(outputs_types.iter()).enumerate() + { + outputs[i] = CTypedCVal(output, output_type).try_into()?; + } + Ok(()) + } + Some(trap) => Err(trap.error), + } + } +} + +#[no_mangle] +pub extern "C" fn wasmtime_component_linker_build( + linker: &mut wasmtime_component_linker_t, +) -> Option> { + if linker.is_built { + return Some(Box::new(wasmtime_error_t::from(anyhow!( + "cannot build an already built linker" + )))); + } + + struct InstanceTree { + children: HashMap, + functions: Vec, + } + + impl InstanceTree { + fn insert(&mut self, depth: usize, function: HostFuncDefinition) { + if function.path.len() == depth { + self.functions.push(function); + } else { + let child = self + .children + .entry(function.path[depth].to_string()) + .or_insert_with(|| InstanceTree { + children: HashMap::new(), + functions: Vec::new(), + }); + child.insert(depth + 1, function); + } + } + fn build(&self, mut instance: LinkerInstance) -> Result<()> { + for function in self.functions.iter() { + instance.func_new(&function.name, build_closure(function))?; + } + for (name, child) in self.children.iter() { + let child_instance = instance.instance(&name)?; + child.build(child_instance)?; + } + Ok(()) + } + } + + let mut root = InstanceTree { + children: HashMap::new(), + functions: Vec::new(), + }; + for function in linker.functions.drain(..) { + root.insert(0, function); + } + match root.build(linker.linker.root()) { + Ok(()) => { + linker.is_built = true; + None + } + Err(err) => Some(Box::new(wasmtime_error_t::from(anyhow!(err)))), + } +} + #[no_mangle] pub extern "C" fn wasmtime_component_linker_instantiate( linker: &wasmtime_component_linker_t, @@ -610,6 +1508,11 @@ pub extern "C" fn wasmtime_component_linker_instantiate( component: &wasmtime_component_t, out: &mut *mut wasmtime_component_instance_t, ) -> Option> { + if !linker.is_built && !linker.functions.is_empty() { + return Some(Box::new(wasmtime_error_t::from(anyhow!( + "cannot instantiate with a linker not built" + )))); + } match linker.linker.instantiate(store, &component.component) { Ok(instance) => { *out = Box::into_raw(Box::new(wasmtime_component_instance_t { instance }));