Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use token limits in core memory instead of character limits #2081

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions letta/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import letta
import letta.utils as utils
from letta.constants import (
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
CORE_MEMORY_HUMAN_TOKEN_LIMIT,
CORE_MEMORY_PERSONA_TOKEN_LIMIT,
DEFAULT_HUMAN,
DEFAULT_PERSONA,
DEFAULT_PRESET,
Expand Down Expand Up @@ -88,8 +88,8 @@ class LettaConfig:
policies_accepted: bool = False

# Default memory limits
core_memory_persona_char_limit: int = CORE_MEMORY_PERSONA_CHAR_LIMIT
core_memory_human_char_limit: int = CORE_MEMORY_HUMAN_CHAR_LIMIT
core_memory_persona_token_limit: int = CORE_MEMORY_PERSONA_TOKEN_LIMIT
core_memory_human_token_limit: int = CORE_MEMORY_HUMAN_TOKEN_LIMIT

def __post_init__(self):
# ensure types
Expand Down
15 changes: 7 additions & 8 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,8 @@
# These serve as in-context examples of how to use functions / what user messages look like
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST = 3

# Default memory limits
CORE_MEMORY_PERSONA_CHAR_LIMIT = 2000
CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000

# Function return limits
FUNCTION_RETURN_CHAR_LIMIT = 6000 # ~300 words
FUNCTION_RETURN_TOKEN_LIMIT = 1500 # ~300 words

MAX_PAUSE_HEARTBEATS = 360 # in min

Expand All @@ -155,9 +151,12 @@

RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5

# TODO Is this config or constant?
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000
# Default memory limits
CORE_MEMORY_PERSONA_TOKEN_LIMIT: int = 2000
CORE_MEMORY_HUMAN_TOKEN_LIMIT: int = 2000

MAX_FILENAME_LENGTH = 255
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}

# Default tokenizer model to use with tiktoken
DEFAULT_TIKTOKEN_MODEL = "gpt-4"
4 changes: 4 additions & 0 deletions letta/orm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import JSON, BigInteger, Integer
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.constants import DEFAULT_TIKTOKEN_MODEL
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.block import Block as PydanticBlock
Expand All @@ -28,6 +29,9 @@ class Block(OrganizationMixin, SqlalchemyBase):
)
value: Mapped[str] = mapped_column(doc="Text content of the block for the respective section of core memory.")
limit: Mapped[BigInteger] = mapped_column(Integer, default=2000, doc="Character limit of the block.")
tokenizer_model: Mapped[str] = mapped_column(
default=DEFAULT_TIKTOKEN_MODEL, doc="Tokenizer model to use for the block to enforce the token limit of value."
)
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default={}, doc="arbitrary information related to the block.")

# relationships
Expand Down
30 changes: 21 additions & 9 deletions letta/schemas/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from pydantic import Field, model_validator
from typing_extensions import Self

from letta.constants import DEFAULT_TIKTOKEN_MODEL
from letta.schemas.letta_base import LettaBase
from letta.utils import count_tokens

# block of the LLM context

Expand All @@ -15,7 +17,11 @@ class BaseBlock(LettaBase, validate_assignment=True):

# data value
value: str = Field(..., description="Value of the block.")
limit: int = Field(2000, description="Character limit of the block.")
limit: int = Field(2000, description="Token limit of the block.")
# required to enforcing the token limit
tokenizer_model: str = Field(
DEFAULT_TIKTOKEN_MODEL, description="Tokenizer model to use for the block to enforce the token limit of value."
)

# template data (optional)
template_name: Optional[str] = Field(None, description="Name of the block if it is a template.", alias="name")
Expand All @@ -28,17 +34,23 @@ class BaseBlock(LettaBase, validate_assignment=True):
description: Optional[str] = Field(None, description="Description of the block.")
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")

# @model_validator(mode="after")
# def verify_char_limit(self) -> Self:
# if len(self.value) > self.limit:
# error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}."
# raise ValueError(error_msg)

# return self

@model_validator(mode="after")
def verify_char_limit(self) -> Self:
if len(self.value) > self.limit:
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self.value)}) - {str(self)}."
def verify_token_limit(self) -> Self:
token_count = count_tokens(self.value, model=self.tokenizer_model)
if token_count > self.limit:
error_msg = f"Edit failed: Exceeds {self.limit} token limit (requested {token_count}) - {str(self)}."
raise ValueError(error_msg)

return self

# def __len__(self):
# return len(self.value)

def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
Expand All @@ -57,7 +69,7 @@ class Block(BaseBlock):
Parameters:
label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block.
value (str): The value of the block. This is the string that is represented in the context window.
limit (int): The character limit of the block.
limit (int): The token limit of the block.
is_template (bool): Whether the block is a template (e.g. saved human/persona options). Non-template blocks are not stored in the database and are ephemeral, while templated blocks are stored in the database.
label (str): The label of the block (e.g. 'human', 'persona'). This defines a category for the block.
template_name (str): The name of the block template (if it is a template).
Expand Down Expand Up @@ -110,7 +122,7 @@ class CreateHuman(BlockCreate):
class BlockUpdate(BaseBlock):
"""Update a block"""

limit: Optional[int] = Field(2000, description="Character limit of the block.")
limit: Optional[int] = Field(2000, description="Token limit of the block.")
value: Optional[str] = Field(None, description="Value of the block.")

class Config:
Expand Down
19 changes: 13 additions & 6 deletions letta/schemas/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
if TYPE_CHECKING:
from letta.agent import Agent

