Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Memory validation fix + core_memory_replace runaway content repeating fix #1616

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The builder image, used to build the virtual environment
FROM python:3.12.2-bookworm as builder
FROM python:3.12.2-bookworm AS builder
ARG MEMGPT_ENVIRONMENT=PRODUCTION
ENV MEMGPT_ENVIRONMENT=${MEMGPT_ENVIRONMENT}
RUN pip install poetry==1.8.2
Expand All @@ -14,15 +14,15 @@ WORKDIR /app
COPY pyproject.toml poetry.lock ./
RUN poetry lock --no-update
RUN if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then \
poetry install --no-root -E "postgres server dev autogen" ; \
poetry install --no-root -E "postgres server dev autogen local" ; \
else \
poetry install --no-root -E "postgres server" && \
rm -rf $POETRY_CACHE_DIR ; \
fi


# The runtime image, used to just run the code provided its virtual environment
FROM python:3.12.2-slim-bookworm as runtime
FROM python:3.12.2-slim-bookworm AS runtime
ARG MEMGPT_ENVIRONMENT=PRODUCTION
ENV MEMGPT_ENVIRONMENT=${MEMGPT_ENVIRONMENT}
ENV VIRTUAL_ENV=/app/.venv \
Expand All @@ -37,7 +37,7 @@ EXPOSE 8083
CMD ./memgpt/server/startup.sh

# allow for in-container development and testing
FROM builder as development
FROM builder AS development
ARG MEMGPT_ENVIRONMENT=PRODUCTION
ENV MEMGPT_ENVIRONMENT=${MEMGPT_ENVIRONMENT}
ENV VIRTUAL_ENV=/app/.venv \
Expand Down
2 changes: 1 addition & 1 deletion memgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.24"
__version__ = "0.3.25"

from memgpt.client.admin import Admin
from memgpt.client.client import create_client
18 changes: 17 additions & 1 deletion memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_utc_datetime,
parse_json,
printd,
truncate_to_token_limit,
united_diff,
validate_function_response,
verify_first_message_correctness,
Expand Down Expand Up @@ -805,6 +806,15 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling

# Get the context window size from the LLM config, or default to maximum if not defined
context_window = self.agent_state.llm_config.context_window
if context_window is None:
context_window = LLM_MAX_TOKENS.get(self.model, LLM_MAX_TOKENS["DEFAULT"])

# Set a token limit for individual messages
MAX_MESSAGE_TOKENS = int(0.4 * context_window)

token_counts = [count_tokens(str(msg)) for msg in self.messages]
message_buffer_token_count = sum(token_counts[1:]) # no system message
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
Expand Down Expand Up @@ -837,12 +847,18 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]"
)

# Exclude or truncate overly large messages
candidate_messages_to_summarize = [
msg if count_tokens(str(msg)) <= MAX_MESSAGE_TOKENS else truncate_to_token_limit(str(msg), MAX_MESSAGE_TOKENS)
for msg in candidate_messages_to_summarize
]

