Skip to content

Commit

Permalink
fix: use apply_chat_template to find turn boundaries and allow tool_c…
Browse files Browse the repository at this point in the history
…alling field (axolotl-ai-cloud#2179) [skip ci]

* fix: use apply_chat_template to find turn boundaries and allow tool_calling field

* fix: keys to include in turn

* feat(doc): explicitly recommend setting train_on_eos and roles_to_train

* fix: eos not being masked for tool due to template padding

* chore: clear up docs

* fix: default messages format, train_on_eos: turn, and train on all assistant msg

* fix: properly warn if empty content

* feat: parametrize chat_template tests to test different tokenizers

* fix: set proper default for message key

* fix: update defaults to match load function

* fix: change defaults to use new

* feat: add tool_calling dataset

* feat: add tool_calling test

* fix: add handling of edge case of mistral tokenizer with only system prompt

* feat: refactor all test to follow source code

* fix: remove unnecessary eos_token from phi35

* fix test for phi3.5 since eos was dropped from chat_template

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
NanoCode012 and winglian authored Dec 17, 2024
1 parent 339f3c6 commit 10cfecf
Show file tree
Hide file tree
Showing 7 changed files with 929 additions and 357 deletions.
26 changes: 16 additions & 10 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -127,34 +127,40 @@ datasets:
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).

# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".

# Key containing the messages (default: "messages")
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
# Key for role in each message (default: "role")
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
# Key for content in each message (default: "content")
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.

# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
assistant: ["gpt", "assistant"]
system: ["system"]
tool: ["tool"]

## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 4 fields are empty, defaults to training only on the last message.

# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
roles_to_train: ["assistant"] # default
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail


Expand Down
6 changes: 3 additions & 3 deletions docs/dataset-formats/conversation.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ We recommend checking the below examples for other usecases.
datasets:
- path: ...
type: chat_template
roles_to_train:
train_on_eos:
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
Expand All @@ -77,7 +79,7 @@ chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
roles_to_train: ["assistant"] # default value
```

3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
Expand All @@ -87,7 +89,6 @@ chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```

4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
Expand All @@ -99,7 +100,6 @@ chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
Expand Down
163 changes: 102 additions & 61 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
message_field_role: str = "role",
message_field_content: str = "content",
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
Expand All @@ -41,6 +41,7 @@ def __init__(
"assistant": "assistant",
"gpt": "assistant",
"system": "system",
"tool": "tool",
}

self.message_field_role = message_field_role
Expand Down Expand Up @@ -188,7 +189,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""

_messages = "conversations"
_messages = "messages"

def __init__(
self,
Expand Down Expand Up @@ -279,12 +280,7 @@ def tokenize_prompt(self, prompt):

LOG.debug(f"Should train: {should_train}")

turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)

if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)

LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")

Expand Down Expand Up @@ -313,8 +309,8 @@ def tokenize_prompt(self, prompt):
LOG.debug(f"Labels after processing turn {index}: {labels}")

# Handle EOS token
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
if eos_idx == turn_end_idx:
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
Expand All @@ -339,75 +335,120 @@ def tokenize_prompt(self, prompt):
"attention_mask": [1] * len(input_ids),
}

def find_eos_token(self, input_ids, start_idx):
def find_first_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1

def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
def find_turn(self, turns: list[dict], turn_idx: int):
"""
Locate the starting and ending indices of the specified turn in a conversation.
"""
content = turn_content.get("content")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
# pylint: disable=too-many-return-statements

LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
if turn_idx >= len(turns):
raise ValueError(f"Turn index {turn_idx} out of range")

if not content_ids:
LOG.warning(f"Empty content for turn {turn}")
# mistral does not output message if it contains only system message
if (
turn_idx == 0
and turns[0].get("role") == "system"
and "mistral" in self.tokenizer.name_or_path.lower()
):
return -1, -1

# For first turn, start from beginning
if turn == 0:
start_search_idx = 0
else:
# For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0

for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break

# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1

if last_index < start_search_idx:
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}

# Create conversation versions
turns_with_empty = turns[:turn_idx] + [empty_turn]
turns_with_content = turns[: turn_idx + 1]

# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore

# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore

if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")
return -1, -1

# Find first difference (start of content)
start_idx = None
min_len = min(len(dummy_ids), len(full_ids))
for i in range(min_len):
if dummy_ids[i] != full_ids[i]:
start_idx = i
break

if start_idx is None:
LOG.warning(f"Could not find content start boundary for turn {turn_idx}")
return -1, -1

# Find last difference (end of content)
end_idx = None
for i in range(min_len):
dummy_pos = len(dummy_ids) - 1 - i
full_pos = len(full_ids) - 1 - i
if dummy_ids[dummy_pos] != full_ids[full_pos]:
end_idx = full_pos + 1 # Add one to include the last token when slice
break

if end_idx is None:
LOG.warning(f"Could not find content end boundary for turn {turn_idx}")
return -1, -1

if end_idx < start_idx:
LOG.warning(
f"Content end boundary is before start boundary for turn {turn_idx}"
)
return -1, -1

if end_idx == start_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
f"Content end boundary is the same as start boundary for turn {turn_idx}. This is likely an empty turn."
)
return -1, -1

# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)
LOG.debug(f"Content boundaries: {start_idx}, {end_idx}")
LOG.debug(
f"Content tokens: {self.tokenizer.convert_ids_to_tokens(full_ids[start_idx:end_idx])}"
)

return -1, -1
return start_idx, end_idx

def get_conversation_thread(self, prompt):
turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
turns = []
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
turn = {
"role": self.prompter.roles[message[self.prompter.message_field_role]],
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}

# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content

for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value

turns.append(turn)

if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
Expand Down Expand Up @@ -446,8 +487,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}

strategy = ChatTemplateStrategy(
Expand Down
Loading

0 comments on commit 10cfecf

Please sign in to comment.