Skip to content

Commit

Permalink
Merge branch 'main' into json_extension
Browse files Browse the repository at this point in the history
  • Loading branch information
hudson-ai committed Nov 19, 2024
2 parents 5a884e5 + 7a98040 commit d822980
Show file tree
Hide file tree
Showing 26 changed files with 909 additions and 114 deletions.
29 changes: 13 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
This library implements constrained decoding (also called constrained sampling or
structured outputs) for Large Langauge Models (LLMs).
It can enforce arbitrary context-free grammar on the output of LLM
and is fast (on the order of 1ms of CPU time per token
(for 100k tokenizer) with negligible startup costs).
and is fast - on the order of 1ms of CPU time per token
(for 100k tokenizer) with negligible startup costs.

Following grammar formats are supported:
- `llguidance` - [internal (JSON-based) format](./parser/src/api.rs)
Expand Down Expand Up @@ -36,20 +36,17 @@ The integration is ongoing in:

## Technical details

Given a context-free grammar, a tokenizer, and prefix of tokens,
llguidance computes a token mask (set of tokens from the tokenizer)
that when added to current prefix of token can lead to a valid string in
the language of the grammar.
Computing a mask takes on the order of 1ms of single-core CPU time
for a tokenizer with 100k tokens.
While this depends on the exact grammar, it holds eg. for grammars resulting from JSON schemas.
There is also no significant startup cost.

