diff --git a/tool/internal/instrument/optimize.go b/tool/internal/instrument/optimize.go index 02c3f77a..49717ae5 100644 --- a/tool/internal/instrument/optimize.go +++ b/tool/internal/instrument/optimize.go @@ -178,6 +178,19 @@ func removeBeforeTrampolineCall(targetFile *dst.File, tjump *TJump) error { return ex.Newf("can not remove Before trampoline function") } +func removeAfterTrampolineDecl(targetFile *dst.File, tjump *TJump) error { + afterTrampolineName := makeName(tjump.rule, tjump.target, false) + for i, decl := range targetFile.Decls { + if funcDecl, ok := decl.(*dst.FuncDecl); ok { + if funcDecl.Name.Name == afterTrampolineName { + targetFile.Decls = append(targetFile.Decls[:i], targetFile.Decls[i+1:]...) + return nil + } + } + } + return ex.Newf("can not remove After trampoline function") +} + // canFlattenTJump checks if the tjump can be safely flattened based on // the hook function's usage of HookContext. Returns true if: // 1. SetSkipCall is never called (so skip is always false) @@ -302,7 +315,6 @@ func flattenTJump(tjump *TJump, removedOnExit bool) error { // block, at this point, all lhs are unused, replace assignment to simple // function call ifStmt.Init = ast.ExprStmt(initStmt.Rhs[0]) - // TODO: Remove After declaration as well } else { // Otherwise, mark skipCall identifier as unused skipCallIdent := util.AssertType[*dst.Ident](initStmt.Lhs[1]) @@ -316,6 +328,27 @@ func stripTJumpLabel(tjump *TJump) { ifStmt.Decs.If = nil } +func (ip *InstrumentPhase) flattenTJumpIfPossible(tjump *TJump, removedOnExit bool) error { + if tjump.rule.Before == "" { + return nil + } + + hookFunc, err := getHookFunc(tjump.rule, true) + if err != nil { + return err + } + if !canFlattenTJump(hookFunc) { + return nil + } + if err = flattenTJump(tjump, removedOnExit); err != nil { + return err + } + if removedOnExit { + return removeAfterTrampolineDecl(ip.target, tjump) + } + return nil +} + func (ip *InstrumentPhase) optimizeTJumps() error { for _, tjump := range ip.tjumps { mustTJump(tjump.ifStmt) @@ -353,17 +386,8 @@ func (ip *InstrumentPhase) optimizeTJumps() error { // memory aware and may generate memory SSA values during compilation. // This further simplifies the trampoline-jump-if and gives more chances // for optimization passes to kick in. - if rule.Before != "" { - hookFunc, err := getHookFunc(tjump.rule, true) - if err != nil { - return err - } - if canFlattenTJump(hookFunc) { - err1 := flattenTJump(tjump, removedOnExit) - if err1 != nil { - return err1 - } - } + if err := ip.flattenTJumpIfPossible(tjump, removedOnExit); err != nil { + return err } } return nil diff --git a/tool/internal/instrument/optimize_test.go b/tool/internal/instrument/optimize_test.go index 95437bcb..4cc55a9d 100644 --- a/tool/internal/instrument/optimize_test.go +++ b/tool/internal/instrument/optimize_test.go @@ -386,6 +386,31 @@ func TestRemoveBeforeTrampolineCall(t *testing.T) { assert.Len(t, tjump.ifStmt.Body.List, 1) } +func TestRemoveAfterTrampolineDecl(t *testing.T) { + funcSrc := `package main + func testFunc(param1 string) {}` + + targetFunc := parseFunc(t, funcSrc) + tjump := &TJump{ + target: targetFunc, + rule: &rule.InstFuncRule{ + Func: targetFunc.Name.Name, + Before: "beforeHook", + }, + } + afterFuncName := makeName(tjump.rule, tjump.target, false) + fileSrc := fmt.Sprintf(`package main + func testFunc(param1 string) {} + func %s() {}`, afterFuncName) + targetFile, err := ast.NewAstParser().ParseSource(fileSrc) + require.NoError(t, err) + require.NotNil(t, ast.FindFuncDeclWithoutRecv(targetFile, afterFuncName)) + + err = removeAfterTrampolineDecl(targetFile, tjump) + require.NoError(t, err) + assert.Nil(t, ast.FindFuncDeclWithoutRecv(targetFile, afterFuncName)) +} + func TestFlattenTJump(t *testing.T) { tests := []struct { name string diff --git a/tool/internal/instrument/testdata/golden/before-only/before_only.main.go.golden b/tool/internal/instrument/testdata/golden/before-only/before_only.main.go.golden index 38784563..91c9332f 100644 --- a/tool/internal/instrument/testdata/golden/before-only/before_only.main.go.golden +++ b/tool/internal/instrument/testdata/golden/before-only/before_only.main.go.golden @@ -162,21 +162,5 @@ func OtelBeforeTrampoline_Func12350319093(param0 *string, param1 *int) (hookCont return hookContext, hookContext.skipCall } -func OtelAfterTrampoline_Func12350319093(hookContext HookContext) { - defer func() { - if err := recover(); err != nil { - println("failed to exec After hook", "") - if e, ok := err.(error); ok { - println(e.Error()) - } - fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl - if fetchStack != nil && printStack != nil { - printStack(fetchStack()) - } - } - }() - hookContext.(*HookContextImpl2350319093).returnVals = []interface{}{} -} - //go:linkname H1Before testdata/golden/before-only.H1Before func H1Before(hookContext HookContext, param0 string, param1 int) diff --git a/tool/internal/instrument/testdata/golden/ellipsis-syntax/ellipsis_syntax.main.go.golden b/tool/internal/instrument/testdata/golden/ellipsis-syntax/ellipsis_syntax.main.go.golden index 49eb0b73..28708c81 100644 --- a/tool/internal/instrument/testdata/golden/ellipsis-syntax/ellipsis_syntax.main.go.golden +++ b/tool/internal/instrument/testdata/golden/ellipsis-syntax/ellipsis_syntax.main.go.golden @@ -150,21 +150,5 @@ func OtelBeforeTrampoline_EllipsisFunc2350319093(param0 *[]string) (hookContext return hookContext, hookContext.skipCall } -func OtelAfterTrampoline_EllipsisFunc2350319093(hookContext HookContext) { - defer func() { - if err := recover(); err != nil { - println("failed to exec After hook", "") - if e, ok := err.(error); ok { - println(e.Error()) - } - fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl - if fetchStack != nil && printStack != nil { - printStack(fetchStack()) - } - } - }() - hookContext.(*HookContextImpl2350319093).returnVals = []interface{}{} -} - //go:linkname H9Before testdata/golden/ellipsis-syntax.H9Before func H9Before(hookContext HookContext, param0 ...string) diff --git a/tool/internal/instrument/testdata/golden/method-receiver/method_receiver.main.go.golden b/tool/internal/instrument/testdata/golden/method-receiver/method_receiver.main.go.golden index e04c5bd6..44d06d8d 100644 --- a/tool/internal/instrument/testdata/golden/method-receiver/method_receiver.main.go.golden +++ b/tool/internal/instrument/testdata/golden/method-receiver/method_receiver.main.go.golden @@ -295,21 +295,5 @@ func OtelBeforeTrampoline_Func32216990050(recv0 *T) (hookContext *HookContextImp return hookContext, hookContext.skipCall } -func OtelAfterTrampoline_Func32216990050(hookContext HookContext) { - defer func() { - if err := recover(); err != nil { - println("failed to exec After hook", "") - if e, ok := err.(error); ok { - println(e.Error()) - } - fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl - if fetchStack != nil && printStack != nil { - printStack(fetchStack()) - } - } - }() - hookContext.(*HookContextImpl2216990050).returnVals = []interface{}{} -} - //go:linkname H11Before testdata/golden/method-receiver.H11Before func H11Before(hookContext HookContext, recv0 interface{}) diff --git a/tool/internal/instrument/testdata/golden/opt-multiple-funcs/opt_multiple_funcs.main.go.golden b/tool/internal/instrument/testdata/golden/opt-multiple-funcs/opt_multiple_funcs.main.go.golden index 9e6f8635..730aecd5 100644 --- a/tool/internal/instrument/testdata/golden/opt-multiple-funcs/opt_multiple_funcs.main.go.golden +++ b/tool/internal/instrument/testdata/golden/opt-multiple-funcs/opt_multiple_funcs.main.go.golden @@ -395,21 +395,5 @@ func OtelBeforeTrampoline_OptGood3887151894() (hookContext *HookContextImpl38871 return hookContext, hookContext.skipCall } -func OtelAfterTrampoline_OptGood3887151894(hookContext HookContext) { - defer func() { - if err := recover(); err != nil { - println("failed to exec After hook", "") - if e, ok := err.(error); ok { - println(e.Error()) - } - fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl - if fetchStack != nil && printStack != nil { - printStack(fetchStack()) - } - } - }() - hookContext.(*HookContextImpl3887151894).returnVals = []interface{}{} -} - //go:linkname H5Before testdata/golden/opt-multiple-funcs.H5Before func H5Before(hookContext HookContext) diff --git a/tool/internal/instrument/testdata/golden/underscore-syntax/underscore_syntax.main.go.golden b/tool/internal/instrument/testdata/golden/underscore-syntax/underscore_syntax.main.go.golden index 0be6f505..17cab532 100644 --- a/tool/internal/instrument/testdata/golden/underscore-syntax/underscore_syntax.main.go.golden +++ b/tool/internal/instrument/testdata/golden/underscore-syntax/underscore_syntax.main.go.golden @@ -154,21 +154,5 @@ func OtelBeforeTrampoline_UnderscoreFunc2694310588(param0 *int, param1 *float32) return hookContext, hookContext.skipCall } -func OtelAfterTrampoline_UnderscoreFunc2694310588(hookContext HookContext) { - defer func() { - if err := recover(); err != nil { - println("failed to exec After hook", "") - if e, ok := err.(error); ok { - println(e.Error()) - } - fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl - if fetchStack != nil && printStack != nil { - printStack(fetchStack()) - } - } - }() - hookContext.(*HookContextImpl2694310588).returnVals = []interface{}{} -} - //go:linkname H10Before testdata/golden/underscore-syntax.H10Before func H10Before(hookContext HookContext, param0 int, param1 float32) diff --git a/tool/internal/instrument/testdata/golden/unnamed-param/unnamed_param.main.go.golden b/tool/internal/instrument/testdata/golden/unnamed-param/unnamed_param.main.go.golden index 54e57333..9f2cc60d 100644 --- a/tool/internal/instrument/testdata/golden/unnamed-param/unnamed_param.main.go.golden +++ b/tool/internal/instrument/testdata/golden/unnamed-param/unnamed_param.main.go.golden @@ -154,21 +154,5 @@ func OtelBeforeTrampoline_Unnamed3742873013(param0 *int, param1 *float32) (hookC return hookContext, hookContext.skipCall } -func OtelAfterTrampoline_Unnamed3742873013(hookContext HookContext) { - defer func() { - if err := recover(); err != nil { - println("failed to exec After hook", "") - if e, ok := err.(error); ok { - println(e.Error()) - } - fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl - if fetchStack != nil && printStack != nil { - printStack(fetchStack()) - } - } - }() - hookContext.(*HookContextImpl3742873013).returnVals = []interface{}{} -} - //go:linkname H13Before testdata/golden/unnamed-param.H13Before func H13Before(hookContext HookContext, param0 int, param1 float32)