Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Nov 26, 2024
2 parents 6eaa15f + ce3c760 commit a9f9f7c
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 19 deletions.
40 changes: 36 additions & 4 deletions parser/src/grammar_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ pub struct GrammarBuilder {
placeholder: Node,
strings: HashMap<String, NodeRef>,
curr_grammar_id: u32,
node_refs: HashMap<String, NodeRef>,
nodes: Vec<Node>,
pub regex: RegexBuilder,
at_most_cache: HashMap<(NodeRef, usize), NodeRef>,
repeat_exact_cache: HashMap<(NodeRef, usize), NodeRef>,
}

pub struct RegexBuilder {
Expand Down Expand Up @@ -136,8 +139,11 @@ impl GrammarBuilder {
},
strings: HashMap::new(),
curr_grammar_id: 0,
node_refs: HashMap::new(),
nodes: vec![],
regex: RegexBuilder::new(),
at_most_cache: HashMap::new(),
repeat_exact_cache: HashMap::new(),
}
}

Expand Down Expand Up @@ -174,11 +180,27 @@ impl GrammarBuilder {
}

pub fn add_node(&mut self, node: Node) -> NodeRef {
// Generate a key for the node from its serialized form if it is not the placeholder
let key = (node != self.placeholder).then(|| serde_json::to_string(&node).ok()).flatten();

// Return the node reference if it already exists
if let Some(ref key) = key {
if let Some(node_ref) = self.node_refs.get(key) {
return *node_ref;
}
}

// Create new node reference
let r = NodeRef {
idx: self.nodes.len(),
grammar_id: self.curr_grammar_id,
};

// Add the node and store the reference (if it's not the placeholder)
self.nodes.push(node);
if let Some(key) = key {
self.node_refs.insert(key, r);
}
r
}

Expand Down Expand Up @@ -321,7 +343,10 @@ impl GrammarBuilder {
// at_most() recursively factors the sequence into K-size pieces,
// in an attempt to keep grammar size O(log(n)).
fn at_most(&mut self, elt: NodeRef, n: usize) -> NodeRef {
if n == 0 {
if let Some(r) = self.at_most_cache.get(&(elt, n)) {
return *r;
}
let r = if n == 0 {
// If the max ('n') is 0, an empty rule
self.empty()
} else if n == 1 {
Expand Down Expand Up @@ -378,7 +403,9 @@ impl GrammarBuilder {
// (inclusive) in 'elt_n'. Clearly, the sequences of length at most 'n'
// are the alternation of 'elt_max_nk' and 'elt_n'.
self.select(&[elt_n, elt_max_nk])
}
};
self.at_most_cache.insert((elt, n), r);
r
}

// simple_repeat() "simply" repeats the element ('elt') 'n' times.
Expand All @@ -393,7 +420,10 @@ impl GrammarBuilder {
// Repeat element 'elt' exactly 'n' times, using factoring
// in an attempt to keep grammar size O(log(n)).
fn repeat_exact(&mut self, elt: NodeRef, n: usize) -> NodeRef {
if n > 2 * K {
if let Some(r) = self.repeat_exact_cache.get(&(elt, n)) {
return *r;
}
let r = if n > 2 * K {
// For large 'n', try to keep the number of rules O(log(n))
// by "factoring" the sequence into K-sized pieces

Expand All @@ -418,7 +448,9 @@ impl GrammarBuilder {
// For small 'n' (currently, 8 or less), simply
// repeat 'elt' 'n' times.
self.simple_repeat(elt, n)
}
};
self.repeat_exact_cache.insert((elt, n), r);
r
}

// at_least() accepts a sequence of at least 'n' copies of
Expand Down
16 changes: 13 additions & 3 deletions parser/src/json/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct JsonCompileOptions {
pub item_separator: String,
pub key_separator: String,
pub whitespace_flexible: bool,
pub coerce_one_of: bool,
}

fn json_dumps(target: &serde_json::Value) -> String {
Expand Down Expand Up @@ -98,6 +99,7 @@ impl Default for JsonCompileOptions {
item_separator: ",".to_string(),
key_separator: ":".to_string(),
whitespace_flexible: true,
coerce_one_of: false,
}
}
}
Expand Down Expand Up @@ -244,11 +246,19 @@ impl Compiler {
Ok(self.builder.string(if *value { "true" } else { "false" }))
}
Schema::AnyOf { options } => self.process_any_of(options.clone()),
Schema::OneOf { options } => self.process_any_of(options.clone()),
Schema::OneOf { options } => self.process_one_of(options.clone()),
Schema::Ref { uri, .. } => self.get_definition(uri),
}
}

fn process_one_of(&mut self, options: Vec<Schema>) -> Result<NodeRef> {
if self.options.coerce_one_of {
self.process_any_of(options)
} else {
Err(anyhow!("oneOf constraints are not supported. Enable 'coerce_one_of' option to approximate oneOf with anyOf"))
}
}

fn process_any_of(&mut self, options: Vec<Schema>) -> Result<NodeRef> {
let mut nodes = vec![];
let mut errors = vec![];
Expand All @@ -265,12 +275,12 @@ impl Compiler {
Ok(self.builder.select(&nodes))
} else if let Some(e) = errors.pop() {
Err(anyhow!(UnsatisfiableSchemaError {
message: format!("All options in anyOf/oneOf are unsatisfiable",),
message: format!("All options in anyOf are unsatisfiable",),
})
.context(e))
} else {
Err(anyhow!(UnsatisfiableSchemaError {
message: "No options in anyOf/oneOf".to_string(),
message: "No options in anyOf".to_string(),
}))
}
}
Expand Down
24 changes: 14 additions & 10 deletions parser/src/json/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ fn compile_contents_map(ctx: &Context, mut schemadict: HashMap<&str, &Value>) ->
.iter()
.map(|value| compile_resource(&ctx, ctx.as_resource_ref(value)))
.collect::<Result<Vec<_>>>()?;
let merged = intersect(ctx, options.into_iter().chain(vec![siblings]).collect())?;
let merged = intersect(ctx, vec![siblings].into_iter().chain(options.into_iter()).collect())?;
return Ok(merged);
}

