Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: per call context #711

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions kernel/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ done

cargo build --package extism-runtime-kernel --bin extism-runtime --release --target wasm32-unknown-unknown $CARGO_FLAGS
cp target/wasm32-unknown-unknown/release/extism-runtime.wasm .
wasm-strip extism-runtime.wasm
mv extism-runtime.wasm ../runtime/src/extism-runtime.wasm

wasm-tools parse extism-context.wat -o extism-context.wasm
wasm-merge --enable-reference-types ./extism-runtime.wasm runtime extism-context.wasm context -o ../runtime/src/extism-runtime.wasm
wasm-strip ../runtime/src/extism-runtime.wasm
3 changes: 3 additions & 0 deletions kernel/extism-context.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(module
(global (export "extism_context") (mut externref) (ref.null extern))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add a layer of protection around this by saying that "we don't accept any Wasm module that imports extism_context." (This would protect against a malicious plugin that imports this externref, holds onto it across two calls, replacing the global value of the second call with the value from the first call.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my first thought is we don't want people to be able to poke this and use some other plugin-runner's context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returning an error if a plugin is trying to import the global directly sounds like a good idea - I also noticed that the context global gets reset before each call regardless of whether call_with_context is used, which makes me confident that a plugin couldn't access the context from another call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks – added that in 68da28b.

This commit also prevents linked modules from importing the extism kernel's memory directly (which I think would lead to similar problems.) I don't know of anyone doing that today – it would certainly seem to void the warranty label – but I can narrow up that check to just the context if we're worried!

)
20 changes: 20 additions & 0 deletions runtime/extism.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ extern "C" {
*/
const uint8_t *extism_plugin_id(ExtismPlugin *plugin);

/**
* Get the current plugin's associated context data. Returns null if call was made without
* context.
*/
void *extism_current_plugin_context(ExtismCurrentPlugin *plugin);

/**
* Returns a pointer to the memory of the currently running plugin
* NOTE: this should only be called from host functions.
Expand Down Expand Up @@ -231,6 +237,20 @@ int32_t extism_plugin_call(ExtismPlugin *plugin,
const uint8_t *data,
ExtismSize data_len);

/**
* Call a function with per-call context.
*
* `func_name`: is the function to call
* `data`: is the input data
* `data_len`: is the length of `data`
* `context`: a pointer to context data that will be available in host functions
*/
int32_t extism_plugin_call_with_context(ExtismPlugin *plugin,
const char *func_name,
const uint8_t *data,
ExtismSize data_len,
void *context);

/**
* Get the error associated with a `Plugin`
*/
Expand Down
17 changes: 17 additions & 0 deletions runtime/src/current_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ impl CurrentPlugin {
anyhow::bail!("{} unable to locate extism memory", self.id)
}

pub fn context<T: Clone + 'static>(&mut self) -> Result<T, Error> {
let (linker, mut store) = self.linker_and_store();
let Some(Extern::Global(xs)) = linker.get(&mut store, EXTISM_ENV_MODULE, "extism_context")
else {
anyhow::bail!("unable to locate an extism kernel global: extism_context",)
};

let Val::ExternRef(Some(xs)) = xs.get(store) else {
anyhow::bail!("expected extism_context to be an externref value",)
};

match xs.data().downcast_ref::<T>().cloned() {
Some(xs) => Ok(xs.clone()),
None => anyhow::bail!("could not downcast extism_context",),
}
}

pub fn memory_alloc(&mut self, n: u64) -> Result<MemoryHandle, Error> {
if n == 0 {
return Ok(MemoryHandle {
Expand Down
Binary file modified runtime/src/extism-runtime.wasm
Binary file not shown.
46 changes: 42 additions & 4 deletions runtime/src/plugin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
any::Any,
collections::{BTreeMap, BTreeSet},
path::PathBuf,
};
Expand Down Expand Up @@ -462,7 +463,12 @@ impl Plugin {
}

// Store input in memory and re-initialize `Internal` pointer
pub(crate) fn set_input(&mut self, input: *const u8, mut len: usize) -> Result<(), Error> {
pub(crate) fn set_input(
&mut self,
input: *const u8,
mut len: usize,
call_context: Option<ExternRef>,
) -> Result<(), Error> {
self.output = Output::default();
self.clear_error()?;
let id = self.id.to_string();
Expand Down Expand Up @@ -496,6 +502,13 @@ impl Plugin {
)?;
}

if let Some(Extern::Global(ctxt)) =
self.linker
.get(&mut self.store, EXTISM_ENV_MODULE, "extism_context")
{
ctxt.set(&mut self.store, Val::ExternRef(call_context))?;
}

Ok(())
}

Expand Down Expand Up @@ -673,6 +686,7 @@ impl Plugin {
lock: &mut std::sync::MutexGuard<Option<Instance>>,
name: impl AsRef<str>,
input: impl AsRef<[u8]>,
call_context: Option<ExternRef>,
) -> Result<i32, (Error, i32)> {
let name = name.as_ref();
let input = input.as_ref();
Expand All @@ -686,7 +700,7 @@ impl Plugin {

self.instantiate(lock).map_err(|e| (e, -1))?;

self.set_input(input.as_ptr(), input.len())
self.set_input(input.as_ptr(), input.len(), call_context)
.map_err(|x| (x, -1))?;

let func = match self.get_func(lock, name) {
Expand Down Expand Up @@ -873,11 +887,35 @@ impl Plugin {
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
let data = input.to_bytes()?;
self.raw_call(&mut lock, name, data)
self.raw_call(&mut lock, name, data, None)
.map_err(|e| e.0)
.and_then(move |_| self.output())
}

pub fn call_with_context<'a, 'b, T, U, C>(
&'b mut self,
name: impl AsRef<str>,
input: T,
context: C,
) -> Result<U, Error>
where
T: ToBytes<'a>,
U: FromBytes<'b>,
C: Any + Send + Sync + 'static,
{
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
let data = input.to_bytes()?;
self.raw_call(&mut lock, name, data, Some(ExternRef::new(context)))
.map_err(|e| e.0)
.and_then(move |_| self.output())
}

pub fn invoke(&mut self, name: &str, args: &[wasmtime::Val]) {
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
}
chrisdickinson marked this conversation as resolved.
Show resolved Hide resolved

/// Similar to `Plugin::call`, but returns the Extism error code along with the
/// `Error`. It is assumed if `Ok(_)` is returned that the error code was `0`.
///
Expand All @@ -892,7 +930,7 @@ impl Plugin {
let lock = self.instance.clone();
let mut lock = lock.lock().unwrap();
let data = input.to_bytes().map_err(|e| (e, -1))?;
self.raw_call(&mut lock, name, data)
self.raw_call(&mut lock, name, data, None)
.and_then(move |_| self.output().map_err(|e| (e, -1)))
}

Expand Down
50 changes: 49 additions & 1 deletion runtime/src/sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ pub unsafe extern "C" fn extism_plugin_id(plugin: *mut Plugin) -> *const u8 {
plugin.id.as_bytes().as_ptr()
}

/// Get the current plugin's associated context data. Returns null if call was made without
/// context.
#[no_mangle]
pub unsafe extern "C" fn extism_current_plugin_context(
plugin: *mut CurrentPlugin,
) -> *mut std::ffi::c_void {
if plugin.is_null() {
return std::ptr::null_mut();
}

let plugin = &mut *plugin;
if let Ok(CVoidContainer(ptr)) = plugin.context::<CVoidContainer>() {
ptr
} else {
std::ptr::null_mut()
}
}

/// Returns a pointer to the memory of the currently running plugin
/// NOTE: this should only be called from host functions.
#[no_mangle]
Expand Down Expand Up @@ -464,6 +482,31 @@ pub unsafe extern "C" fn extism_plugin_call(
func_name: *const c_char,
data: *const u8,
data_len: Size,
) -> i32 {
extism_plugin_call_with_context(plugin, func_name, data, data_len, std::ptr::null_mut())
}

#[derive(Clone)]
#[repr(transparent)]
struct CVoidContainer(*mut std::ffi::c_void);

// "You break it, you buy it."
unsafe impl Send for CVoidContainer {}
unsafe impl Sync for CVoidContainer {}

/// Call a function with per-call context.
///
/// `func_name`: is the function to call
/// `data`: is the input data
/// `data_len`: is the length of `data`
/// `context`: a pointer to context data that will be available in host functions
#[no_mangle]
pub unsafe extern "C" fn extism_plugin_call_with_context(
plugin: *mut Plugin,
func_name: *const c_char,
data: *const u8,
data_len: Size,
context: *mut std::ffi::c_void,
) -> i32 {
if plugin.is_null() {
return -1;
Expand All @@ -485,7 +528,12 @@ pub unsafe extern "C" fn extism_plugin_call(
name
);
let input = std::slice::from_raw_parts(data, data_len as usize);
let res = plugin.raw_call(&mut lock, name, input);
let res = plugin.raw_call(
&mut lock,
name,
input,
Some(ExternRef::new(CVoidContainer(context))),
);

match res {
Err((e, rc)) => plugin.return_error(&mut lock, e, rc),
Expand Down
37 changes: 37 additions & 0 deletions runtime/src/tests/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,43 @@ fn test_toml_manifest() {
assert_eq!(count.get("count").unwrap().as_i64().unwrap(), 1);
}

#[test]
fn test_call_with_context() {
// assert!(set_log_file("stdout", Some(log::Level::Trace)));
#[derive(Clone)]
struct Foo {
message: String,
}

let f = Function::new(
"host_reflect",
[PTR],
[PTR],
UserData::default(),
|current_plugin, _val, ret, _user_data: UserData<()>| {
let foo = current_plugin.context::<Foo>()?;
let hnd = current_plugin.memory_new(foo.message)?;
ret[0] = current_plugin.memory_to_val(hnd);
Ok(())
},
);

let mut plugin = Plugin::new(WASM_REFLECT, [f], true).unwrap();

let message = "hello world";
let output: String = plugin
.call_with_context(
"reflect",
"anything, really",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah, this was my initial thought - "then what's the input?" but I realized that the context value is only available in host functions (assuming we forbid access to the actual global somehow), which feels like a pretty powerful combination. this seems like a nice stepping-stone to allowing externref arguments while compilers are still catching up!

there could be some confusion around what to pass as input vs context (especially since we used to have a Context type that meant something totally different)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like HostContext or HostParam might be more clear - open to other ideas too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point – I renamed to host_context / HostContext in the applicable places.

Foo {
message: message.to_string(),
},
)
.unwrap();

assert_eq!(output, message);
}

#[test]
fn test_fuzz_reflect_plugin() {
// assert!(set_log_file("stdout", Some(log::Level::Trace)));
Expand Down