The library implements a context-free grammar parser with Earley's algorithm
on top of a lexer which uses [derivatives of regular expressions](https://github.com/microsoft/derivre).
A lot of
[low-level optimizations](https://github.com/microsoft/toktrie/blob/main/implementation.md)
are implemented.
Given a context-free grammar, a tokenizer, and a prefix of tokens, llguidance computes a token mask - a set of tokens from the tokenizer - that, when added to the current token prefix, can lead to a valid string in the language defined by the grammar. Mask computation takes approximately 1ms of single-core CPU time for a tokenizer with 100k tokens. While this timing depends on the exact grammar, it holds, for example, for grammars derived from JSON schemas. There is no significant startup cost.

The library implements a context-free grammar parser using Earley’s algorithm on top of a lexer based on [derivatives of regular expressions](https://github.com/microsoft/derivre). Mask computation is achieved by traversing the prefix tree (trie) of all possible tokens, leveraging [highly optimized](https://github.com/microsoft/toktrie/blob/main/implementation.md) code.

### Comparison

[LM-format-enforcer](https://github.com/noamgat/lm-format-enforcer) and [llama.cpp grammars](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) are similar to llguidance in that they dynamically build token masks for every step of the decoding process. Both are significantly slower - the former due to clean Python code and the latter due to the lack of a lexer and use of a backtracking parser, which, while elegant, is inefficient.

[Outlines](https://github.com/dottxt-ai/outlines) builds an automaton from constraints and then pre-computes token masks for all automaton states, making sampling fast but inherently limiting constraint complexity and introducing significant startup cost and memory overhead. Llguidance computes token masks on the fly and has essentially no startup cost. The lexer’s automata are built lazily and are typically much smaller, as the context-free grammar imposes the top-level structure.

In llguidance, online mask computation takes approximately 1ms of CPU time per sequence in a batch. Thus, with 16 cores and a 10ms forward pass, the library can handle batch sizes up to 160 without slowing down the model. (Note that a 10ms forward pass for small batch sizes typically increases to 20ms+ for batch sizes of 100-200.)

## Building

Expand Down
52 changes: 52 additions & 0 deletions parser/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ lazy_static = { version = "1.5.0", optional = true }
regex-syntax = "0.8.5"
indexmap = "2.6.0"
referencing = "0.26.1"
rayon = { version = "1.10.0", optional = true }

[features]
default = ["jsonschema_validation", "lark"]
logging = [] # this is extensive debug logging
lark = [] # ~115k (binary)
jsonschema_validation = ["jsonschema", "lazy_static"] # ~2.5M (binary)
default = ["jsonschema_validation", "lark", "rayon"]
logging = [] # this is extensive debug logging
lark = [] # ~115k (binary)
jsonschema_validation = ["dep:jsonschema", "dep:lazy_static"] # ~2.5M (binary)
rayon = ["dep:rayon"]

[lib]
crate-type = ["staticlib", "rlib", "cdylib"]
Expand Down
44 changes: 44 additions & 0 deletions parser/llguidance.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ typedef struct LlgParserLimits {
* Default: 500_000 (~10ms)
*/
uint64_t step_lexer_fuel;
/**
* Number of Earley items created for the whole token mask.
* Default: 100_000 (~3ms)
*/
size_t step_max_items;
/**
* Maximum number of lexer states.
* Default: 10_000
Expand Down Expand Up @@ -108,6 +113,26 @@ typedef struct LlgCommitResult {
bool is_stop;
} LlgCommitResult;

typedef struct LlgConstraintStep {
/**
* The constraint to compute mask for.
*/
struct LlgConstraint *constraint;
/**
* Pointer to memory where the mask should be written.
*/
uint32_t *mask_dest;
/**
* The length of the mask_dest array in bytes (not elements).
*/
size_t mask_byte_len;
} LlgConstraintStep;

/**
* Function which llg calls when an operation is done.
*/
typedef void (*LlgCallback)(const void *user_data);

/**
* Tokenization function
* Will not write more than output_tokens_len tokens (which can be 0)
Expand Down Expand Up @@ -226,6 +251,17 @@ struct LlgConstraint *llg_new_constraint_any(const struct LlgConstraintInit *ini
*/
const char *llg_get_error(const struct LlgConstraint *cc);

/**
* Get the current temperature of the constraint.
* It is updated by mask computation.
*/
float llg_get_temperature(const struct LlgConstraint *cc);

/**
* Check if constraint is stopped (cannot be extended further).
*/
bool llg_is_stopped(const struct LlgConstraint *cc);

/**
* Compute mask for the next token sampling
* It typically takes up to a millisecond for a 100k tokenizer, so should be called in background.
Expand All @@ -242,6 +278,14 @@ int32_t llg_compute_mask(struct LlgConstraint *cc, struct LlgMaskResult *res_p);
*/
int32_t llg_commit_token(struct LlgConstraint *cc, LlgToken token, struct LlgCommitResult *res_p);

/**
* Compute mask for several constraints in parallel.
*/
void llg_par_compute_mask(const struct LlgConstraintStep *steps,
size_t n_steps,
const void *user_data,
LlgCallback done_cb);

/**
* Clone the constraint
*/
Expand Down
39 changes: 36 additions & 3 deletions parser/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::Debug;
use std::fmt::{Debug, Display};

use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand All @@ -19,6 +19,9 @@ pub const DEFAULT_CONTEXTUAL: bool = true;

#[derive(Serialize, Deserialize, Clone, Default)]
pub struct GrammarWithLexer {
/// The name of this grammar, can be used in GenGrammar nodes.
pub name: Option<String>,

/// The start symbol is at nodes[0]
/// When nodes is empty, then one of json_schema or lark_grammar must be set.
#[serde(default)]
Expand Down Expand Up @@ -257,6 +260,31 @@ impl RegexSpec {
}
}

#[derive(Serialize, Deserialize, Hash, PartialEq, Eq, Clone, Debug)]
#[serde(untagged)]
pub enum GrammarId {
Index(usize),
Name(String),
}

impl GrammarId {
pub fn to_index(&self) -> Option<usize> {
match self {
GrammarId::Index(i) => Some(*i),
GrammarId::Name(_) => None,
}
}
}

impl Display for GrammarId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GrammarId::Index(i) => write!(f, "@#{}", i),
GrammarId::Name(s) => write!(f, "@{}", s),
}
}
}

macro_rules! id_type {
($name:ident) => {
#[derive(Serialize, Deserialize, Hash, PartialEq, Eq, Clone, Copy, Debug)]
Expand All @@ -265,7 +293,6 @@ macro_rules! id_type {
};
}

id_type!(GrammarId);
id_type!(NodeId);
id_type!(RegexId);

Expand All @@ -286,7 +313,7 @@ impl Node {
impl Default for GenGrammarOptions {
fn default() -> Self {
GenGrammarOptions {
grammar: GrammarId(0),
grammar: GrammarId::Index(0),
temperature: None,
max_tokens_grm: usize::MAX,
}
Expand Down Expand Up @@ -351,6 +378,10 @@ pub struct ParserLimits {
/// Default: 500_000 (~10ms)
pub step_lexer_fuel: u64,

/// Number of Earley items created for the whole token mask.
/// Default: 100_000 (~3ms)
pub step_max_items: usize,

/// Maximum number of lexer states.
/// Default: 10_000
pub max_lexer_states: usize,
Expand All @@ -368,6 +399,7 @@ impl Default for ParserLimits {
step_lexer_fuel: 500_000, // 500k => 10ms
max_lexer_states: 10_000, // ?
max_grammar_size: 500_000, // fhir schema => 200k
step_max_items: 100_000, //
}
}
}
Expand All @@ -376,6 +408,7 @@ impl TopLevelGrammar {
pub fn from_regex(rx: RegexNode) -> Self {
TopLevelGrammar {
grammars: vec![GrammarWithLexer {
name: Some("regex_grammar".to_string()),
nodes: vec![Node::Lexeme {
rx: RegexSpec::RegexId(RegexId(0)),
contextual: None,
Expand Down
26 changes: 20 additions & 6 deletions parser/src/earley/from_guidance.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::collections::HashMap;
use std::fmt::Write;
use std::{sync::Arc, vec};

use super::{grammar::SymbolProps, lexerspec::LexerSpec, CGrammar, Grammar};
use crate::api::{
GrammarWithLexer, Node, ParserLimits, RegexId, RegexNode, RegexSpec, TopLevelGrammar,
DEFAULT_CONTEXTUAL,
GrammarId, GrammarWithLexer, Node, ParserLimits, RegexId, RegexNode, RegexSpec,
TopLevelGrammar, DEFAULT_CONTEXTUAL,
};
use crate::{lark_to_llguidance, loginfo, JsonCompileOptions, Logger};
use anyhow::{bail, ensure, Result};
Expand Down Expand Up @@ -304,14 +305,26 @@ pub fn grammars_from_json(
extra_lexemes: Vec<String>,
) -> Result<Vec<Arc<CGrammar>>> {
let t0 = Instant::now();
let grammars = input
let mut grammars = input
.grammars
.into_iter()
.map(|g| grammar_from_json(tok_env, &mut limits, g))
.collect::<Result<Vec<_>>>()?;

for (_, g) in &grammars {
g.validate_grammar_refs(&grammars)?;
let mut grammar_by_idx = HashMap::new();
for (idx, (_, g)) in grammars.iter().enumerate() {
grammar_by_idx.insert(GrammarId::Index(idx), idx);
if let Some(n) = g.name() {
let n = GrammarId::Name(n.to_string());
if grammar_by_idx.contains_key(&n) {
bail!("duplicate grammar name: {}", n);
}
grammar_by_idx.insert(n, idx);
}
}

for (_, g) in grammars.iter_mut() {
g.validate_grammar_refs(&grammar_by_idx)?;
}

let t1 = Instant::now();
Expand All @@ -327,8 +340,9 @@ pub fn grammars_from_json(
if log_grammar {
writeln!(
logger.info_logger(),
"Grammar #{}:\n{:?}\n{:?}\n",
"Grammar #{} {}:\n{:?}\n{:?}\n",
idx,
grm.name().unwrap_or(""),
lex,
grm
)
Expand Down
Loading

0 comments on commit d822980

Please sign in to comment.