Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ This rule wraps function calls at call sites with instrumentation code. Unlike t
| `append_args` | `[]string` | No (one of `replace`/`append_args` required) | Go expression strings appended as additional arguments to the matched call |
| `variadic_type` | string | No | Element type for the ellipsis IIFE wrapper (e.g. `grpc.DialOption`). Required when any matched call uses `...` spread. |
| `imports` | map[string]string | No | Additional imports needed for injected code (alias: path). Packages must be in the target module's `go.mod`. |
| `path` | string | No | Import path or local path containing helper functions referenced by unqualified calls in `replace` or `append_args`. |

**`replace` and `append_args` are independent and can both be set.** When both are present, `append_args` is applied first (arguments are appended to the call), then `replace` wraps the modified call.

Expand Down Expand Up @@ -232,6 +233,7 @@ wrap_http_get:
target: myapp/server
function_call: net/http.Get
replace: "tracedGet({{ . }})"
path: "github.com/my-org/my-repo/instrumentation/http"
```

In the `myapp/server` package, this transforms:
Expand All @@ -246,7 +248,7 @@ func fetchData(url string) {
}
```

**Note:** The `tracedGet` function must be available in the target package, either defined locally or imported.
**Note:** The `tracedGet` function can be defined locally in the target package, or supplied by `path` as a helper function to compile into the target package.

**What gets wrapped:** Only `http.Get()` calls where `http` is imported from `"net/http"`

