Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

align API names with Constraint #63

Merged
merged 5 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions parser/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,14 @@ impl Constraint {
Ok(CommitResult::from_step_result(&self.last_res))
}

/// commit_token() is a top-level method in this file and is called indirectly via
/// the advance_parser() method of the LLIinterpreter trait in py.rs.
/// commit_token() is a top-level method in this file and is called by
/// the LLInterpreter::commit_token().
///
/// commit_token() commits the sampled token (if any), and sees if this forces any more tokens
/// on the output (if ff_tokens are enabled in InferenceCapabilities).
///
/// It only returns 'STOP' if previous compute_mask() already returned 'STOP'
/// (in which case there's little point calling commit_token()).
pub fn commit_token(&mut self, sampled_token: Option<TokenId>) -> Result<CommitResult> {
loginfo!(self.parser.logger, "\ncommit_token({:?})", sampled_token);

Expand Down
2 changes: 1 addition & 1 deletion parser/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ impl ParserState {

// apply_tokens() "pushes" the bytes in 'tokens' into the lexer and parser. It is a top-level
// method in this file. It is well below llguidance's top-level methods, but in the llguidance
// LLInterpreter interface, it is called indirectly via the advance_parser() method.
// LLInterpreter interface, it is called indirectly via the commit_token() method.
pub fn apply_token(&mut self, shared: &mut SharedState, tok_bytes: &[u8]) -> Result<usize> {
self.assert_definitive();

Expand Down
6 changes: 3 additions & 3 deletions parser/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl TokenParser {

// advance_parser() is a top-level method in this file.
// This advance_parser() is called by Constraint::commit_token().
// It is accessible via the advance_parser() method of
// It is accessible via the commit_token() method of
// the LLInterpreter interface.
//
// The result here *never* includes a mask.
Expand All @@ -282,9 +282,9 @@ impl TokenParser {

// mid_process() is a top-level method in this file.
// mid_process() is called by Constraint::commit_token().
// It is also be called by TokenParser::advance_parser()
// It is also be called by TokenParser::commit_token()
// within this file, in which case it is accessible
// via the advance_parser() method of the LLInterpreter interface.
// via the commit_token() method of the LLInterpreter interface.
pub fn mid_process(&mut self, mut arg: StepArg) -> StepResult {
assert!(self.is_fresh == false, "process_prompt() not called");

Expand Down
26 changes: 8 additions & 18 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ from typing import List, Tuple, Mapping, Optional, Sequence, Union
from ._util import TokenId, StopReason
from ._tokenizer import TokenizerWrapper


class LLTokenizer:
vocab_size: int
eos_token: TokenId
Expand Down Expand Up @@ -57,9 +56,7 @@ class LLTokenizer:
Decode the tokens into a bytes object.
"""


class LLInterpreter:

def __new__(
cls,
tokenizer: LLTokenizer,
Expand All @@ -86,7 +83,7 @@ class LLInterpreter:

def is_accepting(self) -> bool:
"""
Check if the last mid_process() call resulted in overall accepting state
Check if the last compute_mask() call resulted in overall accepting state
of the parser.
"""

Expand All @@ -111,33 +108,26 @@ class LLInterpreter:
Returns the adjusted prompt.
"""

def mid_process(self) -> Tuple[Optional[bytes], str]:
def compute_mask(self) -> Tuple[Optional[bytes], str]:
"""
Perform next parsing step.
Returns: optional token mask and a JSON string.
"""

def post_process(
self,
sampled_token: Optional[TokenId]) -> Tuple[int, List[TokenId]]:
def commit_token(
self, sampled_token: Optional[TokenId]
) -> Tuple[int, List[TokenId]]:
"""
Perform any adjustments to the sampled token.
Returns the number of tokens to remove from the prompt and the
list of tokens to append.
If mid_process() returned None, this should be called immedietly with None.
"""

def advance_parser(
self,
sampled_token: Optional[TokenId]) -> Tuple[int, List[TokenId]]:
"""
Like post_process(), but goes further.
This is experimental and breaks tests when used instead of post_process().
If compute_mask() returned None mask, this should be called immediately with None.
If compute_mask() returned stop, you don't need to call this (but can).
"""

def has_pending_stop(self) -> bool:
"""
If true, next mid_process() call will return stop
If true, next compute_mask() call will return stop
"""

class JsonCompiler:
Expand Down
36 changes: 20 additions & 16 deletions rust/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use llguidance_parser::toktrie::{
self, InferenceCapabilities, TokRxInfo, TokTrie, TokenId, TokenizerEnv,
};
use llguidance_parser::{api::TopLevelGrammar, output::ParserOutput, TokenParser};
use llguidance_parser::{lark_to_llguidance, Constraint, GrammarBuilder, JsonCompileOptions, Logger};
use llguidance_parser::{
lark_to_llguidance, Constraint, GrammarBuilder, JsonCompileOptions, Logger,
};
use pyo3::{exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -78,7 +80,7 @@ impl LLInterpreter {
self.inner.process_prompt(prompt)
}

fn mid_process(&mut self, py: Python<'_>) -> PyResult<(Option<Cow<[u8]>>, String)> {
fn compute_mask(&mut self, py: Python<'_>) -> PyResult<(Option<Cow<[u8]>>, String)> {
let r = py
.allow_threads(|| self.inner.compute_mask())
.map_err(val_error)?
Expand Down Expand Up @@ -108,19 +110,19 @@ impl LLInterpreter {
}
}

fn advance_parser(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
fn commit_token(
&mut self,
sampled_token: Option<TokenId>,
) -> PyResult<(u32, Vec<TokenId>)> {
let pres = self.inner.commit_token(sampled_token).map_err(val_error)?;

if pres.stop {
// let the next mid_process() call handle it
return Ok((0, vec![]));
// inner.commit_token() only returns stop, when compute_mask()
// had already returned stop
Ok((0, vec![]))
} else {
Ok((pres.backtrack, pres.ff_tokens))
}

Ok((pres.backtrack, pres.ff_tokens))
}

fn post_process(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
self.advance_parser(sampled_token)
}

fn has_pending_stop(&self) -> bool {
Expand Down Expand Up @@ -246,18 +248,20 @@ impl TokenizerEnv for LLTokenizer {
struct JsonCompiler {
item_separator: String,
key_separator: String,
whitespace_flexible: bool
whitespace_flexible: bool,
}

#[pymethods]
impl JsonCompiler {
#[new]
#[pyo3(signature = (separators = None, whitespace_flexible = false))]
fn py_new(separators: Option<(String, String)>, whitespace_flexible: bool) -> Self {
let (item_separator, key_separator) = separators.unwrap_or_else(|| if whitespace_flexible {
(",".to_owned(), ":".to_owned())
} else {
(", ".to_owned(), ": ".to_owned())
let (item_separator, key_separator) = separators.unwrap_or_else(|| {
if whitespace_flexible {
(",".to_owned(), ":".to_owned())
} else {
(", ".to_owned(), ": ".to_owned())
}
});
JsonCompiler {
item_separator: item_separator,
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-guidance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ else
if test -f guidance/tests/unit/test_ll.py ; then
echo "Guidance clone OK"
else
git clone -b main https://github.com/guidance-ai/guidance
git clone -b llg_040 https://github.com/guidance-ai/guidance
fi
cd guidance
echo "Branch: $(git branch --show-current), Remote URL: $(git remote get-url origin), HEAD: $(git rev-parse HEAD)"
Expand Down
Loading