Skip to content

Commit

Permalink
feat: per call context (#711)
Browse files Browse the repository at this point in the history
Add `plugin.call_with_host_context` and `current_plugin.host_context`
methods, enabling per-call context to be looped from the guest invocation
to any host functions it calls. In an HTTP server environment, this enables
re-using a plugin across multiple requests while switching out backing
connections and user information in host functions. (Imagine a host
function, `update_user` -- previously the plugin would have to have been
aware of the user to pass to the host function. Now that information is
ambient.)

This is a backwards-compatible change and requires no changes to
existing plugins.

Implement by adding a global, mutable externref to the extism kernel.
Since most programming languages, including Rust, don't let you define
these natively, we accomplish this by using `wasm-merge` to combine the
kernel Wasm with Wasm generated by a WAT file containing only the
global.

(This pattern might be useful for other Wasm constructs we can't use
directly from Rust, like `v128` in argument parameters.)

Wasmtime requires extern refs to be `Any + Send + Sync + 'static`; we
additionally add `Clone`. I haven't tried this with an `Arc` directly,
but it should work at least for container structs that hold `Arc`'s
themselves.
  • Loading branch information
chrisdickinson committed May 21, 2024
1 parent 75e92c4 commit 5d9c8c5
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
target: wasm32-unknown-unknown
- uses: Swatinem/rust-cache@v2

- name: install wasm-tools
uses: bytecodealliance/actions/wasm-tools/setup@v1

- name: Install deps
run: |
sudo apt install wabt --yes
Expand Down
7 changes: 5 additions & 2 deletions kernel/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ done

cargo build --package extism-runtime-kernel --bin extism-runtime --release --target wasm32-unknown-unknown $CARGO_FLAGS
cp target/wasm32-unknown-unknown/release/extism-runtime.wasm .
wasm-strip extism-runtime.wasm
mv extism-runtime.wasm ../runtime/src/extism-runtime.wasm

wasm-tools parse extism-context.wat -o extism-context.wasm
wasm-merge --enable-reference-types ./extism-runtime.wasm runtime extism-context.wasm context -o ../runtime/src/extism-runtime.wasm
rm extism-context.wasm
rm extism-runtime.wasm
wasm-strip ../runtime/src/extism-runtime.wasm
3 changes: 3 additions & 0 deletions kernel/extism-context.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(module
(global (export "extism_context") (mut externref) (ref.null extern))
)
20 changes: 20 additions & 0 deletions runtime/extism.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ extern "C" {
*/
const uint8_t *extism_plugin_id(ExtismPlugin *plugin);

/**
* Get the current plugin's associated host context data. Returns null if call was made without
* host context.
*/
void *extism_current_plugin_host_context(ExtismCurrentPlugin *plugin);

/**
* Returns a pointer to the memory of the currently running plugin
* NOTE: this should only be called from host functions.
Expand Down Expand Up @@ -231,6 +237,20 @@ int32_t extism_plugin_call(ExtismPlugin *plugin,
const uint8_t *data,
ExtismSize data_len);

/**
* Call a function with host context.
*
* `func_name`: is the function to call
* `data`: is the input data
* `data_len`: is the length of `data`
* `host_context`: a pointer to context data that will be available in host functions
*/
int32_t extism_plugin_call_with_host_context(ExtismPlugin *plugin,
const char *func_name,
const uint8_t *data,
ExtismSize data_len,
void *host_context);

/**
* Get the error associated with a `Plugin`
*/
Expand Down
17 changes: 17 additions & 0 deletions runtime/src/current_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ impl CurrentPlugin {
anyhow::bail!("{} unable to locate extism memory", self.id)
}

pub fn host_context<T: Clone + 'static>(&mut self) -> Result<T, Error> {
let (linker, mut store) = self.linker_and_store();
let Some(Extern::Global(xs)) = linker.get(&mut store, EXTISM_ENV_MODULE, "extism_context")
else {
anyhow::bail!("unable to locate an extism kernel global: extism_context",)
};

let Val::ExternRef(Some(xs)) = xs.get(store) else {
anyhow::bail!("expected extism_context to be an externref value",)
};

match xs.data().downcast_ref::<T>().cloned() {
Some(xs) => Ok(xs.clone()),
None => anyhow::bail!("could not downcast extism_context",),
}
}

pub fn memory_alloc(&mut self, n: u64) -> Result<MemoryHandle, Error> {
if n == 0 {
return Ok(MemoryHandle {
Expand Down
Binary file modified runtime/src/extism-runtime.wasm
Binary file not shown.
47 changes: 43 additions & 4 deletions runtime/src/plugin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
any::Any,
collections::{BTreeMap, BTreeSet},
path::PathBuf,
};
Expand Down Expand Up @@ -193,6 +194,12 @@ fn add_module<T: 'static>(
}

for import in module.imports() {
let module = import.module();

if module == EXTISM_ENV_MODULE && !matches!(import.ty(), ExternType::Func(_)) {
anyhow::bail!("linked modules cannot access non-function exports of extism kernel");
}

if !linked.contains(import.module()) {
if let Some(m) = modules.get(import.module()) {
add_module(
Expand Down Expand Up @@ -462,7 +469,12 @@ impl Plugin {
}

// Store input in memory and re-initialize `Internal` pointer
pub(crate) fn set_input(&mut self, input: *const u8, mut len: usize) -> Result<(), Error> {
pub(crate) fn set_input(
&mut self,
input: *const u8,
mut len: usize,
host_context: Option<ExternRef>,
) -> Result<(), Error> {
self.output = Output::default();
self.clear_error()?;
let id = self.id.to_string();
Expand Down Expand Up @@ -496,6 +508,13 @@ impl Plugin {
)?;
}

if let Some(Extern::Global(ctxt)) =
self.linker
.get(&mut self.store, EXTISM_ENV_MODULE, "extism_context")
{
ctxt.set(&mut self.store, Val::ExternRef(host_context))?;
}

Ok(())
}

Expand Down Expand Up @@ -673,6 +692,7 @@ impl Plugin {
lock: &mut std::sync::MutexGuard<Option<Instance>>,
name: impl AsRef<str>,
input: impl AsRef<[u8]>,
host_context: Option<ExternRef>,
) -> Result<i32, (Error, i32)> {
let name = name.as_ref();
let input = input.as_ref();
Expand All @@ -686,7 +706,7 @@ impl Plugin {

self.instantiate(lock).map_err(|e| (e, -1))?;

self.set_input(input.as_ptr(), input.len())
self.set_input(input.as_ptr(), input.len(), host_context)
.map_err(|x| (x, -1))?;

let func = match self.get_func(lock, name) {
Expand Down Expand Up @@ -873,7 +893,26 @@ impl Plugin {
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
let data = input.to_bytes()?;
self.raw_call(&mut lock, name, data)
self.raw_call(&mut lock, name, data, None)
.map_err(|e| e.0)
.and_then(move |_| self.output())
}

pub fn call_with_host_context<'a, 'b, T, U, C>(
&'b mut self,
name: impl AsRef<str>,
input: T,
host_context: C,
) -> Result<U, Error>
where
T: ToBytes<'a>,
U: FromBytes<'b>,
C: Any + Send + Sync + 'static,
{
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
let data = input.to_bytes()?;
self.raw_call(&mut lock, name, data, Some(ExternRef::new(host_context)))
.map_err(|e| e.0)
.and_then(move |_| self.output())
}
Expand All @@ -892,7 +931,7 @@ impl Plugin {
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
let data = input.to_bytes().map_err(|e| (e, -1))?;
self.raw_call(&mut lock, name, data)
self.raw_call(&mut lock, name, data, None)
.and_then(move |_| self.output().map_err(|e| (e, -1)))
}

Expand Down
50 changes: 49 additions & 1 deletion runtime/src/sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ pub unsafe extern "C" fn extism_plugin_id(plugin: *mut Plugin) -> *const u8 {
plugin.id.as_bytes().as_ptr()
}

/// Get the current plugin's associated host context data. Returns null if call was made without
/// host context.
#[no_mangle]
pub unsafe extern "C" fn extism_current_plugin_host_context(
plugin: *mut CurrentPlugin,
) -> *mut std::ffi::c_void {
if plugin.is_null() {
return std::ptr::null_mut();
}

let plugin = &mut *plugin;
if let Ok(CVoidContainer(ptr)) = plugin.host_context::<CVoidContainer>() {
ptr
} else {
std::ptr::null_mut()
}
}

/// Returns a pointer to the memory of the currently running plugin
/// NOTE: this should only be called from host functions.
#[no_mangle]
Expand Down Expand Up @@ -464,6 +482,31 @@ pub unsafe extern "C" fn extism_plugin_call(
func_name: *const c_char,
data: *const u8,
data_len: Size,
) -> i32 {
extism_plugin_call_with_host_context(plugin, func_name, data, data_len, std::ptr::null_mut())
}

#[derive(Clone)]
#[repr(transparent)]
struct CVoidContainer(*mut std::ffi::c_void);

// "You break it, you buy it."
unsafe impl Send for CVoidContainer {}
unsafe impl Sync for CVoidContainer {}

/// Call a function with host context.
///
/// `func_name`: is the function to call
/// `data`: is the input data
/// `data_len`: is the length of `data`
/// `host_context`: a pointer to context data that will be available in host functions
#[no_mangle]
pub unsafe extern "C" fn extism_plugin_call_with_host_context(
plugin: *mut Plugin,
func_name: *const c_char,
data: *const u8,
data_len: Size,
host_context: *mut std::ffi::c_void,
) -> i32 {
if plugin.is_null() {
return -1;
Expand All @@ -485,7 +528,12 @@ pub unsafe extern "C" fn extism_plugin_call(
name
);
let input = std::slice::from_raw_parts(data, data_len as usize);
let res = plugin.raw_call(&mut lock, name, input);
let res = plugin.raw_call(
&mut lock,
name,
input,
Some(ExternRef::new(CVoidContainer(host_context))),
);

match res {
Err((e, rc)) => plugin.return_error(&mut lock, e, rc),
Expand Down
36 changes: 36 additions & 0 deletions runtime/src/tests/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,42 @@ fn test_toml_manifest() {
assert_eq!(count.get("count").unwrap().as_i64().unwrap(), 1);
}

#[test]
fn test_call_with_host_context() {
#[derive(Clone)]
struct Foo {
message: String,
}

let f = Function::new(
"host_reflect",
[PTR],
[PTR],
UserData::default(),
|current_plugin, _val, ret, _user_data: UserData<()>| {
let foo = current_plugin.host_context::<Foo>()?;
let hnd = current_plugin.memory_new(foo.message)?;
ret[0] = current_plugin.memory_to_val(hnd);
Ok(())
},
);

let mut plugin = Plugin::new(WASM_REFLECT, [f], true).unwrap();

let message = "hello world";
let output: String = plugin
.call_with_host_context(
"reflect",
"anything, really",
Foo {
message: message.to_string(),
},
)
.unwrap();

assert_eq!(output, message);
}

#[test]
fn test_fuzz_reflect_plugin() {
// assert!(set_log_file("stdout", Some(log::Level::Trace)));
Expand Down

0 comments on commit 5d9c8c5

Please sign in to comment.