-
Notifications
You must be signed in to change notification settings - Fork 7
/
sample_parser.rs
154 lines (127 loc) · 5.04 KB
/
sample_parser.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
use std::{env, fs::File, hint::black_box, io::Read, vec};
use llguidance_parser::{
api::{ParserLimits, TopLevelGrammar},
lark_to_llguidance,
toktrie::{InferenceCapabilities, TokEnv},
Constraint, JsonCompileOptions, TokenParser,
};
fn main() {
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("Usage: {} <schema.ll.json> <sample.json>", args[0]);
std::process::exit(1);
}
let schema_file = read_file_to_string(&args[1]);
let schema: TopLevelGrammar = if args[1].ends_with(".ll.json") {
serde_json::from_str(&schema_file).expect("Invalid JSON in schema")
} else if args[1].ends_with(".schema.json") {
let opts = JsonCompileOptions::default();
let val = serde_json::from_str(&schema_file).expect("Invalid JSON in schema");
opts.json_to_llg(val)
.expect("Failed to convert JSON to LLG")
} else if args[1].ends_with(".lark") {
lark_to_llguidance(&schema_file).expect("Failed to convert lark to LLG")
} else {
panic!("Unknown schema file extension")
};
let obj_str = read_file_to_string(&args[2]);
// you can implement TokEnv yourself, if you have the tokenizer
// see the ByteTokenizerEnv for an example
let tok_env: TokEnv =
toktrie_hf_tokenizers::ByteTokenizerEnv::from_name("microsoft/Phi-3.5-mini-instruct", None)
.unwrap()
.to_env();
let tokens = tok_env.tokenize(&obj_str);
// set to 2 for more output; 1 is warnings only
let stderr_log_level = 1;
// typically set to 2, to send info-level output to the user
let buffer_log_level = 2;
let parser = TokenParser::from_llguidance_json(
tok_env.clone(),
schema,
llguidance_parser::Logger::new(buffer_log_level, stderr_log_level),
InferenceCapabilities {
ff_tokens: true, // can the engine append multiple tokens?
backtrack: false, // can the engine remove generated tokens?
conditional_ff_tokens: false, // not used
fork: false, // not used
},
ParserLimits::default(),
vec![],
)
.unwrap();
let mut constraint = Constraint::new(parser);
// enable sending parser results back via the logs (constraint.flush_logs())
constraint.log_json_progress = true;
let trie = tok_env.tok_trie();
eprintln!("Parsing tokens: {}", trie.tokens_dbg(&tokens));
let mut idx = 0;
while idx < tokens.len() {
let res = constraint.compute_mask().unwrap();
if res.is_stop() {
// stop sequence
break;
}
let sampled_token = if let Some(mask) = &res.sample_mask {
// Simulate sampling - it should use the mask and temperature
black_box(mask);
black_box(constraint.temperature);
let sampled_token = tokens[idx];
let p_stats = constraint.parser.last_step_stats();
println!(
"SAMPLE {}: {} {}; stats: {} lex, {} items, {} us",
idx,
sampled_token,
tok_env.tok_trie().token_dbg(sampled_token),
p_stats.lexer_cost,
p_stats.all_items,
p_stats.compute_time_us,
);
Some(sampled_token)
} else {
// sampling not required
println!("NO SAMPLE");
None
};
let splice = constraint.commit_token(sampled_token).unwrap();
if splice.stop {
// stop sequence
break;
}
assert!(splice.backtrack == 0); // we didn't allow backtracking in InferenceCaps
// The splice contains the tokens (possibly more than one since we enabled ff_tokens
// in InferenceCaps) that the parser wants to append to the output.
// if this fails, our test data is broken
if tokens[idx..idx + splice.ff_tokens.len()] != splice.ff_tokens {
panic!(
"BAD TEST: ff_tokens mismatch:\n{}\n{}",
trie.tokens_dbg(&tokens[idx..idx + splice.ff_tokens.len()]),
trie.tokens_dbg(&splice.ff_tokens)
);
}
if splice.ff_tokens.len() > 1 {
println!("FF: {}", trie.tokens_dbg(&splice.ff_tokens));
}
idx += splice.ff_tokens.len();
// send output to the user
send_output(&constraint.flush_logs());
}
// flush any output
send_output(&constraint.flush_logs());
// the stop reason should be likely also sent to the user
println!("Stop reason: {:?}", constraint.parser.stop_reason());
println!("Max step stats: {:?}", constraint.parser.max_step_stats());
}
fn read_file_to_string(filename: &str) -> String {
let mut file = File::open(filename).expect("Unable to open file");
let mut content = String::new();
file.read_to_string(&mut content)
.expect("Unable to read file");
content
}
fn send_output(user_output: &str) {
// enable if you want to see the output
if false {
println!("{}", user_output);
}
}