Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 196 additions & 36 deletions providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"maps"
"math"
"slices"
"strings"

"charm.land/fantasy"
Expand Down Expand Up @@ -496,6 +497,106 @@ func reasoningEndProviderMetadata(contentBlock anthropic.ContentBlockUnion) fant
}
}

type providerExecutedOperation struct {
id string
name string
input string
}

func newProviderExecutedOperation(id, name, input string) providerExecutedOperation {
return providerExecutedOperation{id: id, name: name, input: input}
}

func (op providerExecutedOperation) appendContent(content []fantasy.Content) []fantasy.Content {
return append(content, fantasy.ToolCallContent{
ToolCallID: op.id,
ToolName: op.name,
Input: op.input,
ProviderExecuted: true,
})
}

func (op providerExecutedOperation) yieldToolCall(yield func(fantasy.StreamPart) bool) bool {
// Provider-executed operations are internally atomic. The adjacent
// legacy tool-call lifecycle is a compatibility encoding for consumers.
return yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolInputStart,
ID: op.id,
ToolCallName: op.name,
ToolCallInput: "",
ProviderExecuted: true,
}) && yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolInputEnd,
ID: op.id,
ProviderExecuted: true,
}) && yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolCall,
ID: op.id,
ToolCallName: op.name,
ToolCallInput: op.input,
ProviderExecuted: true,
})
}

func appendProviderExecutedOperations(
content []fantasy.Content,
operations map[string]providerExecutedOperation,
) []fantasy.Content {
for _, id := range slices.Sorted(maps.Keys(operations)) {
content = operations[id].appendContent(content)
}
return content
}

func yieldProviderExecutedOperations(
yield func(fantasy.StreamPart) bool,
operations map[string]providerExecutedOperation,
) bool {
for _, id := range slices.Sorted(maps.Keys(operations)) {
if !operations[id].yieldToolCall(yield) {
return false
}
}
return true
}

func providerMetadataForStopReason(stopReason string) fantasy.ProviderMetadata {
if stopReason == "" {
return fantasy.ProviderMetadata{}
}
return fantasy.ProviderMetadata{
Name: &StopReasonMetadata{StopReason: stopReason},
}
}

func firstProviderExecutedOperationID(operations map[string]providerExecutedOperation) string {
for id := range operations {
return id
}
return ""
}

func incompleteProviderExecutedOperationError(id string) error {
if id == "" {
return fmt.Errorf("anthropic provider-executed operation ended without a matching result")
}
return fmt.Errorf("anthropic provider-executed operation %q ended without a matching result", id)
}

func orphanProviderExecutedResultError(id string) error {
if id == "" {
return fmt.Errorf("anthropic provider-executed result arrived without a matching server tool use")
}
return fmt.Errorf("anthropic provider-executed result %q arrived without a matching server tool use", id)
}

func duplicateProviderExecutedOperationError(id string) error {
if id == "" {
return fmt.Errorf("anthropic provider-executed operation used a duplicate ID")
}
return fmt.Errorf("anthropic provider-executed operation used duplicate ID %q", id)
}

type messageBlock struct {
Role fantasy.MessageRole
Messages []fantasy.Message
Expand Down Expand Up @@ -1216,6 +1317,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
}

var content []fantasy.Content
providerExecutedOperations := make(map[string]providerExecutedOperation)
for _, block := range response.Content {
switch block.Type {
case "text":
Expand Down Expand Up @@ -1272,12 +1374,14 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
if b, err := json.Marshal(serverToolUse.Input); err == nil {
inputStr = string(b)
}
content = append(content, fantasy.ToolCallContent{
ToolCallID: serverToolUse.ID,
ToolName: string(serverToolUse.Name),
Input: inputStr,
ProviderExecuted: true,
})
if _, exists := providerExecutedOperations[serverToolUse.ID]; exists {
return nil, duplicateProviderExecutedOperationError(serverToolUse.ID)
}
providerExecutedOperations[serverToolUse.ID] = newProviderExecutedOperation(
serverToolUse.ID,
string(serverToolUse.Name),
inputStr,
)
case "web_search_tool_result":
webSearchResult, ok := block.AsAny().(anthropic.WebSearchToolResultBlock)
if !ok {
Expand Down Expand Up @@ -1318,7 +1422,25 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
},
}
}
operation, ok := providerExecutedOperations[webSearchResult.ToolUseID]
if !ok {
return nil, orphanProviderExecutedResultError(webSearchResult.ToolUseID)
}
content = operation.appendContent(content)
content = append(content, toolResult)
delete(providerExecutedOperations, webSearchResult.ToolUseID)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Remember completed provider operation IDs

Because this is the only state used for duplicate detection, deleting the entry after the first web_search_tool_result lets a later server_tool_use with the same ID pass as a fresh operation. In a response containing two completed pairs with the same srvtoolu_* ID, Generate will persist duplicate provider-executed tool calls/results instead of returning the duplicate-ID error promised by this change, and replaying that assistant message can be rejected by Anthropic. Keep a separate seen-ID set (and apply the same pattern in the streaming path) so duplicates are rejected even after the first pair is completed.

Useful? React with 👍 / 👎.

}
}

if len(providerExecutedOperations) > 0 {
if response.StopReason == "pause_turn" {
// Anthropic pause_turn can end with a server_tool_use that has no
// result yet. Surface the call so callers can round-trip it.
content = appendProviderExecutedOperations(content, providerExecutedOperations)
} else {
return nil, incompleteProviderExecutedOperationError(
firstProviderExecutedOperationID(providerExecutedOperations),
)
}
}

