Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions backend/pkg/portforward/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func (p *portForwardRequest) Validate() error {
type portForward struct {
mu *sync.Mutex
ID string `json:"id"`
cacheCluster string
closeChan chan struct{}
Pod string `json:"pod"`
Service string `json:"service"`
Expand All @@ -98,6 +99,14 @@ type portForward struct {
Error string `json:"error"`
}

func portforwardClusterKey(clusterName, userID string) string {
if userID == "" {
return clusterName
}

return clusterName + "/" + userID
Comment thread
harrshita123 marked this conversation as resolved.
}

// setStatusAndSnapshot updates the Status and Error fields and returns a
// snapshot of the struct. When mu is initialized (production path), both
// the update and the snapshot are performed within a single critical section.
Expand Down Expand Up @@ -172,7 +181,9 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache
token, _ := auth.GetTokenFromCookie(r, mux.Vars(r)["clusterName"])

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := mux.Vars(r)["clusterName"]
routeClusterName := mux.Vars(r)["clusterName"]
clusterName := routeClusterName
cacheCluster := portforwardClusterKey(routeClusterName, userID)

if userID != "" {
clusterName += userID
Expand All @@ -187,7 +198,7 @@ func StartPortForward(kubeConfigStore kubeconfig.ContextStore, cache cache.Cache
return
}

err = startPortForward(kContext, cache, p, token, clusterName)
err = startPortForward(kContext, cache, p, token, routeClusterName, cacheCluster)
if err != nil {
logger.Log(logger.LevelError, nil, err, "starting portforward")
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -587,7 +598,7 @@ func runAndMonitorPortForward(
// startPortForward starts a port forward. This is the internal function that was refactored.
// It sets up Kubernetes clients, initializes the port forwarder, and manages its lifecycle.
func startPortForward(kContext *kubeconfig.Context, cache cache.Cache[interface{}],
p portForwardRequest, token string, clusterName string,
p portForwardRequest, token string, clusterName string, cacheCluster string,
) error {
clientset, rConf, err := getKubeClientAndConfig(kContext, token)
if err != nil {
Expand Down Expand Up @@ -621,6 +632,7 @@ func startPortForward(kContext *kubeconfig.Context, cache cache.Cache[interface{
pfDetails := &portForward{
mu: &sync.Mutex{},
ID: p.ID,
cacheCluster: cacheCluster,
closeChan: stopChan,
Pod: p.Pod,
Cluster: clusterName,
Expand Down Expand Up @@ -686,10 +698,7 @@ func StopOrDeletePortForward(cache cache.Cache[interface{}], w http.ResponseWrit

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := mux.Vars(r)["clusterName"]

if userID != "" {
clusterName += userID
}
clusterName = portforwardClusterKey(clusterName, userID)

err = stopOrDeletePortForward(cache, clusterName, p.ID, p.StopOrDelete)
if err == nil {
Expand All @@ -715,11 +724,7 @@ func GetPortForwards(cache cache.Cache[interface{}], w http.ResponseWriter, r *h
}

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := cluster

if userID != "" {
clusterName = cluster + userID
}
clusterName := portforwardClusterKey(cluster, userID)

ports := getPortForwardList(cache, clusterName)

Expand Down Expand Up @@ -752,11 +757,7 @@ func GetPortForwardByID(cache cache.Cache[interface{}], w http.ResponseWriter, r
}

userID := r.Header.Get("X-HEADLAMP-USER-ID")
clusterName := cluster

if userID != "" {
clusterName = cluster + userID
}
clusterName := portforwardClusterKey(cluster, userID)

p, err := getPortForwardByID(cache, clusterName, id)
if err != nil {
Expand Down
19 changes: 8 additions & 11 deletions backend/pkg/portforward/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,6 @@ func TestStartPortForward(t *testing.T) {
require.NoError(t, err)
require.Contains(t, string(stopRespBody), "stopped")

cacheKey := "PORT_FORWARD_" + minikubeName + id
chState, err := ch.Get(context.Background(), cacheKey)
require.NoError(t, err, "failed to get port-forward state from cache with key %s", cacheKey)

chData, err := json.Marshal(chState)
require.NoError(t, err)
assert.Contains(t, string(chData), "Stopped")

listReq := &http.Request{
Header: make(http.Header),
}
Expand Down Expand Up @@ -325,7 +317,12 @@ func TestStartPortForward(t *testing.T) {
require.NoError(t, err)
require.Contains(t, string(deleteRespBody), "stopped")

chState, err = ch.Get(context.Background(), cacheKey)
require.Error(t, err, "port-forward with key %s should be deleted from cache, but Get returned no error", cacheKey)
require.Nil(t, chState)
getAfterDeleteResp := httptest.NewRecorder()
portforward.GetPortForwardByID(ch, getAfterDeleteResp, getReq)

getAfterDeleteRes := getAfterDeleteResp.Result()

defer func() { _ = getAfterDeleteRes.Body.Close() }()

assert.Equal(t, http.StatusNotFound, getAfterDeleteRes.StatusCode)
}
127 changes: 117 additions & 10 deletions backend/pkg/portforward/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,21 @@ func TestPortforwardKeyGenerator(t *testing.T) {
p portForward
want string
}{
{"only_cluster_id", portForward{ID: "id", Cluster: "cluster"}, "PORT_FORWARD_clusterid"},
{"only_service", portForward{Cluster: "cluster", Service: "service"}, "PORT_FORWARD_clusterservice"},
{"only_pod", portForward{Cluster: "cluster", Pod: "pod"}, "PORT_FORWARD_clusterpod"},
{"service_and_pod", portForward{Cluster: "cluster", Service: "service", Pod: "pod"}, "PORT_FORWARD_clusterservice"},
{"id_and_service", portForward{Cluster: "cluster", ID: "id", Service: "service"}, "PORT_FORWARD_clusterid"},
{"id_and_pod", portForward{Cluster: "cluster", ID: "id", Pod: "pod"}, "PORT_FORWARD_clusterid"},
{"only_cluster_id", portForward{ID: "id", Cluster: "cluster"}, "PORT_FORWARD_cluster:id"},
{"only_service", portForward{Cluster: "cluster", Service: "service"}, "PORT_FORWARD_cluster:service"},
{"only_pod", portForward{Cluster: "cluster", Pod: "pod"}, "PORT_FORWARD_cluster:pod"},
{"service_and_pod", portForward{Cluster: "cluster", Service: "service", Pod: "pod"}, "PORT_FORWARD_cluster:service"},
{"id_and_service", portForward{Cluster: "cluster", ID: "id", Service: "service"}, "PORT_FORWARD_cluster:id"},
{"id_and_pod", portForward{Cluster: "cluster", ID: "id", Pod: "pod"}, "PORT_FORWARD_cluster:id"},
{
"id_and_service_and_pod",
portForward{Cluster: "cluster", ID: "id", Service: "service", Pod: "pod"},
"PORT_FORWARD_clusterid",
"PORT_FORWARD_cluster:id",
},
{
"delimiter_chars",
portForward{Cluster: "foo:bar", ID: "baz:qux"},
"PORT_FORWARD_foo%3Abar:baz%3Aqux",
},
}

Expand Down Expand Up @@ -213,6 +218,53 @@ func TestGetPortForwardList(t *testing.T) {
assert.ElementsMatch(t, []portForward{p3}, pfList)
}

func TestGetPortForwardList_ClusterPrefixCollision(t *testing.T) {
tests := []struct {
name string
first portForward
second portForward
cluster string
want []portForward
}{
{
name: "similar_names",
first: portForward{ID: "id1", Cluster: "foo"},
second: portForward{ID: "id2", Cluster: "foobar"},
cluster: "foo",
want: []portForward{{ID: "id1", Cluster: "foo"}},
},
{
name: "delimiter_in_cluster_name",
first: portForward{ID: "id1", Cluster: "foo"},
second: portForward{ID: "id2", Cluster: "foo:bar"},
cluster: "foo",
want: []portForward{{ID: "id1", Cluster: "foo"}},
},
{
name: "delimiter_in_id",
first: portForward{ID: "bar:baz", Cluster: "foo"},
second: portForward{ID: "baz", Cluster: "foo:bar"},
cluster: "foo",
want: []portForward{{ID: "bar:baz", Cluster: "foo"}},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := cache.New[interface{}]()

err := cache.Set(context.Background(), portforwardKeyGenerator(tt.first), tt.first)
require.NoError(t, err)

err = cache.Set(context.Background(), portforwardKeyGenerator(tt.second), tt.second)
require.NoError(t, err)

pfList := getPortForwardList(cache, tt.cluster)
assert.ElementsMatch(t, tt.want, pfList)
})
}
}

// Test portForwardRequest.Validate() function.
func TestPortForwardRequestValidate(t *testing.T) {
req := portForwardRequest{}
Expand Down Expand Up @@ -505,10 +557,31 @@ func TestGetPortForwardList_UserIDKeyIsolation(t *testing.T) {
assert.Equal(t, "pf-1", result[0].ID)

// Query with a user-specific key — should NOT find the entry.
resultWithUser := getPortForwardList(c, "clusteruser123")
resultWithUser := getPortForwardList(c, portforwardClusterKey("cluster", "user123"))
assert.Empty(t, resultWithUser, "user-specific key must not return entries stored under base cluster key")
}

func TestGetPortForwardList_UserIDCollision(t *testing.T) {
c := cache.New[interface{}]()

userPF := portForward{
ID: "pf-user",
cacheCluster: portforwardClusterKey("foo", "bar"),
Cluster: "foo",
Status: RUNNING,
}
clusterPF := portForward{ID: "pf-cluster", Cluster: "foobar", Status: RUNNING}

portforwardstore(c, userPF)
portforwardstore(c, clusterPF)

userResult := getPortForwardList(c, portforwardClusterKey("foo", "bar"))
assert.ElementsMatch(t, []portForward{userPF}, userResult)

clusterResult := getPortForwardList(c, "foobar")
assert.ElementsMatch(t, []portForward{clusterPF}, clusterResult)
}

// TestGetPortForwardByID_UserIDKeyIsolation verifies that getPortForwardByID
// uses the correct key when a user ID is part of the cluster name.
func TestGetPortForwardByID_UserIDKeyIsolation(t *testing.T) {
Expand All @@ -524,7 +597,7 @@ func TestGetPortForwardByID_UserIDKeyIsolation(t *testing.T) {
assert.Equal(t, "pf-2", found.ID)

// Lookup with a user-specific key — should fail.
_, err = getPortForwardByID(c, "clusteruser456", "pf-2")
_, err = getPortForwardByID(c, portforwardClusterKey("cluster", "user456"), "pf-2")
assert.Error(t, err, "user-specific key must not find entries stored under base cluster key")
}

Expand Down Expand Up @@ -613,8 +686,42 @@ func TestGetPortForwardByIDHandler_UserIDKeyIsolation(t *testing.T) {
"user-specific lookup must not find entries under base cluster key")
}

func TestGetPortForwardByIDHandler_UserIDKeepsRouteCluster(t *testing.T) {
c := cache.New[interface{}]()

pf := portForward{
ID: "pf-user",
cacheCluster: portforwardClusterKey("cluster", "user999"),
Cluster: "cluster",
Pod: "redis",
Namespace: "cache",
Status: RUNNING,
}
portforwardstore(c, pf)

w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/portforward?id=pf-user", nil)
r = mux.SetURLVars(r, map[string]string{"clusterName": "cluster"})
r.URL = &url.URL{RawQuery: "id=pf-user"}
r.Header.Set("X-HEADLAMP-USER-ID", "user999")

GetPortForwardByID(c, w, r)

res := w.Result()

defer func() { _ = res.Body.Close() }()

require.Equal(t, http.StatusOK, res.StatusCode)

var payload map[string]string

err := json.NewDecoder(res.Body).Decode(&payload)
require.NoError(t, err)
assert.Equal(t, "cluster", payload["cluster"])
}

// TestStopOrDeletePortForwardHandler_UserIDKeyIsolation verifies that
// StopOrDeletePortForward uses cluster+userID as the cache key.
// StopOrDeletePortForward uses the user-scoped cluster cache key.
func TestStopOrDeletePortForwardHandler_UserIDKeyIsolation(t *testing.T) {
c := cache.New[interface{}]()

Expand Down
36 changes: 24 additions & 12 deletions backend/pkg/portforward/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package portforward
import (
"context"
"fmt"
"net/url"
"strings"

"github.com/kubernetes-sigs/headlamp/backend/pkg/cache"
Expand All @@ -27,22 +28,30 @@ import (

const storeKeyPrefix = "PORT_FORWARD_"

func portforwardCacheKeyPrefix(cluster string) string {
return storeKeyPrefix + url.QueryEscape(cluster) + ":"
}

// portforwardKeyGenerator generates a unique key
// based on the cluster name, id,service name, and pod name.
func portforwardKeyGenerator(p portForward) string {
if p.ID != "" {
return storeKeyPrefix + p.Cluster + p.ID
cluster := p.Cluster
if p.cacheCluster != "" {
cluster = p.cacheCluster
}

key := storeKeyPrefix + p.Cluster

if p.Service != "" {
key += p.Service
} else if p.Pod != "" {
key += p.Pod
key := portforwardCacheKeyPrefix(cluster)

switch {
case p.ID != "":
return key + url.QueryEscape(p.ID)
case p.Service != "":
return key + url.QueryEscape(p.Service)
case p.Pod != "":
return key + url.QueryEscape(p.Pod)
default:
return key
}

return key
}

// portforwardstore stores a port forward in the cache.
Expand Down Expand Up @@ -91,7 +100,7 @@ func stopOrDeletePortForward(cache cache.Cache[interface{}], cluster string, id
// getPortForwardList returns a list of port forwards by its cluster name.
func getPortForwardList(cache cache.Cache[interface{}], cluster string) []portForward {
portforwards, err := cache.GetAll(context.Background(), func(key string) bool {
return strings.HasPrefix(key, storeKeyPrefix+cluster)
return strings.HasPrefix(key, portforwardCacheKeyPrefix(cluster))
})
if err != nil {
logger.Log(logger.LevelError, map[string]string{"cluster": cluster},
Expand All @@ -110,7 +119,10 @@ func getPortForwardList(cache cache.Cache[interface{}], cluster string) []portFo

// getPortForwardByID returns a port forward by its cluster name and id.
func getPortForwardByID(cache cache.Cache[interface{}], cluster string, id string) (portForward, error) {
cacheValue, err := cache.Get(context.Background(), storeKeyPrefix+cluster+id)
cacheValue, err := cache.Get(context.Background(), portforwardKeyGenerator(portForward{
cacheCluster: cluster,
ID: id,
}))
if err != nil {
return portForward{}, fmt.Errorf("failed to get portforward from cache: %v", err)
}
Expand Down
Loading