Skip to content

Commit

Permalink
add llg_par_compute_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Nov 12, 2024
1 parent 4de768e commit ae0c43f
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 6 deletions.
10 changes: 6 additions & 4 deletions parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ jsonschema = { version = "0.24.0", default-features = false, optional = true }
url = "2.5.2"
lazy_static = { version = "1.5.0", optional = true }
regex-syntax = "0.8.5"
rayon = { version = "1.10.0", optional = true }

[features]
default = ["jsonschema_validation", "lark"]
logging = [] # this is extensive debug logging
lark = [] # ~115k (binary)
jsonschema_validation = ["jsonschema", "lazy_static"] # ~2.5M (binary)
default = ["jsonschema_validation", "lark", "rayon"]
logging = [] # this is extensive debug logging
lark = [] # ~115k (binary)
jsonschema_validation = ["dep:jsonschema", "dep:lazy_static"] # ~2.5M (binary)
rayon = ["dep:rayon"]

[lib]
crate-type = ["staticlib", "rlib", "cdylib"]
Expand Down
26 changes: 26 additions & 0 deletions parser/llguidance.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

typedef struct LlgConstraint LlgConstraint;

typedef struct LlgConstraintStep LlgConstraintStep;

typedef struct LlgTokenizer LlgTokenizer;

typedef struct LlgParserLimits {
Expand Down Expand Up @@ -108,6 +110,11 @@ typedef struct LlgCommitResult {
bool is_stop;
} LlgCommitResult;

/**
* Function which llg calls when an operation is done.
*/
typedef void (*LlgCallback)(const void *user_data);

/**
* Tokenization function
* Will not write more than output_tokens_len tokens (which can be 0)
Expand Down Expand Up @@ -226,6 +233,17 @@ struct LlgConstraint *llg_new_constraint_any(const struct LlgConstraintInit *ini
*/
const char *llg_get_error(const struct LlgConstraint *cc);

/**
* Get the current temperature of the constraint.
* It is updated by mask computation.
*/
float llg_get_temperature(const struct LlgConstraint *cc);

/**
* Check if constraint is stopped (cannot be extended further).
*/
bool llg_is_stopped(const struct LlgConstraint *cc);

/**
* Compute mask for the next token sampling
* It typically takes up to a millisecond for a 100k tokenizer, so should be called in background.
Expand All @@ -242,6 +260,14 @@ int32_t llg_compute_mask(struct LlgConstraint *cc, struct LlgMaskResult *res_p);
*/
int32_t llg_commit_token(struct LlgConstraint *cc, LlgToken token, struct LlgCommitResult *res_p);

/**
* Compute mask for several constraints in parallel.
*/
void llg_par_compute_mask(const struct LlgConstraintStep *steps,
size_t n_steps,
const void *user_data,
LlgCallback done_cb);

/**
* Clone the constraint
*/
Expand Down
69 changes: 67 additions & 2 deletions parser/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ pub type LlgTokenizeFn = Option<
) -> usize,
>;

/// Function which llg calls when an operation is done.
pub type LlgCallback = Option<extern "C" fn(user_data: *const c_void)>;

