diff --git a/pkg/instrumentation/databasesql/client.go b/pkg/instrumentation/databasesql/client.go index 955437f70..1b9903c8e 100644 --- a/pkg/instrumentation/databasesql/client.go +++ b/pkg/instrumentation/databasesql/client.go @@ -42,9 +42,11 @@ var clientEnabler = dbClientEnabler{} func beforeOpenInstrumentation(ictx inst.HookContext, driverName, dataSourceName string) { addr, err := parseDSN(driverName, dataSourceName) if err != nil { + logger.Warn("could not determine server address from DSN; server.address will be omitted", + "driver", driverName, "error", err) addr = "unknown" } - dbName := ParseDbName(dataSourceName) + dbName := parseDbName(dataSourceName) ictx.SetData(map[string]string{ "endpoint": addr, "driver": driverName, diff --git a/pkg/instrumentation/databasesql/internal/dsnparse/parse.go b/pkg/instrumentation/databasesql/internal/dsnparse/parse.go new file mode 100644 index 000000000..f2b9914d5 --- /dev/null +++ b/pkg/instrumentation/databasesql/internal/dsnparse/parse.go @@ -0,0 +1,275 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package dsnparse + +import ( + "errors" + "fmt" + nurl "net/url" + "strings" + "sync" +) + +// DSNParser parses a driver-specific DSN and returns the server address (host:port). +type DSNParser func(dsn string) (addr string, err error) + +var ( + parserMu sync.RWMutex + parserRegistry = map[string]DSNParser{} +) + +// RegisterDSNParser registers a custom DSN parser for the given driver name. +// Built-in parsers are registered automatically during package initialization. +// Calling RegisterDSNParser for an already-registered name overwrites the previous parser. +// It is safe to call from package init() functions. +func RegisterDSNParser(driverName string, parser DSNParser) { + parserMu.Lock() + defer parserMu.Unlock() + parserRegistry[driverName] = parser +} + +func init() { + // Register all built-in DSN parsers. + RegisterDSNParser("mysql", ParseMySQL) + RegisterDSNParser("postgres", ParsePostgres) + RegisterDSNParser("postgresql", ParsePostgres) + RegisterDSNParser("pgx", ParsePostgres) // pgx uses the standard postgres URL format + RegisterDSNParser("lib/pq", ParsePostgres) // lib/pq uses the standard postgres URL format + RegisterDSNParser("clickhouse", ParseClickHouse) + RegisterDSNParser("sqlite3", func(_ string) (string, error) { return "sqlite3", nil }) + RegisterDSNParser("godror", ParseOracle) + RegisterDSNParser("oracle", ParseOracle) + RegisterDSNParser("oci8", ParseOracle) + RegisterDSNParser("go-oci8", ParseOracle) + RegisterDSNParser("mssql", ParseSQLServer) + RegisterDSNParser("sqlserver", ParseSQLServer) +} + +// ParseDSN parses driverName and dsn into a server address (host:port). +// It falls back to best-effort URL parsing for unregistered drivers. +func ParseDSN(driverName, dsn string) (addr string, err error) { + parserMu.RLock() + parser, ok := parserRegistry[driverName] + parserMu.RUnlock() + + if ok { + return parser(dsn) + } + + // Best-effort: try standard URL parsing for drivers not in the registry. + return BestEffortParse(dsn) +} + +// BestEffortParse attempts to extract a host:port from a DSN using standard URL parsing. +// It is used as a fallback for drivers that have no registered parser. +func BestEffortParse(dsn string) (string, error) { + u, err := nurl.Parse(dsn) + if err == nil && u.Host != "" { + return u.Host, nil + } + return "", errors.New("no DSN parser registered for this driver; best-effort URL parse also failed") +} + +// ParseDbName extracts the database name from a DSN using best-effort URL parsing. +// For URL-based DSNs it uses the path component; for MySQL-style DSNs it uses +// the segment between the closing parenthesis and the first '?'. Returns an +// empty string when the name cannot be determined. +func ParseDbName(dsn string) string { + // MySQL style: user:pass@tcp(host:port)/dbname?params + if i := strings.LastIndex(dsn, ")/"); i >= 0 { + rest := dsn[i+2:] + if j := strings.IndexByte(rest, '?'); j >= 0 { + rest = rest[:j] + } + return rest + } + // URL style: scheme://user:pass@host/dbname?params + u, err := nurl.Parse(dsn) + if err == nil && u.Scheme != "" && u.Path != "" { + return strings.TrimPrefix(u.Path, "/") + } + return "" +} + +func ParsePostgres(url string) (addr string, err error) { + u, err := nurl.Parse(url) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + if u.Port() != "" { + return u.Host, nil + } + return u.Hostname() + ":5432", nil +} + +func ParseMySQL(dsn string) (addr string, err error) { + // MySQL DSN format: [username[:password]@][protocol[(address)]]/dbname[?params] + // We need to find the protocol part after @ to avoid special chars in password + + // Find @ symbol to locate where credentials end + atIndex := strings.LastIndex(dsn, "@") + var searchStart int + if atIndex >= 0 { + // Start searching for ( after @ + searchStart = atIndex + } else { + // No credentials, search from beginning + searchStart = 0 + } + + // Now find the ( and ) after the @ symbol + n := len(dsn) + i, j := -1, -1 + for k := searchStart; k < n; k++ { + if dsn[k] == '(' { + i = k + } + if dsn[k] == ')' && i >= 0 { + // Only accept ) if we've already found ( + j = k + break + } + } + if i >= 0 && j > i { + return dsn[i+1 : j], nil + } + return "", errors.New("invalid MySQL DSN") +} + +func ParseClickHouse(dsn string) (addr string, err error) { + // ClickHouse DSN formats: + // tcp://host:port?database=dbname&username=user&password=pass + // http://host:port?database=dbname + // clickhouse://host:port/database?username=user&password=pass + u, err := nurl.Parse(dsn) + if err != nil { + return "", err + } + + // Return host with port + if u.Port() != "" { + return u.Host, nil + } + + // Default ports based on scheme + switch u.Scheme { + case "tcp", "native": + return u.Hostname() + ":9000", nil + case "http": + return u.Hostname() + ":8123", nil + case "https": + return u.Hostname() + ":8443", nil + case "clickhouse": + // Default to native port + return u.Hostname() + ":9000", nil + } + + return u.Host, nil +} + +func ParseOracle(dsn string) (addr string, err error) { + // Oracle DSN formats: + // user/password@host:port/service_name + // user/password@host:port/sid + // user/password@//host:port/service_name + // oracle://user:password@host:port/service_name + + // Try URL format first + if strings.Contains(dsn, "://") { + u, err := nurl.Parse(dsn) + if err == nil && u.Host != "" { + if u.Port() != "" { + return u.Host, nil + } + return u.Hostname() + ":1521", nil // Oracle default port + } + } + + // Parse traditional Oracle format: user/password@host:port/service + atIndex := strings.Index(dsn, "@") + if atIndex < 0 { + return "", errors.New("invalid Oracle DSN") + } + + connStr := dsn[atIndex+1:] + // Remove leading // + connStr = strings.TrimPrefix(connStr, "//") + + // Extract host:port before / + slashIndex := strings.Index(connStr, "/") + var hostPort string + if slashIndex > 0 { + hostPort = connStr[:slashIndex] + } else { + hostPort = connStr + } + + // If no port specified, add default + if !strings.Contains(hostPort, ":") { + hostPort = hostPort + ":1521" + } + + return hostPort, nil +} + +func ParseSQLServer(dsn string) (addr string, err error) { + // SQL Server DSN formats: + // sqlserver://username:password@host:port?database=dbname + // server=host;port=1433;database=dbname;user id=user;password=pass + // Server=host,port;Database=dbname;User Id=user;Password=pass + + // Try URL format first + if strings.HasPrefix(dsn, "sqlserver://") || strings.HasPrefix(dsn, "mssql://") { + u, err := nurl.Parse(dsn) + if err != nil { + return "", err + } + if u.Port() != "" { + return u.Host, nil + } + return u.Hostname() + ":1433", nil // SQL Server default port + } + + // Parse connection string format (key=value pairs) + dsn = strings.ToLower(dsn) + var host, port string + + // Split by semicolon + pairs := strings.Split(dsn, ";") + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + if strings.HasPrefix(pair, "server=") { + serverVal := strings.TrimPrefix(pair, "server=") + // Handle Server=host,port format + if strings.Contains(serverVal, ",") { + parts := strings.Split(serverVal, ",") + host = parts[0] + if len(parts) > 1 { + port = parts[1] + } + } else { + host = serverVal + } + } else if strings.HasPrefix(pair, "port=") { + port = strings.TrimPrefix(pair, "port=") + } else if strings.HasPrefix(pair, "host=") { + host = strings.TrimPrefix(pair, "host=") + } + } + + if host == "" { + return "", errors.New("invalid SQL Server DSN") + } + + if port == "" { + port = "1433" // SQL Server default port + } + + return host + ":" + port, nil +} diff --git a/pkg/instrumentation/databasesql/internal/dsnparse/parse_test.go b/pkg/instrumentation/databasesql/internal/dsnparse/parse_test.go new file mode 100644 index 000000000..cd245cfd9 --- /dev/null +++ b/pkg/instrumentation/databasesql/internal/dsnparse/parse_test.go @@ -0,0 +1,159 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package dsnparse + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegisterDSNParser(t *testing.T) { + const driver = "testdriver-register" + + // Not registered yet — should fall back to BestEffortParse. + _, err := ParseDSN(driver, "not-a-url") + assert.Error(t, err, "unregistered driver with unparseable DSN should return error") + + // Register a custom parser. + called := false + RegisterDSNParser(driver, func(dsn string) (string, error) { + called = true + return "custom-host:1234", nil + }) + + addr, err := ParseDSN(driver, "anything") + require.NoError(t, err) + assert.True(t, called, "registered parser should have been called") + assert.Equal(t, "custom-host:1234", addr) +} + +func TestRegisterDSNParser_Overwrite(t *testing.T) { + const driver = "testdriver-overwrite" + + RegisterDSNParser(driver, func(_ string) (string, error) { return "first:1111", nil }) + RegisterDSNParser(driver, func(_ string) (string, error) { return "second:2222", nil }) + + addr, err := ParseDSN(driver, "anything") + require.NoError(t, err) + assert.Equal(t, "second:2222", addr, "second registration should overwrite the first") +} + +func TestBestEffortParse(t *testing.T) { + tests := []struct { + name string + dsn string + want string + wantErr bool + }{ + { + name: "valid URL with host and port", + dsn: "somedriver://user:pass@db.example.com:9999/mydb", + want: "db.example.com:9999", + }, + { + name: "valid URL with host only", + dsn: "somedriver://db.example.com/mydb", + want: "db.example.com", + }, + { + name: "no host in URL", + dsn: "not-a-url-at-all", + wantErr: true, + }, + { + name: "empty string", + dsn: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := BestEffortParse(tt.dsn) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestParseDSN_PgxAndLibPqAliases(t *testing.T) { + const pgDSN = "postgres://user:pass@pg.example.com:5432/mydb" + + for _, driver := range []string{"pgx", "lib/pq"} { + t.Run(driver, func(t *testing.T) { + addr, err := ParseDSN(driver, pgDSN) + require.NoError(t, err) + assert.Equal(t, "pg.example.com:5432", addr, "driver %q should parse postgres DSN", driver) + }) + } +} + +func TestParseDSN_PostgresDefaultPort(t *testing.T) { + // When no port is in the DSN the parser should append :5432. + addr, err := ParseDSN("postgres", "postgres://user:pass@pg.example.com/mydb") + require.NoError(t, err) + assert.Equal(t, "pg.example.com:5432", addr) +} + +func TestParseDSN_UnknownDriverFallback(t *testing.T) { + // Unknown driver with a parseable URL should succeed via BestEffortParse. + addr, err := ParseDSN("unknown-driver", "somedb://host.example.com:9876/mydb") + require.NoError(t, err) + assert.Equal(t, "host.example.com:9876", addr) + + // Unknown driver with an unparseable DSN should return an error. + _, err = ParseDSN("unknown-driver", "not-a-url") + assert.Error(t, err) +} + +func TestParseDbName(t *testing.T) { + tests := []struct { + name string + dsn string + want string + }{ + { + name: "postgres URL", + dsn: "postgres://user:pass@host:5432/mydb", + want: "mydb", + }, + { + name: "mysql style", + dsn: "user:pass@tcp(host:3306)/mydb?charset=utf8", + want: "mydb", + }, + { + name: "mysql style no params", + dsn: "user:pass@tcp(host:3306)/mydb", + want: "mydb", + }, + { + name: "clickhouse URL", + dsn: "clickhouse://host:9000/analytics", + want: "analytics", + }, + { + name: "unparseable DSN", + dsn: "not-a-dsn", + want: "", + }, + { + name: "empty", + dsn: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, ParseDbName(tt.dsn)) + }) + } +} diff --git a/pkg/instrumentation/databasesql/parse.go b/pkg/instrumentation/databasesql/parse.go index a7cf15841..0398c5ba8 100644 --- a/pkg/instrumentation/databasesql/parse.go +++ b/pkg/instrumentation/databasesql/parse.go @@ -4,232 +4,24 @@ package db import ( - "errors" - "fmt" - nurl "net/url" - "strings" + "github.com/open-telemetry/opentelemetry-go-compile-instrumentation/pkg/instrumentation/databasesql/internal/dsnparse" ) -func ParseDbName(dsn string) string { - var name string - var err error - for i := len(dsn) - 1; i >= 0; i-- { - if dsn[i] == '/' { - dbname := dsn[i+1:] - queryIndex := strings.Index(dbname, "?") - if queryIndex > 0 { - dbname = dbname[:queryIndex] - } - if name, err = nurl.PathUnescape(dbname); err != nil { - return "" - } +// DSNParser parses a driver-specific DSN and returns the server address (host:port). +type DSNParser = dsnparse.DSNParser - break - } - } - return name +// RegisterDSNParser registers a custom DSN parser for the given driver name. +// Built-in parsers are registered automatically during package initialization. +// Calling RegisterDSNParser for an already-registered name overwrites the previous parser. +// It is safe to call from package init() functions. +func RegisterDSNParser(driverName string, parser dsnparse.DSNParser) { + dsnparse.RegisterDSNParser(driverName, parser) } -func parseDSN(driverName, dsn string) (addr string, err error) { - // TODO: need a more delegate DFA - switch driverName { - case "mysql": - return parseMySQL(dsn) - case "postgres": - fallthrough - case "postgresql": - return parsePostgres(dsn) - case "clickhouse": - return parseClickHouse(dsn) - case "sqlite3": - return "sqlite3", nil - case "godror", "oracle", "oci8", "go-oci8": - return parseOracle(dsn) - case "mssql", "sqlserver": - return parseSQLServer(dsn) - } - - return driverName, errors.New("invalid DSN") -} - -func parsePostgres(url string) (addr string, err error) { - u, err := nurl.Parse(url) - if err != nil { - return "", err - } - - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) - } - - if u.Port() != "" { - return u.Host, nil - } - return u.Hostname() + ":5432", nil -} - -func parseMySQL(dsn string) (addr string, err error) { - // MySQL DSN format: [username[:password]@][protocol[(address)]]/dbname[?params] - // We need to find the protocol part after @ to avoid special chars in password - - // Find @ symbol to locate where credentials end - atIndex := strings.LastIndex(dsn, "@") - var searchStart int - if atIndex >= 0 { - // Start searching for ( after @ - searchStart = atIndex - } else { - // No credentials, search from beginning - searchStart = 0 - } - - // Now find the ( and ) after the @ symbol - n := len(dsn) - i, j := -1, -1 - for k := searchStart; k < n; k++ { - if dsn[k] == '(' { - i = k - } - if dsn[k] == ')' && i >= 0 { - // Only accept ) if we've already found ( - j = k - break - } - } - if i >= 0 && j > i { - return dsn[i+1 : j], nil - } - return "", errors.New("invalid MySQL DSN") -} - -func parseClickHouse(dsn string) (addr string, err error) { - // ClickHouse DSN formats: - // tcp://host:port?database=dbname&username=user&password=pass - // http://host:port?database=dbname - // clickhouse://host:port/database?username=user&password=pass - u, err := nurl.Parse(dsn) - if err != nil { - return "", err - } - - // Return host with port - if u.Port() != "" { - return u.Host, nil - } - - // Default ports based on scheme - switch u.Scheme { - case "tcp", "native": - return u.Hostname() + ":9000", nil - case "http": - return u.Hostname() + ":8123", nil - case "https": - return u.Hostname() + ":8443", nil - case "clickhouse": - // Default to native port - return u.Hostname() + ":9000", nil - } - - return u.Host, nil +func parseDSN(driverName, dsn string) (string, error) { + return dsnparse.ParseDSN(driverName, dsn) } -func parseOracle(dsn string) (addr string, err error) { - // Oracle DSN formats: - // user/password@host:port/service_name - // user/password@host:port/sid - // user/password@//host:port/service_name - // oracle://user:password@host:port/service_name - - // Try URL format first - if strings.Contains(dsn, "://") { - u, err := nurl.Parse(dsn) - if err == nil && u.Host != "" { - if u.Port() != "" { - return u.Host, nil - } - return u.Hostname() + ":1521", nil // Oracle default port - } - } - - // Parse traditional Oracle format: user/password@host:port/service - atIndex := strings.Index(dsn, "@") - if atIndex < 0 { - return "", errors.New("invalid Oracle DSN") - } - - connStr := dsn[atIndex+1:] - // Remove leading // - connStr = strings.TrimPrefix(connStr, "//") - - // Extract host:port before / - slashIndex := strings.Index(connStr, "/") - var hostPort string - if slashIndex > 0 { - hostPort = connStr[:slashIndex] - } else { - hostPort = connStr - } - - // If no port specified, add default - if !strings.Contains(hostPort, ":") { - hostPort = hostPort + ":1521" - } - - return hostPort, nil -} - -func parseSQLServer(dsn string) (addr string, err error) { - // SQL Server DSN formats: - // sqlserver://username:password@host:port?database=dbname - // server=host;port=1433;database=dbname;user id=user;password=pass - // Server=host,port;Database=dbname;User Id=user;Password=pass - - // Try URL format first - if strings.HasPrefix(dsn, "sqlserver://") || strings.HasPrefix(dsn, "mssql://") { - u, err := nurl.Parse(dsn) - if err != nil { - return "", err - } - if u.Port() != "" { - return u.Host, nil - } - return u.Hostname() + ":1433", nil // SQL Server default port - } - - // Parse connection string format (key=value pairs) - dsn = strings.ToLower(dsn) - var host, port string - - // Split by semicolon - pairs := strings.Split(dsn, ";") - for _, pair := range pairs { - pair = strings.TrimSpace(pair) - if strings.HasPrefix(pair, "server=") { - serverVal := strings.TrimPrefix(pair, "server=") - // Handle Server=host,port format - if strings.Contains(serverVal, ",") { - parts := strings.Split(serverVal, ",") - host = parts[0] - if len(parts) > 1 { - port = parts[1] - } - } else { - host = serverVal - } - } else if strings.HasPrefix(pair, "port=") { - port = strings.TrimPrefix(pair, "port=") - } else if strings.HasPrefix(pair, "host=") { - host = strings.TrimPrefix(pair, "host=") - } - } - - if host == "" { - return "", errors.New("invalid SQL Server DSN") - } - - if port == "" { - port = "1433" // SQL Server default port - } - - return host + ":" + port, nil +func parseDbName(dsn string) string { + return dsnparse.ParseDbName(dsn) } diff --git a/pkg/instrumentation/databasesql/semconv/db.go b/pkg/instrumentation/databasesql/semconv/db.go index 00f4f2ec6..3b8d70be3 100644 --- a/pkg/instrumentation/databasesql/semconv/db.go +++ b/pkg/instrumentation/databasesql/semconv/db.go @@ -42,12 +42,18 @@ func DbClientRequestTraceAttrs(req DatabaseSqlRequest) []attribute.KeyValue { } switch req.DriverName { - case "mysql": + case "mysql", "mariadb": attrs = append(attrs, semconv.DBSystemNameMySQL) - case "postgres": + case "postgres", "postgresql", "pgx", "lib/pq": attrs = append(attrs, semconv.DBSystemNamePostgreSQL) case "sqlite3": attrs = append(attrs, semconv.DBSystemNameSQLite) + case "clickhouse": + attrs = append(attrs, semconv.DBSystemNameClickHouse) + case "godror", "oracle", "oci8", "go-oci8": + attrs = append(attrs, semconv.DBSystemNameOracleDB) + case "mssql", "sqlserver": + attrs = append(attrs, semconv.DBSystemNameMicrosoftSQLServer) default: attrs = append(attrs, semconv.DBSystemNameOtherSQL) } diff --git a/pkg/instrumentation/databasesql/semconv/db_test.go b/pkg/instrumentation/databasesql/semconv/db_test.go index c6526d804..e049a612f 100644 --- a/pkg/instrumentation/databasesql/semconv/db_test.go +++ b/pkg/instrumentation/databasesql/semconv/db_test.go @@ -78,7 +78,7 @@ func TestDbClientRequestTraceAttrs(t *testing.T) { }, }, { - name: "unknown driver falls back to other_sql", + name: "clickhouse driver maps to clickhouse system name", req: DatabaseSqlRequest{ OpType: "SELECT", Sql: "SELECT 1", @@ -88,7 +88,7 @@ func TestDbClientRequestTraceAttrs(t *testing.T) { DbName: "default", }, expected: map[string]interface{}{ - "db.system.name": "other_sql", + "db.system.name": "clickhouse", "db.operation.name": "SELECT", "db.namespace": "default", "server.address": "localhost", @@ -97,6 +97,26 @@ func TestDbClientRequestTraceAttrs(t *testing.T) { "db.query.text": "SELECT 1", }, }, + { + name: "unknown driver falls back to other_sql", + req: DatabaseSqlRequest{ + OpType: "SELECT", + Sql: "SELECT 1", + Endpoint: "localhost:9999", + DriverName: "exoticdb", + Dsn: "exoticdb://localhost:9999/mydb", + DbName: "mydb", + }, + expected: map[string]interface{}{ + "db.system.name": "other_sql", + "db.operation.name": "SELECT", + "db.namespace": "mydb", + "server.address": "localhost", + "server.port": int64(9999), + "network.transport": "tcp", + "db.query.text": "SELECT 1", + }, + }, { name: "empty fields", req: DatabaseSqlRequest{