diff --git a/lib/runtime-core-tests/tests/imports.rs b/lib/runtime-core-tests/tests/imports.rs index 61702e5ac..24fc4ad5f 100644 --- a/lib/runtime-core-tests/tests/imports.rs +++ b/lib/runtime-core-tests/tests/imports.rs @@ -12,33 +12,46 @@ fn imported_functions_forms() { (import "env" "memory" (memory 1 1)) (import "env" "callback_fn" (func $callback_fn (type $type))) (import "env" "callback_closure" (func $callback_closure (type $type))) + (import "env" "callback_closure_with_env" (func $callback_closure_with_env (type $type))) (import "env" "callback_fn_with_vmctx" (func $callback_fn_with_vmctx (type $type))) (import "env" "callback_closure_with_vmctx" (func $callback_closure_with_vmctx (type $type))) (import "env" "callback_fn_trap" (func $callback_fn_trap (type $type))) (import "env" "callback_closure_trap" (func $callback_closure_trap (type $type))) (import "env" "callback_fn_trap_with_vmctx" (func $callback_fn_trap_with_vmctx (type $type))) (import "env" "callback_closure_trap_with_vmctx" (func $callback_closure_trap_with_vmctx (type $type))) + (func (export "function_fn") (type $type) get_local 0 call $callback_fn) + (func (export "function_closure") (type $type) get_local 0 call $callback_closure) + + (func (export "function_closure_with_env") (type $type) + get_local 0 + call $callback_closure_with_env) + (func (export "function_fn_with_vmctx") (type $type) get_local 0 call $callback_fn_with_vmctx) + (func (export "function_closure_with_vmctx") (type $type) get_local 0 call $callback_closure_with_vmctx) + (func (export "function_fn_trap") (type $type) get_local 0 call $callback_fn_trap) + (func (export "function_closure_trap") (type $type) get_local 0 call $callback_closure_trap) + (func (export "function_fn_trap_with_vmctx") (type $type) get_local 0 call $callback_fn_trap_with_vmctx) + (func (export "function_closure_trap_with_vmctx") (type $type) get_local 0 call $callback_closure_trap_with_vmctx)) @@ -51,6 +64,7 @@ fn imported_functions_forms() { const SHIFT: i32 = 10; memory.view()[0].set(SHIFT); + let shift = 100; let import_object = imports! { "env" => { @@ -59,6 +73,9 @@ fn imported_functions_forms() { "callback_closure" => Func::new(|n: i32| -> Result { Ok(n + 1) }), + "callback_closure_with_env" => Func::new(move |n: i32| -> Result { + Ok(shift + n + 1) + }), "callback_fn_with_vmctx" => Func::new(callback_fn_with_vmctx), "callback_closure_with_vmctx" => Func::new(|vmctx: &mut vm::Ctx, n: i32| -> Result { let memory = vmctx.memory(0); @@ -134,6 +151,7 @@ fn imported_functions_forms() { call_and_assert!(function_fn, Ok(2)); call_and_assert!(function_closure, Ok(2)); + call_and_assert!(function_closure_with_env, Ok(2 + shift)); call_and_assert!(function_fn_with_vmctx, Ok(2 + SHIFT)); call_and_assert!(function_closure_with_vmctx, Ok(2 + SHIFT)); call_and_assert!( diff --git a/lib/runtime-core/src/backing.rs b/lib/runtime-core/src/backing.rs index f9b320b45..0f57b2bbf 100644 --- a/lib/runtime-core/src/backing.rs +++ b/lib/runtime-core/src/backing.rs @@ -586,18 +586,18 @@ fn import_functions( if *expected_sig == *signature { functions.push(vm::ImportedFunc { func: func.inner(), - func_ctx: { - let _ = match ctx { - Context::External(ctx) => ctx, - Context::Internal => vmctx, - }; - - NonNull::new(Box::into_raw(Box::new(vm::FuncCtx { - vmctx: NonNull::new(vmctx).expect("`vmctx` must not be null."), - func_env: ptr::null_mut(), - }))) - .unwrap() - }, + func_ctx: NonNull::new(Box::into_raw(Box::new(vm::FuncCtx { + vmctx: NonNull::new(vmctx).expect("`vmctx` must not be null."), + func_env: match ctx { + Context::External(ctx) => { + NonNull::new(ctx).map(|pointer| { + pointer.cast() // `*mut vm::FuncEnv` was casted to `*mut vm::Ctx` to fit in `Context::External`. Cast it back. + }) + } + Context::Internal => None, + }, + }))) + .unwrap(), }); } else { link_errors.push(LinkError::IncorrectImportSignature { diff --git a/lib/runtime-core/src/typed_func.rs b/lib/runtime-core/src/typed_func.rs index 049dfc161..584749ecf 100644 --- a/lib/runtime-core/src/typed_func.rs +++ b/lib/runtime-core/src/typed_func.rs @@ -175,7 +175,7 @@ where Args: WasmTypeList, Rets: WasmTypeList, { - fn to_raw(&self) -> NonNull; + fn to_raw(self) -> (NonNull, Option>); } pub trait TrapEarly @@ -208,10 +208,12 @@ where } /// Represents a function that can be used by WebAssembly. +#[allow(dead_code)] pub struct Func<'a, Args = (), Rets = (), Inner: Kind = Wasm> { inner: Inner, - f: NonNull, - ctx: *mut vm::Ctx, + func: NonNull, + func_env: Option>, + vmctx: *mut vm::Ctx, _phantom: PhantomData<(&'a (), Args, Rets)>, } @@ -225,19 +227,20 @@ where { pub(crate) unsafe fn from_raw_parts( inner: Wasm, - f: NonNull, - ctx: *mut vm::Ctx, + func: NonNull, + vmctx: *mut vm::Ctx, ) -> Func<'a, Args, Rets, Wasm> { Func { inner, - f, - ctx, + func, + func_env: None, + vmctx, _phantom: PhantomData, } } pub fn get_vm_func(&self) -> NonNull { - self.f + self.func } } @@ -246,15 +249,18 @@ where Args: WasmTypeList, Rets: WasmTypeList, { - pub fn new(f: F) -> Func<'a, Args, Rets, Host> + pub fn new(func: F) -> Func<'a, Args, Rets, Host> where Kind: ExternalFunctionKind, F: ExternalFunction, { + let (func, func_env) = func.to_raw(); + Func { inner: Host(()), - f: f.to_raw(), - ctx: ptr::null_mut(), + func, + func_env, + vmctx: ptr::null_mut(), _phantom: PhantomData, } } @@ -391,7 +397,7 @@ where Rets: WasmTypeList, { pub fn call(&self, a: A) -> Result { - unsafe { ::call(a, self.f, self.inner, self.ctx) } + unsafe { ::call(a, self.func, self.inner, self.vmctx) } } } @@ -482,57 +488,75 @@ macro_rules! impl_traits { $( $x: WasmExternType, )* Rets: WasmTypeList, Trap: TrapEarly, - FN: Fn(&mut vm::Ctx $( , $x )*) -> Trap, + FN: Fn(&mut vm::Ctx $( , $x )*) -> Trap + 'static, { #[allow(non_snake_case)] - fn to_raw(&self) -> NonNull { - if mem::size_of::() == 0 { - /// This is required for the llvm backend to be able to unwind through this function. - #[cfg_attr(nightly, unwind(allowed))] - extern fn wrap<$( $x, )* Rets, Trap, FN>( - func_ctx: &mut vm::FuncCtx $( , $x: <$x as WasmExternType>::Native )* - ) -> Rets::CStruct - where - $( $x: WasmExternType, )* - Rets: WasmTypeList, - Trap: TrapEarly, - FN: Fn(&mut vm::Ctx, $( $x, )*) -> Trap, - { - let vmctx = unsafe { func_ctx.vmctx.as_mut() }; - let f: FN = unsafe { mem::transmute_copy(&()) }; + fn to_raw(self) -> (NonNull, Option>) { + /// This is required for the llvm backend to be able to unwind through this function. + #[cfg_attr(nightly, unwind(allowed))] + extern fn wrap<$( $x, )* Rets, Trap, FN>( + func_ctx: &mut vm::FuncCtx $( , $x: <$x as WasmExternType>::Native )* + ) -> Rets::CStruct + where + $( $x: WasmExternType, )* + Rets: WasmTypeList, + Trap: TrapEarly, + FN: Fn(&mut vm::Ctx, $( $x, )*) -> Trap, + { + dbg!(func_ctx.vmctx.as_ptr()); - let err = match panic::catch_unwind( - panic::AssertUnwindSafe( - || { - f(vmctx $( , WasmExternType::from_native($x) )* ).report() - } - ) - ) { - Ok(Ok(returns)) => return returns.into_c_struct(), - Ok(Err(err)) => { - let b: Box<_> = err.into(); - b as Box - }, - Err(err) => err, - }; + let vmctx = unsafe { func_ctx.vmctx.as_mut() }; + let func_env = func_ctx.func_env; - unsafe { - (&*vmctx.module).runnable_module.do_early_trap(err) - } + dbg!(func_env); + + let func: &FN = match func_env { + Some(func_env) => unsafe { + let func: NonNull = func_env.cast(); + + &*func.as_ptr() + }, + None => unsafe { mem::transmute_copy(&()) } + }; + + let err = match panic::catch_unwind( + panic::AssertUnwindSafe( + || { + func(vmctx $( , WasmExternType::from_native($x) )* ).report() + } + ) + ) { + Ok(Ok(returns)) => return returns.into_c_struct(), + Ok(Err(err)) => { + let b: Box<_> = err.into(); + b as Box + }, + Err(err) => err, + }; + + unsafe { + (&*vmctx.module).runnable_module.do_early_trap(err) } - - NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap() - } else { - assert_eq!( - mem::size_of::(), - mem::size_of::(), - "you cannot use a closure that captures state for `Func`." - ); - - NonNull::new(unsafe { - mem::transmute_copy::<_, *mut vm::Func>(self) - }).unwrap() } + + let func_env: Option> = + // `FN` is a function pointer, or a closure + // _without_ a captured environment. + if mem::size_of::() == 0 { + None + } + // `FN` is a closure _with_ a captured + // environment. Grab it. + else { + NonNull::new(Box::into_raw(Box::new(self))).map(NonNull::cast) + }; + + dbg!(func_env); + + ( + NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap(), + func_env + ) } } @@ -541,57 +565,75 @@ macro_rules! impl_traits { $( $x: WasmExternType, )* Rets: WasmTypeList, Trap: TrapEarly, - FN: Fn($( $x, )*) -> Trap, + FN: Fn($( $x, )*) -> Trap + 'static, { #[allow(non_snake_case)] - fn to_raw(&self) -> NonNull { - if mem::size_of::() == 0 { - /// This is required for the llvm backend to be able to unwind through this function. - #[cfg_attr(nightly, unwind(allowed))] - extern fn wrap<$( $x, )* Rets, Trap, FN>( - func_ctx: &mut vm::FuncCtx $( , $x: <$x as WasmExternType>::Native )* - ) -> Rets::CStruct - where - $( $x: WasmExternType, )* - Rets: WasmTypeList, - Trap: TrapEarly, - FN: Fn($( $x, )*) -> Trap, - { - let vmctx = unsafe { func_ctx.vmctx.as_mut() }; - let f: FN = unsafe { mem::transmute_copy(&()) }; + fn to_raw(self) -> (NonNull, Option>) { + /// This is required for the llvm backend to be able to unwind through this function. + #[cfg_attr(nightly, unwind(allowed))] + extern fn wrap<$( $x, )* Rets, Trap, FN>( + func_ctx: &mut vm::FuncCtx $( , $x: <$x as WasmExternType>::Native )* + ) -> Rets::CStruct + where + $( $x: WasmExternType, )* + Rets: WasmTypeList, + Trap: TrapEarly, + FN: Fn($( $x, )*) -> Trap, + { + dbg!(func_ctx.vmctx.as_ptr()); - let err = match panic::catch_unwind( - panic::AssertUnwindSafe( - || { - f($( WasmExternType::from_native($x), )* ).report() - } - ) - ) { - Ok(Ok(returns)) => return returns.into_c_struct(), - Ok(Err(err)) => { - let b: Box<_> = err.into(); - b as Box - }, - Err(err) => err, - }; + let vmctx = unsafe { func_ctx.vmctx.as_mut() }; + let func_env = func_ctx.func_env; - unsafe { - (&*vmctx.module).runnable_module.do_early_trap(err) - } + dbg!(func_env); + + let func: &FN = match func_env { + Some(func_env) => unsafe { + let func: NonNull = func_env.cast(); + + &*func.as_ptr() + }, + None => unsafe { mem::transmute_copy(&()) } + }; + + let err = match panic::catch_unwind( + panic::AssertUnwindSafe( + || { + func($( WasmExternType::from_native($x), )* ).report() + } + ) + ) { + Ok(Ok(returns)) => return returns.into_c_struct(), + Ok(Err(err)) => { + let b: Box<_> = err.into(); + b as Box + }, + Err(err) => err, + }; + + unsafe { + (&*vmctx.module).runnable_module.do_early_trap(err) } - - NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap() - } else { - assert_eq!( - mem::size_of::(), - mem::size_of::(), - "you cannot use a closure that captures state for `Func`." - ); - - NonNull::new(unsafe { - mem::transmute_copy::<_, *mut vm::Func>(self) - }).unwrap() } + + let func_env: Option> = + // `FN` is a function pointer, or a closure + // _without_ a captured environment. + if mem::size_of::() == 0 { + None + } + // `FN` is a closure _with_ a captured + // environment. Grab it. + else { + NonNull::new(Box::into_raw(Box::new(self))).map(NonNull::cast) + }; + + dbg!(func_env); + + ( + NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap(), + func_env + ) } } @@ -606,9 +648,9 @@ macro_rules! impl_traits { unsafe { <( $( $x ),* ) as WasmTypeList>::call( ( $( $x ),* ), - self.f, + self.func, self.inner, - self.ctx + self.vmctx ) } } @@ -646,8 +688,11 @@ where Inner: Kind, { fn to_export(&self) -> Export { - let func = unsafe { FuncPointer::new(self.f.as_ptr()) }; - let ctx = Context::Internal; + let func = unsafe { FuncPointer::new(self.func.as_ptr()) }; + let ctx = match self.func_env { + Some(func_env) => Context::External(func_env.cast().as_ptr()), + None => Context::Internal, + }; let signature = Arc::new(FuncSig::new(Args::types(), Rets::types())); Export::Function { diff --git a/lib/runtime-core/src/vm.rs b/lib/runtime-core/src/vm.rs index 1cc228c65..dc92c788f 100644 --- a/lib/runtime-core/src/vm.rs +++ b/lib/runtime-core/src/vm.rs @@ -512,16 +512,16 @@ pub struct FuncEnv { #[derive(Debug)] #[repr(C)] pub struct FuncCtx { - pub vmctx: NonNull, - pub func_env: *mut FuncEnv, + pub(crate) vmctx: NonNull, + pub(crate) func_env: Option>, } /// An imported function, which contains the vmctx that owns this function. #[derive(Debug, Clone)] #[repr(C)] pub struct ImportedFunc { - pub func: *const Func, - pub func_ctx: NonNull, + pub(crate) func: *const Func, + pub(crate) func_ctx: NonNull, } // manually implemented because ImportedFunc contains raw pointers