diff --git a/Cargo.lock b/Cargo.lock index 725a550..5153db9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,7 +138,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -149,7 +149,7 @@ checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -257,7 +257,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.77", + "syn", ] [[package]] @@ -332,13 +332,13 @@ checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" [[package]] name = "bytemuck_derive" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" +checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -367,7 +367,7 @@ dependencies = [ "quote", "serde", "serde_json", - "syn 2.0.77", + "syn", "tempfile", "toml", ] @@ -440,7 +440,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim 0.11.1", + "strsim", ] [[package]] @@ -452,7 +452,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -551,9 +551,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.14.4" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" dependencies = [ "darling_core", "darling_macro", @@ -561,58 +561,58 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.14.4" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", - "strsim 0.10.0", - "syn 1.0.109", + "strsim", + "syn", ] [[package]] name = "darling_macro" -version = "0.14.4" +version = "0.20.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "derive_builder" -version = "0.12.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" dependencies = [ "derive_builder_macro", ] [[package]] name = "derive_builder_core" -version = "0.12.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ "darling", "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "derive_builder_macro" -version = "0.12.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 1.0.109", + "syn", ] [[package]] @@ -669,7 +669,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1147,7 +1147,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1543,7 +1543,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1727,7 +1727,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1826,7 +1826,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1857,7 +1857,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1900,7 +1900,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" dependencies = [ "proc-macro2", - "syn 2.0.77", + "syn", ] [[package]] @@ -1959,7 +1959,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -1972,7 +1972,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -2082,7 +2082,7 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -2288,7 +2288,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -2401,12 +2401,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - [[package]] name = "strsim" version = "0.11.1" @@ -2419,17 +2413,6 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.77" @@ -2461,7 +2444,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -2500,7 +2483,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -2530,12 +2513,11 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" -version = "0.15.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d" +checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" dependencies = [ "aho-corasick", - "clap", "derive_builder", "esaxx-rs", "getrandom", @@ -2588,7 +2570,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -2909,7 +2891,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.77", + "syn", "wasm-bindgen-shared", ] @@ -2931,7 +2913,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3149,7 +3131,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", "synstructure", ] @@ -3171,7 +3153,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] [[package]] @@ -3191,7 +3173,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", "synstructure", ] @@ -3220,5 +3202,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.77", + "syn", ] diff --git a/llgtrt/src/async_exec.rs b/llgtrt/src/async_exec.rs index 4fb57f3..fa7f75f 100644 --- a/llgtrt/src/async_exec.rs +++ b/llgtrt/src/async_exec.rs @@ -110,7 +110,9 @@ impl PendingSeq { log::trace!("Tokens: {}", llg.tok_trie().tokens_dbg(tokens)); - if tokens.len() > self.prompt_len { + let step_res = if tokens.len() > self.prompt_len { + // if we're past the prompt, commit last token + // and compute mask let tok = *tokens.last().unwrap(); let r = llg.commit_token(Some(tok))?; @@ -122,16 +124,25 @@ impl PendingSeq { assert!(r.ff_tokens.len() == 1); assert!(r.ff_tokens[0] == tok); - } - - let res = llg.compute_mask()?; + llg.compute_mask()? + } else { + // if we're still in prompt + if !llg.has_current_step_result() { + // first time, compute the mask + llg.compute_mask()? + } else { + // if trtllm wants to call us multiple times for the prompt + // (this happens due to chunked prefill), we re-use the first mask + llg.step_result() + } + }; - if res.is_stop() { + if step_res.is_stop() { self.stop = true; return Ok(()); } - let mask = res.sample_mask.as_ref().expect("No mask"); + let mask = step_res.sample_mask.as_ref().expect("No mask"); self.entry.out_mask_pointer = copy_mask(mask); self.entry.temperature = llg.temperature; @@ -178,6 +189,12 @@ extern "C" fn logits_processor(logits: *mut TlcLogitsEntry, num_logits: u32) { let mut pending_assignments = vec![]; for (idx, entry) in entries.iter().enumerate() { if let Some(rd) = exec.req_data.get_mut(&entry.client_req_id()) { + log::debug!( + "llg: {}: {} tokens ({} prompt tokens)", + entry.req_id(), + entry._num_tokens, + rd.prompt_len + ); if let Some(llg_idx) = rd .llg_infos .iter() diff --git a/llgtrt/src/routes/completions.rs b/llgtrt/src/routes/completions.rs index 80634ba..228a9b4 100644 --- a/llgtrt/src/routes/completions.rs +++ b/llgtrt/src/routes/completions.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, bail, ensure, Result}; use async_stream::try_stream; use axum::extract::State; use axum::http::HeaderMap; -use axum::response::sse::{Event, KeepAlive, Sse}; +use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; use axum::Json; use futures_core::Stream; diff --git a/llguidance b/llguidance index 7663176..1108f6e 160000 --- a/llguidance +++ b/llguidance @@ -1 +1 @@ -Subproject commit 7663176a55490790521e4e4892b59f8f0a2680b1 +Subproject commit 1108f6edddf96ae143036cc4609da0560249e41c diff --git a/scripts/req.py b/scripts/req.py index da8c61d..9d8bd15 100644 --- a/scripts/req.py +++ b/scripts/req.py @@ -7,10 +7,12 @@ import time import argparse -PROMPT_SIZE = 50_000 +PROMPT_SIZE = 1_00 NUM_THREADS = 10 NUM_REPS = 3 LLG = False +MAX_TOKENS = 50 +NUM_JOKES = 10 TRT_API_BASE = os.getenv("TRT_API_BASE") if TRT_API_BASE is None or TRT_API_BASE == "": @@ -74,30 +76,35 @@ def llg_data(): "grammar": json.loads(grammar), }, "messages": joke_msg(), - "max_tokens": 50, + "max_tokens": MAX_TOKENS, "temperature": 0.8, } def req_data(): + properties = {} + required = [] + for idx in range(NUM_JOKES): + properties[f"joke_{idx}"] = {"type": "string"} + properties[f"rating_{idx}"] = {"type": "number"} + required.extend([f"joke_{idx}", f"rating_{idx}"]) return { "model": "model", "messages": joke_msg(), ("response_format" if LLG else "ignore_me"): { "type": "json_schema", - "strict": True, - "schema": { - "type": "object", - "properties": { - "joke": {"type": "string"}, - "rating": {"type": "number"}, + "json_schema": { + "strict": True, + "schema": { + "type": "object", + "properties": properties, + "additionalProperties": False, + "required": required, }, - "additionalProperties": False, - "required": ["joke", "rating"], }, }, # "llg_log_level": "json", - "max_tokens": 50, + "max_tokens": MAX_TOKENS, "temperature": 0.8, } @@ -116,6 +123,7 @@ def finalize(self): self.completion_tokens = self.usage.get("completion_tokens", 0) self.completion_tokens2 = len(self.tbt) self.text = "".join(self.text_chunks) + print(self.text) if not self.tbt: self.avg_tbt = 0 self.med_tbt = 0 @@ -174,7 +182,9 @@ def send_one_stream(data: dict) -> list[Results]: if data["object"] == "initial-run": continue - idx: int = data["choices"][0]["index"] if not is_run else data["forks"][0]["index"] + idx: int = ( + data["choices"][0]["index"] if not is_run else data["forks"][0]["index"] + ) res = results[idx] now = time.monotonic() @@ -240,12 +250,23 @@ def main(): random.seed(0) parser = argparse.ArgumentParser() parser.add_argument("--max_threads", type=int, default=0) + parser.add_argument("--sessions", type=int, default=0) args = parser.parse_args() - if args.max_threads > 0: - global LLG - global NUM_THREADS + global LLG, NUM_REPS, NUM_THREADS, MAX_TOKENS, NUM_JOKES, PROMPT_SIZE + + if args.sessions > 0: + LLG = True + NUM_THREADS = args.sessions + PROMPT_SIZE = 2600 + NUM_REPS = 1 + NUM_JOKES = 100 + MAX_TOKENS = 4000 + one_round() + return + + if args.max_threads > 0: thr = 1 def csv_line(lst): @@ -280,7 +301,7 @@ def csv_line(lst): return - #d = llg_data() + # d = llg_data() d = req_data() d["n"] = 1 d["temperature"] = 1.0 diff --git a/toktrie b/toktrie index 972825d..6172936 160000 --- a/toktrie +++ b/toktrie @@ -1 +1 @@ -Subproject commit 972825d6c2090de141e948154a48ed31816c3217 +Subproject commit 6172936f8c965d2050a53d14de0e3410ecc78ad1