Expand All @@ -1332,7 +1454,7 @@ func (a languageModel) Generate(ctx context.Context, call fantasy.Call) (*fantas
CacheReadTokens: response.Usage.CacheReadInputTokens,
},
FinishReason: mapFinishReason(string(response.StopReason)),
ProviderMetadata: fantasy.ProviderMetadata{},
ProviderMetadata: providerMetadataForStopReason(string(response.StopReason)),
Warnings: warnings,
}, nil
}
Expand All @@ -1348,6 +1470,7 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S

stream := a.client.Messages.NewStreaming(ctx, *params, reqOpts...)
acc := anthropic.Message{}
providerExecutedOperations := make(map[string]providerExecutedOperation)
var sawMessageStop bool
return func(yield func(fantasy.StreamPart) bool) {
if len(warnings) > 0 {
Expand Down Expand Up @@ -1402,15 +1525,10 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
return
}
case "server_tool_use":
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolInputStart,
ID: chunk.ContentBlock.ID,
ToolCallName: chunk.ContentBlock.Name,
ToolCallInput: "",
ProviderExecuted: true,
}) {
return
}
// Provider-executed calls are buffered at content_block_stop,
// after Anthropic has accumulated any input_json_delta
// fragments. Consumers do not see a successful operation
// until the matching result arrives.
}
case "content_block_stop":
if len(acc.Content)-1 < int(chunk.Index) {
Expand Down Expand Up @@ -1457,29 +1575,35 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
return
}
case "server_tool_use":
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolInputEnd,
ID: contentBlock.ID,
ProviderExecuted: true,
}) {
return
}
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolCall,
ID: contentBlock.ID,
ToolCallName: contentBlock.Name,
ToolCallInput: string(contentBlock.Input),
ProviderExecuted: true,
}) {
operation := newProviderExecutedOperation(
contentBlock.ID,
contentBlock.Name,
string(contentBlock.Input),
)
if _, exists := providerExecutedOperations[operation.id]; exists {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: duplicateProviderExecutedOperationError(operation.id),
})
return
}
providerExecutedOperations[operation.id] = operation

case "web_search_tool_result":
// Read search results directly from the ContentBlockUnion
// struct fields instead of using AsAny(). The Anthropic SDK's
// Accumulate re-marshals the content block at content_block_stop,
// which corrupts JSON.raw for inline union types like
// WebSearchToolResultBlockContentUnion. The struct fields
// themselves remain correctly populated from content_block_start.
operation, ok := providerExecutedOperations[contentBlock.ToolUseID]
if !ok {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: orphanProviderExecutedResultError(contentBlock.ToolUseID),
})
return
}
var metadataResults []WebSearchResultItem
var providerMeta fantasy.ProviderMetadata
if items := contentBlock.Content.OfWebSearchResultBlockArray; len(items) > 0 {
Expand Down Expand Up @@ -1512,6 +1636,10 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
},
}
}
if !operation.yieldToolCall(yield) {
return
}
delete(providerExecutedOperations, contentBlock.ToolUseID)
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolResult,
ID: contentBlock.ToolUseID,
Expand Down Expand Up @@ -1557,6 +1685,9 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
continue
}
contentBlock := acc.Content[int(chunk.Index)]
if contentBlock.Type == "server_tool_use" {
continue
}
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeToolInputDelta,
ID: contentBlock.ID,
Expand All @@ -1571,7 +1702,29 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
}

err := stream.Err()
var incompleteErr error
if len(providerExecutedOperations) > 0 {
if acc.StopReason == "pause_turn" {
// Anthropic pause_turn can end with server_tool_use blocks that
// have no results yet. Surface the calls so callers can round-trip
// the paused response.
if !yieldProviderExecutedOperations(yield, providerExecutedOperations) {
return
}
} else {
incompleteErr = incompleteProviderExecutedOperationError(
firstProviderExecutedOperationID(providerExecutedOperations),
)
}
}
if err == nil || errors.Is(err, io.EOF) {
if incompleteErr != nil {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: incompleteErr,
})
return
}
if !sawMessageStop {
if err == nil {
err = io.EOF
Expand All @@ -1583,26 +1736,33 @@ func (a languageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.S
return
}
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeFinish,
ID: acc.ID,
FinishReason: mapFinishReason(string(acc.StopReason)),
Type: fantasy.StreamPartTypeFinish,
ID: acc.ID,
FinishReason: mapFinishReason(string(acc.StopReason)),
ProviderMetadata: providerMetadataForStopReason(string(acc.StopReason)),
Usage: fantasy.Usage{
InputTokens: acc.Usage.InputTokens,
OutputTokens: acc.Usage.OutputTokens,
TotalTokens: acc.Usage.InputTokens + acc.Usage.OutputTokens,
CacheCreationTokens: acc.Usage.CacheCreationInputTokens,
CacheReadTokens: acc.Usage.CacheReadInputTokens,
},
ProviderMetadata: fantasy.ProviderMetadata{},
})
return
} else { //nolint: revive
}

if incompleteErr != nil {
yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: toProviderErr(err),
Error: fmt.Errorf("%w: %w", incompleteErr, toProviderErr(err)),
})
return
}

yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: toProviderErr(err),
})
}, nil
}

Expand Down
Loading
Loading