from letta.constants import DEFAULT_TIKTOKEN_MODEL
from letta.schemas.block import Block
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import Tool
Expand Down Expand Up @@ -67,12 +68,12 @@ class Memory(BaseModel, validate_assignment=True):
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
prompt_template: str = Field(
default="{% for block in memory.values() %}"
'<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
'<{{ block.label }} tokens="{{ block.value|length }}/{{ block.limit }}">\n'
"{{ block.value }}\n"
"</{{ block.label }}>"
"{% if not loop.last %}\n{% endif %}"
"{% endfor %}",
description="Jinja2 template for compiling memory blocks into a prompt string",
description="Jinja2 template for compiling memory blocks into a prompt string.",
)

def get_prompt_template(self) -> str:
Expand Down Expand Up @@ -237,23 +238,29 @@ class ChatMemory(BasicBlockMemory):
ChatMemory initializes a BaseChatMemory with two default blocks, `human` and `persona`.
"""

def __init__(self, persona: str, human: str, limit: int = 2000):
def __init__(self, persona: str, human: str, limit: int = 2000, tokenizer_model: str = DEFAULT_TIKTOKEN_MODEL):
"""
Initialize the ChatMemory object with a persona and human string.

Args:
persona (str): The starter value for the persona block.
human (str): The starter value for the human block.
limit (int): The character limit for each block.
limit (int): The token limit for each block.
"""
super().__init__()
self.link_block(block=Block(value=persona, limit=limit, label="persona"))
self.link_block(block=Block(value=human, limit=limit, label="human"))
self.link_block(block=Block(value=persona, limit=limit, label="persona", tokenizer_model=tokenizer_model))
self.link_block(block=Block(value=human, limit=limit, label="human", tokenizer_model=tokenizer_model))


class UpdateMemory(BaseModel):
"""Update the memory of the agent"""

# Memory.memory is a dict mapping from memory block label to memory block.
memory: Optional[Dict[str, Block]] = Field(None, description="Mapping from memory block section to memory block.")

# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
prompt_template: Optional[str] = Field(None, description="Jinja2 template for compiling memory blocks into a prompt string.")


class ArchivalMemorySummary(BaseModel):
size: int = Field(..., description="Number of rows in archival memory")
Expand Down
37 changes: 15 additions & 22 deletions letta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import letta
from letta.constants import (
CLI_WARNING_PREFIX,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
FUNCTION_RETURN_CHAR_LIMIT,
CORE_MEMORY_HUMAN_TOKEN_LIMIT,
CORE_MEMORY_PERSONA_TOKEN_LIMIT,
DEFAULT_TIKTOKEN_MODEL,
FUNCTION_RETURN_TOKEN_LIMIT,
LETTA_DIR,
MAX_FILENAME_LENGTH,
TOOL_CALL_ID_MAX_LEN,
Expand Down Expand Up @@ -790,7 +791,7 @@ def find_class(self, module, name):
return super().find_class(module, name)


def count_tokens(s: str, model: str = "gpt-4") -> int:
def count_tokens(s: str, model: str = DEFAULT_TIKTOKEN_MODEL) -> int:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(s))

Expand Down Expand Up @@ -927,11 +928,10 @@ def validate_function_response(function_response_string: any, strict: bool = Fal

# Now check the length and make sure it doesn't go over the limit
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
if truncate and len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT:
print(
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated"
)
function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]"
token_count = count_tokens(function_response_string)
if truncate and token_count > FUNCTION_RETURN_TOKEN_LIMIT:
print(f"{CLI_WARNING_PREFIX}function return was over limit ({token_count} > {FUNCTION_RETURN_TOKEN_LIMIT}) and was truncated")
function_response_string = f"{function_response_string[:FUNCTION_RETURN_TOKEN_LIMIT]}... [NOTE: function output was truncated since it exceeded the token limit ({token_count} > {FUNCTION_RETURN_TOKEN_LIMIT})]"

return function_response_string

Expand Down Expand Up @@ -994,8 +994,9 @@ def get_human_text(name: str, enforce_limit=True):
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
human_text = open(file_path, "r", encoding="utf-8").read().strip()
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
token_count = count_tokens(human_text, model=DEFAULT_TIKTOKEN_MODEL)
if enforce_limit and token_count > CORE_MEMORY_HUMAN_TOKEN_LIMIT:
raise ValueError(f"Contents of {name}.txt is over the token limit ({token_count} > {CORE_MEMORY_HUMAN_TOKEN_LIMIT})")
return human_text

raise ValueError(f"Human {name}.txt not found")
Expand All @@ -1006,22 +1007,14 @@ def get_persona_text(name: str, enforce_limit=True):
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
persona_text = open(file_path, "r", encoding="utf-8").read().strip()
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
raise ValueError(
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
)
token_count = count_tokens(persona_text, model=DEFAULT_TIKTOKEN_MODEL)
if enforce_limit and token_count > CORE_MEMORY_PERSONA_TOKEN_LIMIT:
raise ValueError(f"Contents of {name}.txt is over the token limit ({token_count} > {CORE_MEMORY_PERSONA_TOKEN_LIMIT})")
return persona_text

raise ValueError(f"Persona {name}.txt not found")


def get_human_text(name: str):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
return open(file_path, "r", encoding="utf-8").read().strip()


def get_schema_diff(schema_a, schema_b):
# Assuming f_schema and linked_function['json_schema'] are your JSON schemas
f_schema_json = json_dumps(schema_a)
Expand Down
Loading