Skip to content

Commit

Permalink
Merge pull request #704 from guidance-ai/fix-unicode-transformers
Browse files Browse the repository at this point in the history
Fix #681 and #682
  • Loading branch information
slundberg authored Mar 18, 2024
2 parents 7873d33 + cd250cd commit e2f517a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
55 changes: 38 additions & 17 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,31 @@ def __init__(self, model, tokenizer, ignore_bos_token=False):

# build out the set of byte_string tokens
if hasattr(tokenizer, "byte_decoder"):
byte_tokens = []
for i in range(len(tokenizer)):
byte_coded = bytes([tokenizer.byte_decoder[c] for c in tokenizer.convert_ids_to_tokens(i)])
byte_tokens.append(byte_coded)
elif hasattr(tokenizer, "convert_tokens_to_string"):
byte_tokens = []
for i in range(len(tokenizer)):
s = tokenizer.convert_tokens_to_string(['a', tokenizer.convert_ids_to_tokens(i)])
if s[0] == 'a':
s = s[1:]
elif s[1] == 'a':
s = s[2:]
else:
raise Exception("Can't determine tokenstring representation!")
byte_tokens.append(bytes(s, encoding="utf8"))
byte_decoder = tokenizer.byte_decoder
else:
raise Exception("Invalid tokenizer object!")
import transformers
byte_decoder = transformers.AutoTokenizer.from_pretrained("gpt2", use_fast=False).byte_decoder # fall back to gpt2 mapping

# some special tokens may not have their whitespace encoded...
byte_decoder[' '] = 32
byte_decoder['\n'] = 10
byte_decoder['\r'] = 13
byte_decoder['\t'] = 9

# run a quick spot check to verify we can rebuild complex multi-token unicode suymbols
s = "’•¶∂ƒ˙∆£Ħ爨ൠᅘ∰፨"
t = tokenizer
reconstructed = b''
for id in t(s)["input_ids"]:
reconstructed += bytes([byte_decoder[c] for c in t.convert_ids_to_tokens(id)])
assert reconstructed.decode() == s, "The passed tokenizer does have a byte_decoder property and using a standard gpt2 byte_decoder fails!"

byte_tokens = []
for i in range(len(tokenizer)):
byte_coded = bytes([byte_decoder[c] for c in tokenizer.convert_ids_to_tokens(i)])
byte_tokens.append(byte_coded)



# the superclass does most of the work once we have the tokens
super().__init__(
Expand Down Expand Up @@ -108,8 +116,21 @@ def _model(self, model, **kwargs):

def _joint_tokenize(self, token_ids):
# first_decode = self.tokenizer._orig_tokenizer.decode(token_ids)
first_decode = b''.join([self.tokenizer.tokens[id] for id in token_ids]).decode("utf8")

# the encode/decode cycle might not work if we have partial unicode strings
used_tokens = len(token_ids)
for _ in range(3):
try:
first_decode = b''.join([self.tokenizer.tokens[id] for id in token_ids[:used_tokens]]).decode("utf8")
except UnicodeDecodeError:
if used_tokens == 0:
break
else:
used_tokens -= 1

new_ids = self.tokenizer._orig_tokenizer(first_decode, add_special_tokens=False)["input_ids"]
if used_tokens < len(token_ids):
new_ids += token_ids[used_tokens:]

# HACK: check for a bug in the HuggingFace tokenizer (that will just add extra spaces during an encode-decode cycle)
second_decode = self.tokenizer._orig_tokenizer.decode(new_ids)
Expand Down
2 changes: 1 addition & 1 deletion guidance/models/vertexai/_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, model, tokenizer=None, echo=True, max_streaming_tokens=None,
found_subclass = vertexai_subclasses.PaLM2Chat

# Gemini2Chat
elif re.match("gemini-pro(@[0-9]+)?", model_name):
elif re.match("gemini-(pro|ultra)(@[0-9]+)?", model_name):
found_subclass = vertexai_subclasses.GeminiChat

# convert to any found subclass
Expand Down

0 comments on commit e2f517a

Please sign in to comment.