Expand All @@ -439,7 +439,7 @@ fn compile_contents_map(ctx: &Context, mut schemadict: HashMap<&str, &Value>) ->
let options = any_of
.into_iter()
.map(|value| compile_resource(&ctx, ctx.as_resource_ref(value)))
.map(|res| res.and_then(|schema| intersect_two(ctx, schema, siblings.clone())))
.map(|res| res.and_then(|schema| intersect_two(ctx, siblings.clone(), schema)))
.collect::<Result<Vec<_>>>()?;
return Ok(Schema::AnyOf { options });
}
Expand All @@ -457,7 +457,7 @@ fn compile_contents_map(ctx: &Context, mut schemadict: HashMap<&str, &Value>) ->
let options = one_of
.into_iter()
.map(|value| compile_resource(&ctx, ctx.as_resource_ref(value)))
.map(|res| res.and_then(|schema| intersect_two(ctx, schema, siblings.clone())))
.map(|res| res.and_then(|schema| intersect_two(ctx, siblings.clone(), schema)))
.collect::<Result<Vec<_>>>()?;
return Ok(Schema::OneOf { options });
}
Expand All @@ -474,7 +474,7 @@ fn compile_contents_map(ctx: &Context, mut schemadict: HashMap<&str, &Value>) ->
define_ref(ctx, &uri)?;
return Ok(Schema::Ref { uri });
} else {
return intersect_ref(ctx, &uri, siblings);
return intersect_ref(ctx, &uri, siblings, false);
}
}

Expand Down Expand Up @@ -511,7 +511,7 @@ fn define_ref(ctx: &Context, ref_uri: &str) -> Result<()> {
Ok(())
}

