-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multimodal support with Phi 3 Vision + Transformers #1020
base: main
Are you sure you want to change the base?
Changes from 250 commits
ae1baa6
077e4b9
ede3333
ed4b220
d26536c
6f98db5
7848c09
2351793
c8f89fc
ecb902c
46c1cfb
1fd2445
791ce57
6293930
ffd51b3
97a91fc
161cd9f
eadd79e
126aaa1
55b3079
3531c43
c59d1a2
36301e8
bfe67c2
0ebb59c
4423906
ee133e7
7f35ed0
60855b6
a6e7f33
d40c5ae
acddbc6
a4ec9a4
217f8ee
7580624
36a4c7e
e677c14
063a011
85908b0
5d966e0
a8469fc
815ed6b
e1645c1
3dd9019
a031f20
34aa88b
6b82dbf
a88456f
2e40d43
4b8e423
3aca8e2
20db158
f8ef812
feb6712
4acc833
ac05ebe
959f2b9
6372a15
0a5a5d1
b545d04
28b2309
1afc9bb
9ef9b07
87c80b8
b00a695
cf56a9e
606b208
546a013
1b26e13
28171f0
34fd004
c4c4872
e4f37de
25f9caf
43a105f
bb9e742
6b30a97
84c1445
6f81987
f76b82f
b60641b
733da2c
85f80b6
86cd01b
142dfa9
cb216f9
68f256b
b7d892a
724cd28
9dc7b86
61b857c
925e0f5
23433ee
9187500
8bf616c
c724445
578fb2c
4bc1e3f
d59cfd5
d66648d
722d682
b05fd96
c15877c
22e5704
3b2e762
b0972ca
e8b57f0
e33588e
3703396
653bf78
28c786d
1a46468
e122cd1
fbb6293
7ec0168
4798474
2c35ad9
857f5e3
c848b22
c2d3afb
9408a98
d6bcb3e
51814ec
d84292a
281cab5
64e9574
198da5b
a32a9e3
c58c275
df2a3d9
a389885
cb61d81
37cf618
77c5dbd
181ab71
f6bd641
9268784
174c481
215bd30
fe8df69
8a6b637
3b31b2c
fae0e9c
8ef75f8
ea5d4e9
dcc86f9
7064c68
101eeae
fd2ca41
53d3df6
53171e1
0d0e249
1b9774e
96ec5ce
8f9bede
9eec277
3461a2b
e7abb2b
312262d
1b3b360
772ecdc
ba3f7f9
3bda4d7
9118bb3
e9048c3
f9b4195
44fed2f
aad4b5a
1914c0b
acce0ef
c5b6997
65adc62
99d2406
24c9820
cdf18e8
0e78ae9
9235709
b896ca0
36fca3c
b9ac1e9
4ee9cff
df1146e
e51ae59
024bd33
7afa96a
9fd36e1
8475d88
7e72483
f77f733
017313b
95314e1
2edb750
9699fa1
fa5d578
125028f
f5b5cfc
30163a2
b2a249a
616b366
7eda809
bf03d78
6bd9fb3
aea428a
824f836
012ac7c
b85f74d
7027d77
c3f128f
da2cf45
278cc42
be35d3b
6619093
e0b010e
29a1ba6
f922ccc
5fa82db
2772297
22a2827
3982ccb
7fbd550
0dd84b7
ec6f43f
5755165
d5e0ac8
04fcd9f
f61d620
22c493b
30a8559
3520114
e6260c2
d046ef5
178cda5
a72352b
2a0f3f7
9384b25
c7b89fc
d474e6a
38cecb1
b4a2947
d7c5c10
cc8ac87
73fa881
761326b
a135311
160a449
2b7410b
4b46880
105d648
af9d11b
ee91785
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import json | ||
import os | ||
from typing import Any, Generator, Optional, Tuple, Union | ||
from typing import Any, Generator, Optional, Sequence, Tuple, Union | ||
|
||
import llguidance # type: ignore[import-untyped] | ||
import numpy as np | ||
|
@@ -30,29 +30,11 @@ class TokenParser: | |
|
||
def __init__( | ||
self, | ||
grammar: Union[GrammarFunction, str], | ||
tokenizer: Tokenizer, | ||
prompt: bytes = b"", | ||
ensure_bos_token: bool = True, | ||
ll_interpreter: llguidance.LLInterpreter, | ||
prompt_tokens: list[int] | ||
): | ||
if isinstance(grammar, GrammarFunction): | ||
# we can't have a terminal as the root | ||
if isinstance(grammar, Terminal): | ||
grammar = Join([grammar]) | ||
serialized_grammar = json.dumps(grammar.ll_serialize()) | ||
else: | ||
serialized_grammar = grammar | ||
|
||
self.tokenizer = tokenizer | ||
self.ll_tokenizer = llguidance.LLTokenizer( | ||
llguidance.TokenizerWrapper(tokenizer) | ||
) | ||
self.ll_interpreter = llguidance.LLInterpreter( | ||
self.ll_tokenizer, | ||
serialized_grammar, | ||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), | ||
) | ||
self._generator = self._parse(prompt, ensure_bos_token) | ||
self.ll_interpreter = ll_interpreter | ||
self._generator = self._parse(prompt_tokens) | ||
self._done = False | ||
|
||
def is_accepting(self) -> bool: | ||
|
@@ -70,28 +52,10 @@ def advance( | |
self._done = True | ||
return None, e.value | ||
|
||
def _process_prompt(self, prompt: bytes, ensure_bos_token: bool) -> list[int]: | ||
prompt_tokens = self.ll_interpreter.process_prompt( | ||
self.tokenizer.encode(prompt) | ||
) | ||
if ( | ||
ensure_bos_token | ||
and self.tokenizer.bos_token is not None | ||
and prompt_tokens[:1] != [self.tokenizer.bos_token_id] | ||
): | ||
# add the beginning of sequence token if needed | ||
prompt_tokens = [self.tokenizer.bos_token_id] + prompt_tokens | ||
|
||
return self.tokenizer.recode(prompt_tokens) | ||
|
||
|
||
def _parse( | ||
self, | ||
prompt: bytes, | ||
ensure_bos_token: bool, | ||
tokens: list[int], | ||
) -> Generator[Tuple[Optional[GenData], EngineCallResponse], Optional[int], EngineCallResponse]: | ||
tokens = self._process_prompt(prompt=prompt, ensure_bos_token=ensure_bos_token) | ||
|
||
while True: | ||
mask, resp = self.ll_interpreter.mid_process() | ||
r = LLInterpreterResponse.model_validate_json(resp) | ||
|
@@ -133,6 +97,54 @@ def _parse( | |
return response | ||
|
||
|
||
def process_prompt(prompt_tokens: Sequence[int], ll_interpreter: llguidance.LLInterpreter, bos_token_id: Optional[int]=None) -> list[int]: | ||
# Allows ll_interpreter to make adjustments to prompt tokens, such as token healing | ||
processed_tokens = ll_interpreter.process_prompt(prompt_tokens) | ||
if ( | ||
bos_token_id is not None | ||
and prompt_tokens[:1] != [bos_token_id] | ||
): | ||
# add the beginning of sequence token if needed | ||
processed_tokens = [bos_token_id] + processed_tokens | ||
|
||
return processed_tokens | ||
|
||
|
||
def process_grammar(grammar: Union[GrammarFunction, str]) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we give this a more informative name? E.g. |
||
if isinstance(grammar, GrammarFunction): | ||
# we can't have a terminal as the root | ||
if isinstance(grammar, Terminal): | ||
grammar = Join([grammar]) | ||
return json.dumps(grammar.ll_serialize()) | ||
else: | ||
return grammar | ||
|
||
|
||
def create_token_parser( | ||
grammar: Union[GrammarFunction, str], | ||
tokenizer: Tokenizer, | ||
prompt: bytes = b"", | ||
ensure_bos_token: bool = True, | ||
trace: bool = False | ||
) -> TokenParser: | ||
serialized_grammar = process_grammar(grammar) | ||
ll_tokenizer = llguidance.LLTokenizer( | ||
llguidance.TokenizerWrapper(tokenizer) | ||
) | ||
ll_interpreter = llguidance.LLInterpreter( | ||
ll_tokenizer, | ||
serialized_grammar, | ||
log_level=2 if trace else int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), | ||
) | ||
if ensure_bos_token and tokenizer.bos_token_id is not None: | ||
bos_token_id = tokenizer.bos_token_id | ||
else: | ||
bos_token_id = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tiny nitpick -- you don't need to check the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should throw an error in the case where ensure_bos_token is True and tokenizer.bos_token_id is None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mmm good question. To my knowledge, we're never calling this with ensure_bos_token set to False... In other words, I'm not sure the kwarg is really providing any value. I don't think an exception is really necessary in this case. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just going to log a warning in that case for now |
||
prompt_tokens = tokenizer.encode(prompt) | ||
processed_tokens = process_prompt(prompt_tokens, ll_interpreter, bos_token_id) | ||
return TokenParser(ll_interpreter, processed_tokens) | ||
|
||
|
||
class ByteParserException(Exception): | ||
def __init__(self, *args, **kwargs): | ||
self.current_byte = kwargs.pop("current_byte", None) | ||
|
@@ -149,7 +161,7 @@ def __init__( | |
ensure_bos_token: bool = True, | ||
): | ||
self.tokenizer = ByteTokenizer() | ||
self.token_parser = TokenParser(grammar, self.tokenizer, prompt, ensure_bos_token) | ||
self.token_parser = create_token_parser(grammar, self.tokenizer, prompt, ensure_bos_token) | ||
self.bytes = b"" | ||
self.gen_data: Optional[GenData] = None | ||
self.pos = 0 | ||
|
@@ -289,3 +301,4 @@ def _update_capture(self, response: EngineCallResponse): | |
pass | ||
self._variables[k] = v | ||
self._variables_log_probs[k] = response.capture_group_log_probs[k] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tokens will likely need to be recoded before being sent to the LLM for logits (I think only in the case that we just added a BOS token).
You could probably just throw that line in
create_token_parser
after callingprocess_prompt
..? Just since you'll have access to a tokenizer there.See
tests/model_integration/test_model.py::test_associativity