diff --git a/go/base/context.go b/go/base/context.go index 2c8d28d56..737069c83 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -175,6 +175,9 @@ type MigrationContext struct { CutOverType CutOver ReplicaServerId uint + // Number of workers used by the trx coordinator + NumWorkers int + Hostname string AssumeMasterHostname string ApplierTimeZone string diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index d690a9f65..5fc3aa67b 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -1,6 +1,6 @@ /* Copyright 2022 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE + See https://github.com/github/gh-ost/blob/master/LICENSE */ package binlog @@ -11,7 +11,6 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" - "github.com/github/gh-ost/go/sql" "time" @@ -85,59 +84,17 @@ func (this *GoMySQLReader) GetCurrentBinlogCoordinates() mysql.BinlogCoordinates return this.currentCoordinates.Clone() } -func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEvent *replication.RowsEvent, entriesChannel chan<- *BinlogEntry) error { - currentCoords := this.GetCurrentBinlogCoordinates() - dml := ToEventDML(ev.Header.EventType.String()) - if dml == NotDML { - return fmt.Errorf("Unknown DML type: %s", ev.Header.EventType.String()) - } - for i, row := range rowsEvent.Rows { - if dml == UpdateDML && i%2 == 1 { - // An update has two rows (WHERE+SET) - // We do both at the same time - continue - } - binlogEntry := NewBinlogEntryAt(currentCoords) - binlogEntry.DmlEvent = NewBinlogDMLEvent( - string(rowsEvent.Table.Schema), - string(rowsEvent.Table.Table), - dml, - ) - switch dml { - case InsertDML: - { - binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(row) - } - case UpdateDML: - { - binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) - binlogEntry.DmlEvent.NewColumnValues = sql.ToColumnValues(rowsEvent.Rows[i+1]) - } - case DeleteDML: - { - binlogEntry.DmlEvent.WhereColumnValues = sql.ToColumnValues(row) - } - } - - // The channel will do the throttling. Whoever is reading from the channel - // decides whether action is taken synchronously (meaning we wait before - // next iteration) or asynchronously (we keep pushing more events) - // In reality, reads will be synchronous - entriesChannel <- binlogEntry - } - return nil -} - -// StreamEvents -func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesChannel chan<- *BinlogEntry) error { - if canStopStreaming() { - return nil - } +// StreamEvents reads binlog events and sends them to the given channel. +// It is blocking and should be executed in a goroutine. +func (this *GoMySQLReader) StreamEvents(ctx context.Context, canStopStreaming func() bool, eventChannel chan<- *replication.BinlogEvent) error { for { if canStopStreaming() { - break + return nil + } + if err := ctx.Err(); err != nil { + return err } - ev, err := this.binlogStreamer.GetEvent(context.Background()) + ev, err := this.binlogStreamer.GetEvent(ctx) if err != nil { return err } @@ -159,45 +116,38 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha switch event := ev.Event.(type) { case *replication.GTIDEvent: - if !this.migrationContext.UseGTIDs { - continue - } - sid, err := uuid.FromBytes(event.SID) - if err != nil { - return err - } - this.currentCoordinatesMutex.Lock() - if this.LastTrxCoords != nil { - this.currentCoordinates = this.LastTrxCoords.Clone() + if this.migrationContext.UseGTIDs { + sid, err := uuid.FromBytes(event.SID) + if err != nil { + return err + } + this.currentCoordinatesMutex.Lock() + if this.LastTrxCoords != nil { + this.currentCoordinates = this.LastTrxCoords.Clone() + } + coords := this.currentCoordinates.(*mysql.GTIDBinlogCoordinates) + trxGset := gomysql.NewUUIDSet(sid, gomysql.Interval{Start: event.GNO, Stop: event.GNO + 1}) + coords.GTIDSet.AddSet(trxGset) + this.currentCoordinatesMutex.Unlock() } - coords := this.currentCoordinates.(*mysql.GTIDBinlogCoordinates) - trxGset := gomysql.NewUUIDSet(sid, gomysql.Interval{Start: event.GNO, Stop: event.GNO + 1}) - coords.GTIDSet.AddSet(trxGset) - this.currentCoordinatesMutex.Unlock() case *replication.RotateEvent: - if this.migrationContext.UseGTIDs { - continue + if !this.migrationContext.UseGTIDs { + this.currentCoordinatesMutex.Lock() + coords := this.currentCoordinates.(*mysql.FileBinlogCoordinates) + coords.LogFile = string(event.NextLogName) + this.migrationContext.Log.Infof("rotate to next log from %s:%d to %s", coords.LogFile, int64(ev.Header.LogPos), event.NextLogName) + this.currentCoordinatesMutex.Unlock() } - this.currentCoordinatesMutex.Lock() - coords := this.currentCoordinates.(*mysql.FileBinlogCoordinates) - coords.LogFile = string(event.NextLogName) - this.migrationContext.Log.Infof("rotate to next log from %s:%d to %s", coords.LogFile, int64(ev.Header.LogPos), event.NextLogName) - this.currentCoordinatesMutex.Unlock() case *replication.XIDEvent: if this.migrationContext.UseGTIDs { this.LastTrxCoords = &mysql.GTIDBinlogCoordinates{GTIDSet: event.GSet.(*gomysql.MysqlGTIDSet)} } else { this.LastTrxCoords = this.currentCoordinates.Clone() } - case *replication.RowsEvent: - if err := this.handleRowsEvent(ev, event, entriesChannel); err != nil { - return err - } } - } - this.migrationContext.Log.Debugf("done streaming events") - return nil + eventChannel <- ev + } } func (this *GoMySQLReader) Close() error { diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 567137fd5..9faa3af38 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -112,6 +112,7 @@ func main() { flag.BoolVar(&migrationContext.PanicOnWarnings, "panic-on-warnings", false, "Panic when SQL warnings are encountered when copying a batch indicating data loss") cutOverLockTimeoutSeconds := flag.Int64("cut-over-lock-timeout-seconds", 3, "Max number of seconds to hold locks on tables while attempting to cut-over (retry attempted when lock exceeds timeout) or attempting instant DDL") niceRatio := flag.Float64("nice-ratio", 0, "force being 'nice', imply sleep time per chunk time; range: [0.0..100.0]. Example values: 0 is aggressive. 1: for every 1ms spent copying rows, sleep additional 1ms (effectively doubling runtime); 0.7: for every 10ms spend in a rowcopy chunk, spend 7ms sleeping immediately after") + flag.IntVar(&migrationContext.NumWorkers, "workers", 8, "Number of concurrent workers for applying DML events. Each worker uses one goroutine.") maxLagMillis := flag.Int64("max-lag-millis", 1500, "replication lag at which to throttle operation") replicationLagQuery := flag.String("replication-lag-query", "", "Deprecated. gh-ost uses an internal, subsecond resolution query") diff --git a/go/logic/applier.go b/go/logic/applier.go index 709fd08da..ddf1c277c 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -109,12 +109,14 @@ func (this *Applier) compileMigrationKeyWarningRegex() (*regexp.Regexp, error) { return migrationKeyRegex, nil } -func (this *Applier) InitDBConnections() (err error) { +func (this *Applier) InitDBConnections(maxConns int) (err error) { applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri) if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil { return err } + this.db.SetMaxOpenConns(maxConns) + this.db.SetMaxIdleConns(maxConns) singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil { return err diff --git a/go/logic/applier_test.go b/go/logic/applier_test.go index a9ecb889d..dcdedba08 100644 --- a/go/logic/applier_test.go +++ b/go/logic/applier_test.go @@ -333,7 +333,7 @@ func (suite *ApplierTestSuite) TestInitDBConnections() { applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) mysqlVersion, _ := strings.CutPrefix(testMysqlContainerImage, "mysql:") @@ -374,7 +374,7 @@ func (suite *ApplierTestSuite) TestApplyDMLEventQueries() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) dmlEvents := []*binlog.BinlogDMLEvent{ @@ -431,7 +431,7 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTables() { applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.ValidateOrDropExistingTables() @@ -463,7 +463,7 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTablesWithGhostTableExi applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.ValidateOrDropExistingTables() @@ -494,7 +494,7 @@ func (suite *ApplierTestSuite) TestValidateOrDropExistingTablesWithGhostTableExi applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.ValidateOrDropExistingTables() @@ -532,7 +532,7 @@ func (suite *ApplierTestSuite) TestCreateGhostTable() { applier := NewApplier(migrationContext) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) err = applier.CreateGhostTable() @@ -586,7 +586,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsInApplyIterationInsertQuerySuc suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(8) suite.Require().NoError(err) _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, item_id) VALUES (123456, 42);", getTestTableName())) @@ -676,7 +676,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsInApplyIterationInsertQueryFai } applier := NewApplier(migrationContext) - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) err = applier.CreateChangelogTable() @@ -743,7 +743,7 @@ func (suite *ApplierTestSuite) TestWriteCheckpoint() { applier := NewApplier(migrationContext) - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) err = applier.CreateChangelogTable() @@ -825,7 +825,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsWithDuplicateKeyOnNonMigration suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table (simulating bulk copy phase) @@ -914,7 +914,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsWithDuplicateCompositeUniqueKe suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table (simulating bulk copy phase) @@ -1016,7 +1016,7 @@ func (suite *ApplierTestSuite) TestUpdateModifyingUniqueKeyWithDuplicateOnOtherI suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Setup: Insert initial rows into ghost table @@ -1111,7 +1111,7 @@ func (suite *ApplierTestSuite) TestNormalUpdateWithPanicOnWarnings() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Setup: Insert initial rows into ghost table @@ -1191,7 +1191,7 @@ func (suite *ApplierTestSuite) TestDuplicateOnMigrationKeyAllowedInBinlogReplay( suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table (simulating bulk copy phase) @@ -1282,7 +1282,7 @@ func (suite *ApplierTestSuite) TestRegexMetacharactersInIndexName() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows @@ -1384,7 +1384,7 @@ func (suite *ApplierTestSuite) TestPanicOnWarningsDisabled() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table @@ -1473,7 +1473,7 @@ func (suite *ApplierTestSuite) TestMultipleDMLEventsInBatch() { suite.Require().NoError(applier.prepareQueries()) defer applier.Teardown() - err = applier.InitDBConnections() + err = applier.InitDBConnections(1) suite.Require().NoError(err) // Insert initial rows into ghost table diff --git a/go/logic/coordinator.go b/go/logic/coordinator.go new file mode 100644 index 000000000..dc865b4cd --- /dev/null +++ b/go/logic/coordinator.go @@ -0,0 +1,669 @@ +package logic + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + "errors" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/binlog" + "github.com/github/gh-ost/go/mysql" + "github.com/github/gh-ost/go/sql" + "github.com/go-mysql-org/go-mysql/replication" + drivermysql "github.com/go-sql-driver/mysql" +) + +type Coordinator struct { + migrationContext *base.MigrationContext + + binlogReader *binlog.GoMySQLReader + + onChangelogEvent func(dmlEvent *binlog.BinlogDMLEvent) error + + applier *Applier + + throttler *Throttler + + // Atomic counter for number of active workers (not in workerQueue) + busyWorkers atomic.Int64 + + // Mutex to protect the fields below + mu sync.Mutex + + // list of workers + workers []*Worker + + // The low water mark. We maintain that all transactions with + // sequence number <= lowWaterMark have been completed. + lowWaterMark int64 + + // This is a map of completed jobs by their sequence numbers. + // This is used when updating the low water mark. + // It records the binlog coordinates of the completed transaction. + completedJobs map[int64]struct{} + + // These are the jobs that are waiting for a previous job to complete. + // They are indexed by the sequence number of the job they are waiting for. + waitingJobs map[int64][]chan struct{} + + events chan *replication.BinlogEvent + + workerQueue chan *Worker + + // fatalErr stores the first fatal error from any worker goroutine. + fatalErr error + fatalErrMu sync.Mutex + // failedCh is closed on the first fatal worker error; all blocking + // coordinator and worker operations select on this to unblock. + failedCh chan struct{} + + finishedMigrating atomic.Bool +} + +// Worker takes jobs from the Coordinator and applies the job's DML events. +type Worker struct { + id int + coordinator *Coordinator + eventQueue chan *replication.BinlogEvent + + executedJobs atomic.Int64 + dmlEventsApplied atomic.Int64 + waitTimeNs atomic.Int64 + busyTimeNs atomic.Int64 +} + +type stats struct { + dmlRate float64 + trxRate float64 + + // Number of DML events applied + dmlEventsApplied int64 + + // Number of transactions processed + executedJobs int64 + + // Time spent applying DML events + busyTime time.Duration + + // Time spent waiting on transaction dependecies + // or waiting on events to arrive in queue. + waitTime time.Duration +} + +// isRetryableError returns true for MySQL errors that are safe to retry +// (deadlock and lock wait timeout). +func isRetryableError(err error) bool { + var mysqlErr *drivermysql.MySQLError + if errors.As(err, &mysqlErr) { + switch mysqlErr.Number { + case 1213, 1205: // deadlock, lock wait timeout + return true + } + } + return false +} + +// setFatalError records the first fatal error and closes failedCh so +// all blocking operations in the coordinator and workers unblock. +func (c *Coordinator) setFatalError(err error) { + c.fatalErrMu.Lock() + defer c.fatalErrMu.Unlock() + if c.fatalErr == nil { + c.fatalErr = err + close(c.failedCh) + } +} + +// getFatalError returns the first fatal error, or nil. +func (c *Coordinator) getFatalError() error { + c.fatalErrMu.Lock() + defer c.fatalErrMu.Unlock() + return c.fatalErr +} + +func (w *Worker) ProcessEvents() error { + databaseName := w.coordinator.migrationContext.DatabaseName + originalTableName := w.coordinator.migrationContext.OriginalTableName + changelogTableName := w.coordinator.migrationContext.GetChangelogTableName() + + for { + if w.coordinator.finishedMigrating.Load() { + return nil + } + + // Wait for first event (GTID), interruptible by fatal error + waitStart := time.Now() + var ev *replication.BinlogEvent + select { + case ev = <-w.eventQueue: + case <-w.coordinator.failedCh: + return fmt.Errorf("aborting: %w", w.coordinator.getFatalError()) + } + w.waitTimeNs.Add(time.Since(waitStart).Nanoseconds()) + + // Verify this is a GTID Event + gtidEvent, ok := ev.Event.(*replication.GTIDEvent) + if !ok { + w.coordinator.migrationContext.Log.Debugf("Received unexpected event: %v\n", ev) + } + + // Dependency wait is done by the coordinator before dispatch + // (coordinator-side scheduling, matching MySQL applier semantics). + + // Process the transaction + var changelogEvent *binlog.BinlogDMLEvent + var txErr error + dmlEvents := make([]*binlog.BinlogDMLEvent, 0, int(atomic.LoadInt64(&w.coordinator.migrationContext.DMLBatchSize))) + events: + for { + // wait for next event in the transaction + waitStart := time.Now() + var ev *replication.BinlogEvent + select { + case ev = <-w.eventQueue: + case <-w.coordinator.failedCh: + w.coordinator.busyWorkers.Add(-1) + return fmt.Errorf("aborting: %w", w.coordinator.getFatalError()) + } + w.waitTimeNs.Add(time.Since(waitStart).Nanoseconds()) + + if ev == nil { + break events + } + + switch binlogEvent := ev.Event.(type) { + case *replication.RowsEvent: + dml := binlog.ToEventDML(ev.Header.EventType.String()) + if dml == binlog.NotDML { + w.coordinator.busyWorkers.Add(-1) + return fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String()) + } + + if !strings.EqualFold(databaseName, string(binlogEvent.Table.Schema)) { + continue + } + + if !strings.EqualFold(originalTableName, string(binlogEvent.Table.Table)) && !strings.EqualFold(changelogTableName, string(binlogEvent.Table.Table)) { + continue + } + + for i, row := range binlogEvent.Rows { + if dml == binlog.UpdateDML && i%2 == 1 { + // An update has two rows (WHERE+SET) + // We do both at the same time + continue + } + dmlEvent := binlog.NewBinlogDMLEvent( + string(binlogEvent.Table.Schema), + string(binlogEvent.Table.Table), + dml, + ) + switch dml { + case binlog.InsertDML: + { + dmlEvent.NewColumnValues = sql.ToColumnValues(row) + } + case binlog.UpdateDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + dmlEvent.NewColumnValues = sql.ToColumnValues(binlogEvent.Rows[i+1]) + } + case binlog.DeleteDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + } + } + + if strings.EqualFold(changelogTableName, string(binlogEvent.Table.Table)) { + changelogEvent = dmlEvent + } else { + dmlEvents = append(dmlEvents, dmlEvent) + + if len(dmlEvents) == cap(dmlEvents) { + if err := w.applyDMLEvents(dmlEvents); err != nil { + txErr = err + break events + } + dmlEvents = dmlEvents[:0] + } + } + } + case *replication.XIDEvent: + if len(dmlEvents) > 0 { + if err := w.applyDMLEvents(dmlEvents); err != nil { + txErr = err + break events + } + } + + w.executedJobs.Add(1) + break events + } + } + + if txErr != nil { + // Fatal: DML failed after retries. Decrement busyWorkers + // since we won't reach the normal cleanup path below. + w.coordinator.busyWorkers.Add(-1) + return txErr + } + + w.coordinator.MarkTransactionCompleted(gtidEvent.SequenceNumber, int64(ev.Header.LogPos), int64(ev.Header.EventSize)) + + // Did we see a changelog event? + // Handle it now + if changelogEvent != nil { + // wait for all transactions before this point + clWaitCh := w.coordinator.WaitForTransaction(gtidEvent.SequenceNumber - 1) + if clWaitCh != nil { + waitStart := time.Now() + select { + case <-clWaitCh: + case <-w.coordinator.failedCh: + w.coordinator.busyWorkers.Add(-1) + return fmt.Errorf("aborting: %w", w.coordinator.getFatalError()) + } + w.waitTimeNs.Add(time.Since(waitStart).Nanoseconds()) + } + w.coordinator.HandleChangeLogEvent(changelogEvent) + } + + w.coordinator.workerQueue <- w + w.coordinator.busyWorkers.Add(-1) + } +} + +func (w *Worker) applyDMLEvents(dmlEvents []*binlog.BinlogDMLEvent) error { + if w.coordinator.throttler != nil { + w.coordinator.throttler.throttle(nil) + } + // Deadlocks between parallel workers are expected due to InnoDB gap locks + // on secondary indexes. Use a generous retry limit with jittered backoff + // to handle contention between workers. + const maxDeadlockRetries = 100 + var err error + for attempt := 0; attempt < maxDeadlockRetries; attempt++ { + if attempt > 0 { + // Jittered exponential backoff: base * 2^min(attempt,7) + random jitter + base := time.Duration(10) * time.Millisecond + backoff := base * (1 << min(attempt, 7)) + jitter := time.Duration(rand.Int63n(int64(backoff))) + time.Sleep(backoff + jitter) + } + busyStart := time.Now() + err = w.coordinator.applier.ApplyDMLEventQueries(dmlEvents) + w.busyTimeNs.Add(time.Since(busyStart).Nanoseconds()) + if err == nil { + w.dmlEventsApplied.Add(int64(len(dmlEvents))) + return nil + } + if !isRetryableError(err) { + return err + } + if attempt > 0 && attempt%10 == 0 { + w.coordinator.migrationContext.Log.Infof("Worker %d: DML batch retry attempt %d after deadlock", w.id, attempt) + } + } + return fmt.Errorf("DML batch failed after %d deadlock retries: %w", maxDeadlockRetries, err) +} + +func NewCoordinator(migrationContext *base.MigrationContext, applier *Applier, throttler *Throttler, onChangelogEvent func(dmlEvent *binlog.BinlogDMLEvent) error) *Coordinator { + return &Coordinator{ + migrationContext: migrationContext, + + onChangelogEvent: onChangelogEvent, + + throttler: throttler, + + binlogReader: binlog.NewGoMySQLReader(migrationContext), + + lowWaterMark: -1, + completedJobs: make(map[int64]struct{}), + waitingJobs: make(map[int64][]chan struct{}), + + events: make(chan *replication.BinlogEvent, 1000), + failedCh: make(chan struct{}), + } +} + +func (c *Coordinator) StartStreaming(ctx context.Context, coords mysql.BinlogCoordinates, canStopStreaming func() bool) error { + err := c.binlogReader.ConnectBinlogStreamer(coords) + if err != nil { + return err + } + defer c.binlogReader.Close() + + var retries int64 + for { + if err := ctx.Err(); err != nil { + return err + } + if canStopStreaming() { + return nil + } + if err := c.binlogReader.StreamEvents(ctx, canStopStreaming, c.events); err != nil { + if errors.Is(err, context.Canceled) { + return err + } + + c.migrationContext.Log.Infof("StreamEvents encountered unexpected error: %+v", err) + c.migrationContext.MarkPointOfInterest() + + if retries >= c.migrationContext.MaxRetries() { + return fmt.Errorf("%d successive failures in streamer reconnect at coordinates %+v", retries, coords) + } + c.migrationContext.Log.Infof("Reconnecting... Will resume at %+v", coords) + + // We reconnect from the event that was last emitted to the stream. + // This ensures we don't miss any events, and we don't process any events twice. + // Processing events twice messes up the transaction tracking and + // will cause data corruption. + coords := c.binlogReader.GetCurrentBinlogCoordinates() + if err := c.binlogReader.ConnectBinlogStreamer(coords); err != nil { + return err + } + retries += 1 + } + } +} + +func (c *Coordinator) ProcessEventsUntilNextChangelogEvent() (*binlog.BinlogDMLEvent, error) { + databaseName := c.migrationContext.DatabaseName + changelogTableName := c.migrationContext.GetChangelogTableName() + + for ev := range c.events { + switch binlogEvent := ev.Event.(type) { + case *replication.RowsEvent: + dml := binlog.ToEventDML(ev.Header.EventType.String()) + if dml == binlog.NotDML { + return nil, fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String()) + } + + if !strings.EqualFold(databaseName, string(binlogEvent.Table.Schema)) { + continue + } + + if !strings.EqualFold(changelogTableName, string(binlogEvent.Table.Table)) { + continue + } + + for i, row := range binlogEvent.Rows { + if dml == binlog.UpdateDML && i%2 == 1 { + // An update has two rows (WHERE+SET) + // We do both at the same time + continue + } + dmlEvent := binlog.NewBinlogDMLEvent( + string(binlogEvent.Table.Schema), + string(binlogEvent.Table.Table), + dml, + ) + switch dml { + case binlog.InsertDML: + { + dmlEvent.NewColumnValues = sql.ToColumnValues(row) + } + case binlog.UpdateDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + dmlEvent.NewColumnValues = sql.ToColumnValues(binlogEvent.Rows[i+1]) + } + case binlog.DeleteDML: + { + dmlEvent.WhereColumnValues = sql.ToColumnValues(row) + } + } + + return dmlEvent, nil + } + } + } + + //nolint:nilnil + return nil, nil +} + +// ProcessEventsUntilDrained reads binlog events and sends them to the workers to process. +// It exits when the event queue is empty and all the workers are returned to the workerQueue. +func (c *Coordinator) ProcessEventsUntilDrained() error { + for { + // Check for fatal worker error first + select { + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + default: + } + + select { + // Read events from the binlog and submit them to the next worker + case ev := <-c.events: + { + if c.finishedMigrating.Load() { + return nil + } + + switch binlogEvent := ev.Event.(type) { + case *replication.GTIDEvent: + c.mu.Lock() + if c.lowWaterMark < 0 && binlogEvent.SequenceNumber > 0 { + c.lowWaterMark = binlogEvent.SequenceNumber - 1 + } + c.mu.Unlock() + + // Coordinator-side dependency wait: don't schedule this + // transaction until all its dependencies are complete. + // This matches MySQL's replication applier coordinator + // semantics (schedule iff lwm >= lastCommitted). + waitChannel := c.WaitForTransaction(binlogEvent.LastCommitted) + if waitChannel != nil { + select { + case <-waitChannel: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + } + case *replication.RotateEvent: + c.migrationContext.Log.Infof("rotate to next log in %s", binlogEvent.NextLogName) + // Binlog rotation resets sequence numbers. We must + // drain all workers (old file) and reset the lwm so + // that dependency checking uses the new file's + // sequence number space. + c.mu.Lock() + needsReset := c.lowWaterMark >= 0 + c.mu.Unlock() + if needsReset { + for c.busyWorkers.Load() > 0 { + select { + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + default: + } + time.Sleep(time.Millisecond) + } + c.mu.Lock() + c.lowWaterMark = -1 + c.completedJobs = make(map[int64]struct{}) + c.waitingJobs = make(map[int64][]chan struct{}) + c.mu.Unlock() + } + continue + default: // ignore all other events + continue + } + + // Acquire a worker, interruptible by fatal error + var worker *Worker + select { + case worker = <-c.workerQueue: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + c.busyWorkers.Add(1) + + // Send GTID to worker, interruptible + select { + case worker.eventQueue <- ev: + case <-c.failedCh: + c.busyWorkers.Add(-1) + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + + ev = <-c.events + + switch binlogEvent := ev.Event.(type) { + case *replication.QueryEvent: + if bytes.Equal([]byte("BEGIN"), binlogEvent.Query) { + } else { + worker.eventQueue <- nil + continue + } + default: + worker.eventQueue <- nil + continue + } + + events: + for { + ev = <-c.events + switch ev.Event.(type) { + case *replication.RowsEvent: + select { + case worker.eventQueue <- ev: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + case *replication.XIDEvent: + select { + case worker.eventQueue <- ev: + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + } + + // We're done with this transaction + break events + } + } + } + + // No events in the queue. Check if all workers are sleeping now + default: + { + select { + case <-c.failedCh: + return fmt.Errorf("worker error: %w", c.getFatalError()) + default: + } + if c.busyWorkers.Load() == 0 { + return nil + } + } + } + } +} + +func (c *Coordinator) InitializeWorkers(count int) { + c.workerQueue = make(chan *Worker, count) + for i := 0; i < count; i++ { + w := &Worker{id: i, coordinator: c, eventQueue: make(chan *replication.BinlogEvent, 1000)} + + c.mu.Lock() + c.workers = append(c.workers, w) + c.mu.Unlock() + + c.workerQueue <- w + go func() { + if err := w.ProcessEvents(); err != nil { + c.migrationContext.Log.Errorf("Worker %d fatal error: %v", w.id, err) + c.setFatalError(err) + } + }() + } +} + +// GetWorkerStats collects profiling stats for ProcessEvents from each worker. +func (c *Coordinator) GetWorkerStats() []stats { + c.mu.Lock() + defer c.mu.Unlock() + statSlice := make([]stats, 0, len(c.workers)) + for _, w := range c.workers { + stat := stats{} + stat.dmlEventsApplied = w.dmlEventsApplied.Load() + stat.executedJobs = w.executedJobs.Load() + stat.busyTime = time.Duration(w.busyTimeNs.Load()) + stat.waitTime = time.Duration(w.waitTimeNs.Load()) + if stat.busyTime.Milliseconds() > 0 { + stat.dmlRate = 1000.0 * float64(stat.dmlEventsApplied) / float64(stat.busyTime.Milliseconds()) + stat.trxRate = 1000.0 * float64(stat.executedJobs) / float64(stat.busyTime.Milliseconds()) + } + statSlice = append(statSlice, stat) + } + return statSlice +} + +func (c *Coordinator) WaitForTransaction(lastCommitted int64) chan struct{} { + c.mu.Lock() + defer c.mu.Unlock() + + if lastCommitted <= c.lowWaterMark { + return nil + } + + // Buffered so MarkTransactionCompleted never blocks if the waiter + // already exited (e.g. via failedCh). + waitChannel := make(chan struct{}, 1) + c.waitingJobs[lastCommitted] = append(c.waitingJobs[lastCommitted], waitChannel) + + return waitChannel +} + +func (c *Coordinator) HandleChangeLogEvent(event *binlog.BinlogDMLEvent) { + c.mu.Lock() + defer c.mu.Unlock() + c.onChangelogEvent(event) +} + +func (c *Coordinator) MarkTransactionCompleted(sequenceNumber, logPos, eventSize int64) { + var channelsToNotify []chan struct{} + + func() { + c.mu.Lock() + defer c.mu.Unlock() + + // Mark the job as completed + c.completedJobs[sequenceNumber] = struct{}{} + + // Then, update the low water mark if possible + for { + if _, ok := c.completedJobs[c.lowWaterMark+1]; ok { + c.lowWaterMark++ + delete(c.completedJobs, c.lowWaterMark) + } else { + break + } + } + channelsToNotify = make([]chan struct{}, 0) + + // Schedule any jobs that were waiting for this job to complete or for the low watermark + for waitingForSequenceNumber, channels := range c.waitingJobs { + if waitingForSequenceNumber <= c.lowWaterMark { + channelsToNotify = append(channelsToNotify, channels...) + delete(c.waitingJobs, waitingForSequenceNumber) + } + } + }() + + for _, waitChannel := range channelsToNotify { + waitChannel <- struct{}{} + } +} + +func (c *Coordinator) Teardown() { + c.finishedMigrating.Store(true) +} diff --git a/go/logic/coordinator_test.go b/go/logic/coordinator_test.go new file mode 100644 index 000000000..b38b7ff28 --- /dev/null +++ b/go/logic/coordinator_test.go @@ -0,0 +1,382 @@ +package logic + +import ( + "context" + gosql "database/sql" + "fmt" + "math/rand/v2" + "os" + "testing" + "time" + + "path/filepath" + "runtime" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/binlog" + "github.com/github/gh-ost/go/mysql" + "github.com/github/gh-ost/go/sql" + "github.com/stretchr/testify/suite" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" + "golang.org/x/sync/errgroup" +) + +type CoordinatorTestSuite struct { + suite.Suite + + mysqlContainer testcontainers.Container + db *gosql.DB + concurrentTransactions int + transactionsPerWorker int + transactionSize int +} + +func (suite *CoordinatorTestSuite) SetupSuite() { + ctx := context.Background() + req := testcontainers.ContainerRequest{ + Image: "mysql:8.0.40", + Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root-password"}, + WaitingFor: wait.ForListeningPort("3306/tcp"), + } + + mysqlContainer, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + suite.Require().NoError(err) + + suite.mysqlContainer = mysqlContainer + + dsn, err := GetDSN(ctx, mysqlContainer) + suite.Require().NoError(err) + + db, err := gosql.Open("mysql", dsn) + suite.Require().NoError(err) + + suite.db = db + suite.concurrentTransactions = 8 + suite.transactionsPerWorker = 1000 + suite.transactionSize = 10 + + db.SetMaxOpenConns(suite.concurrentTransactions) +} + +func (suite *CoordinatorTestSuite) SetupTest() { + ctx := context.Background() + _, err := suite.db.ExecContext(ctx, "RESET MASTER") + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, "SET @@GLOBAL.binlog_transaction_dependency_tracking = WRITESET") + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET @@GLOBAL.max_connections = %d", suite.concurrentTransactions*2)) + suite.Require().NoError(err) + + _, err = suite.db.ExecContext(ctx, "CREATE DATABASE test") + suite.Require().NoError(err) +} + +func (suite *CoordinatorTestSuite) TearDownTest() { + ctx := context.Background() + _, err := suite.db.ExecContext(ctx, "DROP DATABASE test") + suite.Require().NoError(err) +} + +func (suite *CoordinatorTestSuite) TeardownSuite() { + ctx := context.Background() + + suite.Assert().NoError(suite.db.Close()) + suite.Assert().NoError(suite.mysqlContainer.Terminate(ctx)) +} + +func (suite *CoordinatorTestSuite) TestApplyDML() { + ctx := context.Background() + + connectionConfig, err := GetConnectionConfig(ctx, suite.mysqlContainer) + suite.Require().NoError(err) + + _ = os.Remove("/tmp/gh-ost.sock") + + _, err = suite.db.Exec("CREATE TABLE test.gh_ost_test (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(255)) ENGINE=InnoDB") + suite.Require().NoError(err) + + _, err = suite.db.Exec("CREATE TABLE test._gh_ost_test_gho (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(255))") + suite.Require().NoError(err) + + migrationContext := base.NewMigrationContext() + migrationContext.DatabaseName = "test" + migrationContext.OriginalTableName = "gh_ost_test" + migrationContext.AlterStatement = "ALTER TABLE gh_ost_test ENGINE=InnoDB" + migrationContext.AllowedRunningOnMaster = true + migrationContext.ReplicaServerId = 99999 + migrationContext.HeartbeatIntervalMilliseconds = 100 + migrationContext.ThrottleHTTPIntervalMillis = 100 + migrationContext.DMLBatchSize = 10 + + migrationContext.ApplierConnectionConfig = connectionConfig + migrationContext.InspectorConnectionConfig = connectionConfig + + migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.GhostTableColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "name"}) + migrationContext.UniqueKey = &sql.UniqueKey{ + Name: "PRIMARY", + Columns: *sql.NewColumnList([]string{"id"}), + IsAutoIncrement: true, + } + + migrationContext.SetConnectionConfig("innodb") + migrationContext.SkipPortValidation = true + migrationContext.NumWorkers = 4 + + //nolint:dogsled + _, filename, _, _ := runtime.Caller(0) + migrationContext.ServeSocketFile = filepath.Join(filepath.Dir(filename), "../../tmp/gh-ost.sock") + + applier := NewApplier(migrationContext) + err = applier.InitDBConnections(migrationContext.NumWorkers) + suite.Require().NoError(err) + + err = applier.prepareQueries() + suite.Require().NoError(err) + + err = applier.CreateChangelogTable() + suite.Require().NoError(err) + + g, _ := errgroup.WithContext(ctx) + for i := range suite.concurrentTransactions { + g.Go(func() error { + r := rand.New(rand.NewPCG(uint64(0), uint64(i))) + maxID := int64(1) + for range suite.transactionsPerWorker { + tx, txErr := suite.db.Begin() + if txErr != nil { + return txErr + } + + // generate random write queries + for range r.IntN(suite.transactionSize) + 1 { + switch r.IntN(5) { + case 0: + _, txErr = tx.Exec(fmt.Sprintf("DELETE FROM test.gh_ost_test WHERE id=%d", r.Int64N(maxID))) + if txErr != nil { + return txErr + } + case 1, 2: + _, txErr = tx.Exec(fmt.Sprintf("UPDATE test.gh_ost_test SET name='test-%d' WHERE id=%d", r.Int(), r.Int64N(maxID))) + if txErr != nil { + return txErr + } + default: + res, txErr := tx.Exec(fmt.Sprintf("INSERT INTO test.gh_ost_test (name) VALUES ('test-%d')", r.Int())) + if txErr != nil { + return txErr + } + lastID, err := res.LastInsertId() + if err != nil { + return err + } + maxID = lastID + 1 + } + } + txErr = tx.Commit() + if txErr != nil { + return txErr + } + } + return nil + }) + } + + _, err = applier.WriteChangelogState("completed") + suite.Require().NoError(err) + + ctx, cancel := context.WithCancel(context.Background()) + + coord := NewCoordinator(migrationContext, applier, nil, + func(dmlEvent *binlog.BinlogDMLEvent) error { + fmt.Printf("Received Changelog DML event: %+v\n", dmlEvent) + fmt.Printf("Rowdata: %v - %v\n", dmlEvent.NewColumnValues, dmlEvent.WhereColumnValues) + + cancel() + + return nil + }) + coord.applier = applier + coord.InitializeWorkers(4) + + streamCtx, cancelStreaming := context.WithCancel(context.Background()) + canStopStreaming := func() bool { + return streamCtx.Err() != nil + } + go func() { + streamErr := coord.StartStreaming(streamCtx, &mysql.FileBinlogCoordinates{ + LogFile: "binlog.000001", + LogPos: int64(4), + }, canStopStreaming) + suite.Require().Equal(context.Canceled, streamErr) + }() + + // Give streamer some time to start + time.Sleep(1 * time.Second) + + startAt := time.Now() + + for { + if ctx.Err() != nil { + cancelStreaming() + break + } + + err = coord.ProcessEventsUntilDrained() + suite.Require().NoError(err) + } + + //err = g.Wait() + //suite.Require().NoError(err) + g.Wait() // there will be deadlock errors + + fmt.Printf("Time taken: %s\n", time.Since(startAt)) + + result, err := suite.db.Exec(`SELECT * FROM ( + SELECT t1.id, + CRC32(CONCAT_WS(';',t1.id,t1.name)) AS checksum1, + CRC32(CONCAT_WS(';',t2.id,t2.name)) AS checksum2 + FROM test.gh_ost_test t1 + LEFT JOIN test._gh_ost_test_gho t2 + ON t1.id = t2.id +) AS checksums +WHERE checksums.checksum1 != checksums.checksum2`) + suite.Require().NoError(err) + + count, err := result.RowsAffected() + suite.Require().NoError(err) + suite.Require().Zero(count) +} + +func TestCoordinator(t *testing.T) { + suite.Run(t, new(CoordinatorTestSuite)) +} + +// TestRotationResetsLowWaterMark is a deterministic unit test verifying that +// after a simulated binlog rotation the coordinator's lowWaterMark is reset +// so that transactions from the new file are properly ordered. +// This is the regression test for the root cause of the MTR data inconsistency: +// MySQL's logical clock (last_committed, sequence_number) resets per-binlog-file, +// but without resetting lwm, post-rotation transactions with small lastCommitted +// values would pass the WaitForTransaction check against the stale high lwm. +func TestRotationResetsLowWaterMark(t *testing.T) { + // Simulate a coordinator that has processed transactions from the first binlog file. + c := &Coordinator{ + lowWaterMark: -1, + completedJobs: make(map[int64]struct{}), + waitingJobs: make(map[int64][]chan struct{}), + failedCh: make(chan struct{}), + } + + // --- First binlog file: sequence numbers 1..5 --- + + // Initialize lwm (simulates first GTID event setting lwm = seqNo - 1 = 0) + c.mu.Lock() + c.lowWaterMark = 0 + c.mu.Unlock() + + // Complete transactions 1 through 5 + for seq := int64(1); seq <= 5; seq++ { + c.MarkTransactionCompleted(seq, 0, 0) + } + + // Verify lwm advanced to 5 + c.mu.Lock() + if c.lowWaterMark != 5 { + t.Fatalf("expected lwm=5 after completing seqs 1-5, got %d", c.lowWaterMark) + } + c.mu.Unlock() + + // A transaction with lastCommitted=3 should pass immediately (3 <= 5) + ch := c.WaitForTransaction(3) + if ch != nil { + t.Fatal("expected WaitForTransaction(3) to return nil when lwm=5") + } + + // --- Simulate binlog rotation: reset coordinator state --- + // This is what the RotateEvent handler does after draining workers. + c.mu.Lock() + c.lowWaterMark = -1 + c.completedJobs = make(map[int64]struct{}) + c.waitingJobs = make(map[int64][]chan struct{}) + c.mu.Unlock() + + // --- Second binlog file: sequence numbers restart at 1 --- + + // Initialize lwm for new file (first GTID sets lwm = seqNo - 1 = 0) + c.mu.Lock() + c.lowWaterMark = 0 + c.mu.Unlock() + + // BUG SCENARIO (before fix): if lwm was still 5 from the old file, + // WaitForTransaction(3) would return nil → tx executes out of order. + // After fix: lwm=0, so WaitForTransaction(3) must block. + ch = c.WaitForTransaction(3) + if ch == nil { + t.Fatal("expected WaitForTransaction(3) to block when lwm=0 in new binlog file, but it returned nil (stale lwm bug!)") + } + + // Complete transactions 1, 2, 3 in the new file + c.MarkTransactionCompleted(1, 0, 0) + c.MarkTransactionCompleted(2, 0, 0) + c.MarkTransactionCompleted(3, 0, 0) + + // Now the wait channel should be notified (lwm advances to 3) + select { + case <-ch: + // success + case <-time.After(time.Second): + t.Fatal("WaitForTransaction(3) was not notified after completing seqs 1-3") + } + + // Verify lwm is now 3 + c.mu.Lock() + if c.lowWaterMark != 3 { + t.Fatalf("expected lwm=3, got %d", c.lowWaterMark) + } + c.mu.Unlock() +} + +// TestBufferedWaitChannelNoDeadlock verifies that if a waiter exits early +// (e.g., via failedCh), MarkTransactionCompleted does not block forever. +func TestBufferedWaitChannelNoDeadlock(t *testing.T) { + c := &Coordinator{ + lowWaterMark: 0, + completedJobs: make(map[int64]struct{}), + waitingJobs: make(map[int64][]chan struct{}), + failedCh: make(chan struct{}), + } + + // Create a waiter for lastCommitted=3 + ch := c.WaitForTransaction(3) + if ch == nil { + t.Fatal("expected a wait channel") + } + + // Simulate the waiter exiting early (not reading from ch) + // This mimics what happens when a worker exits via failedCh. + + // MarkTransactionCompleted should NOT block even though nobody reads ch + done := make(chan struct{}) + go func() { + c.MarkTransactionCompleted(1, 0, 0) + c.MarkTransactionCompleted(2, 0, 0) + c.MarkTransactionCompleted(3, 0, 0) + close(done) + }() + + select { + case <-done: + // success — did not deadlock + case <-time.After(2 * time.Second): + t.Fatal("MarkTransactionCompleted deadlocked because wait channel is unbuffered") + } +} diff --git a/go/logic/inspect.go b/go/logic/inspect.go index 97895890d..7d1f1a373 100644 --- a/go/logic/inspect.go +++ b/go/logic/inspect.go @@ -19,6 +19,8 @@ import ( "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" + gomysql "github.com/go-mysql-org/go-mysql/mysql" + "github.com/openark/golib/sqlutils" ) @@ -951,6 +953,37 @@ func (this *Inspector) readChangelogState(hint string) (string, error) { return result, err } +// readCurrentBinlogCoordinates reads master status from hooked server +func (this *Inspector) readCurrentBinlogCoordinates() (mysql.BinlogCoordinates, error) { + var coords mysql.BinlogCoordinates + query := fmt.Sprintf(`show /* gh-ost readCurrentBinlogCoordinates */ %s`, mysql.ReplicaTermFor(this.migrationContext.InspectorMySQLVersion, "master status")) + foundMasterStatus := false + err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { + if this.migrationContext.UseGTIDs { + execGtidSet := m.GetString("Executed_Gtid_Set") + gtidSet, err := gomysql.ParseMysqlGTIDSet(execGtidSet) + if err != nil { + return err + } + coords = &mysql.GTIDBinlogCoordinates{GTIDSet: gtidSet.(*gomysql.MysqlGTIDSet)} + } else { + coords = &mysql.FileBinlogCoordinates{ + LogFile: m.GetString("File"), + LogPos: m.GetInt64("Position"), + } + } + foundMasterStatus = true + return nil + }) + if err != nil { + return nil, err + } + if !foundMasterStatus { + return nil, fmt.Errorf("Got no results from SHOW MASTER STATUS. Bailing out") + } + return coords, nil +} + func (this *Inspector) getMasterConnectionConfig() (applierConfig *mysql.ConnectionConfig, err error) { this.migrationContext.Log.Infof("Recursively searching for replication master") visitedKeys := mysql.NewInstanceKeyMap() diff --git a/go/logic/migrator.go b/go/logic/migrator.go index e282147ae..c44444a5f 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -11,11 +11,12 @@ import ( "fmt" "io" "math" - "os" "strings" "sync/atomic" "time" + "os" + "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" "github.com/github/gh-ost/go/mysql" @@ -48,23 +49,6 @@ type lockProcessedStruct struct { state string coords mysql.BinlogCoordinates } - -type applyEventStruct struct { - writeFunc *tableWriteFunc - dmlEvent *binlog.BinlogDMLEvent - coords mysql.BinlogCoordinates -} - -func newApplyEventStructByFunc(writeFunc *tableWriteFunc) *applyEventStruct { - result := &applyEventStruct{writeFunc: writeFunc} - return result -} - -func newApplyEventStructByDML(dmlEntry *binlog.BinlogEntry) *applyEventStruct { - result := &applyEventStruct{dmlEvent: dmlEntry.DmlEvent, coords: dmlEntry.Coordinates} - return result -} - type PrintStatusRule int const ( @@ -81,7 +65,6 @@ type Migrator struct { parser *sql.AlterTableParser inspector *Inspector applier *Applier - eventsStreamer *EventsStreamer server *Server throttler *Throttler hooksExecutor *HooksExecutor @@ -96,10 +79,10 @@ type Migrator struct { rowCopyCompleteFlag int64 // copyRowsQueue should not be buffered; if buffered some non-damaging but // excessive work happens at the end of the iteration as new copy-jobs arrive before realizing the copy is complete - copyRowsQueue chan tableWriteFunc - applyEventsQueue chan *applyEventStruct + copyRowsQueue chan tableWriteFunc finishedMigrating int64 + trxCoordinator *Coordinator } func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { @@ -108,13 +91,12 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { hooksExecutor: NewHooksExecutor(context), migrationContext: context, parser: sql.NewAlterTableParser(), - ghostTableMigrated: make(chan bool), + ghostTableMigrated: make(chan bool, 1), firstThrottlingCollected: make(chan bool, 3), rowCopyComplete: make(chan error), allEventsUpToLockProcessed: make(chan *lockProcessedStruct), copyRowsQueue: make(chan tableWriteFunc), - applyEventsQueue: make(chan *applyEventStruct, base.MaxEventsBatchSize), finishedMigrating: 0, } return migrator @@ -251,20 +233,20 @@ func (this *Migrator) canStopStreaming() bool { } // onChangelogEvent is called when a binlog event operation on the changelog table is intercepted. -func (this *Migrator) onChangelogEvent(dmlEntry *binlog.BinlogEntry) (err error) { +func (this *Migrator) onChangelogEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { // Hey, I created the changelog table, I know the type of columns it has! - switch hint := dmlEntry.DmlEvent.NewColumnValues.StringColumn(2); hint { + switch hint := dmlEvent.NewColumnValues.StringColumn(2); hint { case "state": - return this.onChangelogStateEvent(dmlEntry) + return this.onChangelogStateEvent(dmlEvent) case "heartbeat": - return this.onChangelogHeartbeatEvent(dmlEntry) + return this.onChangelogHeartbeatEvent(dmlEvent) default: return nil } } -func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err error) { - changelogStateString := dmlEntry.DmlEvent.NewColumnValues.StringColumn(3) +func (this *Migrator) onChangelogStateEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { + changelogStateString := dmlEvent.NewColumnValues.StringColumn(3) changelogState := ReadChangelogState(changelogStateString) this.migrationContext.Log.Infof("Intercepted changelog state %s", changelogState) switch changelogState { @@ -274,20 +256,15 @@ func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err e // Use helper to prevent deadlock if migration aborts before receiver is ready _ = base.SendWithContext(this.migrationContext.GetContext(), this.ghostTableMigrated, true) case AllEventsUpToLockProcessed: - var applyEventFunc tableWriteFunc = func() error { - return base.SendWithContext(this.migrationContext.GetContext(), this.allEventsUpToLockProcessed, &lockProcessedStruct{ - state: changelogStateString, - coords: dmlEntry.Coordinates.Clone(), - }) - } // at this point we know all events up to lock have been read from the streamer, // because the streamer works sequentially. So those events are either already handled, - // or have event functions in applyEventsQueue. - // So as not to create a potential deadlock, we write this func to applyEventsQueue - // asynchronously, understanding it doesn't really matter. + // or are being processed by the coordinator. + // So as not to create a potential deadlock, we send this asynchronously. go func() { - // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits - _ = base.SendWithContext(this.migrationContext.GetContext(), this.applyEventsQueue, newApplyEventStructByFunc(&applyEventFunc)) + _ = base.SendWithContext(this.migrationContext.GetContext(), this.allEventsUpToLockProcessed, &lockProcessedStruct{ + state: changelogStateString, + coords: this.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates(), + }) }() default: return fmt.Errorf("Unknown changelog state: %+v", changelogState) @@ -296,8 +273,8 @@ func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err e return nil } -func (this *Migrator) onChangelogHeartbeatEvent(dmlEntry *binlog.BinlogEntry) (err error) { - changelogHeartbeatString := dmlEntry.DmlEvent.NewColumnValues.StringColumn(3) +func (this *Migrator) onChangelogHeartbeatEvent(dmlEvent *binlog.BinlogDMLEvent) (err error) { + changelogHeartbeatString := dmlEvent.NewColumnValues.StringColumn(3) heartbeatTime, err := time.Parse(time.RFC3339Nano, changelogHeartbeatString) if err != nil { @@ -305,7 +282,7 @@ func (this *Migrator) onChangelogHeartbeatEvent(dmlEntry *binlog.BinlogEntry) (e } else { this.migrationContext.SetLastHeartbeatOnChangelogTime(heartbeatTime) this.applier.CurrentCoordinatesMutex.Lock() - this.applier.CurrentCoordinates = dmlEntry.Coordinates + this.applier.CurrentCoordinates = this.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates() this.applier.CurrentCoordinatesMutex.Unlock() return nil } @@ -449,6 +426,9 @@ func (this *Migrator) Migrate() (err error) { if err := this.initiateInspector(); err != nil { return err } + + this.trxCoordinator = NewCoordinator(this.migrationContext, this.applier, this.throttler, this.onChangelogEvent) + if err := this.checkAbort(); err != nil { return err } @@ -470,6 +450,9 @@ func (this *Migrator) Migrate() (err error) { if err := this.checkAbort(); err != nil { return err } + + this.trxCoordinator.applier = this.applier + if err := this.createFlagFiles(); err != nil { return err } @@ -495,10 +478,28 @@ func (this *Migrator) Migrate() (err error) { } } + this.migrationContext.Log.Infof("starting %d applier workers", this.migrationContext.NumWorkers) + this.trxCoordinator.InitializeWorkers(this.migrationContext.NumWorkers) + initialLag, _ := this.inspector.getReplicationLag() if !this.migrationContext.Resume { this.migrationContext.Log.Infof("Waiting for ghost table to be migrated. Current lag is %+v", initialLag) - <-this.ghostTableMigrated + + waitForGhostTable: + for { + select { + case <-this.ghostTableMigrated: + break waitForGhostTable + default: + dmlEvent, err := this.trxCoordinator.ProcessEventsUntilNextChangelogEvent() + if err != nil { + return err + } + + this.onChangelogEvent(dmlEvent) + } + } + this.migrationContext.Log.Debugf("ghost table migrated") } // Yay! We now know the Ghost and Changelog tables are good to examine! @@ -553,9 +554,7 @@ func (this *Migrator) Migrate() (err error) { if err := this.countTableRows(); err != nil { return err } - if err := this.addDMLEventsListener(); err != nil { - return err - } + if err := this.applier.ReadMigrationRangeValues(); err != nil { return err } @@ -589,6 +588,7 @@ func (this *Migrator) Migrate() (err error) { return err } this.printStatus(ForcePrintStatusRule) + this.printWorkerStats() if this.migrationContext.IsCountingTableRows() { this.migrationContext.Log.Info("stopping query for exact row count, because that can accidentally lock out the cut over") @@ -704,14 +704,11 @@ func (this *Migrator) Revert() error { return err } defer this.server.RemoveSocketFile() - if err := this.addDMLEventsListener(); err != nil { - return err - } this.initiateThrottler() go this.initiateStatus() go func() { - if err := this.executeDMLWriteFuncs(); err != nil { + if err := this.executeWriteFuncs(); err != nil { // Send error to PanicAbort to trigger abort _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } @@ -1037,10 +1034,13 @@ func (this *Migrator) atomicCutOver() (err error) { // initiateServer begins listening on unix socket/tcp for incoming interactive commands func (this *Migrator) initiateServer() (err error) { - var f printStatusFunc = func(rule PrintStatusRule, writer io.Writer) { + var printStatus printStatusFunc = func(rule PrintStatusRule, writer io.Writer) { this.printStatus(rule, writer) } - this.server = NewServer(this.migrationContext, this.hooksExecutor, f) + var printWorkers printWorkersFunc = func(writer io.Writer) { + this.printWorkerStats(writer) + } + this.server = NewServer(this.migrationContext, this.hooksExecutor, printStatus, printWorkers) if err := this.server.BindSocketFile(); err != nil { return err } @@ -1318,6 +1318,29 @@ func (this *Migrator) shouldPrintMigrationStatusHint(rule PrintStatusRule, elaps return shouldPrint } +// printWorkerStats prints cumulative stats from the trxCoordinator workers. +func (this *Migrator) printWorkerStats(writers ...io.Writer) { + writers = append(writers, os.Stdout) + mw := io.MultiWriter(writers...) + + busyWorkers := this.trxCoordinator.busyWorkers.Load() + totalWorkers := cap(this.trxCoordinator.workerQueue) + fmt.Fprintf(mw, "# %d/%d workers are busy\n", busyWorkers, totalWorkers) + + stats := this.trxCoordinator.GetWorkerStats() + for id, stat := range stats { + fmt.Fprintf(mw, + "Worker %d; Waited: %s; Busy: %s; DML Applied: %d (%.2f/s), Trx Applied: %d (%.2f/s)\n", + id, + base.PrettifyDurationOutput(stat.waitTime), + base.PrettifyDurationOutput(stat.busyTime), + stat.dmlEventsApplied, + stat.dmlRate, + stat.executedJobs, + stat.trxRate) + } +} + // printStatus prints the progress status, and optionally additionally detailed // dump of configuration. // `rule` indicates the type of output expected. @@ -1356,12 +1379,12 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { return } - currentBinlogCoordinates := this.eventsStreamer.GetCurrentBinlogCoordinates() + currentBinlogCoordinates := this.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates() status := fmt.Sprintf("Copy: %d/%d %.1f%%; Applied: %d; Backlog: %d/%d; Time: %+v(total), %+v(copy); streamer: %+v; Lag: %.2fs, HeartbeatLag: %.2fs, State: %s; ETA: %s", totalRowsCopied, rowsEstimate, progressPct, atomic.LoadInt64(&this.migrationContext.TotalDMLEventsApplied), - len(this.applyEventsQueue), cap(this.applyEventsQueue), + len(this.trxCoordinator.events), cap(this.trxCoordinator.events), base.PrettifyDurationOutput(elapsedTime), base.PrettifyDurationOutput(this.migrationContext.ElapsedRowCopyTime()), currentBinlogCoordinates.DisplayString(), this.migrationContext.GetCurrentLagDuration().Seconds(), @@ -1392,22 +1415,15 @@ func (this *Migrator) printStatus(rule PrintStatusRule, writers ...io.Writer) { // initiateStreaming begins streaming of binary log events and registers listeners for such events func (this *Migrator) initiateStreaming() error { - this.eventsStreamer = NewEventsStreamer(this.migrationContext) - if err := this.eventsStreamer.InitDBConnections(); err != nil { + initialCoords, err := this.inspector.readCurrentBinlogCoordinates() + if err != nil { return err } - this.eventsStreamer.AddListener( - false, - this.migrationContext.DatabaseName, - this.migrationContext.GetChangelogTableName(), - func(dmlEntry *binlog.BinlogEntry) error { - return this.onChangelogEvent(dmlEntry) - }, - ) go func() { - this.migrationContext.Log.Debugf("Beginning streaming") - err := this.eventsStreamer.StreamEvents(this.canStopStreaming) + this.migrationContext.Log.Debugf("Beginning streaming at coordinates: %+v", initialCoords) + ctx := context.TODO() + err := this.trxCoordinator.StartStreaming(ctx, initialCoords, this.canStopStreaming) if err != nil { // Use helper to prevent deadlock if listenOnPanicAbort already exited _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) @@ -1422,28 +1438,12 @@ func (this *Migrator) initiateStreaming() error { if atomic.LoadInt64(&this.finishedMigrating) > 0 { return } - this.migrationContext.SetRecentBinlogCoordinates(this.eventsStreamer.GetCurrentBinlogCoordinates()) + this.migrationContext.SetRecentBinlogCoordinates(this.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates()) } }() return nil } -// addDMLEventsListener begins listening for binlog events on the original table, -// and creates & enqueues a write task per such event. -func (this *Migrator) addDMLEventsListener() error { - err := this.eventsStreamer.AddListener( - false, - this.migrationContext.DatabaseName, - this.migrationContext.OriginalTableName, - func(dmlEntry *binlog.BinlogEntry) error { - // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits - // This is critical because this callback blocks the event streamer - return base.SendWithContext(this.migrationContext.GetContext(), this.applyEventsQueue, newApplyEventStructByDML(dmlEntry)) - }, - ) - return err -} - // initiateThrottler kicks in the throttling collection and the throttling checks. func (this *Migrator) initiateThrottler() { this.throttler = NewThrottler(this.migrationContext, this.applier, this.inspector, this.appVersion) @@ -1459,7 +1459,7 @@ func (this *Migrator) initiateThrottler() { func (this *Migrator) initiateApplier() error { this.applier = NewApplier(this.migrationContext) - if err := this.applier.InitDBConnections(); err != nil { + if err := this.applier.InitDBConnections(this.migrationContext.NumWorkers); err != nil { return err } if this.migrationContext.Revert { @@ -1614,67 +1614,11 @@ func (this *Migrator) iterateChunks() error { } } -func (this *Migrator) onApplyEventStruct(eventStruct *applyEventStruct) error { - handleNonDMLEventStruct := func(eventStruct *applyEventStruct) error { - if eventStruct.writeFunc != nil { - if err := this.retryOperation(*eventStruct.writeFunc); err != nil { - return this.migrationContext.Log.Errore(err) - } - } - return nil - } - if eventStruct.dmlEvent == nil { - return handleNonDMLEventStruct(eventStruct) - } - if eventStruct.dmlEvent != nil { - dmlEvents := [](*binlog.BinlogDMLEvent){} - dmlEvents = append(dmlEvents, eventStruct.dmlEvent) - var nonDmlStructToApply *applyEventStruct - - availableEvents := len(this.applyEventsQueue) - batchSize := int(atomic.LoadInt64(&this.migrationContext.DMLBatchSize)) - if availableEvents > batchSize-1 { - // The "- 1" is because we already consumed one event: the original event that led to this function getting called. - // So, if DMLBatchSize==1 we wish to not process any further events - availableEvents = batchSize - 1 - } - for i := 0; i < availableEvents; i++ { - additionalStruct := <-this.applyEventsQueue - if additionalStruct.dmlEvent == nil { - // Not a DML. We don't group this, and we don't batch any further - nonDmlStructToApply = additionalStruct - break - } - dmlEvents = append(dmlEvents, additionalStruct.dmlEvent) - } - // Create a task to apply the DML event; this will be execute by executeWriteFuncs() - var applyEventFunc tableWriteFunc = func() error { - return this.applier.ApplyDMLEventQueries(dmlEvents) - } - if err := this.retryOperation(applyEventFunc); err != nil { - return this.migrationContext.Log.Errore(err) - } - // update applier coordinates - this.applier.CurrentCoordinatesMutex.Lock() - this.applier.CurrentCoordinates = eventStruct.coords - this.applier.CurrentCoordinatesMutex.Unlock() - - if nonDmlStructToApply != nil { - // We pulled DML events from the queue, and then we hit a non-DML event. Wait! - // We need to handle it! - if err := handleNonDMLEventStruct(nonDmlStructToApply); err != nil { - return this.migrationContext.Log.Errore(err) - } - } - } - return nil -} - // Checkpoint attempts to write a checkpoint of the Migrator's current state. // It gets the binlog coordinates of the last received trx and waits until the // applier reaches that trx. At that point it's safe to resume from these coordinates. func (this *Migrator) Checkpoint(ctx context.Context) (*Checkpoint, error) { - coords := this.eventsStreamer.GetCurrentBinlogCoordinates() + coords := this.trxCoordinator.binlogReader.GetCurrentBinlogCoordinates() this.applier.LastIterationRangeMutex.Lock() if this.applier.LastIterationRangeMaxValues == nil || this.applier.LastIterationRangeMinValues == nil { this.applier.LastIterationRangeMutex.Unlock() @@ -1774,6 +1718,7 @@ func (this *Migrator) executeWriteFuncs() error { this.migrationContext.Log.Debugf("Noop operation; not really executing write funcs") return nil } + for { if err := this.checkAbort(); err != nil { return err @@ -1784,66 +1729,37 @@ func (this *Migrator) executeWriteFuncs() error { this.throttler.throttle(nil) - // We give higher priority to event processing, then secondary priority to - // rowcopy - select { - case eventStruct := <-this.applyEventsQueue: - { - if err := this.onApplyEventStruct(eventStruct); err != nil { - return err - } - } - default: - { - select { - case copyRowsFunc := <-this.copyRowsQueue: - { - copyRowsStartTime := time.Now() - // Retries are handled within the copyRowsFunc - if err := copyRowsFunc(); err != nil { - return this.migrationContext.Log.Errore(err) - } - if niceRatio := this.migrationContext.GetNiceRatio(); niceRatio > 0 { - copyRowsDuration := time.Since(copyRowsStartTime) - sleepTimeNanosecondFloat64 := niceRatio * float64(copyRowsDuration.Nanoseconds()) - sleepTime := time.Duration(int64(sleepTimeNanosecondFloat64)) * time.Nanosecond - time.Sleep(sleepTime) - } - } - default: - { - // Hmmmmm... nothing in the queue; no events, but also no row copy. - // This is possible upon load. Let's just sleep it over. - this.migrationContext.Log.Debugf("Getting nothing in the write queue. Sleeping...") - time.Sleep(time.Second) - } - } - } - } - } -} - -func (this *Migrator) executeDMLWriteFuncs() error { - if this.migrationContext.Noop { - this.migrationContext.Log.Debugf("Noop operation; not really executing DML write funcs") - return nil - } - for { - if atomic.LoadInt64(&this.finishedMigrating) > 0 { - return nil + // We give higher priority to event processing. + // ProcessEventsUntilDrained will process all events in the queue, and then return once no more events are available. + if err := this.trxCoordinator.ProcessEventsUntilDrained(); err != nil { + return this.migrationContext.Log.Errore(err) } this.throttler.throttle(nil) + // And secondary priority to rowcopy select { - case eventStruct := <-this.applyEventsQueue: + case copyRowsFunc := <-this.copyRowsQueue: { - if err := this.onApplyEventStruct(eventStruct); err != nil { - return err + copyRowsStartTime := time.Now() + // Retries are handled within the copyRowsFunc + if err := copyRowsFunc(); err != nil { + return this.migrationContext.Log.Errore(err) + } + if niceRatio := this.migrationContext.GetNiceRatio(); niceRatio > 0 { + copyRowsDuration := time.Since(copyRowsStartTime) + sleepTimeNanosecondFloat64 := niceRatio * float64(copyRowsDuration.Nanoseconds()) + sleepTime := time.Duration(int64(sleepTimeNanosecondFloat64)) * time.Nanosecond + time.Sleep(sleepTime) } } - case <-time.After(time.Second): - continue + default: + { + // Hmmmmm... nothing in the queue; no events, but also no row copy. + // This is possible upon load. Let's just sleep it over. + this.migrationContext.Log.Debugf("Getting nothing in the write queue. Sleeping...") + time.Sleep(time.Second) + } } } } @@ -1865,10 +1781,6 @@ func (this *Migrator) finalCleanup() error { this.migrationContext.Log.Errore(err) } } - if err := this.eventsStreamer.Close(); err != nil { - this.migrationContext.Log.Errore(err) - } - if err := this.retryOperation(this.applier.DropChangelogTable); err != nil { return err } @@ -1899,6 +1811,16 @@ func (this *Migrator) finalCleanup() error { func (this *Migrator) teardown() { atomic.StoreInt64(&this.finishedMigrating, 1) + if this.trxCoordinator != nil { + this.migrationContext.Log.Infof("Tearing down coordinator") + this.trxCoordinator.Teardown() + } + + if this.throttler != nil { + this.migrationContext.Log.Infof("Tearing down throttler") + this.throttler.Teardown() + } + if this.inspector != nil { this.migrationContext.Log.Infof("Tearing down inspector") this.inspector.Teardown() @@ -1908,14 +1830,4 @@ func (this *Migrator) teardown() { this.migrationContext.Log.Infof("Tearing down applier") this.applier.Teardown() } - - if this.eventsStreamer != nil { - this.migrationContext.Log.Infof("Tearing down streamer") - this.eventsStreamer.Teardown() - } - - if this.throttler != nil { - this.migrationContext.Log.Infof("Tearing down throttler") - this.throttler.Teardown() - } } diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go deleted file mode 100644 index f731035e1..000000000 --- a/go/logic/migrator_test.go +++ /dev/null @@ -1,1423 +0,0 @@ -/* - Copyright 2022 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ - -package logic - -import ( - "bytes" - "context" - gosql "database/sql" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" - - "runtime" - - "github.com/github/gh-ost/go/base" - "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" - "github.com/github/gh-ost/go/sql" - "github.com/testcontainers/testcontainers-go" -) - -func TestMigratorOnChangelogEvent(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - migrator.applier = NewApplier(migrationContext) - - t.Run("heartbeat", func(t *testing.T) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "heartbeat", - "2022-08-16T00:45:10.52Z", - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) - - t.Run("state-AllEventsUpToLockProcessed", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - go func(wg *sync.WaitGroup) { - defer wg.Done() - es := <-migrator.applyEventsQueue - require.NotNil(t, es) - require.NotNil(t, es.writeFunc) - }(&wg) - - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - AllEventsUpToLockProcessed, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - wg.Wait() - }) - - t.Run("state-GhostTableMigrated", func(t *testing.T) { - go func() { - require.True(t, <-migrator.ghostTableMigrated) - }() - - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - GhostTableMigrated, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) - - t.Run("state-Migrated", func(t *testing.T) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - Migrated, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) - - t.Run("state-ReadMigrationRangeValues", func(t *testing.T) { - columnValues := sql.ToColumnValues([]interface{}{ - 123, - time.Now().Unix(), - "state", - ReadMigrationRangeValues, - }) - require.Nil(t, migrator.onChangelogEvent(&binlog.BinlogEntry{ - DmlEvent: &binlog.BinlogDMLEvent{ - DatabaseName: "test", - DML: binlog.InsertDML, - NewColumnValues: columnValues}, - Coordinates: mysql.NewFileBinlogCoordinates("mysql-bin.000004", int64(4)), - })) - }) -} - -func TestMigratorValidateStatement(t *testing.T) { - t.Run("add-column", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test ADD test_new VARCHAR(64) NOT NULL`)) - - require.Nil(t, migrator.validateAlterStatement()) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) - - t.Run("drop-column", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test DROP abc`)) - - require.Nil(t, migrator.validateAlterStatement()) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 1) - _, exists := migrator.migrationContext.DroppedColumnsMap["abc"] - require.True(t, exists) - }) - - t.Run("rename-column", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) - - err := migrator.validateAlterStatement() - require.Error(t, err) - require.True(t, strings.HasPrefix(err.Error(), "gh-ost believes the ALTER statement renames columns")) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) - - t.Run("rename-column-approved", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - migrator.migrationContext.ApproveRenamedColumns = true - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test CHANGE test123 test1234 bigint unsigned`)) - - require.Nil(t, migrator.validateAlterStatement()) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) - - t.Run("rename-table", func(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.parser.ParseAlterStatement(`ALTER TABLE test RENAME TO test_new`)) - - err := migrator.validateAlterStatement() - require.Error(t, err) - require.True(t, errors.Is(err, ErrMigratorUnsupportedRenameAlter)) - require.Len(t, migrator.migrationContext.DroppedColumnsMap, 0) - }) -} - -func TestMigratorCreateFlagFiles(t *testing.T) { - tmpdir, err := os.MkdirTemp("", t.Name()) - if err != nil { - panic(err) - } - defer os.RemoveAll(tmpdir) - - migrationContext := base.NewMigrationContext() - migrationContext.PostponeCutOverFlagFile = filepath.Join(tmpdir, "cut-over.flag") - migrator := NewMigrator(migrationContext, "1.2.3") - require.Nil(t, migrator.createFlagFiles()) - require.Nil(t, migrator.createFlagFiles()) // twice to test already-exists - - _, err = os.Stat(migrationContext.PostponeCutOverFlagFile) - require.NoError(t, err) -} - -func TestMigratorGetProgressPercent(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - - { - require.Equal(t, float64(100.0), migrator.getProgressPercent(0)) - } - { - migrationContext.TotalRowsCopied = 250 - require.Equal(t, float64(25.0), migrator.getProgressPercent(1000)) - } -} - -func TestMigratorGetMigrationStateAndETA(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - now := time.Now() - migrationContext.RowCopyStartTime = now.Add(-time.Minute) - migrationContext.RowCopyEndTime = now - - { - migrationContext.TotalRowsCopied = 456 - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "migrating", state) - require.Equal(t, "4h29m44s", eta) - require.Equal(t, "4h29m44s", etaDuration.String()) - } - { - // Test using rows-per-second added data. - migrationContext.TotalRowsCopied = 456 - migrationContext.EtaRowsPerSecond = 100 - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "migrating", state) - require.Equal(t, "20m30s", eta) - require.Equal(t, "20m30s", etaDuration.String()) - } - { - migrationContext.TotalRowsCopied = 456 - state, eta, etaDuration := migrator.getMigrationStateAndETA(456) - require.Equal(t, "migrating", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } - { - migrationContext.TotalRowsCopied = 123456 - state, eta, etaDuration := migrator.getMigrationStateAndETA(456) - require.Equal(t, "migrating", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } - { - atomic.StoreInt64(&migrationContext.CountingRowsFlag, 1) - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "counting rows", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } - { - atomic.StoreInt64(&migrationContext.CountingRowsFlag, 0) - atomic.StoreInt64(&migrationContext.IsPostponingCutOver, 1) - state, eta, etaDuration := migrator.getMigrationStateAndETA(123456) - require.Equal(t, "postponing cut-over", state) - require.Equal(t, "due", eta) - require.Equal(t, "0s", etaDuration.String()) - } -} - -func TestMigratorShouldPrintStatus(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.2.3") - - require.True(t, migrator.shouldPrintStatus(NoPrintStatusRule, 10, time.Second)) // test 'rule != HeuristicPrintStatusRule' return - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 10, time.Second)) // test 'etaDuration.Seconds() <= 60' - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 90, time.Second)) // test 'etaDuration.Seconds() <= 60' again - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 90, time.Minute)) // test 'etaDuration.Seconds() <= 180' - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 60, 90*time.Second)) // test 'elapsedSeconds <= 180' - require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 61, 90*time.Second)) // test 'elapsedSeconds <= 180' - require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 99, 210*time.Second)) // test 'elapsedSeconds <= 180' - require.False(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 12345, 86400*time.Second)) // test 'else' - require.True(t, migrator.shouldPrintStatus(HeuristicPrintStatusRule, 30030, 86400*time.Second)) // test 'else' again -} - -type MigratorTestSuite struct { - suite.Suite - - mysqlContainer testcontainers.Container - db *gosql.DB -} - -func (suite *MigratorTestSuite) SetupSuite() { - ctx := context.Background() - mysqlContainer, err := testmysql.Run(ctx, - testMysqlContainerImage, - testmysql.WithDatabase(testMysqlDatabase), - testmysql.WithUsername(testMysqlUser), - testmysql.WithPassword(testMysqlPass), - testmysql.WithConfigFile("my.cnf.test"), - ) - suite.Require().NoError(err) - - suite.mysqlContainer = mysqlContainer - dsn, err := mysqlContainer.ConnectionString(ctx) - suite.Require().NoError(err) - - db, err := gosql.Open("mysql", dsn) - suite.Require().NoError(err) - - suite.db = db -} - -func (suite *MigratorTestSuite) TeardownSuite() { - suite.Assert().NoError(suite.db.Close()) - suite.Assert().NoError(testcontainers.TerminateContainer(suite.mysqlContainer)) -} - -func (suite *MigratorTestSuite) SetupTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+testMysqlDatabase) - suite.Require().NoError(err) - - os.Remove("/tmp/gh-ost.sock") -} - -func (suite *MigratorTestSuite) TearDownTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestGhostTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestRevertedTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestOldTableName()) - suite.Require().NoError(err) -} - -func (suite *MigratorTestSuite) TestMigrateEmpty() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(64))", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.InitiallyDropOldTable = true - - migrationContext.AlterStatementOptions = "ADD COLUMN foobar varchar(255), ENGINE=InnoDB" - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - suite.Require().NoError(err) - - // Verify the new column was added - var tableName, createTableSQL string - //nolint:execinquery - err = suite.db.QueryRow("SHOW CREATE TABLE "+getTestTableName()).Scan(&tableName, &createTableSQL) - suite.Require().NoError(err) - - suite.Require().Equal("testing", tableName) - suite.Require().Equal("CREATE TABLE `testing` (\n `id` int NOT NULL,\n `name` varchar(64) DEFAULT NULL,\n `foobar` varchar(255) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createTableSQL) - - // Verify the changelog table was claned up - //nolint:execinquery - err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_ghc'").Scan(&tableName) - suite.Require().Error(err) - suite.Require().Equal(gosql.ErrNoRows, err) - - // Verify the old table was renamed - //nolint:execinquery - err = suite.db.QueryRow("SHOW TABLES IN test LIKE '_testing_del'").Scan(&tableName) - suite.Require().NoError(err) - suite.Require().Equal("_testing_del", tableName) -} - -func (suite *MigratorTestSuite) TestRetryBatchCopyWithHooks() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "CREATE TABLE test.test_retry_batch (id INT PRIMARY KEY AUTO_INCREMENT, name TEXT)") - suite.Require().NoError(err) - - const initStride = 1000 - const totalBatches = 3 - for i := 0; i < totalBatches; i++ { - dataSize := 50 * i - for j := 0; j < initStride; j++ { - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO test.test_retry_batch (name) VALUES ('%s')", strings.Repeat("a", dataSize))) - suite.Require().NoError(err) - } - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET GLOBAL max_binlog_cache_size = %d", 1024*8)) - suite.Require().NoError(err) - defer func() { - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("SET GLOBAL max_binlog_cache_size = %d", 1024*1024*1024)) - suite.Require().NoError(err) - }() - - tmpDir, err := os.MkdirTemp("", "gh-ost-hooks") - suite.Require().NoError(err) - defer os.RemoveAll(tmpDir) - - hookScript := filepath.Join(tmpDir, "gh-ost-on-batch-copy-retry") - hookContent := `#!/bin/bash -# Mock hook that reduces chunk size on binlog cache error -ERROR_MSG="$GH_OST_LAST_BATCH_COPY_ERROR" -SOCKET_PATH="/tmp/gh-ost.sock" - -if ! [[ "$ERROR_MSG" =~ "max_binlog_cache_size" ]]; then - echo "Nothing to do for error: $ERROR_MSG" - exit 0 -fi - -CHUNK_SIZE=$(echo "chunk-size=?" | nc -U $SOCKET_PATH | tr -d '\n') - -MIN_CHUNK_SIZE=10 -NEW_CHUNK_SIZE=$(( CHUNK_SIZE * 8 / 10 )) -if [ $NEW_CHUNK_SIZE -lt $MIN_CHUNK_SIZE ]; then - NEW_CHUNK_SIZE=$MIN_CHUNK_SIZE -fi - -if [ $CHUNK_SIZE -eq $NEW_CHUNK_SIZE ]; then - echo "Chunk size unchanged: $CHUNK_SIZE" - exit 0 -fi - -echo "[gh-ost-on-batch-copy-retry]: Changing chunk size from $CHUNK_SIZE to $NEW_CHUNK_SIZE" -echo "chunk-size=$NEW_CHUNK_SIZE" | nc -U $SOCKET_PATH -echo "[gh-ost-on-batch-copy-retry]: Done, exiting..." -` - err = os.WriteFile(hookScript, []byte(hookContent), 0755) - suite.Require().NoError(err) - - origStdout := os.Stdout - origStderr := os.Stderr - - rOut, wOut, _ := os.Pipe() - rErr, wErr, _ := os.Pipe() - os.Stdout = wOut - os.Stderr = wErr - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := base.NewMigrationContext() - migrationContext.AllowedRunningOnMaster = true - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.DatabaseName = "test" - migrationContext.SkipPortValidation = true - migrationContext.OriginalTableName = "test_retry_batch" - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatementOptions = "MODIFY name LONGTEXT, ENGINE=InnoDB" - migrationContext.ReplicaServerId = 99999 - migrationContext.HeartbeatIntervalMilliseconds = 100 - migrationContext.ThrottleHTTPIntervalMillis = 100 - migrationContext.ThrottleHTTPTimeoutMillis = 1000 - migrationContext.HooksPath = tmpDir - migrationContext.ChunkSize = 1000 - migrationContext.SetDefaultNumRetries(10) - migrationContext.ServeSocketFile = "/tmp/gh-ost.sock" - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - suite.Require().NoError(err) - - wOut.Close() - wErr.Close() - os.Stdout = origStdout - os.Stderr = origStderr - - var bufOut, bufErr bytes.Buffer - io.Copy(&bufOut, rOut) - io.Copy(&bufErr, rErr) - - outStr := bufOut.String() - errStr := bufErr.String() - - suite.Assert().Contains(outStr, "chunk-size: 1000") - suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 1000 to 800") - suite.Assert().Contains(outStr, "chunk-size: 800") - - suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 800 to 640") - suite.Assert().Contains(outStr, "chunk-size: 640") - - suite.Assert().Contains(errStr, "[gh-ost-on-batch-copy-retry]: Changing chunk size from 640 to 512") - suite.Assert().Contains(outStr, "chunk-size: 512") - - var count int - err = suite.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM test.test_retry_batch").Scan(&count) - suite.Require().NoError(err) - suite.Assert().Equal(3000, count) -} - -func (suite *MigratorTestSuite) TestCopierIntPK() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(64), age INT);", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - migrationContext.AlterStatementOptions = "ENGINE=InnoDB" - migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "name", "age"}) - migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "name", "age"}) - migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "name", "age"}) - migrationContext.UniqueKey = &sql.UniqueKey{ - Name: "PRIMARY", - NameInGhostTable: "PRIMARY", - Columns: *sql.NewColumnList([]string{"id"}), - } - - chunkSize := int64(73) - migrationContext.ChunkSize = chunkSize - - // fill with some rows - numRows := int64(3421) - for i := range numRows { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (id, name, age) VALUES (%d, 'user-%d', %d)", getTestTableName(), i, i, i%99)) - suite.Require().NoError(err) - } - - migrator := NewMigrator(migrationContext, "0.0.0") - suite.Require().NoError(migrator.initiateApplier()) - suite.Require().NoError(migrator.applier.prepareQueries()) - suite.Require().NoError(migrator.applier.ReadMigrationRangeValues()) - - go migrator.iterateChunks() - go func() { - if err := <-migrator.rowCopyComplete; err != nil { - migrator.migrationContext.PanicAbort <- err - } - atomic.StoreInt64(&migrator.rowCopyCompleteFlag, 1) - }() - - for { - if atomic.LoadInt64(&migrator.rowCopyCompleteFlag) == 1 { - suite.Assert().Equal((numRows/chunkSize)+1, migrator.migrationContext.GetIteration()) - return - } - select { - case copyRowsFunc := <-migrator.copyRowsQueue: - { - suite.Require().NoError(copyRowsFunc()) - - // check ghost table has expected number of rows - var ghostRows int64 - suite.db.QueryRowContext(ctx, - fmt.Sprintf(`SELECT COUNT(*) FROM %s`, getTestGhostTableName()), - ).Scan(&ghostRows) - suite.Assert().Equal(migrator.migrationContext.TotalRowsCopied, ghostRows) - } - default: - time.Sleep(time.Second) - } - } -} - -func (suite *MigratorTestSuite) TestCopierCompositePK() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT UNSIGNED, t CHAR(32), PRIMARY KEY (t, id));", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - migrationContext.AlterStatementOptions = "ENGINE=InnoDB" - migrationContext.OriginalTableColumns = sql.NewColumnList([]string{"id", "t"}) - migrationContext.SharedColumns = sql.NewColumnList([]string{"id", "t"}) - migrationContext.MappedSharedColumns = sql.NewColumnList([]string{"id", "t"}) - migrationContext.UniqueKey = &sql.UniqueKey{ - Name: "PRIMARY", - NameInGhostTable: "PRIMARY", - Columns: *sql.NewColumnList([]string{"t", "id"}), - } - - chunkSize := int64(100) - migrationContext.ChunkSize = chunkSize - - // fill with some rows - numRows := int64(2049) - for i := range numRows { - query := fmt.Sprintf(`INSERT INTO %s (id, t) VALUES (FLOOR(100000000 * RAND(%d)), MD5(RAND(%d)))`, getTestTableName(), i, i) - _, err = suite.db.ExecContext(ctx, query) - suite.Require().NoError(err) - } - - migrator := NewMigrator(migrationContext, "0.0.0") - suite.Require().NoError(migrator.initiateApplier()) - suite.Require().NoError(migrator.applier.prepareQueries()) - suite.Require().NoError(migrator.applier.ReadMigrationRangeValues()) - - go migrator.iterateChunks() - go func() { - if err := <-migrator.rowCopyComplete; err != nil { - migrator.migrationContext.PanicAbort <- err - } - atomic.StoreInt64(&migrator.rowCopyCompleteFlag, 1) - }() - - for { - if atomic.LoadInt64(&migrator.rowCopyCompleteFlag) == 1 { - suite.Assert().Equal((numRows/chunkSize)+1, migrator.migrationContext.GetIteration()) - return - } - select { - case copyRowsFunc := <-migrator.copyRowsQueue: - { - suite.Require().NoError(copyRowsFunc()) - - // check ghost table has expected number of rows - var ghostRows int64 - suite.db.QueryRowContext(ctx, - fmt.Sprintf(`SELECT COUNT(*) FROM %s`, getTestGhostTableName()), - ).Scan(&ghostRows) - suite.Assert().Equal(migrator.migrationContext.TotalRowsCopied, ghostRows) - } - default: - time.Sleep(time.Second) - } - } -} - -func TestMigratorRetry(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrator := NewMigrator(migrationContext, "1.2.3") - - var sleeps = 0 - RetrySleepFn = func(duration time.Duration) { - assert.Equal(t, 1*time.Second, duration) - sleeps++ - } - - var tries = 0 - retryable := func() error { - tries++ - if tries < int(migrationContext.MaxRetries()) { - return errors.New("Backoff") - } - return nil - } - - result := migrator.retryOperation(retryable, false) - assert.NoError(t, result) - assert.Equal(t, sleeps, 99) - assert.Equal(t, tries, 100) -} - -func TestMigratorRetryWithExponentialBackoff(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrationContext.SetExponentialBackoffMaxInterval(42) - migrator := NewMigrator(migrationContext, "1.2.3") - - var sleeps = 0 - expected := []int{ - 1, 2, 4, 8, 16, 32, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, - 42, 42, 42, 42, 42, 42, - } - RetrySleepFn = func(duration time.Duration) { - assert.Equal(t, time.Duration(expected[sleeps])*time.Second, duration) - sleeps++ - } - - var tries = 0 - retryable := func() error { - tries++ - if tries < int(migrationContext.MaxRetries()) { - return errors.New("Backoff") - } - return nil - } - - result := migrator.retryOperationWithExponentialBackoff(retryable, false) - assert.NoError(t, result) - assert.Equal(t, sleeps, 99) - assert.Equal(t, tries, 100) -} - -func TestMigratorRetryAbortsOnContextCancellation(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - // No sleep needed for this test - } - - var tries = 0 - retryable := func() error { - tries++ - if tries == 5 { - // Cancel context on 5th try - migrationContext.CancelContext() - } - return errors.New("Simulated error") - } - - result := migrator.retryOperation(retryable, false) - assert.Error(t, result) - // Should abort after 6 tries: 5 failures + 1 checkAbort detection - assert.True(t, tries <= 6, "Expected tries <= 6, got %d", tries) - // Verify we got context cancellation error - assert.Contains(t, result.Error(), "context canceled") -} - -func TestMigratorRetryWithExponentialBackoffAbortsOnContextCancellation(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrationContext.SetExponentialBackoffMaxInterval(42) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - // No sleep needed for this test - } - - var tries = 0 - retryable := func() error { - tries++ - if tries == 5 { - // Cancel context on 5th try - migrationContext.CancelContext() - } - return errors.New("Simulated error") - } - - result := migrator.retryOperationWithExponentialBackoff(retryable, false) - assert.Error(t, result) - // Should abort after 6 tries: 5 failures + 1 checkAbort detection - assert.True(t, tries <= 6, "Expected tries <= 6, got %d", tries) - // Verify we got context cancellation error - assert.Contains(t, result.Error(), "context canceled") -} - -func TestMigratorRetrySkipsRetriesForWarnings(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - t.Fatal("Should not sleep/retry for warning errors") - } - - var tries = 0 - retryable := func() error { - tries++ - return errors.New("warnings detected in statement 1 of 1: [Warning: Duplicate entry 'test' for key 'idx' (1062)]") - } - - result := migrator.retryOperation(retryable, false) - assert.Error(t, result) - // Should only try once - no retries for warnings - assert.Equal(t, 1, tries, "Expected exactly 1 try (no retries) for warning error") - assert.Contains(t, result.Error(), "warnings detected") -} - -func TestMigratorRetryWithExponentialBackoffSkipsRetriesForWarnings(t *testing.T) { - oldRetrySleepFn := RetrySleepFn - defer func() { RetrySleepFn = oldRetrySleepFn }() - - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(100) - migrationContext.SetExponentialBackoffMaxInterval(42) - migrator := NewMigrator(migrationContext, "1.2.3") - - RetrySleepFn = func(duration time.Duration) { - t.Fatal("Should not sleep/retry for warning errors") - } - - var tries = 0 - retryable := func() error { - tries++ - return errors.New("warnings detected in statement 1 of 1: [Warning: Duplicate entry 'test' for key 'idx' (1062)]") - } - - result := migrator.retryOperationWithExponentialBackoff(retryable, false) - assert.Error(t, result) - // Should only try once - no retries for warnings - assert.Equal(t, 1, tries, "Expected exactly 1 try (no retries) for warning error") - assert.Contains(t, result.Error(), "warnings detected") -} - -func (suite *MigratorTestSuite) TestCutOverLossDataCaseLockGhostBeforeRename() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(64))", getTestTableName())) - suite.Require().NoError(err) - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("insert into %s values(1,'a')", getTestTableName())) - suite.Require().NoError(err) - - done := make(chan error, 1) - go func() { - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - if err != nil { - done <- err - return - } - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AllowSetupMetadataLockInstruments = true - migrationContext.AlterStatementOptions = "ADD COLUMN foobar varchar(255)" - migrationContext.HeartbeatIntervalMilliseconds = 100 - migrationContext.CutOverLockTimeoutSeconds = 4 - - _, filename, _, _ := runtime.Caller(0) - migrationContext.PostponeCutOverFlagFile = filepath.Join(filepath.Dir(filename), "../../tmp/ghost.postpone.flag") - - migrator := NewMigrator(migrationContext, "0.0.0") - - //nolint:contextcheck - done <- migrator.Migrate() - }() - - time.Sleep(2 * time.Second) - //nolint:dogsled - _, filename, _, _ := runtime.Caller(0) - err = os.Remove(filepath.Join(filepath.Dir(filename), "../../tmp/ghost.postpone.flag")) - if err != nil { - suite.Require().NoError(err) - } - time.Sleep(1 * time.Second) - go func() { - holdConn, err := suite.db.Conn(ctx) - suite.Require().NoError(err) - _, err = holdConn.ExecContext(ctx, "SELECT *, sleep(2) FROM test._testing_gho WHERE id = 1") - suite.Require().NoError(err) - }() - - dmlConn, err := suite.db.Conn(ctx) - suite.Require().NoError(err) - - _, err = dmlConn.ExecContext(ctx, fmt.Sprintf("insert into %s (id, name) values(2,'b')", getTestTableName())) - fmt.Println("insert into table original table") - suite.Require().NoError(err) - - migrateErr := <-done - suite.Require().NoError(migrateErr) - - // Verify the new column was added - var delValue, OriginalValue int64 - err = suite.db.QueryRow( - fmt.Sprintf("select count(*) from %s._%s_del", testMysqlDatabase, testMysqlTableName), - ).Scan(&delValue) - suite.Require().NoError(err) - - err = suite.db.QueryRow("select count(*) from " + getTestTableName()).Scan(&OriginalValue) - suite.Require().NoError(err) - - suite.Require().LessOrEqual(delValue, OriginalValue) - - var tableName, createTableSQL string - //nolint:execinquery - err = suite.db.QueryRow("SHOW CREATE TABLE "+getTestTableName()).Scan(&tableName, &createTableSQL) - suite.Require().NoError(err) - - suite.Require().Equal(testMysqlTableName, tableName) - suite.Require().Equal("CREATE TABLE `testing` (\n `id` int NOT NULL,\n `name` varchar(64) DEFAULT NULL,\n `foobar` varchar(255) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci", createTableSQL) -} - -func (suite *MigratorTestSuite) TestRevertEmpty() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, s CHAR(32))", getTestTableName())) - suite.Require().NoError(err) - - var oldTableName string - - // perform original migration - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatement = "ADD COLUMN newcol CHAR(32)" - migrationContext.Checkpoint = true - migrationContext.CheckpointIntervalSeconds = 10 - migrationContext.DropServeSocket = true - migrationContext.InitiallyDropOldTable = true - migrationContext.UseGTIDs = true - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - oldTableName = migrationContext.GetOldTableName() - suite.Require().NoError(err) - } - - // revert the original migration - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.DropServeSocket = true - migrationContext.UseGTIDs = true - migrationContext.Revert = true - migrationContext.OkToDropTable = true - migrationContext.OldTableName = oldTableName - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Revert() - suite.Require().NoError(err) - } -} - -func (suite *MigratorTestSuite) TestRevert() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, s CHAR(32))", getTestTableName())) - suite.Require().NoError(err) - - numRows := 0 - for range 100 { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (id, s) VALUES (%d, MD5('%d'))", getTestTableName(), numRows, numRows)) - suite.Require().NoError(err) - numRows += 1 - } - - var oldTableName string - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - // perform original migration - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatement = "ADD INDEX idx1 (s)" - migrationContext.Checkpoint = true - migrationContext.CheckpointIntervalSeconds = 10 - migrationContext.DropServeSocket = true - migrationContext.InitiallyDropOldTable = true - migrationContext.UseGTIDs = true - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Migrate() - oldTableName = migrationContext.GetOldTableName() - suite.Require().NoError(err) - } - - // do some writes - for range 100 { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("INSERT INTO %s (id, s) VALUES (%d, MD5('%d'))", getTestTableName(), numRows, numRows)) - suite.Require().NoError(err) - numRows += 1 - } - for i := 0; i < numRows; i += 7 { - _, err = suite.db.ExecContext(ctx, - fmt.Sprintf("UPDATE %s SET s=MD5('%d') where id=%d", getTestTableName(), 2*i, i)) - suite.Require().NoError(err) - } - - // revert the original migration - { - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.DropServeSocket = true - migrationContext.UseGTIDs = true - migrationContext.Revert = true - migrationContext.OldTableName = oldTableName - - migrator := NewMigrator(migrationContext, "0.0.0") - - err = migrator.Revert() - oldTableName = migrationContext.GetOldTableName() - suite.Require().NoError(err) - } - - // checksum original and reverted table - var _tableName, checksum1, checksum2 string - rows, err := suite.db.Query(fmt.Sprintf("CHECKSUM TABLE %s, %s", testMysqlTableName, oldTableName)) - suite.Require().NoError(err) - defer rows.Close() - suite.Require().True(rows.Next()) - suite.Require().NoError(rows.Scan(&_tableName, &checksum1)) - suite.Require().True(rows.Next()) - suite.Require().NoError(rows.Scan(&_tableName, &checksum2)) - suite.Require().NoError(rows.Err()) - - suite.Require().Equal(checksum1, checksum2) -} - -func TestMigrator(t *testing.T) { - suite.Run(t, new(MigratorTestSuite)) -} - -func TestPanicAbort_PropagatesError(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Send an error to PanicAbort - testErr := errors.New("test abort error") - go func() { - migrationContext.PanicAbort <- testErr - }() - - // Wait a bit for error to be processed - time.Sleep(100 * time.Millisecond) - - // Verify error was stored - got := migrationContext.GetAbortError() - if got != testErr { //nolint:errorlint // Testing pointer equality for sentinel error - t.Errorf("Expected error %v, got %v", testErr, got) - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - context was cancelled - default: - t.Error("Expected context to be cancelled") - } -} - -func TestPanicAbort_FirstErrorWins(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Send first error - err1 := errors.New("first error") - go func() { - migrationContext.PanicAbort <- err1 - }() - - // Wait for first error to be processed - time.Sleep(50 * time.Millisecond) - - // Try to send second error (should be ignored) - err2 := errors.New("second error") - migrationContext.SetAbortError(err2) - - // Verify only first error is stored - got := migrationContext.GetAbortError() - if got != err1 { //nolint:errorlint // Testing pointer equality for sentinel error - t.Errorf("Expected first error %v, got %v", err1, got) - } -} - -func TestAbort_AfterRowCopy(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Give listenOnPanicAbort time to start - time.Sleep(20 * time.Millisecond) - - // Simulate row copy error by sending to rowCopyComplete in a goroutine - // (unbuffered channel, so send must be async) - testErr := errors.New("row copy failed") - go func() { - migrator.rowCopyComplete <- testErr - }() - - // Consume the error (simulating what Migrate() does) - // This is a blocking call that waits for the error - migrator.consumeRowCopyComplete() - - // Wait for the error to be processed by listenOnPanicAbort - time.Sleep(50 * time.Millisecond) - - // Check that error was stored - if got := migrationContext.GetAbortError(); got == nil { - t.Fatal("Expected abort error to be stored after row copy error") - } else if got.Error() != "row copy failed" { - t.Errorf("Expected 'row copy failed', got %v", got) - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - case <-time.After(1 * time.Second): - t.Error("Expected context to be cancelled after row copy error") - } -} - -func TestAbort_DuringInspection(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Simulate error during inspection phase - testErr := errors.New("inspection failed") - go func() { - time.Sleep(10 * time.Millisecond) - select { - case migrationContext.PanicAbort <- testErr: - case <-migrationContext.GetContext().Done(): - } - }() - - // Wait for abort to be processed - time.Sleep(50 * time.Millisecond) - - // Call checkAbort (simulating what Migrate() does after initiateInspector) - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error after abort during inspection") - } - - if err.Error() != "inspection failed" { - t.Errorf("Expected 'inspection failed', got %v", err) - } -} - -func TestAbort_DuringStreaming(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Simulate error from streaming goroutine - testErr := errors.New("streaming error") - go func() { - time.Sleep(10 * time.Millisecond) - // Use select pattern like actual code does - select { - case migrationContext.PanicAbort <- testErr: - case <-migrationContext.GetContext().Done(): - } - }() - - // Wait for abort to be processed - time.Sleep(50 * time.Millisecond) - - // Verify error stored and context cancelled - if got := migrationContext.GetAbortError(); got == nil { - t.Fatal("Expected abort error to be stored") - } else if got.Error() != "streaming error" { - t.Errorf("Expected 'streaming error', got %v", got) - } - - // Verify checkAbort catches it - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error after streaming abort") - } -} - -func TestRetryExhaustion_TriggersAbort(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrationContext.SetDefaultNumRetries(2) // Only 2 retries - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Operation that always fails - callCount := 0 - operation := func() error { - callCount++ - return errors.New("persistent failure") - } - - // Call retryOperation (with notFatalHint=false so it sends to PanicAbort) - err := migrator.retryOperation(operation) - - // Should have called operation MaxRetries times - if callCount != 2 { - t.Errorf("Expected 2 retry attempts, got %d", callCount) - } - - // Should return the error - if err == nil { - t.Fatal("Expected retryOperation to return error") - } - - // Wait for abort to be processed - time.Sleep(100 * time.Millisecond) - - // Verify error was sent to PanicAbort and stored - if got := migrationContext.GetAbortError(); got == nil { - t.Error("Expected abort error to be stored after retry exhaustion") - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - default: - t.Error("Expected context to be cancelled after retry exhaustion") - } -} - -func TestRevert_AbortsOnError(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrationContext.Revert = true - migrationContext.OldTableName = "_test_del" - migrationContext.OriginalTableName = "test" - migrationContext.DatabaseName = "testdb" - migrator := NewMigrator(migrationContext, "1.0.0") - - // Start listenOnPanicAbort - go migrator.listenOnPanicAbort() - - // Simulate error during revert - testErr := errors.New("revert failed") - go func() { - time.Sleep(10 * time.Millisecond) - select { - case migrationContext.PanicAbort <- testErr: - case <-migrationContext.GetContext().Done(): - } - }() - - // Wait for abort to be processed - time.Sleep(50 * time.Millisecond) - - // Verify checkAbort catches it - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error during revert") - } - - if err.Error() != "revert failed" { - t.Errorf("Expected 'revert failed', got %v", err) - } - - // Verify context was cancelled - ctx := migrationContext.GetContext() - select { - case <-ctx.Done(): - // Success - default: - t.Error("Expected context to be cancelled during revert abort") - } -} - -func TestCheckAbort_ReturnsNilWhenNoError(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // No error has occurred - err := migrator.checkAbort() - if err != nil { - t.Errorf("Expected no error, got %v", err) - } -} - -func TestCheckAbort_DetectsContextCancellation(t *testing.T) { - migrationContext := base.NewMigrationContext() - migrator := NewMigrator(migrationContext, "1.0.0") - - // Cancel context directly (without going through PanicAbort) - migrationContext.CancelContext() - - // checkAbort should detect the cancellation - err := migrator.checkAbort() - if err == nil { - t.Fatal("Expected checkAbort to return error when context is cancelled") - } -} - -func (suite *MigratorTestSuite) TestPanicOnWarningsDuplicateDuringCutoverWithHighRetries() { - ctx := context.Background() - - // Create table with email column (no unique constraint initially) - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY AUTO_INCREMENT, email VARCHAR(100))", getTestTableName())) - suite.Require().NoError(err) - - // Insert initial rows with unique email values - passes pre-flight validation - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user1@example.com')", getTestTableName())) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user2@example.com')", getTestTableName())) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user3@example.com')", getTestTableName())) - suite.Require().NoError(err) - - // Verify we have 3 rows - var count int - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) - suite.Require().NoError(err) - suite.Require().Equal(3, count) - - // Create postpone flag file - tmpDir, err := os.MkdirTemp("", "gh-ost-postpone-test") - suite.Require().NoError(err) - defer os.RemoveAll(tmpDir) - postponeFlagFile := filepath.Join(tmpDir, "postpone.flag") - err = os.WriteFile(postponeFlagFile, []byte{}, 0644) - suite.Require().NoError(err) - - // Start migration in goroutine - done := make(chan error, 1) - go func() { - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - if err != nil { - done <- err - return - } - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - migrationContext.AlterStatementOptions = "ADD UNIQUE KEY unique_email_idx (email)" - migrationContext.HeartbeatIntervalMilliseconds = 100 - migrationContext.PostponeCutOverFlagFile = postponeFlagFile - migrationContext.PanicOnWarnings = true - - // High retry count + exponential backoff means retries will take a long time and fail the test if not properly aborted - migrationContext.SetDefaultNumRetries(30) - migrationContext.CutOverExponentialBackoff = true - migrationContext.SetExponentialBackoffMaxInterval(128) - - migrator := NewMigrator(migrationContext, "0.0.0") - - //nolint:contextcheck - done <- migrator.Migrate() - }() - - // Wait for migration to reach postponed state - // TODO replace this with an actual check for postponed state - time.Sleep(3 * time.Second) - - // Now insert a duplicate email value while migration is postponed - // This simulates data arriving during migration that would violate the unique constraint - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (email) VALUES ('user1@example.com')", getTestTableName())) - suite.Require().NoError(err) - - // Verify we now have 4 rows (including the duplicate) - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) - suite.Require().NoError(err) - suite.Require().Equal(4, count) - - // Unpostpone the migration - gh-ost will now try to apply binlog events with the duplicate - err = os.Remove(postponeFlagFile) - suite.Require().NoError(err) - - // Wait for Migrate() to return - with timeout to detect if it hangs - select { - case migrateErr := <-done: - // Success - Migrate() returned - // It should return an error due to the duplicate - suite.Require().Error(migrateErr, "Expected migration to fail due to duplicate key violation") - suite.Require().Contains(migrateErr.Error(), "Duplicate entry", "Error should mention duplicate entry") - case <-time.After(5 * time.Minute): - suite.FailNow("Migrate() hung and did not return within 5 minutes - failure to abort on warnings in retry loop") - } - - // Verify all 4 rows are still in the original table (no silent data loss) - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s", getTestTableName())).Scan(&count) - suite.Require().NoError(err) - suite.Require().Equal(4, count, "Original table should still have all 4 rows") - - // Verify both user1@example.com entries still exist - var duplicateCount int - err = suite.db.QueryRowContext(ctx, fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE email = 'user1@example.com'", getTestTableName())).Scan(&duplicateCount) - suite.Require().NoError(err) - suite.Require().Equal(2, duplicateCount, "Should have 2 duplicate email entries") -} diff --git a/go/logic/server.go b/go/logic/server.go index 74097acb7..c819c2206 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -32,6 +32,7 @@ var ( ) type printStatusFunc func(PrintStatusRule, io.Writer) +type printWorkersFunc func(io.Writer) // Server listens for requests on a socket file or via TCP type Server struct { @@ -40,14 +41,16 @@ type Server struct { tcpListener net.Listener hooksExecutor *HooksExecutor printStatus printStatusFunc + printWorkers printWorkersFunc isCPUProfiling int64 } -func NewServer(migrationContext *base.MigrationContext, hooksExecutor *HooksExecutor, printStatus printStatusFunc) *Server { +func NewServer(migrationContext *base.MigrationContext, hooksExecutor *HooksExecutor, printStatus printStatusFunc, printWorkers printWorkersFunc) *Server { return &Server{ migrationContext: migrationContext, hooksExecutor: hooksExecutor, printStatus: printStatus, + printWorkers: printWorkers, } } @@ -243,6 +246,9 @@ help # This message return ForcePrintStatusOnlyRule, nil case "info", "status": return ForcePrintStatusAndHintRule, nil + case "worker-stats": + this.printWorkers(writer) + return NoPrintStatusRule, nil case "cpu-profile": cpuProfile, err := this.runCPUProfile(arg) if err == nil { diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 1c2635138..86fe3d985 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -1,244 +1,10 @@ /* Copyright 2022 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE + See https://github.com/github/gh-ost/blob/master/LICENSE */ package logic -import ( - gosql "database/sql" - "fmt" - "strings" - "sync" - "time" - - "github.com/github/gh-ost/go/base" - "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" - - gomysql "github.com/go-mysql-org/go-mysql/mysql" - "github.com/openark/golib/sqlutils" -) - -type BinlogEventListener struct { - async bool - databaseName string - tableName string - onDmlEvent func(event *binlog.BinlogEntry) error -} - -const ( - EventsChannelBufferSize = 1 - ReconnectStreamerSleepSeconds = 1 -) - -// EventsStreamer reads data from binary logs and streams it on. It acts as a publisher, -// and interested parties may subscribe for per-table events. -type EventsStreamer struct { - connectionConfig *mysql.ConnectionConfig - db *gosql.DB - dbVersion string - migrationContext *base.MigrationContext - initialBinlogCoordinates mysql.BinlogCoordinates - listeners [](*BinlogEventListener) - listenersMutex *sync.Mutex - eventsChannel chan *binlog.BinlogEntry - binlogReader *binlog.GoMySQLReader - name string -} - -func NewEventsStreamer(migrationContext *base.MigrationContext) *EventsStreamer { - return &EventsStreamer{ - connectionConfig: migrationContext.InspectorConnectionConfig, - migrationContext: migrationContext, - listeners: [](*BinlogEventListener){}, - listenersMutex: &sync.Mutex{}, - eventsChannel: make(chan *binlog.BinlogEntry, EventsChannelBufferSize), - name: "streamer", - initialBinlogCoordinates: migrationContext.InitialStreamerCoords, - } -} - -// AddListener registers a new listener for binlog events, on a per-table basis -func (this *EventsStreamer) AddListener( - async bool, databaseName string, tableName string, onDmlEvent func(event *binlog.BinlogEntry) error) (err error) { - this.listenersMutex.Lock() - defer this.listenersMutex.Unlock() - - if databaseName == "" { - return fmt.Errorf("Empty database name in AddListener") - } - if tableName == "" { - return fmt.Errorf("Empty table name in AddListener") - } - listener := &BinlogEventListener{ - async: async, - databaseName: databaseName, - tableName: tableName, - onDmlEvent: onDmlEvent, - } - this.listeners = append(this.listeners, listener) - return nil -} - -// notifyListeners will notify relevant listeners with given DML event. Only -// listeners registered for changes on the table on which the DML operates are notified. -func (this *EventsStreamer) notifyListeners(binlogEntry *binlog.BinlogEntry) { - this.listenersMutex.Lock() - defer this.listenersMutex.Unlock() - - for _, listener := range this.listeners { - listener := listener - if !strings.EqualFold(listener.databaseName, binlogEntry.DmlEvent.DatabaseName) { - continue - } - if !strings.EqualFold(listener.tableName, binlogEntry.DmlEvent.TableName) { - continue - } - if listener.async { - go func() { - listener.onDmlEvent(binlogEntry) - }() - } else { - listener.onDmlEvent(binlogEntry) - } - } -} - -func (this *EventsStreamer) InitDBConnections() (err error) { - EventsStreamerUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, EventsStreamerUri); err != nil { - return err - } - version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name) - if err != nil { - return err - } - this.dbVersion = version - if this.initialBinlogCoordinates == nil || this.initialBinlogCoordinates.IsEmpty() { - if err := this.readCurrentBinlogCoordinates(); err != nil { - return err - } - } - if err := this.initBinlogReader(this.initialBinlogCoordinates); err != nil { - return err - } - - return nil -} - -// initBinlogReader creates and connects the reader: we hook up to a MySQL server as a replica -func (this *EventsStreamer) initBinlogReader(binlogCoordinates mysql.BinlogCoordinates) error { - goMySQLReader := binlog.NewGoMySQLReader(this.migrationContext) - if err := goMySQLReader.ConnectBinlogStreamer(binlogCoordinates); err != nil { - return err - } - this.binlogReader = goMySQLReader - return nil -} - -func (this *EventsStreamer) GetCurrentBinlogCoordinates() mysql.BinlogCoordinates { - return this.binlogReader.GetCurrentBinlogCoordinates() -} - -// readCurrentBinlogCoordinates reads master status from hooked server -func (this *EventsStreamer) readCurrentBinlogCoordinates() error { - binaryLogStatusTerm := mysql.ReplicaTermFor(this.dbVersion, "master status") - query := fmt.Sprintf("show /* gh-ost readCurrentBinlogCoordinates */ %s", binaryLogStatusTerm) - foundMasterStatus := false - err := sqlutils.QueryRowsMap(this.db, query, func(m sqlutils.RowMap) error { - if this.migrationContext.UseGTIDs { - execGtidSet := m.GetString("Executed_Gtid_Set") - gtidSet, err := gomysql.ParseMysqlGTIDSet(execGtidSet) - if err != nil { - return err - } - this.initialBinlogCoordinates = &mysql.GTIDBinlogCoordinates{GTIDSet: gtidSet.(*gomysql.MysqlGTIDSet)} - } else { - this.initialBinlogCoordinates = &mysql.FileBinlogCoordinates{ - LogFile: m.GetString("File"), - LogPos: m.GetInt64("Position"), - } - } - foundMasterStatus = true - return nil - }) - if err != nil { - return err - } - if !foundMasterStatus { - return fmt.Errorf("Got no results from SHOW %s. Bailing out", strings.ToUpper(binaryLogStatusTerm)) - } - this.migrationContext.Log.Debugf("Streamer binlog coordinates: %+v", this.initialBinlogCoordinates) - return nil -} - -// StreamEvents will begin streaming events. It will be blocking, so should be -// executed by a goroutine -func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { - go func() { - for binlogEntry := range this.eventsChannel { - if binlogEntry.DmlEvent != nil { - this.notifyListeners(binlogEntry) - } - } - }() - // The next should block and execute forever, unless there's a serious error. - var successiveFailures int - var reconnectCoords mysql.BinlogCoordinates - ctx := this.migrationContext.GetContext() - for { - // Check for context cancellation each iteration - if err := ctx.Err(); err != nil { - return err - } - if canStopStreaming() { - return nil - } - // We will reconnect the binlog streamer at the coordinates - // of the last trx that was read completely from the streamer. - // Since row event application is idempotent, it's OK if we reapply some events. - if err := this.binlogReader.StreamEvents(canStopStreaming, this.eventsChannel); err != nil { - if canStopStreaming() { - return nil - } - - this.migrationContext.Log.Infof("StreamEvents encountered unexpected error: %+v", err) - this.migrationContext.MarkPointOfInterest() - time.Sleep(ReconnectStreamerSleepSeconds * time.Second) - - // See if there's retry overflow - if this.migrationContext.BinlogSyncerMaxReconnectAttempts > 0 && successiveFailures >= this.migrationContext.BinlogSyncerMaxReconnectAttempts { - return fmt.Errorf("%d successive failures in streamer reconnect at coordinates %+v", successiveFailures, reconnectCoords) - } - - // Reposition at same coordinates - if this.binlogReader.LastTrxCoords != nil { - reconnectCoords = this.binlogReader.LastTrxCoords.Clone() - } else { - reconnectCoords = this.initialBinlogCoordinates.Clone() - } - if !reconnectCoords.SmallerThan(this.GetCurrentBinlogCoordinates()) { - successiveFailures += 1 - } else { - successiveFailures = 0 - } - - this.migrationContext.Log.Infof("Reconnecting EventsStreamer... Will resume at %+v", reconnectCoords) - _ = this.binlogReader.Close() - if err := this.initBinlogReader(reconnectCoords); err != nil { - return err - } - } - } -} - -func (this *EventsStreamer) Close() (err error) { - err = this.binlogReader.Close() - this.migrationContext.Log.Infof("Closed streamer connection. err=%+v", err) - return err -} - -func (this *EventsStreamer) Teardown() { - this.db.Close() -} +// EventsStreamer has been replaced by the transaction-aware streaming +// in gomysql_reader.go (StreamTransactions / handleTransactionEvent). +// The coordinator now handles event dispatching and parallel application. diff --git a/go/logic/streamer_test.go b/go/logic/streamer_test.go index 8e0b57f80..c48054717 100644 --- a/go/logic/streamer_test.go +++ b/go/logic/streamer_test.go @@ -1,264 +1,4 @@ package logic -import ( - "context" - "database/sql" - gosql "database/sql" - "fmt" - "testing" - "time" - - "github.com/github/gh-ost/go/binlog" - "github.com/stretchr/testify/suite" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/modules/mysql" - - "golang.org/x/sync/errgroup" -) - -type EventsStreamerTestSuite struct { - suite.Suite - - mysqlContainer testcontainers.Container - db *gosql.DB -} - -func (suite *EventsStreamerTestSuite) SetupSuite() { - ctx := context.Background() - mysqlContainer, err := mysql.Run(ctx, - testMysqlContainerImage, - mysql.WithDatabase(testMysqlDatabase), - mysql.WithUsername(testMysqlUser), - mysql.WithPassword(testMysqlPass), - ) - suite.Require().NoError(err) - - suite.mysqlContainer = mysqlContainer - dsn, err := mysqlContainer.ConnectionString(ctx) - suite.Require().NoError(err) - - db, err := gosql.Open("mysql", dsn) - suite.Require().NoError(err) - - suite.db = db -} - -func (suite *EventsStreamerTestSuite) TeardownSuite() { - suite.Assert().NoError(suite.db.Close()) - suite.Assert().NoError(testcontainers.TerminateContainer(suite.mysqlContainer)) -} - -func (suite *EventsStreamerTestSuite) SetupTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "CREATE DATABASE IF NOT EXISTS "+testMysqlDatabase) - suite.Require().NoError(err) -} - -func (suite *EventsStreamerTestSuite) TearDownTest() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestTableName()) - suite.Require().NoError(err) - _, err = suite.db.ExecContext(ctx, "DROP TABLE IF EXISTS "+getTestGhostTableName()) - suite.Require().NoError(err) -} - -func (suite *EventsStreamerTestSuite) TestStreamEvents() { - ctx := context.Background() - - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(255))", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - streamer := NewEventsStreamer(migrationContext) - - err = streamer.InitDBConnections() - suite.Require().NoError(err) - defer streamer.Close() - defer streamer.Teardown() - - streamCtx, cancel := context.WithCancel(context.Background()) - - dmlEvents := make([]*binlog.BinlogDMLEvent, 0) - err = streamer.AddListener(false, testMysqlDatabase, testMysqlTableName, func(event *binlog.BinlogEntry) error { - dmlEvents = append(dmlEvents, event.DmlEvent) - - // Stop once we've collected three events - if len(dmlEvents) == 3 { - cancel() - } - - return nil - }) - suite.Require().NoError(err) - - group := errgroup.Group{} - group.Go(func() error { - return streamer.StreamEvents(func() bool { - return streamCtx.Err() != nil - }) - }) - - group.Go(func() error { - var err error - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, 'foo')", getTestTableName())) - if err != nil { - return err - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (2, 'bar')", getTestTableName())) - if err != nil { - return err - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (3, 'baz')", getTestTableName())) - if err != nil { - return err - } - - // Bug: Need to write fourth event to hit the canStopStreaming function again - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (4, 'qux')", getTestTableName())) - if err != nil { - return err - } - - return nil - }) - - err = group.Wait() - suite.Require().NoError(err) - - suite.Require().Len(dmlEvents, 3) -} - -func (suite *EventsStreamerTestSuite) TestStreamEventsAutomaticallyReconnects() { - ctx := context.Background() - _, err := suite.db.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name VARCHAR(255))", getTestTableName())) - suite.Require().NoError(err) - - connectionConfig, err := getTestConnectionConfig(ctx, suite.mysqlContainer) - suite.Require().NoError(err) - - migrationContext := newTestMigrationContext() - migrationContext.ApplierConnectionConfig = connectionConfig - migrationContext.InspectorConnectionConfig = connectionConfig - migrationContext.SetConnectionConfig("innodb") - - streamer := NewEventsStreamer(migrationContext) - - err = streamer.InitDBConnections() - suite.Require().NoError(err) - defer streamer.Close() - defer streamer.Teardown() - - streamCtx, cancel := context.WithCancel(context.Background()) - - dmlEvents := make([]*binlog.BinlogDMLEvent, 0) - err = streamer.AddListener(false, testMysqlDatabase, testMysqlTableName, func(event *binlog.BinlogEntry) error { - dmlEvents = append(dmlEvents, event.DmlEvent) - - // Stop once we've collected three events - if len(dmlEvents) == 3 { - cancel() - } - - return nil - }) - suite.Require().NoError(err) - - group := errgroup.Group{} - group.Go(func() error { - return streamer.StreamEvents(func() bool { - return streamCtx.Err() != nil - }) - }) - - group.Go(func() error { - var err error - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, 'foo')", getTestTableName())) - if err != nil { - return err - } - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (2, 'bar')", getTestTableName())) - if err != nil { - return err - } - - var currentConnectionId int - err = suite.db.QueryRowContext(ctx, "SELECT CONNECTION_ID()").Scan(¤tConnectionId) - if err != nil { - return err - } - - //nolint:execinquery - rows, err := suite.db.Query("SHOW FULL PROCESSLIST") - if err != nil { - return err - } - defer rows.Close() - - connectionIdsToKill := make([]int, 0) - - var id, stateTime int - var user, host, dbName, command, state, info sql.NullString - for rows.Next() { - err = rows.Scan(&id, &user, &host, &dbName, &command, &stateTime, &state, &info) - if err != nil { - return err - } - - fmt.Printf("id: %d, user: %s, host: %s, dbName: %s, command: %s, time: %d, state: %s, info: %s\n", id, user.String, host.String, dbName.String, command.String, stateTime, state.String, info.String) - - if id != currentConnectionId && user.String == testMysqlUser { - connectionIdsToKill = append(connectionIdsToKill, id) - } - } - - if err := rows.Err(); err != nil { - return err - } - - for _, connectionIdToKill := range connectionIdsToKill { - _, err = suite.db.ExecContext(ctx, "KILL ?", connectionIdToKill) - if err != nil { - return err - } - } - - // Bug: We need to wait here for the streamer to reconnect - time.Sleep(time.Second * 2) - - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (3, 'baz')", getTestTableName())) - if err != nil { - return err - } - - // Bug: Need to write fourth event to hit the canStopStreaming function again - _, err = suite.db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (id, name) VALUES (4, 'qux')", getTestTableName())) - if err != nil { - return err - } - - return nil - }) - - err = group.Wait() - suite.Require().NoError(err) - - suite.Require().Len(dmlEvents, 3) -} - -func TestEventsStreamer(t *testing.T) { - suite.Run(t, new(EventsStreamerTestSuite)) -} +// Legacy EventsStreamer tests removed. +// See coordinator_test.go for the replacement test suite. diff --git a/go/logic/test_helpers_test.go b/go/logic/test_helpers_test.go new file mode 100644 index 000000000..3b067fcfb --- /dev/null +++ b/go/logic/test_helpers_test.go @@ -0,0 +1,38 @@ +package logic + +import ( + "context" + "fmt" + + "github.com/github/gh-ost/go/mysql" + "github.com/testcontainers/testcontainers-go" +) + +func GetDSN(ctx context.Context, container testcontainers.Container) (string, error) { + host, err := container.Host(ctx) + if err != nil { + return "", err + } + port, err := container.MappedPort(ctx, "3306/tcp") + if err != nil { + return "", err + } + return fmt.Sprintf("root:root-password@tcp(%s:%s)/", host, port.Port()), nil +} + +func GetConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { + host, err := container.Host(ctx) + if err != nil { + return nil, err + } + port, err := container.MappedPort(ctx, "3306/tcp") + if err != nil { + return nil, err + } + config := mysql.NewConnectionConfig() + config.Key.Hostname = host + config.Key.Port = port.Int() + config.User = "root" + config.Password = "root-password" + return config, nil +} diff --git a/go/logic/test_utils.go b/go/logic/test_utils.go index f552cfc76..d532e0920 100644 --- a/go/logic/test_utils.go +++ b/go/logic/test_utils.go @@ -28,14 +28,6 @@ func getTestGhostTableName() string { return fmt.Sprintf("`%s`.`_%s_gho`", testMysqlDatabase, testMysqlTableName) } -func getTestRevertedTableName() string { - return fmt.Sprintf("`%s`.`_%s_rev_del`", testMysqlDatabase, testMysqlTableName) -} - -func getTestOldTableName() string { - return fmt.Sprintf("`%s`.`_%s_del`", testMysqlDatabase, testMysqlTableName) -} - func getTestConnectionConfig(ctx context.Context, container testcontainers.Container) (*mysql.ConnectionConfig, error) { host, err := container.Host(ctx) if err != nil { diff --git a/localtests/sysbench/generate_load b/localtests/sysbench/generate_load new file mode 100755 index 000000000..f1f641af1 --- /dev/null +++ b/localtests/sysbench/generate_load @@ -0,0 +1 @@ +#!/usr/bin/env bash diff --git a/localtests/test.sh b/localtests/test.sh index d918d473b..575cf6a72 100755 --- a/localtests/test.sh +++ b/localtests/test.sh @@ -16,6 +16,7 @@ toxiproxy=false gtid=false storage_engine=innodb exec_command_file=/tmp/gh-ost-test.bash +generate_load_file=/tmp/gh-ost-generate-load.bash ghost_structure_output_file=/tmp/gh-ost-test.ghost.structure.sql orig_content_output_file=/tmp/gh-ost-test.orig.content.csv ghost_content_output_file=/tmp/gh-ost-test.ghost.content.csv diff --git a/script/test b/script/test index 5c32b370c..3f66288d1 100755 --- a/script/test +++ b/script/test @@ -5,7 +5,7 @@ set -e . script/bootstrap echo "Verifying code is formatted via 'gofmt -s -w go/'" -gofmt -s -w go/ +gofmt -s -w go/ git diff --exit-code --quiet echo "Building" @@ -14,4 +14,4 @@ script/build cd .gopath/src/github.com/github/gh-ost echo "Running unit tests" -go test -v -covermode=atomic ./go/... +go test -v -p 1 -covermode=atomic -race ./go/...