diff --git a/parser/src/earley/parser.rs b/parser/src/earley/parser.rs index 58e7859c..90c36c8b 100644 --- a/parser/src/earley/parser.rs +++ b/parser/src/earley/parser.rs @@ -670,6 +670,33 @@ impl ParserState { } } + pub fn bytes_allowed(&mut self, shared: &mut SharedState, tok_bytes: &[u8]) -> bool { + self.assert_definitive(); + let applied_idx = self.byte_to_token_idx.len(); + let tok_bytes = if applied_idx < self.bytes.len() { + let prefix_len = std::cmp::min(tok_bytes.len(), self.bytes.len() - applied_idx); + if self.bytes[applied_idx..applied_idx + prefix_len] != tok_bytes[..prefix_len] { + return false; + } + &tok_bytes[prefix_len..] + } else { + tok_bytes + }; + if tok_bytes.is_empty() { + return true; + } + + self.run_speculative(|s| { + let mut r = ParserRecognizer { shared, state: s }; + for &b in tok_bytes { + if !r.try_push_byte(b) { + return false; + } + } + true + }) + } + // apply_tokens() "pushes" the bytes in 'tokens' into the lexer and parser. It is a top-level // method in this file. It is well below llguidance's top-level methods, but in the llguidance // LLInterpreter interface, it is called indirectly via the commit_token() method. @@ -2001,6 +2028,11 @@ impl Parser { r } + pub fn bytes_allowed(&mut self, tok_bytes: &[u8]) -> bool { + let mut shared = self.shared.lock().unwrap(); + self.state.bytes_allowed(&mut shared, tok_bytes) + } + pub fn filter_max_tokens(&mut self) { self.state.filter_max_tokens(); }