Skip to content

Commit

Permalink
align API names with Constraint (#63)
Browse files Browse the repository at this point in the history
* align API names with Constraint

* fix up comments

* remove deprecated methods

* don't return None on stop from commit_token()

* test guidance on llg_040 branch for now
  • Loading branch information
mmoskal authored Nov 22, 2024
1 parent a54c899 commit e55998b
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 41 deletions.
7 changes: 5 additions & 2 deletions parser/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,14 @@ impl Constraint {
Ok(CommitResult::from_step_result(&self.last_res))
}

/// commit_token() is a top-level method in this file and is called indirectly via
/// the advance_parser() method of the LLIinterpreter trait in py.rs.
/// commit_token() is a top-level method in this file and is called by
/// the LLInterpreter::commit_token().
///
/// commit_token() commits the sampled token (if any), and sees if this forces any more tokens
/// on the output (if ff_tokens are enabled in InferenceCapabilities).
///
/// It only returns 'STOP' if previous compute_mask() already returned 'STOP'
/// (in which case there's little point calling commit_token()).
pub fn commit_token(&mut self, sampled_token: Option<TokenId>) -> Result<CommitResult> {
loginfo!(self.parser.logger, "\ncommit_token({:?})", sampled_token);

Expand Down
2 changes: 1 addition & 1 deletion parser/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ impl ParserState {

// apply_tokens() "pushes" the bytes in 'tokens' into the lexer and parser. It is a top-level
// method in this file. It is well below llguidance's top-level methods, but in the llguidance
// LLInterpreter interface, it is called indirectly via the advance_parser() method.
// LLInterpreter interface, it is called indirectly via the commit_token() method.
pub fn apply_token(&mut self, shared: &mut SharedState, tok_bytes: &[u8]) -> Result<usize> {
self.assert_definitive();

Expand Down
6 changes: 3 additions & 3 deletions parser/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl TokenParser {

// advance_parser() is a top-level method in this file.
// This advance_parser() is called by Constraint::commit_token().
// It is accessible via the advance_parser() method of
// It is accessible via the commit_token() method of
// the LLInterpreter interface.
//
// The result here *never* includes a mask.
Expand All @@ -282,9 +282,9 @@ impl TokenParser {

// mid_process() is a top-level method in this file.
// mid_process() is called by Constraint::commit_token().
// It is also be called by TokenParser::advance_parser()
// It is also be called by TokenParser::commit_token()
// within this file, in which case it is accessible
// via the advance_parser() method of the LLInterpreter interface.
// via the commit_token() method of the LLInterpreter interface.
pub fn mid_process(&mut self, mut arg: StepArg) -> StepResult {
assert!(self.is_fresh == false, "process_prompt() not called");

Expand Down
26 changes: 8 additions & 18 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ from typing import List, Tuple, Mapping, Optional, Sequence, Union
from ._util import TokenId, StopReason
from ._tokenizer import TokenizerWrapper


class LLTokenizer:
vocab_size: int
eos_token: TokenId
Expand Down Expand Up @@ -57,9 +56,7 @@ class LLTokenizer:
Decode the tokens into a bytes object.
"""


class LLInterpreter:

def __new__(
cls,
tokenizer: LLTokenizer,
Expand All @@ -86,7 +83,7 @@ class LLInterpreter:

def is_accepting(self) -> bool:
"""
Check if the last mid_process() call resulted in overall accepting state
Check if the last compute_mask() call resulted in overall accepting state
of the parser.
"""

Expand All @@ -111,33 +108,26 @@ class LLInterpreter:
Returns the adjusted prompt.
"""

def mid_process(self) -> Tuple[Optional[bytes], str]:
def compute_mask(self) -> Tuple[Optional[bytes], str]:
"""
Perform next parsing step.
Returns: optional token mask and a JSON string.
"""

def post_process(
self,
sampled_token: Optional[TokenId]) -> Tuple[int, List[TokenId]]:
def commit_token(
self, sampled_token: Optional[TokenId]
) -> Tuple[int, List[TokenId]]:
"""
Perform any adjustments to the sampled token.
Returns the number of tokens to remove from the prompt and the
list of tokens to append.
If mid_process() returned None, this should be called immedietly with None.
"""

def advance_parser(
self,
sampled_token: Optional[TokenId]) -> Tuple[int, List[TokenId]]:
"""
Like post_process(), but goes further.
This is experimental and breaks tests when used instead of post_process().
If compute_mask() returned None mask, this should be called immediately with None.
If compute_mask() returned stop, you don't need to call this (but can).
"""

def has_pending_stop(self) -> bool:
"""
If true, next mid_process() call will return stop
If true, next compute_mask() call will return stop
"""

class JsonCompiler:
Expand Down
36 changes: 20 additions & 16 deletions rust/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use llguidance_parser::toktrie::{
self, InferenceCapabilities, TokRxInfo, TokTrie, TokenId, TokenizerEnv,
};
use llguidance_parser::{api::TopLevelGrammar, output::ParserOutput, TokenParser};
use llguidance_parser::{lark_to_llguidance, Constraint, GrammarBuilder, JsonCompileOptions, Logger};
use llguidance_parser::{
lark_to_llguidance, Constraint, GrammarBuilder, JsonCompileOptions, Logger,
};
use pyo3::{exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -78,7 +80,7 @@ impl LLInterpreter {
self.inner.process_prompt(prompt)
}

fn mid_process(&mut self, py: Python<'_>) -> PyResult<(Option<Cow<[u8]>>, String)> {
fn compute_mask(&mut self, py: Python<'_>) -> PyResult<(Option<Cow<[u8]>>, String)> {
let r = py
.allow_threads(|| self.inner.compute_mask())
.map_err(val_error)?
Expand Down Expand Up @@ -108,19 +110,19 @@ impl LLInterpreter {
}
}

fn advance_parser(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
fn commit_token(
&mut self,
sampled_token: Option<TokenId>,
) -> PyResult<(u32, Vec<TokenId>)> {
let pres = self.inner.commit_token(sampled_token).map_err(val_error)?;

if pres.stop {
// let the next mid_process() call handle it
return Ok((0, vec![]));
// inner.commit_token() only returns stop, when compute_mask()
// had already returned stop
Ok((0, vec![]))
} else {
Ok((pres.backtrack, pres.ff_tokens))
}

Ok((pres.backtrack, pres.ff_tokens))
}

fn post_process(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
self.advance_parser(sampled_token)
}

fn has_pending_stop(&self) -> bool {
Expand Down Expand Up @@ -246,18 +248,20 @@ impl TokenizerEnv for LLTokenizer {
struct JsonCompiler {
item_separator: String,
key_separator: String,
whitespace_flexible: bool
whitespace_flexible: bool,
}

#[pymethods]
impl JsonCompiler {
#[new]
#[pyo3(signature = (separators = None, whitespace_flexible = false))]
fn py_new(separators: Option<(String, String)>, whitespace_flexible: bool) -> Self {
let (item_separator, key_separator) = separators.unwrap_or_else(|| if whitespace_flexible {
(",".to_owned(), ":".to_owned())
} else {
(", ".to_owned(), ": ".to_owned())
let (item_separator, key_separator) = separators.unwrap_or_else(|| {
if whitespace_flexible {
(",".to_owned(), ":".to_owned())
} else {
(", ".to_owned(), ": ".to_owned())
}
});
JsonCompiler {
item_separator: item_separator,
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-guidance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ else
if test -f guidance/tests/unit/test_ll.py ; then
echo "Guidance clone OK"
else
git clone -b main https://github.com/guidance-ai/guidance
git clone -b llg_040 https://github.com/guidance-ai/guidance
fi
cd guidance
echo "Branch: $(git branch --show-current), Remote URL: $(git remote get-url origin), HEAD: $(git rev-parse HEAD)"
Expand Down

0 comments on commit e55998b

Please sign in to comment.