diff --git a/deploy/csi-azurefile-driver.yaml b/deploy/csi-azurefile-driver.yaml index e5491b376f..d9f3a7837e 100644 --- a/deploy/csi-azurefile-driver.yaml +++ b/deploy/csi-azurefile-driver.yaml @@ -13,5 +13,7 @@ spec: - Persistent - Ephemeral fsGroupPolicy: ReadWriteOnceWithFSType + requiresRepublish: true tokenRequests: - audience: api://AzureADTokenExchange + expirationSeconds: 3600 diff --git a/pkg/azurefile/azurefile.go b/pkg/azurefile/azurefile.go index 40b5fe169d..8451f33aa8 100644 --- a/pkg/azurefile/azurefile.go +++ b/pkg/azurefile/azurefile.go @@ -20,10 +20,13 @@ import ( "bytes" "context" "encoding/binary" + "encoding/json" "errors" "fmt" "net/http" "net/url" + "os" + "path/filepath" "strconv" "strings" "sync" @@ -88,6 +91,7 @@ const ( defaultAzureFileQuota = 100 minimumAccountQuota = 100 // GB + DefaultTokenAudience = "api://AzureADTokenExchange/.default" // key of snapshot name in metadata snapshotNameKey = "initiator" @@ -168,6 +172,7 @@ const ( runtimeClassHandlerField = "runtimeclasshandler" defaultRuntimeClassHandler = "kata-cc" mountWithManagedIdentityField = "mountwithmanagedidentity" + mountWithWITokenField = "mountwithworkloadidentitytoken" accountNotProvisioned = "StorageAccountIsNotProvisioned" // this is a workaround fix for 429 throttling issue, will update cloud provider for better fix later @@ -229,6 +234,8 @@ var ( azcopyCloneVolumeOptions = []string{"--recursive", "--check-length=false", "--log-level=ERROR"} // azcopySnapshotRestoreOptions used in smb snapshot restore and set --check-length to true because snapshot data is changeless azcopySnapshotRestoreOptions = []string{"--recursive", "--check-length=true", "--log-level=ERROR"} + + defaultAzureOAuthTokenDir = "/var/lib/kubelet/plugins/" + DefaultDriverName ) // Driver implements all interfaces of CSI drivers @@ -804,8 +811,8 @@ func IsCorruptedDir(dir string) bool { } // GetAccountInfo get account info -// return -func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, reqContext map[string]string) (string, string, string, string, string, string, error) { +// return +func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, reqContext map[string]string) (string, string, string, string, string, string, string, string, error) { rgName, accountName, fileShareName, diskName, secretNamespace, subsID, err := GetFileShareInfo(volumeID) if err != nil { // ignore volumeID parsing error @@ -815,8 +822,8 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r var protocol, accountKey, secretName, pvcNamespace string // getAccountKeyFromSecret indicates whether get account key only from k8s secret - var getAccountKeyFromSecret, getLatestAccountKey, mountWithManagedIdentity bool - var clientID, tenantID, serviceAccountToken string + var getAccountKeyFromSecret, getLatestAccountKey, mountWithManagedIdentity, mountWithWIToken bool + var clientID, tenantID, tokenFilePath, serviceAccountToken string for k, v := range reqContext { switch strings.ToLower(k) { @@ -844,13 +851,17 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r pvcNamespace = v case getLatestAccountKeyField: if getLatestAccountKey, err = strconv.ParseBool(v); err != nil { - return rgName, accountName, accountKey, fileShareName, diskName, subsID, fmt.Errorf("invalid %s: %s in volume context", getLatestAccountKeyField, v) + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid %s: %s in volume context", getLatestAccountKeyField, v) } case clientIDField: clientID = v case mountWithManagedIdentityField: if mountWithManagedIdentity, err = strconv.ParseBool(v); err != nil { - return rgName, accountName, accountKey, fileShareName, diskName, subsID, fmt.Errorf("invalid %s: %s in volume context", mountWithManagedIdentityField, v) + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid %s: %s in volume context", mountWithManagedIdentityField, v) + } + case mountWithWITokenField: + if mountWithWIToken, err = strconv.ParseBool(v); err != nil { + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid %s: %s in volume context", mountWithWITokenField, v) } case tenantIDField: tenantID = v @@ -870,7 +881,7 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r } if protocol == nfs && fileShareName != "" { // nfs protocol does not need account key, return directly - return rgName, accountName, accountKey, fileShareName, diskName, subsID, err + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err } if secretNamespace == "" { @@ -883,13 +894,43 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r if mountWithManagedIdentity { klog.V(2).Infof("mountWithManagedIdentity is true, use managed identity auth") - return rgName, accountName, accountKey, fileShareName, diskName, subsID, nil + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, nil + } + + if mountWithWIToken { + if clientID == "" { + clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID + if clientID == "" { + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("clientID is empty for workload identity auth") + } + } + klog.V(2).Infof("mountWithWorkloadIdentityToken is specified, use workload identity auth for mount, clientID: %s, tenantID: %s", clientID, tenantID) + token, err := parseServiceAccountToken(serviceAccountToken) + if err != nil { + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("failed to parse service account token: %v", err) + } + tokenFileName := clientID + "-" + accountName + if !isValidTokenFileName(tokenFileName) { + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid token file name(%s) generated for clientID(%s) and accountName(%s)", tokenFileName, clientID, accountName) + } + tokenFilePath = filepath.Join(defaultAzureOAuthTokenDir, tokenFileName) + // check whether token value is the same as the one in the token file + existingToken, readErr := os.ReadFile(tokenFilePath) + if readErr == nil && string(existingToken) == token { + klog.V(4).Infof("the token file(%s) already exists and the token value is the same, no need to rewrite the token file", tokenFilePath) + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, "", nil + } + // write token to a file + if err := os.WriteFile(tokenFilePath, []byte(token), 0600); err != nil { + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("failed to write azure oAuth token file(%s): %v", tokenFilePath, err) + } + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err } if clientID != "" { klog.V(2).Infof("clientID(%s) is specified, use service account token to get account key", clientID) accountKey, err := d.cloud.GetStorageAccesskeyFromServiceAccountToken(ctx, subsID, accountName, rgName, clientID, tenantID, serviceAccountToken) - return rgName, accountName, accountKey, fileShareName, diskName, subsID, err + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, "", err } if len(secrets) == 0 { @@ -897,7 +938,7 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r // 1. get account key from cache first cache, errCache := d.accountCacheMap.Get(ctx, accountName, azcache.CacheReadTypeDefault) if errCache != nil { - return rgName, accountName, accountKey, fileShareName, diskName, subsID, errCache + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, errCache } if cache != nil { accountKey = cache.(string) @@ -939,7 +980,7 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r if err == nil && accountKey != "" { d.accountCacheMap.Set(accountName, accountKey) } - return rgName, accountName, accountKey, fileShareName, diskName, subsID, err + return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err } func isSupportedProtocol(protocol string) bool { @@ -1489,3 +1530,28 @@ func (d *Driver) createFolderIfNotExists(ctx context.Context, accountName, accou klog.V(2).Infof("Successfully ensured folder path %s exists in share %s", folderName, fileShareName) return nil } + +// serviceAccountToken represents the service account token sent from NodePublishVolume Request. +// ref: https://kubernetes-csi.github.io/docs/token-requests.html +type serviceAccountToken struct { + APIAzureADTokenExchange struct { + Token string `json:"token"` + ExpirationTimestamp time.Time `json:"expirationTimestamp"` + } `json:"api://AzureADTokenExchange"` +} + +// parseServiceAccountToken parses the bound service account token from the token passed from NodePublishVolume Request. +// ref: https://kubernetes-csi.github.io/docs/token-requests.html +func parseServiceAccountToken(tokenStr string) (string, error) { + if len(tokenStr) == 0 { + return "", fmt.Errorf("service account token is empty") + } + token := serviceAccountToken{} + if err := json.Unmarshal([]byte(tokenStr), &token); err != nil { + return "", fmt.Errorf("failed to unmarshal service account tokens, error: %w", err) + } + if token.APIAzureADTokenExchange.Token == "" { + return "", fmt.Errorf("token for audience %s not found", DefaultTokenAudience) + } + return token.APIAzureADTokenExchange.Token, nil +} diff --git a/pkg/azurefile/azurefile_test.go b/pkg/azurefile/azurefile_test.go index f8f96ec669..fe3152c3e4 100644 --- a/pkg/azurefile/azurefile_test.go +++ b/pkg/azurefile/azurefile_test.go @@ -838,6 +838,20 @@ func TestGetAccountInfo(t *testing.T) { expectFileShareName: "test_sharename", expectDiskName: "test_diskname", }, + { + volumeID: "invalid_mountWithWITokenField_value##", + rgName: "vol_2", + secrets: emptySecret, + reqContext: map[string]string{ + shareNameField: "test_sharename", + mountWithWITokenField: "invalid", + }, + expectErr: true, + err: fmt.Errorf("invalid %s: %s in volume context", mountWithWITokenField, "invalid"), + expectAccountName: "", + expectFileShareName: "test_sharename", + expectDiskName: "test_diskname", + }, } for _, test := range tests { @@ -847,7 +861,7 @@ func TestGetAccountInfo(t *testing.T) { d.kubeClient = clientSet d.cloud.Environment = &azclient.Environment{StorageEndpointSuffix: "abc"} mockStorageAccountsClient.EXPECT().ListKeys(gomock.Any(), gomock.Any(), test.rgName).Return(key, nil).AnyTimes() - rgName, accountName, _, fileShareName, diskName, _, err := d.GetAccountInfo(context.Background(), test.volumeID, test.secrets, test.reqContext) + rgName, accountName, _, fileShareName, diskName, _, _, _, err := d.GetAccountInfo(context.Background(), test.volumeID, test.secrets, test.reqContext) if test.expectErr && err == nil { t.Errorf("Unexpected non-error") continue @@ -2052,3 +2066,88 @@ func TestGetInfoFromSnapshotID(t *testing.T) { }) } } + +func TestParseServiceAccountToken(t *testing.T) { + tests := []struct { + name string + tokenStr string + expectedToken string + expectedError string + }{ + { + name: "Empty token string", + tokenStr: "", + expectedToken: "", + expectedError: "service account token is empty", + }, + { + name: "Invalid JSON", + tokenStr: "invalid-json", + expectedToken: "", + expectedError: "failed to unmarshal service account tokens", + }, + { + name: "Valid token with audience", + tokenStr: `{"api://AzureADTokenExchange":{"token":"test-token-value","expirationTimestamp":"2025-01-01T00:00:00Z"}}`, + expectedToken: "test-token-value", + expectedError: "", + }, + { + name: "Token with empty token value", + tokenStr: `{"api://AzureADTokenExchange":{"token":"","expirationTimestamp":"2025-01-01T00:00:00Z"}}`, + expectedToken: "", + expectedError: "token for audience api://AzureADTokenExchange/.default not found", + }, + { + name: "Token with missing api://AzureADTokenExchange field", + tokenStr: `{"someOtherField":{"token":"test-token","expirationTimestamp":"2025-01-01T00:00:00Z"}}`, + expectedToken: "", + expectedError: "token for audience api://AzureADTokenExchange/.default not found", + }, + { + name: "Token with partial JSON structure", + tokenStr: `{"api://AzureADTokenExchange":{}}`, + expectedToken: "", + expectedError: "token for audience api://AzureADTokenExchange/.default not found", + }, + { + name: "Malformed JSON with extra characters", + tokenStr: `{"api://AzureADTokenExchange":{"token":"test-token"}}extra`, + expectedToken: "", + expectedError: "failed to unmarshal service account tokens", + }, + { + name: "Token with special characters", + tokenStr: `{"api://AzureADTokenExchange":{"token":"eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test","expirationTimestamp":"2025-01-01T00:00:00Z"}}`, + expectedToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + expectedError: "", + }, + { + name: "Token with unicode characters", + tokenStr: `{"api://AzureADTokenExchange":{"token":"test-token-.","expirationTimestamp":"2025-01-01T00:00:00Z"}}`, + expectedToken: "test-token-.", + expectedError: "", + }, + { + name: "Token with whitespace in value", + tokenStr: `{"api://AzureADTokenExchange":{"token":" test-token ","expirationTimestamp":"2025-01-01T00:00:00Z"}}`, + expectedToken: " test-token ", + expectedError: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + token, err := parseServiceAccountToken(test.tokenStr) + + if test.expectedError != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), test.expectedError) + assert.Equal(t, "", token) + } else { + assert.NoError(t, err) + assert.Equal(t, test.expectedToken, token) + } + }) + } +} diff --git a/pkg/azurefile/controllerserver.go b/pkg/azurefile/controllerserver.go index 3ec614b83f..d4af54a08c 100644 --- a/pkg/azurefile/controllerserver.go +++ b/pkg/azurefile/controllerserver.go @@ -120,9 +120,9 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } var sku, subsID, resourceGroup, location, account, fileShareName, diskName, fsType, secretName string var secretNamespace, pvcNamespace, protocol, customTags, storageEndpointSuffix, networkEndpointType, shareAccessTier, accountAccessTier, rootSquashType, tagValueDelimiter string - var createAccount, useSeretCache, matchTags, selectRandomMatchingAccount, getLatestAccountKey, encryptInTransit bool + var createAccount, useSeretCache, matchTags, selectRandomMatchingAccount, getLatestAccountKey, encryptInTransit, mountWithManagedIdentity, mountWithWIToken bool var vnetResourceGroup, vnetName, vnetLinkName, publicNetworkAccess, subnetName, shareNamePrefix, fsGroupChangePolicy, useDataPlaneAPI string - var requireInfraEncryption, disableDeleteRetentionPolicy, enableLFS, isMultichannelEnabled, allowSharedKeyAccess, mountWithManagedIdentity *bool + var requireInfraEncryption, disableDeleteRetentionPolicy, enableLFS, isMultichannelEnabled, allowSharedKeyAccess *bool var provisionedBandwidthMibps, provisionedIops *int32 // set allowBlobPublicAccess as false by default allowBlobPublicAccess := ptr.To(false) @@ -131,6 +131,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) // store account key to k8s secret by default storeAccountKey := true + var err error var accountQuota int32 // Apply ProvisionerParameters (case-insensitive). We leave validation of // the values to the cloud provider. @@ -231,6 +232,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) case serverNameField: case folderNameField: case clientIDField: + case tenantIDField: case confidentialContainerLabelField: case runtimeClassHandlerField: case createFolderIfNotExistField: @@ -297,16 +299,24 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } provisionedIops = to.Ptr(int32(value)) case mountWithManagedIdentityField: - value, err := strconv.ParseBool(v) + mountWithManagedIdentity, err = strconv.ParseBool(v) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid %s: %s in storage class", mountWithManagedIdentityField, v) } - mountWithManagedIdentity = &value + case mountWithWITokenField: + mountWithWIToken, err = strconv.ParseBool(v) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid %s: %s in storage class", mountWithWITokenField, v) + } default: return nil, status.Errorf(codes.InvalidArgument, "invalid parameter %q in storage class", k) } } + if mountWithManagedIdentity && mountWithWIToken { + return nil, status.Error(codes.InvalidArgument, "mountwithmanagedidentity and mountwithworkloadidentitytoken cannot be both true in storage class") + } + if matchTags && account != "" { return nil, status.Errorf(codes.InvalidArgument, "matchTags must set as false when storageAccount(%s) is provided", account) } @@ -524,6 +534,12 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } } + var requiresSmbOAuth *bool + if mountWithManagedIdentity || mountWithWIToken { + klog.V(2).Info("enabling smb oauth for managed identity or work identity token based mount") + requiresSmbOAuth = to.Ptr(true) + } + accountOptions := &storage.AccountOptions{ Name: account, Type: sku, @@ -551,7 +567,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) StorageType: storage.StorageTypeFile, StorageEndpointSuffix: storageEndpointSuffix, IsMultichannelEnabled: isMultichannelEnabled, - IsSmbOAuthEnabled: mountWithManagedIdentity, + IsSmbOAuthEnabled: requiresSmbOAuth, PickRandomMatchingAccount: selectRandomMatchingAccount, GetLatestAccountKey: getLatestAccountKey, SourceAccountName: srcAccountName, @@ -819,7 +835,7 @@ func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) } // use data plane api, get account key first - _, _, accountKey, _, _, _, err := d.GetAccountInfo(ctx, volumeID, req.GetSecrets(), reqContext) + _, _, accountKey, _, _, _, _, _, err := d.GetAccountInfo(ctx, volumeID, req.GetSecrets(), reqContext) if err != nil { return nil, status.Errorf(codes.NotFound, "get account info from(%s) failed with error: %v", volumeID, err) } @@ -873,7 +889,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida return nil, status.Error(codes.InvalidArgument, "Volume capabilities not provided") } - resourceGroupName, accountName, _, fileShareName, diskName, subsID, err := d.GetAccountInfo(ctx, volumeID, req.GetSecrets(), req.GetVolumeContext()) + resourceGroupName, accountName, _, fileShareName, diskName, subsID, _, _, err := d.GetAccountInfo(ctx, volumeID, req.GetSecrets(), req.GetVolumeContext()) //nolint:dogsled if err != nil || accountName == "" || fileShareName == "" { return nil, status.Errorf(codes.NotFound, "get account info from(%s) failed with error: %v", volumeID, err) } @@ -1332,7 +1348,7 @@ func (d *Driver) ControllerExpandVolume(ctx context.Context, req *csi.Controller setKeyValueInMap(reqContext, secretNamespaceField, secretNamespace) } // use data plane api, get account key first - _, _, accountKey, _, _, _, err := d.GetAccountInfo(ctx, volumeID, secrets, reqContext) + _, _, accountKey, _, _, _, _, _, err := d.GetAccountInfo(ctx, volumeID, secrets, reqContext) if err != nil { return nil, status.Errorf(codes.NotFound, "get account info from(%s) failed with error: %v", volumeID, err) } @@ -1365,7 +1381,7 @@ func (d *Driver) getShareClient(ctx context.Context, sourceVolumeID string, secr } func (d *Driver) getServiceClient(ctx context.Context, sourceVolumeID string, secrets map[string]string, useDataPlaneAPI string) (*service.Client, string, error) { - _, accountName, accountKey, fileShareName, _, _, err := d.GetAccountInfo(ctx, sourceVolumeID, secrets, map[string]string{}) //nolint:dogsled + _, accountName, accountKey, fileShareName, _, _, _, _, err := d.GetAccountInfo(ctx, sourceVolumeID, secrets, map[string]string{}) //nolint:dogsled if err != nil { return nil, fileShareName, err } diff --git a/pkg/azurefile/controllerserver_test.go b/pkg/azurefile/controllerserver_test.go index f983b23c8b..382673e050 100644 --- a/pkg/azurefile/controllerserver_test.go +++ b/pkg/azurefile/controllerserver_test.go @@ -926,6 +926,7 @@ var _ = ginkgo.Describe("TestCreateVolume", func() { createFolderIfNotExistField: "true", confidentialContainerLabelField: "confidential-container-label", mountWithManagedIdentityField: "true", + mountWithWITokenField: "false", } req := &csi.CreateVolumeRequest{ @@ -1077,6 +1078,41 @@ var _ = ginkgo.Describe("TestCreateVolume", func() { }) }) + ginkgo.When("invalid mountWithWIToken", func() { + ginkgo.It("should fail", func(ctx context.Context) { + req := &csi.CreateVolumeRequest{ + Name: "random-vol-name-valid-request", + VolumeCapabilities: stdVolCap, + CapacityRange: lessThanPremCapRange, + Parameters: map[string]string{ + mountWithWITokenField: "invalid", + }, + } + + expectedErr := status.Errorf(codes.InvalidArgument, "invalid %s: %s in storage class", mountWithWITokenField, "invalid") + _, err := d.CreateVolume(ctx, req) + gomega.Expect(err).To(gomega.Equal(expectedErr)) + }) + }) + + ginkgo.When("mountWithManagedIdentity and mountWithWIToken cannot be both true", func() { + ginkgo.It("should fail", func(ctx context.Context) { + req := &csi.CreateVolumeRequest{ + Name: "random-vol-name-valid-request", + VolumeCapabilities: stdVolCap, + CapacityRange: lessThanPremCapRange, + Parameters: map[string]string{ + mountWithManagedIdentityField: "true", + mountWithWITokenField: "true", + }, + } + + expectedErr := status.Errorf(codes.InvalidArgument, "%s and %s cannot be both true in storage class", mountWithManagedIdentityField, mountWithWITokenField) + _, err := d.CreateVolume(ctx, req) + gomega.Expect(err).To(gomega.Equal(expectedErr)) + }) + }) + ginkgo.When("invalid parameter", func() { ginkgo.It("should fail", func(ctx context.Context) { name := "baz" diff --git a/pkg/azurefile/nodeserver.go b/pkg/azurefile/nodeserver.go index 253bea6a36..167e0fd35b 100644 --- a/pkg/azurefile/nodeserver.go +++ b/pkg/azurefile/nodeserver.go @@ -76,8 +76,8 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu mountPermissions := d.mountPermissions context := req.GetVolumeContext() if context != nil { - if !strings.EqualFold(getValueInMap(context, mountWithManagedIdentityField), trueValue) && getValueInMap(context, serviceAccountTokenField) != "" && getValueInMap(context, clientIDField) != "" { - klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s", volumeID, target, getValueInMap(context, clientIDField)) + if getValueInMap(context, serviceAccountTokenField) != "" && shouldUseServiceAccountToken(context) { + klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with service account token, clientID: %s, mountWithWIToken: %s", volumeID, target, getValueInMap(context, clientIDField), getValueInMap(context, mountWithWITokenField)) _, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{ StagingTargetPath: target, VolumeContext: context, @@ -247,8 +247,8 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe volumeID := req.GetVolumeId() context := req.GetVolumeContext() - if getValueInMap(context, clientIDField) != "" && !strings.EqualFold(getValueInMap(context, mountWithManagedIdentityField), trueValue) && getValueInMap(context, serviceAccountTokenField) == "" { - klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID %s is provided but service account token is empty", volumeID, getValueInMap(context, clientIDField)) + if getValueInMap(context, serviceAccountTokenField) == "" && shouldUseServiceAccountToken(context) { + klog.V(2).Infof("Skip NodeStageVolume for volume(%s) since clientID(%s) or mountWithWIToken(%s) is provided but service account token is empty", volumeID, getValueInMap(context, clientIDField), getValueInMap(context, mountWithWITokenField)) return &csi.NodeStageVolumeResponse{}, nil } @@ -267,7 +267,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe mc.ObserveOperationWithResult(isOperationSucceeded, VolumeID, volumeID) }() - _, accountName, accountKey, fileShareName, diskName, _, err := d.GetAccountInfo(ctx, volumeID, req.GetSecrets(), context) + _, accountName, accountKey, fileShareName, diskName, _, tenantID, tokenFilePath, err := d.GetAccountInfo(ctx, volumeID, req.GetSecrets(), context) if err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("GetAccountInfo(%s) failed with error: %v", volumeID, err)) } @@ -277,7 +277,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe // don't respect fsType from req.GetVolumeCapability().GetMount().GetFsType() // since it's ext4 by default on Linux var fsType, server, protocol, ephemeralVolMountOptions, storageEndpointSuffix, folderName, clientID string - var ephemeralVol, createFolderIfNotExist, encryptInTransit, mountWithManagedIdentity bool + var ephemeralVol, createFolderIfNotExist, encryptInTransit, mountWithManagedIdentity, mountWithWIToken bool fileShareNameReplaceMap := map[string]string{} mountPermissions := d.mountPermissions @@ -333,6 +333,11 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe if err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("Volume context property %q must be a boolean value: %v", k, err)) } + case mountWithWITokenField: + mountWithWIToken, err = strconv.ParseBool(v) + if err != nil { + return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("Volume context property %q must be a boolean value: %v", k, err)) + } case clientIDField: clientID = v } @@ -354,6 +359,10 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe return nil, status.Errorf(codes.InvalidArgument, "fsGroupChangePolicy(%s) is not supported, supported fsGroupChangePolicy list: %v", fsGroupChangePolicy, supportedFSGroupChangePolicyList) } + if mountWithManagedIdentity && mountWithWIToken { + return nil, status.Error(codes.InvalidArgument, "mountWithManagedIdentity and mountWithWIToken cannot be both true") + } + lockKey := fmt.Sprintf("%s-%s", volumeID, targetPath) if acquired := d.volumeLocks.TryAcquire(lockKey); !acquired { return nil, status.Errorf(codes.Aborted, volumeOperationAlreadyExistsFmt, volumeID) @@ -405,12 +414,22 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe mountOptions = util.JoinMountOptions(mountFlags, []string{"vers=4,minorversion=1,sec=sys"}) mountOptions = appendDefaultNfsMountOptions(mountOptions, d.appendNoResvPortOption, d.appendActimeoOption) } else { + if (mountWithManagedIdentity || mountWithWIToken) && clientID == "" { + clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID + } + if mountWithManagedIdentity && runtime.GOOS != "windows" { - if clientID == "" { - clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID - } sensitiveMountOptions = []string{"sec=krb5,cruid=0,upcall_target=mount", fmt.Sprintf("username=%s", clientID)} klog.V(2).Infof("using managed identity %s for volume %s with mount options: %v", clientID, volumeID, sensitiveMountOptions) + } else if mountWithWIToken && runtime.GOOS != "windows" { + sensitiveMountOptions = []string{"sec=krb5,cruid=0,upcall_target=mount"} + klog.V(2).Infof("using workload identity token for volume %s with mount options: %v", volumeID, sensitiveMountOptions) + if tokenFilePath != "" { + // always set credential cache when token file is provided even mount does not happen + if out, err := setCredentialCache(server, clientID, tenantID, tokenFilePath); err != nil { + return nil, status.Errorf(codes.Internal, "setCredentialCache failed for %s with error: %v, output: %s", server, err, out) + } + } } else { if accountName == "" || accountKey == "" { return nil, status.Errorf(codes.Internal, "accountName(%s) or accountKey is empty", accountName) @@ -472,7 +491,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } else { execFunc := func() error { if mountWithManagedIdentity && protocol != nfs && runtime.GOOS != "windows" { - if out, err := setCredentialCache(server, clientID); err != nil { + if out, err := setCredentialCache(server, clientID, tenantID, tokenFilePath); err != nil { return fmt.Errorf("setCredentialCache failed for %s with error: %v, output: %s", server, err, out) } } @@ -838,3 +857,14 @@ func checkGidPresentInMountFlags(mountFlags []string) bool { } return false } + +// shouldUseServiceAccountToken determines whether a service account token should be used for authentication based on the volume context attributes. +func shouldUseServiceAccountToken(attrib map[string]string) bool { + if getValueInMap(attrib, mountWithWITokenField) == trueValue { + return true + } + if getValueInMap(attrib, clientIDField) != "" && !strings.EqualFold(getValueInMap(attrib, mountWithManagedIdentityField), trueValue) { + return true + } + return false +} diff --git a/pkg/azurefile/nodeserver_test.go b/pkg/azurefile/nodeserver_test.go index ddeeba34cd..d7b85c1bd3 100644 --- a/pkg/azurefile/nodeserver_test.go +++ b/pkg/azurefile/nodeserver_test.go @@ -761,6 +761,22 @@ func TestNodeStageVolume(t *testing.T) { DefaultError: status.Error(codes.InvalidArgument, fmt.Sprintf("invalid mountPermissions %s", "07ab")), }, }, + { + desc: "[Error] mountWithManagedIdentity and mountWithWIToken cannot be both true", + req: &csi.NodeStageVolumeRequest{VolumeId: "vol_1##", StagingTargetPath: sourceTest, + VolumeCapability: &stdVolCap, + VolumeContext: map[string]string{ + shareNameField: "test_sharename", + storageAccountField: "test_accountname", + serviceAccountTokenField: "token", + mountWithManagedIdentityField: "true", + mountWithWITokenField: "true", + }, + Secrets: secrets}, + expectedErr: testutil.TestError{ + DefaultError: status.Error(codes.InvalidArgument, "mountWithManagedIdentity and mountWithWIToken cannot be both true"), + }, + }, { desc: "[Success] Valid request with Kata CC Mount enabled", setup: func() { @@ -1177,6 +1193,50 @@ func TestNodePublishVolumeIdempotentMount(t *testing.T) { assert.NoError(t, err) } +func TestShouldUseServiceAccountToken(t *testing.T) { + tests := []struct { + name string + attrib map[string]string + expect bool + }{ + { + name: "witoken true", + attrib: map[string]string{mountWithWITokenField: trueValue}, + expect: true, + }, + { + name: "clientID without managed identity", + attrib: map[string]string{ + clientIDField: "client-id", + mountWithManagedIdentityField: "", + }, + expect: true, + }, + { + name: "clientID with managed identity true", + attrib: map[string]string{ + clientIDField: "client-id", + mountWithManagedIdentityField: "True", + }, + expect: false, + }, + { + name: "no wi configuration", + attrib: map[string]string{}, + expect: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := shouldUseServiceAccountToken(test.attrib) + if result != test.expect { + t.Fatalf("shouldUseServiceAccountToken() = %t, want %t for attrib %v", result, test.expect, test.attrib) + } + }) + } +} + func makeFakeCmd(fakeCmd *testingexec.FakeCmd, cmd string, args ...string) testingexec.FakeCommandAction { c := cmd a := args diff --git a/pkg/azurefile/utils.go b/pkg/azurefile/utils.go index 2197c64621..58823c2c1f 100644 --- a/pkg/azurefile/utils.go +++ b/pkg/azurefile/utils.go @@ -417,13 +417,43 @@ func getDefaultBandwidth(requestGiB int, storageAccountType string) *int32 { return &bandwidth } -func setCredentialCache(server, clientID string) ([]byte, error) { - if server == "" || clientID == "" { - return nil, fmt.Errorf("server and clientID must be provided") +func setCredentialCache(server, clientID, tenantID, tokenFile string) ([]byte, error) { + if server == "" { + return nil, fmt.Errorf("server must be provided") + } + if clientID == "" { + return nil, fmt.Errorf("clientID must be provided") } - cmd := exec.Command("azfilesauthmanager", "set", "https://"+server, "--imds-client-id", clientID) + var args []string + if tokenFile != "" { + if tenantID == "" { + return nil, fmt.Errorf("tenantID must be provided when tokenFile is provided") + } + args = []string{"set", "https://" + server, "--workload-identity", "--tenant-id", tenantID, "--client-id", clientID, "--token-file", tokenFile} + } else { + args = []string{"set", "https://" + server, "--imds-client-id", clientID} + } + + cmd := exec.Command("azfilesauthmanager", args...) cmd.Env = append(os.Environ(), cmd.Env...) klog.V(2).Infof("Executing command: %q", cmd.String()) return cmd.CombinedOutput() } + +// isValidTokenFileName checks if the token file name is valid +// fileName should only contain alphanumeric characters, hyphens +func isValidTokenFileName(fileName string) bool { + if fileName == "" { + return false + } + for _, c := range fileName { + if !(('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + ('0' <= c && c <= '9') || + (c == '-')) { + return false + } + } + return true +} diff --git a/pkg/azurefile/utils_test.go b/pkg/azurefile/utils_test.go index bc9c1c0fa9..d57adb968a 100644 --- a/pkg/azurefile/utils_test.go +++ b/pkg/azurefile/utils_test.go @@ -1366,30 +1366,62 @@ func TestSetCredentialCache(t *testing.T) { desc string server string clientID string + tenantID string + tokenFile string expectedError string }{ { desc: "empty server", server: "", clientID: "test-client-id", - expectedError: "server and clientID must be provided", + tenantID: "test-tenant-id", + tokenFile: "test-token-file", + expectedError: "server must be provided", }, { desc: "empty clientID", server: "test.file.core.windows.net", clientID: "", - expectedError: "server and clientID must be provided", + tenantID: "test-tenant-id", + tokenFile: "test-token-file", + expectedError: "clientID must be provided", }, { - desc: "both empty", + desc: "empty tenantID with tokenFile", + server: "test.file.core.windows.net", + clientID: "test-client-id", + tenantID: "", + tokenFile: "test-token-file", + expectedError: "tenantID must be provided when tokenFile is provided", + }, + { + desc: "valid IMDS authentication (no tokenFile)", + server: "test.file.core.windows.net", + clientID: "test-client-id", + tenantID: "", + tokenFile: "", + expectedError: "", // Will fail due to missing azfilesauthmanager, but validates argument construction + }, + { + desc: "valid workload identity authentication", + server: "test.file.core.windows.net", + clientID: "test-client-id", + tenantID: "test-tenant-id", + tokenFile: "test-token-file", + expectedError: "", // Will fail due to missing azfilesauthmanager, but validates argument construction + }, + { + desc: "both empty server and clientID", server: "", clientID: "", - expectedError: "server and clientID must be provided", + tenantID: "test-tenant-id", + tokenFile: "test-token-file", + expectedError: "server must be provided", }, } for _, test := range tests { - _, err := setCredentialCache(test.server, test.clientID) + _, err := setCredentialCache(test.server, test.clientID, test.tenantID, test.tokenFile) if test.expectedError != "" { if err == nil { t.Errorf("test[%s]: expected error containing %q, got nil", test.desc, test.expectedError) @@ -1397,9 +1429,73 @@ func TestSetCredentialCache(t *testing.T) { t.Errorf("test[%s]: expected error containing %q, got %v", test.desc, test.expectedError, err) } } + // Note: We don't test successful execution as it requires azfilesauthmanager binary + // The actual command execution will fail, but we've validated the argument construction } } func int32Ptr(i int32) *int32 { return &i } + +func TestIsValidTokenFileName(t *testing.T) { + testCases := []struct { + name string + fileName string + expected bool + }{ + { + name: "valid lowercase", + fileName: "token", + expected: true, + }, + { + name: "valid uppercase", + fileName: "TOKEN", + expected: true, + }, + { + name: "valid mixed alphanumeric with hyphen", + fileName: "Token-123", + expected: true, + }, + { + name: "valid mixed alphanumeric with hyphen#2", + fileName: "0ab48765-efce-4799-8a9c-c3e1de2ee42eg", + expected: true, + }, + { + name: "empty string", + fileName: "", + expected: false, + }, + { + name: "contains underscore", + fileName: "token_file", + expected: false, + }, + { + name: "contains dot", + fileName: "token.file", + expected: false, + }, + { + name: "contains space", + fileName: "token file", + expected: false, + }, + { + name: "contains slash", + fileName: "token/file", + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := isValidTokenFileName(tc.fileName); got != tc.expected { + t.Fatalf("isValidTokenFileName(%q) = %t, want %t", tc.fileName, got, tc.expected) + } + }) + } +} diff --git a/test/e2e/testsuites/pre_provisioned_existing_credentials_tester.go b/test/e2e/testsuites/pre_provisioned_existing_credentials_tester.go index 103c6becb1..c86a718316 100644 --- a/test/e2e/testsuites/pre_provisioned_existing_credentials_tester.go +++ b/test/e2e/testsuites/pre_provisioned_existing_credentials_tester.go @@ -41,7 +41,7 @@ type PreProvisionedExistingCredentialsTest struct { func (t *PreProvisionedExistingCredentialsTest) Run(ctx context.Context, client clientset.Interface, namespace *v1.Namespace) { for _, pod := range t.Pods { for n, volume := range pod.Volumes { - resourceGroupName, accountName, _, fileShareName, _, _, err := t.Azurefile.GetAccountInfo(ctx, volume.VolumeID, nil, nil) + resourceGroupName, accountName, _, fileShareName, _, _, _, _, err := t.Azurefile.GetAccountInfo(ctx, volume.VolumeID, nil, nil) if err != nil { framework.ExpectNoError(err, fmt.Sprintf("Error GetContainerInfo from volumeID(%s): %v", volume.VolumeID, err)) return diff --git a/test/e2e/testsuites/pre_provisioned_provided_credentials_tester.go b/test/e2e/testsuites/pre_provisioned_provided_credentials_tester.go index 4f679bc014..f881f0fde4 100644 --- a/test/e2e/testsuites/pre_provisioned_provided_credentials_tester.go +++ b/test/e2e/testsuites/pre_provisioned_provided_credentials_tester.go @@ -41,7 +41,7 @@ type PreProvisionedProvidedCredentiasTest struct { func (t *PreProvisionedProvidedCredentiasTest) Run(ctx context.Context, client clientset.Interface, namespace *v1.Namespace) { for _, pod := range t.Pods { for n, volume := range pod.Volumes { - _, accountName, accountKey, fileShareName, _, _, err := t.Azurefile.GetAccountInfo(ctx, volume.VolumeID, nil, nil) + _, accountName, accountKey, fileShareName, _, _, _, _, err := t.Azurefile.GetAccountInfo(ctx, volume.VolumeID, nil, nil) framework.ExpectNoError(err, fmt.Sprintf("Error GetAccountInfo from volumeID(%s): %v", volume.VolumeID, err)) ginkgo.By("creating the secret")