diff --git a/pkg/instrumentation/databasesql/client.go b/pkg/instrumentation/databasesql/client.go index 955437f70..db28aacdc 100644 --- a/pkg/instrumentation/databasesql/client.go +++ b/pkg/instrumentation/databasesql/client.go @@ -40,11 +40,15 @@ func (n dbClientEnabler) Enable() bool { var clientEnabler = dbClientEnabler{} func beforeOpenInstrumentation(ictx inst.HookContext, driverName, dataSourceName string) { - addr, err := parseDSN(driverName, dataSourceName) - if err != nil { + info := ParseDSN(driverName, dataSourceName) + addr := info.Addr() + if addr == "" { addr = "unknown" } - dbName := ParseDbName(dataSourceName) + dbName := info.DBName + if dbName == "" { + dbName = ParseDbName(dataSourceName) + } ictx.SetData(map[string]string{ "endpoint": addr, "driver": driverName, diff --git a/pkg/instrumentation/databasesql/dsnparse/parse.go b/pkg/instrumentation/databasesql/dsnparse/parse.go new file mode 100644 index 000000000..6cf2083b8 --- /dev/null +++ b/pkg/instrumentation/databasesql/dsnparse/parse.go @@ -0,0 +1,454 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package dsnparse + +import ( + nurl "net/url" + "strings" +) + +// DSNInfo holds the parsed server address components and database name from a +// data source name. +type DSNInfo struct { + Host string + Port string + DBName string +} + +// Addr returns the host:port pair. When Port is empty only the host is returned. +func (d DSNInfo) Addr() string { + if d.Host == "" { + return "" + } + if d.Port == "" { + return d.Host + } + return d.Host + ":" + d.Port +} + +// ParseDSN parses a driver-specific data source name and returns structured +// connection information. It tries multiple well-known formats in order and +// never panics. Unrecognised drivers return a zero-value DSNInfo. +func ParseDSN(driverName, dsn string) DSNInfo { + switch driverName { + case "postgres", "pgx", "postgresql": + return parsePostgresDSN(dsn) + case "mysql": + return parseMySQLDSN(dsn) + case "sqlite3", "sqlite": + return parseSQLiteDSN(dsn) + case "sqlserver", "mssql": + return parseSQLServerDSN(dsn) + case "clickhouse": + return parseClickHouseDSN(dsn) + case "godror", "oracle", "oci8", "go-oci8": + return parseOracleDSN(dsn) + } + return DSNInfo{} +} + +// ParseDbName extracts the database name from a generic DSN by finding the +// last '/' and trimming any query-string suffix. Retained for backward +// compatibility; prefer ParseDSN when the driver name is known. +func ParseDbName(dsn string) string { + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + dbname := dsn[i+1:] + if idx := strings.IndexAny(dbname, "?&"); idx >= 0 { + dbname = dbname[:idx] + } + if unesc, err := nurl.PathUnescape(dbname); err == nil { + return unesc + } + return dbname + } + } + return "" +} + +// LegacyParseDSN wraps ParseDSN and preserves the original (addr, error) shape +// used by the db package's beforeOpenInstrumentation. +func LegacyParseDSN(driverName, dsn string) (string, error) { + info := ParseDSN(driverName, dsn) + if addr := info.Addr(); addr != "" { + return addr, nil + } + return driverName, nil +} + +// ---- PostgreSQL --------------------------------------------------------------- + +// parsePostgresDSN handles both RFC 3986 URL format and the PostgreSQL libpq +// key=value connection string format. +// +// URL examples: +// +// postgres://user:pass@host:5432/mydb +// postgresql://host/mydb +// +// Key-value examples: +// +// host=localhost port=5432 dbname=mydb +// host='db.example.com' port=5432 dbname=prod user=alice +func parsePostgresDSN(dsn string) DSNInfo { + if strings.Contains(dsn, "://") { + if u, err := nurl.Parse(dsn); err == nil && + (u.Scheme == "postgres" || u.Scheme == "postgresql") { + host := u.Hostname() + port := u.Port() + if port == "" { + port = "5432" + } + dbName := strings.TrimPrefix(u.Path, "/") + return DSNInfo{Host: host, Port: port, DBName: dbName} + } + } + return parseLibpqKV(dsn) +} + +// parseLibpqKV parses the PostgreSQL libpq keyword=value connection string. +// Values may be unquoted (terminated by whitespace) or single-quoted (with +// backslash escape support inside quotes). +func parseLibpqKV(dsn string) DSNInfo { + info := DSNInfo{Host: "localhost", Port: "5432"} + rest := strings.TrimSpace(dsn) + for len(rest) > 0 { + rest = strings.TrimLeft(rest, " \t\n\r") + if rest == "" { + break + } + eqIdx := strings.IndexByte(rest, '=') + if eqIdx < 0 { + break + } + key := strings.TrimSpace(rest[:eqIdx]) + rest = rest[eqIdx+1:] + + var val string + if strings.HasPrefix(rest, "'") { + rest = rest[1:] + var b strings.Builder + for len(rest) > 0 { + switch { + case rest[0] == '\\' && len(rest) > 1: + b.WriteByte(rest[1]) + rest = rest[2:] + case rest[0] == '\'': + rest = rest[1:] + val = b.String() + goto nextKV + default: + b.WriteByte(rest[0]) + rest = rest[1:] + } + } + val = b.String() + } else { + end := strings.IndexAny(rest, " \t\n\r") + if end < 0 { + end = len(rest) + } + val = rest[:end] + rest = rest[end:] + } + nextKV: + switch key { + case "host": + info.Host = val + case "port": + info.Port = val + case "dbname": + info.DBName = val + } + } + return info +} + +// ---- MySQL ------------------------------------------------------------------- + +// parseMySQLDSN handles the go-sql-driver/mysql DSN format: +// +// [user[:password]@][protocol[(address)]]/dbname[?params] +// +// It also handles the non-standard form where the address is not wrapped in +// parentheses (e.g. user:pass@tcp:3306/dbname or user:pass@host:3306/dbname). +func parseMySQLDSN(dsn string) DSNInfo { + // Find the last @ that precedes the first '(' so we skip '@' inside passwords. + atIdx := -1 + for i := 0; i < len(dsn); i++ { + if dsn[i] == '@' { + atIdx = i + } + if dsn[i] == '(' { + break + } + } + rest := dsn + if atIdx >= 0 { + rest = dsn[atIdx+1:] + } + + var addrStr, dbPart string + + if lp := strings.IndexByte(rest, '('); lp >= 0 { + // Standard parenthesised address: proto(host:port)/dbname + if rp := strings.IndexByte(rest[lp:], ')'); rp >= 0 { + addrStr = rest[lp+1 : lp+rp] + dbPart = rest[lp+rp+1:] + } + } else { + // Non-standard: no parentheses around the address. + // Locate the '/' that separates the address from the database name. + if sl := strings.IndexByte(rest, '/'); sl >= 0 { + addrStr = rest[:sl] + dbPart = rest[sl:] + } else { + addrStr = rest + } + addrStr = mysqlStripProtocol(addrStr) + } + + host, port := mysqlSplitAddr(addrStr) + dbName := mysqlDBName(dbPart) + return DSNInfo{Host: host, Port: port, DBName: dbName} +} + +// mysqlKnownProtocols lists the transport keywords used by go-sql-driver/mysql. +var mysqlKnownProtocols = map[string]bool{ + "tcp": true, "unix": true, "udp": true, "pipe": true, +} + +// mysqlStripProtocol removes a leading "proto:" prefix when proto is a known +// MySQL transport keyword, normalising a bare port number to "localhost:port". +func mysqlStripProtocol(addr string) string { + colon := strings.IndexByte(addr, ':') + if colon < 0 { + return addr + } + proto := addr[:colon] + if !mysqlKnownProtocols[proto] { + return addr + } + rest := addr[colon+1:] + if rest == "" { + return "" + } + // If rest is pure digits it is just a port number with an implicit localhost. + allDigits := true + for _, c := range rest { + if c < '0' || c > '9' { + allDigits = false + break + } + } + if allDigits { + return "localhost:" + rest + } + return rest +} + +// mysqlSplitAddr splits a "host:port" string, defaulting to port 3306 when no +// port is present. IPv6 addresses enclosed in brackets are handled correctly. +func mysqlSplitAddr(addr string) (host, port string) { + if addr == "" { + return "", "" + } + if strings.HasPrefix(addr, "[") { + if rb := strings.LastIndexByte(addr, ']'); rb >= 0 { + host = addr[1:rb] + rest := addr[rb+1:] + if strings.HasPrefix(rest, ":") { + return host, rest[1:] + } + return host, "3306" + } + } + if colon := strings.LastIndexByte(addr, ':'); colon >= 0 { + return addr[:colon], addr[colon+1:] + } + // Unix socket paths contain '/' but no port. + if strings.ContainsRune(addr, '/') { + return addr, "" + } + return addr, "3306" +} + +// mysqlDBName trims the leading '/' and any query string from a dbname segment. +func mysqlDBName(s string) string { + s = strings.TrimPrefix(s, "/") + if i := strings.IndexByte(s, '?'); i >= 0 { + s = s[:i] + } + return s +} + +// ---- SQLite ------------------------------------------------------------------ + +// parseSQLiteDSN handles sqlite3 / sqlite DSN strings. SQLite is always +// file-local, so Host and Port are not populated. DBName is set to the +// filename (or ":memory:" for in-memory databases). +func parseSQLiteDSN(dsn string) DSNInfo { + return DSNInfo{Host: "sqlite3", DBName: sqliteDBName(dsn)} +} + +// sqliteDBName extracts the database name from a SQLite DSN. It handles +// file: URI schemes, the :memory: shorthand, and bare filenames. +func sqliteDBName(dsn string) string { + if strings.HasPrefix(dsn, "file:") { + path := dsn[len("file:"):] + if i := strings.IndexByte(path, '?'); i >= 0 { + path = path[:i] + } + bare := strings.TrimPrefix(path, "//") + if bare == ":memory:" { + return ":memory:" + } + if i := strings.LastIndexByte(path, '/'); i >= 0 { + return path[i+1:] + } + return path + } + if dsn == ":memory:" { + return ":memory:" + } + if i := strings.IndexByte(dsn, '?'); i >= 0 { + dsn = dsn[:i] + } + if i := strings.LastIndexByte(dsn, '/'); i >= 0 { + return dsn[i+1:] + } + return dsn +} + +// ---- SQL Server -------------------------------------------------------------- + +// parseSQLServerDSN handles SQL Server DSNs in both URL format and the +// ADO.NET semicolon-delimited key=value connection-string format. +// +// URL examples: +// +// sqlserver://user:pass@host:1433?database=mydb +// mssql://host:1433?database=mydb +// +// ADO.NET examples: +// +// server=host;port=1433;database=mydb;user id=sa;password=secret +// Server=host,1433;Database=mydb;User Id=sa;Password=secret +func parseSQLServerDSN(dsn string) DSNInfo { + if strings.Contains(dsn, "://") { + if u, err := nurl.Parse(dsn); err == nil { + host := u.Hostname() + port := u.Port() + if port == "" { + port = "1433" + } + dbName := u.Query().Get("database") + return DSNInfo{Host: host, Port: port, DBName: dbName} + } + } + return parseSQLServerKV(dsn) +} + +// parseSQLServerKV parses the ADO.NET semicolon-delimited key=value format. +func parseSQLServerKV(dsn string) DSNInfo { + var host, port, dbName string + for _, pair := range strings.Split(dsn, ";") { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + key, val, ok := strings.Cut(pair, "=") + if !ok { + continue + } + key = strings.ToLower(strings.TrimSpace(key)) + val = strings.TrimSpace(val) + switch key { + case "server", "data source": + // Accept both "host,port" and "host\instance" forms. + if comma := strings.IndexByte(val, ','); comma >= 0 { + host = val[:comma] + port = val[comma+1:] + } else if bs := strings.IndexByte(val, '\\'); bs >= 0 { + host = val[:bs] + } else { + host = val + } + case "port": + port = val + case "database", "initial catalog": + dbName = val + } + } + if host != "" && port == "" { + port = "1433" + } + return DSNInfo{Host: host, Port: port, DBName: dbName} +} + +// ---- ClickHouse -------------------------------------------------------------- + +// parseClickHouseDSN handles ClickHouse DSNs, which are always URL-formatted. +func parseClickHouseDSN(dsn string) DSNInfo { + u, err := nurl.Parse(dsn) + if err != nil { + return DSNInfo{} + } + host := u.Hostname() + port := u.Port() + if port == "" { + switch u.Scheme { + case "http": + port = "8123" + case "https": + port = "8443" + default: // tcp, native, clickhouse + port = "9000" + } + } + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = u.Query().Get("database") + } + return DSNInfo{Host: host, Port: port, DBName: dbName} +} + +// ---- Oracle ------------------------------------------------------------------ + +// parseOracleDSN handles Oracle DSNs in URL format or the traditional +// user/pass@host:port/service notation. +func parseOracleDSN(dsn string) DSNInfo { + if strings.Contains(dsn, "://") { + if u, err := nurl.Parse(dsn); err == nil && u.Host != "" { + host := u.Hostname() + port := u.Port() + if port == "" { + port = "1521" + } + dbName := strings.TrimPrefix(u.Path, "/") + return DSNInfo{Host: host, Port: port, DBName: dbName} + } + } + atIdx := strings.IndexByte(dsn, '@') + if atIdx < 0 { + return DSNInfo{} + } + connStr := strings.TrimPrefix(dsn[atIdx+1:], "//") + var hostPort, dbName string + if sl := strings.IndexByte(connStr, '/'); sl >= 0 { + hostPort = connStr[:sl] + dbName = connStr[sl+1:] + } else { + hostPort = connStr + } + host, port := oracleSplitHostPort(hostPort) + return DSNInfo{Host: host, Port: port, DBName: dbName} +} + +func oracleSplitHostPort(hp string) (host, port string) { + if i := strings.LastIndexByte(hp, ':'); i >= 0 { + return hp[:i], hp[i+1:] + } + return hp, "1521" +} diff --git a/pkg/instrumentation/databasesql/dsnparse/parse_test.go b/pkg/instrumentation/databasesql/dsnparse/parse_test.go new file mode 100644 index 000000000..a1ea73aaf --- /dev/null +++ b/pkg/instrumentation/databasesql/dsnparse/parse_test.go @@ -0,0 +1,536 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package dsnparse + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseDSN_Postgres(t *testing.T) { + tests := []struct { + name string + dsn string + wantHost string + wantPort string + wantDB string + }{ + // ---- URL format ---- + { + name: "url with credentials and port", + dsn: "postgres://user:pass@localhost:5432/mydb", + wantHost: "localhost", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "url no credentials", + dsn: "postgres://localhost:5432/mydb", + wantHost: "localhost", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "url no port uses default 5432", + dsn: "postgres://localhost/mydb", + wantHost: "localhost", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "postgresql scheme", + dsn: "postgresql://user:pass@db.example.com:5433/prod", + wantHost: "db.example.com", + wantPort: "5433", + wantDB: "prod", + }, + { + name: "url with query params", + dsn: "postgres://user:pass@host:5432/mydb?sslmode=require&connect_timeout=5", + wantHost: "host", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "url with ip address", + dsn: "postgres://10.0.0.1:5432/mydb", + wantHost: "10.0.0.1", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "url with encoded password", + dsn: "postgres://user:p%40ss@host:5432/mydb", + wantHost: "host", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "url empty dbname", + dsn: "postgres://host:5432/", + wantHost: "host", + wantPort: "5432", + wantDB: "", + }, + // ---- Libpq key=value format (the previously failing case) ---- + { + name: "libpq all fields", + dsn: "host=localhost port=5432 dbname=mydb", + wantHost: "localhost", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "libpq no port uses default", + dsn: "host=db.example.com dbname=prod", + wantHost: "db.example.com", + wantPort: "5432", + wantDB: "prod", + }, + { + name: "libpq with extra fields", + dsn: "host=localhost port=5432 dbname=mydb user=alice password=secret sslmode=disable", + wantHost: "localhost", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "libpq single-quoted host", + dsn: "host='db.example.com' port=5432 dbname=prod", + wantHost: "db.example.com", + wantPort: "5432", + wantDB: "prod", + }, + { + name: "libpq only dbname uses default host and port", + dsn: "dbname=mydb", + wantHost: "localhost", + wantPort: "5432", + wantDB: "mydb", + }, + { + name: "libpq custom port", + dsn: "host=replica.internal port=5433 dbname=analytics", + wantHost: "replica.internal", + wantPort: "5433", + wantDB: "analytics", + }, + { + name: "libpq single-quoted dbname with space", + dsn: "host=localhost port=5432 dbname='my db'", + wantHost: "localhost", + wantPort: "5432", + wantDB: "my db", + }, + // pgx uses the same DSN format + { + name: "pgx driver url format", + dsn: "postgres://user:pass@host:5432/mydb", + wantHost: "host", + wantPort: "5432", + wantDB: "mydb", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, driver := range []string{"postgres", "pgx", "postgresql"} { + got := ParseDSN(driver, tt.dsn) + assert.Equal(t, tt.wantHost, got.Host, "driver=%s Host", driver) + assert.Equal(t, tt.wantPort, got.Port, "driver=%s Port", driver) + assert.Equal(t, tt.wantDB, got.DBName, "driver=%s DBName", driver) + } + }) + } +} + +func TestParseDSN_MySQL(t *testing.T) { + tests := []struct { + name string + dsn string + wantHost string + wantPort string + wantDB string + }{ + // ---- Standard go-sql-driver/mysql format (parenthesised address) ---- + { + name: "full credentials tcp", + dsn: "user:pass@tcp(localhost:3306)/mydb", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "no password", + dsn: "user@tcp(localhost:3306)/mydb", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "no credentials", + dsn: "tcp(localhost:3306)/mydb", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "ip address", + dsn: "user:pass@tcp(127.0.0.1:3306)/mydb", + wantHost: "127.0.0.1", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "with query params", + dsn: "user:pass@tcp(localhost:3306)/mydb?charset=utf8&timeout=5s", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "remote host", + dsn: "user:pass@tcp(db.example.com:3306)/prod", + wantHost: "db.example.com", + wantPort: "3306", + wantDB: "prod", + }, + { + name: "tcp no port", + dsn: "user:pass@tcp(localhost)/mydb", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "password with @ symbol uses last @", + dsn: "user:p@ss@tcp(host:3306)/mydb", + wantHost: "host", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "unix socket", + dsn: "user:pass@unix(/tmp/mysql.sock)/mydb", + wantHost: "/tmp/mysql.sock", + wantPort: "", + wantDB: "mydb", + }, + { + name: "just dbname via slash", + dsn: "/mydb", + wantHost: "", + wantPort: "", + wantDB: "mydb", + }, + { + name: "credentials and dbname only", + dsn: "user:pass@/mydb", + wantHost: "", + wantPort: "", + wantDB: "mydb", + }, + // ---- Non-standard format without parentheses (previously failing case) ---- + { + name: "tcp protocol with bare port", + dsn: "user:pass@tcp:3306/mydb", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "no protocol host:port", + dsn: "user:pass@localhost:3306/mydb", + wantHost: "localhost", + wantPort: "3306", + wantDB: "mydb", + }, + { + name: "ip host:port no protocol", + dsn: "user:pass@10.0.0.5:3306/proddb", + wantHost: "10.0.0.5", + wantPort: "3306", + wantDB: "proddb", + }, + { + name: "empty db after slash", + dsn: "user:pass@tcp(host:3306)/", + wantHost: "host", + wantPort: "3306", + wantDB: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ParseDSN("mysql", tt.dsn) + assert.Equal(t, tt.wantHost, got.Host, "Host") + assert.Equal(t, tt.wantPort, got.Port, "Port") + assert.Equal(t, tt.wantDB, got.DBName, "DBName") + }) + } +} + +func TestParseDSN_SQLite(t *testing.T) { + tests := []struct { + name string + dsn string + wantDB string + }{ + // ---- file: URI format (the previously failing case) ---- + { + name: "file URI with query params", + dsn: "file:test.db?cache=shared", + wantDB: "test.db", + }, + { + name: "file URI no params", + dsn: "file:test.db", + wantDB: "test.db", + }, + { + name: "file URI absolute path", + dsn: "file:/var/lib/myapp/data.db", + wantDB: "data.db", + }, + { + name: "file URI relative path with dirs", + dsn: "file:../data/test.db?mode=ro", + wantDB: "test.db", + }, + { + name: "file URI in-memory", + dsn: "file::memory:?cache=shared", + wantDB: ":memory:", + }, + { + name: "file URI double-slash in-memory", + dsn: "file://:memory:", + wantDB: ":memory:", + }, + // ---- :memory: shorthand ---- + { + name: "in-memory shorthand", + dsn: ":memory:", + wantDB: ":memory:", + }, + // ---- bare filename ---- + { + name: "bare filename", + dsn: "test.db", + wantDB: "test.db", + }, + { + name: "bare filename with query", + dsn: "test.db?cache=shared", + wantDB: "test.db", + }, + { + name: "bare path", + dsn: "/var/lib/data.db", + wantDB: "data.db", + }, + { + name: "production db name", + dsn: "file:production.sqlite3?mode=ro&cache=shared", + wantDB: "production.sqlite3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, driver := range []string{"sqlite3", "sqlite"} { + got := ParseDSN(driver, tt.dsn) + assert.Equal(t, "sqlite3", got.Host, "driver=%s Host should always be sqlite3", driver) + assert.Equal(t, tt.wantDB, got.DBName, "driver=%s DBName", driver) + } + }) + } +} + +func TestParseDSN_SQLServer(t *testing.T) { + tests := []struct { + name string + dsn string + wantHost string + wantPort string + wantDB string + }{ + // ---- URL format ---- + { + name: "sqlserver url", + dsn: "sqlserver://user:pass@host:1433?database=mydb", + wantHost: "host", + wantPort: "1433", + wantDB: "mydb", + }, + { + name: "mssql url", + dsn: "mssql://user:pass@host:1433?database=mydb", + wantHost: "host", + wantPort: "1433", + wantDB: "mydb", + }, + { + name: "sqlserver url default port", + dsn: "sqlserver://user:pass@db.example.com?database=prod", + wantHost: "db.example.com", + wantPort: "1433", + wantDB: "prod", + }, + { + name: "sqlserver url ip", + dsn: "sqlserver://sa:secret@10.0.0.1:1433?database=sales", + wantHost: "10.0.0.1", + wantPort: "1433", + wantDB: "sales", + }, + // ---- ADO.NET semicolon key=value format ---- + { + name: "ado.net lowercase keys", + dsn: "server=host;port=1433;database=mydb;user id=sa;password=secret", + wantHost: "host", + wantPort: "1433", + wantDB: "mydb", + }, + { + name: "ado.net mixed case keys", + dsn: "Server=db.example.com;Database=prod;User Id=sa;Password=secret", + wantHost: "db.example.com", + wantPort: "1433", + wantDB: "prod", + }, + { + name: "ado.net server with comma port", + dsn: "Server=host,1434;Database=mydb;User Id=sa;Password=secret", + wantHost: "host", + wantPort: "1434", + wantDB: "mydb", + }, + { + name: "ado.net initial catalog key", + dsn: "server=host;initial catalog=mydb;user id=sa", + wantHost: "host", + wantPort: "1433", + wantDB: "mydb", + }, + { + name: "ado.net data source key", + dsn: "data source=db.example.com;initial catalog=sales;user id=sa;password=p", + wantHost: "db.example.com", + wantPort: "1433", + wantDB: "sales", + }, + { + name: "ado.net server with backslash instance", + dsn: "Server=host\\SQLEXPRESS;Database=mydb;User Id=sa;Password=p", + wantHost: "host", + wantPort: "1433", + wantDB: "mydb", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, driver := range []string{"sqlserver", "mssql"} { + got := ParseDSN(driver, tt.dsn) + assert.Equal(t, tt.wantHost, got.Host, "driver=%s Host", driver) + assert.Equal(t, tt.wantPort, got.Port, "driver=%s Port", driver) + assert.Equal(t, tt.wantDB, got.DBName, "driver=%s DBName", driver) + } + }) + } +} + +func TestParseDSN_UnknownDriver(t *testing.T) { + got := ParseDSN("someunknowndriver", "host=localhost") + assert.Equal(t, DSNInfo{}, got) + assert.Equal(t, "", got.Addr()) +} + +func TestDSNInfo_Addr(t *testing.T) { + tests := []struct { + info DSNInfo + want string + }{ + {DSNInfo{Host: "localhost", Port: "5432"}, "localhost:5432"}, + {DSNInfo{Host: "localhost", Port: ""}, "localhost"}, + {DSNInfo{Host: "", Port: "5432"}, ""}, + {DSNInfo{}, ""}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, tt.info.Addr()) + } +} + +func TestParseDbName(t *testing.T) { + tests := []struct { + name string + dsn string + want string + }{ + { + name: "standard url dbname", + dsn: "postgres://user:pass@host:5432/mydb", + want: "mydb", + }, + { + name: "mysql standard", + dsn: "user:pass@tcp(host:3306)/mydb", + want: "mydb", + }, + { + name: "with query string", + dsn: "user:pass@tcp(host:3306)/mydb?charset=utf8", + want: "mydb", + }, + { + name: "url encoded dbname", + dsn: "postgres://host/my%20db", + want: "my db", + }, + { + name: "no slash returns empty", + dsn: "host=localhost dbname=mydb", + want: "", + }, + { + name: "sqlite bare filename", + dsn: "file:test.db?cache=shared", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, ParseDbName(tt.dsn)) + }) + } +} + +func TestParseDSN_SQLitePreviouslyFailing(t *testing.T) { + // Regression test: prior to this fix, file:test.db?cache=shared returned + // DBName="" because ParseDbName found no '/' in the string. + got := ParseDSN("sqlite3", "file:test.db?cache=shared") + assert.Equal(t, "test.db", got.DBName, "DBName must be the filename, not empty") + assert.Equal(t, "sqlite3", got.Host) +} + +func TestParseDSN_PostgresLibpqPreviouslyFailing(t *testing.T) { + // Regression test: prior to this fix, the libpq KV format fell through the + // URL check (wrong scheme) and returned addr="unknown" / dbname="". + got := ParseDSN("postgres", "host=localhost port=5432 dbname=mydb") + assert.Equal(t, "localhost", got.Host) + assert.Equal(t, "5432", got.Port) + assert.Equal(t, "mydb", got.DBName) +} + +func TestParseDSN_MySQLNoParensPreviouslyFailing(t *testing.T) { + // Regression test: prior to this fix, user:pass@tcp:3306/dbname returned + // an error because the parser required parentheses around the address. + got := ParseDSN("mysql", "user:pass@tcp:3306/mydb") + assert.Equal(t, "localhost", got.Host) + assert.Equal(t, "3306", got.Port) + assert.Equal(t, "mydb", got.DBName) +} diff --git a/pkg/instrumentation/databasesql/go.mod b/pkg/instrumentation/databasesql/go.mod index 264af310d..c413281d4 100644 --- a/pkg/instrumentation/databasesql/go.mod +++ b/pkg/instrumentation/databasesql/go.mod @@ -36,7 +36,7 @@ require ( go.opentelemetry.io/contrib/exporters/autoexport v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/runtime v0.64.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.14.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.14.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.19.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect @@ -46,10 +46,10 @@ require ( go.opentelemetry.io/otel/exporters/stdout/stdoutlog v0.14.0 // indirect go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0 // indirect go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.38.0 // indirect - go.opentelemetry.io/otel/log v0.14.0 // indirect + go.opentelemetry.io/otel/log v0.19.0 // indirect go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect - go.opentelemetry.io/otel/sdk/log v0.14.0 // indirect + go.opentelemetry.io/otel/sdk/log v0.19.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect golang.org/x/net v0.52.0 // indirect diff --git a/pkg/instrumentation/databasesql/go.sum b/pkg/instrumentation/databasesql/go.sum index ac7216fac..1707e1fb8 100644 --- a/pkg/instrumentation/databasesql/go.sum +++ b/pkg/instrumentation/databasesql/go.sum @@ -59,8 +59,8 @@ go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.14.0 h1:OMqPldHt79PqWKOMYIAQs3CxAi7RLgPxwfFSwr4ZxtM= go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.14.0/go.mod h1:1biG4qiqTxKiUCtoWDPpL3fB3KxVwCiGw81j3nKMuHE= -go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.14.0 h1:QQqYw3lkrzwVsoEX0w//EhH/TCnpRdEenKBOOEIMjWc= -go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.14.0/go.mod h1:gSVQcr17jk2ig4jqJ2DX30IdWH251JcNAecvrqTxH1s= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.19.0 h1:HIBTQ3VO5aupLKjC90JgMqpezVXwFuq6Ryjn0/izoag= +go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.19.0/go.mod h1:ji9vId85hMxqfvICA0Jt8JqEdrXaAkcpkI9HPXya0ro= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0 h1:vl9obrcoWVKp/lwl8tRE33853I8Xru9HFbw/skNeLs8= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.38.0/go.mod h1:GAXRxmLJcVM3u22IjTg74zWBrRCKq8BnOqUVLodpcpw= go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 h1:w1K+pCJoPpQifuVpsKamUdn9U0zM3xUziVOqsGksUrY= @@ -79,16 +79,16 @@ go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0 h1:wm/Q0GAAykXv83 go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.38.0/go.mod h1:ra3Pa40+oKjvYh+ZD3EdxFZZB0xdMfuileHAm4nNN7w= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.38.0 h1:kJxSDN4SgWWTjG/hPp3O7LCGLcHXFlvS2/FFOrwL+SE= go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.38.0/go.mod h1:mgIOzS7iZeKJdeB8/NYHrJ48fdGc71Llo5bJ1J4DWUE= -go.opentelemetry.io/otel/log v0.14.0 h1:2rzJ+pOAZ8qmZ3DDHg73NEKzSZkhkGIua9gXtxNGgrM= -go.opentelemetry.io/otel/log v0.14.0/go.mod h1:5jRG92fEAgx0SU/vFPxmJvhIuDU9E1SUnEQrMlJpOno= +go.opentelemetry.io/otel/log v0.19.0 h1:KUZs/GOsw79TBBMfDWsXS+KZ4g2Ckzksd1ymzsIEbo4= +go.opentelemetry.io/otel/log v0.19.0/go.mod h1:5DQYeGmxVIr4n0/BcJvF4upsraHjg6vudJJpnkL6Ipk= go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= -go.opentelemetry.io/otel/sdk/log v0.14.0 h1:JU/U3O7N6fsAXj0+CXz21Czg532dW2V4gG1HE/e8Zrg= -go.opentelemetry.io/otel/sdk/log v0.14.0/go.mod h1:imQvII+0ZylXfKU7/wtOND8Hn4OpT3YUoIgqJVksUkM= -go.opentelemetry.io/otel/sdk/log/logtest v0.14.0 h1:Ijbtz+JKXl8T2MngiwqBlPaHqc4YCaP/i13Qrow6gAM= -go.opentelemetry.io/otel/sdk/log/logtest v0.14.0/go.mod h1:dCU8aEL6q+L9cYTqcVOk8rM9Tp8WdnHOPLiBgp0SGOA= +go.opentelemetry.io/otel/sdk/log v0.19.0 h1:scYVLqT22D2gqXItnWiocLUKGH9yvkkeql5dBDiXyko= +go.opentelemetry.io/otel/sdk/log v0.19.0/go.mod h1:vFBowwXGLlW9AvpuF7bMgnNI95LiW10szrOdvzBHlAg= +go.opentelemetry.io/otel/sdk/log/logtest v0.19.0 h1:BEbF7ZBB6qQloV/Ub1+3NQoOUnVtcGkU3XX4Ws3GQfk= +go.opentelemetry.io/otel/sdk/log/logtest v0.19.0/go.mod h1:Lua81/3yM0wOmoHTokLj9y9ADeA02v1naRrVrkAZuKk= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= diff --git a/pkg/instrumentation/databasesql/parse.go b/pkg/instrumentation/databasesql/parse.go index a7cf15841..18204e2df 100644 --- a/pkg/instrumentation/databasesql/parse.go +++ b/pkg/instrumentation/databasesql/parse.go @@ -4,232 +4,28 @@ package db import ( - "errors" - "fmt" - nurl "net/url" - "strings" + "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/databasesql/dsnparse" ) -func ParseDbName(dsn string) string { - var name string - var err error - for i := len(dsn) - 1; i >= 0; i-- { - if dsn[i] == '/' { - dbname := dsn[i+1:] - queryIndex := strings.Index(dbname, "?") - if queryIndex > 0 { - dbname = dbname[:queryIndex] - } - if name, err = nurl.PathUnescape(dbname); err != nil { - return "" - } - - break - } - } - return name -} - -func parseDSN(driverName, dsn string) (addr string, err error) { - // TODO: need a more delegate DFA - switch driverName { - case "mysql": - return parseMySQL(dsn) - case "postgres": - fallthrough - case "postgresql": - return parsePostgres(dsn) - case "clickhouse": - return parseClickHouse(dsn) - case "sqlite3": - return "sqlite3", nil - case "godror", "oracle", "oci8", "go-oci8": - return parseOracle(dsn) - case "mssql", "sqlserver": - return parseSQLServer(dsn) - } - - return driverName, errors.New("invalid DSN") -} - -func parsePostgres(url string) (addr string, err error) { - u, err := nurl.Parse(url) - if err != nil { - return "", err - } - - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - if u.Port() != "" { - return u.Host, nil - } - return u.Hostname() + ":5432", nil -} - -func parseMySQL(dsn string) (addr string, err error) { - // MySQL DSN format: [username[:password]@][protocol[(address)]]/dbname[?params] - // We need to find the protocol part after @ to avoid special chars in password - - // Find @ symbol to locate where credentials end - atIndex := strings.LastIndex(dsn, "@") - var searchStart int - if atIndex >= 0 { - // Start searching for ( after @ - searchStart = atIndex - } else { - // No credentials, search from beginning - searchStart = 0 - } - - // Now find the ( and ) after the @ symbol - n := len(dsn) - i, j := -1, -1 - for k := searchStart; k < n; k++ { - if dsn[k] == '(' { - i = k - } - if dsn[k] == ')' && i >= 0 { - // Only accept ) if we've already found ( - j = k - break - } - } - if i >= 0 && j > i { - return dsn[i+1 : j], nil - } - return "", errors.New("invalid MySQL DSN") -} - -func parseClickHouse(dsn string) (addr string, err error) { - // ClickHouse DSN formats: - // tcp://host:port?database=dbname&username=user&password=pass - // http://host:port?database=dbname - // clickhouse://host:port/database?username=user&password=pass - u, err := nurl.Parse(dsn) - if err != nil { - return "", err - } +// DSNInfo is an alias for dsnparse.DSNInfo so callers that import the db +// package can use the type without an additional import. +type DSNInfo = dsnparse.DSNInfo - // Return host with port - if u.Port() != "" { - return u.Host, nil - } - - // Default ports based on scheme - switch u.Scheme { - case "tcp", "native": - return u.Hostname() + ":9000", nil - case "http": - return u.Hostname() + ":8123", nil - case "https": - return u.Hostname() + ":8443", nil - case "clickhouse": - // Default to native port - return u.Hostname() + ":9000", nil - } - - return u.Host, nil +// ParseDSN parses a driver-specific data source name and returns structured +// connection information. It tries multiple well-known formats in order and +// never panics. Unrecognised drivers return a zero-value DSNInfo. +func ParseDSN(driverName, dsn string) DSNInfo { + return dsnparse.ParseDSN(driverName, dsn) } -func parseOracle(dsn string) (addr string, err error) { - // Oracle DSN formats: - // user/password@host:port/service_name - // user/password@host:port/sid - // user/password@//host:port/service_name - // oracle://user:password@host:port/service_name - - // Try URL format first - if strings.Contains(dsn, "://") { - u, err := nurl.Parse(dsn) - if err == nil && u.Host != "" { - if u.Port() != "" { - return u.Host, nil - } - return u.Hostname() + ":1521", nil // Oracle default port - } - } - - // Parse traditional Oracle format: user/password@host:port/service - atIndex := strings.Index(dsn, "@") - if atIndex < 0 { - return "", errors.New("invalid Oracle DSN") - } - - connStr := dsn[atIndex+1:] - // Remove leading // - connStr = strings.TrimPrefix(connStr, "//") - - // Extract host:port before / - slashIndex := strings.Index(connStr, "/") - var hostPort string - if slashIndex > 0 { - hostPort = connStr[:slashIndex] - } else { - hostPort = connStr - } - - // If no port specified, add default - if !strings.Contains(hostPort, ":") { - hostPort = hostPort + ":1521" - } - - return hostPort, nil +// ParseDbName extracts the database name from a generic DSN by finding the +// last '/' and trimming any query-string suffix. Retained for backward +// compatibility; prefer ParseDSN when the driver name is known. +func ParseDbName(dsn string) string { + return dsnparse.ParseDbName(dsn) } -func parseSQLServer(dsn string) (addr string, err error) { - // SQL Server DSN formats: - // sqlserver://username:password@host:port?database=dbname - // server=host;port=1433;database=dbname;user id=user;password=pass - // Server=host,port;Database=dbname;User Id=user;Password=pass - - // Try URL format first - if strings.HasPrefix(dsn, "sqlserver://") || strings.HasPrefix(dsn, "mssql://") { - u, err := nurl.Parse(dsn) - if err != nil { - return "", err - } - if u.Port() != "" { - return u.Host, nil - } - return u.Hostname() + ":1433", nil // SQL Server default port - } - - // Parse connection string format (key=value pairs) - dsn = strings.ToLower(dsn) - var host, port string - - // Split by semicolon - pairs := strings.Split(dsn, ";") - for _, pair := range pairs { - pair = strings.TrimSpace(pair) - if strings.HasPrefix(pair, "server=") { - serverVal := strings.TrimPrefix(pair, "server=") - // Handle Server=host,port format - if strings.Contains(serverVal, ",") { - parts := strings.Split(serverVal, ",") - host = parts[0] - if len(parts) > 1 { - port = parts[1] - } - } else { - host = serverVal - } - } else if strings.HasPrefix(pair, "port=") { - port = strings.TrimPrefix(pair, "port=") - } else if strings.HasPrefix(pair, "host=") { - host = strings.TrimPrefix(pair, "host=") - } - } - - if host == "" { - return "", errors.New("invalid SQL Server DSN") - } - - if port == "" { - port = "1433" // SQL Server default port - } - - return host + ":" + port, nil +// parseDSN is the package-internal adapter called by beforeOpenInstrumentation. +func parseDSN(driverName, dsn string) (string, error) { + return dsnparse.LegacyParseDSN(driverName, dsn) }