diff --git a/parser/src/tokenparser.rs b/parser/src/tokenparser.rs index b489241..d5f37d5 100644 --- a/parser/src/tokenparser.rs +++ b/parser/src/tokenparser.rs @@ -254,12 +254,32 @@ impl TokenParser { Ok(()) } + pub fn validate_token(&mut self, token: TokenId) -> Result { + self.check_initialized("validate_tokens_raw")?; + let bytes = self.tok_trie().decode_raw(&[token]); + let n_valid = self.parser.validate_bytes(&bytes); + assert!(n_valid <= bytes.len()); + Ok(n_valid == bytes.len()) + } + /// Returns how many of the passed tokens can be accepted by the parser. /// It does not tokenize forced bytes, so will accept non-canonical tokenizations. /// If called with more than one token, it may ignore max_tokens constraints. pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result { self.check_initialized("validate_tokens_raw")?; + if tokens.is_empty() { + return Ok(0); + } + + if tokens.len() == 1 { + return if self.validate_token(tokens[0])? { + Ok(1) + } else { + Ok(0) + }; + } + let bytes = self.tok_trie().decode_raw(tokens); let n_valid = self.parser.validate_bytes(&bytes); assert!(n_valid <= bytes.len());