diff --git a/pkg/executor/BUILD.bazel b/pkg/executor/BUILD.bazel index 1b9c39f1144d4..4f0b9146e178a 100644 --- a/pkg/executor/BUILD.bazel +++ b/pkg/executor/BUILD.bazel @@ -466,6 +466,7 @@ go_test( "//pkg/testkit/testfailpoint", "//pkg/testkit/testmain", "//pkg/testkit/testsetup", + "//pkg/testkit/testutil", "//pkg/types", "//pkg/util", "//pkg/util/benchdaily", diff --git a/pkg/executor/adapter_slow_log.go b/pkg/executor/adapter_slow_log.go index 9c4693f598ac5..ada4f3b3c7d19 100644 --- a/pkg/executor/adapter_slow_log.go +++ b/pkg/executor/adapter_slow_log.go @@ -264,6 +264,9 @@ func SetSlowLogItems(a *ExecStmt, txnTS uint64, hasMoreResults bool, items *vari items.CPUUsages = sessVars.SQLCPUUsages.GetCPUUsages() items.StorageKV = stmtCtx.IsTiKV.Load() items.StorageMPP = stmtCtx.IsTiFlash.Load() + if sessVars.ConnectionInfo != nil && len(sessVars.ConnectionInfo.Attributes) > 0 { + items.SessionConnectAttrs = sessVars.ConnectionInfo.Attributes + } if a.retryCount > 0 { items.ExecRetryTime = items.TimeTotal - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) diff --git a/pkg/executor/cluster_table_test.go b/pkg/executor/cluster_table_test.go index 62acd59b4bac5..bea00ff90176c 100644 --- a/pkg/executor/cluster_table_test.go +++ b/pkg/executor/cluster_table_test.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/server" "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/testutil" "github.com/pingcap/tidb/pkg/util" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -354,3 +355,38 @@ func removeFiles(t *testing.T, fileNames []string) { require.NoError(t, os.Remove(fileName)) } } + +func TestClusterTableSlowQuerySessionConnectAttrs(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + srv := createRPCServer(t, dom) + defer srv.Stop() + + logData := ` +# Time: 2024-01-15T10:00:00.000000+08:00 +# Txn_start_ts: 123456789 +# User@Host: root[root] @ localhost [127.0.0.1] +# Query_time: 0.5 +# Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 +# Is_internal: false +# Succ: true +` + testutil.DefaultSessionConnectAttrsSlowLogLine() + ` +select * from t;` + fileName := "tidb-slow-query-attrs.log" + prepareLogs(t, []string{logData}, []string{fileName}) + defer removeFiles(t, []string{fileName}) + + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.Log.SlowQueryFile = fileName + }) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use information_schema") + + // Verify Session_connect_attrs column is present in cluster_slow_query as well. + clusterRows := tk.MustQuery("select Session_connect_attrs from information_schema.cluster_slow_query " + + "where time > '2024-01-01 00:00:00' and query = 'select * from t;'").Rows() + require.Len(t, clusterRows, 1) + clusterAttrsStr := clusterRows[0][0].(string) + testutil.RequireContainsDefaultSessionConnectAttrs(t, clusterAttrsStr) +} diff --git a/pkg/executor/infoschema_reader_test.go b/pkg/executor/infoschema_reader_test.go index 7b686693603f1..a3477bcaa9b0e 100644 --- a/pkg/executor/infoschema_reader_test.go +++ b/pkg/executor/infoschema_reader_test.go @@ -665,7 +665,7 @@ func TestColumnTable(t *testing.T) { testkit.RowsWithSep("|", "test|tbl1|col_2")) tk.MustQuery(`select count(*) from information_schema.columns;`).Check( - testkit.RowsWithSep("|", "5017")) + testkit.RowsWithSep("|", "5019")) } func TestIndexUsageTable(t *testing.T) { diff --git a/pkg/executor/slow_query.go b/pkg/executor/slow_query.go index 38fe968474a8f..3237e8626e7d6 100644 --- a/pkg/executor/slow_query.go +++ b/pkg/executor/slow_query.go @@ -447,7 +447,7 @@ func (e *slowQueryRetriever) parseSlowLog(ctx context.Context, sctx sessionctx.C startTime := time.Now() var logs [][]string var err error - if !e.extractor.Desc { + if e.extractor == nil || !e.extractor.Desc { logs, err = e.getBatchLog(ctx, reader, &offset, logNum) } else { logs, err = e.getBatchLogForReversedScan(ctx, reader, &offset, logNum) @@ -703,6 +703,12 @@ func (e *slowQueryRetriever) parseLog(ctx context.Context, sctx sessionctx.Conte } else if strings.HasPrefix(line, variable.SlowLogWarnings) { line = line[len(variable.SlowLogWarnings+variable.SlowLogSpaceMarkStr):] valid = e.setColumnValue(sctx, row, tz, variable.SlowLogWarnings, line, e.checker, fileLine) + } else if strings.HasPrefix(line, variable.SlowLogSessionConnectAttrs+variable.SlowLogSpaceMarkStr) { + line = line[len(variable.SlowLogSessionConnectAttrs+variable.SlowLogSpaceMarkStr):] + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogSessionConnectAttrs, line, e.checker, fileLine) + } else if strings.HasPrefix(line, variable.SlowLogDBStr+variable.SlowLogSpaceMarkStr) { + line = line[len(variable.SlowLogDBStr+variable.SlowLogSpaceMarkStr):] + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogDBStr, line, e.checker, fileLine) } else { fields, values := splitByColon(line) for i := 0; i < len(fields); i++ { @@ -885,6 +891,18 @@ func getColumnValueFactoryByName(colName string, columnIdx int) (slowQueryColumn row[columnIdx] = types.NewDatum(v) return true, nil }, nil + case variable.SlowLogSessionConnectAttrs: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { + if len(value) == 0 { + return true, nil + } + bj, err := types.ParseBinaryJSONFromString(value) + if err != nil { + return false, err + } + row[columnIdx] = types.NewDatum(bj) + return true, nil + }, nil } return nil, nil } diff --git a/pkg/executor/slow_query_sql_test.go b/pkg/executor/slow_query_sql_test.go index 10cf326484679..1dc797bc9b03f 100644 --- a/pkg/executor/slow_query_sql_test.go +++ b/pkg/executor/slow_query_sql_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/testkit/external" "github.com/pingcap/tidb/pkg/testkit/testdata" + "github.com/pingcap/tidb/pkg/testkit/testutil" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/stretchr/testify/require" ) @@ -540,3 +541,98 @@ func TestStorageEnginesInSlowQuery(t *testing.T) { "where query like 'select%tablesample%;'"). Check(testkit.Rows("1 0")) } + +func TestSessionConnectAttrsInSlowQuery(t *testing.T) { + originCfg := config.GetGlobalConfig() + newCfg := *originCfg + f, err := os.CreateTemp("", "tidb-slow-*.log") + require.NoError(t, err) + _, err = f.WriteString(`# Time: 2024-01-15T10:00:00.000000+08:00 +# Txn_start_ts: 123456789 +# User@Host: root[root] @ localhost [127.0.0.1] +# Query_time: 0.5 +# Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 +# Is_internal: false +# Succ: true +` + testutil.DefaultSessionConnectAttrsSlowLogLine() + ` +select * from t; +`) + require.NoError(t, err) + require.NoError(t, f.Close()) + newCfg.Log.SlowQueryFile = f.Name() + config.StoreGlobalConfig(&newCfg) + defer func() { + config.StoreGlobalConfig(originCfg) + require.NoError(t, os.Remove(newCfg.Log.SlowQueryFile)) + }() + require.NoError(t, logutil.InitLogger(newCfg.Log.ToLogConfig())) + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("set @@time_zone='+08:00'") + tk.MustExec(fmt.Sprintf("set @@tidb_slow_query_file='%v'", f.Name())) + + // Verify Session_connect_attrs column is present and returns the correct JSON value. + rows := tk.MustQuery("select Session_connect_attrs from information_schema.slow_query " + + "where query = 'select * from t;'").Rows() + require.Len(t, rows, 1) + attrsStr := rows[0][0].(string) + testutil.RequireContainsDefaultSessionConnectAttrs(t, attrsStr) + + // Verify individual keys are accessible via JSON_EXTRACT. + tk.MustQuery("select JSON_EXTRACT(Session_connect_attrs, '$._client_name') from information_schema.slow_query " + + "where query = 'select * from t;'"). + Check(testkit.Rows(`"Go-MySQL-Driver"`)) + tk.MustQuery("select JSON_EXTRACT(Session_connect_attrs, '$.app_name') from information_schema.slow_query " + + "where query = 'select * from t;'"). + Check(testkit.Rows(`"test_app"`)) +} + +func TestSessionConnectAttrsMissingAndTruncatedInSlowQuery(t *testing.T) { + originCfg := config.GetGlobalConfig() + newCfg := *originCfg + f, err := os.CreateTemp("", "tidb-slow-*.log") + require.NoError(t, err) + _, err = f.WriteString(`# Time: 2024-01-15T10:00:00.000000+08:00 +# Txn_start_ts: 123456789 +# User@Host: root[root] @ localhost [127.0.0.1] +# Query_time: 0.5 +# Digest: 1111111111111111111111111111111111111111111111111111111111111111 +# Is_internal: false +# Succ: true +select * from t_no_attrs; +# Time: 2024-01-15T10:00:01.000000+08:00 +# Txn_start_ts: 123456790 +# User@Host: root[root] @ localhost [127.0.0.1] +# Query_time: 0.6 +# Digest: 2222222222222222222222222222222222222222222222222222222222222222 +# Is_internal: false +# Succ: true +# Session_connect_attrs: {"_truncated":"4","app_name":"trunc_case"} +select * from t_truncated; +`) + require.NoError(t, err) + require.NoError(t, f.Close()) + newCfg.Log.SlowQueryFile = f.Name() + config.StoreGlobalConfig(&newCfg) + defer func() { + config.StoreGlobalConfig(originCfg) + require.NoError(t, os.Remove(newCfg.Log.SlowQueryFile)) + }() + require.NoError(t, logutil.InitLogger(newCfg.Log.ToLogConfig())) + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + tk.MustExec("set @@time_zone='+08:00'") + tk.MustExec(fmt.Sprintf("set @@tidb_slow_query_file='%v'", f.Name())) + + // Missing Session_connect_attrs should parse to JSON null-like empty behavior. + tk.MustQuery("select Session_connect_attrs = cast('null' as json), JSON_EXTRACT(Session_connect_attrs, '$._truncated') is null from information_schema.slow_query " + + "where query = 'select * from t_no_attrs;' "). + Check(testkit.Rows("1 1")) + + // Truncation metadata key should be preserved and queryable from JSON. + tk.MustQuery("select JSON_UNQUOTE(JSON_EXTRACT(Session_connect_attrs, '$._truncated')), JSON_UNQUOTE(JSON_EXTRACT(Session_connect_attrs, '$.app_name')) from information_schema.slow_query " + + "where query = 'select * from t_truncated;' "). + Check(testkit.Rows("4 trunc_case")) +} diff --git a/pkg/executor/slow_query_test.go b/pkg/executor/slow_query_test.go index be21032cb37b3..766642ff00ee5 100644 --- a/pkg/executor/slow_query_test.go +++ b/pkg/executor/slow_query_test.go @@ -35,6 +35,7 @@ import ( plannercore "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/testkit/testutil" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/logutil" @@ -178,7 +179,7 @@ select * from t;` `0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,65536,0,0,0,0,0,,` + `Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2 Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2 Cop_backoff_rpcTiKV_total_times: 200 Cop_backoff_rpcTiKV_total_time: 0.2 Cop_backoff_rpcTiKV_max_time: 0.2 Cop_backoff_rpcTiKV_max_addr: 127.0.0.1 Cop_backoff_rpcTiKV_avg_time: 0.2 Cop_backoff_rpcTiKV_p90_time: 0.2,` + `0,0,1,0,1,1,0,default,2.158,2.123,0.05,0.01,0.021,1,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + - `,update t set i = 1;,select * from t;` + `,update t set i = 1;,null,select * from t;` require.Equal(t, expectRecordString, recordString) // Issue 20928 @@ -201,7 +202,7 @@ select * from t;` `0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,65536,0,0,0,0,0,,` + `Cop_backoff_regionMiss_total_times: 200 Cop_backoff_regionMiss_total_time: 0.2 Cop_backoff_regionMiss_max_time: 0.2 Cop_backoff_regionMiss_max_addr: 127.0.0.1 Cop_backoff_regionMiss_avg_time: 0.2 Cop_backoff_regionMiss_p90_time: 0.2 Cop_backoff_rpcPD_total_times: 200 Cop_backoff_rpcPD_total_time: 0.2 Cop_backoff_rpcPD_max_time: 0.2 Cop_backoff_rpcPD_max_addr: 127.0.0.1 Cop_backoff_rpcPD_avg_time: 0.2 Cop_backoff_rpcPD_p90_time: 0.2 Cop_backoff_rpcTiKV_total_times: 200 Cop_backoff_rpcTiKV_total_time: 0.2 Cop_backoff_rpcTiKV_max_time: 0.2 Cop_backoff_rpcTiKV_max_addr: 127.0.0.1 Cop_backoff_rpcTiKV_avg_time: 0.2 Cop_backoff_rpcTiKV_p90_time: 0.2,` + `0,0,1,0,1,1,0,default,2.158,2.123,0.05,0.01,0.021,1,1,,60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4,` + - `,update t set i = 1;,select * from t;` + `,update t set i = 1;,null,select * from t;` require.Equal(t, expectRecordString, recordString) // fix sql contain '# ' bug @@ -254,6 +255,57 @@ select * from t; require.Equal(t, warnings[0].Err.Error(), "Parse slow log at line 2, failed field is Succ, failed value is abc, error is strconv.ParseBool: parsing \"abc\": invalid syntax") } +func TestParseSlowLogSessionConnectAttrs(t *testing.T) { + // Slow log entry that includes Session_connect_attrs JSON. + slowLogStr := `# Time: 2019-04-28T15:24:04.309074+08:00 +# Txn_start_ts: 405888132465033227 +# User@Host: root[root] @ localhost [127.0.0.1] +# Query_time: 0.216905 +# Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 +# Is_internal: false +# Succ: true +` + testutil.DefaultSessionConnectAttrsSlowLogLine() + ` +# Prev_stmt: begin; +select * from t; +` + loc, err := time.LoadLocation("Asia/Shanghai") + require.NoError(t, err) + ctx := mock.NewContext() + ctx.ResetSessionAndStmtTimeZone(loc) + + // Use the retriever directly (without initialize) to avoid reading + // from actual slow log files on disk, which can produce extra rows. + retriever, err := newSlowQueryRetriever() + require.NoError(t, err) + retriever.columnValueFactoryMap = make(map[string]slowQueryColumnValueFactory, len(retriever.outputCols)) + for idx, col := range retriever.outputCols { + factory, err := getColumnValueFactoryByName(col.Name.O, idx) + require.NoError(t, err) + require.NotNil(t, factory, "column %s should have a factory", col.Name.O) + retriever.columnValueFactoryMap[col.Name.O] = factory + } + + reader := bufio.NewReader(bytes.NewBufferString(slowLogStr)) + rows, err := parseLog(retriever, ctx, reader) + require.NoError(t, err) + require.Len(t, rows, 1) + + // Find the Session_connect_attrs column. + colIdx := -1 + for i, col := range retriever.outputCols { + if col.Name.L == strings.ToLower(variable.SlowLogSessionConnectAttrs) { + colIdx = i + break + } + } + require.NotEqual(t, -1, colIdx, "Session_connect_attrs column should exist") + + // Verify the parsed JSON contains the expected keys. + bj := rows[0][colIdx].GetMysqlJSON() + bjStr := bj.String() + testutil.RequireContainsDefaultSessionConnectAttrs(t, bjStr) +} + // It changes variable.MaxOfMaxAllowedPacket, so must be stayed in SerialSuite. func TestParseSlowLogFileSerial(t *testing.T) { loc, err := time.LoadLocation("Asia/Shanghai") diff --git a/pkg/infoschema/tables.go b/pkg/infoschema/tables.go index 748f11086b854..7e1810ed5dbc8 100644 --- a/pkg/infoschema/tables.go +++ b/pkg/infoschema/tables.go @@ -971,6 +971,7 @@ var slowQueryCols = []columnInfo{ {name: variable.SlowLogPlanDigest, tp: mysql.TypeVarchar, size: 128}, {name: variable.SlowLogBinaryPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, {name: variable.SlowLogPrevStmt, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: variable.SlowLogSessionConnectAttrs, tp: mysql.TypeJSON, size: types.UnspecifiedLength}, {name: variable.SlowLogQuerySQLStr, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, } diff --git a/pkg/infoschema/test/clustertablestest/BUILD.bazel b/pkg/infoschema/test/clustertablestest/BUILD.bazel index d61f60d5edf02..e163d2cd2c30f 100644 --- a/pkg/infoschema/test/clustertablestest/BUILD.bazel +++ b/pkg/infoschema/test/clustertablestest/BUILD.bazel @@ -38,6 +38,7 @@ go_test( "//pkg/testkit", "//pkg/testkit/external", "//pkg/testkit/testsetup", + "//pkg/testkit/testutil", "//pkg/types", "//pkg/util", "//pkg/util/dbterror/exeerrors", diff --git a/pkg/infoschema/test/clustertablestest/cluster_tables_test.go b/pkg/infoschema/test/clustertablestest/cluster_tables_test.go index ddb50e7a9bf42..b07cd9d01c6a6 100644 --- a/pkg/infoschema/test/clustertablestest/cluster_tables_test.go +++ b/pkg/infoschema/test/clustertablestest/cluster_tables_test.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tidb/pkg/store/mockstore/unistore" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/testkit/external" + "github.com/pingcap/tidb/pkg/testkit/testutil" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" "github.com/pingcap/tidb/pkg/util/logutil" @@ -293,6 +294,55 @@ func TestSelectClusterTable(t *testing.T) { tk.MustQuery("select instance from `CLUSTER_SLOW_QUERY` where time='2019-02-12 19:33:56.571953'").Check(testkit.Rows(instanceAddr)) } +func TestClusterSlowQuerySessionConnectAttrs(t *testing.T) { + // setup suite + s := new(clusterTablesSuite) + s.store, s.dom = testkit.CreateMockStoreAndDomain(t) + s.rpcserver, s.listenAddr = s.setUpRPCService(t, "127.0.0.1:0", nil) + s.httpServer, s.mockAddr = s.setUpMockPDHTTPServer() + s.startTime = time.Now() + defer s.httpServer.Close() + defer s.rpcserver.Stop() + + f, err := os.CreateTemp("", "tidb-cluster-slow-*.log") + require.NoError(t, err) + _, err = f.WriteString(`# Time: 2024-01-15T10:00:00.000000+08:00 +# Txn_start_ts: 123456789 +# User@Host: root[root] @ localhost [127.0.0.1] +# Query_time: 0.5 +# Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 +# Is_internal: false +# Succ: true +` + testutil.DefaultSessionConnectAttrsSlowLogLine() + ` +select * from t; +`) + require.NoError(t, err) + require.NoError(t, f.Close()) + defer func() { require.NoError(t, os.Remove(f.Name())) }() + + defer config.RestoreFunc()() + config.UpdateGlobal(func(conf *config.Config) { + conf.Log.SlowQueryFile = f.Name() + }) + + tk := s.newTestKitWithRoot(t) + tk.MustExec("use information_schema") + tk.MustExec("set time_zone = '+08:00';") + + rows := tk.MustQuery("select Session_connect_attrs from information_schema.cluster_slow_query " + + "where query = 'select * from t;'").Rows() + require.Len(t, rows, 1) + attrsStr := rows[0][0].(string) + testutil.RequireContainsDefaultSessionConnectAttrs(t, attrsStr) + + tk.MustQuery("select JSON_EXTRACT(Session_connect_attrs, '$._client_name') from information_schema.cluster_slow_query " + + "where query = 'select * from t;'"). + Check(testkit.Rows(`"Go-MySQL-Driver"`)) + tk.MustQuery("select JSON_EXTRACT(Session_connect_attrs, '$.app_name') from information_schema.cluster_slow_query " + + "where query = 'select * from t;'"). + Check(testkit.Rows(`"test_app"`)) +} + func TestSelectClusterTablePrivilege(t *testing.T) { // setup suite s := new(clusterTablesSuite) diff --git a/pkg/infoschema/test/clustertablestest/tables_test.go b/pkg/infoschema/test/clustertablestest/tables_test.go index 58c413f15e579..3a4405961b9bf 100644 --- a/pkg/infoschema/test/clustertablestest/tables_test.go +++ b/pkg/infoschema/test/clustertablestest/tables_test.go @@ -460,6 +460,7 @@ func TestSlowQuery(t *testing.T) { "60e9378c746d9a2be1c791047e008967cf252eb6de9167ad3aa6098fa2d523f4", "", "update t set i = 2;", + "null", "select * from t_slim;", }, {"2021-09-08 14:39:54.506967", @@ -544,6 +545,7 @@ func TestSlowQuery(t *testing.T) { "", "", "", + "null", "INSERT INTO ...;", }, } diff --git a/pkg/server/conn_test.go b/pkg/server/conn_test.go index 2bc276efa0400..677d2169bd632 100644 --- a/pkg/server/conn_test.go +++ b/pkg/server/conn_test.go @@ -307,6 +307,83 @@ func TestParseHandshakeResponse(t *testing.T) { require.Equal(t, "test", p.DBName) } +func encodeLengthEncodedIntForHandshake(v uint64) []byte { + switch { + case v < 251: + return []byte{byte(v)} + case v < 1<<16: + return []byte{0xfc, byte(v), byte(v >> 8)} + case v < 1<<24: + return []byte{0xfd, byte(v), byte(v >> 8), byte(v >> 16)} + default: + return []byte{0xfe, byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), byte(v >> 32), byte(v >> 40), byte(v >> 48), byte(v >> 56)} + } +} + +func buildHandshakeResponsePacket(capability uint32, attrsPayload []byte, attrsLenOverride *uint64) []byte { + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, capability) + binary.Write(&buf, binary.LittleEndian, uint32(0)) // max packet size + buf.WriteByte(0) // collation + buf.Write(make([]byte, 23)) // reserved + + buf.WriteString("root") + buf.WriteByte(0) // user null-terminated + buf.WriteByte(0) // auth null-terminated (legacy path) + + if capability&mysql.ClientConnectAtts > 0 { + attrsLen := uint64(len(attrsPayload)) + if attrsLenOverride != nil { + attrsLen = *attrsLenOverride + } + buf.Write(encodeLengthEncodedIntForHandshake(attrsLen)) + buf.Write(attrsPayload) + } + + return buf.Bytes() +} + +func TestHandshakeResponseCompatibilityAndFailurePaths(t *testing.T) { + t.Run("legacy client without connect attrs capability", func(t *testing.T) { + data := buildHandshakeResponsePacket(mysql.ClientProtocol41, nil, nil) + + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + require.Equal(t, "root", p.User) + require.Empty(t, p.Attrs) + }) + + t.Run("malformed connect attrs length declaration", func(t *testing.T) { + attrsPayload := []byte{2, 'a', 'b', 2, 'c', 'd'} + declaredLen := uint64(len(attrsPayload) + 3) + data := buildHandshakeResponsePacket(mysql.ClientProtocol41|mysql.ClientConnectAtts, attrsPayload, &declaredLen) + + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.ErrorIs(t, err, mysql.ErrMalformPacket) + }) + + t.Run("oversize connect attrs declaration rejected", func(t *testing.T) { + declaredLen := uint64(1<<20 + 1) + data := buildHandshakeResponsePacket(mysql.ClientProtocol41|mysql.ClientConnectAtts, nil, &declaredLen) + + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.Error(t, err) + require.Contains(t, err.Error(), "connection refused: session connection attributes exceed the 1 MiB hard limit") + }) +} + func TestIssue1768(t *testing.T) { // this data is from captured handshake packet, using mysql client. // TiDB should handle authorization correctly, even mysql client set @@ -2278,3 +2355,362 @@ func TestIssue54335(t *testing.T) { _ = cc.handleQuery(context.Background(), "select /*+ MAX_EXECUTION_TIME(1)*/ * FROM testTable2;") } } + +func TestParseHandshakeAttrsTruncation(t *testing.T) { + // Save and restore global atomic counters. + origSize := variable.ConnectAttrsSize.Load() + origLongest := variable.ConnectAttrsLongestSeen.Load() + origLost := variable.ConnectAttrsLost.Load() + defer func() { + variable.ConnectAttrsSize.Store(origSize) + variable.ConnectAttrsLongestSeen.Store(origLongest) + variable.ConnectAttrsLost.Store(origLost) + }() + + t.Run("exceeds limit truncation", func(t *testing.T) { + variable.ConnectAttrsSize.Store(5) // very small limit + variable.ConnectAttrsLongestSeen.Store(0) + variable.ConnectAttrsLost.Store(0) + + // Construct payload + // Capability: ClientProtocol41 | ClientConnectAtts + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + + var buf bytes.Buffer + // Header (32 bytes) + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) // MaxPacketSize + buf.WriteByte(0) // Collation + buf.Write(make([]byte, 23)) // Reserved + + // Body + buf.WriteString("root") + buf.WriteByte(0) // User null-term + + buf.WriteByte(0) // Auth null-term + + // Attrs: "ab":"cd" (4), "ef":"gh" (4). Total = 8. Limit = 5. + attrsBuf := bytes.NewBuffer(nil) + // K1 + attrsBuf.WriteByte(2) + attrsBuf.WriteString("ab") + // V1 + attrsBuf.WriteByte(2) + attrsBuf.WriteString("cd") + // K2 + attrsBuf.WriteByte(2) + attrsBuf.WriteString("ef") + // V2 + attrsBuf.WriteByte(2) + attrsBuf.WriteString("gh") + + attrsBytes := attrsBuf.Bytes() + buf.WriteByte(byte(len(attrsBytes))) // 12 bytes total length (includes len enc overhead or just payload?) + // ParseLengthEncodedInt parses the integer. It returns the integer value. + // HandshakeResponseBody uses this value as the length of the *bytes* to consume for attributes. + // The bytes consumed are then passed to parseAttrs. + // parseAttrs expects `[len][str][len][str]`. + // So `attrsBytes` is correct. + buf.Write(attrsBytes) + + data := buf.Bytes() + + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + + // Check truncation + require.Len(t, p.Attrs, 2) + require.Equal(t, "cd", p.Attrs["ab"]) + val, ok := p.Attrs["_truncated"] + require.True(t, ok) + require.Equal(t, "4", val) + + require.Equal(t, int64(1), variable.ConnectAttrsLost.Load()) + require.Equal(t, int64(8), variable.ConnectAttrsLongestSeen.Load()) // 4+4=8 + }) + + t.Run("limit 0 disables collection", func(t *testing.T) { + variable.ConnectAttrsSize.Store(0) + variable.ConnectAttrsLongestSeen.Store(0) + variable.ConnectAttrsLost.Store(0) + + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) + buf.WriteByte(0) + buf.Write(make([]byte, 23)) + buf.WriteString("root") + buf.WriteByte(0) + buf.WriteByte(0) + + attrsBuf := bytes.NewBuffer(nil) + attrsBuf.WriteByte(2) + attrsBuf.WriteString("ab") + attrsBuf.WriteByte(2) + attrsBuf.WriteString("cd") + attrsBytes := attrsBuf.Bytes() + buf.WriteByte(byte(len(attrsBytes))) + buf.Write(attrsBytes) + data := buf.Bytes() + + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + + require.Len(t, p.Attrs, 0) + require.Equal(t, int64(0), variable.ConnectAttrsLost.Load()) + require.Equal(t, int64(0), variable.ConnectAttrsLongestSeen.Load()) + }) + + t.Run("limit 65536 acceptance", func(t *testing.T) { + variable.ConnectAttrsSize.Store(65536) + variable.ConnectAttrsLongestSeen.Store(0) + variable.ConnectAttrsLost.Store(0) + + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) + buf.WriteByte(0) + buf.Write(make([]byte, 23)) + buf.WriteString("root") + buf.WriteByte(0) + buf.WriteByte(0) + + attrsBuf := bytes.NewBuffer(nil) + attrsBuf.WriteByte(2) + attrsBuf.WriteString("ab") + attrsBuf.WriteByte(2) + attrsBuf.WriteString("cd") + attrsBytes := attrsBuf.Bytes() + buf.WriteByte(byte(len(attrsBytes))) + buf.Write(attrsBytes) + data := buf.Bytes() + + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + + require.Len(t, p.Attrs, 1) + require.Equal(t, "cd", p.Attrs["ab"]) + require.Equal(t, int64(0), variable.ConnectAttrsLost.Load()) + require.Equal(t, int64(4), variable.ConnectAttrsLongestSeen.Load()) + }) + + t.Run("limit 1MiB rejection", func(t *testing.T) { + // Construct payload declaring > 1 MiB of attributes. + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + + var buf bytes.Buffer + // Header (32 bytes) + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) + buf.WriteByte(0) + buf.Write(make([]byte, 23)) + // Body + buf.WriteString("root") + buf.WriteByte(0) + buf.WriteByte(0) + + // Encode a length-encoded integer of 1<<20 + 1 (1 MiB + 1 byte). + // 0xfd prefix = 3-byte little-endian integer (range 65536–16777215). + // 1048577 = 0x100001 → LE bytes: 0x01, 0x00, 0x10. + buf.WriteByte(0xfd) + buf.WriteByte(0x01) + buf.WriteByte(0x00) + buf.WriteByte(0x10) + + // The 1 MiB check fires before the bounds-check on + // data[offset : offset+num], so no panic is expected. + + data := buf.Bytes() + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.Error(t, err) + require.Contains(t, err.Error(), "connection refused: session connection attributes exceed the 1 MiB hard limit") + }) + + t.Run("no limit", func(t *testing.T) { + variable.ConnectAttrsSize.Store(-1) // -1 means no limit up to 64KB + variable.ConnectAttrsLongestSeen.Store(0) + variable.ConnectAttrsLost.Store(0) + + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + + var buf bytes.Buffer + // Header (32 bytes) + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) + buf.WriteByte(0) + buf.Write(make([]byte, 23)) + // Body + buf.WriteString("root") + buf.WriteByte(0) + buf.WriteByte(0) + + // Attrs: "ab":"cd", "ef":"gh". Total = 8. + attrsBuf := bytes.NewBuffer(nil) + attrsBuf.WriteByte(2) + attrsBuf.WriteString("ab") + attrsBuf.WriteByte(2) + attrsBuf.WriteString("cd") + attrsBuf.WriteByte(2) + attrsBuf.WriteString("ef") + attrsBuf.WriteByte(2) + attrsBuf.WriteString("gh") + + attrsBytes := attrsBuf.Bytes() + buf.WriteByte(byte(len(attrsBytes))) + buf.Write(attrsBytes) + + data := buf.Bytes() + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + + // All attrs should be accepted, no truncation. + require.Len(t, p.Attrs, 2) + require.Equal(t, "cd", p.Attrs["ab"]) + require.Equal(t, "gh", p.Attrs["ef"]) + _, hasTruncated := p.Attrs["_truncated"] + require.False(t, hasTruncated) + + require.Equal(t, int64(0), variable.ConnectAttrsLost.Load()) + require.Equal(t, int64(8), variable.ConnectAttrsLongestSeen.Load()) + }) + + t.Run("longest_seen not updated for large payloads", func(t *testing.T) { + // Attrs >= 64KB should NOT update LongestSeen. + variable.ConnectAttrsSize.Store(-1) // Limit mapped to 65536 max internally + variable.ConnectAttrsLongestSeen.Store(100) + variable.ConnectAttrsLost.Store(0) + + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) + buf.WriteByte(0) + buf.Write(make([]byte, 23)) + buf.WriteString("root") + buf.WriteByte(0) + buf.WriteByte(0) + + // Build attrs: one key "k" with a 70000-byte value (total > 64KB). + attrsBuf := bytes.NewBuffer(nil) + attrsBuf.WriteByte(1) // key length = 1 + attrsBuf.WriteByte('k') + // value length = 70000 → needs 3-byte length-encoded int: 0xfd + 3 LE bytes + valLen := 70000 + attrsBuf.WriteByte(0xfd) + attrsBuf.WriteByte(byte(valLen)) + attrsBuf.WriteByte(byte(valLen >> 8)) + attrsBuf.WriteByte(byte(valLen >> 16)) + attrsBuf.Write(make([]byte, valLen)) // value payload + + attrsBytes := attrsBuf.Bytes() + // Encode overall attrs length as 3-byte length-encoded int. + attrsLen := len(attrsBytes) + buf.WriteByte(0xfd) + buf.WriteByte(byte(attrsLen)) + buf.WriteByte(byte(attrsLen >> 8)) + buf.WriteByte(byte(attrsLen >> 16)) + buf.Write(attrsBytes) + + data := buf.Bytes() + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + + // Attrs are truncated (totalSize=70001 > effectiveLimit=65536, because + // limit=-1 maps to effectiveLimit=65536 internally), but LongestSeen + // should NOT be updated because totalSize >= 64KB. + require.Equal(t, int64(100), variable.ConnectAttrsLongestSeen.Load(), + "LongestSeen should not be updated for payloads >= 64KB") + }) + + t.Run("underscore-prefixed attrs include reserved key", func(t *testing.T) { + // Keep underscore-prefixed attributes from clients for MySQL parity and + // observability. The client-provided "_truncated" is retained when no + // truncation occurs; if truncation happens, server may overwrite it. + variable.ConnectAttrsSize.Store(-1) // Limit mapped to 65536 max internally + variable.ConnectAttrsLongestSeen.Store(0) + variable.ConnectAttrsLost.Store(0) + + var clientCap uint32 = mysql.ClientProtocol41 | mysql.ClientConnectAtts + + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, clientCap) + binary.Write(&buf, binary.LittleEndian, uint32(0)) + buf.WriteByte(0) + buf.Write(make([]byte, 23)) + buf.WriteString("root") + buf.WriteByte(0) + buf.WriteByte(0) + + // Attrs: + // "_client_name":"Go-MySQL-Driver" → must be kept + // "_os":"linux" → must be kept + // "_program_name":"mysql" → must be kept + // "_custom":"val" → must be kept + // "_truncated":"fake" → kept (may be overwritten by server on truncation) + // "app_name":"myapp" → must be kept + attrsBuf := bytes.NewBuffer(nil) + for _, kv := range [][2]string{ + {"_client_name", "Go-MySQL-Driver"}, + {"_os", "linux"}, + {"_program_name", "mysql"}, + {"_custom", "val"}, + {"_truncated", "fake"}, + {"app_name", "myapp"}, + } { + attrsBuf.WriteByte(byte(len(kv[0]))) + attrsBuf.WriteString(kv[0]) + attrsBuf.WriteByte(byte(len(kv[1]))) + attrsBuf.WriteString(kv[1]) + } + + attrsBytes := attrsBuf.Bytes() + buf.WriteByte(byte(len(attrsBytes))) + buf.Write(attrsBytes) + + data := buf.Bytes() + var p handshake.Response41 + offset, err := parse.HandshakeResponseHeader(context.Background(), &p, data) + require.NoError(t, err) + + err = parse.HandshakeResponseBody(context.Background(), &p, data, offset) + require.NoError(t, err) + + // All keys are kept. + require.Len(t, p.Attrs, 6) + require.Equal(t, "Go-MySQL-Driver", p.Attrs["_client_name"]) + require.Equal(t, "linux", p.Attrs["_os"]) + require.Equal(t, "mysql", p.Attrs["_program_name"]) + require.Equal(t, "val", p.Attrs["_custom"]) + require.Equal(t, "fake", p.Attrs["_truncated"]) + require.Equal(t, "myapp", p.Attrs["app_name"]) + + require.Equal(t, int64(0), variable.ConnectAttrsLost.Load()) + }) +} diff --git a/pkg/server/internal/parse/BUILD.bazel b/pkg/server/internal/parse/BUILD.bazel index aae318fa183df..93e161a1e5068 100644 --- a/pkg/server/internal/parse/BUILD.bazel +++ b/pkg/server/internal/parse/BUILD.bazel @@ -9,7 +9,9 @@ go_library( "//pkg/parser/mysql", "//pkg/server/internal/handshake", "//pkg/server/internal/util", + "//pkg/sessionctx/variable", "//pkg/util/logutil", + "@com_github_pingcap_errors//:errors", "@org_uber_go_zap//:zap", ], ) @@ -23,9 +25,11 @@ go_test( ], embed = [":parse"], flaky = True, + shard_count = 3, deps = [ "//pkg/parser/mysql", "//pkg/server/internal/handshake", + "//pkg/sessionctx/variable", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/server/internal/parse/parse.go b/pkg/server/internal/parse/parse.go index eb79bc0927fbc..b81dffc630ee6 100644 --- a/pkg/server/internal/parse/parse.go +++ b/pkg/server/internal/parse/parse.go @@ -18,10 +18,15 @@ import ( "bytes" "context" "encoding/binary" + "fmt" + "strconv" + "strings" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/server/internal/handshake" util2 "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" ) @@ -131,14 +136,29 @@ func HandshakeResponseBody(ctx context.Context, packet *handshake.Response41, da // Defend some ill-formated packet, connection attribute is not important and can be ignored. return nil } - if num, null, intOff := util2.ParseLengthEncodedInt(data[offset:]); !null { - offset += intOff // Length of variable length encoded integer itself in bytes - row := data[offset : offset+int(num)] - attrs, err := parseAttrs(row) + num, null, intOff := util2.ParseLengthEncodedInt(data[offset:]) + offset += intOff // Length of variable length encoded integer itself in bytes + if !null { + if num > 1<<20 { // 1 MiB hard limit + return errors.New("connection refused: session connection attributes exceed the 1 MiB hard limit") + } + end := offset + int(num) + if end > len(data) { + logutil.Logger(ctx).Error("malformed connection attributes packet", + zap.Int("offset", offset), + zap.Uint64("attrLength", num), + zap.Int("dataLen", len(data))) + return mysql.ErrMalformPacket + } + row := data[offset:end] + attrs, warningsText, err := parseAttrs(row) if err != nil { logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err)) return nil } + if warningsText != "" { + logutil.Logger(ctx).Debug(warningsText) + } packet.Attrs = attrs offset += int(num) // Length of attributes } @@ -151,22 +171,142 @@ func HandshakeResponseBody(ctx context.Context, packet *handshake.Response41, da return nil } -func parseAttrs(data []byte) (map[string]string, error) { - attrs := make(map[string]string) +// reservedConnAttrTruncated is injected by TiDB when connection attributes +// are truncated. A client-provided key with the same name may be overwritten +// when truncation happens. +const reservedConnAttrTruncated = "_truncated" + +var standardConnAttrs = map[string]struct{}{ + "_client_name": {}, + "_client_version": {}, + "_os": {}, + "_pid": {}, + "_platform": {}, +} + +type connAttrKV struct { + key string + value string +} + +type decodedConnAttrs struct { + items []connAttrKV + totalSize int64 + hasDeprecatedUnderscoreAttr bool +} + +func parseAttrs(data []byte) (map[string]string, string, error) { + if variable.ConnectAttrsSize.Load() == 0 { + return map[string]string{}, "", nil + } + + decoded, err := decodeConnAttrs(data) + if err != nil { + return map[string]string{}, "", err + } + attrs, warningsText := applyConnAttrsPolicyAndMetrics(decoded, variable.ConnectAttrsSize.Load()) + return attrs, warningsText, nil +} + +func decodeConnAttrs(data []byte) (decodedConnAttrs, error) { + decoded := decodedConnAttrs{items: make([]connAttrKV, 0)} pos := 0 + for pos < len(data) { key, _, off, err := util2.ParseLengthEncodedBytes(data[pos:]) if err != nil { - return attrs, err + return decoded, err } pos += off + value, _, off, err := util2.ParseLengthEncodedBytes(data[pos:]) if err != nil { - return attrs, err + return decoded, err } pos += off - attrs[string(key)] = string(value) + keyStr := string(key) + valueStr := string(value) + + decoded.items = append(decoded.items, connAttrKV{key: keyStr, value: valueStr}) + decoded.totalSize += int64(len(key)) + int64(len(value)) + + if !decoded.hasDeprecatedUnderscoreAttr && strings.HasPrefix(keyStr, "_") { + if _, ok := standardConnAttrs[keyStr]; !ok { + decoded.hasDeprecatedUnderscoreAttr = true + } + } + } + + return decoded, nil +} + +func applyConnAttrsPolicyAndMetrics(decoded decodedConnAttrs, limit int64) (map[string]string, string) { + attrs := make(map[string]string) + effectiveLimit := normalizeConnectAttrsLimit(limit) + + var totalSize int64 + var acceptedSize int64 + truncated := false + + for _, item := range decoded.items { + kvSize := int64(len(item.key)) + int64(len(item.value)) + totalSize += kvSize + if totalSize > effectiveLimit { + if !truncated { + truncated = true + variable.ConnectAttrsLost.Add(1) + } + continue + } + if !truncated { + attrs[item.key] = item.value + acceptedSize += kvSize + } + } + + updateConnectAttrsLongestSeen(decoded.totalSize) + + warnings := make([]string, 0, 2) + if decoded.hasDeprecatedUnderscoreAttr { + warnings = append(warnings, + "custom connection attributes with leading underscore are deprecated and will be rejected in a future release") + } + if truncated { + truncatedBytes := decoded.totalSize - acceptedSize + attrs[reservedConnAttrTruncated] = strconv.FormatInt(truncatedBytes, 10) + warnings = append(warnings, fmt.Sprintf( + "session connection attributes truncated: total size %d bytes exceeds "+ + "performance_schema_session_connect_attrs_size (%d), %d bytes were discarded", + decoded.totalSize, effectiveLimit, truncatedBytes)) + } + warningsText := strings.Join(warnings, "; ") + return attrs, warningsText +} + +func normalizeConnectAttrsLimit(limit int64) int64 { + if limit < 0 { + // In MySQL, -1 means autosizing. We map it to a maximum of 64KB (65536) + // to prevent unconstrained slow log bloating. + return 65536 + } + return limit +} + +func updateConnectAttrsLongestSeen(totalSize int64) { + // Update LongestSeen only for normal-sized payloads (< 64 KiB). + // Abnormally large payloads are still accepted (up to 1 MiB) but should + // not skew this monitoring metric. + if totalSize >= 65536 { + return + } + for { + old := variable.ConnectAttrsLongestSeen.Load() + if totalSize <= old { + break + } + if variable.ConnectAttrsLongestSeen.CompareAndSwap(old, totalSize) { + break + } } - return attrs, nil } diff --git a/pkg/server/internal/parse/parse_test.go b/pkg/server/internal/parse/parse_test.go index cc44038cb50a5..c3f6bdbfd9c8a 100644 --- a/pkg/server/internal/parse/parse_test.go +++ b/pkg/server/internal/parse/parse_test.go @@ -15,9 +15,11 @@ package parse import ( + "bytes" "testing" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/stretchr/testify/require" ) @@ -43,3 +45,70 @@ func TestParseStmtFetchCmd(t *testing.T) { require.Equal(t, tc.err, err) } } + +func TestParseAttrsUnderscoreWarning(t *testing.T) { + origSize := variable.ConnectAttrsSize.Load() + defer variable.ConnectAttrsSize.Store(origSize) + variable.ConnectAttrsSize.Store(-1) + + buildAttrsPayload := func(kvs [][2]string) []byte { + var buf bytes.Buffer + for _, kv := range kvs { + buf.WriteByte(byte(len(kv[0]))) + buf.WriteString(kv[0]) + buf.WriteByte(byte(len(kv[1]))) + buf.WriteString(kv[1]) + } + return buf.Bytes() + } + + t.Run("warn for custom underscore attrs", func(t *testing.T) { + payload := buildAttrsPayload([][2]string{ + {"_client_name", "libmysql"}, + {"_custom", "val"}, + {"_program_name", "mysql"}, + {"app_name", "myapp"}, + }) + + attrs, warning, err := parseAttrs(payload) + require.NoError(t, err) + require.Equal(t, "libmysql", attrs["_client_name"]) + require.Equal(t, "val", attrs["_custom"]) + require.Equal(t, "mysql", attrs["_program_name"]) + require.Equal(t, "myapp", attrs["app_name"]) + require.Contains(t, warning, "custom connection attributes with leading underscore are deprecated and will be rejected in a future release") + }) + + t.Run("no warning for standard underscore attrs", func(t *testing.T) { + payload := buildAttrsPayload([][2]string{ + {"_client_name", "libmysql"}, + {"_client_version", "8.0.33"}, + {"_os", "linux"}, + {"_pid", "123"}, + {"_platform", "x86_64"}, + {"app_name", "myapp"}, + }) + + _, warning, err := parseAttrs(payload) + require.NoError(t, err) + require.Empty(t, warning) + }) + + t.Run("server may overwrite client _truncated on truncation", func(t *testing.T) { + origLost := variable.ConnectAttrsLost.Load() + defer variable.ConnectAttrsLost.Store(origLost) + variable.ConnectAttrsLost.Store(0) + variable.ConnectAttrsSize.Store(20) + + payload := buildAttrsPayload([][2]string{ + {"_truncated", "client-value"}, + {"app_name", "my_service"}, + }) + + attrs, warning, err := parseAttrs(payload) + require.NoError(t, err) + require.Contains(t, warning, "session connection attributes truncated") + require.NotEqual(t, "client-value", attrs["_truncated"]) + require.Equal(t, int64(1), variable.ConnectAttrsLost.Load()) + }) +} diff --git a/pkg/sessionctx/variable/noop.go b/pkg/sessionctx/variable/noop.go index bcfa6d63a0f0a..d8cea0581b8fc 100644 --- a/pkg/sessionctx/variable/noop.go +++ b/pkg/sessionctx/variable/noop.go @@ -130,7 +130,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: BinlogOrderCommits, Value: On, Type: TypeBool}, {Scope: ScopeGlobal, Name: "key_cache_division_limit", Value: "100"}, {Scope: ScopeGlobal | ScopeSession, Name: "max_insert_delayed_threads", Value: "20"}, - {Scope: ScopeNone, Name: "performance_schema_session_connect_attrs_size", Value: "512"}, {Scope: ScopeGlobal, Name: "innodb_max_dirty_pages_pct", Value: "75"}, {Scope: ScopeGlobal, Name: InnodbFilePerTable, Value: On, Type: TypeBool, AutoConvertNegativeBool: true}, {Scope: ScopeGlobal, Name: InnodbLogCompressedPages, Value: "1"}, diff --git a/pkg/sessionctx/variable/session_test.go b/pkg/sessionctx/variable/session_test.go index 6525af2f15462..3a6ca73d0d0e7 100644 --- a/pkg/sessionctx/variable/session_test.go +++ b/pkg/sessionctx/variable/session_test.go @@ -333,11 +333,49 @@ func TestSlowLogFormat(t *testing.T) { seVar.CurrentDBChanged = false logString := seVar.SlowLogFormat(logItems) require.Equal(t, resultFields+"\n"+sql, logString) + require.NotContains(t, logString, variable.SlowLogSessionConnectAttrs) seVar.CurrentDBChanged = true logString = seVar.SlowLogFormat(logItems) require.Equal(t, resultFields+"\n"+"use test;\n"+sql, logString) require.False(t, seVar.CurrentDBChanged) + // Verify SessionConnectAttrs serialization. + logItems.SessionConnectAttrs = map[string]string{ + "_client_name": "libmysql", + "_os": "Linux", + "app_name": "test_svc", + } + logString = seVar.SlowLogFormat(logItems) + // json.Encoder sorts map keys, so the output is deterministic. + expectedAttrsLine := `# Session_connect_attrs: {"_client_name":"libmysql","_os":"Linux","app_name":"test_svc"}` + require.Contains(t, logString, expectedAttrsLine) + seVar.EnableRedactLog = variable.On + logString = seVar.SlowLogFormat(logItems) + require.Contains(t, logString, expectedAttrsLine) + seVar.EnableRedactLog = variable.Off + // Session_connect_attrs should appear after Storage_from_mpp, before Prev_stmt, and before the SQL. + attrsIdx := strings.Index(logString, "Session_connect_attrs") + mppIdx := strings.Index(logString, variable.SlowLogStorageFromMPP) + prevStmtIdx := strings.Index(logString, variable.SlowLogPrevStmt) + sqlIdx := strings.Index(logString, sql) + require.Greater(t, attrsIdx, 0) + require.Greater(t, mppIdx, 0) + require.Greater(t, attrsIdx, mppIdx, "Session_connect_attrs should appear after Storage_from_mpp") + if prevStmtIdx > 0 { + require.Less(t, attrsIdx, prevStmtIdx, "Session_connect_attrs should appear before Prev_stmt") + } + require.Less(t, attrsIdx, sqlIdx, "Session_connect_attrs should appear before the SQL statement") + + // Verify reserved truncation metadata key is serialized as expected. + logItems.SessionConnectAttrs = map[string]string{ + "_truncated": "4", + "app_name": "test_svc", + } + logString = seVar.SlowLogFormat(logItems) + require.Contains(t, logString, `# Session_connect_attrs: {"_truncated":"4","app_name":"test_svc"}`) + // Restore for subsequent assertions. + logItems.SessionConnectAttrs = nil + // test PrepareSlowLogItemsForRules and CompleteSlowLogItemsForRules seVar.SlowLogRules = slowlogrule.NewSessionSlowLogRules(&slowlogrule.SlowLogRules{ Fields: map[string]struct{}{ @@ -696,3 +734,33 @@ func TestUserVars(t *testing.T) { require.True(t, ok) require.Equal(t, types.NewStringDatum("v2"), dt) } +func TestPerformanceSchemaSessionConnectAttrsSizeGlobalSQL(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + + originSize := variable.ConnectAttrsSize.Load() + defer func() { + variable.ConnectAttrsSize.Store(originSize) + tk.MustExec("set global performance_schema_session_connect_attrs_size = " + strconv.FormatInt(originSize, 10)) + }() + + tk.MustExec("set global performance_schema_session_connect_attrs_size = 0") + tk.MustQuery("select @@global.performance_schema_session_connect_attrs_size").Check(testkit.Rows("0")) + require.Equal(t, int64(0), variable.ConnectAttrsSize.Load()) + + tk.MustExec("set global performance_schema_session_connect_attrs_size = 65536") + tk.MustQuery("select @@global.performance_schema_session_connect_attrs_size").Check(testkit.Rows("65536")) + require.Equal(t, int64(65536), variable.ConnectAttrsSize.Load()) + + tk.MustExec("set global performance_schema_session_connect_attrs_size = -1") + tk.MustQuery("select @@global.performance_schema_session_connect_attrs_size").Check(testkit.Rows("-1")) + require.Equal(t, int64(-1), variable.ConnectAttrsSize.Load()) + + tk.MustExec("set global performance_schema_session_connect_attrs_size = 70000") + tk.MustQuery("select @@global.performance_schema_session_connect_attrs_size").Check(testkit.Rows("65536")) + require.Equal(t, int64(65536), variable.ConnectAttrsSize.Load()) + + tk.MustExec("set global performance_schema_session_connect_attrs_size = -2") + tk.MustQuery("select @@global.performance_schema_session_connect_attrs_size").Check(testkit.Rows("-1")) + require.Equal(t, int64(-1), variable.ConnectAttrsSize.Load()) +} diff --git a/pkg/sessionctx/variable/slow_log.go b/pkg/sessionctx/variable/slow_log.go index 98ed77aacb64d..6e5e598c335ef 100644 --- a/pkg/sessionctx/variable/slow_log.go +++ b/pkg/sessionctx/variable/slow_log.go @@ -192,6 +192,8 @@ const ( SlowLogResourceGroup = "Resource_group" // SlowLogCopMVCCReadAmplification is total_keys / processed_keys in coprocessor scan detail. SlowLogCopMVCCReadAmplification = "cop_mvcc_read_amplification" + // SlowLogSessionConnectAttrs is the session connection attributes from the client. + SlowLogSessionConnectAttrs = "Session_connect_attrs" ) // JSONSQLWarnForSlowLog helps to print the SQLWarn through the slow log in JSON format. @@ -269,6 +271,9 @@ type SlowQueryLogItems struct { CPUUsages ppcpuusage.CPUUsages StorageKV bool // query read from TiKV StorageMPP bool // query read from TiFlash + // SessionConnectAttrs holds the client connection attributes (e.g. _client_name, _os). + // This is a shared reference to ConnectionInfo.Attributes and must not be modified. + SessionConnectAttrs map[string]string } const zeroStr = "0" @@ -503,6 +508,17 @@ func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { } writeSlowLogItem(&buf, SlowLogStorageFromKV, strconv.FormatBool(logItems.StorageKV)) writeSlowLogItem(&buf, SlowLogStorageFromMPP, strconv.FormatBool(logItems.StorageMPP)) + if len(logItems.SessionConnectAttrs) > 0 { + // Encode into a temporary buffer first so that a (practically impossible) + // encoding error does not leave a partial line in the main buffer. + var attrsBuf bytes.Buffer + encoder := json.NewEncoder(&attrsBuf) + encoder.SetEscapeHTML(false) + if err := encoder.Encode(logItems.SessionConnectAttrs); err == nil { + buf.WriteString(SlowLogRowPrefixStr + SlowLogSessionConnectAttrs + SlowLogSpaceMarkStr) + buf.Write(attrsBuf.Bytes()) // Encode already appends \n + } + } if logItems.PrevStmt != "" { writeSlowLogItem(&buf, SlowLogPrevStmt, logItems.PrevStmt) } diff --git a/pkg/sessionctx/variable/statusvar.go b/pkg/sessionctx/variable/statusvar.go index f9d790b3fa7c0..643c353897ecc 100644 --- a/pkg/sessionctx/variable/statusvar.go +++ b/pkg/sessionctx/variable/statusvar.go @@ -133,6 +133,8 @@ var defaultStatus = map[string]*StatusVal{ "Ssl_cipher_list": {ScopeGlobal | ScopeSession, ""}, "Ssl_verify_mode": {ScopeGlobal | ScopeSession, 0}, "Ssl_version": {ScopeGlobal | ScopeSession, ""}, + "Performance_schema_session_connect_attrs_longest_seen": {ScopeGlobal, int64(0)}, + "Performance_schema_session_connect_attrs_lost": {ScopeGlobal, int64(0)}, } type defaultStatusStat struct { @@ -149,6 +151,10 @@ func (s defaultStatusStat) Stats(vars *SessionVars) (map[string]any, error) { statusVars[name] = v.Value } + // Read live values from atomic counters for connect attrs status variables. + statusVars["Performance_schema_session_connect_attrs_longest_seen"] = ConnectAttrsLongestSeen.Load() + statusVars["Performance_schema_session_connect_attrs_lost"] = ConnectAttrsLost.Load() + // `vars` may be nil in unit tests. if vars != nil && vars.TLSConnectionState != nil { statusVars["Ssl_cipher"] = util.TLSCipher2String(vars.TLSConnectionState.CipherSuite) diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index baf70bcdd76ef..d2dad495efdeb 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -609,6 +609,16 @@ var defaultSysVars = []*SysVar{ }}, /* The system variables below have GLOBAL scope */ + {Scope: ScopeGlobal, Name: PerformanceSchemaSessionConnectAttrsSize, + Value: strconv.FormatInt(DefConnectAttrsSize, 10), + Type: TypeInt, MinValue: -1, MaxValue: 65536, + GetGlobal: func(_ context.Context, sv *SessionVars) (string, error) { + return strconv.FormatInt(ConnectAttrsSize.Load(), 10), nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + ConnectAttrsSize.Store(TidbOptInt64(val, DefConnectAttrsSize)) + return nil + }}, {Scope: ScopeGlobal, Name: MaxPreparedStmtCount, Value: strconv.FormatInt(DefMaxPreparedStmtCount, 10), Type: TypeInt, MinValue: -1, MaxValue: 1048576, SetGlobal: func(_ context.Context, s *SessionVars, val string) error { num, err := strconv.ParseInt(val, 10, 64) @@ -3911,6 +3921,10 @@ const ( LocalInFile = "local_infile" // PerformanceSchema is the name for 'performance_schema' system variable. PerformanceSchema = "performance_schema" + // PerformanceSchemaSessionConnectAttrsSize is the name for 'performance_schema_session_connect_attrs_size' system variable. + PerformanceSchemaSessionConnectAttrsSize = "performance_schema_session_connect_attrs_size" + // PerfSchemaSessionConnectAttrsSize is kept as a compatibility alias for release-8.5 backports. + PerfSchemaSessionConnectAttrsSize = PerformanceSchemaSessionConnectAttrsSize // Flush is the name for 'flush' system variable. Flush = "flush" // SlaveAllowBatching is the name for 'slave_allow_batching' system variable. diff --git a/pkg/sessionctx/variable/tidb_vars.go b/pkg/sessionctx/variable/tidb_vars.go index 3f3bc36947dc9..1a91094a78995 100644 --- a/pkg/sessionctx/variable/tidb_vars.go +++ b/pkg/sessionctx/variable/tidb_vars.go @@ -1656,6 +1656,9 @@ const ( DefTiDBAdvancerCheckPointLagLimit = 48 * time.Hour DefTiDBIndexLookUpPushDownPolicy = IndexLookUpPushDownPolicyHintOnly DefTiDBCircuitBreakerPDMetaErrorRateRatio = 0.0 + // DefConnectAttrsSize is the default max aggregate byte size of connection attributes per connection. + // This corresponds to performance_schema_session_connect_attrs_size. In TiDB, -1 means no limit up to 64KB. + DefConnectAttrsSize int64 = 4096 ) // Process global variables. @@ -1785,6 +1788,13 @@ var ( AdvancerCheckPointLagLimit = atomic.NewDuration(DefTiDBAdvancerCheckPointLagLimit) CircuitBreakerPDMetadataErrorRateThresholdRatio = atomic.NewFloat64(0.0) + // ConnectAttrsSize is the max aggregate byte size of connection attributes allowed per connection. + // Corresponds to performance_schema_session_connect_attrs_size. Default 4096. + ConnectAttrsSize = atomic.NewInt64(DefConnectAttrsSize) + // ConnectAttrsLongestSeen tracks the largest connection attribute aggregate size seen so far. + ConnectAttrsLongestSeen = atomic.NewInt64(0) + // ConnectAttrsLost counts the number of connections whose attributes were truncated. + ConnectAttrsLost = atomic.NewInt64(0) ) var ( diff --git a/pkg/sessionctx/variable/variable_test.go b/pkg/sessionctx/variable/variable_test.go index 3a1057082b647..94c50692f7921 100644 --- a/pkg/sessionctx/variable/variable_test.go +++ b/pkg/sessionctx/variable/variable_test.go @@ -144,6 +144,35 @@ func TestIntValidation(t *testing.T) { require.Equal(t, "-1", val) } +func TestPerformanceSchemaSessionConnectAttrsSizeValidation(t *testing.T) { + sv := GetSysVar(PerformanceSchemaSessionConnectAttrsSize) + require.NotNil(t, sv) + require.True(t, sv.HasGlobalScope()) + require.False(t, sv.HasSessionScope()) + + vars := NewSessionVars(nil) + + val, err := sv.Validate(vars, "-1", ScopeGlobal) + require.NoError(t, err) + require.Equal(t, "-1", val) + + val, err = sv.Validate(vars, "0", ScopeGlobal) + require.NoError(t, err) + require.Equal(t, "0", val) + + val, err = sv.Validate(vars, "65536", ScopeGlobal) + require.NoError(t, err) + require.Equal(t, "65536", val) + + val, err = sv.Validate(vars, "65537", ScopeGlobal) + require.NoError(t, err) + require.Equal(t, "65536", val) + + val, err = sv.Validate(vars, "-2", ScopeGlobal) + require.NoError(t, err) + require.Equal(t, "-1", val) +} + func TestUintValidation(t *testing.T) { sv := SysVar{Scope: ScopeGlobal | ScopeSession, Name: "mynewsysvar", Value: "123", Type: TypeUnsigned, MinValue: 10, MaxValue: 300, AllowAutoValue: true} vars := NewSessionVars(nil) diff --git a/pkg/testkit/testutil/require.go b/pkg/testkit/testutil/require.go index 165de9e761d2d..9454689fdf251 100644 --- a/pkg/testkit/testutil/require.go +++ b/pkg/testkit/testutil/require.go @@ -77,6 +77,28 @@ func CompareUnorderedStringSlice(a []string, b []string) bool { var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") +const defaultSessionConnectAttrsJSON = `{"_client_name":"Go-MySQL-Driver","_os":"linux","app_name":"test_app"}` + +// DefaultSessionConnectAttrsJSON returns the shared fixture JSON used in slow-log tests. +func DefaultSessionConnectAttrsJSON() string { + return defaultSessionConnectAttrsJSON +} + +// DefaultSessionConnectAttrsSlowLogLine returns the shared slow-log line for Session_connect_attrs fixture. +func DefaultSessionConnectAttrsSlowLogLine() string { + return "# Session_connect_attrs: " + defaultSessionConnectAttrsJSON +} + +// RequireContainsDefaultSessionConnectAttrs verifies the expected fixture keys/values are present. +func RequireContainsDefaultSessionConnectAttrs(t testing.TB, attrsText string) { + require.Contains(t, attrsText, `"_client_name"`) + require.Contains(t, attrsText, `"Go-MySQL-Driver"`) + require.Contains(t, attrsText, `"_os"`) + require.Contains(t, attrsText, `"linux"`) + require.Contains(t, attrsText, `"app_name"`) + require.Contains(t, attrsText, `"test_app"`) +} + // RandStringRunes generate random string of length n. func RandStringRunes(n int) string { b := make([]rune, n)