Skip to content

Commit

Permalink
Add args to enable/disable backtracking and fast-forwarding in LLInte…
Browse files Browse the repository at this point in the history
…rpreter (#30)

* Add args to enable/disable backtracking and fast-forwarding in LLInterpreter

* Remove conditional_ff_tokens args

---------

Co-authored-by: Loc Huynh <[email protected]>
  • Loading branch information
JC1DA and lochuynh1412 authored Oct 28, 2024
1 parent 29c5861 commit f9b0589
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 4 additions & 0 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,17 @@ class LLInterpreter:
cls,
tokenizer: LLTokenizer,
llguidance_json: str,
enable_backtrack: bool = True,
enable_ff_tokens: bool = True,
log_level: int = 1,
) -> "LLInterpreter":
"""
Create a new interpreter.
Args:
tokenizer: LLTokenizer - the tokenizer to use
llguidance_json: str - the JSON representation of the AG2 grammar/constraint
enable_backtrack: bool - whether to enable backtracking in the interpreter
enable_ff_tokens: bool - whether to enable fast-forwarded tokens in the interpreter
log_level: int - the verbosity level of the interpreter
0 is silent, 1 is warnings, 2 is verbose
"""
Expand Down
8 changes: 5 additions & 3 deletions rust/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ impl LLInterpreter {
fn py_new(
tokenizer: &LLTokenizer,
llguidance_json: &str,
enable_backtrack: Option<bool>,
enable_ff_tokens: Option<bool>,
log_level: Option<isize>,
) -> PyResult<Self> {
let env = tokenizer.clone();
let arg: TopLevelGrammar = serde_json::from_str(llguidance_json).map_err(val_error)?;
let log_level = log_level.unwrap_or(1);
let inference_caps = InferenceCapabilities {
backtrack: true,
ff_tokens: true,
conditional_ff_tokens: true,
backtrack: enable_backtrack.unwrap_or(true),
ff_tokens: enable_ff_tokens.unwrap_or(true),
conditional_ff_tokens: enable_ff_tokens.unwrap_or(true),
fork: false,
};
let logger = Logger::new(0, std::cmp::max(0, log_level) as u32);
Expand Down

0 comments on commit f9b0589

Please sign in to comment.