diff --git a/tool/internal/pkgload/pkgload.go b/tool/internal/pkgload/pkgload.go index 526fe5694..6022e0adf 100644 --- a/tool/internal/pkgload/pkgload.go +++ b/tool/internal/pkgload/pkgload.go @@ -105,3 +105,26 @@ func ResolveExportFiles(ctx context.Context, importPath string, buildFlags ...st return result, nil } + +// ResolveModuleDir returns the module directory for a given package directory. +func ResolveModuleDir(ctx context.Context, pkgDir string) (string, error) { + pkgs, err := LoadPackages(ctx, packages.NeedModule, nil, pkgDir) + if err != nil { + return "", err + } + if len(pkgs) == 0 { + return "", ex.Newf("no packages found for directory: %s", pkgDir) + } + + pkg := pkgs[0] + if pkg.Module == nil || pkg.Module.Dir == "" || len(pkg.Errors) > 0 { + return "", ex.Newf( + "failed to load module information for package in directory %s: module=%v, errors=%v", + pkgDir, + pkg.Module, + pkg.Errors, + ) + } + + return pkg.Module.Dir, nil +} diff --git a/tool/internal/pkgload/pkgload_test.go b/tool/internal/pkgload/pkgload_test.go index 6a1bb7a36..2e6f70669 100644 --- a/tool/internal/pkgload/pkgload_test.go +++ b/tool/internal/pkgload/pkgload_test.go @@ -4,6 +4,8 @@ package pkgload import ( + "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -103,3 +105,131 @@ func TestResolveExportFiles_NoExportFile(t *testing.T) { assert.Contains(t, err.Error(), "not found or has no export file") assert.Nil(t, archives) } + +func TestResolveModuleDir(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, root string) string + expectedDir string + expectError bool + }{ + { + name: "finds go.mod in current directory", + setup: func(t *testing.T, root string) string { + err := os.WriteFile( + filepath.Join(root, "go.mod"), + []byte("module example.com/test\n"), + 0o644, + ) + require.NoError(t, err) + + err = os.WriteFile( + filepath.Join(root, "main.go"), + []byte("package main\n\nfunc main() {}\n"), + 0o644, + ) + require.NoError(t, err) + + return root + }, + expectedDir: ".", + }, + { + name: "finds go.mod in parent directory", + setup: func(t *testing.T, root string) string { + err := os.WriteFile( + filepath.Join(root, "go.mod"), + []byte("module example.com/test\n"), + 0o644, + ) + require.NoError(t, err) + + nested := filepath.Join(root, "a", "b", "c") + err = os.MkdirAll(nested, 0o755) + require.NoError(t, err) + + err = os.WriteFile( + filepath.Join(nested, "main.go"), + []byte("package main\n\nfunc main() {}\n"), + 0o644, + ) + require.NoError(t, err) + + return nested + }, + expectedDir: ".", + }, + { + name: "returns error when no go.mod exists", + setup: func(t *testing.T, root string) string { + return root + }, + expectError: true, + }, + { + name: "fails for directory without go files", + setup: func(t *testing.T, root string) string { + err := os.WriteFile( + filepath.Join(root, "go.mod"), + []byte("module example.com/test\n"), + 0o644, + ) + require.NoError(t, err) + + emptyDir := filepath.Join(root, "empty") + err = os.MkdirAll(emptyDir, 0o755) + require.NoError(t, err) + + return emptyDir + }, + expectError: true, + }, + { + name: "fails for build-tag-excluded package", + setup: func(t *testing.T, root string) string { + err := os.WriteFile( + filepath.Join(root, "go.mod"), + []byte("module example.com/test\n"), + 0o644, + ) + require.NoError(t, err) + + err = os.WriteFile( + filepath.Join(root, "main.go"), + []byte("//go:build never\n\npackage main\n"), + 0o644, + ) + require.NoError(t, err) + + return root + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + workDir := tt.setup(t, tmpDir) + + t.Chdir(workDir) + + ctx := t.Context() + moduleDir, err := ResolveModuleDir(ctx, workDir) + + if tt.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + + expectedDir := tmpDir + if tt.expectedDir != "." { + expectedDir = tt.expectedDir + } + + require.Equal(t, expectedDir, moduleDir) + }) + } +} diff --git a/tool/internal/setup/setup.go b/tool/internal/setup/setup.go index e3e39a037..f427243f3 100644 --- a/tool/internal/setup/setup.go +++ b/tool/internal/setup/setup.go @@ -75,6 +75,8 @@ var flagsWithPathValues = map[string]bool{ "-toolexec": true, } +const commandLineArgumentsPackage = "command-line-arguments" + // GetBuildPackages loads all packages from the otelc go build/install or otelc setup command arguments. // Returns a list of loaded packages. If no package patterns are found in args, // defaults to loading the current directory package. @@ -87,10 +89,61 @@ var flagsWithPathValues = map[string]bool{ // - args [] returns packages for "." func getBuildPackages(ctx context.Context, args []string) ([]*packages.Package, error) { logger := util.LoggerFromContext(ctx) - mode := packages.NeedName | packages.NeedFiles | packages.NeedModule - buildPkgs := make([]*packages.Package, 0) - found := false + + pkgTargets, fileTargets, err := splitBuildTargets(args) + if err != nil { + return nil, ex.Wrapf(err, "splitting build targets") + } + + var ( + pkgs []*packages.Package + loadErr error + ) + switch { + case len(fileTargets) > 0: + pkgs, loadErr = pkgload.LoadPackages(ctx, mode, nil, fileTargets...) + if loadErr != nil { + return nil, ex.Wrapf(loadErr, "failed to load packages for files %v", fileTargets) + } + + if len(pkgs) > 1 { + return nil, ex.New("multiple packages found for file targets") + } + case len(pkgTargets) > 0: + pkgs, loadErr = pkgload.LoadPackages(ctx, mode, nil, pkgTargets...) + if loadErr != nil { + return nil, ex.Wrapf(loadErr, "failed to load packages for patterns %v", pkgTargets) + } + default: + pkgs, loadErr = pkgload.LoadPackages(ctx, mode, nil, ".") + if loadErr != nil { + return nil, ex.Wrapf(loadErr, "failed to load packages for pattern .") + } + } + + buildPkgs := make([]*packages.Package, 0, len(pkgs)) + for _, pkg := range pkgs { + // file-based builds use synthetic "command-line-arguments" packages + if len(pkg.Errors) > 0 || (pkg.Module == nil && pkg.PkgPath != commandLineArgumentsPackage) { + logger.DebugContext(ctx, "skipping package", "name", pkg.Name, "errors", pkg.Errors, "args", args) + continue + } + + buildPkgs = append(buildPkgs, pkg) + } + + if len(buildPkgs) == 0 { + return nil, ex.New("no valid packages found in build targets") + } + + return buildPkgs, nil +} + +//nolint:revive // if we add named returns then nonamedreturns will complain +func splitBuildTargets(args []string) ([]string, []string, error) { + var pkgs, files []string + for i := len(args) - 1; i >= 0; i-- { arg := args[i] @@ -107,28 +160,38 @@ func getBuildPackages(ctx context.Context, args []string) ([]*packages.Package, break } - pkgs, err := pkgload.LoadPackages(ctx, mode, nil, arg) - if err != nil { - return nil, ex.Wrapf(err, "failed to load packages for pattern %s", arg) - } - for _, pkg := range pkgs { - if pkg.Errors != nil || pkg.Module == nil { - logger.DebugContext(ctx, "skipping package", "pattern", arg, "errors", pkg.Errors, "module", pkg.Module) - continue - } - buildPkgs = append(buildPkgs, pkg) - found = true + if filepath.Ext(arg) == ".go" { + files = append(files, arg) + } else { + pkgs = append(pkgs, arg) } } - if !found { - var err error - buildPkgs, err = pkgload.LoadPackages(ctx, mode, nil, ".") + if len(files) > 0 && len(pkgs) > 0 { + return nil, nil, ex.New("cannot mix .go files and packages") + } + + if len(files) > 0 { + // files are collected in reverse order due to reverse argument traversal. + // files[0] is therefore the last .go file from the original CLI args. + dir, err := filepath.Abs(filepath.Dir(files[0])) if err != nil { - return nil, ex.Wrapf(err, "failed to load packages for pattern .") + return nil, nil, ex.Wrapf(err, "failed to get absolute path for directory containing files") + } + + for _, f := range files[1:] { + fdir, err2 := filepath.Abs(filepath.Dir(f)) + if err2 != nil { + return nil, nil, ex.Wrapf(err2, "failed to get absolute path for directory containing file %s", f) + } + + if fdir != dir { + return nil, nil, ex.New("named files must all be in one directory") + } } } - return buildPkgs, nil + + return pkgs, files, nil } func getPackageDir(pkg *packages.Package) string { @@ -195,15 +258,27 @@ func Setup(ctx context.Context, cmd *cli.Command) error { // Generate otelc.runtime.go for all packages moduleDirs := make(map[string]bool) for _, pkg := range pkgs { - if pkg.Module == nil { + // file-based builds use synthetic "command-line-arguments" packages + if pkg.Module == nil && pkg.PkgPath != commandLineArgumentsPackage { sp.Warn("skipping package without module", "package", pkg.PkgPath) continue } - moduleDir := pkg.Module.Dir + pkgDir := getPackageDir(pkg) if pkgDir == "" { - pkgDir = moduleDir + sp.Warn("skipping package without Go files", "package", pkg.PkgPath) + continue } + + var moduleDir string + if pkg.Module != nil { + moduleDir = pkg.Module.Dir + } else { + if moduleDir, err = pkgload.ResolveModuleDir(ctx, pkgDir); err != nil { + return ex.Wrapf(err, "finding module dir for package %s", pkg.PkgPath) + } + } + // Introduce additional hook code by generating otelc.runtime.go if err = sp.addDeps(matched, pkgDir); err != nil { return ex.Wrapf(err, "adding deps for package at %s", pkgDir) @@ -356,7 +431,16 @@ func BuildWithToolexec(ctx context.Context, cmd *cli.Command) error { // Add "-toolexec=..." newArgs = append(newArgs, insert) // Add the rest - newArgs = append(newArgs, args[1:]...) + restArgs := args[1:] + if _, fileTargets, err2 := splitBuildTargets(restArgs); err2 == nil && len(fileTargets) > 0 { + // add otelc.runtime.go manually to command line for file targets + dir := filepath.Dir(fileTargets[0]) + otelcRuntimePath := filepath.Join(dir, OtelcRuntimeFile) + if util.PathExists(otelcRuntimePath) { + restArgs = append(restArgs, otelcRuntimePath) + } + } + newArgs = append(newArgs, restArgs...) logger.InfoContext(ctx, "Running go build with toolexec", "args", newArgs) // Tell the sub-process the working directory diff --git a/tool/internal/setup/setup_test.go b/tool/internal/setup/setup_test.go index ecc999169..d9ff72ebd 100644 --- a/tool/internal/setup/setup_test.go +++ b/tool/internal/setup/setup_test.go @@ -11,6 +11,8 @@ import ( "testing" "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/tool/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/tools/go/packages" ) @@ -22,59 +24,137 @@ func TestGetPackages(t *testing.T) { args []string expectedCount int expectedPackages []string + expectError bool }{ { name: "single package", args: []string{"-a", "-o", "tmp", "./cmd"}, expectedCount: 1, expectedPackages: []string{"testmodule/cmd"}, + expectError: false, }, { name: "multiple packages", args: []string{"./cmd", "./foo/demo"}, expectedCount: 2, expectedPackages: []string{"testmodule/cmd", "testmodule/foo/demo"}, + expectError: false, }, { name: "wildcard pattern", args: []string{"./cmd/..."}, expectedCount: 1, expectedPackages: []string{"testmodule/cmd"}, + expectError: false, + }, + { + name: "file as a target", + args: []string{"./cmd/main.go"}, + expectedCount: 1, + expectedPackages: []string{commandLineArgumentsPackage}, + expectError: false, + }, + { + name: "file and pkg mixed targets", + args: []string{"./cmd/main.go", "./foo/demo"}, + expectedCount: 0, + expectedPackages: []string{}, + expectError: true, }, { name: "default to current directory", args: []string{}, expectedCount: 1, - expectedPackages: []string{"."}, + expectedPackages: []string{"testmodule"}, + expectError: false, }, { name: "current directory explicit", args: []string{"."}, expectedCount: 1, - expectedPackages: []string{"."}, + expectedPackages: []string{"testmodule"}, + expectError: false, }, { name: "nonexistent package mixed with valid", args: []string{"./cmd", "./nonexistent"}, expectedCount: 1, expectedPackages: []string{"testmodule/cmd"}, + expectError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pkgs, err := getBuildPackages(t.Context(), tt.args) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Len(t, pkgs, tt.expectedCount) - if len(pkgs) != tt.expectedCount { - t.Errorf("Expected %d packages, got %d", tt.expectedCount, len(pkgs)) + if tt.expectedPackages != nil { + pkgIDs := extractPackageIDs(pkgs) + checkPackages(t, pkgIDs, tt.expectedPackages) + } } + }) + } +} - if tt.expectedPackages != nil { - pkgIDs := extractPackageIDs(pkgs) - checkPackages(t, pkgIDs, tt.expectedPackages) +func TestSplitBuildTargets(t *testing.T) { + tests := []struct { + name string + targets []string + pkgTargets []string + fileTargets []string + expectError bool + }{ + { + name: "all package targets", + targets: []string{"./cmd", "./foo/demo"}, + pkgTargets: []string{"./cmd", "./foo/demo"}, + fileTargets: nil, + expectError: false, + }, + { + name: "all file targets", + targets: []string{"./cmd/main.go", "./cmd/util.go"}, + pkgTargets: nil, + fileTargets: []string{"./cmd/main.go", "./cmd/util.go"}, + expectError: false, + }, + { + name: "all file targets from different packages", + targets: []string{"./cmd/main.go", "./util/util.go"}, + pkgTargets: nil, + fileTargets: nil, + expectError: true, + }, + { + name: "mixed package and file targets with valid package", + targets: []string{"./cmd/main.go", "./foo/demo"}, + pkgTargets: nil, + fileTargets: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pkgTargets, fileTargets, err := splitBuildTargets(tt.targets) + if tt.expectError { + require.Error(t, err) + assert.Nil(t, pkgTargets) + assert.Nil(t, fileTargets) + } else { + assert.NoError(t, err) + for _, exp := range tt.pkgTargets { + assert.Contains(t, pkgTargets, exp, "Expected package target %q not found in %v", exp, pkgTargets) + } + for _, exp := range tt.fileTargets { + assert.Contains(t, fileTargets, exp, "Expected file target %q not found in %v", exp, fileTargets) + } } }) } @@ -126,6 +206,11 @@ func setupTestModule(t *testing.T, subDirs []string) { t.Fatalf("Failed to create go.mod: %v", err) } + mainGoPath := filepath.Join(tmpDir, "main.go") + if err := os.WriteFile(mainGoPath, []byte("package main\n\nfunc main() {}\n"), 0o644); err != nil { + t.Fatalf("Failed to create main.go: %v", err) + } + t.Chdir(tmpDir) }