Skip to content

Commit

Permalink
Merge pull request #26 from northes/feature/builtin_function_web_search
Browse files Browse the repository at this point in the history
feature: builtin function web search
  • Loading branch information
northes authored Sep 4, 2024
2 parents 938f70a + a397c94 commit d6cbffd
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 33 deletions.
89 changes: 73 additions & 16 deletions api_chat_completions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/northes/go-moonshot"
"github.com/northes/go-moonshot/internal/httpx"
"github.com/northes/go-moonshot/test"
"github.com/stretchr/testify/require"
)

func TestChat(t *testing.T) {
Expand Down Expand Up @@ -179,24 +180,20 @@ func TestUseTools(t *testing.T) {
// check tool calls
if len(resp.Choices) != 0 {
if resp.Choices[0].FinishReason == moonshot.FinishReasonToolCalls {
for _, toolCall := range resp.Choices[0].Message.ToolCalls {
t.Logf("should tool calls: %v", test.MarshalJsonToStringX(toolCall))
if strings.HasPrefix(toolCall.ID, functionName) {
for _, tool := range resp.Choices[0].Message.ToolCalls {
t.Logf("should tool calls: %v", test.MarshalJsonToStringX(tool))
if strings.HasPrefix(tool.ID, functionName) {
// tool calls
ipInfo, err := IPLocate(ip)
if err != nil {
t.Fatal(err)
}
b, err := json.Marshal(ipInfo)
if err != nil {
t.Fatal(err)
}

builder.AddMessageFromChoices(resp.Choices)

t.Logf("tool calls result: %s", test.MarshalJsonToStringX(ipInfo))
t.Logf("tool calls result: %s", ipInfo)

builder.AddToolContent(string(b), functionName, resp.Choices[0].Message.ToolCalls[0].ID)
builder.AddToolContent(ipInfo, functionName, tool.ID)
}
}
}
Expand Down Expand Up @@ -231,17 +228,77 @@ type IPLocateInfoResponse struct {
Data *IPLocateInfo `json:"data"`
}

func IPLocate(ip string) (*IPLocateInfo, error) {
func IPLocate(ip string) (string, error) {
response, err := httpx.NewClient(fmt.Sprintf("https://apihut.co/ip/%s", ip)).Get(context.Background())
if err != nil {
return nil, err
return "", err
}
defer func() {
_ = response.Raw().Body.Close()
}()

body, err := io.ReadAll(response.Raw().Body)
if err != nil {
return "", err
}

return string(body), nil
}

func TestBuiltinFunctionWebSearch(t *testing.T) {
if test.IsGithubActions() {
return
}

cli, err := NewTestClient()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()

builder := moonshot.NewChatCompletionsBuilder()
builder.SetModel(moonshot.ModelMoonshotV1128K)
builder.AddUserContent("请搜索 Moonshot AI Context Caching 技术,并告诉我它是什么。")
builder.SetTool(&moonshot.ChatCompletionsTool{
Type: moonshot.ChatCompletionsToolTypeBuiltinFunction,
Function: &moonshot.ChatCompletionsToolFunction{
Name: moonshot.BuiltinFunctionWebSearch,
},
})

resp, err := cli.Chat().Completions(ctx, builder.ToRequest())
if err != nil {
t.Fatal(err)
}

if len(resp.Choices) != 0 {
choice := resp.Choices[0]
if choice.FinishReason == moonshot.FinishReasonToolCalls {
for _, tool := range choice.Message.ToolCalls {
t.Logf("tool calls: %v", test.MarshalJsonToStringX(tool))
if tool.Function.Name == moonshot.BuiltinFunctionWebSearch {
// web search
arguments := new(moonshot.ChatCompletionsToolBuiltinFunctionWebSearchArguments)
if err = json.Unmarshal([]byte(tool.Function.Arguments), arguments); err != nil {
t.Errorf("unmarshal tool arguments error: %v", err)
continue
}

t.Logf("tool calls result: search_id: %s, total_tokens: %d", arguments.SearchResult.SearchId, arguments.Usage.TotalTokens)

builder.AddMessageFromChoices(resp.Choices)
builder.AddToolContent(tool.Function.Arguments, tool.Function.Name, tool.ID)
}
}
}
}

respData := new(IPLocateInfoResponse)
err = response.Unmarshal(respData)
t.Logf("builder: %v", test.MarshalJsonToStringX(builder.ToRequest()))

resp, err = cli.Chat().Completions(ctx, builder.ToRequest())
if err != nil {
return nil, err
t.Fatal(err)
}

return respData.Data, nil
t.Log(test.MarshalJsonToStringX(resp))
}
13 changes: 11 additions & 2 deletions api_chat_completions_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ type ChatCompletionsTool struct {

type ChatCompletionsToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters *ChatCompletionsToolFunctionParameters `json:"parameters"`
Description string `json:"description,omitempty"`
Parameters *ChatCompletionsToolFunctionParameters `json:"parameters,omitempty"`
}

type ChatCompletionsToolFunctionParameters struct {
Expand All @@ -34,3 +34,12 @@ type ChatCompletionsResponseToolCallsFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}

type ChatCompletionsToolBuiltinFunctionWebSearchArguments struct {
SearchResult struct {
SearchId string `json:"search_id"`
} `json:"search_result"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
18 changes: 5 additions & 13 deletions api_context_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package moonshot_test

import (
"context"
"os"
"testing"
"time"

Expand All @@ -17,7 +16,7 @@ import (
// https://github.com/MoonshotAI/moonpalace

func TestContextCache(t *testing.T) {
if isGithubActions() {
if test.IsGithubActions() {
return
}
cli, err := NewTestClient()
Expand Down Expand Up @@ -93,7 +92,7 @@ func TestContextCache(t *testing.T) {
}

func TestContextCache_Create(t *testing.T) {
if isGithubActions() {
if test.IsGithubActions() {
return
}
cli, err := NewTestClient()
Expand Down Expand Up @@ -124,7 +123,7 @@ func TestContextCache_Create(t *testing.T) {
}

func TestContextCache_Delete(t *testing.T) {
if isGithubActions() {
if test.IsGithubActions() {
return
}
cli, err := NewTestClient()
Expand All @@ -144,7 +143,7 @@ func TestContextCache_Delete(t *testing.T) {
}

func TestContextCache_List(t *testing.T) {
if isGithubActions() {
if test.IsGithubActions() {
return
}
cli, err := NewTestClient()
Expand All @@ -162,7 +161,7 @@ func TestContextCache_List(t *testing.T) {
}

func TestContextCache_CreateTag(t *testing.T) {
if isGithubActions() {
if test.IsGithubActions() {
return
}
cli, err := NewTestClient()
Expand All @@ -182,10 +181,3 @@ func TestContextCache_CreateTag(t *testing.T) {
}
assert.Equal(t, "MyCacheTag", createResponse.Tag)
}

func isGithubActions() bool {
if val, ok := os.LookupEnv("GITHUB_ACTIONS"); !ok || val != "true" {
return false
}
return true
}
5 changes: 5 additions & 0 deletions enum_builtin_function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package moonshot

const (
BuiltinFunctionWebSearch string = "$web_search"
)
3 changes: 2 additions & 1 deletion enum_chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func (c ChatCompletionsFinishReason) String() string {
type ChatCompletionsToolType string

const (
ChatCompletionsToolTypeFunction ChatCompletionsToolType = "function"
ChatCompletionsToolTypeFunction ChatCompletionsToolType = "function"
ChatCompletionsToolTypeBuiltinFunction ChatCompletionsToolType = "builtin_function"
)

func (c ChatCompletionsToolType) String() string {
Expand Down
12 changes: 12 additions & 0 deletions test/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package test

import (
"os"
)

func IsGithubActions() bool {
if val, ok := os.LookupEnv("GITHUB_ACTIONS"); !ok || val != "true" {
return false
}
return true
}
2 changes: 1 addition & 1 deletion test/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

func GenerateTestContent() []byte {
return []byte("夕阳无限好")
return []byte("夕阳无限好,麦当劳汉堡")
}

func GenerateTestFile(content []byte) (string, error) {
Expand Down

0 comments on commit d6cbffd

Please sign in to comment.