diff --git a/provider/awssd/aws_sd.go b/provider/awssd/aws_sd.go index d6734b0a8c..843032bb6f 100644 --- a/provider/awssd/aws_sd.go +++ b/provider/awssd/aws_sd.go @@ -303,14 +303,18 @@ func (p *AWSSDProvider) updatesToCreates(changes *plan.Changes) ([]*endpoint.End func (p *AWSSDProvider) submitCreates(ctx context.Context, namespaces []*sdtypes.NamespaceSummary, changes []*endpoint.Endpoint) error { changesByNamespaceID := p.changesByNamespaceID(namespaces, changes) + nsIDToName := namespaceIDToName(namespaces) + for nsID, changeList := range changesByNamespaceID { services, err := p.ListServicesByNamespaceID(ctx, aws.String(nsID)) if err != nil { return err } + nsName := nsIDToName[nsID] for _, ch := range changeList { - _, srvName := p.parseHostname(ch.DNSName) + hostname := strings.TrimSuffix(ch.DNSName, ".") + srvName := strings.TrimSuffix(hostname, "."+nsName) srv := services[srvName] if srv == nil { @@ -342,15 +346,18 @@ func (p *AWSSDProvider) submitCreates(ctx context.Context, namespaces []*sdtypes func (p *AWSSDProvider) submitDeletes(ctx context.Context, namespaces []*sdtypes.NamespaceSummary, changes []*endpoint.Endpoint) error { changesByNamespaceID := p.changesByNamespaceID(namespaces, changes) + nsIDToName := namespaceIDToName(namespaces) + for nsID, changeList := range changesByNamespaceID { services, err := p.ListServicesByNamespaceID(ctx, aws.String(nsID)) if err != nil { return err } + nsName := nsIDToName[nsID] for _, ch := range changeList { - hostname := ch.DNSName - _, srvName := p.parseHostname(hostname) + hostname := strings.TrimSuffix(ch.DNSName, ".") + srvName := strings.TrimSuffix(hostname, "."+nsName) srv := services[srvName] if srv == nil { @@ -605,7 +612,7 @@ func (p *AWSSDProvider) changesByNamespaceID(namespaces []*sdtypes.NamespaceSumm for _, c := range changes { // trim the trailing dot from hostname if any hostname := strings.TrimSuffix(c.DNSName, ".") - nsName, _ := p.parseHostname(hostname) + nsName := parseNamespace(hostname, namespaces) matchingNamespaces := matchingNamespaces(nsName, namespaces) if len(matchingNamespaces) == 0 { @@ -627,25 +634,6 @@ func (p *AWSSDProvider) changesByNamespaceID(namespaces []*sdtypes.NamespaceSumm return changesByNsID } -// returns list of all namespaces matching given hostname -func matchingNamespaces(hostname string, namespaces []*sdtypes.NamespaceSummary) []*sdtypes.NamespaceSummary { - matchingNamespaces := make([]*sdtypes.NamespaceSummary, 0) - - for _, ns := range namespaces { - if *ns.Name == hostname { - matchingNamespaces = append(matchingNamespaces, ns) - } - } - - return matchingNamespaces -} - -// parseHostname parse hostname to namespace (domain) and service -func (p *AWSSDProvider) parseHostname(hostname string) (string, string) { - parts := strings.Split(hostname, ".") - return strings.Join(parts[1:], "."), parts[0] -} - // determine service routing policy based on endpoint type func (p *AWSSDProvider) routingPolicyFromEndpoint(ep *endpoint.Endpoint) sdtypes.RoutingPolicy { if ep.RecordType == endpoint.RecordTypeA || ep.RecordType == endpoint.RecordTypeAAAA { @@ -680,3 +668,44 @@ func (p *AWSSDProvider) isAWSLoadBalancer(hostname string) bool { return matchElb || matchNlb } + +// returns list of all namespaces matching given hostname +func matchingNamespaces(hostname string, namespaces []*sdtypes.NamespaceSummary) []*sdtypes.NamespaceSummary { + matchingNamespaces := make([]*sdtypes.NamespaceSummary, 0) + + for _, ns := range namespaces { + if *ns.Name == hostname { + matchingNamespaces = append(matchingNamespaces, ns) + } + } + + return matchingNamespaces +} + +// parseNamespace returns the Cloud Map namespace name that matches the given +// hostname using longest-suffix matching. Falls back to the original first-dot +// split when no namespace suffix matches. +func parseNamespace(hostname string, namespaces []*sdtypes.NamespaceSummary) string { + hostname = strings.TrimSuffix(hostname, ".") + var bestNS string + for _, ns := range namespaces { + nsName := aws.ToString(ns.Name) + if len(nsName) > len(bestNS) && strings.HasSuffix(hostname, "."+nsName) { + bestNS = nsName + } + } + if bestNS != "" { + return bestNS + } + parts := strings.Split(hostname, ".") + return strings.Join(parts[1:], ".") +} + +// namespaceIDToName builds a map from namespace ID to namespace name. +func namespaceIDToName(namespaces []*sdtypes.NamespaceSummary) map[string]string { + m := make(map[string]string, len(namespaces)) + for _, ns := range namespaces { + m[aws.ToString(ns.Id)] = aws.ToString(ns.Name) + } + return m +} diff --git a/provider/awssd/aws_sd_test.go b/provider/awssd/aws_sd_test.go index d7df8dabf9..0c152c711e 100644 --- a/provider/awssd/aws_sd_test.go +++ b/provider/awssd/aws_sd_test.go @@ -280,6 +280,56 @@ func TestAWSSDProvider_ApplyChanges_Update(t *testing.T) { assert.Equal(t, "1.2.3.5", api.deregistered[0], "wrong target de-registered") } +func TestAWSSDProvider_ApplyChanges_DottedServiceName(t *testing.T) { + namespaces := map[string]*sdtypes.Namespace{ + "dev-local": { + Id: aws.String("dev-local"), + Name: aws.String("dev.local"), + Type: sdtypes.NamespaceTypeDnsPrivate, + }, + } + + api := &AWSSDClientStub{ + namespaces: namespaces, + services: make(map[string]map[string]*sdtypes.Service), + instances: make(map[string]map[string]*sdtypes.Instance), + } + + createEndpoints := []*endpoint.Endpoint{ + {DNSName: "my-app.elb.dev.local", Targets: endpoint.Targets{"1.2.3.4"}, RecordType: endpoint.RecordTypeA, RecordTTL: 60}, + } + + provider := newTestAWSSDProvider(api, endpoint.NewDomainFilter([]string{"dev.local"}), "", "") + + ctx := t.Context() + + err := provider.ApplyChanges(ctx, &plan.Changes{ + Create: createEndpoints, + }) + require.NoError(t, err) + + // service must be created with the dotted name "my-app.elb" + assert.Len(t, api.services["dev-local"], 1) + existingServices, err := provider.ListServicesByNamespaceID(ctx, namespaces["dev-local"].Id) + require.NoError(t, err) + assert.NotNil(t, existingServices["my-app.elb"], "service should be named 'my-app.elb'") + + // verify the record round-trips through Records() + endpoints, err := provider.Records(ctx) + require.NoError(t, err) + assert.True(t, testutils.SameEndpoints(createEndpoints, endpoints), + "expected and actual endpoints don't match, expected=%v, actual=%v", createEndpoints, endpoints) + + // apply deletes + err = provider.ApplyChanges(ctx, &plan.Changes{ + Delete: createEndpoints, + }) + require.NoError(t, err) + + endpoints, _ = provider.Records(ctx) + assert.Empty(t, endpoints) +} + func TestAWSSDProvider_ListNamespaces(t *testing.T) { namespaces := map[string]*sdtypes.Namespace{ "private": { @@ -1042,3 +1092,95 @@ func TestAWSSDProvider_awsTags(t *testing.T) { require.ElementsMatch(t, test.Expectation, awsTags(test.Input)) } } + +func Test_parseNamespace(t *testing.T) { + tests := []struct { + name string + hostname string + namespaces []*sdtypes.NamespaceSummary + wantNS string + }{ + { + name: "simple service name", + hostname: "foo.dev.local", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("dev.local")}, + }, + wantNS: "dev.local", + }, + { + name: "dotted service name", + hostname: "foo.bar.dev.local", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("dev.local")}, + }, + wantNS: "dev.local", + }, + { + name: "SRV-style hostname", + hostname: "_tcp.backend.mynet.internal", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("mynet.internal")}, + }, + wantNS: "mynet.internal", + }, + { + name: "longest namespace match wins", + hostname: "foo.a.b.c", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("b.c")}, + {Name: aws.String("a.b.c")}, + }, + wantNS: "a.b.c", + }, + { + name: "no matching namespace falls back to first-dot split", + hostname: "foo.unknown.tld", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("dev.local")}, + }, + wantNS: "unknown.tld", + }, + { + name: "empty namespaces falls back to first-dot split", + hostname: "foo.bar.baz", + namespaces: []*sdtypes.NamespaceSummary{}, + wantNS: "bar.baz", + }, + { + name: "nil namespaces falls back to first-dot split", + hostname: "foo.bar.baz", + namespaces: nil, + wantNS: "bar.baz", + }, + { + name: "trailing dot is stripped before matching", + hostname: "foo.bar.dev.local.", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("dev.local")}, + }, + wantNS: "dev.local", + }, + { + name: "hostname is namespace only, no service prefix", + hostname: "dev.local", + namespaces: []*sdtypes.NamespaceSummary{ + {Name: aws.String("dev.local")}, + }, + wantNS: "local", + }, + { + name: "single label hostname, no dots", + hostname: "foo", + namespaces: nil, + wantNS: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotNS := parseNamespace(tc.hostname, tc.namespaces) + assert.Equal(t, tc.wantNS, gotNS) + }) + } +}