# Walk down the message buffer (front-to-back) until we hit the target token count
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(candidate_messages_to_summarize):
cutoff = i
tokens_so_far += token_counts[i]
tokens_so_far += count_tokens(str(msg))
if tokens_so_far > desired_token_count_to_summarize:
break
# Account for system message
Expand Down
2 changes: 2 additions & 0 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,8 +662,10 @@ def run(
system_prompt = system if system else None
if human_obj is None:
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
sys.exit(1)
if persona_obj is None:
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
sys.exit(1)

memory = ChatMemory(human=human_obj.text, persona=persona_obj.text, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
Expand Down
6 changes: 3 additions & 3 deletions memgpt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text
from memgpt.utils import get_human_text, get_persona_text


def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -259,7 +259,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down Expand Up @@ -729,7 +729,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# system prompt (can be templated)
system_prompt: Optional[str] = None,
# tools
Expand Down
41 changes: 31 additions & 10 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import uuid
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

Expand All @@ -19,6 +20,7 @@
)


# always run validation
class MemoryModule(BaseModel):
"""Base class for memory modules"""

Expand All @@ -28,13 +30,16 @@ class MemoryModule(BaseModel):

def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
# Temporarily set the attribute to run validation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by what this code - what is it doing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creates a copy of itself and runs validation on the value. Otherwise, even if the validation fails it still modifies the value and will go beyond the 2000 (or configured) limit.

temp = self.copy(update={name: value})
self.__class__.validate(temp.dict(exclude_unset=True))

super().__setattr__(name, value)

@validator("value", always=True)
@validator("value", always=True, check_fields=False)
def check_value_length(cls, v, values):
# TODO: this doesn't run all the time, should fix
if v is not None:
# Fetching the limit from the values dictionary
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
Expand All @@ -48,10 +53,9 @@ def check_value_length(cls, v, values):
raise ValueError("Value must be either a string or a list of strings.")

if length > limit:
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# TODO: add archival memory error?
raise ValueError(error_msg)
return v
raise ValueError(f"Value exceeds {limit} character limit (requested {length}).")

return v

def __len__(self):
return len(str(self))
Expand All @@ -71,10 +75,14 @@ def __init__(self):
self.memory = {}

@classmethod
def load(cls, state: dict):
def load(cls, state: dict, catch_overflow: bool = True):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
# TODO: will cause an error for lists
if catch_overflow and len(value["value"]) >= value["limit"]:
warnings.warn(f"Loaded section {key} exceeds character limit {value['limit']} - increasing specified memory limit.")
value["limit"] = len(value["value"])
obj.memory[key] = MemoryModule(**value)
return obj

Expand All @@ -93,6 +101,14 @@ def to_dict(self):
class ChatMemory(BaseMemory):

def __init__(self, persona: str, human: str, limit: int = 2000):
# TODO: clip if needed
# if persona and len(persona) > limit:
# warnings.warn(f"Persona exceeds {limit} character limit (requested {len(persona)}).")
# persona = persona[:limit]

# if human and len(human) > limit:
# warnings.warn(f"Human exceeds {limit} character limit (requested {len(human)}).")
# human = human[:limit]
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
Expand Down Expand Up @@ -124,7 +140,12 @@ def core_memory_replace(self, name: str, old_content: str, new_content: str) ->
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
if old_content == "":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe give it a hint like "Use core_memory_append to add new content without replacing any existing content." to give to the model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the ValueError that it gets

raise ValueError(
f"old_content can not be empty. Use core_memory_append to add new content without replacing any existing content."
)
else:
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
return None


Expand Down
23 changes: 23 additions & 0 deletions memgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,29 @@ def count_tokens(s: str, model: str = "gpt-4") -> int:
return len(encoding.encode(s))


def truncate_to_token_limit(message: str, token_limit: int, model: str = "gpt-4") -> str:
"""
Truncate the message to ensure it does not exceed the token limit.

Args:
message (str): The message to be truncated.
token_limit (int): The maximum number of tokens allowed.
model (str): The model to use for token encoding.

Returns:
str: The truncated message.
"""
encoding = tiktoken.encoding_for_model(model)
tokens = encoding.encode(message)
if len(tokens) <= token_limit:
return message

truncated_tokens = tokens[:token_limit]
truncated_message = encoding.decode(truncated_tokens)

return truncated_message


def printd(*args, **kwargs):
if DEBUG:
print(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pymemgpt"
version = "0.3.24"
version = "0.3.25"
packages = [
{include = "memgpt"}
]
Expand Down
39 changes: 39 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ def test_create_chat_memory():
assert chat_memory.memory["human"].value == "User"


def test_overflow_chat_memory():
"""Test overflowing an instance of ChatMemory"""
chat_memory = ChatMemory(persona="Chat Agent", human="User")
assert chat_memory.memory["persona"].value == "Chat Agent"
assert chat_memory.memory["human"].value == "User"

# try overflowing via core_memory_append
with pytest.raises(ValueError):
persona_limit = chat_memory.memory["persona"].limit
chat_memory.core_memory_append(name="persona", content="x" * (persona_limit + 1))

# try overflowing via core_memory_replace
with pytest.raises(ValueError):
persona_limit = chat_memory.memory["persona"].limit
chat_memory.core_memory_replace(name="persona", old_content="Chat Agent", new_content="x" * (persona_limit + 1))


def test_dump_memory_as_json(sample_memory):
"""Test dumping ChatMemory as JSON compatible dictionary"""
memory_dict = sample_memory.to_dict()
Expand Down Expand Up @@ -63,3 +80,25 @@ def test_memory_limit_validation(sample_memory):

with pytest.raises(ValueError):
sample_memory.memory["persona"].value = "x" * 3000


def test_corrupted_memory_limit(sample_memory):
"""Test what happens when a memory is stored with a value over the limit

See: https://github.com/cpacker/MemGPT/issues/1567
"""
with pytest.raises(ValueError):
ChatMemory(persona="x" * 3000, human="y" * 3000)

memory_dict = sample_memory.to_dict()
assert memory_dict["persona"]["limit"] == 2000, memory_dict

# overflow the value
memory_dict["persona"]["value"] = "x" * 2500

# by default, this should throw a value error
with pytest.raises(ValueError):
BaseMemory.load(memory_dict, catch_overflow=False)

# if we have overflow protection on, this shouldn't raise a value error
BaseMemory.load(memory_dict)