diff --git a/endpoint/target_filter.go b/endpoint/target_filter.go index 2706155e91..fc3ba22e43 100644 --- a/endpoint/target_filter.go +++ b/endpoint/target_filter.go @@ -25,7 +25,7 @@ import ( // TargetFilterInterface defines the interface to select matching targets for a specific provider or runtime type TargetFilterInterface interface { - Match(target string) bool + Match(target string, ep *Endpoint) bool IsEnabled() bool } @@ -60,7 +60,12 @@ func NewTargetNetFilterWithExclusions(targetFilterNets []string, excludeNets []s } // Match checks whether a target can be found in the TargetNetFilter. -func (tf TargetNetFilter) Match(target string) bool { +func (tf TargetNetFilter) Match(target string, ep *Endpoint) bool { + // Target network filter is only relevant for A/AAAA records containing IPv4/IPv6 address + // therefore we can add other endpoints directly and skip filtering + if ep.RecordType != RecordTypeA && ep.RecordType != RecordTypeAAAA { + return true + } return matchTargetNetFilter(tf.filterNets, target, true) && !matchTargetNetFilter(tf.excludeNets, target, false) } diff --git a/endpoint/target_filter_test.go b/endpoint/target_filter_test.go index d803093c17..09da6575e9 100644 --- a/endpoint/target_filter_test.go +++ b/endpoint/target_filter_test.go @@ -23,61 +23,78 @@ import ( ) type targetFilterTest struct { + stubEndpoint *Endpoint targetFilter []string exclusions []string targets []string expected bool } +var targetFilterTestsStubEndpointNetwork = NewEndpoint("test-a.example.com", RecordTypeA) +var targetFilterTestsStubEndpointCname = NewEndpoint("test-cname.example.com", RecordTypeCNAME) var targetFilterTests = []targetFilterTest{ { + targetFilterTestsStubEndpointNetwork, []string{"10.0.0.0/8"}, []string{}, []string{"10.1.2.3"}, true, }, { + targetFilterTestsStubEndpointNetwork, []string{" 10.0.0.0/8 "}, []string{}, []string{"10.1.2.3"}, true, }, { + targetFilterTestsStubEndpointNetwork, []string{"0"}, []string{}, []string{"10.1.2.3"}, true, }, { + targetFilterTestsStubEndpointNetwork, []string{"10.0.0.0/8"}, []string{}, []string{"1.1.1.1"}, false, }, { + targetFilterTestsStubEndpointNetwork, []string{}, []string{"10.0.0.0/8"}, []string{"1.1.1.1"}, true, }, { + targetFilterTestsStubEndpointNetwork, []string{}, []string{"10.0.0.0/8"}, []string{"10.1.2.3"}, false, }, { + targetFilterTestsStubEndpointNetwork, []string{}, []string{"10.0.0.0/8"}, []string{"49.13.41.161"}, true, }, { + targetFilterTestsStubEndpointNetwork, []string{}, []string{"10.0.0.0/8"}, []string{"10.0.1.101"}, false, }, + {targetFilterTestsStubEndpointCname, + []string{"10.0.0.0/8"}, + []string{"10.1.0.0/24"}, + []string{"10.2.254.254", "10.1.1.1", "cname-1.example.com", "random text data"}, + true, + }, } func TestTargetFilterWithExclusions(t *testing.T) { @@ -87,7 +104,7 @@ func TestTargetFilterWithExclusions(t *testing.T) { } targetFilter := NewTargetNetFilterWithExclusions(tt.targetFilter, tt.exclusions) for _, target := range tt.targets { - assert.Equal(t, tt.expected, targetFilter.Match(target), "should not fail: %v in test-case #%v", target, i) + assert.Equal(t, tt.expected, targetFilter.Match(target, tt.stubEndpoint), "should not fail: %v in test-case #%v", target, i) } } } @@ -96,7 +113,7 @@ func TestTargetFilterMatchWithEmptyFilter(t *testing.T) { for _, tt := range targetFilterTests { targetFilter := TargetNetFilter{} for i, target := range tt.targets { - assert.True(t, targetFilter.Match(target), "should not fail: %v in test-case #%v", target, i) + assert.True(t, targetFilter.Match(target, tt.stubEndpoint), "should not fail: %v in test-case #%v", target, i) } } } diff --git a/source/wrappers/targetfiltersource.go b/source/wrappers/targetfiltersource.go index 7cbbaa8ceb..cd0efc6526 100644 --- a/source/wrappers/targetfiltersource.go +++ b/source/wrappers/targetfiltersource.go @@ -55,7 +55,7 @@ func (ms *targetFilterSource) Endpoints(ctx context.Context) ([]*endpoint.Endpoi filteredTargets := make([]string, 0, len(ep.Targets)) for _, t := range ep.Targets { - if ms.targetFilter.Match(t) { + if ms.targetFilter.Match(t, ep) { filteredTargets = append(filteredTargets, t) } } diff --git a/source/wrappers/targetfiltersource_test.go b/source/wrappers/targetfiltersource_test.go index ca711d94e9..5d6d9c537e 100644 --- a/source/wrappers/targetfiltersource_test.go +++ b/source/wrappers/targetfiltersource_test.go @@ -40,7 +40,7 @@ func NewMockTargetNetFilter(targets []string) endpoint.TargetFilterInterface { return &mockTargetNetFilter{targets: targetMap} } -func (m *mockTargetNetFilter) Match(target string) bool { +func (m *mockTargetNetFilter) Match(target string, _ *endpoint.Endpoint) bool { return m.targets[target] } @@ -227,6 +227,43 @@ func TestTargetFilterConcreteTargetFilter(t *testing.T) { } } +func TestTargetFilterNonAddressRecords(t *testing.T) { + tests := []struct { + title string + filters endpoint.TargetFilterInterface + endpoints []*endpoint.Endpoint + expected []*endpoint.Endpoint + }{ + { + title: "should pass CNAME records", + filters: endpoint.NewTargetNetFilterWithExclusions([]string{"10.0.0.0/8", "91ca::/16"}, []string{}), + endpoints: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.1.2.3"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "192.168.7.1"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "91ca:beef:83b2:beef:1490:7604:d192:b326"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "2ac3:beef:83b2:beef:1490:beef:d192:beef"), + endpoint.NewEndpoint("foo-cname", endpoint.RecordTypeCNAME, "target-rr.example.com"), + }, + expected: []*endpoint.Endpoint{ + endpoint.NewEndpoint("foo", endpoint.RecordTypeA, "10.1.2.3"), + endpoint.NewEndpoint("foo", endpoint.RecordTypeAAAA, "91ca:beef:83b2:beef:1490:7604:d192:b326"), + endpoint.NewEndpoint("foo-cname", endpoint.RecordTypeCNAME, "target-rr.example.com"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + echo := testutils.NewMockSource(tt.endpoints...) + src := NewTargetFilterSource(echo, tt.filters) + + endpoints, err := src.Endpoints(t.Context()) + require.NoError(t, err, "failed to get Endpoints") + + testutils.ValidateEndpoints(t, endpoints, tt.expected) + }) + } +} + func TestTargetFilterSource_AddEventHandler(t *testing.T) { tests := []struct { title string