fn intersect_ref(ctx: &Context, ref_uri: &str, schema: Schema) -> Result<Schema> {
fn intersect_ref(ctx: &Context, ref_uri: &str, schema: Schema, ref_first: bool) -> Result<Schema> {
define_ref(ctx, ref_uri)?;
let resolved_schema = ctx
.get_ref_cloned(ref_uri)
Expand All @@ -525,7 +525,11 @@ fn intersect_ref(ctx: &Context, ref_uri: &str, schema: Schema) -> Result<Schema>
ref_uri
)
})?;
intersect_two(ctx, schema, resolved_schema)
if ref_first {
intersect_two(ctx, resolved_schema, schema)
} else {
intersect_two(ctx, schema, resolved_schema)
}
}

fn compile_const(instance: &Value) -> Result<Schema> {
Expand Down Expand Up @@ -833,8 +837,8 @@ fn intersect_two(ctx: &Context, schema0: Schema, schema1: Schema) -> Result<Sche
(schema0, Schema::Any) => schema0,
(Schema::Unsatisfiable { reason }, _) => Schema::Unsatisfiable { reason },
(_, Schema::Unsatisfiable { reason }) => Schema::Unsatisfiable { reason },
(Schema::Ref { uri }, schema1) => intersect_ref(ctx, &uri, schema1)?,
(schema0, Schema::Ref { uri }) => intersect_ref(ctx, &uri, schema0)?,
(Schema::Ref { uri }, schema1) => intersect_ref(ctx, &uri, schema1, true)?,
(schema0, Schema::Ref { uri }) => intersect_ref(ctx, &uri, schema0, false)?,
(Schema::OneOf { options }, schema1) => Schema::OneOf {
options: options
.into_iter()
Expand Down Expand Up @@ -953,8 +957,8 @@ fn intersect_two(ctx: &Context, schema0: Schema, schema1: Schema) -> Result<Sche
max_items: opt_min(max1, max2),
prefix_items: {
let len = prefix1.len().max(prefix2.len());
prefix1.resize(len, items2.as_deref().cloned().unwrap_or(Schema::Any));
prefix2.resize(len, items1.as_deref().cloned().unwrap_or(Schema::Any));
prefix1.resize_with(len, || items1.as_deref().cloned().unwrap_or(Schema::Any));
prefix2.resize_with(len, || items2.as_deref().cloned().unwrap_or(Schema::Any));
prefix1
.into_iter()
.zip(prefix2.into_iter())
Expand Down
1 change: 1 addition & 0 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class JsonCompiler:
cls,
separators: Optional[Tuple[str, str]] = None,
whitespace_flexible: bool = False,
coerce_one_of: bool = False,
) -> "JsonCompiler":
"""
Create a new JSON compiler.
Expand Down
7 changes: 5 additions & 2 deletions rust/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,14 @@ struct JsonCompiler {
item_separator: String,
key_separator: String,
whitespace_flexible: bool,
coerce_one_of: bool,
}

#[pymethods]
impl JsonCompiler {
#[new]
#[pyo3(signature = (separators = None, whitespace_flexible = false))]
fn py_new(separators: Option<(String, String)>, whitespace_flexible: bool) -> Self {
#[pyo3(signature = (separators = None, whitespace_flexible = false, coerce_one_of = false))]
fn py_new(separators: Option<(String, String)>, whitespace_flexible: bool, coerce_one_of: bool) -> Self {
let (item_separator, key_separator) = separators.unwrap_or_else(|| {
if whitespace_flexible {
(",".to_owned(), ":".to_owned())
Expand All @@ -263,6 +264,7 @@ impl JsonCompiler {
item_separator: item_separator,
key_separator: key_separator,
whitespace_flexible,
coerce_one_of,
}
}
fn compile(&self, schema: &str) -> PyResult<String> {
Expand All @@ -271,6 +273,7 @@ impl JsonCompiler {
item_separator: self.item_separator.clone(),
key_separator: self.key_separator.clone(),
whitespace_flexible: self.whitespace_flexible,
coerce_one_of: self.coerce_one_of,
};
let grammar = compile_options.json_to_llg(schema).map_err(val_error)?;
serde_json::to_string(&grammar).map_err(val_error)
Expand Down

0 comments on commit a9f9f7c

Please sign in to comment.