Skip to content

Commit

Permalink
return eos mask on error
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Nov 12, 2024
1 parent ae0c43f commit 84e5f41
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions parser/src/ffi_par.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,35 @@ fn par_compute_mask_inner(constraints: Vec<LlgConstraintStep>) {

let cc = unsafe { &mut *step.constraint };
if let Some(constraint) = &mut cc.constraint {
let mut num_copied = 0;
let mut add_eos = false;
let eos = constraint.tok_trie().eos_token() as usize;
match constraint.compute_mask() {
Ok(r) => {
let mut num_copied = 0;
if let Some(m) = r.sample_mask.as_ref() {
num_copied = std::cmp::min(m.len(), mask_elts);
unsafe {
std::ptr::copy_nonoverlapping(m.as_ptr(), step.mask_dest, num_copied);
}
}
let left = mask_elts - num_copied;
if left > 0 {
unsafe {
std::ptr::write_bytes(step.mask_dest.add(num_copied), 0, left);
}
}
if r.is_stop() {
let eos = constraint.tok_trie().eos_token() as usize;
if eos / 32 < mask_elts {
unsafe {
*step.mask_dest.add(eos / 32) |= 1 << (eos % 32);
}
}
}
add_eos = r.is_stop();
}
Err(e) => cc.set_error(&e.to_string()),
}

let left = mask_elts - num_copied;
if left > 0 {
unsafe {
std::ptr::write_bytes(step.mask_dest.add(num_copied), 0, left);
}
}
if add_eos {
if eos / 32 < mask_elts {
unsafe {
*step.mask_dest.add(eos / 32) |= 1 << (eos % 32);
}
}
}
}
});
}
Expand Down

0 comments on commit 84e5f41

Please sign in to comment.