Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Google Gemini 1.5 pro/flash models #1189

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
# OPENAI_API_KEY=Your personal OpenAI API key from https://platform.openai.com/account/api-keys
OPENAI_API_KEY=...
ANTHROPIC_API_KEY=...
# GOOGLE_API_KEY=Your personal GOOGLE API key from https://aistudio.google.com/app/apikey
GOOGLE_API_KEY=...

# If not set Model Name defaults to gpt-4o
MODEL_NAME=gemini-1.5-pro-latest
15 changes: 14 additions & 1 deletion gpt_engineer/applications/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def load_env_if_needed():
if os.getenv("ANTHROPIC_API_KEY") is None:
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))

if os.getenv("GOOGLE_API_KEY") is None:
load_dotenv()
if os.getenv("GOOGLE_API_KEY") is None:
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))

def model_env():
if os.getenv("MODEL_NAME") is None:
load_dotenv()
if os.getenv("MODEL_NAME") is None:
load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))

return os.getenv("MODEL_NAME", default="gpt-4o")


def concatenate_paths(base_path, sub_path):
# Compute the relative path from base_path to sub_path
Expand Down Expand Up @@ -281,7 +294,7 @@ def format_installed_packages(packages):
def main(
project_path: str = typer.Argument(".", help="path"),
model: str = typer.Option(
os.environ.get("MODEL_NAME", "gpt-4o"), "--model", "-m", help="model id string"
os.environ.get("MODEL_NAME", model_env()), "--model", "-m", help="model id string"
),
temperature: float = typer.Option(
0.1,
Expand Down
9 changes: 9 additions & 0 deletions gpt_engineer/core/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
messages_to_dict,
)
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from gpt_engineer.core.token_usage import TokenUsageLog
Expand Down Expand Up @@ -362,6 +363,14 @@ def _create_chat_model(self) -> BaseChatModel:
streaming=self.streaming,
max_tokens_to_sample=4096,
)
elif "gemini" in self.model_name:
return ChatGoogleGenerativeAI(
model=self.model_name,
temperature=self.temperature,
streaming=self.streaming,
google_api_key=os.getenv('GOOGLE_API_KEY'),
callbacks=[StreamingStdOutCallbackHandler()]
)
elif self.vision:
return ChatOpenAI(
model=self.model_name,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pillow = "^10.2.0"
datasets = "^2.17.1"
black = "23.3.0"
langchain-community = "^0.2.0"
langchain-google-genai = "^1.0.7"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.3.1"
Expand Down
Loading