Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion pkg/instrumentation/databasesql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@ package db
import (
"errors"
"fmt"
"net"
nurl "net/url"
"regexp"
"strings"
)

var dbNamePattern = regexp.MustCompile(`(?i)(^|[\s;?&])(?:dbname|database)\s*=\s*('([^']*)'|"([^"]*)"|[^;\s&]+)`)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regex is too tricky, e.g. from user:pass@tcp(host)/mydb?timeout=30s&database=other matches database=other

what about making parseDSN return also the dbname?


func ParseDbName(dsn string) string {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some unit tests to cover non-happy path scenarios?

if name := parseDBNameFromKeyValue(dsn); name != "" {
return name
}

var name string
var err error
for i := len(dsn) - 1; i >= 0; i-- {
Expand All @@ -30,6 +38,30 @@ func ParseDbName(dsn string) string {
return name
}

func parseDBNameFromKeyValue(dsn string) string {
match := dbNamePattern.FindStringSubmatch(dsn)
if len(match) == 0 {
return ""
}

value := strings.TrimSpace(match[2])
if value == "" {
return ""
}

if len(value) >= 2 {
if (value[0] == '\'' && value[len(value)-1] == '\'') || (value[0] == '"' && value[len(value)-1] == '"') {
value = value[1 : len(value)-1]
}
}

name, err := nurl.PathUnescape(value)
if err != nil {
return ""
}
return name
}

func parseDSN(driverName, dsn string) (addr string, err error) {
// TODO: need a more delegate DFA
switch driverName {
Expand All @@ -53,6 +85,17 @@ func parseDSN(driverName, dsn string) (addr string, err error) {
}

func parsePostgres(url string) (addr string, err error) {
if !strings.Contains(url, "://") {
host, port, ok := parseConnectionStringPair(url, []string{"host"})
if !ok || host == "" {
return "", errors.New("invalid Postgres DSN")
}
if port == "" {
port = "5432"
}
return net.JoinHostPort(strings.Trim(host, "[]"), port), nil
}
Comment on lines +88 to +97

u, err := nurl.Parse(url)
if err != nil {
return "", err
Expand Down Expand Up @@ -97,11 +140,49 @@ func parseMySQL(dsn string) (addr string, err error) {
}
}
if i >= 0 && j > i {
return dsn[i+1 : j], nil
addr := dsn[i+1 : j]
if strings.TrimSpace(addr) == "" {
return "localhost:3306", nil
}
return addr, nil
}

if strings.Contains(dsn, "@/") || strings.HasPrefix(dsn, "/") {
return "localhost:3306", nil
}
return "", errors.New("invalid MySQL DSN")
}

func parseConnectionStringPair(dsn string, keys []string) (host string, port string, ok bool) {
pairs := strings.Fields(dsn)
var hostFound bool

for _, pair := range pairs {
parts := strings.SplitN(pair, "=", 2)
if len(parts) != 2 {
continue
}
key := strings.ToLower(strings.TrimSpace(parts[0]))
value := strings.Trim(strings.TrimSpace(parts[1]), `"'`)

for _, k := range keys {
if key == k {
host = value
hostFound = true
break
}
}
if key == "port" {
port = value
}
}

if !hostFound {
return "", "", false
}
return host, port, true
}

func parseClickHouse(dsn string) (addr string, err error) {
// ClickHouse DSN formats:
// tcp://host:port?database=dbname&username=user&password=pass
Expand Down
5 changes: 5 additions & 0 deletions test/apps/dbclient/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ var (

func init() {
sql.Register("testdb", &testDriver{})
sql.Register("mysql", &testDriver{})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the value of these new registrations in the test?

sql.Register("postgres", &testDriver{})
sql.Register("postgresql", &testDriver{})
sql.Register("mssql", &testDriver{})
sql.Register("sqlserver", &testDriver{})
}

func main() {
Expand Down
88 changes: 82 additions & 6 deletions test/integration/db_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func TestDBClientPing(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildAndRun("dbclient", "-op=ping")
f.BuildSharedAndRun("dbclient", "-op=ping")

span := f.RequireSingleSpan()
require.Equal(t, "PING", span.Name())
Expand All @@ -32,7 +32,7 @@ func TestDBClientPing(t *testing.T) {
func TestDBClientExec(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildAndRun("dbclient", "-op=exec")
f.BuildSharedAndRun("dbclient", "-op=exec")

span := f.RequireSingleSpan()
require.Equal(t, "INSERT", span.Name())
Expand All @@ -47,7 +47,7 @@ func TestDBClientExec(t *testing.T) {
func TestDBClientQuery(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildAndRun("dbclient", "-op=query")
f.BuildSharedAndRun("dbclient", "-op=query")

span := f.RequireSingleSpan()
require.Equal(t, "SELECT", span.Name())
Expand All @@ -62,7 +62,7 @@ func TestDBClientQuery(t *testing.T) {
func TestDBClientPrepareAndQuery(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildAndRun("dbclient", "-op=prepare")
f.BuildSharedAndRun("dbclient", "-op=prepare")

// PrepareContext doesn't create a span directly, but stmt.QueryContext does
spans := testutil.AllSpans(f.Traces())
Expand All @@ -76,7 +76,7 @@ func TestDBClientPrepareAndQuery(t *testing.T) {
func TestDBClientTransaction(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildAndRun("dbclient", "-op=tx")
f.BuildSharedAndRun("dbclient", "-op=tx")

spans := testutil.AllSpans(f.Traces())
// BeginTx -> ExecContext -> Commit = 3 spans
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestDBClientTransaction(t *testing.T) {
func TestDBClientAll(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildAndRun("dbclient",
f.BuildSharedAndRun("dbclient",
"-driver=testdb",
"-dsn=user:pass@tcp(127.0.0.1:3306)/testdb?charset=utf8",
"-op=all",
Expand Down Expand Up @@ -181,3 +181,79 @@ func TestDBClientAll(t *testing.T) {
"testdb",
)
}

func TestDBClientPostgresLibpqDSN(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildSharedAndRun("dbclient",
"-driver=postgres",
"-dsn=host=localhost port=5432 dbname=mydb user=postgres",
"-op=ping",
)

span := f.RequireSingleSpan()
require.Equal(t, "PING", span.Name())
testutil.RequireDBClientSemconv(t, span,
"PING",
"ping",
"localhost", 5432,
"mydb",
)
}

func TestDBClientPostgresLibpqIPv6DSN(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildSharedAndRun("dbclient",
"-driver=postgres",
"-dsn=host=::1 port=5432 dbname=mydb user=postgres",
"-op=ping",
)

span := f.RequireSingleSpan()
require.Equal(t, "PING", span.Name())
testutil.RequireDBClientSemconv(t, span,
"PING",
"ping",
"::1", 5432,
"mydb",
)
}

func TestDBClientMySQLDefaultHostDSN(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildSharedAndRun("dbclient",
"-driver=mysql",
"-dsn=user:pass@/mydb",
"-op=ping",
)

span := f.RequireSingleSpan()
require.Equal(t, "PING", span.Name())
testutil.RequireDBClientSemconv(t, span,
"PING",
"ping",
"localhost", 3306,
"mydb",
)
}

func TestDBClientSQLServerDatabaseKeyDSN(t *testing.T) {
f := testutil.NewTestFixture(t)

f.BuildSharedAndRun("dbclient",
"-driver=sqlserver",
"-dsn=Server=localhost,1433;Database=myDB;User Id=sa;Password=Pass123",
"-op=ping",
)

span := f.RequireSingleSpan()
require.Equal(t, "PING", span.Name())
testutil.RequireDBClientSemconv(t, span,
"PING",
"ping",
"localhost", 1433,
"myDB",
)
}
11 changes: 11 additions & 0 deletions test/testutil/fixture.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ func (f *TestFixture) Build(appName string) {
Build(f.t, f.resolveAppPath(appName), "go", "build", "-a")
}

// BuildShared builds a test application once and reuses the binary across tests.
func (f *TestFixture) BuildShared(appName string) {
BuildShared(f.t, f.resolveAppPath(appName), "go", "build", "-a")
}

// Server represents a running server process.
type Server struct {
t *testing.T
Expand Down Expand Up @@ -126,6 +131,12 @@ func (f *TestFixture) BuildAndRun(appName string, args ...string) string {
return f.Run(appName, args...)
}

// BuildSharedAndRun builds a test application once and runs it.
func (f *TestFixture) BuildSharedAndRun(appName string, args ...string) string {
f.BuildShared(appName)
return f.Run(appName, args...)
}

// RequireTraceCount asserts the expected number of traces were collected.
func (f *TestFixture) RequireTraceCount(expected int) {
stats := AnalyzeTraces(f.t, f.collector.GetTraces())
Expand Down
63 changes: 50 additions & 13 deletions test/testutil/infra.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ package testutil

import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"sync"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -34,31 +36,66 @@ func newCmd(ctx context.Context, dir string, args ...string) *exec.Cmd {
return cmd
}

// Build builds the application with the instrumentation tool.
func Build(t *testing.T, appDir string, args ...string) {
type sharedBuild struct {
once sync.Once
err error
}

var sharedBuilds sync.Map

func appBinaryName() string {
name := appBinName
if util.IsWindows() {
name += ".exe"
}
return name
}

func appBinaryPath(appDir string) string {
return filepath.Join(appDir, appBinaryName())
}

func buildApp(ctx context.Context, appDir string, args ...string) error {
binName := otelcBinName
if util.IsWindows() {
binName += ".exe"
}
pwd, err := os.Getwd()
require.NoError(t, err)
otelcPath := filepath.Join(pwd, "..", "..", binName)

// Use a consistent binary name for all test apps
outputName := appBinName
if util.IsWindows() {
outputName += ".exe"
if err != nil {
return err
}
otelcPath := filepath.Join(pwd, "..", "..", binName)

args = append(args, "-o", outputName)
args = append(args, "-o", appBinaryName())
args = append([]string{otelcPath}, args...)

cmd := newCmd(t.Context(), appDir, args...)
cmd := newCmd(ctx, appDir, args...)
out, err := cmd.CombinedOutput()
require.NoError(t, err, string(out))
if err != nil {
return fmt.Errorf("build failed: %w: %s", err, string(out))
}
return nil
}

// Build builds the application with the instrumentation tool.
func Build(t *testing.T, appDir string, args ...string) {
if err := buildApp(t.Context(), appDir, args...); err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
os.Remove(filepath.Join(appDir, outputName))
_ = os.Remove(appBinaryPath(appDir))
})
}

// BuildShared builds the app once per process and keeps the binary for reuse.
func BuildShared(t *testing.T, appDir string, args ...string) {
t.Helper()
entry, _ := sharedBuilds.LoadOrStore(appDir, &sharedBuild{})
build := entry.(*sharedBuild)
build.once.Do(func() {
build.err = buildApp(t.Context(), appDir, args...)
})
require.NoError(t, build.err)
}

// Run runs the application and returns the output.
Expand Down
Loading