diff --git a/parser/src/ffi_par.rs b/parser/src/ffi_par.rs index 8482978..b871180 100644 --- a/parser/src/ffi_par.rs +++ b/parser/src/ffi_par.rs @@ -11,32 +11,35 @@ fn par_compute_mask_inner(constraints: Vec) { let cc = unsafe { &mut *step.constraint }; if let Some(constraint) = &mut cc.constraint { + let mut num_copied = 0; + let mut add_eos = false; + let eos = constraint.tok_trie().eos_token() as usize; match constraint.compute_mask() { Ok(r) => { - let mut num_copied = 0; if let Some(m) = r.sample_mask.as_ref() { num_copied = std::cmp::min(m.len(), mask_elts); unsafe { std::ptr::copy_nonoverlapping(m.as_ptr(), step.mask_dest, num_copied); } } - let left = mask_elts - num_copied; - if left > 0 { - unsafe { - std::ptr::write_bytes(step.mask_dest.add(num_copied), 0, left); - } - } - if r.is_stop() { - let eos = constraint.tok_trie().eos_token() as usize; - if eos / 32 < mask_elts { - unsafe { - *step.mask_dest.add(eos / 32) |= 1 << (eos % 32); - } - } - } + add_eos = r.is_stop(); } Err(e) => cc.set_error(&e.to_string()), } + + let left = mask_elts - num_copied; + if left > 0 { + unsafe { + std::ptr::write_bytes(step.mask_dest.add(num_copied), 0, left); + } + } + if add_eos { + if eos / 32 < mask_elts { + unsafe { + *step.mask_dest.add(eos / 32) |= 1 << (eos % 32); + } + } + } } }); }