Expand Down Expand Up @@ -372,7 +374,7 @@ grpc.Dial(addr, func(v ...grpc.DialOption) []grpc.DialOption {

- The `{{ . }}` placeholder in `replace` represents the original function call.
- `replace` must be a valid Go expression that includes the placeholder; the result may be any expression type.
- Replacement code can only reference packages and functions that are already imported or defined in the target file.
- Replacement code can reference packages imported through `imports`, functions defined in the target file, or unqualified helper functions found under `path`.
- Call rules only affect call sites in the target package, not the function definition itself.
- Multiple calls to the same function will all be wrapped independently.
- Use the qualified format `package/path.FunctionName` for functions.
Expand Down
198 changes: 198 additions & 0 deletions tool/internal/instrument/apply_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@ package instrument

import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"

"github.com/dave/dst"
"github.com/dave/dst/dstutil"

"github.com/open-telemetry/opentelemetry-go-compile-instrumentation/tool/ex"
"github.com/open-telemetry/opentelemetry-go-compile-instrumentation/tool/internal/ast"
"github.com/open-telemetry/opentelemetry-go-compile-instrumentation/tool/internal/rule"
"github.com/open-telemetry/opentelemetry-go-compile-instrumentation/tool/util"
)
Expand All @@ -32,6 +39,10 @@ func (ip *InstrumentPhase) applyCallRule(ctx context.Context, r *rule.InstCallRu

util.Assert(appendModified || replaceModified, "call rule did not match any call")

if err := ip.applyCallRuleHelpers(ctx, r, root); err != nil {
return err
}

if err := ip.addRuleImports(ctx, root, r.Imports, r.Name); err != nil {
return err
}
Expand Down Expand Up @@ -102,6 +113,193 @@ func (*InstrumentPhase) applyCallReplace(
return true, nil
}

func (ip *InstrumentPhase) applyCallRuleHelpers(
ctx context.Context,
r *rule.InstCallRule,
root *dst.File,
) error {
if strings.TrimSpace(r.Path) == "" {
return nil
}

helperNames, err := callRuleHelperNames(r)
if err != nil {
return err
}
removeLocalFuncNames(helperNames, root)
if len(helperNames) == 0 {
return nil
}

files, err := callRuleHelperFiles(r.Path, helperNames)
if err != nil {
return ex.Wrapf(err, "finding helper files for call rule %s at path %s", r.Name, r.Path)
}
for _, file := range files {
if err = ip.addCallRuleHelperFile(ctx, r, file, root.Name.Name); err != nil {
return err
}
}
return nil
}

func callRuleHelperNames(r *rule.InstCallRule) (map[string]bool, error) {
names := make(map[string]bool)
if strings.TrimSpace(r.Replace) != "" {
tmpl, err := newCallTemplate(r.Replace)
if err != nil {
return nil, err
}
expr, err := tmpl.compileExpression(&dst.CallExpr{
Fun: &dst.SelectorExpr{
X: &dst.Ident{Name: "pkg", Path: r.ImportPath},
Sel: &dst.Ident{Name: r.FuncName},
},
})
if err != nil {
return nil, err
}
collectUnqualifiedCallNames(expr, names)
}
for _, arg := range r.AppendArgs {
expr, err := parseGoExpression(arg)
if err != nil {
return nil, err
}
collectUnqualifiedCallNames(expr, names)
}
return names, nil
}

func collectUnqualifiedCallNames(node dst.Node, names map[string]bool) {
dst.Inspect(node, func(n dst.Node) bool {
call, ok := n.(*dst.CallExpr)
if !ok {
return true
}
ident, ok := call.Fun.(*dst.Ident)
if !ok || isBuiltinCall(ident.Name) {
return true
}
names[ident.Name] = true
return true
})
}

func isBuiltinCall(name string) bool {
switch name {
case "append", "cap", "clear", "close", "complex", "copy", "delete", "imag",
"len", "make", "max", "min", "new", "panic", "print", "println", "real", "recover":
return true
default:
return false
}
}

func removeLocalFuncNames(names map[string]bool, root *dst.File) {
for _, decl := range root.Decls {
funcDecl, ok := decl.(*dst.FuncDecl)
if !ok || funcDecl.Recv != nil {
continue
}
delete(names, funcDecl.Name.Name)
}
}

func callRuleHelperFiles(path string, names map[string]bool) ([]string, error) {
files, err := listRuleFiles(path)
if err != nil {
return nil, err
}

remaining := make(map[string]bool, len(names))
for name := range names {
remaining[name] = true
}

var matched []string
for _, file := range files {
if !util.IsGoFile(file) {
continue
}
root, err := ast.ParseFileFast(file)
if err != nil {
return nil, err
}
var found bool
for _, decl := range root.Decls {
funcDecl, ok := decl.(*dst.FuncDecl)
if !ok || funcDecl.Recv != nil || !remaining[funcDecl.Name.Name] {
continue
}
delete(remaining, funcDecl.Name.Name)
found = true
}
if found {
matched = append(matched, file)
}
}

if len(remaining) != 0 {
missing := mapsKeys(remaining)
sort.Strings(missing)
return nil, ex.Newf("helper function(s) %v not found", missing)
}
return matched, nil
}

func mapsKeys(m map[string]bool) []string {
keys := make([]string, 0, len(m))
for key := range m {
keys = append(keys, key)
}
return keys
}

var nonIdentifierChars = regexp.MustCompile(`[^A-Za-z0-9_]+`)

func callRuleHelperOutputName(ruleName, file string) string {
base := filepath.Base(file)
ext := filepath.Ext(base)
name := strings.TrimSuffix(base, ext)
rulePart := nonIdentifierChars.ReplaceAllString(ruleName, "_")
filePart := nonIdentifierChars.ReplaceAllString(name, "_")
return fmt.Sprintf("otelc.%s.%s.go", rulePart, filePart)
}

func (ip *InstrumentPhase) addCallRuleHelperFile(
ctx context.Context,
r *rule.InstCallRule,
file string,
pkgName string,
) error {
data, err := os.ReadFile(file)
if err != nil {
return ex.Wrapf(err, "reading call rule helper file %s", file)
}
root, err := ast.NewAstParser().ParseSource(stripBuildIgnoreTag(string(data)))
if err != nil {
return ex.Wrapf(err, "parsing call rule helper file %s", file)
}
root.Name.Name = pkgName
if err = ip.updateImportConfigForFile(ctx, root, r.Name); err != nil {
return err
}

newFile := filepath.Join(ip.workDir, callRuleHelperOutputName(r.Name, file))
if !util.PathExists(newFile) {
if err = ast.WriteFile(newFile, root); err != nil {
return ex.Wrapf(err, "writing call rule helper file %s", newFile)
}
ip.keepForDebug(newFile)
}
if !ip.hasCompileArg(newFile) {
ip.addCompileArg(newFile)
}
ip.Info("Apply call rule helper file", "rule", r, "helper", file, "new", newFile)
return nil
}

func (ip *InstrumentPhase) applyCallAppendArgs(
r *rule.InstCallRule,
root *dst.File,
Expand Down
57 changes: 57 additions & 0 deletions tool/internal/instrument/apply_call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ package instrument
import (
"context"
"go/token"
"os"
"path/filepath"
"strings"
"testing"

"github.com/dave/dst"
Expand Down Expand Up @@ -98,6 +101,60 @@ func TestApplyCallRule_InvalidTemplate(t *testing.T) {
assert.Contains(t, err.Error(), "rule has no compiled replacement template")
}

func TestApplyCallRule_AddsHelperFileFromPath(t *testing.T) {
helperDir := t.TempDir()
helperFile := filepath.Join(helperDir, "helper.go")
require.NoError(t, os.WriteFile(helperFile, []byte(`package helper

import "fmt"

func Wrapper(resp any) any {
fmt.Println("wrapped")
return resp
}
`), 0o600))

file := makeCallFile(httpGetCall())
r := httpGetRule("Wrapper({{ . }})")
r.Path = helperDir

phase := newTestPhase()
phase.workDir = t.TempDir()

err := phase.applyCallRule(context.Background(), r, file)

require.NoError(t, err)
generated := filepath.Join(phase.workDir, "otelc.wrap_get.helper.go")
assert.Contains(t, phase.compileArgs, generated)
content, err := os.ReadFile(generated)
require.NoError(t, err)
assert.Contains(t, string(content), "package main")
assert.Contains(t, string(content), "func Wrapper(resp any) any")
assert.Contains(t, string(content), `"fmt"`)
}

func TestApplyCallRule_PathWithMissingHelper(t *testing.T) {
helperDir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(helperDir, "helper.go"), []byte(`package helper

func Other(resp any) any {
return resp
}
`), 0o600))

file := makeCallFile(httpGetCall())
r := httpGetRule("Wrapper({{ . }})")
r.Path = helperDir

phase := newTestPhase()
phase.workDir = t.TempDir()

err := phase.applyCallRule(context.Background(), r, file)

require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "Wrapper"), err.Error())
}

// --- matchesCallRule tests ---

func TestMatchesCallRule_QualifiedCallMatches(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions tool/internal/instrument/apply_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,23 @@ func (ip *InstrumentPhase) addCompileArg(newArg string) {
ip.compileArgs = append(ip.compileArgs, newArg)
}

func (ip *InstrumentPhase) hasCompileArg(arg string) bool {
absArg, err := filepath.Abs(arg)
if err != nil {
return false
}
for _, compileArg := range ip.compileArgs {
absCompileArg, err := filepath.Abs(compileArg)
if err != nil {
continue
}
if absCompileArg == absArg {
return true
}
}
return false
}

//go:embed api.tmpl
var templateAPI string

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@ package main

import "unsafe"

func Wrapper(size uintptr) uintptr {
println("Wrapped!")
return size
}

func CallSizeof() {
x := 42
size := Wrapper(unsafe.Sizeof(x))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package main

func Wrapper(size uintptr) uintptr {
println("Wrapped!")
return size
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package helper

func Wrapper(size uintptr) uintptr {
println("Wrapped!")
return size
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ wrap_sizeof_call:
target: main
function_call: unsafe.Sizeof
replace: "Wrapper({{ . }})"
path: testdata/golden/call-rule-only
Loading
Loading