Skip to content

Commit

Permalink
simplify commit_token logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Nov 22, 2024
1 parent 81f6be3 commit 4931fc3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 36 deletions.
6 changes: 1 addition & 5 deletions parser/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,7 @@ impl Constraint {
}

// now, advance the parser with the sampled token - this should be very quick
let pres = self.parser.advance_parser(StepArg {
backtrack: 0,
tokens: vec![sampled_token],
sampled: Some(sampled_token),
});
let pres = self.parser.advance_parser(sampled_token)?;

// save any logs
self.save_progress_and_result(pres);
Expand Down
55 changes: 24 additions & 31 deletions parser/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
},
infoln, warn, Logger,
};
use anyhow::Result;
use anyhow::{ensure, Result};
use serde_json::json;
use toktrie::{InferenceCapabilities, SimpleVob, StepArg, StepResult, TokEnv, TokenId};

Expand Down Expand Up @@ -37,8 +37,6 @@ pub struct TokenParser {
stop_reason: StopReason,
error_message: Option<String>,

no_bias_this_mid_process: bool,

max_tokens_total: usize,
max_tokens_parser: usize,
compiled_grammars: Vec<Arc<CGrammar>>,
Expand Down Expand Up @@ -97,7 +95,6 @@ impl TokenParser {
last_step_stats: ParserStats::default(),
mid_process_start_time,
mid_process_was_accepting: false,
no_bias_this_mid_process: false,
stop_reason: StopReason::NotStopped,
error_message: None,
pop_tokens: None,
Expand Down Expand Up @@ -265,17 +262,31 @@ impl TokenParser {
//
// The result here *never* includes a mask.
// It's either stop or an unconditional splice (possibly noop).
pub fn advance_parser(&mut self, arg: StepArg) -> StepResult {
assert!(self.inference_caps.ff_tokens);
assert!(!self.test_trace);
pub fn advance_parser(&mut self, token: TokenId) -> Result<StepResult> {
ensure!(self.is_fresh == false, "process_prompt() not called");
ensure!(self.inference_caps.ff_tokens, "ff_tokens required");
ensure!(
self.stop_reason == StopReason::NotStopped,
"commit_token() on stopped parser"
);

self.mid_process_was_accepting = false;

let tokens = &[token];
infoln!(
self,
"commit_token: {}",
self.token_env.tok_trie().tokens_dbg(tokens)
);

self.no_bias_this_mid_process = true;
let r = self.mid_process(arg);
self.no_bias_this_mid_process = false;
let r = match self.commit_tokens_inner(tokens) {
Ok(_) => StepResult::noop(),
Err(r) => r,
};

assert!(r.sample_mask.is_none());

r
Ok(r)
}

// mid_process() is a top-level method in this file.
Expand Down Expand Up @@ -371,20 +382,6 @@ impl TokenParser {
self.stop(&format!("{}{}", pref, err.message()), err.stop_reason())
}

fn log_inital(&mut self, tokens: &[TokenId]) {
let trie = self.token_env.tok_trie();
infoln!(
self,
"{}: {}",
if self.no_bias_this_mid_process {
"commit_token"
} else {
"compute_mask"
},
trie.tokens_dbg(tokens)
);
}

fn maybe_pop_parsers(&mut self, tokens: &[TokenId]) {
if tokens.len() == 1 {
let token = tokens[0];
Expand Down Expand Up @@ -812,15 +809,11 @@ impl TokenParser {
fn mid_process_inner(&mut self, tokens: &[TokenId]) -> Result<(), StepResult> {
self.mid_process_was_accepting = false;

self.log_inital(&tokens);
let trie = self.token_env.tok_trie();
infoln!(self, "compute_mask: {}", trie.tokens_dbg(tokens));

let (token_prefix, inner_accepting) = self.commit_tokens_inner(tokens)?;

if self.no_bias_this_mid_process {
self.no_bias_this_mid_process = false;
return Err(StepResult::noop());
}

let mut allowed_tokens = self.compute_bias(&token_prefix);

if let Some(err) = self.parser.get_error() {
Expand Down

0 comments on commit 4931fc3

Please sign in to comment.