From 50eb993554252ad0cf202481406e2839da20cbbd Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 4 Dec 2024 16:59:10 -0800 Subject: [PATCH] add json schema tester --- Cargo.lock | 1 + sample_parser/Cargo.toml | 7 +- sample_parser/jsonschema.sh | 5 + sample_parser/src/json_schema_testsuite.rs | 186 +++++++++++++++++++++ 4 files changed, 198 insertions(+), 1 deletion(-) create mode 100755 sample_parser/jsonschema.sh create mode 100644 sample_parser/src/json_schema_testsuite.rs diff --git a/Cargo.lock b/Cargo.lock index 7b3f2506..590d2d69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1560,6 +1560,7 @@ dependencies = [ "anyhow", "lazy_static", "llguidance", + "serde", "serde_json", "toktrie_hf_tokenizers", ] diff --git a/sample_parser/Cargo.toml b/sample_parser/Cargo.toml index f3e3fd1c..e7d060b6 100644 --- a/sample_parser/Cargo.toml +++ b/sample_parser/Cargo.toml @@ -7,14 +7,19 @@ default-run = "sample_parser" [dependencies] llguidance = { workspace = true } toktrie_hf_tokenizers = { workspace = true } -serde_json = "1.0.128" anyhow = "1.0.87" lazy_static = "1.5.0" +serde_json = { version = "1.0.132", features = ["preserve_order"] } +serde = { version = "1.0.210", features = ["derive"] } [[bin]] name = "sample_parser" path = "src/sample_parser.rs" +[[bin]] +name = "json_schema_testsuite" +path = "src/json_schema_testsuite.rs" + [[bin]] name = "schema_tester" path = "src/schema_tester.rs" diff --git a/sample_parser/jsonschema.sh b/sample_parser/jsonschema.sh new file mode 100755 index 00000000..fa39dc5f --- /dev/null +++ b/sample_parser/jsonschema.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +set -e + +cargo run --release --bin json_schema_testsuite ../../../JSON-Schema-Test-Suite/tests/draft2020-12/*.json diff --git a/sample_parser/src/json_schema_testsuite.rs b/sample_parser/src/json_schema_testsuite.rs new file mode 100644 index 00000000..e692d3da --- /dev/null +++ b/sample_parser/src/json_schema_testsuite.rs @@ -0,0 +1,186 @@ +use anyhow::{bail, Result}; +use core::str; +use llguidance::{ + api::{ParserLimits, TopLevelGrammar}, + toktrie::{InferenceCapabilities, TokEnv}, + Constraint, JsonCompileOptions, TokenParser, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::{env, fs::File, io::Read, vec}; + +#[derive(Debug, Serialize, Deserialize)] +struct JsonTest { + description: String, + schema: Value, + tests: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct JsonTestSequence { + description: String, + data: Value, + valid: bool, +} + +impl JsonTestSequence { + fn run_for(&self, obj_str: &str, tok_env: &TokEnv, mut constraint: Constraint) -> Result<()> { + let tokens = tok_env.tokenize(obj_str); + let trie = tok_env.tok_trie(); + + let mut idx = 0; + while idx < tokens.len() { + // println!("idx: {} {}", idx, trie.token_dbg(tokens[idx])); + + let res = constraint.compute_mask()?; + + if res.is_stop() { + if self.valid { + bail!("premature stop in valid test"); + } else { + bail!("premature stop in invalid test"); // ?? + } + } + + let sampled_token = if let Some(mask) = &res.sample_mask { + let sampled_token = tokens[idx]; + if !mask.is_allowed(sampled_token) { + if self.valid { + bail!( + "sampled token {} not allowed by mask", + trie.token_dbg(sampled_token) + ); + } else { + return Ok(()); + } + } + + // let p_stats = constraint.parser.last_step_stats(); + Some(sampled_token) + } else { + None + }; + + let splice = constraint.commit_token(sampled_token)?; + if splice.stop { + if self.valid { + if idx + 1 < tokens.len() { + bail!("premature stop in valid test (commit)"); + } else { + return Ok(()); + } + } else { + bail!("premature stop in invalid test (commit)"); // ?? + } + } + + assert!(splice.backtrack == 0); // we didn't allow backtracking in InferenceCaps + + if tokens[idx..idx + splice.ff_tokens.len()] != splice.ff_tokens { + bail!( + "BAD TEST: ff_tokens mismatch:\n{}\n{}", + trie.tokens_dbg(&tokens[idx..idx + splice.ff_tokens.len()]), + trie.tokens_dbg(&splice.ff_tokens) + ); + } + + idx += splice.ff_tokens.len(); + } + + let accept = constraint.parser.is_accepting(); + + if self.valid { + if accept { + Ok(()) + } else { + bail!("unexpected end of test"); + } + } else { + bail!( + "unexpected end of test for invalid test (accept={})", + accept + ); + } + } + + fn run(&self, grm: &TopLevelGrammar, tok_env: &TokEnv) -> Result<()> { + let stderr_log_level = 1; + let buffer_log_level = 0; + let parser = TokenParser::from_llguidance_json( + tok_env.clone(), + grm.clone(), + llguidance::Logger::new(buffer_log_level, stderr_log_level), + InferenceCapabilities { + ff_tokens: false, // can the engine append multiple tokens? + backtrack: false, // can the engine remove generated tokens? + + conditional_ff_tokens: false, // not used + fork: false, // not used + }, + ParserLimits::default(), + vec![], + )?; + let constraint = Constraint::new(parser); + + let obj_str = serde_json::to_string_pretty(&self.data).unwrap(); + match self.run_for(&obj_str, tok_env, constraint) { + Ok(_) => Ok(()), + Err(e) => { + bail!("{}\n{:?}", e, obj_str) + } + } + } +} + +impl JsonTest { + fn run(&self, tok_env: &TokEnv) -> Result<()> { + let opts = JsonCompileOptions::default(); + let grm = opts.json_to_llg(self.schema.clone())?; + let mut first_err = Ok(()); + for t in &self.tests { + let r = t.run(&grm, tok_env); + if first_err.is_ok() && r.is_err() { + first_err = r; + } + } + first_err + } +} + +fn main() { + let args: Vec = env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let tok_env: TokEnv = + toktrie_hf_tokenizers::ByteTokenizerEnv::from_name("meta-llama/Llama-3.2-1B", None) + .unwrap() + .to_env(); + + let t0 = std::time::Instant::now(); + for arg in &args[1..] { + let schema_file = read_file_to_string(arg); + let val: Vec = + serde_json::from_str(&schema_file).expect("Invalid JSON in schema"); + for (idx, t) in val.iter().enumerate() { + print!("Running test: {} ({}) #{} ", arg, t.description, idx); + match t.run(&tok_env) { + Ok(_) => println!("OK"), + Err(e) => println!("ERROR: {}", e), + } + } + } + + let elapsed = t0.elapsed(); + println!("Total time: {} ms", elapsed.as_millis()); +} + +fn read_file_to_string(filename: &str) -> String { + let mut file = File::open(filename).expect("Unable to open file"); + let mut content = String::new(); + file.read_to_string(&mut content) + .expect("Unable to read file"); + content +}