From ae0c43fc302a44ba8ea0305fbdd4420f09a44861 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 12 Nov 2024 10:56:51 -0800 Subject: [PATCH] add llg_par_compute_mask --- parser/Cargo.toml | 10 ++++--- parser/llguidance.h | 26 ++++++++++++++++ parser/src/ffi.rs | 69 +++++++++++++++++++++++++++++++++++++++++-- parser/src/ffi_par.rs | 64 +++++++++++++++++++++++++++++++++++++++ parser/src/lib.rs | 2 ++ 5 files changed, 165 insertions(+), 6 deletions(-) create mode 100644 parser/src/ffi_par.rs diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 1d7a763..e9c581e 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -15,12 +15,14 @@ jsonschema = { version = "0.24.0", default-features = false, optional = true } url = "2.5.2" lazy_static = { version = "1.5.0", optional = true } regex-syntax = "0.8.5" +rayon = { version = "1.10.0", optional = true } [features] -default = ["jsonschema_validation", "lark"] -logging = [] # this is extensive debug logging -lark = [] # ~115k (binary) -jsonschema_validation = ["jsonschema", "lazy_static"] # ~2.5M (binary) +default = ["jsonschema_validation", "lark", "rayon"] +logging = [] # this is extensive debug logging +lark = [] # ~115k (binary) +jsonschema_validation = ["dep:jsonschema", "dep:lazy_static"] # ~2.5M (binary) +rayon = ["dep:rayon"] [lib] crate-type = ["staticlib", "rlib", "cdylib"] diff --git a/parser/llguidance.h b/parser/llguidance.h index 6646c1c..b26420e 100644 --- a/parser/llguidance.h +++ b/parser/llguidance.h @@ -9,6 +9,8 @@ typedef struct LlgConstraint LlgConstraint; +typedef struct LlgConstraintStep LlgConstraintStep; + typedef struct LlgTokenizer LlgTokenizer; typedef struct LlgParserLimits { @@ -108,6 +110,11 @@ typedef struct LlgCommitResult { bool is_stop; } LlgCommitResult; +/** + * Function which llg calls when an operation is done. + */ +typedef void (*LlgCallback)(const void *user_data); + /** * Tokenization function * Will not write more than output_tokens_len tokens (which can be 0) @@ -226,6 +233,17 @@ struct LlgConstraint *llg_new_constraint_any(const struct LlgConstraintInit *ini */ const char *llg_get_error(const struct LlgConstraint *cc); +/** + * Get the current temperature of the constraint. + * It is updated by mask computation. + */ +float llg_get_temperature(const struct LlgConstraint *cc); + +/** + * Check if constraint is stopped (cannot be extended further). + */ +bool llg_is_stopped(const struct LlgConstraint *cc); + /** * Compute mask for the next token sampling * It typically takes up to a millisecond for a 100k tokenizer, so should be called in background. @@ -242,6 +260,14 @@ int32_t llg_compute_mask(struct LlgConstraint *cc, struct LlgMaskResult *res_p); */ int32_t llg_commit_token(struct LlgConstraint *cc, LlgToken token, struct LlgCommitResult *res_p); +/** + * Compute mask for several constraints in parallel. + */ +void llg_par_compute_mask(const struct LlgConstraintStep *steps, + size_t n_steps, + const void *user_data, + LlgCallback done_cb); + /** * Clone the constraint */ diff --git a/parser/src/ffi.rs b/parser/src/ffi.rs index 7cd1a36..5d65ad9 100644 --- a/parser/src/ffi.rs +++ b/parser/src/ffi.rs @@ -149,6 +149,9 @@ pub type LlgTokenizeFn = Option< ) -> usize, >; +/// Function which llg calls when an operation is done. +pub type LlgCallback = Option; + #[repr(C)] pub struct LlgTokenizerInit { /// The number of tokens in the vocabulary @@ -210,13 +213,35 @@ pub struct LlgConstraintInit { } #[derive(Clone)] +pub struct LlgConstraintStep { + /// The constraint to compute mask for. + pub constraint: *mut LlgConstraint, + /// Pointer to memory where the mask should be written. + pub mask_dest: *mut u32, + /// The length of the mask_dest array in bytes (not elements). + pub mask_byte_len: usize, +} + +unsafe impl Send for LlgConstraintStep {} + pub struct LlgConstraint { local_error: Option, last_logs: String, - constraint: Option, + pub(crate) constraint: Option, last_commit_result: CommitResult, } +impl Clone for LlgConstraint { + fn clone(&self) -> Self { + LlgConstraint { + local_error: self.local_error.clone(), + last_logs: self.last_logs.clone(), + constraint: self.constraint.clone(), + last_commit_result: self.last_commit_result.clone(), + } + } +} + impl Default for LlgConstraint { fn default() -> Self { LlgConstraint { @@ -360,7 +385,7 @@ impl LlgConstraint { } } - fn set_error(&mut self, e: &str) { + pub(crate) fn set_error(&mut self, e: &str) { self.constraint = None; self.local_error = Some(format!("{e}\0")); } @@ -456,6 +481,21 @@ pub extern "C" fn llg_get_error(cc: &LlgConstraint) -> *const c_char { cc.get_error() } +/// Get the current temperature of the constraint. +/// It is updated by mask computation. +#[no_mangle] +pub extern "C" fn llg_get_temperature(cc: &LlgConstraint) -> f32 { + cc.constraint.as_ref().map_or(0.0, |c| c.temperature) +} + +/// Check if constraint is stopped (cannot be extended further). +#[no_mangle] +pub extern "C" fn llg_is_stopped(cc: &LlgConstraint) -> bool { + cc.constraint + .as_ref() + .map_or(true, |c| c.step_result().is_stop()) +} + /// Compute mask for the next token sampling /// It typically takes up to a millisecond for a 100k tokenizer, so should be called in background. /// Returns 0 on success and -1 on error (use llg_get_error() to get the exact error). @@ -511,6 +551,31 @@ pub extern "C" fn llg_commit_token( cc.get_error_code() } +/// Compute mask for several constraints in parallel. +#[no_mangle] +pub extern "C" fn llg_par_compute_mask( + steps: *const LlgConstraintStep, + n_steps: usize, + user_data: *const c_void, + done_cb: LlgCallback, +) { + if steps.is_null() { + panic!("llg_par_compute_mask: steps is null"); + } + + #[cfg(feature = "rayon")] + { + let steps = unsafe { std::slice::from_raw_parts(steps, n_steps).to_vec() }; + crate::ffi_par::par_compute_mask(steps, user_data, done_cb); + } + + #[cfg(not(feature = "rayon"))] + { + let _ = (steps, n_steps, user_data, done_cb); + panic!("llg_par_compute_mask: rayon feature is not enabled"); + } +} + /// Clone the constraint #[no_mangle] pub extern "C" fn llg_clone_constraint(cc: &LlgConstraint) -> *mut LlgConstraint { diff --git a/parser/src/ffi_par.rs b/parser/src/ffi_par.rs new file mode 100644 index 0000000..8482978 --- /dev/null +++ b/parser/src/ffi_par.rs @@ -0,0 +1,64 @@ +use std::ffi::c_void; + +use crate::ffi::{LlgCallback, LlgConstraintStep}; + +fn par_compute_mask_inner(constraints: Vec) { + use rayon::prelude::*; + constraints.into_par_iter().for_each(|step| { + assert!(step.mask_byte_len % 4 == 0); + assert!(!step.mask_dest.is_null()); + let mask_elts = step.mask_byte_len / 4; + + let cc = unsafe { &mut *step.constraint }; + if let Some(constraint) = &mut cc.constraint { + match constraint.compute_mask() { + Ok(r) => { + let mut num_copied = 0; + if let Some(m) = r.sample_mask.as_ref() { + num_copied = std::cmp::min(m.len(), mask_elts); + unsafe { + std::ptr::copy_nonoverlapping(m.as_ptr(), step.mask_dest, num_copied); + } + } + let left = mask_elts - num_copied; + if left > 0 { + unsafe { + std::ptr::write_bytes(step.mask_dest.add(num_copied), 0, left); + } + } + if r.is_stop() { + let eos = constraint.tok_trie().eos_token() as usize; + if eos / 32 < mask_elts { + unsafe { + *step.mask_dest.add(eos / 32) |= 1 << (eos % 32); + } + } + } + } + Err(e) => cc.set_error(&e.to_string()), + } + } + }); +} + +pub(crate) fn par_compute_mask( + constraints: Vec, + user_data: *const c_void, + done_cb: LlgCallback, +) { + struct CbData { + user_data: *const c_void, + } + unsafe impl Send for CbData {} + + if let Some(cb) = done_cb { + let ptr = CbData { user_data }; + rayon::spawn(move || { + par_compute_mask_inner(constraints); + cb(ptr.user_data); + drop(ptr); + }); + } else { + par_compute_mask_inner(constraints); + } +} diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 2dd37dc..faf756c 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -21,6 +21,8 @@ pub use logging::Logger; pub use derivre; pub mod ffi; +#[cfg(feature = "rayon")] +mod ffi_par; mod grammar_builder; mod json;