From b2c6bd9f1744303b3b666a7338bbf9ef9e920d7e Mon Sep 17 00:00:00 2001 From: Harshita Yadav Date: Sun, 17 May 2026 16:48:07 +0530 Subject: [PATCH] backend: Fix port-forward cache key collisions --- backend/pkg/portforward/handler.go | 35 ++++--- backend/pkg/portforward/handler_test.go | 19 ++-- backend/pkg/portforward/internal_test.go | 127 +++++++++++++++++++++-- backend/pkg/portforward/store.go | 36 ++++--- 4 files changed, 167 insertions(+), 50 deletions(-) diff --git a/backend/pkg/portforward/handler.go b/backend/pkg/portforward/handler.go index b6cee9e2212..2db812094bd 100644 --- a/backend/pkg/portforward/handler.go +++ b/backend/pkg/portforward/handler.go @@ -86,6 +86,7 @@ func (p *portForwardRequest) Validate() error { type portForward struct { mu *sync.Mutex ID string `json:"id"` + cacheCluster string closeChan chan struct{} Pod string `json:"pod"` Service string `json:"service"` @@ -98,6 +99,14 @@ type portForward struct { Error string `json:"error"` } +func portforwardClusterKey(clusterName, userID string) string { + if userID == "" { + return clusterName + } + + return clusterName + "/" + userID +} + // setStatusAndSnapshot updates the Status and Error fields and returns a // snapshot of the struct. When mu is initialized (production path), both // the update and the snapshot are performed within a single critical section. @@ -172,7 +181,9 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache token, _ := auth.GetTokenFromCookie(r, mux.Vars(r)["clusterName"]) userID := r.Header.Get("X-HEADLAMP-USER-ID") - clusterName := mux.Vars(r)["clusterName"] + routeClusterName := mux.Vars(r)["clusterName"] + clusterName := routeClusterName + cacheCluster := portforwardClusterKey(routeClusterName, userID) if userID != "" { clusterName += userID @@ -187,7 +198,7 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache return } - err = startPortForward(kContext, cache, p, token, clusterName) + err = startPortForward(kContext, cache, p, token, routeClusterName, cacheCluster) if err != nil { logger.Log(logger.LevelError, nil, err, "starting portforward") http.Error(w, err.Error(), http.StatusInternalServerError) @@ -587,7 +598,7 @@ func runAndMonitorPortForward( // startPortForward starts a port forward. This is the internal function that was refactored. // It sets up Kubernetes clients, initializes the port forwarder, and manages its lifecycle. func startPortForward(kContext *kubeconfig.Context, cache cache.Cache[interface{}], - p portForwardRequest, token string, clusterName string, + p portForwardRequest, token string, clusterName string, cacheCluster string, ) error { clientset, rConf, err := getKubeClientAndConfig(kContext, token) if err != nil { @@ -621,6 +632,7 @@ func startPortForward(kContext *kubeconfig.Context, cache cache.Cache[interface{ pfDetails := &portForward{ mu: &sync.Mutex{}, ID: p.ID, + cacheCluster: cacheCluster, closeChan: stopChan, Pod: p.Pod, Cluster: clusterName, @@ -686,10 +698,7 @@ func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWrit userID := r.Header.Get("X-HEADLAMP-USER-ID") clusterName := mux.Vars(r)["clusterName"] - - if userID != "" { - clusterName += userID - } + clusterName = portforwardClusterKey(clusterName, userID) err = stopOrDeletePortForward(cache, clusterName, p.ID, p.StopOrDelete) if err == nil { @@ -715,11 +724,7 @@ func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *h } userID := r.Header.Get("X-HEADLAMP-USER-ID") - clusterName := cluster - - if userID != "" { - clusterName = cluster + userID - } + clusterName := portforwardClusterKey(cluster, userID) ports := getPortForwardList(cache, clusterName) @@ -752,11 +757,7 @@ func GetPortForwardByID(cache cache.Cache[interface{}], w http.ResponseWriter, r } userID := r.Header.Get("X-HEADLAMP-USER-ID") - clusterName := cluster - - if userID != "" { - clusterName = cluster + userID - } + clusterName := portforwardClusterKey(cluster, userID) p, err := getPortForwardByID(cache, clusterName, id) if err != nil { diff --git a/backend/pkg/portforward/handler_test.go b/backend/pkg/portforward/handler_test.go index f7afcc3742c..0727786bbe5 100644 --- a/backend/pkg/portforward/handler_test.go +++ b/backend/pkg/portforward/handler_test.go @@ -214,14 +214,6 @@ func TestStartPortForward(t *testing.T) { require.NoError(t, err) require.Contains(t, string(stopRespBody), "stopped") - cacheKey := "PORT_FORWARD_" + minikubeName + id - chState, err := ch.Get(context.Background(), cacheKey) - require.NoError(t, err, "failed to get port-forward state from cache with key %s", cacheKey) - - chData, err := json.Marshal(chState) - require.NoError(t, err) - assert.Contains(t, string(chData), "Stopped") - listReq := &http.Request{ Header: make(http.Header), } @@ -325,7 +317,12 @@ func TestStartPortForward(t *testing.T) { require.NoError(t, err) require.Contains(t, string(deleteRespBody), "stopped") - chState, err = ch.Get(context.Background(), cacheKey) - require.Error(t, err, "port-forward with key %s should be deleted from cache, but Get returned no error", cacheKey) - require.Nil(t, chState) + getAfterDeleteResp := httptest.NewRecorder() + portforward.GetPortForwardByID(ch, getAfterDeleteResp, getReq) + + getAfterDeleteRes := getAfterDeleteResp.Result() + + defer func() { _ = getAfterDeleteRes.Body.Close() }() + + assert.Equal(t, http.StatusNotFound, getAfterDeleteRes.StatusCode) } diff --git a/backend/pkg/portforward/internal_test.go b/backend/pkg/portforward/internal_test.go index 7a1bb17ca22..0f4177adb9f 100644 --- a/backend/pkg/portforward/internal_test.go +++ b/backend/pkg/portforward/internal_test.go @@ -103,16 +103,21 @@ func TestPortforwardKeyGenerator(t *testing.T) { p portForward want string }{ - {"only_cluster_id", portForward{ID: "id", Cluster: "cluster"}, "PORT_FORWARD_clusterid"}, - {"only_service", portForward{Cluster: "cluster", Service: "service"}, "PORT_FORWARD_clusterservice"}, - {"only_pod", portForward{Cluster: "cluster", Pod: "pod"}, "PORT_FORWARD_clusterpod"}, - {"service_and_pod", portForward{Cluster: "cluster", Service: "service", Pod: "pod"}, "PORT_FORWARD_clusterservice"}, - {"id_and_service", portForward{Cluster: "cluster", ID: "id", Service: "service"}, "PORT_FORWARD_clusterid"}, - {"id_and_pod", portForward{Cluster: "cluster", ID: "id", Pod: "pod"}, "PORT_FORWARD_clusterid"}, + {"only_cluster_id", portForward{ID: "id", Cluster: "cluster"}, "PORT_FORWARD_cluster:id"}, + {"only_service", portForward{Cluster: "cluster", Service: "service"}, "PORT_FORWARD_cluster:service"}, + {"only_pod", portForward{Cluster: "cluster", Pod: "pod"}, "PORT_FORWARD_cluster:pod"}, + {"service_and_pod", portForward{Cluster: "cluster", Service: "service", Pod: "pod"}, "PORT_FORWARD_cluster:service"}, + {"id_and_service", portForward{Cluster: "cluster", ID: "id", Service: "service"}, "PORT_FORWARD_cluster:id"}, + {"id_and_pod", portForward{Cluster: "cluster", ID: "id", Pod: "pod"}, "PORT_FORWARD_cluster:id"}, { "id_and_service_and_pod", portForward{Cluster: "cluster", ID: "id", Service: "service", Pod: "pod"}, - "PORT_FORWARD_clusterid", + "PORT_FORWARD_cluster:id", + }, + { + "delimiter_chars", + portForward{Cluster: "foo:bar", ID: "baz:qux"}, + "PORT_FORWARD_foo%3Abar:baz%3Aqux", }, } @@ -213,6 +218,53 @@ func TestGetPortForwardList(t *testing.T) { assert.ElementsMatch(t, []portForward{p3}, pfList) } +func TestGetPortForwardList_ClusterPrefixCollision(t *testing.T) { + tests := []struct { + name string + first portForward + second portForward + cluster string + want []portForward + }{ + { + name: "similar_names", + first: portForward{ID: "id1", Cluster: "foo"}, + second: portForward{ID: "id2", Cluster: "foobar"}, + cluster: "foo", + want: []portForward{{ID: "id1", Cluster: "foo"}}, + }, + { + name: "delimiter_in_cluster_name", + first: portForward{ID: "id1", Cluster: "foo"}, + second: portForward{ID: "id2", Cluster: "foo:bar"}, + cluster: "foo", + want: []portForward{{ID: "id1", Cluster: "foo"}}, + }, + { + name: "delimiter_in_id", + first: portForward{ID: "bar:baz", Cluster: "foo"}, + second: portForward{ID: "baz", Cluster: "foo:bar"}, + cluster: "foo", + want: []portForward{{ID: "bar:baz", Cluster: "foo"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cache := cache.New[interface{}]() + + err := cache.Set(context.Background(), portforwardKeyGenerator(tt.first), tt.first) + require.NoError(t, err) + + err = cache.Set(context.Background(), portforwardKeyGenerator(tt.second), tt.second) + require.NoError(t, err) + + pfList := getPortForwardList(cache, tt.cluster) + assert.ElementsMatch(t, tt.want, pfList) + }) + } +} + // Test portForwardRequest.Validate() function. func TestPortForwardRequestValidate(t *testing.T) { req := portForwardRequest{} @@ -505,10 +557,31 @@ func TestGetPortForwardList_UserIDKeyIsolation(t *testing.T) { assert.Equal(t, "pf-1", result[0].ID) // Query with a user-specific key — should NOT find the entry. - resultWithUser := getPortForwardList(c, "clusteruser123") + resultWithUser := getPortForwardList(c, portforwardClusterKey("cluster", "user123")) assert.Empty(t, resultWithUser, "user-specific key must not return entries stored under base cluster key") } +func TestGetPortForwardList_UserIDCollision(t *testing.T) { + c := cache.New[interface{}]() + + userPF := portForward{ + ID: "pf-user", + cacheCluster: portforwardClusterKey("foo", "bar"), + Cluster: "foo", + Status: RUNNING, + } + clusterPF := portForward{ID: "pf-cluster", Cluster: "foobar", Status: RUNNING} + + portforwardstore(c, userPF) + portforwardstore(c, clusterPF) + + userResult := getPortForwardList(c, portforwardClusterKey("foo", "bar")) + assert.ElementsMatch(t, []portForward{userPF}, userResult) + + clusterResult := getPortForwardList(c, "foobar") + assert.ElementsMatch(t, []portForward{clusterPF}, clusterResult) +} + // TestGetPortForwardByID_UserIDKeyIsolation verifies that getPortForwardByID // uses the correct key when a user ID is part of the cluster name. func TestGetPortForwardByID_UserIDKeyIsolation(t *testing.T) { @@ -524,7 +597,7 @@ func TestGetPortForwardByID_UserIDKeyIsolation(t *testing.T) { assert.Equal(t, "pf-2", found.ID) // Lookup with a user-specific key — should fail. - _, err = getPortForwardByID(c, "clusteruser456", "pf-2") + _, err = getPortForwardByID(c, portforwardClusterKey("cluster", "user456"), "pf-2") assert.Error(t, err, "user-specific key must not find entries stored under base cluster key") } @@ -613,8 +686,42 @@ func TestGetPortForwardByIDHandler_UserIDKeyIsolation(t *testing.T) { "user-specific lookup must not find entries under base cluster key") } +func TestGetPortForwardByIDHandler_UserIDKeepsRouteCluster(t *testing.T) { + c := cache.New[interface{}]() + + pf := portForward{ + ID: "pf-user", + cacheCluster: portforwardClusterKey("cluster", "user999"), + Cluster: "cluster", + Pod: "redis", + Namespace: "cache", + Status: RUNNING, + } + portforwardstore(c, pf) + + w := httptest.NewRecorder() + r := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/portforward?id=pf-user", nil) + r = mux.SetURLVars(r, map[string]string{"clusterName": "cluster"}) + r.URL = &url.URL{RawQuery: "id=pf-user"} + r.Header.Set("X-HEADLAMP-USER-ID", "user999") + + GetPortForwardByID(c, w, r) + + res := w.Result() + + defer func() { _ = res.Body.Close() }() + + require.Equal(t, http.StatusOK, res.StatusCode) + + var payload map[string]string + + err := json.NewDecoder(res.Body).Decode(&payload) + require.NoError(t, err) + assert.Equal(t, "cluster", payload["cluster"]) +} + // TestStopOrDeletePortForwardHandler_UserIDKeyIsolation verifies that -// StopOrDeletePortForward uses cluster+userID as the cache key. +// StopOrDeletePortForward uses the user-scoped cluster cache key. func TestStopOrDeletePortForwardHandler_UserIDKeyIsolation(t *testing.T) { c := cache.New[interface{}]() diff --git a/backend/pkg/portforward/store.go b/backend/pkg/portforward/store.go index bad4233e30c..6e89bb5f7de 100644 --- a/backend/pkg/portforward/store.go +++ b/backend/pkg/portforward/store.go @@ -19,6 +19,7 @@ package portforward import ( "context" "fmt" + "net/url" "strings" "github.com/kubernetes-sigs/headlamp/backend/pkg/cache" @@ -27,22 +28,30 @@ import ( const storeKeyPrefix = "PORT_FORWARD_" +func portforwardCacheKeyPrefix(cluster string) string { + return storeKeyPrefix + url.QueryEscape(cluster) + ":" +} + // portforwardKeyGenerator generates a unique key // based on the cluster name, id,service name, and pod name. func portforwardKeyGenerator(p portForward) string { - if p.ID != "" { - return storeKeyPrefix + p.Cluster + p.ID + cluster := p.Cluster + if p.cacheCluster != "" { + cluster = p.cacheCluster } - key := storeKeyPrefix + p.Cluster - - if p.Service != "" { - key += p.Service - } else if p.Pod != "" { - key += p.Pod + key := portforwardCacheKeyPrefix(cluster) + + switch { + case p.ID != "": + return key + url.QueryEscape(p.ID) + case p.Service != "": + return key + url.QueryEscape(p.Service) + case p.Pod != "": + return key + url.QueryEscape(p.Pod) + default: + return key } - - return key } // portforwardstore stores a port forward in the cache. @@ -91,7 +100,7 @@ func stopOrDeletePortForward(cache cache.Cache[interface{}], cluster string, id // getPortForwardList returns a list of port forwards by its cluster name. func getPortForwardList(cache cache.Cache[interface{}], cluster string) []portForward { portforwards, err := cache.GetAll(context.Background(), func(key string) bool { - return strings.HasPrefix(key, storeKeyPrefix+cluster) + return strings.HasPrefix(key, portforwardCacheKeyPrefix(cluster)) }) if err != nil { logger.Log(logger.LevelError, map[string]string{"cluster": cluster}, @@ -110,7 +119,10 @@ func getPortForwardList(cache cache.Cache[interface{}], cluster string) []portFo // getPortForwardByID returns a port forward by its cluster name and id. func getPortForwardByID(cache cache.Cache[interface{}], cluster string, id string) (portForward, error) { - cacheValue, err := cache.Get(context.Background(), storeKeyPrefix+cluster+id) + cacheValue, err := cache.Get(context.Background(), portforwardKeyGenerator(portForward{ + cacheCluster: cluster, + ID: id, + })) if err != nil { return portForward{}, fmt.Errorf("failed to get portforward from cache: %v", err) }