#[repr(C)]
pub struct LlgTokenizerInit {
/// The number of tokens in the vocabulary
Expand Down Expand Up @@ -210,13 +213,35 @@ pub struct LlgConstraintInit {
}

#[derive(Clone)]
pub struct LlgConstraintStep {
/// The constraint to compute mask for.
pub constraint: *mut LlgConstraint,
/// Pointer to memory where the mask should be written.
pub mask_dest: *mut u32,
/// The length of the mask_dest array in bytes (not elements).
pub mask_byte_len: usize,
}

unsafe impl Send for LlgConstraintStep {}

pub struct LlgConstraint {
local_error: Option<String>,
last_logs: String,
constraint: Option<Constraint>,
pub(crate) constraint: Option<Constraint>,
last_commit_result: CommitResult,
}

impl Clone for LlgConstraint {
fn clone(&self) -> Self {
LlgConstraint {
local_error: self.local_error.clone(),
last_logs: self.last_logs.clone(),
constraint: self.constraint.clone(),
last_commit_result: self.last_commit_result.clone(),
}
}
}

impl Default for LlgConstraint {
fn default() -> Self {
LlgConstraint {
Expand Down Expand Up @@ -360,7 +385,7 @@ impl LlgConstraint {
}
}

fn set_error(&mut self, e: &str) {
pub(crate) fn set_error(&mut self, e: &str) {
self.constraint = None;
self.local_error = Some(format!("{e}\0"));
}
Expand Down Expand Up @@ -456,6 +481,21 @@ pub extern "C" fn llg_get_error(cc: &LlgConstraint) -> *const c_char {
cc.get_error()
}

/// Get the current temperature of the constraint.
/// It is updated by mask computation.
#[no_mangle]
pub extern "C" fn llg_get_temperature(cc: &LlgConstraint) -> f32 {
cc.constraint.as_ref().map_or(0.0, |c| c.temperature)
}

/// Check if constraint is stopped (cannot be extended further).
#[no_mangle]
pub extern "C" fn llg_is_stopped(cc: &LlgConstraint) -> bool {
cc.constraint
.as_ref()
.map_or(true, |c| c.step_result().is_stop())
}

/// Compute mask for the next token sampling
/// It typically takes up to a millisecond for a 100k tokenizer, so should be called in background.
/// Returns 0 on success and -1 on error (use llg_get_error() to get the exact error).
Expand Down Expand Up @@ -511,6 +551,31 @@ pub extern "C" fn llg_commit_token(
cc.get_error_code()
}

/// Compute mask for several constraints in parallel.
#[no_mangle]
pub extern "C" fn llg_par_compute_mask(
steps: *const LlgConstraintStep,
n_steps: usize,
user_data: *const c_void,
done_cb: LlgCallback,
) {
if steps.is_null() {
panic!("llg_par_compute_mask: steps is null");
}

#[cfg(feature = "rayon")]
{
let steps = unsafe { std::slice::from_raw_parts(steps, n_steps).to_vec() };
crate::ffi_par::par_compute_mask(steps, user_data, done_cb);
}

#[cfg(not(feature = "rayon"))]
{
let _ = (steps, n_steps, user_data, done_cb);
panic!("llg_par_compute_mask: rayon feature is not enabled");
}
}

/// Clone the constraint
#[no_mangle]
pub extern "C" fn llg_clone_constraint(cc: &LlgConstraint) -> *mut LlgConstraint {
Expand Down
64 changes: 64 additions & 0 deletions parser/src/ffi_par.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::ffi::c_void;

use crate::ffi::{LlgCallback, LlgConstraintStep};

fn par_compute_mask_inner(constraints: Vec<LlgConstraintStep>) {
use rayon::prelude::*;
constraints.into_par_iter().for_each(|step| {
assert!(step.mask_byte_len % 4 == 0);
assert!(!step.mask_dest.is_null());
let mask_elts = step.mask_byte_len / 4;

let cc = unsafe { &mut *step.constraint };
if let Some(constraint) = &mut cc.constraint {
match constraint.compute_mask() {
Ok(r) => {
let mut num_copied = 0;
if let Some(m) = r.sample_mask.as_ref() {
num_copied = std::cmp::min(m.len(), mask_elts);
unsafe {
std::ptr::copy_nonoverlapping(m.as_ptr(), step.mask_dest, num_copied);
}
}
let left = mask_elts - num_copied;
if left > 0 {
unsafe {
std::ptr::write_bytes(step.mask_dest.add(num_copied), 0, left);
}
}
if r.is_stop() {
let eos = constraint.tok_trie().eos_token() as usize;
if eos / 32 < mask_elts {
unsafe {
*step.mask_dest.add(eos / 32) |= 1 << (eos % 32);
}
}
}
}
Err(e) => cc.set_error(&e.to_string()),
}
}
});
}

pub(crate) fn par_compute_mask(
constraints: Vec<LlgConstraintStep>,
user_data: *const c_void,
done_cb: LlgCallback,
) {
struct CbData {
user_data: *const c_void,
}
unsafe impl Send for CbData {}

if let Some(cb) = done_cb {
let ptr = CbData { user_data };
rayon::spawn(move || {
par_compute_mask_inner(constraints);
cb(ptr.user_data);
drop(ptr);
});
} else {
par_compute_mask_inner(constraints);
}
}
2 changes: 2 additions & 0 deletions parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub use logging::Logger;
pub use derivre;

pub mod ffi;
#[cfg(feature = "rayon")]
mod ffi_par;

mod grammar_builder;
mod json;
Expand Down

0 comments on commit ae0c43f

Please sign in to comment.