diff --git a/providers/anthropic/anthropic.go b/providers/anthropic/anthropic.go index 4be919e09..370fccf6a 100644 --- a/providers/anthropic/anthropic.go +++ b/providers/anthropic/anthropic.go @@ -11,6 +11,7 @@ import ( "io" "maps" "math" + "slices" "strings" "charm.land/fantasy" @@ -496,6 +497,106 @@ func reasoningEndProviderMetadata(contentBlock anthropic.ContentBlockUnion) fant } } +type providerExecutedOperation struct { + id string + name string + input string +} + +func newProviderExecutedOperation(id, name, input string) providerExecutedOperation { + return providerExecutedOperation{id: id, name: name, input: input} +} + +func (op providerExecutedOperation) appendContent(content []fantasy.Content) []fantasy.Content { + return append(content, fantasy.ToolCallContent{ + ToolCallID: op.id, + ToolName: op.name, + Input: op.input, + ProviderExecuted: true, + }) +} + +func (op providerExecutedOperation) yieldToolCall(yield func(fantasy.StreamPart) bool) bool { + // Provider-executed operations are internally atomic. The adjacent + // legacy tool-call lifecycle is a compatibility encoding for consumers. + return yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputStart, + ID: op.id, + ToolCallName: op.name, + ToolCallInput: "", + ProviderExecuted: true, + }) && yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolInputEnd, + ID: op.id, + ProviderExecuted: true, + }) && yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeToolCall, + ID: op.id, + ToolCallName: op.name, + ToolCallInput: op.input, + ProviderExecuted: true, + }) +} + +func appendProviderExecutedOperations( + content []fantasy.Content, + operations map[string]providerExecutedOperation, +) []fantasy.Content { + for _, id := range slices.Sorted(maps.Keys(operations)) { + content = operations[id].appendContent(content) + } + return content +} + +func yieldProviderExecutedOperations( + yield func(fantasy.StreamPart) bool, + operations map[string]providerExecutedOperation, +) bool { + for _, id := range slices.Sorted(maps.Keys(operations)) { + if !operations[id].yieldToolCall(yield) { + return false + } + } + return true +} + +func providerMetadataForStopReason(stopReason string) fantasy.ProviderMetadata { + if stopReason == "" { + return fantasy.ProviderMetadata{} + } + return fantasy.ProviderMetadata{ + Name: &StopReasonMetadata{StopReason: stopReason}, + } +} + +func firstProviderExecutedOperationID(operations map[string]providerExecutedOperation) string { + for id := range operations { + return id + } + return "" +} + +func incompleteProviderExecutedOperationError(id string) error { + if id == "" { + return fmt.Errorf("anthropic provider-executed operation ended without a matching result") + } + return fmt.Errorf("anthropic provider-executed operation %q ended without a matching result", id) +} + +func orphanProviderExecutedResultError(id string) error { + if id == "" { + return fmt.Errorf("anthropic provider-executed result arrived without a matching server tool use") + } + return fmt.Errorf("anthropic provider-executed result %q arrived without a matching server tool use", id) +} + +func duplicateProviderExecutedOperationError(id string) error { + if id == "" { + return fmt.Errorf("anthropic provider-executed operation used a duplicate ID") + } + return fmt.Errorf("anthropic provider-executed operation used duplicate ID %q", id) +} + type messageBlock struct { Role fantasy.MessageRole Messages []fantasy.Message @@ -1216,6 +1317,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas } var content []fantasy.Content + providerExecutedOperations := make(map[string]providerExecutedOperation) for _, block := range response.Content { switch block.Type { case "text": @@ -1272,12 +1374,14 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas if b, err := json.Marshal(serverToolUse.Input); err == nil { inputStr = string(b) } - content = append(content, fantasy.ToolCallContent{ - ToolCallID: serverToolUse.ID, - ToolName: string(serverToolUse.Name), - Input: inputStr, - ProviderExecuted: true, - }) + if _, exists := providerExecutedOperations[serverToolUse.ID]; exists { + return nil, duplicateProviderExecutedOperationError(serverToolUse.ID) + } + providerExecutedOperations[serverToolUse.ID] = newProviderExecutedOperation( + serverToolUse.ID, + string(serverToolUse.Name), + inputStr, + ) case "web_search_tool_result": webSearchResult, ok := block.AsAny().(anthropic.WebSearchToolResultBlock) if !ok { @@ -1318,7 +1422,25 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas }, } } + operation, ok := providerExecutedOperations[webSearchResult.ToolUseID] + if !ok { + return nil, orphanProviderExecutedResultError(webSearchResult.ToolUseID) + } + content = operation.appendContent(content) content = append(content, toolResult) + delete(providerExecutedOperations, webSearchResult.ToolUseID) + } + } + + if len(providerExecutedOperations) > 0 { + if response.StopReason == "pause_turn" { + // Anthropic pause_turn can end with a server_tool_use that has no + // result yet. Surface the call so callers can round-trip it. + content = appendProviderExecutedOperations(content, providerExecutedOperations) + } else { + return nil, incompleteProviderExecutedOperationError( + firstProviderExecutedOperationID(providerExecutedOperations), + ) } } @@ -1332,7 +1454,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas CacheReadTokens: response.Usage.CacheReadInputTokens, }, FinishReason: mapFinishReason(string(response.StopReason)), - ProviderMetadata: fantasy.ProviderMetadata{}, + ProviderMetadata: providerMetadataForStopReason(string(response.StopReason)), Warnings: warnings, }, nil } @@ -1348,6 +1470,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S stream := a.client.Messages.NewStreaming(ctx, *params, reqOpts...) acc := anthropic.Message{} + providerExecutedOperations := make(map[string]providerExecutedOperation) var sawMessageStop bool return func(yield func(fantasy.StreamPart) bool) { if len(warnings) > 0 { @@ -1402,15 +1525,10 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S return } case "server_tool_use": - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputStart, - ID: chunk.ContentBlock.ID, - ToolCallName: chunk.ContentBlock.Name, - ToolCallInput: "", - ProviderExecuted: true, - }) { - return - } + // Provider-executed calls are buffered at content_block_stop, + // after Anthropic has accumulated any input_json_delta + // fragments. Consumers do not see a successful operation + // until the matching result arrives. } case "content_block_stop": if len(acc.Content)-1 < int(chunk.Index) { @@ -1457,22 +1575,20 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S return } case "server_tool_use": - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputEnd, - ID: contentBlock.ID, - ProviderExecuted: true, - }) { - return - } - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolCall, - ID: contentBlock.ID, - ToolCallName: contentBlock.Name, - ToolCallInput: string(contentBlock.Input), - ProviderExecuted: true, - }) { + operation := newProviderExecutedOperation( + contentBlock.ID, + contentBlock.Name, + string(contentBlock.Input), + ) + if _, exists := providerExecutedOperations[operation.id]; exists { + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: duplicateProviderExecutedOperationError(operation.id), + }) return } + providerExecutedOperations[operation.id] = operation + case "web_search_tool_result": // Read search results directly from the ContentBlockUnion // struct fields instead of using AsAny(). The Anthropic SDK's @@ -1480,6 +1596,14 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S // which corrupts JSON.raw for inline union types like // WebSearchToolResultBlockContentUnion. The struct fields // themselves remain correctly populated from content_block_start. + operation, ok := providerExecutedOperations[contentBlock.ToolUseID] + if !ok { + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: orphanProviderExecutedResultError(contentBlock.ToolUseID), + }) + return + } var metadataResults []WebSearchResultItem var providerMeta fantasy.ProviderMetadata if items := contentBlock.Content.OfWebSearchResultBlockArray; len(items) > 0 { @@ -1512,6 +1636,10 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S }, } } + if !operation.yieldToolCall(yield) { + return + } + delete(providerExecutedOperations, contentBlock.ToolUseID) if !yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeToolResult, ID: contentBlock.ToolUseID, @@ -1557,6 +1685,9 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S continue } contentBlock := acc.Content[int(chunk.Index)] + if contentBlock.Type == "server_tool_use" { + continue + } if !yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeToolInputDelta, ID: contentBlock.ID, @@ -1571,7 +1702,29 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S } err := stream.Err() + var incompleteErr error + if len(providerExecutedOperations) > 0 { + if acc.StopReason == "pause_turn" { + // Anthropic pause_turn can end with server_tool_use blocks that + // have no results yet. Surface the calls so callers can round-trip + // the paused response. + if !yieldProviderExecutedOperations(yield, providerExecutedOperations) { + return + } + } else { + incompleteErr = incompleteProviderExecutedOperationError( + firstProviderExecutedOperationID(providerExecutedOperations), + ) + } + } if err == nil || errors.Is(err, io.EOF) { + if incompleteErr != nil { + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: incompleteErr, + }) + return + } if !sawMessageStop { if err == nil { err = io.EOF @@ -1583,9 +1736,10 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S return } yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeFinish, - ID: acc.ID, - FinishReason: mapFinishReason(string(acc.StopReason)), + Type: fantasy.StreamPartTypeFinish, + ID: acc.ID, + FinishReason: mapFinishReason(string(acc.StopReason)), + ProviderMetadata: providerMetadataForStopReason(string(acc.StopReason)), Usage: fantasy.Usage{ InputTokens: acc.Usage.InputTokens, OutputTokens: acc.Usage.OutputTokens, @@ -1593,16 +1747,22 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S CacheCreationTokens: acc.Usage.CacheCreationInputTokens, CacheReadTokens: acc.Usage.CacheReadInputTokens, }, - ProviderMetadata: fantasy.ProviderMetadata{}, }) return - } else { //nolint: revive + } + + if incompleteErr != nil { yield(fantasy.StreamPart{ Type: fantasy.StreamPartTypeError, - Error: toProviderErr(err), + Error: fmt.Errorf("%w: %w", incompleteErr, toProviderErr(err)), }) return } + + yield(fantasy.StreamPart{ + Type: fantasy.StreamPartTypeError, + Error: toProviderErr(err), + }) }, nil } diff --git a/providers/anthropic/anthropic_provider_operation_test.go b/providers/anthropic/anthropic_provider_operation_test.go new file mode 100644 index 000000000..e8e631ecb --- /dev/null +++ b/providers/anthropic/anthropic_provider_operation_test.go @@ -0,0 +1,498 @@ +package anthropic + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/require" +) + +func TestGenerate_WebSearchResponseRejectsDuplicateProviderOperationIDs(t *testing.T) { + t.Parallel() + + response := mockAnthropicWebSearchResponse() + content, ok := response["content"].([]any) + require.True(t, ok) + response["content"] = []any{content[0], content[0], content[1]} + + server, _ := newAnthropicJSONServer(response) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + _, err = model.Generate(context.Background(), fantasy.Call{ + Prompt: testPrompt(), + Tools: []fantasy.Tool{ + WebSearchTool(nil), + }, + }) + require.ErrorContains(t, err, "duplicate ID") +} + +func TestGenerate_WebSearchResponsePausedTurnEmitsUnmatchedProviderOperation(t *testing.T) { + t.Parallel() + + response := mockAnthropicWebSearchResponse() + content, ok := response["content"].([]any) + require.True(t, ok) + response["content"] = []any{content[0]} + response["stop_reason"] = "pause_turn" + + server, _ := newAnthropicJSONServer(response) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + result, err := model.Generate(context.Background(), fantasy.Call{ + Prompt: testPrompt(), + Tools: []fantasy.Tool{ + WebSearchTool(nil), + }, + }) + require.NoError(t, err) + require.Equal(t, fantasy.FinishReasonStop, result.FinishReason) + require.True(t, IsPauseTurnStopReason(result.ProviderMetadata)) + + toolCalls := result.Content.ToolCalls() + require.Len(t, toolCalls, 1) + require.True(t, toolCalls[0].ProviderExecuted) + require.Equal(t, "srvtoolu_01", toolCalls[0].ToolCallID) + require.Equal(t, "web_search", toolCalls[0].ToolName) + require.JSONEq(t, `{"query":"latest AI news"}`, toolCalls[0].Input) + require.Empty(t, result.Content.ToolResults()) +} + +func TestStream_WebSearchResponseRejectsOrphanResultBeforeSources(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchToolResultChunks( + t, + 0, + "srvtoolu_01", + anthropicWebSearchResultItem( + "https://example.com/orphan", + "Orphan Result", + "encrypted_orphan", + "1 hour ago", + ), + ), + anthropicWebSearchMessageStopChunks(), + )) + + var providerToolEvents []fantasy.StreamPart + var sourceParts []fantasy.StreamPart + var errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart, + fantasy.StreamPartTypeToolInputEnd, + fantasy.StreamPartTypeToolCall, + fantasy.StreamPartTypeToolResult: + if part.ProviderExecuted { + providerToolEvents = append(providerToolEvents, part) + } + case fantasy.StreamPartTypeSource: + sourceParts = append(sourceParts, part) + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + require.Empty(t, providerToolEvents) + require.Empty(t, sourceParts) + require.Len(t, errorParts, 1) + require.ErrorContains(t, errorParts[0].Error, "without a matching server tool use") +} + +func TestStream_WebSearchResponseSkipsProviderInputDeltas(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks( + 0, + "srvtoolu_01", + `{"query":"latest `, + `AI news"}`, + ), + anthropicWebSearchToolResultChunks( + t, + 1, + "srvtoolu_01", + anthropicWebSearchResultItem( + "https://example.com/ai-news", + "Latest AI News", + "encrypted_abc123", + "2 hours ago", + ), + ), + anthropicWebSearchMessageStopChunks(), + )) + + var inputDeltaParts []fantasy.StreamPart + var providerToolCalls []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputDelta: + inputDeltaParts = append(inputDeltaParts, part) + case fantasy.StreamPartTypeToolCall: + if part.ProviderExecuted { + providerToolCalls = append(providerToolCalls, part) + } + } + } + + require.Empty(t, inputDeltaParts) + require.Len(t, providerToolCalls, 1) + require.JSONEq(t, `{"query":"latest AI news"}`, providerToolCalls[0].ToolCallInput) +} + +func TestStream_WebSearchResponseHandlesMultipleProviderOperations(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks(0, "srvtoolu_01"), + anthropicWebSearchToolResultChunks( + t, + 1, + "srvtoolu_01", + anthropicWebSearchResultItem( + "https://example.com/first", + "First Result", + "encrypted_first", + "1 hour ago", + ), + ), + anthropicWebSearchServerToolUseChunks(2, "srvtoolu_02"), + anthropicWebSearchToolResultChunks( + t, + 3, + "srvtoolu_02", + anthropicWebSearchResultItem( + "https://example.com/second", + "Second Result", + "encrypted_second", + "3 hours ago", + ), + ), + anthropicWebSearchMessageStopChunks(), + )) + + var providerToolCallIDs []string + var providerToolResultIDs []string + for _, part := range parts { + if !part.ProviderExecuted { + continue + } + switch part.Type { + case fantasy.StreamPartTypeToolCall: + providerToolCallIDs = append(providerToolCallIDs, part.ID) + case fantasy.StreamPartTypeToolResult: + providerToolResultIDs = append(providerToolResultIDs, part.ID) + } + } + + require.Equal(t, []string{"srvtoolu_01", "srvtoolu_02"}, providerToolCallIDs) + require.Equal(t, []string{"srvtoolu_01", "srvtoolu_02"}, providerToolResultIDs) +} + +func TestStream_WebSearchResponsePausedTurnEmitsUnmatchedProviderOperation(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks(0, "srvtoolu_01"), + anthropicWebSearchMessageDeltaChunks("pause_turn"), + anthropicWebSearchMessageStopChunks(), + )) + + var providerToolEvents []fantasy.StreamPart + var errorParts []fantasy.StreamPart + var finishParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart, + fantasy.StreamPartTypeToolInputEnd, + fantasy.StreamPartTypeToolCall, + fantasy.StreamPartTypeToolResult: + if part.ProviderExecuted { + providerToolEvents = append(providerToolEvents, part) + } + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + case fantasy.StreamPartTypeFinish: + finishParts = append(finishParts, part) + } + } + + require.Empty(t, errorParts) + require.Len(t, providerToolEvents, 3) + require.Equal(t, fantasy.StreamPartTypeToolInputStart, providerToolEvents[0].Type) + require.Equal(t, fantasy.StreamPartTypeToolInputEnd, providerToolEvents[1].Type) + require.Equal(t, fantasy.StreamPartTypeToolCall, providerToolEvents[2].Type) + require.Equal(t, "srvtoolu_01", providerToolEvents[2].ID) + require.JSONEq(t, `{}`, providerToolEvents[2].ToolCallInput) + require.Len(t, finishParts, 1) + require.Equal(t, fantasy.FinishReasonStop, finishParts[0].FinishReason) + require.True(t, IsPauseTurnStopReason(finishParts[0].ProviderMetadata)) +} + +func TestStream_WebSearchResponsePausedTurnWithTransportErrorDoesNotReportIncompleteOperation(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks(0, "srvtoolu_01"), + anthropicWebSearchMessageDeltaChunks("pause_turn"), + []string{ + "event: content_block_delta\n", + "data: {not-json}\n\n", + }, + )) + + var providerToolEvents []fantasy.StreamPart + var errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart, + fantasy.StreamPartTypeToolInputEnd, + fantasy.StreamPartTypeToolCall, + fantasy.StreamPartTypeToolResult: + if part.ProviderExecuted { + providerToolEvents = append(providerToolEvents, part) + } + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + require.Len(t, providerToolEvents, 3) + require.Len(t, errorParts, 1) + require.NotContains(t, errorParts[0].Error.Error(), "ended without a matching result") + require.ErrorContains(t, errorParts[0].Error, "invalid character") +} + +func TestStream_WebSearchResponseRejectsDuplicateProviderOperationIDs(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks(0, "srvtoolu_01"), + anthropicWebSearchServerToolUseChunks(1, "srvtoolu_01"), + anthropicWebSearchMessageStopChunks(), + )) + + var providerToolEvents []fantasy.StreamPart + var errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart, + fantasy.StreamPartTypeToolInputEnd, + fantasy.StreamPartTypeToolCall, + fantasy.StreamPartTypeToolResult: + if part.ProviderExecuted { + providerToolEvents = append(providerToolEvents, part) + } + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + require.Empty(t, providerToolEvents) + require.Len(t, errorParts, 1) + require.ErrorContains(t, errorParts[0].Error, "duplicate ID") +} + +func TestStream_WebSearchResponseSurfacesIncompleteOperationOnStreamError(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks(0, "srvtoolu_01"), + []string{ + "event: content_block_delta\n", + "data: {not-json}\n\n", + }, + )) + + var providerToolEvents []fantasy.StreamPart + var errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart, + fantasy.StreamPartTypeToolInputEnd, + fantasy.StreamPartTypeToolCall, + fantasy.StreamPartTypeToolResult: + if part.ProviderExecuted { + providerToolEvents = append(providerToolEvents, part) + } + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + require.Empty(t, providerToolEvents) + require.Len(t, errorParts, 1) + require.ErrorContains(t, errorParts[0].Error, "ended without a matching result") + require.ErrorContains(t, errorParts[0].Error, "invalid character") +} + +func TestStream_WebSearchResponsePreservesProviderMetadata(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, concatAnthropicChunkSets( + anthropicWebSearchMessageStartChunks(), + anthropicWebSearchServerToolUseChunks(0, "srvtoolu_01"), + anthropicWebSearchToolResultChunks( + t, + 1, + "srvtoolu_01", + anthropicWebSearchResultItem( + "https://example.com/ai-news", + "Latest AI News", + "encrypted_abc123", + "2 hours ago", + ), + anthropicWebSearchResultItem( + "https://example.com/ml-update", + "ML Update", + "encrypted_def456", + "", + ), + ), + anthropicWebSearchMessageStopChunks(), + )) + + var providerToolResults []fantasy.StreamPart + for _, part := range parts { + if part.Type == fantasy.StreamPartTypeToolResult && part.ProviderExecuted { + providerToolResults = append(providerToolResults, part) + } + } + + require.Len(t, providerToolResults, 1) + searchMeta, ok := providerToolResults[0].ProviderMetadata[Name] + require.True(t, ok) + webMeta, ok := searchMeta.(*WebSearchResultMetadata) + require.True(t, ok) + require.Len(t, webMeta.Results, 2) + require.Equal(t, "https://example.com/ai-news", webMeta.Results[0].URL) + require.Equal(t, "encrypted_abc123", webMeta.Results[0].EncryptedContent) + require.Equal(t, "2 hours ago", webMeta.Results[0].PageAge) + require.Equal(t, "https://example.com/ml-update", webMeta.Results[1].URL) + require.Equal(t, "encrypted_def456", webMeta.Results[1].EncryptedContent) + require.Empty(t, webMeta.Results[1].PageAge) +} + +func concatAnthropicChunkSets(chunkSets ...[]string) []string { + var chunks []string + for _, chunkSet := range chunkSets { + chunks = append(chunks, chunkSet...) + } + return chunks +} + +func anthropicWebSearchMessageStartChunks() []string { + return []string{ + "event: message_start\n", + `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n", + } +} + +func anthropicWebSearchMessageDeltaChunks(stopReason string) []string { + return []string{ + "event: message_delta\n", + fmt.Sprintf( + `data: {"type":"message_delta","delta":{"stop_reason":%q},"usage":{"output_tokens":5}}`, + stopReason, + ) + "\n\n", + } +} + +func anthropicWebSearchMessageStopChunks() []string { + return []string{ + "event: message_stop\n", + `data: {"type":"message_stop"}` + "\n\n", + } +} + +func anthropicWebSearchServerToolUseChunks(index int, id string, partialJSONDeltas ...string) []string { + chunks := []string{ + "event: content_block_start\n", + fmt.Sprintf( + `data: {"type":"content_block_start","index":%d,"content_block":{"type":"server_tool_use","id":%q,"name":"web_search","input":{}}}`, + index, + id, + ) + "\n\n", + } + for _, delta := range partialJSONDeltas { + chunks = append(chunks, + "event: content_block_delta\n", + fmt.Sprintf( + `data: {"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":%q}}`, + index, + delta, + )+"\n\n", + ) + } + chunks = append(chunks, + "event: content_block_stop\n", + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, index)+"\n\n", + ) + return chunks +} + +func anthropicWebSearchToolResultChunks(t *testing.T, index int, toolUseID string, items ...map[string]any) []string { + t.Helper() + + if items == nil { + items = []map[string]any{} + } + content, err := json.Marshal(items) + require.NoError(t, err) + + return []string{ + "event: content_block_start\n", + fmt.Sprintf( + `data: {"type":"content_block_start","index":%d,"content_block":{"type":"web_search_tool_result","tool_use_id":%q,"content":%s}}`, + index, + toolUseID, + string(content), + ) + "\n\n", + "event: content_block_stop\n", + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, index) + "\n\n", + } +} + +func anthropicWebSearchResultItem(url, title, encryptedContent, pageAge string) map[string]any { + item := map[string]any{ + "type": "web_search_result", + "url": url, + "title": title, + "encrypted_content": encryptedContent, + } + if pageAge != "" { + item["page_age"] = pageAge + } + return item +} diff --git a/providers/anthropic/anthropic_test.go b/providers/anthropic/anthropic_test.go index 176b4f276..a35c03090 100644 --- a/providers/anthropic/anthropic_test.go +++ b/providers/anthropic/anthropic_test.go @@ -2994,3 +2994,209 @@ func TestStream_ComputerUseTool(t *testing.T) { require.Contains(t, h, "computer-use-2025-01-24", "request %d", i) } } + +func TestGenerate_WebSearchResponseRejectsUnpairedProviderOperations(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + content func([]any) []any + wantErrText string + }{ + { + name: "missing result", + content: func(content []any) []any { + return []any{content[0]} + }, + wantErrText: "ended without a matching result", + }, + { + name: "missing server tool use", + content: func(content []any) []any { + return []any{content[1]} + }, + wantErrText: "without a matching server tool use", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + response := mockAnthropicWebSearchResponse() + content, ok := response["content"].([]any) + require.True(t, ok) + response["content"] = tc.content(content) + + server, _ := newAnthropicJSONServer(response) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + _, err = model.Generate(context.Background(), fantasy.Call{ + Prompt: testPrompt(), + Tools: []fantasy.Tool{ + WebSearchTool(nil), + }, + }) + require.ErrorContains(t, err, tc.wantErrText) + }) + } +} + +func TestStream_WebSearchResponseRejectsUnpairedProviderOperations(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chunks []string + wantErrText string + }{ + { + name: "missing result", + chunks: []string{ + "event: message_start\n", + `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n", + "event: content_block_start\n", + `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n", + "event: content_block_stop\n", + `data: {"type":"content_block_stop","index":0}` + "\n\n", + "event: message_stop\n", + `data: {"type":"message_stop"}` + "\n\n", + }, + wantErrText: "ended without a matching result", + }, + { + name: "missing server tool use", + chunks: []string{ + "event: message_start\n", + `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n", + "event: content_block_start\n", + `data: {"type":"content_block_start","index":0,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":[]}}` + "\n\n", + "event: content_block_stop\n", + `data: {"type":"content_block_stop","index":0}` + "\n\n", + "event: message_stop\n", + `data: {"type":"message_stop"}` + "\n\n", + }, + wantErrText: "without a matching server tool use", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + parts := collectAnthropicStreamPartsFromChunks(t, tc.chunks) + + var providerToolEvents []fantasy.StreamPart + var errorParts []fantasy.StreamPart + for _, part := range parts { + switch part.Type { + case fantasy.StreamPartTypeToolInputStart, + fantasy.StreamPartTypeToolInputEnd, + fantasy.StreamPartTypeToolCall, + fantasy.StreamPartTypeToolResult: + if part.ProviderExecuted { + providerToolEvents = append(providerToolEvents, part) + } + case fantasy.StreamPartTypeError: + errorParts = append(errorParts, part) + } + } + + require.Empty(t, providerToolEvents) + require.Len(t, errorParts, 1) + require.ErrorContains(t, errorParts[0].Error, tc.wantErrText) + }) + } +} + +func TestStream_WebSearchResponseEmitsProviderOperationAdjacently(t *testing.T) { + t.Parallel() + + webSearchResultContent, _ := json.Marshal([]any{ + map[string]any{ + "type": "web_search_result", + "url": "https://example.com/ai-news", + "title": "Latest AI News", + "encrypted_content": "encrypted_abc123", + }, + }) + parts := collectAnthropicStreamPartsFromChunks(t, []string{ + "event: message_start\n", + `data: {"type":"message_start","message":{"id":"msg_01WebSearch","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"usage":{"input_tokens":100,"output_tokens":0}}}` + "\n\n", + "event: content_block_start\n", + `data: {"type":"content_block_start","index":0,"content_block":{"type":"server_tool_use","id":"srvtoolu_01","name":"web_search","input":{}}}` + "\n\n", + "event: content_block_stop\n", + `data: {"type":"content_block_stop","index":0}` + "\n\n", + "event: content_block_start\n", + `data: {"type":"content_block_start","index":1,"content_block":{"type":"web_search_tool_result","tool_use_id":"srvtoolu_01","content":` + string(webSearchResultContent) + `}}` + "\n\n", + "event: content_block_stop\n", + `data: {"type":"content_block_stop","index":1}` + "\n\n", + "event: message_stop\n", + `data: {"type":"message_stop"}` + "\n\n", + }) + + startIndex := -1 + endIndex := -1 + callIndex := -1 + resultIndex := -1 + for i, part := range parts { + if !part.ProviderExecuted { + continue + } + switch part.Type { + case fantasy.StreamPartTypeToolInputStart: + startIndex = i + case fantasy.StreamPartTypeToolInputEnd: + endIndex = i + case fantasy.StreamPartTypeToolCall: + callIndex = i + case fantasy.StreamPartTypeToolResult: + resultIndex = i + } + } + require.NotEqual(t, -1, startIndex) + require.Equal(t, startIndex+1, endIndex) + require.Equal(t, endIndex+1, callIndex) + require.Equal(t, callIndex+1, resultIndex) +} + +func collectAnthropicStreamPartsFromChunks(t *testing.T, chunks []string) []fantasy.StreamPart { + t.Helper() + + server, calls := newAnthropicStreamingServer(chunks) + defer server.Close() + + provider, err := New( + WithAPIKey("test-api-key"), + WithBaseURL(server.URL), + ) + require.NoError(t, err) + + model, err := provider.LanguageModel(context.Background(), "claude-sonnet-4-20250514") + require.NoError(t, err) + + stream, err := model.Stream(context.Background(), fantasy.Call{ + Prompt: testPrompt(), + Tools: []fantasy.Tool{ + WebSearchTool(nil), + }, + }) + require.NoError(t, err) + + var parts []fantasy.StreamPart + stream(func(part fantasy.StreamPart) bool { + parts = append(parts, part) + return true + }) + _ = awaitAnthropicCall(t, calls) + return parts +} diff --git a/providers/anthropic/provider_options.go b/providers/anthropic/provider_options.go index 14ae8111b..0f55d9ea8 100644 --- a/providers/anthropic/provider_options.go +++ b/providers/anthropic/provider_options.go @@ -33,6 +33,7 @@ const ( TypeReasoningOptionMetadata = Name + ".reasoning_metadata" TypeProviderCacheControl = Name + ".cache_control_options" TypeWebSearchResultMetadata = Name + ".web_search_result_metadata" + TypeStopReasonMetadata = Name + ".stop_reason_metadata" ) // Register Anthropic provider-specific types with the global registry. @@ -65,6 +66,13 @@ func init() { } return &v, nil }) + fantasy.RegisterProviderType(TypeStopReasonMetadata, func(data []byte) (fantasy.ProviderOptionsData, error) { + var v StopReasonMetadata + if err := json.Unmarshal(data, &v); err != nil { + return nil, err + } + return &v, nil + }) } // ProviderOptions represents additional options for the Anthropic provider. @@ -191,6 +199,50 @@ func (m *WebSearchResultMetadata) UnmarshalJSON(data []byte) error { return nil } +// StopReasonMetadata stores the raw Anthropic stop reason on a response. +// This preserves provider-specific terminal states such as pause_turn that +// collapse into the generic fantasy.FinishReason values. +type StopReasonMetadata struct { + StopReason string `json:"stop_reason"` +} + +// Options implements the ProviderOptions interface. +func (*StopReasonMetadata) Options() {} + +// MarshalJSON implements custom JSON marshaling with type info for StopReasonMetadata. +func (m StopReasonMetadata) MarshalJSON() ([]byte, error) { + type plain StopReasonMetadata + return fantasy.MarshalProviderType(TypeStopReasonMetadata, plain(m)) +} + +// UnmarshalJSON implements custom JSON unmarshaling with type info for StopReasonMetadata. +func (m *StopReasonMetadata) UnmarshalJSON(data []byte) error { + type plain StopReasonMetadata + var p plain + if err := fantasy.UnmarshalProviderType(data, &p); err != nil { + return err + } + *m = StopReasonMetadata(p) + return nil +} + +// GetStopReasonMetadata extracts Anthropic stop-reason metadata. +func GetStopReasonMetadata(providerMetadata fantasy.ProviderMetadata) *StopReasonMetadata { + if metadata, ok := providerMetadata[Name]; ok { + if stopReason, ok := metadata.(*StopReasonMetadata); ok { + return stopReason + } + } + return nil +} + +// IsPauseTurnStopReason reports whether provider metadata identifies an +// Anthropic pause_turn response. +func IsPauseTurnStopReason(providerMetadata fantasy.ProviderMetadata) bool { + metadata := GetStopReasonMetadata(providerMetadata) + return metadata != nil && metadata.StopReason == "pause_turn" +} + // CacheControl represents cache control settings for the Anthropic provider. type CacheControl struct { Type string `json:"type"` diff --git a/providers/openai/computer_use.go b/providers/openai/computer_use.go index a5a871f7c..4d398ff75 100644 --- a/providers/openai/computer_use.go +++ b/providers/openai/computer_use.go @@ -11,12 +11,12 @@ import ( const computerUseToolID = "openai.computer_use" -// Type identifier for computer use metadata, registered in -// responses_options.go init(). +// TypeComputerUseMetadata is the type identifier for computer use metadata, +// registered in responses_options.go init(). const TypeComputerUseMetadata = Name + ".responses.computer_use_metadata" -// Type identifier for computer call output options, registered in -// responses_options.go init(). +// TypeComputerCallOutputOptions is the type identifier for computer call +// output options, registered in responses_options.go init(). const TypeComputerCallOutputOptions = Name + ".responses.computer_call_output_options" // ComputerUseMetadata stores the raw wire-format JSON of a computer_call diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index c344e6fad..59ad2f000 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -3184,7 +3184,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3212,7 +3212,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3251,7 +3251,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3274,7 +3274,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3298,7 +3298,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3331,7 +3331,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3364,7 +3364,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -4015,7 +4015,7 @@ func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) { t.Run("store false skips item reference", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system instructions", false) + input, warnings, err := toResponsesPrompt(prompt, "system instructions", false, false) require.NoError(t, err) @@ -4029,7 +4029,7 @@ func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) { t.Run("store true skips item reference", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system instructions", true) + input, warnings, err := toResponsesPrompt(prompt, "system instructions", true, false) require.NoError(t, err) @@ -4083,7 +4083,7 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { t.Run("store true emits item_reference for reasoning", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4134,7 +4134,7 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(noIDPrompt, "system", true) + input, warnings, err := toResponsesPrompt(noIDPrompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4151,7 +4151,7 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { t.Run("store false skips reasoning", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4216,7 +4216,7 @@ func TestResponsesToPrompt_ReasoningWithWebSearchCombined(t *testing.T) { t.Run("store true pairs reasoning and web search", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4236,7 +4236,7 @@ func TestResponsesToPrompt_ReasoningWithWebSearchCombined(t *testing.T) { t.Run("store false skips both reasoning and provider tool call", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4272,7 +4272,7 @@ func TestResponsesToPrompt_WebSearchRequiresReasoningReference(t *testing.T) { fantasy.TextPart{Text: "Search completed."}, }, }, - }, "system", true) + }, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4329,7 +4329,7 @@ func TestResponsesToPrompt_ReasoningWithFunctionCallCombined(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4905,7 +4905,7 @@ func TestComputerUseGenerateRoundTrip_NonImageResult(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) // Should warn about non-image result. diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 077c27325..690f16b08 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -177,7 +177,13 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res } storeEnabled := openaiOptions != nil && openaiOptions.Store != nil && *openaiOptions.Store - input, inputWarnings, err := toResponsesPrompt(call.Prompt, modelConfig.systemMessageMode, storeEnabled) + allowOrphanFunctionOutputs := openaiOptions != nil && openaiOptions.PreviousResponseID != nil && *openaiOptions.PreviousResponseID != "" + input, inputWarnings, err := toResponsesPrompt( + call.Prompt, + modelConfig.systemMessageMode, + storeEnabled, + allowOrphanFunctionOutputs, + ) warnings = append(warnings, inputWarnings...) if err != nil { return nil, warnings, err @@ -400,7 +406,12 @@ func responsesUsage(resp responses.Response) fantasy.Usage { return usage } -func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bool) (responses.ResponseInputParam, []fantasy.CallWarning, error) { +func toResponsesPrompt( + prompt fantasy.Prompt, + systemMessageMode string, + store bool, + allowOrphanFunctionOutputs bool, +) (responses.ResponseInputParam, []fantasy.CallWarning, error) { var input responses.ResponseInputParam var warnings []fantasy.CallWarning @@ -741,7 +752,7 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bo } } - if err := validateResponsesInput(input); err != nil { + if err := validateResponsesInput(input, allowOrphanFunctionOutputs); err != nil { return nil, warnings, err } @@ -753,14 +764,14 @@ func isResponsesWebSearchToolCall(toolCallPart fantasy.ToolCallPart) bool { toolCallPart.ToolName == "web_search_preview" } -func validateResponsesInput(input responses.ResponseInputParam) error { - if err := validateResponsesFunctionCallOutputs(input); err != nil { +func validateResponsesInput(input responses.ResponseInputParam, allowOrphanFunctionOutputs bool) error { + if err := validateResponsesFunctionCallOutputs(input, allowOrphanFunctionOutputs); err != nil { return err } return validateResponsesItemReferences(input) } -func validateResponsesFunctionCallOutputs(input responses.ResponseInputParam) error { +func validateResponsesFunctionCallOutputs(input responses.ResponseInputParam, allowOrphanFunctionOutputs bool) error { type callState struct { calls int outputs int @@ -818,7 +829,10 @@ func validateResponsesFunctionCallOutputs(input responses.ResponseInputParam) er for _, callID := range outputIDs { state := states[callID] if state.calls == 0 { - return fmt.Errorf("openai responses prompt has function_call_output without function_call for call_id %q", callID) + if !allowOrphanFunctionOutputs { + return fmt.Errorf("openai responses prompt has function_call_output without function_call for call_id %q", callID) + } + continue } if state.firstOutput < state.firstCall { return fmt.Errorf("openai responses prompt has function_call_output before function_call for call_id %q", callID) diff --git a/providers/openai/responses_params_test.go b/providers/openai/responses_params_test.go index 800eff6ba..cebb4754c 100644 --- a/providers/openai/responses_params_test.go +++ b/providers/openai/responses_params_test.go @@ -148,7 +148,41 @@ func TestPrepareParams_PreviousResponseID_Validation(t *testing.T) { testTextMessage(fantasy.MessageRoleUser, "hello"), }, opts)) require.NoError(t, err) - _ = warnings + require.Empty(t, warnings) + }) + + t.Run("allows orphan function call output", func(t *testing.T) { + t.Parallel() + + _, warnings, err := lm.prepareParams(testCall(fantasy.Prompt{ + testResponsesToolResultMessage("call_orphan", "done"), + testTextMessage(fantasy.MessageRoleUser, "hello"), + }, opts)) + require.NoError(t, err) + require.Empty(t, warnings) + }) + + t.Run("rejects duplicate orphan function call outputs", func(t *testing.T) { + t.Parallel() + + _, warnings, err := lm.prepareParams(testCall(fantasy.Prompt{ + { + Role: fantasy.MessageRoleTool, + Content: []fantasy.MessagePart{ + fantasy.ToolResultPart{ + ToolCallID: "call_duplicate", + Output: fantasy.ToolResultOutputContentText{Text: "first"}, + }, + fantasy.ToolResultPart{ + ToolCallID: "call_duplicate", + Output: fantasy.ToolResultOutputContentText{Text: "second"}, + }, + }, + }, + testTextMessage(fantasy.MessageRoleUser, "hello"), + }, opts)) + require.EqualError(t, err, `openai responses prompt has duplicate function_call_output for call_id "call_duplicate"`) + require.Empty(t, warnings) }) t.Run("rejects without store", func(t *testing.T) { @@ -471,7 +505,7 @@ func TestPrepareParams_ValidatesFunctionCallOutputPairing(t *testing.T) { input, warnings, err := toResponsesPrompt(fantasy.Prompt{ testResponsesProviderToolResultMessage("ws_01"), - }, "system", false) + }, "system", false, false) require.NoError(t, err) require.Empty(t, warnings) require.Empty(t, input) @@ -498,7 +532,7 @@ func TestValidateResponsesInput_WebSearchReferenceRequiresReasoning(t *testing.T err := validateResponsesInput(responses.ResponseInputParam{ responses.ResponseInputItemParamOfItemReference("rs_valid"), responses.ResponseInputItemParamOfItemReference("ws_valid"), - }) + }, false) require.NoError(t, err) }) @@ -507,7 +541,7 @@ func TestValidateResponsesInput_WebSearchReferenceRequiresReasoning(t *testing.T err := validateResponsesInput(responses.ResponseInputParam{ responses.ResponseInputItemParamOfItemReference("ws_orphan"), - }) + }, false) require.EqualError(t, err, `openai responses prompt has web_search_call item_reference without preceding reasoning item_reference for item_id "ws_orphan"`) }) @@ -518,7 +552,16 @@ func TestValidateResponsesInput_WebSearchReferenceRequiresReasoning(t *testing.T responses.ResponseInputItemParamOfItemReference("rs_valid"), responses.ResponseInputItemParamOfMessage("text", responses.EasyInputMessageRoleAssistant), responses.ResponseInputItemParamOfItemReference("ws_orphan"), - }) + }, false) + require.EqualError(t, err, `openai responses prompt has web_search_call item_reference without preceding reasoning item_reference for item_id "ws_orphan"`) + }) + + t.Run("web search references are checked when orphan outputs are allowed", func(t *testing.T) { + t.Parallel() + + err := validateResponsesInput(responses.ResponseInputParam{ + responses.ResponseInputItemParamOfItemReference("ws_orphan"), + }, true) require.EqualError(t, err, `openai responses prompt has web_search_call item_reference without preceding reasoning item_reference for item_id "ws_orphan"`) }) } diff --git a/providertests/openai_computer_use_test.go b/providertests/openai_computer_use_test.go index eda625920..efd5c312d 100644 --- a/providertests/openai_computer_use_test.go +++ b/providertests/openai_computer_use_test.go @@ -69,6 +69,10 @@ func testComputerUseTool(t *testing.T) (fantasy.ExecutableProviderTool, *bool) { return tool, called } +func int64Ptr(v int64) *int64 { + return &v +} + // TestOpenAIComputerUse tests OpenAI computer use tool support via the // agent using the Responses API. Cassettes are stored under // testdata/TestOpenAIComputerUse/. @@ -95,7 +99,7 @@ func TestOpenAIComputerUse(t *testing.T) { result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: "Take a screenshot of the desktop", - MaxOutputTokens: new(int64(4000)), + MaxOutputTokens: int64Ptr(4000), ProviderOptions: providerOpts, }) require.NoError(t, err) @@ -134,7 +138,7 @@ func TestOpenAIComputerUse(t *testing.T) { result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{ Prompt: "Take a screenshot of the desktop", - MaxOutputTokens: new(int64(4000)), + MaxOutputTokens: int64Ptr(4000), ProviderOptions: providerOpts, }) require.NoError(t, err) @@ -238,7 +242,7 @@ func TestOpenAIComputerUse_AllActions(t *testing.T) { result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: prompt, - MaxOutputTokens: new(int64(16000)), + MaxOutputTokens: int64Ptr(16000), ProviderOptions: providerOpts, }) require.NoError(t, err)