diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index 5363692f7..e2c6cd85e 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -216,7 +216,12 @@ type DefaultBackendStorage struct { // randomly pick a key from a map (in this case, the backends) in // Golang. agentIDs []string - random *rand.Rand + // nonDrainingIDs tracks identifiers whose primary backend is not draining. + // Used to avoid scanning all identifiers on every selection. + nonDrainingIDs []string + // nonDrainingIndex tracks index positions in nonDrainingIDs for O(1) removals. + nonDrainingIndex map[string]int + random *rand.Rand // idTypes contains the valid identifier types for this // DefaultBackendStorage. The DefaultBackendStorage may only tolerate certain // types of identifiers when associating to a specific BackendManager, @@ -244,10 +249,11 @@ func NewDefaultBackendStorage(idTypes []header.IdentifierType, proxyStrategy pro metrics.Metrics.SetTotalBackendCount(proxyStrategy, 0) return &DefaultBackendStorage{ - backends: make(map[string][]*Backend), - random: rand.New(rand.NewSource(time.Now().UnixNano())), /* #nosec G404 */ - idTypes: idTypes, - proxyStrategy: proxyStrategy, + backends: make(map[string][]*Backend), + nonDrainingIndex: make(map[string]int), + random: rand.New(rand.NewSource(time.Now().UnixNano())), /* #nosec G404 */ + idTypes: idTypes, + proxyStrategy: proxyStrategy, } } @@ -255,15 +261,57 @@ func containIDType(idTypes []header.IdentifierType, idType header.IdentifierType return slices.Contains(idTypes, idType) } +func (s *DefaultBackendStorage) addNonDrainingIdentifierLocked(identifier string) { + if _, ok := s.nonDrainingIndex[identifier]; ok { + return + } + s.nonDrainingIDs = append(s.nonDrainingIDs, identifier) + s.nonDrainingIndex[identifier] = len(s.nonDrainingIDs) - 1 +} + +func (s *DefaultBackendStorage) removeNonDrainingIdentifierLocked(identifier string) { + idx, ok := s.nonDrainingIndex[identifier] + if !ok { + return + } + lastIdx := len(s.nonDrainingIDs) - 1 + if idx != lastIdx { + lastIdentifier := s.nonDrainingIDs[lastIdx] + s.nonDrainingIDs[idx] = lastIdentifier + s.nonDrainingIndex[lastIdentifier] = idx + } + s.nonDrainingIDs = s.nonDrainingIDs[:lastIdx] + delete(s.nonDrainingIndex, identifier) +} + +func (s *DefaultBackendStorage) refreshNonDrainingIdentifierLocked(identifier string) { + backends, ok := s.backends[identifier] + if !ok || len(backends) == 0 { + s.removeNonDrainingIdentifierLocked(identifier) + return + } + + // For a given identifier, routing always uses the first backend. + if backends[0].IsDraining() { + s.removeNonDrainingIdentifierLocked(identifier) + return + } + s.addNonDrainingIdentifierLocked(identifier) +} + // addBackend adds a backend. func (s *DefaultBackendStorage) addBackend(identifier string, idType header.IdentifierType, backend *Backend) { + s.mu.Lock() + defer s.mu.Unlock() + s.addBackendLocked(identifier, idType, backend) +} + +func (s *DefaultBackendStorage) addBackendLocked(identifier string, idType header.IdentifierType, backend *Backend) { if !containIDType(s.idTypes, idType) { klog.V(3).InfoS("fail to add backend", "backend", identifier, "error", &ErrWrongIDType{idType, s.idTypes}) return } klog.V(2).InfoS("Register backend for agent", "agentID", identifier) - s.mu.Lock() - defer s.mu.Unlock() _, ok := s.backends[identifier] if ok { for _, b := range s.backends[identifier] { @@ -276,6 +324,7 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden return } s.backends[identifier] = []*Backend{backend} + s.refreshNonDrainingIdentifierLocked(identifier) metrics.Metrics.SetBackendCountDeprecated(len(s.backends)) metrics.Metrics.SetTotalBackendCount(s.proxyStrategy, len(s.backends)) s.agentIDs = append(s.agentIDs, identifier) @@ -283,30 +332,39 @@ func (s *DefaultBackendStorage) addBackend(identifier string, idType header.Iden // removeBackend removes a backend. func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.IdentifierType, backend *Backend) { + s.mu.Lock() + defer s.mu.Unlock() + s.removeBackendLocked(identifier, idType, backend) +} + +func (s *DefaultBackendStorage) removeBackendLocked(identifier string, idType header.IdentifierType, backend *Backend) { if !containIDType(s.idTypes, idType) { klog.ErrorS(&ErrWrongIDType{idType, s.idTypes}, "fail to remove backend") return } klog.V(2).InfoS("Remove connection for agent", "agentID", identifier) - s.mu.Lock() - defer s.mu.Unlock() backends, ok := s.backends[identifier] if !ok { klog.V(1).InfoS("Cannot find agent in backends", "identifier", identifier) return } var found bool + var removedFirst bool for i, b := range backends { if b == backend { s.backends[identifier] = append(s.backends[identifier][:i], s.backends[identifier][i+1:]...) if i == 0 && len(s.backends[identifier]) != 0 { klog.V(1).InfoS("This should not happen. Removed connection that is not the first connection", "agentID", identifier) } + if i == 0 { + removedFirst = true + } found = true } } if len(s.backends[identifier]) == 0 { delete(s.backends, identifier) + s.removeNonDrainingIdentifierLocked(identifier) for i := range s.agentIDs { if s.agentIDs[i] == identifier { s.agentIDs[i] = s.agentIDs[len(s.agentIDs)-1] @@ -314,6 +372,8 @@ func (s *DefaultBackendStorage) removeBackend(identifier string, idType header.I break } } + } else if removedFirst { + s.refreshNonDrainingIdentifierLocked(identifier) } if !found { klog.V(1).InfoS("Could not find connection matching identifier to remove", "agentID", identifier, "idType", idType) @@ -361,35 +421,30 @@ func (s *DefaultBackendStorage) GetRandomBackend() (*Backend, error) { return nil, &ErrNotFound{} } - var firstDrainingBackend *Backend - - // Start at a random agent and check each agent in sequence - startIdx := s.random.Intn(len(s.agentIDs)) - for i := 0; i < len(s.agentIDs); i++ { - // Wrap around using modulo - currentIdx := (startIdx + i) % len(s.agentIDs) - agentID := s.agentIDs[currentIdx] - // always return the first connection to an agent, because the agent - // will close later connections if there are multiple. - backend := s.backends[agentID][0] - - if !backend.IsDraining() { - klog.V(3).InfoS("Pick agent as backend", "agentID", agentID) - return backend, nil + // Prefer random selection from known non-draining identifiers. + // This avoids a full scan on every selection. + for len(s.nonDrainingIDs) > 0 { + idx := s.random.Intn(len(s.nonDrainingIDs)) + identifier := s.nonDrainingIDs[idx] + backends, ok := s.backends[identifier] + if !ok || len(backends) == 0 { + s.removeNonDrainingIdentifierLocked(identifier) + continue } - - // Keep track of first draining backend as fallback - if firstDrainingBackend == nil { - firstDrainingBackend = backend + backend := backends[0] + // A backend may have transitioned to draining since the pool was updated. + // Remove stale entries lazily and retry. + if backend.IsDraining() { + s.removeNonDrainingIdentifierLocked(identifier) + continue } + klog.V(3).InfoS("Pick agent as backend", "agentID", identifier) + return backend, nil } // All agents are draining, use one as fallback - if firstDrainingBackend != nil { - agentID := firstDrainingBackend.id - klog.V(3).InfoS("No non-draining backends available, using draining backend as fallback", "agentID", agentID) - return firstDrainingBackend, nil - } - - return nil, &ErrNotFound{} + agentID := s.agentIDs[s.random.Intn(len(s.agentIDs))] + backend := s.backends[agentID][0] + klog.V(3).InfoS("No non-draining backends available, using draining backend as fallback", "agentID", agentID) + return backend, nil } diff --git a/pkg/server/backend_manager_test.go b/pkg/server/backend_manager_test.go index fe7f97839..c556d311e 100644 --- a/pkg/server/backend_manager_test.go +++ b/pkg/server/backend_manager_test.go @@ -18,6 +18,7 @@ package server import ( "context" + "math/rand" "reflect" "testing" @@ -41,6 +42,17 @@ func mockAgentConn(ctrl *gomock.Controller, agentID string, agentIdentifiers []s return agentConn } +func expectedBackendIndex(backends map[string][]*Backend) map[string]map[*Backend]int { + index := make(map[string]map[*Backend]int, len(backends)) + for identifier, bes := range backends { + index[identifier] = make(map[*Backend]int, len(bes)) + for i, backend := range bes { + index[identifier][backend] = i + } + } + return index +} + func TestNewBackend(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -384,6 +396,74 @@ func TestDestHostBackendManager_WithDuplicateIdents(t *testing.T) { } } +func TestDestHostBackendManager_NonDrainingCacheTracksCurrentBackends(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost"})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"host=localhost"})) + backend2.SetDraining() + backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{"host=localhost"})) + + p := NewDestHostBackendManager() + + p.AddBackend(backend1) + p.AddBackend(backend2) + + expectedBackends := map[string][]*Backend{ + "localhost": {backend1, backend2}, + } + expectedNonDraining := map[string][]*Backend{ + "localhost": {backend1}, + } + + if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := expectedNonDraining, p.nonDrainingBackends; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := expectedBackendIndex(expectedNonDraining), p.nonDrainingIndex; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + + p.RemoveBackend(backend1) + + expectedBackends = map[string][]*Backend{ + "localhost": {backend2}, + } + expectedNonDraining = map[string][]*Backend{} + + if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := expectedNonDraining, p.nonDrainingBackends; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := expectedBackendIndex(expectedNonDraining), p.nonDrainingIndex; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + + p.AddBackend(backend3) + + expectedBackends = map[string][]*Backend{ + "localhost": {backend2, backend3}, + } + expectedNonDraining = map[string][]*Backend{ + "localhost": {backend3}, + } + + if e, a := expectedBackends, p.backends; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := expectedNonDraining, p.nonDrainingBackends; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := expectedBackendIndex(expectedNonDraining), p.nonDrainingIndex; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } +} + func TestDefaultBackendManager_GetRandomBackend_DrainingFallback(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -444,6 +524,44 @@ func TestDefaultBackendManager_GetRandomBackend_DrainingFallback(t *testing.T) { } } +func TestDefaultBackendManager_GetRandomBackend_LazilyEvictsDrainingCandidate(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) + + p := NewDefaultBackendManager() + // Seed so the first selection checks the stale entry for agent1. + p.random = rand.New(rand.NewSource(0)) + + p.AddBackend(backend1) + p.AddBackend(backend2) + + if e, a := []string{"agent1", "agent2"}, p.nonDrainingIDs; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := map[string]int{"agent1": 0, "agent2": 1}, p.nonDrainingIndex; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + + backend1.SetDraining() + + b, err := p.Backend(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if b != backend2 { + t.Fatalf("expected backend2 after lazy eviction, got %v", b) + } + if e, a := []string{"agent2"}, p.nonDrainingIDs; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + if e, a := map[string]int{"agent2": 0}, p.nonDrainingIndex; !reflect.DeepEqual(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } +} + func TestDestHostBackendManager_Backend_DrainingFallback(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -512,3 +630,89 @@ func TestDestHostBackendManager_Backend_DrainingFallback(t *testing.T) { t.Error("expected non-draining backend for otherhost") } } + +func TestDefaultBackendManager_GetRandomBackend_DistributesAcrossNonDrainingCandidates(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{})) + backend3, _ := NewBackend(mockAgentConn(ctrl, "agent3", []string{})) + + p := NewDefaultBackendManager() + // Use a deterministic random source so this test is stable. + p.random = rand.New(rand.NewSource(7)) + + p.AddBackend(backend1) + p.AddBackend(backend2) + p.AddBackend(backend3) + backend1.SetDraining() + + const iterations = 6000 + counts := map[*Backend]int{ + backend2: 0, + backend3: 0, + } + + for i := 0; i < iterations; i++ { + b, err := p.Backend(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if b == backend1 { + t.Fatalf("expected non-draining backend, got draining backend1") + } + counts[b]++ + } + + diff := counts[backend2] - counts[backend3] + if diff < 0 { + diff = -diff + } + // With random selection from the non-draining set, distribution should be + // approximately even. Allow 5% skew to avoid flakiness. + if diff > iterations/20 { + t.Fatalf("expected near-even distribution across non-draining backends, got backend2=%d backend3=%d", counts[backend2], counts[backend3]) + } +} + +func TestDestHostBackendManager_Backend_DistributesAcrossNonDrainingCandidates(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + backend1, _ := NewBackend(mockAgentConn(ctrl, "agent1", []string{"host=localhost"})) + backend2, _ := NewBackend(mockAgentConn(ctrl, "agent2", []string{"host=localhost"})) + + p := NewDestHostBackendManager() + // Use a deterministic random source so this test is stable. + p.random = rand.New(rand.NewSource(11)) + + p.AddBackend(backend1) + p.AddBackend(backend2) + + ctx := context.WithValue(context.Background(), destHostKey, "localhost") + + const iterations = 6000 + counts := map[*Backend]int{ + backend1: 0, + backend2: 0, + } + + for i := 0; i < iterations; i++ { + b, err := p.Backend(ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + counts[b]++ + } + + diff := counts[backend1] - counts[backend2] + if diff < 0 { + diff = -diff + } + // With random selection from non-draining candidates, distribution should + // be approximately even. Allow 5% skew to avoid flakiness. + if diff > iterations/20 { + t.Fatalf("expected near-even distribution across non-draining backends, got backend1=%d backend2=%d", counts[backend1], counts[backend2]) + } +} diff --git a/pkg/server/desthost_backend_manager.go b/pkg/server/desthost_backend_manager.go index 5c041cc04..b0e88d5de 100644 --- a/pkg/server/desthost_backend_manager.go +++ b/pkg/server/desthost_backend_manager.go @@ -26,6 +26,11 @@ import ( type DestHostBackendManager struct { *DefaultBackendStorage + // nonDrainingBackends tracks non-draining backends per destination host. + // This avoids scanning all candidate backends on each request. + nonDrainingBackends map[string][]*Backend + // nonDrainingIndex tracks backend index positions for O(1) removals. + nonDrainingIndex map[string]map[*Backend]int } var _ BackendManager = &DestHostBackendManager{} @@ -33,71 +38,127 @@ var _ BackendManager = &DestHostBackendManager{} func NewDestHostBackendManager() *DestHostBackendManager { return &DestHostBackendManager{ DefaultBackendStorage: NewDefaultBackendStorage( - []header.IdentifierType{header.IPv4, header.IPv6, header.Host}, proxystrategies.ProxyStrategyDestHost)} + []header.IdentifierType{header.IPv4, header.IPv6, header.Host}, proxystrategies.ProxyStrategyDestHost), + nonDrainingBackends: make(map[string][]*Backend), + nonDrainingIndex: make(map[string]map[*Backend]int), + } +} + +func (dibm *DestHostBackendManager) addNonDrainingBackendLocked(identifier string, backend *Backend) { + backendIndexByIdentifier, ok := dibm.nonDrainingIndex[identifier] + if !ok { + backendIndexByIdentifier = make(map[*Backend]int) + dibm.nonDrainingIndex[identifier] = backendIndexByIdentifier + } + if _, ok := backendIndexByIdentifier[backend]; ok { + return + } + dibm.nonDrainingBackends[identifier] = append(dibm.nonDrainingBackends[identifier], backend) + backendIndexByIdentifier[backend] = len(dibm.nonDrainingBackends[identifier]) - 1 +} + +func (dibm *DestHostBackendManager) removeNonDrainingBackendLocked(identifier string, backend *Backend) { + backendIndexByIdentifier, ok := dibm.nonDrainingIndex[identifier] + if !ok { + return + } + idx, ok := backendIndexByIdentifier[backend] + if !ok { + return + } + + lastIdx := len(dibm.nonDrainingBackends[identifier]) - 1 + if idx != lastIdx { + lastBackend := dibm.nonDrainingBackends[identifier][lastIdx] + dibm.nonDrainingBackends[identifier][idx] = lastBackend + backendIndexByIdentifier[lastBackend] = idx + } + dibm.nonDrainingBackends[identifier] = dibm.nonDrainingBackends[identifier][:lastIdx] + delete(backendIndexByIdentifier, backend) + + if len(dibm.nonDrainingBackends[identifier]) == 0 { + delete(dibm.nonDrainingBackends, identifier) + delete(dibm.nonDrainingIndex, identifier) + } +} + +func (dibm *DestHostBackendManager) addBackendForIdentifierLocked(identifier string, idType header.IdentifierType, backend *Backend, isDraining bool) { + dibm.addBackendLocked(identifier, idType, backend) + if !isDraining { + dibm.addNonDrainingBackendLocked(identifier, backend) + } +} + +func (dibm *DestHostBackendManager) removeBackendForIdentifierLocked(identifier string, idType header.IdentifierType, backend *Backend) { + dibm.removeBackendLocked(identifier, idType, backend) + dibm.removeNonDrainingBackendLocked(identifier, backend) } func (dibm *DestHostBackendManager) AddBackend(backend *Backend) { + isDraining := backend.IsDraining() agentIdentifiers := backend.GetAgentIdentifiers() + dibm.mu.Lock() + defer dibm.mu.Unlock() for _, ipv4 := range agentIdentifiers.IPv4 { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv4) - dibm.addBackend(ipv4, header.IPv4, backend) + dibm.addBackendForIdentifierLocked(ipv4, header.IPv4, backend, isDraining) } for _, ipv6 := range agentIdentifiers.IPv6 { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", ipv6) - dibm.addBackend(ipv6, header.IPv6, backend) + dibm.addBackendForIdentifierLocked(ipv6, header.IPv6, backend, isDraining) } for _, host := range agentIdentifiers.Host { klog.V(5).InfoS("Add the agent to DestHostBackendManager", "agent address", host) - dibm.addBackend(host, header.Host, backend) + dibm.addBackendForIdentifierLocked(host, header.Host, backend, isDraining) } } func (dibm *DestHostBackendManager) RemoveBackend(backend *Backend) { agentIdentifiers := backend.GetAgentIdentifiers() + dibm.mu.Lock() + defer dibm.mu.Unlock() for _, ipv4 := range agentIdentifiers.IPv4 { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv4) - dibm.removeBackend(ipv4, header.IPv4, backend) + dibm.removeBackendForIdentifierLocked(ipv4, header.IPv4, backend) } for _, ipv6 := range agentIdentifiers.IPv6 { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", ipv6) - dibm.removeBackend(ipv6, header.IPv6, backend) + dibm.removeBackendForIdentifierLocked(ipv6, header.IPv6, backend) } for _, host := range agentIdentifiers.Host { klog.V(5).InfoS("Remove the agent from the DestHostBackendManager", "agentHost", host) - dibm.removeBackend(host, header.Host, backend) + dibm.removeBackendForIdentifierLocked(host, header.Host, backend) } } // Backend tries to get a backend associating to the request destination host. func (dibm *DestHostBackendManager) Backend(ctx context.Context) (*Backend, error) { - dibm.mu.RLock() - defer dibm.mu.RUnlock() + dibm.mu.Lock() + defer dibm.mu.Unlock() if len(dibm.backends) == 0 { return nil, &ErrNotFound{} } - destHost := ctx.Value(destHostKey).(string) + destHost, _ := ctx.Value(destHostKey).(string) if destHost != "" { - bes, exist := dibm.backends[destHost] - if exist && len(bes) > 0 { - var firstDrainingBackend *Backend - - // Find a non-draining backend for this destination host - for _, backend := range bes { - if !backend.IsDraining() { - klog.V(5).InfoS("Get the backend through the DestHostBackendManager", "destHost", destHost) - return backend, nil - } - // Keep track of first draining backend as fallback - if firstDrainingBackend == nil { - firstDrainingBackend = backend - } + // Prefer random selection from known non-draining backends. + // Remove stale entries lazily when backends transition to draining. + for len(dibm.nonDrainingBackends[destHost]) > 0 { + idx := dibm.random.Intn(len(dibm.nonDrainingBackends[destHost])) + backend := dibm.nonDrainingBackends[destHost][idx] + if backend.IsDraining() { + dibm.removeNonDrainingBackendLocked(destHost, backend) + continue } + klog.V(5).InfoS("Get the backend through the DestHostBackendManager", "destHost", destHost) + return backend, nil + } - // All backends for this destination are draining, use one as fallback - if firstDrainingBackend != nil { - klog.V(3).InfoS("All backends for destination host are draining, using one as fallback", "destHost", destHost) - return firstDrainingBackend, nil - } + // All backends for this destination are draining, use one as fallback. + bes, exist := dibm.backends[destHost] + if exist && len(bes) > 0 { + backend := bes[dibm.random.Intn(len(bes))] + klog.V(3).InfoS("All backends for destination host are draining, using one as fallback", "destHost", destHost) + return backend, nil } } return nil, &ErrNotFound{}