Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions tool/internal/instrument/optimize.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ func removeBeforeTrampolineCall(targetFile *dst.File, tjump *TJump) error {
return ex.Newf("can not remove Before trampoline function")
}

func removeAfterTrampolineDecl(targetFile *dst.File, tjump *TJump) error {
afterTrampolineName := makeName(tjump.rule, tjump.target, false)
for i, decl := range targetFile.Decls {
if funcDecl, ok := decl.(*dst.FuncDecl); ok {
if funcDecl.Name.Name == afterTrampolineName {
targetFile.Decls = append(targetFile.Decls[:i], targetFile.Decls[i+1:]...)
return nil
}
}
}
return ex.Newf("can not remove After trampoline function")
}

// canFlattenTJump checks if the tjump can be safely flattened based on
// the hook function's usage of HookContext. Returns true if:
// 1. SetSkipCall is never called (so skip is always false)
Expand Down Expand Up @@ -302,7 +315,6 @@ func flattenTJump(tjump *TJump, removedOnExit bool) error {
// block, at this point, all lhs are unused, replace assignment to simple
// function call
ifStmt.Init = ast.ExprStmt(initStmt.Rhs[0])
// TODO: Remove After declaration as well
} else {
// Otherwise, mark skipCall identifier as unused
skipCallIdent := util.AssertType[*dst.Ident](initStmt.Lhs[1])
Expand All @@ -316,6 +328,27 @@ func stripTJumpLabel(tjump *TJump) {
ifStmt.Decs.If = nil
}

func (ip *InstrumentPhase) flattenTJumpIfPossible(tjump *TJump, removedOnExit bool) error {
if tjump.rule.Before == "" {
return nil
}

hookFunc, err := getHookFunc(tjump.rule, true)
if err != nil {
return err
}
if !canFlattenTJump(hookFunc) {
return nil
}
if err = flattenTJump(tjump, removedOnExit); err != nil {
return err
}
if removedOnExit {
return removeAfterTrampolineDecl(ip.target, tjump)
}
return nil
}

func (ip *InstrumentPhase) optimizeTJumps() error {
for _, tjump := range ip.tjumps {
mustTJump(tjump.ifStmt)
Expand Down Expand Up @@ -353,17 +386,8 @@ func (ip *InstrumentPhase) optimizeTJumps() error {
// memory aware and may generate memory SSA values during compilation.
// This further simplifies the trampoline-jump-if and gives more chances
// for optimization passes to kick in.
if rule.Before != "" {
hookFunc, err := getHookFunc(tjump.rule, true)
if err != nil {
return err
}
if canFlattenTJump(hookFunc) {
err1 := flattenTJump(tjump, removedOnExit)
if err1 != nil {
return err1
}
}
if err := ip.flattenTJumpIfPossible(tjump, removedOnExit); err != nil {
return err
}
}
return nil
Expand Down
25 changes: 25 additions & 0 deletions tool/internal/instrument/optimize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,31 @@ func TestRemoveBeforeTrampolineCall(t *testing.T) {
assert.Len(t, tjump.ifStmt.Body.List, 1)
}

func TestRemoveAfterTrampolineDecl(t *testing.T) {
funcSrc := `package main
func testFunc(param1 string) {}`

targetFunc := parseFunc(t, funcSrc)
tjump := &TJump{
target: targetFunc,
rule: &rule.InstFuncRule{
Func: targetFunc.Name.Name,
Before: "beforeHook",
},
}
afterFuncName := makeName(tjump.rule, tjump.target, false)
fileSrc := fmt.Sprintf(`package main
func testFunc(param1 string) {}
func %s() {}`, afterFuncName)
targetFile, err := ast.NewAstParser().ParseSource(fileSrc)
require.NoError(t, err)
require.NotNil(t, ast.FindFuncDeclWithoutRecv(targetFile, afterFuncName))

err = removeAfterTrampolineDecl(targetFile, tjump)
require.NoError(t, err)
assert.Nil(t, ast.FindFuncDeclWithoutRecv(targetFile, afterFuncName))
}

func TestFlattenTJump(t *testing.T) {
tests := []struct {
name string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,5 @@ func OtelBeforeTrampoline_Func12350319093(param0 *string, param1 *int) (hookCont
return hookContext, hookContext.skipCall
}

func OtelAfterTrampoline_Func12350319093(hookContext HookContext) {
defer func() {
if err := recover(); err != nil {
println("failed to exec After hook", "")
if e, ok := err.(error); ok {
println(e.Error())
}
fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl
if fetchStack != nil && printStack != nil {
printStack(fetchStack())
}
}
}()
hookContext.(*HookContextImpl2350319093).returnVals = []interface{}{}
}

//go:linkname H1Before testdata/golden/before-only.H1Before
func H1Before(hookContext HookContext, param0 string, param1 int)
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,5 @@ func OtelBeforeTrampoline_EllipsisFunc2350319093(param0 *[]string) (hookContext
return hookContext, hookContext.skipCall
}

func OtelAfterTrampoline_EllipsisFunc2350319093(hookContext HookContext) {
defer func() {
if err := recover(); err != nil {
println("failed to exec After hook", "")
if e, ok := err.(error); ok {
println(e.Error())
}
fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl
if fetchStack != nil && printStack != nil {
printStack(fetchStack())
}
}
}()
hookContext.(*HookContextImpl2350319093).returnVals = []interface{}{}
}

//go:linkname H9Before testdata/golden/ellipsis-syntax.H9Before
func H9Before(hookContext HookContext, param0 ...string)
Original file line number Diff line number Diff line change
Expand Up @@ -295,21 +295,5 @@ func OtelBeforeTrampoline_Func32216990050(recv0 *T) (hookContext *HookContextImp
return hookContext, hookContext.skipCall
}

func OtelAfterTrampoline_Func32216990050(hookContext HookContext) {
defer func() {
if err := recover(); err != nil {
println("failed to exec After hook", "")
if e, ok := err.(error); ok {
println(e.Error())
}
fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl
if fetchStack != nil && printStack != nil {
printStack(fetchStack())
}
}
}()
hookContext.(*HookContextImpl2216990050).returnVals = []interface{}{}
}

//go:linkname H11Before testdata/golden/method-receiver.H11Before
func H11Before(hookContext HookContext, recv0 interface{})
Original file line number Diff line number Diff line change
Expand Up @@ -395,21 +395,5 @@ func OtelBeforeTrampoline_OptGood3887151894() (hookContext *HookContextImpl38871
return hookContext, hookContext.skipCall
}

func OtelAfterTrampoline_OptGood3887151894(hookContext HookContext) {
defer func() {
if err := recover(); err != nil {
println("failed to exec After hook", "")
if e, ok := err.(error); ok {
println(e.Error())
}
fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl
if fetchStack != nil && printStack != nil {
printStack(fetchStack())
}
}
}()
hookContext.(*HookContextImpl3887151894).returnVals = []interface{}{}
}

//go:linkname H5Before testdata/golden/opt-multiple-funcs.H5Before
func H5Before(hookContext HookContext)
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,5 @@ func OtelBeforeTrampoline_UnderscoreFunc2694310588(param0 *int, param1 *float32)
return hookContext, hookContext.skipCall
}

func OtelAfterTrampoline_UnderscoreFunc2694310588(hookContext HookContext) {
defer func() {
if err := recover(); err != nil {
println("failed to exec After hook", "")
if e, ok := err.(error); ok {
println(e.Error())
}
fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl
if fetchStack != nil && printStack != nil {
printStack(fetchStack())
}
}
}()
hookContext.(*HookContextImpl2694310588).returnVals = []interface{}{}
}

//go:linkname H10Before testdata/golden/underscore-syntax.H10Before
func H10Before(hookContext HookContext, param0 int, param1 float32)
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,5 @@ func OtelBeforeTrampoline_Unnamed3742873013(param0 *int, param1 *float32) (hookC
return hookContext, hookContext.skipCall
}

func OtelAfterTrampoline_Unnamed3742873013(hookContext HookContext) {
defer func() {
if err := recover(); err != nil {
println("failed to exec After hook", "")
if e, ok := err.(error); ok {
println(e.Error())
}
fetchStack, printStack := OtelGetStackImpl, OtelPrintStackImpl
if fetchStack != nil && printStack != nil {
printStack(fetchStack())
}
}
}()
hookContext.(*HookContextImpl3742873013).returnVals = []interface{}{}
}

//go:linkname H13Before testdata/golden/unnamed-param.H13Before
func H13Before(hookContext HookContext, param0 int, param1 float32)
Loading