diff --git a/apis/gateway/v1beta1/loadbalancerconfig_types.go b/apis/gateway/v1beta1/loadbalancerconfig_types.go index 0136d89e49..9a262b2269 100644 --- a/apis/gateway/v1beta1/loadbalancerconfig_types.go +++ b/apis/gateway/v1beta1/loadbalancerconfig_types.go @@ -198,6 +198,20 @@ type LoadBalancerConfigurationSpec struct { // +optional MergingMode *LoadBalancerConfigMergeMode `json:"mergingMode,omitempty"` + // region is the AWS region where the load balancer will be deployed. When unset, the controller's default region is used. + // When set to a different region, vpcId, vpcSelector, or loadBalancerSubnets with identifiers must be set so the VPC can be resolved. + // +optional + Region *string `json:"region,omitempty"` + + // vpcId is the VPC ID in the target region. Used when region is set (and especially when it differs from the controller default). + // +optional + VpcID *string `json:"vpcId,omitempty"` + + // vpcSelector selects the VPC in the target region by tags. Each key is a tag name; the value list is the allowed tag values. + // A VPC matches if it has each tag key with one of the corresponding values. Exactly one VPC must match in the target region. + // +optional + VpcSelector *map[string][]string `json:"vpcSelector,omitempty"` + // +kubebuilder:validation:MinLength=1 // +kubebuilder:validation:MaxLength=32 // loadBalancerName defines the name of the LB to provision. If unspecified, it will be automatically generated. diff --git a/apis/gateway/v1beta1/zz_generated.deepcopy.go b/apis/gateway/v1beta1/zz_generated.deepcopy.go index e1312cbb91..59a4cacc63 100644 --- a/apis/gateway/v1beta1/zz_generated.deepcopy.go +++ b/apis/gateway/v1beta1/zz_generated.deepcopy.go @@ -655,6 +655,36 @@ func (in *LoadBalancerConfigurationSpec) DeepCopyInto(out *LoadBalancerConfigura *out = new(LoadBalancerConfigMergeMode) **out = **in } + if in.Region != nil { + in, out := &in.Region, &out.Region + *out = new(string) + **out = **in + } + if in.VpcID != nil { + in, out := &in.VpcID, &out.VpcID + *out = new(string) + **out = **in + } + if in.VpcSelector != nil { + in, out := &in.VpcSelector, &out.VpcSelector + *out = new(map[string][]string) + if **in != nil { + in, out := *in, *out + *out = make(map[string][]string, len(*in)) + for key, val := range *in { + var outVal []string + if val == nil { + (*out)[key] = nil + } else { + inVal := (*in)[key] + in, out := &inVal, &outVal + *out = make([]string, len(*in)) + copy(*out, *in) + } + (*out)[key] = outVal + } + } + } if in.LoadBalancerName != nil { in, out := &in.LoadBalancerName, &out.LoadBalancerName *out = new(string) diff --git a/config/crd/gateway/gateway-crds.yaml b/config/crd/gateway/gateway-crds.yaml index 212daa53b1..3f568f6830 100644 --- a/config/crd/gateway/gateway-crds.yaml +++ b/config/crd/gateway/gateway-crds.yaml @@ -733,6 +733,11 @@ spec: required: - capacityUnits type: object + region: + description: |- + region is the AWS region where the load balancer will be deployed. When unset, the controller's default region is used. + When set to a different region, vpcId, vpcSelector, or loadBalancerSubnets with identifiers must be set so the VPC can be resolved. + type: string scheme: description: scheme defines the type of LB to provision. If unspecified, it will be automatically inferred. @@ -772,6 +777,19 @@ spec: type: string description: Tags the AWS Tags on all related resources to the gateway. type: object + vpcId: + description: vpcId is the VPC ID in the target region. Used when region + is set (and especially when it differs from the controller default). + type: string + vpcSelector: + additionalProperties: + items: + type: string + type: array + description: |- + vpcSelector selects the VPC in the target region by tags. Each key is a tag name; the value list is the allowed tag values. + A VPC matches if it has each tag key with one of the corresponding values. Exactly one VPC must match in the target region. + type: object wafV2: description: WAFv2 define the AWS WAFv2 settings for a Gateway [Application Load Balancer] diff --git a/config/crd/gateway/gateway.k8s.aws_loadbalancerconfigurations.yaml b/config/crd/gateway/gateway.k8s.aws_loadbalancerconfigurations.yaml index 25de7944f2..6ec26ff779 100644 --- a/config/crd/gateway/gateway.k8s.aws_loadbalancerconfigurations.yaml +++ b/config/crd/gateway/gateway.k8s.aws_loadbalancerconfigurations.yaml @@ -278,6 +278,11 @@ spec: required: - capacityUnits type: object + region: + description: |- + region is the AWS region where the load balancer will be deployed. When unset, the controller's default region is used. + When set to a different region, vpcId, vpcSelector, or loadBalancerSubnets with identifiers must be set so the VPC can be resolved. + type: string scheme: description: scheme defines the type of LB to provision. If unspecified, it will be automatically inferred. @@ -317,6 +322,19 @@ spec: type: string description: Tags the AWS Tags on all related resources to the gateway. type: object + vpcId: + description: vpcId is the VPC ID in the target region. Used when region + is set (and especially when it differs from the controller default). + type: string + vpcSelector: + additionalProperties: + items: + type: string + type: array + description: |- + vpcSelector selects the VPC in the target region by tags. Each key is a tag name; the value list is the allowed tag values. + A VPC matches if it has each tag key with one of the corresponding values. Exactly one VPC must match in the target region. + type: object wafV2: description: WAFv2 define the AWS WAFv2 settings for a Gateway [Application Load Balancer] diff --git a/controllers/gateway/gateway_controller.go b/controllers/gateway/gateway_controller.go index afb56862ff..af74b68a1e 100644 --- a/controllers/gateway/gateway_controller.go +++ b/controllers/gateway/gateway_controller.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "sigs.k8s.io/aws-load-balancer-controller/pkg/certs" @@ -28,6 +29,7 @@ import ( elbv2deploy "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/elbv2" "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/tracking" ctrlerrors "sigs.k8s.io/aws-load-balancer-controller/pkg/error" + gatewaypkg "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway" "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway/constants" gateway_constants "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway/constants" gatewaymodel "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway/model" @@ -60,14 +62,14 @@ const ( var _ Reconciler = &gatewayReconciler{} -// NewNLBGatewayReconciler constructs a gateway reconciler to handle specifically for NLB gateways -func NewNLBGatewayReconciler(routeLoader routeutils.Loader, referenceCounter referencecounter.ServiceReferenceCounter, cloud services.Cloud, k8sClient client.Client, certDiscovery certs.CertDiscovery, eventRecorder record.EventRecorder, controllerConfig config.ControllerConfig, finalizerManager k8s.FinalizerManager, networkingManager networking.NetworkingManager, networkingSGReconciler networking.SecurityGroupReconciler, networkingSGManager networking.SecurityGroupManager, elbv2TaggingManager elbv2deploy.TaggingManager, subnetResolver networking.SubnetsResolver, vpcInfoProvider networking.VPCInfoProvider, backendSGProvider networking.BackendSGProvider, sgResolver networking.SecurityGroupResolver, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters, targetGroupCollector awsmetrics.TargetGroupCollector, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper) Reconciler { - return newGatewayReconciler(constants.NLBGatewayController, elbv2model.LoadBalancerTypeNetwork, controllerConfig.NLBGatewayMaxConcurrentReconciles, constants.NLBGatewayTagPrefix, shared_constants.NLBGatewayFinalizer, certDiscovery, routeLoader, referenceCounter, routeutils.L4RouteFilter, cloud, k8sClient, eventRecorder, controllerConfig, finalizerManager, networkingSGReconciler, networkingManager, networkingSGManager, elbv2TaggingManager, subnetResolver, vpcInfoProvider, backendSGProvider, sgResolver, nlbAddons, targetGroupNameToArnMapper, logger, metricsCollector, reconcileCounters.IncrementNLBGateway, targetGroupCollector) +// NewNLBGatewayReconciler constructs a gateway reconciler to handle specifically for NLB gateways. +func NewNLBGatewayReconciler(routeLoader routeutils.Loader, referenceCounter referencecounter.ServiceReferenceCounter, cloud services.Cloud, k8sClient client.Client, certDiscovery certs.CertDiscovery, eventRecorder record.EventRecorder, controllerConfig config.ControllerConfig, finalizerManager k8s.FinalizerManager, networkingManager networking.NetworkingManager, networkingSGReconciler networking.SecurityGroupReconciler, networkingSGManager networking.SecurityGroupManager, elbv2TaggingManager elbv2deploy.TaggingManager, subnetResolver networking.SubnetsResolver, vpcInfoProvider networking.VPCInfoProvider, backendSGProvider networking.BackendSGProvider, sgResolver networking.SecurityGroupResolver, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters, targetGroupCollector awsmetrics.TargetGroupCollector, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, cloudProvider gatewaypkg.CloudProvider, stackDeployerFactory func(cloud services.Cloud) deploy.StackDeployer) Reconciler { + return newGatewayReconciler(constants.NLBGatewayController, elbv2model.LoadBalancerTypeNetwork, controllerConfig.NLBGatewayMaxConcurrentReconciles, constants.NLBGatewayTagPrefix, shared_constants.NLBGatewayFinalizer, certDiscovery, routeLoader, referenceCounter, routeutils.L4RouteFilter, cloud, k8sClient, eventRecorder, controllerConfig, finalizerManager, networkingSGReconciler, networkingManager, networkingSGManager, elbv2TaggingManager, subnetResolver, vpcInfoProvider, backendSGProvider, sgResolver, nlbAddons, targetGroupNameToArnMapper, logger, metricsCollector, reconcileCounters.IncrementNLBGateway, targetGroupCollector, cloudProvider, stackDeployerFactory) } -// NewALBGatewayReconciler constructs a gateway reconciler to handle specifically for ALB gateways -func NewALBGatewayReconciler(routeLoader routeutils.Loader, cloud services.Cloud, k8sClient client.Client, certDiscovery certs.CertDiscovery, referenceCounter referencecounter.ServiceReferenceCounter, eventRecorder record.EventRecorder, controllerConfig config.ControllerConfig, finalizerManager k8s.FinalizerManager, networkingManager networking.NetworkingManager, networkingSGReconciler networking.SecurityGroupReconciler, networkingSGManager networking.SecurityGroupManager, elbv2TaggingManager elbv2deploy.TaggingManager, subnetResolver networking.SubnetsResolver, vpcInfoProvider networking.VPCInfoProvider, backendSGProvider networking.BackendSGProvider, sgResolver networking.SecurityGroupResolver, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters, targetGroupCollector awsmetrics.TargetGroupCollector, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper) Reconciler { - return newGatewayReconciler(constants.ALBGatewayController, elbv2model.LoadBalancerTypeApplication, controllerConfig.ALBGatewayMaxConcurrentReconciles, constants.ALBGatewayTagPrefix, shared_constants.ALBGatewayFinalizer, certDiscovery, routeLoader, referenceCounter, routeutils.L7RouteFilter, cloud, k8sClient, eventRecorder, controllerConfig, finalizerManager, networkingSGReconciler, networkingManager, networkingSGManager, elbv2TaggingManager, subnetResolver, vpcInfoProvider, backendSGProvider, sgResolver, albAddons, targetGroupNameToArnMapper, logger, metricsCollector, reconcileCounters.IncrementALBGateway, targetGroupCollector) +// NewALBGatewayReconciler constructs a gateway reconciler to handle specifically for ALB gateways. +func NewALBGatewayReconciler(routeLoader routeutils.Loader, cloud services.Cloud, k8sClient client.Client, certDiscovery certs.CertDiscovery, referenceCounter referencecounter.ServiceReferenceCounter, eventRecorder record.EventRecorder, controllerConfig config.ControllerConfig, finalizerManager k8s.FinalizerManager, networkingManager networking.NetworkingManager, networkingSGReconciler networking.SecurityGroupReconciler, networkingSGManager networking.SecurityGroupManager, elbv2TaggingManager elbv2deploy.TaggingManager, subnetResolver networking.SubnetsResolver, vpcInfoProvider networking.VPCInfoProvider, backendSGProvider networking.BackendSGProvider, sgResolver networking.SecurityGroupResolver, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters, targetGroupCollector awsmetrics.TargetGroupCollector, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, cloudProvider gatewaypkg.CloudProvider, stackDeployerFactory func(cloud services.Cloud) deploy.StackDeployer) Reconciler { + return newGatewayReconciler(constants.ALBGatewayController, elbv2model.LoadBalancerTypeApplication, controllerConfig.ALBGatewayMaxConcurrentReconciles, constants.ALBGatewayTagPrefix, shared_constants.ALBGatewayFinalizer, certDiscovery, routeLoader, referenceCounter, routeutils.L7RouteFilter, cloud, k8sClient, eventRecorder, controllerConfig, finalizerManager, networkingSGReconciler, networkingManager, networkingSGManager, elbv2TaggingManager, subnetResolver, vpcInfoProvider, backendSGProvider, sgResolver, albAddons, targetGroupNameToArnMapper, logger, metricsCollector, reconcileCounters.IncrementALBGateway, targetGroupCollector, cloudProvider, stackDeployerFactory) } // newGatewayReconciler constructs a reconciler that responds to gateway object changes @@ -78,7 +80,7 @@ func newGatewayReconciler(controllerName string, lbType elbv2model.LoadBalancerT networkingManager networking.NetworkingManager, networkingSGManager networking.SecurityGroupManager, elbv2TaggingManager elbv2deploy.TaggingManager, subnetResolver networking.SubnetsResolver, vpcInfoProvider networking.VPCInfoProvider, backendSGProvider networking.BackendSGProvider, sgResolver networking.SecurityGroupResolver, supportedAddons []addon.Addon, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, - reconcileTracker func(namespaceName types.NamespacedName), targetGroupCollector awsmetrics.TargetGroupCollector) Reconciler { + reconcileTracker func(namespaceName types.NamespacedName), targetGroupCollector awsmetrics.TargetGroupCollector, cloudProvider gatewaypkg.CloudProvider, stackDeployerFactory func(cloud services.Cloud) deploy.StackDeployer) Reconciler { trackingProvider := tracking.NewDefaultProvider(gatewayTagPrefix, controllerConfig.ClusterName) modelBuilder := gatewaymodel.NewModelBuilder(subnetResolver, vpcInfoProvider, cloud.VpcID(), lbType, trackingProvider, elbv2TaggingManager, controllerConfig, cloud.EC2(), cloud.ELBV2(), certDiscovery, k8sClient, controllerConfig.FeatureGates, controllerConfig.ClusterName, controllerConfig.DefaultTags, sets.New(controllerConfig.ExternalManagedTags...), controllerConfig.DefaultSSLPolicy, controllerConfig.DefaultTargetType, controllerConfig.DefaultLoadBalancerScheme, backendSGProvider, sgResolver, controllerConfig.EnableBackendSecurityGroup, controllerConfig.DisableRestrictedSGRules, supportedAddons, logger) @@ -109,6 +111,14 @@ func newGatewayReconciler(controllerName string, lbType elbv2model.LoadBalancerT serviceReferenceCounter: serviceReferenceCounter, gatewayConditionUpdater: prepareGatewayConditionUpdate, targetGroupNameToArnMapper: targetGroupNameToArnMapper, + cloudProvider: cloudProvider, + stackDeployerFactory: stackDeployerFactory, + defaultRegion: cloud.Region(), + controllerConfig: controllerConfig, + networkingManager: networkingManager, + networkingSGManager: networkingSGManager, + networkingSGReconciler: networkingSGReconciler, + deployerCache: make(map[string]deploy.StackDeployer), } } @@ -136,6 +146,17 @@ type gatewayReconciler struct { gatewayConditionUpdater func(gw *gwv1.Gateway, targetConditionType string, newStatus metav1.ConditionStatus, reason string, message string) bool cfgResolver gatewayConfigResolver + + // Multi-region support + cloudProvider gatewaypkg.CloudProvider + stackDeployerFactory func(cloud services.Cloud) deploy.StackDeployer + defaultRegion string + controllerConfig config.ControllerConfig + networkingManager networking.NetworkingManager + networkingSGManager networking.SecurityGroupManager + networkingSGReconciler networking.SecurityGroupReconciler + deployerCache map[string]deploy.StackDeployer + deployerCacheMu sync.RWMutex } //+kubebuilder:rbac:groups=gateway.networking.k8s.io,resources=referencegrants,verbs=get;list;watch;patch @@ -222,6 +243,19 @@ func (r *gatewayReconciler) reconcileHelper(ctx context.Context, req reconcile.R return err } + effectiveRegion := r.defaultRegion + if mergedLbConfig.Spec.Region != nil && *mergedLbConfig.Spec.Region != "" { + effectiveRegion = *mergedLbConfig.Spec.Region + } + reconcileContext, err := r.cloudProvider.GetReconcileContext(ctx, effectiveRegion, &mergedLbConfig.Spec) + if err != nil { + statusErr := r.updateGatewayStatusFailure(ctx, gw, gwv1.GatewayReasonInvalid, err.Error(), nil) + if statusErr != nil { + r.logger.Error(statusErr, "Unable to update gateway status on failure to get reconcile context") + } + return err + } + isDeleting := isGatewayDeleting(gw) loaderResults, err := r.gatewayLoader.LoadRoutesForGateway(ctx, *gw, r.routeFilter, r.controllerName) @@ -257,7 +291,7 @@ func (r *gatewayReconciler) reconcileHelper(ctx context.Context, req reconcile.R } } - stack, lb, newAddOnConfig, backendSGRequired, secrets, err := r.buildModel(ctx, gw, mergedLbConfig, allRoutes, currentAddOns, isDeleting) + stack, lb, newAddOnConfig, backendSGRequired, secrets, err := r.buildModel(ctx, gw, mergedLbConfig, allRoutes, currentAddOns, isDeleting, reconcileContext) if err != nil { r.handleReconcileError(ctx, gw, err) @@ -278,8 +312,18 @@ func (r *gatewayReconciler) reconcileHelper(ctx context.Context, req reconcile.R } } + deployer := r.stackDeployer + if reconcileContext.IsCrossRegion() && r.stackDeployerFactory != nil { + deployer = r.getDeployerForCloud(reconcileContext.Cloud) + } + + backendSGProviderToUse := r.backendSGProvider + if reconcileContext.GetBackendSGProvider() != nil { + backendSGProviderToUse = reconcileContext.GetBackendSGProvider() + } + if lb == nil { - err = r.reconcileDelete(ctx, gw, stack, allRoutes) + err = r.reconcileDelete(ctx, gw, stack, allRoutes, deployer, backendSGProviderToUse) if err != nil { r.logger.Error(err, "Failed to process gateway delete") return err @@ -287,7 +331,7 @@ func (r *gatewayReconciler) reconcileHelper(ctx context.Context, req reconcile.R return nil } r.serviceReferenceCounter.UpdateRelations(getServicesFromRoutes(allRoutes), k8s.NamespacedName(gw), false) - err = r.reconcileUpdate(ctx, gw, stack, lb, backendSGRequired, secrets, *loaderResults) + err = r.reconcileUpdate(ctx, gw, stack, lb, backendSGRequired, secrets, *loaderResults, deployer, backendSGProviderToUse) if err != nil { r.logger.Error(err, "Failed to process gateway update", "gw", k8s.NamespacedName(gw)) return err @@ -302,13 +346,31 @@ func (r *gatewayReconciler) reconcileHelper(ctx context.Context, req reconcile.R return nil } -func (r *gatewayReconciler) reconcileDelete(ctx context.Context, gw *gwv1.Gateway, stack core.Stack, routes map[int32][]routeutils.RouteDescriptor) error { +func (r *gatewayReconciler) getDeployerForCloud(cloud services.Cloud) deploy.StackDeployer { + cacheKey := cloud.Region() + ":" + cloud.VpcID() + r.deployerCacheMu.RLock() + if d, ok := r.deployerCache[cacheKey]; ok { + r.deployerCacheMu.RUnlock() + return d + } + r.deployerCacheMu.RUnlock() + r.deployerCacheMu.Lock() + defer r.deployerCacheMu.Unlock() + if d, ok := r.deployerCache[cacheKey]; ok { + return d + } + d := r.stackDeployerFactory(cloud) + r.deployerCache[cacheKey] = d + return d +} + +func (r *gatewayReconciler) reconcileDelete(ctx context.Context, gw *gwv1.Gateway, stack core.Stack, routes map[int32][]routeutils.RouteDescriptor, deployer deploy.StackDeployer, backendSGProvider networking.BackendSGProvider) error { if k8s.HasFinalizer(gw, r.finalizer) { - err := r.deployModel(ctx, gw, stack, nil) + err := r.deployModel(ctx, gw, stack, nil, deployer) if err != nil { return err } - if err := r.backendSGProvider.Release(ctx, networking.ResourceTypeGateway, []types.NamespacedName{k8s.NamespacedName(gw)}); err != nil { + if err := backendSGProvider.Release(ctx, networking.ResourceTypeGateway, []types.NamespacedName{k8s.NamespacedName(gw)}); err != nil { return err } r.serviceReferenceCounter.UpdateRelations([]types.NamespacedName{}, k8s.NamespacedName(gw), true) @@ -322,21 +384,21 @@ func (r *gatewayReconciler) reconcileDelete(ctx context.Context, gw *gwv1.Gatewa } func (r *gatewayReconciler) reconcileUpdate(ctx context.Context, gw *gwv1.Gateway, stack core.Stack, - lb *elbv2model.LoadBalancer, backendSGRequired bool, secrets []types.NamespacedName, loaderResults routeutils.LoaderResult) error { + lb *elbv2model.LoadBalancer, backendSGRequired bool, secrets []types.NamespacedName, loaderResults routeutils.LoaderResult, deployer deploy.StackDeployer, backendSGProvider networking.BackendSGProvider) error { // add gateway finalizer if err := r.finalizerManager.AddFinalizers(ctx, gw, r.finalizer); err != nil { r.eventRecorder.Event(gw, corev1.EventTypeWarning, k8s.GatewayEventReasonFailedAddFinalizer, fmt.Sprintf("Failed add gateway finalizer due to %v", err)) return err } - err := r.deployModel(ctx, gw, stack, secrets) + err := r.deployModel(ctx, gw, stack, secrets, deployer) if err != nil { r.handleReconcileError(ctx, gw, err) return err } if !backendSGRequired { - if err := r.backendSGProvider.Release(ctx, networking.ResourceTypeGateway, []types.NamespacedName{k8s.NamespacedName(gw)}); err != nil { + if err := backendSGProvider.Release(ctx, networking.ResourceTypeGateway, []types.NamespacedName{k8s.NamespacedName(gw)}); err != nil { return err } } @@ -365,8 +427,11 @@ func (r *gatewayReconciler) handleReconcileError(ctx context.Context, gw *gwv1.G } } -func (r *gatewayReconciler) deployModel(ctx context.Context, gw *gwv1.Gateway, stack core.Stack, secrets []types.NamespacedName) error { - if err := r.stackDeployer.Deploy(ctx, stack, r.metricsCollector, r.controllerName); err != nil { +func (r *gatewayReconciler) deployModel(ctx context.Context, gw *gwv1.Gateway, stack core.Stack, secrets []types.NamespacedName, deployer deploy.StackDeployer) error { + if deployer == nil { + deployer = r.stackDeployer + } + if err := deployer.Deploy(ctx, stack, r.metricsCollector, r.controllerName); err != nil { var requeueNeededAfter *ctrlerrors.RequeueNeededAfter if errors.As(err, &requeueNeededAfter) { return err @@ -381,8 +446,8 @@ func (r *gatewayReconciler) deployModel(ctx context.Context, gw *gwv1.Gateway, s return nil } -func (r *gatewayReconciler) buildModel(ctx context.Context, gw *gwv1.Gateway, cfg elbv2gw.LoadBalancerConfiguration, listenerToRoute map[int32][]routeutils.RouteDescriptor, currentAddonConfig []addon.Addon, isDelete bool) (core.Stack, *elbv2model.LoadBalancer, []addon.AddonMetadata, bool, []types.NamespacedName, error) { - stack, lb, newAddOnConfig, backendSGRequired, secrets, err := r.modelBuilder.Build(ctx, gw, cfg, listenerToRoute, currentAddonConfig, r.secretsManager, r.targetGroupNameToArnMapper, isDelete) +func (r *gatewayReconciler) buildModel(ctx context.Context, gw *gwv1.Gateway, cfg elbv2gw.LoadBalancerConfiguration, listenerToRoute map[int32][]routeutils.RouteDescriptor, currentAddonConfig []addon.Addon, isDelete bool, rc *gatewaypkg.ReconcileContext) (core.Stack, *elbv2model.LoadBalancer, []addon.AddonMetadata, bool, []types.NamespacedName, error) { + stack, lb, newAddOnConfig, backendSGRequired, secrets, err := r.modelBuilder.Build(ctx, gw, cfg, listenerToRoute, currentAddonConfig, r.secretsManager, r.targetGroupNameToArnMapper, isDelete, rc) if err != nil { r.eventRecorder.Event(gw, corev1.EventTypeWarning, k8s.GatewayEventReasonFailedBuildModel, fmt.Sprintf("Failed build model due to %v", err)) return nil, nil, nil, false, nil, err diff --git a/docs/guide/gateway/loadbalancerconfig.md b/docs/guide/gateway/loadbalancerconfig.md index e2b69d6b58..d5e6947017 100644 --- a/docs/guide/gateway/loadbalancerconfig.md +++ b/docs/guide/gateway/loadbalancerconfig.md @@ -65,6 +65,26 @@ Defines the LoadBalancer Scheme. **Default** internal +#### Region + +`region` + +The AWS region where the load balancer will be deployed. When unset, the controller's default region (from `--aws-region` or environment) is used. When set to a different region, you must specify the VPC in that region using one of: `vpcId`, `vpcSelector`, or `loadBalancerSubnets` with subnet identifiers so the controller can resolve the VPC. + +**Default** Controller's default region + +#### VpcID + +`vpcId` + +The VPC ID in the target region. Used when `region` is set, especially when it differs from the controller default. Required (or use `vpcSelector` / `loadBalancerSubnets` with identifiers) when deploying to a non-default region. + +#### VpcSelector + +`vpcSelector` + +Selects the VPC in the target region by tags. Same shape as `loadBalancerSubnetsSelector`: each key is a tag name, the value list is the allowed tag values. A VPC matches if it has each tag key with one of the corresponding values. Exactly one VPC must match in the target region. Use when `region` is set and you prefer tag-based selection over a fixed `vpcId`. + #### IpAddressType `ipAddressType` diff --git a/docs/guide/gateway/spec.md b/docs/guide/gateway/spec.md index 04c92a1339..c9b255b1c2 100644 --- a/docs/guide/gateway/spec.md +++ b/docs/guide/gateway/spec.md @@ -460,6 +460,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `mergingMode` _[LoadBalancerConfigMergeMode](#loadbalancerconfigmergemode)_ | mergingMode defines the merge behavior when both the Gateway and GatewayClass have a defined LoadBalancerConfiguration.
This field is only honored for the configuration attached to the GatewayClass. | | Enum: [prefer-gateway prefer-gateway-class]
| +| `region` _string_ | region is the AWS region where the load balancer will be deployed. When unset, the controller's default region is used. When set to a different region, set vpcId, vpcSelector, or loadBalancerSubnets with identifiers so the VPC can be resolved. | | | +| `vpcId` _string_ | vpcId is the VPC ID in the target region. Used when region is set (especially when it differs from the controller default). | | | +| `vpcSelector` _map[string][]string_ | vpcSelector selects the VPC in the target region by tags. Each key is a tag name; the value list is the allowed tag values. Exactly one VPC must match in the target region. | | | | `loadBalancerName` _string_ | loadBalancerName defines the name of the LB to provision. If unspecified, it will be automatically generated. | | MaxLength: 32
MinLength: 1
| | `scheme` _[LoadBalancerScheme](#loadbalancerscheme)_ | scheme defines the type of LB to provision. If unspecified, it will be automatically inferred. | | Enum: [internal internet-facing]
| | `ipAddressType` _[LoadBalancerIpAddressType](#loadbalanceripaddresstype)_ | loadBalancerIPType defines what kind of load balancer to provision (ipv4, dual stack) | | Enum: [ipv4 dualstack dualstack-without-public-ipv4]
| @@ -472,7 +475,6 @@ _Appears in:_ | `securityGroups` _string_ | securityGroups an optional list of security group ids or names to apply to the LB | | | | `securityGroupPrefixes` _string_ | securityGroupPrefixes an optional list of prefixes that are allowed to access the LB. | | | | `sourceRanges` _string_ | sourceRanges an optional list of CIDRs that are allowed to access the LB. | | | -| `vpcId` _string_ | vpcId is the ID of the VPC for the load balancer. | | | | `loadBalancerAttributes` _[LoadBalancerAttribute](#loadbalancerattribute) array_ | LoadBalancerAttributes defines the attribute of LB | | | | `tags` _map[string]string_ | Tags the AWS Tags on all related resources to the gateway. | | | | `enableICMP` _boolean_ | EnableICMP [Network LoadBalancer]
enables the creation of security group rules to the managed security group
to allow explicit ICMP traffic for Path MTU discovery for IPv4 and dual-stack VPCs | | | diff --git a/helm/aws-load-balancer-controller/crds/gateway-crds.yaml b/helm/aws-load-balancer-controller/crds/gateway-crds.yaml index 212daa53b1..3f568f6830 100644 --- a/helm/aws-load-balancer-controller/crds/gateway-crds.yaml +++ b/helm/aws-load-balancer-controller/crds/gateway-crds.yaml @@ -733,6 +733,11 @@ spec: required: - capacityUnits type: object + region: + description: |- + region is the AWS region where the load balancer will be deployed. When unset, the controller's default region is used. + When set to a different region, vpcId, vpcSelector, or loadBalancerSubnets with identifiers must be set so the VPC can be resolved. + type: string scheme: description: scheme defines the type of LB to provision. If unspecified, it will be automatically inferred. @@ -772,6 +777,19 @@ spec: type: string description: Tags the AWS Tags on all related resources to the gateway. type: object + vpcId: + description: vpcId is the VPC ID in the target region. Used when region + is set (and especially when it differs from the controller default). + type: string + vpcSelector: + additionalProperties: + items: + type: string + type: array + description: |- + vpcSelector selects the VPC in the target region by tags. Each key is a tag name; the value list is the allowed tag values. + A VPC matches if it has each tag key with one of the corresponding values. Exactly one VPC must match in the target region. + type: object wafV2: description: WAFv2 define the AWS WAFv2 settings for a Gateway [Application Load Balancer] diff --git a/main.go b/main.go index 5a4fa0bceb..809fd81886 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,8 @@ import ( elbv2gw "sigs.k8s.io/aws-load-balancer-controller/apis/gateway/v1beta1" "sigs.k8s.io/aws-load-balancer-controller/controllers/gateway" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy" + gatewaypkg "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway" gateway_constants "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway/constants" "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway/referencecounter" "sigs.k8s.io/aws-load-balancer-controller/pkg/gateway/routeutils" @@ -119,6 +121,9 @@ type gatewayControllerConfig struct { targetGroupCollector awsmetrics.TargetGroupCollector targetGroupARNMapper shared_utils.TargetGroupARNMapper certDiscovery certs.CertDiscovery + cloudProvider gatewaypkg.CloudProvider + stackDeployerFactoryALB func(cloud services.Cloud) deploy.StackDeployer + stackDeployerFactoryNLB func(cloud services.Cloud) deploy.StackDeployer } func main() { @@ -198,7 +203,11 @@ func main() { tgArnMapper := shared_utils.NewTargetGroupNameToArnMapper(cloud.ELBV2()) - tgbResManager := targetgroupbinding.NewDefaultResourceManager(mgr.GetClient(), cloud.ELBV2(), + elbv2ForRegion := func(region string) (services.ELBV2, error) { + return aws.NewELBV2ForRegion(controllerCFG.AWSConfig, region, awsMetricsCollector, ctrl.Log.WithName("elbv2-"+region), aws.DefaultLbStabilizationTime) + } + + tgbResManager := targetgroupbinding.NewDefaultResourceManager(mgr.GetClient(), cloud.ELBV2(), cloud.Region(), elbv2ForRegion, podInfoRepo, networkingManager, vpcInfoProvider, multiClusterManager, lbcMetricsCollector, cloud.VpcID(), controllerCFG.FeatureGates.Enabled(config.EndpointsFailOpen), controllerCFG.EnableEndpointSlices, mgr.GetEventRecorderFor("targetGroupBinding"), ctrl.Log, controllerCFG.MaxTargetsPerTargetGroup) @@ -263,6 +272,20 @@ func main() { serviceReferenceCounter := referencecounter.NewServiceReferenceCounter() certDiscovery := certs.NewACMCertDiscovery(cloud.ACM(), controllerCFG.IngressConfig.AllowedCertificateAuthorityARNs, ctrl.Log.WithName("gateway-cert-discovery")) + cloudProvider := gatewaypkg.NewDefaultCloudProvider(cloud, subnetResolver, vpcInfoProvider, controllerCFG.AWSConfig, controllerCFG, mgr.GetClient(), awsMetricsCollector, ctrl.Log.WithName("cloud-provider")) + // Use region-scoped networking SG manager and reconciler per cloud so the deployer lists/finds SGs and reconciles ingress in the target region (fixes cross-region Duplicate and DescribeSecurityGroups NotFound). + stackDeployerFactoryALB := func(c services.Cloud) deploy.StackDeployer { + deployLogger := ctrl.Log.WithName("deploy").WithName(c.Region()) + networkingSGManagerForCloud := networking.NewDefaultSecurityGroupManager(c.EC2(), deployLogger) + sgReconcilerForCloud := networking.NewDefaultSecurityGroupReconciler(networkingSGManagerForCloud, deployLogger) + return deploy.NewDefaultStackDeployer(c, mgr.GetClient(), networkingManager, networkingSGManagerForCloud, sgReconcilerForCloud, elbv2deploy.NewDefaultTaggingManager(c.ELBV2(), c.VpcID(), controllerCFG.FeatureGates, c.RGT(), ctrl.Log), controllerCFG, gateway_constants.ALBGatewayTagPrefix, ctrl.Log, lbcMetricsCollector, gateway_constants.ALBGatewayController, true, targetGroupCollector, false) + } + stackDeployerFactoryNLB := func(c services.Cloud) deploy.StackDeployer { + deployLogger := ctrl.Log.WithName("deploy").WithName(c.Region()) + networkingSGManagerForCloud := networking.NewDefaultSecurityGroupManager(c.EC2(), deployLogger) + sgReconcilerForCloud := networking.NewDefaultSecurityGroupReconciler(networkingSGManagerForCloud, deployLogger) + return deploy.NewDefaultStackDeployer(c, mgr.GetClient(), networkingManager, networkingSGManagerForCloud, sgReconcilerForCloud, elbv2deploy.NewDefaultTaggingManager(c.ELBV2(), c.VpcID(), controllerCFG.FeatureGates, c.RGT(), ctrl.Log), controllerCFG, gateway_constants.NLBGatewayTagPrefix, ctrl.Log, lbcMetricsCollector, gateway_constants.NLBGatewayController, true, targetGroupCollector, true) + } gwControllerConfig := &gatewayControllerConfig{ cloud: cloud, k8sClient: mgr.GetClient(), @@ -282,6 +305,9 @@ func main() { targetGroupCollector: targetGroupCollector, targetGroupARNMapper: tgArnMapper, certDiscovery: certDiscovery, + cloudProvider: cloudProvider, + stackDeployerFactoryALB: stackDeployerFactoryALB, + stackDeployerFactoryNLB: stackDeployerFactoryNLB, } enabledControllers := sets.Set[string]{} @@ -441,8 +467,8 @@ func main() { } corewebhook.NewServiceMutator(controllerCFG.ServiceConfig.LoadBalancerClass, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) elbv2webhook.NewIngressClassParamsValidator(lbcMetricsCollector).SetupWithManager(mgr) - elbv2webhook.NewTargetGroupBindingMutator(cloud.ELBV2(), ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) - elbv2webhook.NewTargetGroupBindingValidator(mgr.GetClient(), cloud.ELBV2(), cloud.VpcID(), ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) + elbv2webhook.NewTargetGroupBindingMutator(cloud.ELBV2(), cloud.Region(), elbv2ForRegion, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) + elbv2webhook.NewTargetGroupBindingValidator(mgr.GetClient(), cloud.ELBV2(), cloud.Region(), elbv2ForRegion, cloud.VpcID(), ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) networkingwebhook.NewIngressValidator(mgr.GetClient(), controllerCFG.IngressConfig, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) // Setup GlobalAccelerator validator only if enabled @@ -514,6 +540,8 @@ func setupGatewayController(ctx context.Context, mgr ctrl.Manager, cfg *gatewayC cfg.reconcileCounters, cfg.targetGroupCollector, cfg.targetGroupARNMapper, + cfg.cloudProvider, + cfg.stackDeployerFactoryNLB, ) case gateway_constants.ALBGatewayController: reconciler = gateway.NewALBGatewayReconciler( @@ -538,6 +566,8 @@ func setupGatewayController(ctx context.Context, mgr ctrl.Manager, cfg *gatewayC cfg.reconcileCounters, cfg.targetGroupCollector, cfg.targetGroupARNMapper, + cfg.cloudProvider, + cfg.stackDeployerFactoryALB, ) default: return fmt.Errorf("unknown controller type: %s", controllerType) diff --git a/pkg/aws/region.go b/pkg/aws/region.go new file mode 100644 index 0000000000..18870d18ea --- /dev/null +++ b/pkg/aws/region.go @@ -0,0 +1,83 @@ +package aws + +import ( + "context" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/go-logr/logr" + "github.com/pkg/errors" + epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + aws_metrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws" +) + +// NewCloudForRegion creates a Cloud for the given region and vpcID without using EC2 metadata. +// The base cfg is copied; only Region and VpcID are set. Use this when deploying to a non-default region. +func NewCloudForRegion(cfg CloudConfig, region, vpcID string, clusterName string, metricsCollector *aws_metrics.Collector, logger logr.Logger, lbStabilizationTime time.Duration) (services.Cloud, error) { + cfgForRegion := cfg + cfgForRegion.Region = region + cfgForRegion.VpcID = vpcID + return NewCloud(cfgForRegion, clusterName, metricsCollector, logger, nil, lbStabilizationTime) +} + +// NewEC2ClientForRegion returns an EC2 client configured for the given region. +// Used for VPC resolution (e.g. DescribeVpcs, DescribeSubnets) in that region before creating a full Cloud. +func NewEC2ClientForRegion(cfg CloudConfig, region string, metricsCollector *aws_metrics.Collector, logger logr.Logger) (services.EC2, error) { + awsClientsProvider, err := newClientsProviderForRegion(cfg, region, metricsCollector) + if err != nil { + return nil, err + } + return services.NewEC2(awsClientsProvider), nil +} + +// NewELBV2ForRegion returns an ELBV2 client configured for the given region. +// Used by webhooks and model builders that need to describe/validate resources in a non-default region. +// This does not require a VPC ID or full Cloud — only the region is needed for API calls like DescribeTargetGroups. +func NewELBV2ForRegion(cfg CloudConfig, region string, metricsCollector *aws_metrics.Collector, logger logr.Logger, lbStabilizationTime time.Duration) (services.ELBV2, error) { + awsClientsProvider, err := newClientsProviderForRegion(cfg, region, metricsCollector) + if err != nil { + return nil, err + } + cloud := ®ionStubCloud{region: region} + elbv2 := services.NewELBV2(awsClientsProvider, cloud, lbStabilizationTime) + cloud.elbv2 = elbv2 + return elbv2, nil +} + +func newClientsProviderForRegion(cfg CloudConfig, region string, metricsCollector *aws_metrics.Collector) (provider.AWSClientsProvider, error) { + cfgForRegion := cfg + cfgForRegion.Region = region + endpointsResolver := epresolver.NewResolver(cfgForRegion.AWSEndpoints) + configGenerator := NewAWSConfigGenerator(cfgForRegion, imds.EndpointModeStateIPv4, metricsCollector) + awsConfig, err := configGenerator.GenerateAWSConfig() + if err != nil { + return nil, err + } + return provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver) +} + +// regionStubCloud is a minimal Cloud implementation used only by NewELBV2ForRegion. +// Only Region() and ELBV2() are meaningful; other methods are unused by the ELBV2 client +// for basic operations like DescribeTargetGroups. +type regionStubCloud struct { + region string + elbv2 services.ELBV2 +} + +func (c *regionStubCloud) Region() string { return c.region } +func (c *regionStubCloud) VpcID() string { return "" } +func (c *regionStubCloud) ELBV2() services.ELBV2 { return c.elbv2 } +func (c *regionStubCloud) EC2() services.EC2 { return nil } +func (c *regionStubCloud) ACM() services.ACM { return nil } +func (c *regionStubCloud) WAFv2() services.WAFv2 { return nil } +func (c *regionStubCloud) WAFRegional() services.WAFRegional { return nil } +func (c *regionStubCloud) Shield() services.Shield { return nil } +func (c *regionStubCloud) RGT() services.RGT { return nil } +func (c *regionStubCloud) GlobalAccelerator() services.GlobalAccelerator { return nil } +func (c *regionStubCloud) GetAssumedRoleELBV2(_ context.Context, _ string, _ string) (services.ELBV2, error) { + return nil, errors.New("AssumeRole is not supported for cross-region stub cloud; use a full Cloud instead") +} + +var _ services.Cloud = ®ionStubCloud{} diff --git a/pkg/deploy/ec2/security_group_synthesizer.go b/pkg/deploy/ec2/security_group_synthesizer.go index b5e76845dc..c92b2ab196 100644 --- a/pkg/deploy/ec2/security_group_synthesizer.go +++ b/pkg/deploy/ec2/security_group_synthesizer.go @@ -2,6 +2,12 @@ package ec2 import ( "context" + "strings" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + ec2sdk "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/smithy-go" "github.com/go-logr/logr" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/util/sets" @@ -57,6 +63,21 @@ func (s *securityGroupSynthesizer) Synthesize(ctx context.Context) error { for _, resSG := range unmatchedResSGs { sgStatus, err := s.sgManager.Create(ctx, resSG) if err != nil { + if isInvalidGroupDuplicateError(err) { + sdkSG, findErr := s.findExistingSecurityGroupByName(ctx, resSG.Spec.GroupName) + if findErr != nil { + return errors.Wrapf(err, "Create failed with Duplicate and finding existing SG by name failed: %v", findErr) + } + if sdkSG != nil { + s.logger.Info("adopting existing security group after Duplicate", "groupName", resSG.Spec.GroupName, "securityGroupID", sdkSG.SecurityGroupID) + sgStatus, updateErr := s.sgManager.Update(ctx, resSG, *sdkSG) + if updateErr != nil { + return errors.Wrapf(updateErr, "failed to update adopted security group %s", sdkSG.SecurityGroupID) + } + resSG.SetStatus(sgStatus) + continue + } + } return err } resSG.SetStatus(sgStatus) @@ -148,3 +169,30 @@ func mapSDKSecurityGroupByResourceID(sdkSGs []networking.SecurityGroupInfo, reso } return sdkSGsByID, nil } + +func isInvalidGroupDuplicateError(err error) bool { + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return apiErr.ErrorCode() == "InvalidGroup.Duplicate" + } + return strings.Contains(err.Error(), "InvalidGroup.Duplicate") +} + +// findExistingSecurityGroupByName describes SGs in the synthesizer's VPC by group name and returns the first match, or nil if not found. +func (s *securityGroupSynthesizer) findExistingSecurityGroupByName(ctx context.Context, groupName string) (*networking.SecurityGroupInfo, error) { + req := &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []ec2types.Filter{ + {Name: awssdk.String("vpc-id"), Values: []string{s.vpcID}}, + {Name: awssdk.String("group-name"), Values: []string{groupName}}, + }, + } + sgs, err := s.ec2Client.DescribeSecurityGroupsAsList(ctx, req) + if err != nil { + return nil, err + } + if len(sgs) == 0 { + return nil, nil + } + info := networking.NewRawSecurityGroupInfo(sgs[0]) + return &info, nil +} diff --git a/pkg/gateway/cloud_provider.go b/pkg/gateway/cloud_provider.go new file mode 100644 index 0000000000..46be2f29a9 --- /dev/null +++ b/pkg/gateway/cloud_provider.go @@ -0,0 +1,281 @@ +package gateway + +import ( + "context" + "fmt" + "sync" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + ec2sdk "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/go-logr/logr" + "github.com/pkg/errors" + elbv2gw "sigs.k8s.io/aws-load-balancer-controller/apis/gateway/v1beta1" + awspkg "sigs.k8s.io/aws-load-balancer-controller/pkg/aws" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/aws-load-balancer-controller/pkg/certs" + "sigs.k8s.io/aws-load-balancer-controller/pkg/config" + elbv2deploy "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/elbv2" + awsmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws" + "sigs.k8s.io/aws-load-balancer-controller/pkg/networking" + "sigs.k8s.io/aws-load-balancer-controller/pkg/shared_utils" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// ReconcileContext holds the Cloud and region-specific resolvers for a gateway reconcile. +// For the default region it wraps the default cloud and resolvers with optional fields nil. +// For non-default regions all fields are populated with region-scoped implementations. +type ReconcileContext struct { + Cloud services.Cloud + SubnetsResolver networking.SubnetsResolver + VPCInfoProvider networking.VPCInfoProvider + Elbv2TaggingManager elbv2deploy.TaggingManager + BackendSGProvider networking.BackendSGProvider + SecurityGroupResolver networking.SecurityGroupResolver + CertDiscovery certs.CertDiscovery + TargetGroupARNMapper shared_utils.TargetGroupARNMapper + crossRegion bool +} + +// GetCloud returns the Cloud for this context. +func (r *ReconcileContext) GetCloud() services.Cloud { return r.Cloud } + +// GetSubnetsResolver returns the SubnetsResolver for this context. +func (r *ReconcileContext) GetSubnetsResolver() networking.SubnetsResolver { return r.SubnetsResolver } + +// GetVPCInfoProvider returns the VPCInfoProvider for this context. +func (r *ReconcileContext) GetVPCInfoProvider() networking.VPCInfoProvider { return r.VPCInfoProvider } + +// GetElbv2TaggingManager returns the ELBV2 tagging manager for this context, or nil to use the default. +func (r *ReconcileContext) GetElbv2TaggingManager() elbv2deploy.TaggingManager { + return r.Elbv2TaggingManager +} + +// GetBackendSGProvider returns the BackendSGProvider for this context, or nil to use the default (default region). +func (r *ReconcileContext) GetBackendSGProvider() networking.BackendSGProvider { + return r.BackendSGProvider +} + +// GetSecurityGroupResolver returns the SecurityGroupResolver for this context, or nil to use the default. +func (r *ReconcileContext) GetSecurityGroupResolver() networking.SecurityGroupResolver { + return r.SecurityGroupResolver +} + +// GetCertDiscovery returns the CertDiscovery for this context, or nil to use the default (default region's ACM). +func (r *ReconcileContext) GetCertDiscovery() certs.CertDiscovery { return r.CertDiscovery } + +// GetTargetGroupARNMapper returns the TargetGroupARNMapper for this context, or nil to use the default. +func (r *ReconcileContext) GetTargetGroupARNMapper() shared_utils.TargetGroupARNMapper { + return r.TargetGroupARNMapper +} + +// IsCrossRegion returns true when the gateway targets a region different from the controller's default. +func (r *ReconcileContext) IsCrossRegion() bool { return r.crossRegion } + +// CloudProvider returns a ReconcileContext for a given region and optional LoadBalancerConfiguration spec. +// For the default region it returns the default context; for other regions it resolves VPC and creates (or caches) a Cloud and resolvers. +type CloudProvider interface { + GetReconcileContext(ctx context.Context, region string, spec *elbv2gw.LoadBalancerConfigurationSpec) (*ReconcileContext, error) +} + +// NewDefaultCloudProvider returns a CloudProvider that uses the default cloud for the default region +// and creates Clouds for other regions with VPC resolution from spec (vpcId, vpcSelector, or first subnet). +// k8sClient is used to create region-scoped BackendSGProvider for non-default regions. +func NewDefaultCloudProvider( + defaultCloud services.Cloud, + defaultSubnetsResolver networking.SubnetsResolver, + defaultVPCInfoProvider networking.VPCInfoProvider, + baseConfig awspkg.CloudConfig, + controllerConfig config.ControllerConfig, + k8sClient client.Client, + metricsCollector *awsmetrics.Collector, + logger logr.Logger, +) CloudProvider { + return &defaultCloudProvider{ + defaultCloud: defaultCloud, + defaultSubnetsResolver: defaultSubnetsResolver, + defaultVPCInfoProvider: defaultVPCInfoProvider, + baseConfig: baseConfig, + controllerConfig: controllerConfig, + k8sClient: k8sClient, + metricsCollector: metricsCollector, + logger: logger, + cache: make(map[string]*ReconcileContext), + } +} + +var _ CloudProvider = &defaultCloudProvider{} + +type defaultCloudProvider struct { + defaultCloud services.Cloud + defaultSubnetsResolver networking.SubnetsResolver + defaultVPCInfoProvider networking.VPCInfoProvider + baseConfig awspkg.CloudConfig + controllerConfig config.ControllerConfig + k8sClient client.Client + metricsCollector *awsmetrics.Collector + logger logr.Logger + mu sync.RWMutex + cache map[string]*ReconcileContext +} + +func (p *defaultCloudProvider) GetReconcileContext(ctx context.Context, region string, spec *elbv2gw.LoadBalancerConfigurationSpec) (*ReconcileContext, error) { + defaultRegion := p.defaultCloud.Region() + if region == "" || region == defaultRegion { + return &ReconcileContext{ + Cloud: p.defaultCloud, + SubnetsResolver: p.defaultSubnetsResolver, + VPCInfoProvider: p.defaultVPCInfoProvider, + }, nil + } + + // Resolve VPC for the target region + vpcID, err := p.resolveVPCForRegion(ctx, region, spec) + if err != nil { + return nil, err + } + + cacheKey := region + ":" + vpcID + p.mu.RLock() + if ctx, ok := p.cache[cacheKey]; ok { + p.mu.RUnlock() + return ctx, nil + } + p.mu.RUnlock() + + p.mu.Lock() + defer p.mu.Unlock() + if ctx, ok := p.cache[cacheKey]; ok { + return ctx, nil + } + + cloud, err := awspkg.NewCloudForRegion(p.baseConfig, region, vpcID, p.controllerConfig.ClusterName, p.metricsCollector, p.logger, awspkg.DefaultLbStabilizationTime) + if err != nil { + return nil, errors.Wrapf(err, "failed to create cloud for region %q vpc %q", region, vpcID) + } + + azInfoProvider := networking.NewDefaultAZInfoProvider(cloud.EC2(), p.logger.WithName("az-info-provider")) + vpcInfoProvider := networking.NewDefaultVPCInfoProvider(cloud.EC2(), p.logger.WithName("vpc-info-provider")) + subnetsResolver := networking.NewDefaultSubnetsResolver( + azInfoProvider, + cloud.EC2(), + cloud.VpcID(), + p.controllerConfig.ClusterName, + p.controllerConfig.FeatureGates.Enabled(config.SubnetsClusterTagCheck), + p.controllerConfig.FeatureGates.Enabled(config.ALBSingleSubnet), + p.controllerConfig.FeatureGates.Enabled(config.SubnetDiscoveryByReachability), + p.logger.WithName("subnets-resolver"), + ) + + elbv2TaggingManager := elbv2deploy.NewDefaultTaggingManager(cloud.ELBV2(), cloud.VpcID(), p.controllerConfig.FeatureGates, cloud.RGT(), p.logger.WithName("elbv2-tagging")) + enableGatewayCheck := p.controllerConfig.FeatureGates.Enabled(config.NLBGatewayAPI) || p.controllerConfig.FeatureGates.Enabled(config.ALBGatewayAPI) + backendSGProvider := networking.NewBackendSGProvider( + p.controllerConfig.ClusterName, + p.controllerConfig.BackendSecurityGroup, + cloud.VpcID(), + cloud.EC2(), + p.k8sClient, + p.controllerConfig.DefaultTags, + enableGatewayCheck, + p.logger.WithName("backend-sg-provider").WithName(region), + ) + sgResolver := networking.NewDefaultSecurityGroupResolver(cloud.EC2(), cloud.VpcID()) + certDiscovery := certs.NewACMCertDiscovery(cloud.ACM(), p.controllerConfig.IngressConfig.AllowedCertificateAuthorityARNs, p.logger.WithName("cert-discovery").WithName(region)) + tgARNMapper := shared_utils.NewTargetGroupNameToArnMapper(cloud.ELBV2()) + reconcileCtx := &ReconcileContext{ + Cloud: cloud, + SubnetsResolver: subnetsResolver, + VPCInfoProvider: vpcInfoProvider, + Elbv2TaggingManager: elbv2TaggingManager, + BackendSGProvider: backendSGProvider, + SecurityGroupResolver: sgResolver, + CertDiscovery: certDiscovery, + TargetGroupARNMapper: tgARNMapper, + crossRegion: true, + } + p.cache[cacheKey] = reconcileCtx + return reconcileCtx, nil +} + +// resolveVPCForRegion resolves the VPC ID for the target region from spec (vpcId, vpcSelector, or first subnet). +func (p *defaultCloudProvider) resolveVPCForRegion(ctx context.Context, region string, spec *elbv2gw.LoadBalancerConfigurationSpec) (string, error) { + if spec == nil { + return "", errors.New("when region differs from controller default, set vpcId, vpcSelector, or loadBalancerSubnets with identifiers in LoadBalancerConfiguration") + } + + if spec.VpcID != nil && *spec.VpcID != "" { + return *spec.VpcID, nil + } + + ec2Client, err := awspkg.NewEC2ClientForRegion(p.baseConfig, region, p.metricsCollector, p.logger) + if err != nil { + return "", errors.Wrapf(err, "failed to create EC2 client for region %q", region) + } + + if spec.VpcSelector != nil && len(*spec.VpcSelector) > 0 { + vpcID, err := p.resolveVPCFromSelector(ctx, ec2Client, *spec.VpcSelector) + if err != nil { + return "", err + } + return vpcID, nil + } + + if spec.LoadBalancerSubnets != nil && len(*spec.LoadBalancerSubnets) > 0 { + first := (*spec.LoadBalancerSubnets)[0] + if first.Identifier != "" { + vpcID, err := p.resolveVPCFromFirstSubnet(ctx, ec2Client, first.Identifier) + if err != nil { + return "", err + } + return vpcID, nil + } + } + + return "", errors.New("when region differs from controller default, set vpcId, vpcSelector, or loadBalancerSubnets with identifiers in LoadBalancerConfiguration") +} + +func (p *defaultCloudProvider) resolveVPCFromSelector(ctx context.Context, ec2Client services.EC2, selector map[string][]string) (string, error) { + filters := make([]ec2types.Filter, 0, len(selector)) + for tagKey, values := range selector { + if len(values) == 0 { + continue + } + filters = append(filters, ec2types.Filter{ + Name: awssdk.String("tag:" + tagKey), + Values: values, + }) + } + if len(filters) == 0 { + return "", errors.New("vpcSelector must have at least one tag key with values") + } + + vpcs, err := ec2Client.DescribeVPCsAsList(ctx, &ec2sdk.DescribeVpcsInput{ + Filters: filters, + }) + if err != nil { + return "", errors.Wrap(err, "failed to describe VPCs by tag selector") + } + if len(vpcs) == 0 { + return "", errors.New("no VPC found matching vpcSelector in target region") + } + if len(vpcs) > 1 { + return "", fmt.Errorf("multiple VPCs (%d) found matching vpcSelector in target region; exactly one is required", len(vpcs)) + } + return awssdk.ToString(vpcs[0].VpcId), nil +} + +func (p *defaultCloudProvider) resolveVPCFromFirstSubnet(ctx context.Context, ec2Client services.EC2, subnetIDOrName string) (string, error) { + subnets, err := ec2Client.DescribeSubnetsAsList(ctx, &ec2sdk.DescribeSubnetsInput{ + SubnetIds: []string{subnetIDOrName}, + }) + if err != nil { + return "", errors.Wrapf(err, "failed to describe subnet %q in target region", subnetIDOrName) + } + if len(subnets) == 0 { + return "", fmt.Errorf("subnet %q not found in target region", subnetIDOrName) + } + if subnets[0].VpcId == nil { + return "", fmt.Errorf("subnet %q has no VpcId", subnetIDOrName) + } + return *subnets[0].VpcId, nil +} diff --git a/pkg/gateway/cloud_provider_test.go b/pkg/gateway/cloud_provider_test.go new file mode 100644 index 0000000000..6e3914c6e8 --- /dev/null +++ b/pkg/gateway/cloud_provider_test.go @@ -0,0 +1,406 @@ +package gateway + +import ( + "context" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + ec2sdk "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/go-logr/logr" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + elbv2gw "sigs.k8s.io/aws-load-balancer-controller/apis/gateway/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/aws-load-balancer-controller/pkg/certs" + elbv2deploy "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/elbv2" + "sigs.k8s.io/aws-load-balancer-controller/pkg/networking" + "sigs.k8s.io/aws-load-balancer-controller/pkg/shared_utils" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +// fakeCloud implements services.Cloud for testing without needing the full aws package. +type fakeCloud struct { + region string + vpcID string +} + +func (c *fakeCloud) Region() string { return c.region } +func (c *fakeCloud) VpcID() string { return c.vpcID } +func (c *fakeCloud) ELBV2() services.ELBV2 { return nil } +func (c *fakeCloud) EC2() services.EC2 { return nil } +func (c *fakeCloud) ACM() services.ACM { return nil } +func (c *fakeCloud) WAFv2() services.WAFv2 { return nil } +func (c *fakeCloud) WAFRegional() services.WAFRegional { return nil } +func (c *fakeCloud) Shield() services.Shield { return nil } +func (c *fakeCloud) RGT() services.RGT { return nil } +func (c *fakeCloud) GlobalAccelerator() services.GlobalAccelerator { return nil } +func (c *fakeCloud) GetAssumedRoleELBV2(_ context.Context, _ string, _ string) (services.ELBV2, error) { + return nil, errors.New("not supported") +} + +var _ services.Cloud = &fakeCloud{} + +// --- ReconcileContext getter tests --- + +func TestReconcileContext_Getters(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cloud := &fakeCloud{region: "us-east-1", vpcID: "vpc-111"} + subnetsResolver := networking.NewMockSubnetsResolver(ctrl) + vpcInfoProvider := networking.NewMockVPCInfoProvider(ctrl) + + rc := &ReconcileContext{ + Cloud: cloud, + SubnetsResolver: subnetsResolver, + VPCInfoProvider: vpcInfoProvider, + } + + assert.Equal(t, cloud, rc.GetCloud()) + assert.Equal(t, subnetsResolver, rc.GetSubnetsResolver()) + assert.Equal(t, vpcInfoProvider, rc.GetVPCInfoProvider()) + assert.Nil(t, rc.GetElbv2TaggingManager()) + assert.Nil(t, rc.GetBackendSGProvider()) + assert.Nil(t, rc.GetSecurityGroupResolver()) + assert.Nil(t, rc.GetCertDiscovery()) + assert.Nil(t, rc.GetTargetGroupARNMapper()) + assert.False(t, rc.IsCrossRegion()) +} + +func TestReconcileContext_NonDefaultRegionGetters(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cloud := &fakeCloud{region: "ap-northeast-1", vpcID: "vpc-222"} + subnetsResolver := networking.NewMockSubnetsResolver(ctrl) + vpcInfoProvider := networking.NewMockVPCInfoProvider(ctrl) + + type fakeTaggingManager struct{ elbv2deploy.TaggingManager } + type fakeBackendSGProvider struct{ networking.BackendSGProvider } + type fakeSGResolver struct{ networking.SecurityGroupResolver } + type fakeCertDiscovery struct{ certs.CertDiscovery } + type fakeTGMapper struct{ shared_utils.TargetGroupARNMapper } + + taggingMgr := &fakeTaggingManager{} + backendSG := &fakeBackendSGProvider{} + sgResolver := &fakeSGResolver{} + certDisc := &fakeCertDiscovery{} + tgMapper := &fakeTGMapper{} + + rc := &ReconcileContext{ + Cloud: cloud, + SubnetsResolver: subnetsResolver, + VPCInfoProvider: vpcInfoProvider, + Elbv2TaggingManager: taggingMgr, + BackendSGProvider: backendSG, + SecurityGroupResolver: sgResolver, + CertDiscovery: certDisc, + TargetGroupARNMapper: tgMapper, + } + + assert.Equal(t, cloud, rc.GetCloud()) + assert.Equal(t, subnetsResolver, rc.GetSubnetsResolver()) + assert.Equal(t, vpcInfoProvider, rc.GetVPCInfoProvider()) + assert.Equal(t, taggingMgr, rc.GetElbv2TaggingManager()) + assert.Equal(t, backendSG, rc.GetBackendSGProvider()) + assert.Equal(t, sgResolver, rc.GetSecurityGroupResolver()) + assert.Equal(t, certDisc, rc.GetCertDiscovery()) + assert.Equal(t, tgMapper, rc.GetTargetGroupARNMapper()) + assert.False(t, rc.IsCrossRegion()) +} + +func TestReconcileContext_IsCrossRegion(t *testing.T) { + defaultRC := &ReconcileContext{ + Cloud: &fakeCloud{region: "us-east-1"}, + } + assert.False(t, defaultRC.IsCrossRegion()) + + crossRegionRC := &ReconcileContext{ + Cloud: &fakeCloud{region: "ap-northeast-1"}, + crossRegion: true, + } + assert.True(t, crossRegionRC.IsCrossRegion()) +} + +// --- GetReconcileContext default region tests --- + +func TestGetReconcileContext_DefaultRegion(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cloud := &fakeCloud{region: "us-east-1", vpcID: "vpc-default"} + subnetsResolver := networking.NewMockSubnetsResolver(ctrl) + vpcInfoProvider := networking.NewMockVPCInfoProvider(ctrl) + + provider := &defaultCloudProvider{ + defaultCloud: cloud, + defaultSubnetsResolver: subnetsResolver, + defaultVPCInfoProvider: vpcInfoProvider, + logger: logr.New(&log.NullLogSink{}), + cache: make(map[string]*ReconcileContext), + } + + rc, err := provider.GetReconcileContext(context.Background(), "us-east-1", nil) + assert.NoError(t, err) + assert.Equal(t, cloud, rc.Cloud) + assert.Equal(t, subnetsResolver, rc.SubnetsResolver) + assert.Equal(t, vpcInfoProvider, rc.VPCInfoProvider) + assert.Nil(t, rc.Elbv2TaggingManager) + assert.Nil(t, rc.BackendSGProvider) + assert.Nil(t, rc.SecurityGroupResolver) + assert.Nil(t, rc.CertDiscovery) + assert.Nil(t, rc.TargetGroupARNMapper) + assert.False(t, rc.IsCrossRegion()) +} + +func TestGetReconcileContext_EmptyRegion(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + cloud := &fakeCloud{region: "us-east-1", vpcID: "vpc-default"} + subnetsResolver := networking.NewMockSubnetsResolver(ctrl) + vpcInfoProvider := networking.NewMockVPCInfoProvider(ctrl) + + provider := &defaultCloudProvider{ + defaultCloud: cloud, + defaultSubnetsResolver: subnetsResolver, + defaultVPCInfoProvider: vpcInfoProvider, + logger: logr.New(&log.NullLogSink{}), + cache: make(map[string]*ReconcileContext), + } + + rc, err := provider.GetReconcileContext(context.Background(), "", nil) + assert.NoError(t, err) + assert.Equal(t, cloud, rc.Cloud) + assert.False(t, rc.IsCrossRegion()) +} + +// --- resolveVPCForRegion tests --- + +func TestResolveVPCForRegion_NilSpec(t *testing.T) { + provider := &defaultCloudProvider{ + logger: logr.New(&log.NullLogSink{}), + } + + _, err := provider.resolveVPCForRegion(context.Background(), "ap-northeast-1", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "set vpcId, vpcSelector, or loadBalancerSubnets") +} + +func TestResolveVPCForRegion_ExplicitVpcID(t *testing.T) { + provider := &defaultCloudProvider{ + logger: logr.New(&log.NullLogSink{}), + } + + spec := &elbv2gw.LoadBalancerConfigurationSpec{ + VpcID: awssdk.String("vpc-explicit"), + } + vpcID, err := provider.resolveVPCForRegion(context.Background(), "ap-northeast-1", spec) + assert.NoError(t, err) + assert.Equal(t, "vpc-explicit", vpcID) +} + +func TestResolveVPCForRegion_EmptyVpcID_NoSelectorNoSubnets(t *testing.T) { + provider := &defaultCloudProvider{ + logger: logr.New(&log.NullLogSink{}), + } + + spec := &elbv2gw.LoadBalancerConfigurationSpec{ + VpcID: awssdk.String(""), + } + _, err := provider.resolveVPCForRegion(context.Background(), "ap-northeast-1", spec) + assert.Error(t, err) + assert.Contains(t, err.Error(), "set vpcId, vpcSelector, or loadBalancerSubnets") +} + +func TestResolveVPCForRegion_EmptySpec(t *testing.T) { + provider := &defaultCloudProvider{ + logger: logr.New(&log.NullLogSink{}), + } + + spec := &elbv2gw.LoadBalancerConfigurationSpec{} + _, err := provider.resolveVPCForRegion(context.Background(), "eu-west-1", spec) + assert.Error(t, err) + assert.Contains(t, err.Error(), "set vpcId, vpcSelector, or loadBalancerSubnets") +} + +// --- resolveVPCFromSelector tests --- + +func TestResolveVPCFromSelector_SingleMatch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeVPCsAsList(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *ec2sdk.DescribeVpcsInput) ([]ec2types.Vpc, error) { + assert.Len(t, input.Filters, 1) + assert.Equal(t, "tag:env", *input.Filters[0].Name) + assert.Equal(t, []string{"production"}, input.Filters[0].Values) + return []ec2types.Vpc{ + {VpcId: awssdk.String("vpc-matched")}, + }, nil + }, + ) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + vpcID, err := provider.resolveVPCFromSelector(context.Background(), ec2Client, map[string][]string{ + "env": {"production"}, + }) + assert.NoError(t, err) + assert.Equal(t, "vpc-matched", vpcID) +} + +func TestResolveVPCFromSelector_NoMatch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeVPCsAsList(gomock.Any(), gomock.Any()).Return([]ec2types.Vpc{}, nil) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromSelector(context.Background(), ec2Client, map[string][]string{ + "env": {"staging"}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no VPC found matching vpcSelector") +} + +func TestResolveVPCFromSelector_MultipleMatches(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeVPCsAsList(gomock.Any(), gomock.Any()).Return([]ec2types.Vpc{ + {VpcId: awssdk.String("vpc-1")}, + {VpcId: awssdk.String("vpc-2")}, + }, nil) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromSelector(context.Background(), ec2Client, map[string][]string{ + "env": {"production"}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "multiple VPCs (2) found") +} + +func TestResolveVPCFromSelector_EmptySelector(t *testing.T) { + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromSelector(context.Background(), nil, map[string][]string{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "at least one tag key with values") +} + +func TestResolveVPCFromSelector_SelectorWithEmptyValues(t *testing.T) { + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromSelector(context.Background(), nil, map[string][]string{ + "env": {}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "at least one tag key with values") +} + +func TestResolveVPCFromSelector_APIError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeVPCsAsList(gomock.Any(), gomock.Any()).Return(nil, errors.New("throttled")) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromSelector(context.Background(), ec2Client, map[string][]string{ + "env": {"prod"}, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "throttled") +} + +func TestResolveVPCFromSelector_MultipleFilters(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeVPCsAsList(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *ec2sdk.DescribeVpcsInput) ([]ec2types.Vpc, error) { + assert.Len(t, input.Filters, 2) + return []ec2types.Vpc{ + {VpcId: awssdk.String("vpc-multi-tag")}, + }, nil + }, + ) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + vpcID, err := provider.resolveVPCFromSelector(context.Background(), ec2Client, map[string][]string{ + "env": {"production"}, + "cluster": {"main"}, + }) + assert.NoError(t, err) + assert.Equal(t, "vpc-multi-tag", vpcID) +} + +// --- resolveVPCFromFirstSubnet tests --- + +func TestResolveVPCFromFirstSubnet_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeSubnetsAsList(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, input *ec2sdk.DescribeSubnetsInput) ([]ec2types.Subnet, error) { + assert.Equal(t, []string{"subnet-abc123"}, input.SubnetIds) + return []ec2types.Subnet{ + { + SubnetId: awssdk.String("subnet-abc123"), + VpcId: awssdk.String("vpc-from-subnet"), + }, + }, nil + }, + ) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + vpcID, err := provider.resolveVPCFromFirstSubnet(context.Background(), ec2Client, "subnet-abc123") + assert.NoError(t, err) + assert.Equal(t, "vpc-from-subnet", vpcID) +} + +func TestResolveVPCFromFirstSubnet_NotFound(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeSubnetsAsList(gomock.Any(), gomock.Any()).Return([]ec2types.Subnet{}, nil) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromFirstSubnet(context.Background(), ec2Client, "subnet-missing") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found in target region") +} + +func TestResolveVPCFromFirstSubnet_NilVpcID(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeSubnetsAsList(gomock.Any(), gomock.Any()).Return([]ec2types.Subnet{ + {SubnetId: awssdk.String("subnet-no-vpc"), VpcId: nil}, + }, nil) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromFirstSubnet(context.Background(), ec2Client, "subnet-no-vpc") + assert.Error(t, err) + assert.Contains(t, err.Error(), "has no VpcId") +} + +func TestResolveVPCFromFirstSubnet_APIError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + ec2Client.EXPECT().DescribeSubnetsAsList(gomock.Any(), gomock.Any()).Return(nil, errors.New("access denied")) + + provider := &defaultCloudProvider{logger: logr.New(&log.NullLogSink{})} + _, err := provider.resolveVPCFromFirstSubnet(context.Background(), ec2Client, "subnet-denied") + assert.Error(t, err) + assert.Contains(t, err.Error(), "access denied") +} diff --git a/pkg/gateway/lb_config_merger.go b/pkg/gateway/lb_config_merger.go index c5b353b65e..97d5ada9ec 100644 --- a/pkg/gateway/lb_config_merger.go +++ b/pkg/gateway/lb_config_merger.go @@ -136,6 +136,24 @@ func (merger *loadBalancerConfigMergerImpl) performTakeOneMerges(merged *elbv2gw merged.LoadBalancerSubnetsSelector = lowPriority.Spec.LoadBalancerSubnetsSelector } + if highPriority.Spec.Region != nil { + merged.Region = highPriority.Spec.Region + } else { + merged.Region = lowPriority.Spec.Region + } + + if highPriority.Spec.VpcID != nil { + merged.VpcID = highPriority.Spec.VpcID + } else { + merged.VpcID = lowPriority.Spec.VpcID + } + + if highPriority.Spec.VpcSelector != nil { + merged.VpcSelector = highPriority.Spec.VpcSelector + } else { + merged.VpcSelector = lowPriority.Spec.VpcSelector + } + if highPriority.Spec.SecurityGroups != nil { merged.SecurityGroups = highPriority.Spec.SecurityGroups } else { diff --git a/pkg/gateway/lb_config_merger_test.go b/pkg/gateway/lb_config_merger_test.go index 664534ecff..01978b6204 100644 --- a/pkg/gateway/lb_config_merger_test.go +++ b/pkg/gateway/lb_config_merger_test.go @@ -510,6 +510,61 @@ func Test_Merge(t *testing.T) { }, }, }, + { + name: "region and vpc merge - prefer gateway class", + gwClassLbConfig: elbv2gw.LoadBalancerConfiguration{ + Spec: elbv2gw.LoadBalancerConfigurationSpec{ + MergingMode: &mergeModeGWC, + Region: awssdk.String("us-east-1"), + VpcID: awssdk.String("vpc-111"), + VpcSelector: &map[string][]string{"Environment": {"prod"}}, + Tags: &map[string]string{}, + }, + }, + gwLbConfig: elbv2gw.LoadBalancerConfiguration{ + Spec: elbv2gw.LoadBalancerConfigurationSpec{ + Region: awssdk.String("eu-west-1"), + VpcID: awssdk.String("vpc-222"), + VpcSelector: &map[string][]string{"Purpose": {"shared"}}, + Tags: &map[string]string{}, + }, + }, + expected: elbv2gw.LoadBalancerConfiguration{ + Spec: elbv2gw.LoadBalancerConfigurationSpec{ + Region: awssdk.String("us-east-1"), + VpcID: awssdk.String("vpc-111"), + VpcSelector: &map[string][]string{"Environment": {"prod"}}, + LoadBalancerAttributes: []elbv2gw.LoadBalancerAttribute{}, + Tags: &map[string]string{}, + }, + }, + }, + { + name: "region and vpc merge - prefer gateway", + gwClassLbConfig: elbv2gw.LoadBalancerConfiguration{ + Spec: elbv2gw.LoadBalancerConfigurationSpec{ + MergingMode: &mergeModeGW, + Region: awssdk.String("us-east-1"), + VpcID: awssdk.String("vpc-111"), + Tags: &map[string]string{}, + }, + }, + gwLbConfig: elbv2gw.LoadBalancerConfiguration{ + Spec: elbv2gw.LoadBalancerConfigurationSpec{ + Region: awssdk.String("eu-west-1"), + VpcID: awssdk.String("vpc-222"), + Tags: &map[string]string{}, + }, + }, + expected: elbv2gw.LoadBalancerConfiguration{ + Spec: elbv2gw.LoadBalancerConfigurationSpec{ + Region: awssdk.String("eu-west-1"), + VpcID: awssdk.String("vpc-222"), + LoadBalancerAttributes: []elbv2gw.LoadBalancerAttribute{}, + Tags: &map[string]string{}, + }, + }, + }, } for _, tc := range testCases { diff --git a/pkg/gateway/model/base_model_builder.go b/pkg/gateway/model/base_model_builder.go index fd76645940..c7879f1b3a 100644 --- a/pkg/gateway/model/base_model_builder.go +++ b/pkg/gateway/model/base_model_builder.go @@ -31,10 +31,25 @@ import ( gwv1 "sigs.k8s.io/gateway-api/apis/v1" ) +// ReconcileContextInterface provides Cloud and resolvers for a reconcile. +// Always passed to Build(); for the default region, getters return the default implementations +// and IsCrossRegion() returns false. For non-default regions, all fields are region-scoped. +type ReconcileContextInterface interface { + GetCloud() services.Cloud + GetSubnetsResolver() networking.SubnetsResolver + GetVPCInfoProvider() networking.VPCInfoProvider + GetElbv2TaggingManager() elbv2deploy.TaggingManager + GetBackendSGProvider() networking.BackendSGProvider + GetSecurityGroupResolver() networking.SecurityGroupResolver + GetCertDiscovery() certs.CertDiscovery + GetTargetGroupARNMapper() shared_utils.TargetGroupARNMapper + IsCrossRegion() bool +} + // Builder builds the model stack for a Gateway resource. type Builder interface { // Build model stack for a gateway - Build(ctx context.Context, gw *gwv1.Gateway, lbConf elbv2gw.LoadBalancerConfiguration, routes map[int32][]routeutils.RouteDescriptor, currentAddonConfig []addon.Addon, secretsManager k8s.SecretsManager, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, isDelete bool) (core.Stack, *elbv2model.LoadBalancer, []addon.AddonMetadata, bool, []types.NamespacedName, error) + Build(ctx context.Context, gw *gwv1.Gateway, lbConf elbv2gw.LoadBalancerConfiguration, routes map[int32][]routeutils.RouteDescriptor, currentAddonConfig []addon.Addon, secretsManager k8s.SecretsManager, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, isDelete bool, rc ReconcileContextInterface) (core.Stack, *elbv2model.LoadBalancer, []addon.AddonMetadata, bool, []types.NamespacedName, error) } // NewModelBuilder construct a new baseModelBuilder @@ -58,6 +73,7 @@ func NewModelBuilder(subnetsResolver networking.SubnetsResolver, backendSGProvider: backendSGProvider, tgPropertiesConstructor: tgConfigConstructor, sgResolver: sgResolver, + enableBackendSG: enableBackendSG, vpcInfoProvider: vpcInfoProvider, elbv2TaggingManager: elbv2TaggingManager, featureGates: featureGates, @@ -115,6 +131,7 @@ type baseModelBuilder struct { subnetBuilder subnetModelBuilder securityGroupBuilder securityGroupBuilder + enableBackendSG bool tgPropertiesConstructor config2.TargetGroupConfigConstructor addOnBuilder modelAddons.AddOnBuilder @@ -123,7 +140,12 @@ type baseModelBuilder struct { defaultIPType elbv2model.IPAddressType } -func (baseBuilder *baseModelBuilder) Build(ctx context.Context, gw *gwv1.Gateway, lbConf elbv2gw.LoadBalancerConfiguration, routes map[int32][]routeutils.RouteDescriptor, currentAddonConfig []addon.Addon, secretsManager k8s.SecretsManager, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, isDelete bool) (core.Stack, *elbv2model.LoadBalancer, []addon.AddonMetadata, bool, []types.NamespacedName, error) { +func (baseBuilder *baseModelBuilder) Build(ctx context.Context, gw *gwv1.Gateway, lbConf elbv2gw.LoadBalancerConfiguration, routes map[int32][]routeutils.RouteDescriptor, currentAddonConfig []addon.Addon, secretsManager k8s.SecretsManager, targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper, isDelete bool, rc ReconcileContextInterface) (core.Stack, *elbv2model.LoadBalancer, []addon.AddonMetadata, bool, []types.NamespacedName, error) { + effectiveVpcID := rc.GetCloud().VpcID() + effectiveSubnetsResolver := rc.GetSubnetsResolver() + effectiveVpcInfoProvider := rc.GetVPCInfoProvider() + effectiveELBV2 := rc.GetCloud().ELBV2() + stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(gw))) if isDelete { if baseBuilder.isDeleteProtected(lbConf) { @@ -151,17 +173,45 @@ func (baseBuilder *baseModelBuilder) Build(ctx context.Context, gw *gwv1.Gateway /* Subnets */ - subnets, err := baseBuilder.subnetBuilder.buildLoadBalancerSubnets(ctx, lbConf.Spec.LoadBalancerSubnets, lbConf.Spec.LoadBalancerSubnetsSelector, scheme, ipAddressType, stack) + taggingManager := rc.GetElbv2TaggingManager() + if taggingManager == nil { + taggingManager = baseBuilder.subnetBuilder.(*subnetModelBuilderImpl).elbv2TaggingManager + } + subnetBuilderForBuild := newSubnetModelBuilder(baseBuilder.loadBalancerType, baseBuilder.subnetBuilder.(*subnetModelBuilderImpl).trackingProvider, effectiveSubnetsResolver, taggingManager) + subnets, err := subnetBuilderForBuild.buildLoadBalancerSubnets(ctx, lbConf.Spec.LoadBalancerSubnets, lbConf.Spec.LoadBalancerSubnetsSelector, scheme, ipAddressType, stack) if err != nil { return nil, nil, nil, false, nil, err } /* Security Groups */ - securityGroups, err := baseBuilder.securityGroupBuilder.buildSecurityGroups(ctx, stack, lbConf, gw, ipAddressType) + effectiveBackendSGProvider := baseBuilder.backendSGProvider + effectiveSGResolver := baseBuilder.sgResolver + if rc.GetBackendSGProvider() != nil { + effectiveBackendSGProvider = rc.GetBackendSGProvider() + } + if rc.GetSecurityGroupResolver() != nil { + effectiveSGResolver = rc.GetSecurityGroupResolver() + } + enableBackendSGForBuild := baseBuilder.enableBackendSG + if rc.IsCrossRegion() { + enableBackendSGForBuild = false + } + sgBuilder := baseBuilder.securityGroupBuilder + if effectiveBackendSGProvider != baseBuilder.backendSGProvider || effectiveSGResolver != baseBuilder.sgResolver || enableBackendSGForBuild != baseBuilder.enableBackendSG { + sgBuilder = newSecurityGroupBuilder(baseBuilder.gwTagHelper, baseBuilder.clusterName, baseBuilder.loadBalancerType, enableBackendSGForBuild, effectiveSGResolver, effectiveBackendSGProvider, baseBuilder.logger) + } + securityGroups, err := sgBuilder.buildSecurityGroups(ctx, stack, lbConf, gw, ipAddressType) if err != nil { return nil, nil, nil, false, nil, err } + if rc.IsCrossRegion() { + baseBuilder.logger.Info("Cross-region gateway detected, disabling backend SG networking for TGBs", + "effectiveVpcID", effectiveVpcID, "defaultVpcID", baseBuilder.vpcID) + securityGroups.backendSecurityGroupToken = nil + securityGroups.backendSecurityGroupAllocated = false + } + /* Combine everything to form a LoadBalancer */ spec, err := baseBuilder.lbBuilder.buildLoadBalancerSpec(scheme, ipAddressType, gw, lbConf, subnets, securityGroups.securityGroupTokens) if err != nil { @@ -179,9 +229,18 @@ func (baseBuilder *baseModelBuilder) Build(ctx context.Context, gw *gwv1.Gateway } lb := elbv2model.NewLoadBalancer(stack, shared_constants.ResourceIDLoadBalancer, spec) - tgbNetworkingBuilder := newTargetGroupBindingNetworkBuilder(baseBuilder.disableRestrictedSGRules, baseBuilder.vpcID, spec.Scheme, lbConf.Spec.SourceRanges, securityGroups, subnets.ec2Result, baseBuilder.vpcInfoProvider) - tgBuilder := newTargetGroupBuilder(baseBuilder.clusterName, baseBuilder.vpcID, baseBuilder.gwTagHelper, baseBuilder.loadBalancerType, tgbNetworkingBuilder, baseBuilder.tgPropertiesConstructor, baseBuilder.defaultTargetType, targetGroupNameToArnMapper) - listenerBuilder := newListenerBuilder(baseBuilder.loadBalancerType, tgBuilder, baseBuilder.gwTagHelper, baseBuilder.certDiscovery, baseBuilder.clusterName, baseBuilder.defaultSSLPolicy, baseBuilder.elbv2Client, baseBuilder.k8sClient, secretsManager, baseBuilder.logger) + effectiveTGMapper := targetGroupNameToArnMapper + if rc.GetTargetGroupARNMapper() != nil { + effectiveTGMapper = rc.GetTargetGroupARNMapper() + } + + tgbNetworkingBuilder := newTargetGroupBindingNetworkBuilder(baseBuilder.disableRestrictedSGRules, effectiveVpcID, spec.Scheme, lbConf.Spec.SourceRanges, securityGroups, subnets.ec2Result, effectiveVpcInfoProvider) + tgBuilder := newTargetGroupBuilder(baseBuilder.clusterName, effectiveVpcID, baseBuilder.gwTagHelper, baseBuilder.loadBalancerType, tgbNetworkingBuilder, baseBuilder.tgPropertiesConstructor, baseBuilder.defaultTargetType, effectiveTGMapper) + effectiveCertDiscovery := baseBuilder.certDiscovery + if rc.GetCertDiscovery() != nil { + effectiveCertDiscovery = rc.GetCertDiscovery() + } + listenerBuilder := newListenerBuilder(baseBuilder.loadBalancerType, tgBuilder, baseBuilder.gwTagHelper, effectiveCertDiscovery, baseBuilder.clusterName, baseBuilder.defaultSSLPolicy, effectiveELBV2, baseBuilder.k8sClient, secretsManager, baseBuilder.logger) secrets, err := listenerBuilder.buildListeners(ctx, stack, lb, gw, routes, lbConf) if err != nil { diff --git a/pkg/gateway/model/model_build_target_group_binding_network_test.go b/pkg/gateway/model/model_build_target_group_binding_network_test.go index cd54bc8f74..d9a898725c 100644 --- a/pkg/gateway/model/model_build_target_group_binding_network_test.go +++ b/pkg/gateway/model/model_build_target_group_binding_network_test.go @@ -421,6 +421,22 @@ func Test_buildTargetGroupBindingNetworking_standardBuilder(t *testing.T) { }, }, }, + { + name: "nil backend SG token (cross-region) returns nil networking", + sgOutput: securityGroupOutput{ + securityGroupTokens: []core.StringToken{core.LiteralStringToken("sg-remote-region")}, + backendSecurityGroupToken: nil, + backendSecurityGroupAllocated: false, + }, + tgSpec: elbv2model.TargetGroupSpec{ + Protocol: elbv2model.ProtocolHTTP, + HealthCheckConfig: &elbv2model.TargetGroupHealthCheckConfig{ + Port: &intstr80, + }, + }, + targetPort: intstr80, + expected: nil, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/pkg/targetgroupbinding/resource_manager.go b/pkg/targetgroupbinding/resource_manager.go index f912d4b296..db159cf916 100644 --- a/pkg/targetgroupbinding/resource_manager.go +++ b/pkg/targetgroupbinding/resource_manager.go @@ -46,13 +46,14 @@ type ResourceManager interface { } // NewDefaultResourceManager constructs new defaultResourceManager. -func NewDefaultResourceManager(k8sClient client.Client, elbv2Client services.ELBV2, +// elbv2Provider is optional; when set, cross-region target group ARNs are resolved to a regional ELBV2 client. +func NewDefaultResourceManager(k8sClient client.Client, elbv2Client services.ELBV2, defaultRegion string, elbv2Provider ELBV2ClientProvider, podInfoRepo k8s.PodInfoRepo, networkingManager networking.NetworkingManager, vpcInfoProvider networking.VPCInfoProvider, multiClusterManager MultiClusterManager, metricsCollector lbcmetrics.MetricCollector, vpcID string, failOpenEnabled bool, endpointSliceEnabled bool, eventRecorder record.EventRecorder, logger logr.Logger, maxTargetsPerTargetGroup int) *defaultResourceManager { - targetsManager := NewCachedTargetsManager(elbv2Client, logger) + targetsManager := NewCachedTargetsManager(elbv2Client, defaultRegion, elbv2Provider, logger) endpointResolver := backend.NewDefaultEndpointResolver(k8sClient, podInfoRepo, failOpenEnabled, endpointSliceEnabled, logger) return &defaultResourceManager{ k8sClient: k8sClient, @@ -731,13 +732,14 @@ func (m *defaultResourceManager) generateOverrideAzFn(ctx context.Context, vpcID } } + usingNonLocalVPC := vpcID != m.vpcID vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID) if err != nil { - // A VPC Not Found Error along with cross-account usage means that the VPC either, is not shared with the assume - // role account OR this falls into case (1) from above where the VPC is just peered but not shared with RAM. - // As we can't differentiate if RAM sharing wasn't set up correctly OR the VPC is set up via peering, we will - // just default to assume that the VPC is peered but not shared. - if isVPCNotFoundError(err) && usingCrossAccount { + // A VPC Not Found Error means either: + // 1. Cross-account with peered VPC (not RAM-shared) — the VPC isn't visible to the controller. + // 2. Cross-region — the VPC is in a different region and can't be described by the local EC2 client. + // In both cases, pod IPs are outside the TG's VPC, so we override AZ to "all" for all targets. + if isVPCNotFoundError(err) && (usingCrossAccount || usingNonLocalVPC) { m.invalidVpcCacheMutex.Lock() m.invalidVpcCache.Set(invalidVPCCacheKey, true, m.invalidVpcCacheTTL) m.invalidVpcCacheMutex.Unlock() diff --git a/pkg/targetgroupbinding/resource_manager_test.go b/pkg/targetgroupbinding/resource_manager_test.go index 278f5e6e59..3cbf7f6398 100644 --- a/pkg/targetgroupbinding/resource_manager_test.go +++ b/pkg/targetgroupbinding/resource_manager_test.go @@ -502,13 +502,14 @@ func Test_defaultResourceManager_GenerateOverrideAzFn(t *testing.T) { } testCases := []struct { - name string - vpcInfoCalls int - assumeRole string - vpcInfo networking.VPCInfo - vpcInfoError error - ipTestCases []ipTestCase - expectErr bool + name string + vpcInfoCalls int + assumeRole string + controllerVPCID string + vpcInfo networking.VPCInfo + vpcInfoError error + ipTestCases []ipTestCase + expectErr bool }{ { name: "standard case ipv4", @@ -665,10 +666,11 @@ func Test_defaultResourceManager_GenerateOverrideAzFn(t *testing.T) { }, }, { - name: "not found error from vpc info should be propagated when not using assume role", - vpcInfoCalls: 1, - vpcInfoError: &smithy.GenericAPIError{Code: "InvalidVpcID.NotFound", Message: ""}, - expectErr: true, + name: "not found error from vpc info should be propagated when not using assume role and local vpc", + controllerVPCID: vpcId, + vpcInfoCalls: 1, + vpcInfoError: &smithy.GenericAPIError{Code: "InvalidVpcID.NotFound", Message: ""}, + expectErr: true, }, { name: "assume role case peered vpc other error should get propagated", @@ -677,6 +679,26 @@ func Test_defaultResourceManager_GenerateOverrideAzFn(t *testing.T) { vpcInfoError: &smithy.GenericAPIError{Code: "other error", Message: ""}, expectErr: true, }, + { + name: "cross-region vpc not found should override AZ for all IPs", + controllerVPCID: "different-vpc", + vpcInfoCalls: 1, + vpcInfoError: &smithy.GenericAPIError{Code: "InvalidVpcID.NotFound", Message: ""}, + ipTestCases: []ipTestCase{ + { + ip: netip.MustParseAddr("172.0.0.0"), + result: true, + }, + { + ip: netip.MustParseAddr("127.0.0.1"), + result: true, + }, + { + ip: netip.MustParseAddr("2001:db8:0:0:0:0:0:0"), + result: true, + }, + }, + }, } for _, tc := range testCases { @@ -689,6 +711,7 @@ func Test_defaultResourceManager_GenerateOverrideAzFn(t *testing.T) { logger: logr.New(&log.NullLogSink{}), invalidVpcCache: cache.NewExpiring(), vpcInfoProvider: vpcInfoProvider, + vpcID: tc.controllerVPCID, } returnedFn, err := m.generateOverrideAzFn(context.Background(), vpcId, tc.assumeRole) diff --git a/pkg/targetgroupbinding/targets_manager.go b/pkg/targetgroupbinding/targets_manager.go index be7b7bd217..1d6c4d9075 100644 --- a/pkg/targetgroupbinding/targets_manager.go +++ b/pkg/targetgroupbinding/targets_manager.go @@ -2,6 +2,7 @@ package targetgroupbinding import ( "context" + "strings" "sync" "time" @@ -34,10 +35,16 @@ type TargetsManager interface { ListTargets(ctx context.Context, tgb *elbv2api.TargetGroupBinding) ([]TargetInfo, error) } -// NewCachedTargetsManager constructs new cachedTargetsManager -func NewCachedTargetsManager(elbv2Client services.ELBV2, logger logr.Logger) *cachedTargetsManager { +// ELBV2ClientProvider returns an ELBV2 client configured for the given region. +type ELBV2ClientProvider func(region string) (services.ELBV2, error) + +// NewCachedTargetsManager constructs new cachedTargetsManager. +// elbv2Provider is optional; when set, cross-region target group ARNs are resolved to a regional ELBV2 client. +func NewCachedTargetsManager(elbv2Client services.ELBV2, defaultRegion string, elbv2Provider ELBV2ClientProvider, logger logr.Logger) *cachedTargetsManager { return &cachedTargetsManager{ elbv2Client: elbv2Client, + defaultRegion: defaultRegion, + elbv2Provider: elbv2Provider, targetsCache: cache.NewExpiring(), targetsCacheTTL: defaultTargetsCacheTTL, registerTargetsChunkSize: defaultRegisterTargetsChunkSize, @@ -53,7 +60,9 @@ var _ TargetsManager = &cachedTargetsManager{} // When list Targets with RefreshTargets list Option set, // only targets with ongoing TargetHealth(unknown/initial/draining) TargetHealth will be refreshed. type cachedTargetsManager struct { - elbv2Client services.ELBV2 + elbv2Client services.ELBV2 + defaultRegion string + elbv2Provider ELBV2ClientProvider // cache of targets by targetGroupARN. // NOTE: since this cache implementation will automatically GC expired entries, we don't need to GC entries. @@ -72,6 +81,22 @@ type cachedTargetsManager struct { logger logr.Logger } +// resolveELBV2 returns the ELBV2 client for the region encoded in the target group ARN. +// Falls back to the default client if the provider is nil or the region matches the default. +func (m *cachedTargetsManager) resolveELBV2(tgARN string) (services.ELBV2, error) { + if m.elbv2Provider == nil || tgARN == "" { + return m.elbv2Client, nil + } + parts := strings.SplitN(tgARN, ":", 6) + if len(parts) >= 5 { + arnRegion := parts[3] + if arnRegion != "" && arnRegion != m.defaultRegion { + return m.elbv2Provider(arnRegion) + } + } + return m.elbv2Client, nil +} + // cache entry for targetsCache type targetsCacheItem struct { // mutex protects below fields @@ -92,7 +117,11 @@ func (m *cachedTargetsManager) RegisterTargets(ctx context.Context, tgb *elbv2ap "arn", tgARN, "targets", targetsChunk) - clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + regionalClient, err := m.resolveELBV2(tgARN) + if err != nil { + return err + } + clientToUse, err := regionalClient.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) if err != nil { return err } @@ -120,7 +149,11 @@ func (m *cachedTargetsManager) DeregisterTargets(ctx context.Context, tgb *elbv2 m.logger.Info("deRegistering targets", "arn", tgARN, "targets", targetsChunk) - clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + regionalClient, err := m.resolveELBV2(tgARN) + if err != nil { + return err + } + clientToUse, err := regionalClient.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) if err != nil { return err } @@ -213,7 +246,11 @@ func (m *cachedTargetsManager) listTargetsFromAWS(ctx context.Context, tgb *elbv TargetGroupArn: aws.String(tgARN), Targets: targetByIdPort(targets), } - clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + regionalClient, err := m.resolveELBV2(tgARN) + if err != nil { + return nil, err + } + clientToUse, err := regionalClient.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) if err != nil { return nil, err } diff --git a/pkg/targetgroupbinding/targets_manager_resolve_test.go b/pkg/targetgroupbinding/targets_manager_resolve_test.go new file mode 100644 index 0000000000..b3ed8c4cf1 --- /dev/null +++ b/pkg/targetgroupbinding/targets_manager_resolve_test.go @@ -0,0 +1,116 @@ +package targetgroupbinding + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" +) + +func Test_cachedTargetsManager_resolveELBV2(t *testing.T) { + tests := []struct { + name string + defaultRegion string + tgARN string + hasProvider bool + providerRegion string + providerErr error + wantDefault bool + wantErr string + }{ + { + name: "same region ARN returns default client", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/tg/abc", + hasProvider: true, + wantDefault: true, + }, + { + name: "different region ARN returns provider client", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:ap-northeast-1:025054649006:targetgroup/tg/abc", + hasProvider: true, + providerRegion: "ap-northeast-1", + wantDefault: false, + }, + { + name: "empty ARN returns default client", + defaultRegion: "us-east-1", + tgARN: "", + hasProvider: true, + wantDefault: true, + }, + { + name: "nil provider returns default client for cross-region ARN", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:ap-northeast-1:025054649006:targetgroup/tg/abc", + hasProvider: false, + wantDefault: true, + }, + { + name: "provider error is propagated", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:eu-west-1:111111111111:targetgroup/tg/abc", + hasProvider: true, + providerRegion: "eu-west-1", + providerErr: errors.New("region not supported"), + wantErr: "region not supported", + }, + { + name: "malformed ARN returns default client", + defaultRegion: "us-east-1", + tgARN: "not-an-arn", + hasProvider: true, + wantDefault: true, + }, + { + name: "ARN with empty region field returns default client", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing::123456789012:targetgroup/tg/abc", + hasProvider: true, + wantDefault: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + defaultClient := services.NewMockELBV2(ctrl) + providerClient := services.NewMockELBV2(ctrl) + + var provider ELBV2ClientProvider + if tt.hasProvider { + provider = func(region string) (services.ELBV2, error) { + if tt.providerRegion != "" { + assert.Equal(t, tt.providerRegion, region) + } + if tt.providerErr != nil { + return nil, tt.providerErr + } + return providerClient, nil + } + } + + m := &cachedTargetsManager{ + elbv2Client: defaultClient, + defaultRegion: tt.defaultRegion, + elbv2Provider: provider, + } + + got, err := m.resolveELBV2(tt.tgARN) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + if tt.wantDefault { + assert.Equal(t, defaultClient, got) + } else { + assert.Equal(t, providerClient, got) + } + } + }) + } +} diff --git a/webhooks/elbv2/targetgroup_helper.go b/webhooks/elbv2/targetgroup_helper.go index dc847dc6fe..89c6bb6b21 100644 --- a/webhooks/elbv2/targetgroup_helper.go +++ b/webhooks/elbv2/targetgroup_helper.go @@ -2,6 +2,8 @@ package elbv2 import ( "context" + "strings" + elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/pkg/errors" @@ -9,6 +11,34 @@ import ( "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" ) +// ELBV2ClientProvider returns an ELBV2 client configured for the given region. +// Used by webhooks to describe target groups in a non-default region. +type ELBV2ClientProvider func(region string) (services.ELBV2, error) + +// regionFromTGARN extracts the AWS region from a target group ARN. +// ARN format: arn:aws:elasticloadbalancing:::targetgroup/... +// Returns empty string if the ARN cannot be parsed. +func regionFromTGARN(arn string) string { + parts := strings.SplitN(arn, ":", 6) + if len(parts) >= 5 { + return parts[3] + } + return "" +} + +// resolveELBV2ForTGB returns the ELBV2 client that should be used for the given TGB. +// If the TGB's target group ARN refers to a different region, a regional client is obtained from the provider. +func resolveELBV2ForTGB(defaultClient services.ELBV2, defaultRegion string, provider ELBV2ClientProvider, tgARN string) (services.ELBV2, error) { + if tgARN == "" || provider == nil { + return defaultClient, nil + } + arnRegion := regionFromTGARN(tgARN) + if arnRegion == "" || arnRegion == defaultRegion { + return defaultClient, nil + } + return provider(arnRegion) +} + // getTargetGroupFromAWS returns the AWS target group corresponding to the arn func getTargetGroupFromAWS(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) { tgARN := tgb.Spec.TargetGroupARN diff --git a/webhooks/elbv2/targetgroup_helper_test.go b/webhooks/elbv2/targetgroup_helper_test.go new file mode 100644 index 0000000000..0c4ef15bff --- /dev/null +++ b/webhooks/elbv2/targetgroup_helper_test.go @@ -0,0 +1,162 @@ +package elbv2 + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" +) + +func Test_regionFromTGARN(t *testing.T) { + tests := []struct { + name string + arn string + want string + }{ + { + name: "standard TG ARN", + arn: "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-tg/abc123", + want: "us-east-1", + }, + { + name: "tokyo region TG ARN", + arn: "arn:aws:elasticloadbalancing:ap-northeast-1:025054649006:targetgroup/k8s-kubesyst-traefikt-1cdbaaab9f/964ac6392723fe3a", + want: "ap-northeast-1", + }, + { + name: "eu-west-1 region", + arn: "arn:aws:elasticloadbalancing:eu-west-1:111111111111:targetgroup/tg-name/deadbeef", + want: "eu-west-1", + }, + { + name: "empty ARN", + arn: "", + want: "", + }, + { + name: "malformed ARN with fewer colons", + arn: "arn:aws:elasticloadbalancing", + want: "", + }, + { + name: "ARN with exactly 4 parts", + arn: "arn:aws:elasticloadbalancing:us-west-2", + want: "", + }, + { + name: "ARN with 5 parts (minimum for region extraction)", + arn: "arn:aws:elasticloadbalancing:us-west-2:123456789012", + want: "us-west-2", + }, + { + name: "china partition ARN", + arn: "arn:aws-cn:elasticloadbalancing:cn-north-1:123456789012:targetgroup/tg/abc", + want: "cn-north-1", + }, + { + name: "govcloud ARN", + arn: "arn:aws-us-gov:elasticloadbalancing:us-gov-west-1:123456789012:targetgroup/tg/abc", + want: "us-gov-west-1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := regionFromTGARN(tt.arn) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_resolveELBV2ForTGB(t *testing.T) { + tests := []struct { + name string + defaultRegion string + tgARN string + providerRegion string + providerReturns bool + providerErr error + wantDefault bool + wantErr string + }{ + { + name: "empty ARN returns default client", + defaultRegion: "us-east-1", + tgARN: "", + wantDefault: true, + }, + { + name: "same region returns default client", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/tg/abc", + wantDefault: true, + }, + { + name: "different region returns provider client", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:ap-northeast-1:123456789012:targetgroup/tg/abc", + providerRegion: "ap-northeast-1", + providerReturns: true, + wantDefault: false, + }, + { + name: "provider error is propagated", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:eu-west-1:123456789012:targetgroup/tg/abc", + providerRegion: "eu-west-1", + providerErr: errors.New("failed to create client"), + wantErr: "failed to create client", + }, + { + name: "nil provider returns default client even for different region", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing:ap-northeast-1:123456789012:targetgroup/tg/abc", + wantDefault: true, + }, + { + name: "malformed ARN returns default client", + defaultRegion: "us-east-1", + tgARN: "not-an-arn", + wantDefault: true, + }, + { + name: "ARN with empty region returns default client", + defaultRegion: "us-east-1", + tgARN: "arn:aws:elasticloadbalancing::123456789012:targetgroup/tg/abc", + wantDefault: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + defaultClient := services.NewMockELBV2(ctrl) + providerClient := services.NewMockELBV2(ctrl) + + var provider ELBV2ClientProvider + if tt.providerRegion != "" { + provider = func(region string) (services.ELBV2, error) { + assert.Equal(t, tt.providerRegion, region) + if tt.providerErr != nil { + return nil, tt.providerErr + } + return providerClient, nil + } + } + + got, err := resolveELBV2ForTGB(defaultClient, tt.defaultRegion, provider, tt.tgARN) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + if tt.wantDefault { + assert.Equal(t, defaultClient, got) + } else { + assert.Equal(t, providerClient, got) + } + } + }) + } +} diff --git a/webhooks/elbv2/targetgroupbinding_mutator.go b/webhooks/elbv2/targetgroupbinding_mutator.go index 20df3e741a..89db9e9cf8 100644 --- a/webhooks/elbv2/targetgroupbinding_mutator.go +++ b/webhooks/elbv2/targetgroupbinding_mutator.go @@ -27,9 +27,12 @@ type tgCacheObject struct { } // NewTargetGroupBindingMutator returns a mutator for TargetGroupBinding CRD. -func NewTargetGroupBindingMutator(elbv2Client services.ELBV2, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector) *targetGroupBindingMutator { +// elbv2Provider is optional; when set, cross-region target group ARNs are resolved to a regional ELBV2 client. +func NewTargetGroupBindingMutator(elbv2Client services.ELBV2, defaultRegion string, elbv2Provider ELBV2ClientProvider, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector) *targetGroupBindingMutator { return &targetGroupBindingMutator{ elbv2Client: elbv2Client, + defaultRegion: defaultRegion, + elbv2Provider: elbv2Provider, logger: logger, metricsCollector: metricsCollector, } @@ -39,6 +42,8 @@ var _ webhook.Mutator = &targetGroupBindingMutator{} type targetGroupBindingMutator struct { elbv2Client services.ELBV2 + defaultRegion string + elbv2Provider ELBV2ClientProvider logger logr.Logger metricsCollector lbcmetrics.MetricCollector } @@ -51,14 +56,6 @@ func (m *targetGroupBindingMutator) MutateCreate(ctx context.Context, obj runtim tgb := obj.(*elbv2api.TargetGroupBinding) - targetGroupCache := sync.OnceValue(func() tgCacheObject { - targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb) - return tgCacheObject{ - tg: targetGroup, - error: err, - } - }) - if tgb.Spec.TargetGroupARN == "" && tgb.Spec.TargetGroupName == "" { m.metricsCollector.ObserveWebhookMutationError(apiPathMutateELBv2TargetGroupBinding, "checkTargetGroupArnOrName") return nil, errors.Errorf("must provide either TargetGroupARN or TargetGroupName") @@ -67,6 +64,20 @@ func (m *targetGroupBindingMutator) MutateCreate(ctx context.Context, obj runtim m.metricsCollector.ObserveWebhookMutationError(apiPathMutateELBv2TargetGroupBinding, "getArnFromNameIfNeeded") return nil, err } + + elbv2ForTGB, err := resolveELBV2ForTGB(m.elbv2Client, m.defaultRegion, m.elbv2Provider, tgb.Spec.TargetGroupARN) + if err != nil { + m.metricsCollector.ObserveWebhookMutationError(apiPathMutateELBv2TargetGroupBinding, "resolveELBV2ForTGB") + return nil, errors.Wrapf(err, "unable to create ELBV2 client for target group %s", tgb.Spec.TargetGroupARN) + } + + targetGroupCache := sync.OnceValue(func() tgCacheObject { + targetGroup, err := getTargetGroupFromAWS(ctx, elbv2ForTGB, tgb) + return tgCacheObject{ + tg: targetGroup, + error: err, + } + }) if err := m.defaultingTargetType(tgb, targetGroupCache); err != nil { m.metricsCollector.ObserveWebhookMutationError(apiPathMutateELBv2TargetGroupBinding, "defaultingTargetType") return nil, err diff --git a/webhooks/elbv2/targetgroupbinding_validator.go b/webhooks/elbv2/targetgroupbinding_validator.go index 0ec817b4cb..2121ada287 100644 --- a/webhooks/elbv2/targetgroupbinding_validator.go +++ b/webhooks/elbv2/targetgroupbinding_validator.go @@ -37,10 +37,13 @@ const ( var vpcIDPatternRegex = regexp.MustCompile("^(?:vpc-[0-9a-f]{8}|vpc-[0-9a-f]{17}|vpc-[0-9a-f]{32})$") // NewTargetGroupBindingValidator returns a validator for TargetGroupBinding CRD. -func NewTargetGroupBindingValidator(k8sClient client.Client, elbv2Client services.ELBV2, vpcID string, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector) *targetGroupBindingValidator { +// elbv2Provider is optional; when set, cross-region target group ARNs are resolved to a regional ELBV2 client. +func NewTargetGroupBindingValidator(k8sClient client.Client, elbv2Client services.ELBV2, defaultRegion string, elbv2Provider ELBV2ClientProvider, vpcID string, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector) *targetGroupBindingValidator { return &targetGroupBindingValidator{ k8sClient: k8sClient, elbv2Client: elbv2Client, + defaultRegion: defaultRegion, + elbv2Provider: elbv2Provider, logger: logger, vpcID: vpcID, metricsCollector: metricsCollector, @@ -52,6 +55,8 @@ var _ webhook.Validator = &targetGroupBindingValidator{} type targetGroupBindingValidator struct { k8sClient client.Client elbv2Client services.ELBV2 + defaultRegion string + elbv2Provider ELBV2ClientProvider logger logr.Logger vpcID string metricsCollector lbcmetrics.MetricCollector @@ -64,15 +69,21 @@ func (v *targetGroupBindingValidator) Prototype(_ admission.Request) (runtime.Ob func (v *targetGroupBindingValidator) ValidateCreate(ctx context.Context, obj runtime.Object) error { tgb := obj.(*elbv2api.TargetGroupBinding) + elbv2ForTGB, err := resolveELBV2ForTGB(v.elbv2Client, v.defaultRegion, v.elbv2Provider, tgb.Spec.TargetGroupARN) + if err != nil { + v.metricsCollector.ObserveWebhookValidationError(apiPathValidateELBv2TargetGroupBinding, "resolveELBV2ForTGB") + return errors.Wrapf(err, "unable to create ELBV2 client for target group %s", tgb.Spec.TargetGroupARN) + } + targetGroupCache := sync.OnceValue(func() tgCacheObject { - targetGroup, err := getTargetGroupFromAWS(ctx, v.elbv2Client, tgb) + targetGroup, err := getTargetGroupFromAWS(ctx, elbv2ForTGB, tgb) return tgCacheObject{ tg: targetGroup, error: err, } }) - if err := v.checkRequiredFields(ctx, tgb); err != nil { + if err := v.checkRequiredFields(ctx, tgb, elbv2ForTGB); err != nil { v.metricsCollector.ObserveWebhookValidationError(apiPathValidateELBv2TargetGroupBinding, "checkRequiredFields") return err } @@ -109,7 +120,14 @@ func (v *targetGroupBindingValidator) ValidateCreate(ctx context.Context, obj ru func (v *targetGroupBindingValidator) ValidateUpdate(ctx context.Context, obj runtime.Object, oldObj runtime.Object) error { tgb := obj.(*elbv2api.TargetGroupBinding) oldTgb := oldObj.(*elbv2api.TargetGroupBinding) - if err := v.checkRequiredFields(ctx, tgb); err != nil { + + elbv2ForTGB, err := resolveELBV2ForTGB(v.elbv2Client, v.defaultRegion, v.elbv2Provider, tgb.Spec.TargetGroupARN) + if err != nil { + v.metricsCollector.ObserveWebhookValidationError(apiPathValidateELBv2TargetGroupBinding, "resolveELBV2ForTGB") + return errors.Wrapf(err, "unable to create ELBV2 client for target group %s", tgb.Spec.TargetGroupARN) + } + + if err := v.checkRequiredFields(ctx, tgb, elbv2ForTGB); err != nil { v.metricsCollector.ObserveWebhookValidationError(apiPathValidateELBv2TargetGroupBinding, "checkRequiredFields") return err } @@ -137,7 +155,8 @@ func (v *targetGroupBindingValidator) ValidateDelete(ctx context.Context, obj ru } // checkRequiredFields will check required fields are not absent. -func (v *targetGroupBindingValidator) checkRequiredFields(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error { +// elbv2ForTGB is the ELBV2 client scoped to the correct region for this TGB. +func (v *targetGroupBindingValidator) checkRequiredFields(ctx context.Context, tgb *elbv2api.TargetGroupBinding, elbv2ForTGB services.ELBV2) error { var absentRequiredFields []string if tgb.Spec.TargetGroupARN == "" { if tgb.Spec.TargetGroupName == "" { @@ -154,7 +173,7 @@ func (v *targetGroupBindingValidator) checkRequiredFields(ctx context.Context, t By changing the object here I guarantee as early as possible that that assumption is true. */ - tgObj, err := getTargetGroupsByNameFromAWS(ctx, v.elbv2Client, tgb) + tgObj, err := getTargetGroupsByNameFromAWS(ctx, elbv2ForTGB, tgb) if err != nil { return fmt.Errorf("searching TargetGroup with name %s: %w", tgb.Spec.TargetGroupName, err) } diff --git a/webhooks/elbv2/targetgroupbinding_validator_test.go b/webhooks/elbv2/targetgroupbinding_validator_test.go index 5a15dc95eb..aa0c11e569 100644 --- a/webhooks/elbv2/targetgroupbinding_validator_test.go +++ b/webhooks/elbv2/targetgroupbinding_validator_test.go @@ -766,7 +766,7 @@ func Test_targetGroupBindingValidator_checkRequiredFields(t *testing.T) { logger: logr.New(&log.NullLogSink{}), metricsCollector: mockMetricsCollector, } - err := v.checkRequiredFields(context.Background(), tt.args.tgb) + err := v.checkRequiredFields(context.Background(), tt.args.tgb, nil) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else {