diff --git a/.buildkite/test-e2e-native-scheduling.yml b/.buildkite/test-e2e-native-scheduling.yml new file mode 100644 index 00000000000..168653b8093 --- /dev/null +++ b/.buildkite/test-e2e-native-scheduling.yml @@ -0,0 +1,34 @@ +- label: 'Test Native Workload Scheduling E2E (nightly operator)' + instance_size: large + image: golang:1.26-bookworm + commands: + - source .buildkite/setup-env.sh + + # Install a kind version that supports Kubernetes 1.36 + - echo "--- Installing kind v0.27.0 (required for K8s 1.36 support)" + - curl -Lo ./kind https://kind.sigs.k8s.io/dl/v0.27.0/kind-linux-amd64 + - chmod +x ./kind + - mv ./kind /usr/local/bin/kind + + # Build the kind node image for K8s 1.36 with GenericWorkload and GangScheduling + - echo "--- Building kind node image for K8s 1.36" + - kind build node-image --image kindest/node:v1.36.0 v1.36.0 + + # Create the kind cluster with native scheduling feature gates + - kind create cluster --wait 900s --config ./ci/kind-config-buildkite-native-scheduling.yml + - kubectl config set clusters.kind-kind.server https://docker:6443 + + # Build nightly KubeRay operator image and deploy with NativeWorkloadScheduling enabled + - pushd ray-operator + - IMG=kuberay/operator:nightly make docker-image + - kind load docker-image kuberay/operator:nightly + - IMG=kuberay/operator:nightly make deploy-native-scheduling + - kubectl wait --timeout=90s --for=condition=Available=true deployment kuberay-operator + + # Run native scheduling e2e tests + - echo "--- START:Running Native Workload Scheduling E2E tests" + - set -o pipefail + - mkdir -p "$(pwd)/tmp" && export KUBERAY_TEST_OUTPUT_DIR=$(pwd)/tmp + - echo "KUBERAY_TEST_OUTPUT_DIR=$$KUBERAY_TEST_OUTPUT_DIR" + - KUBERAY_TEST_TIMEOUT_SHORT=1m KUBERAY_TEST_TIMEOUT_MEDIUM=5m KUBERAY_TEST_TIMEOUT_LONG=10m go test -timeout 30m -v ./test/e2enativescheduling 2>&1 | awk -f ../.buildkite/format.awk | tee $$KUBERAY_TEST_OUTPUT_DIR/gotest.log || (kubectl logs --tail -1 -l app.kubernetes.io/name=kuberay | tee $$KUBERAY_TEST_OUTPUT_DIR/kuberay-operator.log && cd $$KUBERAY_TEST_OUTPUT_DIR && find . -name "*.log" | tar -cf /artifact-mount/e2e-native-scheduling-log.tar -T - && exit 1) + - echo "--- END:Native Workload Scheduling E2E tests finished" diff --git a/ci/kind-config-buildkite-native-scheduling.yml b/ci/kind-config-buildkite-native-scheduling.yml new file mode 100644 index 00000000000..b0db15bb2a4 --- /dev/null +++ b/ci/kind-config-buildkite-native-scheduling.yml @@ -0,0 +1,28 @@ +kind: Cluster +apiVersion: kind.x-k8s.io/v1alpha4 +networking: + apiServerAddress: "0.0.0.0" + # Ensure stable port so we can rewrite the server address later + apiServerPort: 6443 +# Adding this so containers from the same docker network can access it +# https://blog.scottlowe.org/2019/07/30/adding-a-name-to-kubernetes-api-server-certificate/ +nodes: +- role: control-plane + # No pre-built kindest/node:v1.36.0 images exist yet. + # The Buildkite job builds it via: kind build node-image --image kindest/node:v1.36.0 v1.36.0 + image: kindest/node:v1.36.0 + kubeadmConfigPatches: + - | + kind: ClusterConfiguration + apiServer: + extraArgs: + feature-gates: "GenericWorkload=true" + runtime-config: "scheduling.k8s.io/v1alpha2=true" + certSANs: + - "docker" + controllerManager: + extraArgs: + feature-gates: "GenericWorkload=true" + scheduler: + extraArgs: + feature-gates: "GenericWorkload=true,GangScheduling=true" diff --git a/ci/kind-config-native-workload-scheduling.yml b/ci/kind-config-native-workload-scheduling.yml new file mode 100644 index 00000000000..2e8c018f766 --- /dev/null +++ b/ci/kind-config-native-workload-scheduling.yml @@ -0,0 +1,30 @@ +kind: Cluster +apiVersion: kind.x-k8s.io/v1alpha4 +networking: + apiServerAddress: "127.0.0.1" + # Ensure stable port so we can rewrite the server address later + apiServerPort: 6443 +# Adding this so containers from the same docker network can access it +# https://blog.scottlowe.org/2019/07/30/adding-a-name-to-kubernetes-api-server-certificate/ +nodes: +- role: control-plane + # No pre-built kindest/node:v1.36.0 images exist yet. + # Build locally: kind build node-image --image kindest/node:v1.36.0 v1.36.0 + # Update this image tag to match whichever version is built. + image: kindest/node:v1.36.0 + kubeadmConfigPatches: + - | + kind: ClusterConfiguration + apiServer: + extraArgs: + feature-gates: "GenericWorkload=true" + runtime-config: "scheduling.k8s.io/v1alpha2=true" + certSANs: + - "127.0.0.1" + - "docker" + controllerManager: + extraArgs: + feature-gates: "GenericWorkload=true" + scheduler: + extraArgs: + feature-gates: "GenericWorkload=true,GangScheduling=true" diff --git a/docs/guidance/native-workload-scheduling.md b/docs/guidance/native-workload-scheduling.md new file mode 100644 index 00000000000..c166c341554 --- /dev/null +++ b/docs/guidance/native-workload-scheduling.md @@ -0,0 +1,222 @@ + +# Native Workload Scheduling (Gang Scheduling) + +This guide explains how to use KubeRay's native Kubernetes gang scheduling integration, which ensures that all pods in a RayCluster worker group are either entirely schedulable or not scheduled at all. + +## Overview + +Distributed AI/ML workloads on Kubernetes can suffer from partial scheduling: some pods in a group get scheduled and hold expensive GPU nodes idle while waiting for the remaining pods, +or partially scheduled groups block other workloads indefinitely (livelock). Gang scheduling solves this by treating a group of pods as an atomic unit. + +KubeRay integrates with the **native** Kubernetes gang scheduling APIs (`scheduling.k8s.io/v1alpha2`) introduced in [KEP-4671](https://github.com/kubernetes/enhancements/tree/master/keps/sig-scheduling/4671-gang-scheduling) +and [KEP-5832](https://github.com/kubernetes/enhancements/tree/master/keps/sig-scheduling/5832-decouple-podgroup-api). +"Native" means this gang scheduling is built into the default Kubernetes scheduler (kube-scheduler) starting in Kubernetes 1.36, unlike other external solutions (Volcano, YuniKorn, etc.) which replace or wrap the default scheduler. + +When enabled, the KubeRay operator creates `Workload` and `PodGroup` resources for each RayCluster and sets `spec.schedulingGroup` on every pod, linking it to the appropriate PodGroup. +The Kubernetes scheduler then evaluates each worker group as a gang — all pods in the group must be schedulable before any of them are scheduled. + +## Prerequisites + +- **Kubernetes 1.36+** with the following feature gates enabled: + - `GenericWorkload=true` on the **kube-apiserver** — enables the Workload/PodGroup APIs + - `scheduling.k8s.io/v1alpha2=true` in the apiserver's **runtime-config** — serves the alpha API (alpha APIs are off by default) + - `GenericWorkload=true` on the **kube-controller-manager** + - `GangScheduling=true` on the **kube-scheduler** — enables the gang scheduling plugin that processes `schedulingGroup` on pods +- **KubeRay operator** with the `NativeWorkloadScheduling` feature gate enabled + +## Enabling the Feature + +Start with a Kubernetes 1.36+ cluster with the required feature gates (`GenericWorkload`, `GangScheduling`) and alpha API enabled as described in [Prerequisites](#prerequisites). + +### 1. Enable the KubeRay operator feature gate + +Pass the `NativeWorkloadScheduling` feature gate when starting the operator: + +```bash +--feature-gates=NativeWorkloadScheduling=true +``` + +With Helm: + +```yaml +featureGates: + - name: NativeWorkloadScheduling + enabled: true +``` + +Or via `--set`: + +```bash +helm install kuberay-operator kuberay/kuberay-operator \ + --set featureGates[0].name=NativeWorkloadScheduling \ + --set featureGates[0].enabled=true +``` + +### 2. Opt-in per RayCluster + +Add the `ray.io/native-workload-scheduling: "true"` annotation to each RayCluster that should use gang scheduling: + +```yaml +apiVersion: ray.io/v1 +kind: RayCluster +metadata: + name: my-cluster + annotations: + ray.io/native-workload-scheduling: "true" +spec: + headGroupSpec: + rayStartParams: + dashboard-host: "0.0.0.0" + template: + spec: + containers: + - name: ray-head + image: rayproject/ray:2.52.0 + resources: + limits: + cpu: "1" + memory: "2Gi" + workerGroupSpecs: + - groupName: gpu-workers + replicas: 4 + minReplicas: 4 + maxReplicas: 4 + rayStartParams: {} + template: + spec: + containers: + - name: ray-worker + image: rayproject/ray:2.52.0 + resources: + limits: + cpu: "1" + memory: "2Gi" + nvidia.com/gpu: "1" +``` + +## What Happens Under the Hood + +When both the feature gate and annotation are set, the operator creates the following resources before creating pods: + +1. **A `Workload` object** (one per RayCluster) with a `podGroupTemplate` for each group: + - One template named `head` using `BasicSchedulingPolicy` (single pod, no gang constraint) + - One template per worker group named `worker-` using `GangSchedulingPolicy` with `minCount` set to the desired replica count + +2. **PodGroup objects** (one per template) named `-head`, `-worker-`, etc. + +3. **`spec.schedulingGroup`** is set on every pod, linking it to the appropriate PodGroup + +The Workload and PodGroups are owned by the RayCluster (via `ownerReferences`), so they are automatically garbage collected when the RayCluster is deleted. + +### Spec drift detection + +If you change the RayCluster spec (add/remove worker groups, change replica counts), the operator detects the mismatch, deletes the stale Workload and PodGroups, and recreates them from the updated spec. + +### Suspend and resume + +When a RayCluster is suspended, the operator deletes the Workload and PodGroups alongside the pods. On resume, fresh scheduling resources are created with the current spec. + +### Status condition + +When the `RayClusterStatusConditions` feature gate is also enabled, the operator sets a `WorkloadScheduled` condition on the RayCluster: + +| Condition | Status | Reason | Meaning | +|-----------|--------|--------|---------| +| `WorkloadScheduled` | `True` | `WorkloadReady` | Workload and PodGroups have been created | +| `WorkloadScheduled` | `False` | `WorkloadPending` | Workload has not been created yet | + +The condition is removed when the cluster is suspended or when native scheduling is not enabled. + +## Limitations + +> **Note**: This feature is in early alpha. Both the Kubernetes `scheduling.k8s.io/v1alpha2` API and the KubeRay integration are under active development. Notably, autoscaling is not supported — only fixed-size worker groups are compatible. + +- **No autoscaling support**: RayClusters with autoscaling enabled (`enableInTreeAutoscaling: true`) will skip native scheduling with a warning event. Fixed-size worker groups only. +- **Max 7 worker groups**: The `scheduling.k8s.io/v1alpha2` API allows at most 8 PodGroupTemplates per Workload (1 reserved for the head group). +- **Per-worker-group atomicity only**: Each worker group is scheduled as an independent gang. There is no cross-worker-group atomicity (e.g., "schedule all GPU workers AND all CPU workers or none"). +- **Mutually exclusive with batch schedulers**: Cannot be used together with `batchScheduler` configuration (Volcano, YuniKorn, etc.). The operator will refuse to start if both are enabled. +- **Immutable `schedulingGroup` on pods**: The `spec.schedulingGroup` field on pods is immutable. If you enable native scheduling on an already-running cluster, existing pods will not get `schedulingGroup` set. New pods (from scale-up, recreation, or suspend/resume) will be correctly configured. +- **Requires Kubernetes 1.36+**: The `scheduling.k8s.io/v1alpha2` API is not available in earlier versions. + +## Troubleshooting + +### Verify resources were created + +```bash +# Check Workload +kubectl get workloads -n + +# Check PodGroups +kubectl get podgroups.scheduling.k8s.io -n + +# Check events on the RayCluster +kubectl describe raycluster -n | grep -A 20 Events +``` + +You should see `CreatedWorkload` and `CreatedPodGroup` events on the RayCluster. + +### Pods stuck in PreEnqueue + +If the kube-apiserver's `GenericWorkload` feature gate is enabled but the kube-scheduler's `GangScheduling` feature gate is **not**, the operator will successfully create Workload and PodGroup resources, but pods will remain stuck in the `PreEnqueue` scheduling gate. + +The operator's startup check cannot detect this — it only verifies the API is registered on the apiserver, not that the scheduler plugin is active. To diagnose: + +```bash +# Check pod events for scheduling gate messages +kubectl describe pod -n + +# Verify the scheduler has GangScheduling enabled +kubectl get pod kube-scheduler- -n kube-system -o yaml | grep feature-gates +``` + +Ensure **both** `GenericWorkload` (apiserver) and `GangScheduling` (scheduler) feature gates are enabled. + +### Operator fails to start + +If you see an error like: + +```text +NativeWorkloadScheduling feature gate and batchScheduler configuration are mutually exclusive +``` + +The operator has both native scheduling and a batch scheduler (Volcano, YuniKorn, etc.) configured. Disable one of them. + +If you see: + +```text +scheduling.k8s.io/v1alpha2 API is not available +``` + +Your Kubernetes cluster is either older than 1.36 or the `GenericWorkload` feature gate / `scheduling.k8s.io/v1alpha2` runtime-config is not enabled on the apiserver. + +### Warning events + +| Event | Meaning | +|-------|---------| +| `WorkloadSchedulingSkipped` | Native scheduling was skipped because autoscaling is enabled or a batch scheduler is configured | +| `WorkloadSchedulingInvalidSpec` | Too many worker groups (>7) | +| `FailedToCreateWorkload` / `FailedToCreatePodGroup` | API error creating scheduling resources | + +## Setting Up a Local Test Cluster + +To test native workload scheduling locally with [kind](https://kind.sigs.k8s.io/), you need a Kubernetes 1.36+ node image with the required feature gates: + +```bash +# Build the kind node image from a K8s 1.36 release +kind build node-image v1.36.0 + +# Create the cluster with the required feature gates +kind create cluster --name native-sched \ + --config ci/kind-config-native-workload-scheduling.yml +``` + +The [`ci/kind-config-native-workload-scheduling.yml`](../../ci/kind-config-native-workload-scheduling.yml) config enables `GenericWorkload` on the apiserver/controller-manager, `GangScheduling` on the scheduler, and serves the `scheduling.k8s.io/v1alpha2` alpha API. + +Then deploy the operator with the feature gate enabled: + +```bash +cd ray-operator +make docker-image IMG=kuberay/operator:latest +kind load docker-image kuberay/operator:latest --name native-sched +make deploy-native-scheduling IMG=kuberay/operator:latest +``` diff --git a/helm-chart/kuberay-operator/README.md b/helm-chart/kuberay-operator/README.md index b3a3be195af..c4d05c0174a 100644 --- a/helm-chart/kuberay-operator/README.md +++ b/helm-chart/kuberay-operator/README.md @@ -181,6 +181,8 @@ spec: | featureGates[3].enabled | bool | `false` | | | featureGates[4].name | string | `"RayCronJob"` | | | featureGates[4].enabled | bool | `false` | | +| featureGates[5].name | string | `"NativeWorkloadScheduling"` | | +| featureGates[5].enabled | bool | `false` | | | metrics.enabled | bool | `true` | Whether KubeRay operator should emit control plane metrics. | | metrics.serviceMonitor.enabled | bool | `false` | Enable a prometheus ServiceMonitor | | metrics.serviceMonitor.interval | string | `"30s"` | Prometheus ServiceMonitor interval | diff --git a/helm-chart/kuberay-operator/templates/_helpers.tpl b/helm-chart/kuberay-operator/templates/_helpers.tpl index 0af59ec4d73..f4b7166d0d5 100644 --- a/helm-chart/kuberay-operator/templates/_helpers.tpl +++ b/helm-chart/kuberay-operator/templates/_helpers.tpl @@ -346,6 +346,19 @@ rules: - patch - update - watch +- apiGroups: + - scheduling.k8s.io + resources: + - podgroups + - workloads + verbs: + - create + - delete + - get + - list + - patch + - update + - watch {{- if or .batchSchedulerEnabled (eq .batchSchedulerName "volcano") }} - apiGroups: - scheduling.volcano.sh diff --git a/helm-chart/kuberay-operator/tests/deployment_test.yaml b/helm-chart/kuberay-operator/tests/deployment_test.yaml index 8b2d5e1e631..db9f3ad1f99 100644 --- a/helm-chart/kuberay-operator/tests/deployment_test.yaml +++ b/helm-chart/kuberay-operator/tests/deployment_test.yaml @@ -403,3 +403,14 @@ tests: - contains: path: spec.template.spec.containers[?(@.name=="kuberay-operator")].args content: "--burst=200" + + - it: Should succeed when NativeWorkloadScheduling feature gate is enabled + set: + featureGates: + - name: NativeWorkloadScheduling + enabled: true + asserts: + - containsDocument: + apiVersion: apps/v1 + kind: Deployment + name: kuberay-operator diff --git a/helm-chart/kuberay-operator/tests/multiple_namespaces_role_test.yaml b/helm-chart/kuberay-operator/tests/multiple_namespaces_role_test.yaml new file mode 100644 index 00000000000..421ce704177 --- /dev/null +++ b/helm-chart/kuberay-operator/tests/multiple_namespaces_role_test.yaml @@ -0,0 +1,31 @@ +suite: Test Multiple Namespaces Role + +templates: + - multiple_namespaces_role.yaml + +release: + name: kuberay-operator + namespace: kuberay-system + +tests: + - it: Should always include scheduling.k8s.io RBAC rules + set: + singleNamespaceInstall: true + crNamespacedRbacEnable: true + asserts: + - contains: + path: rules + content: + apiGroups: + - scheduling.k8s.io + resources: + - podgroups + - workloads + verbs: + - create + - delete + - get + - list + - patch + - update + - watch diff --git a/helm-chart/kuberay-operator/tests/role_test.yaml b/helm-chart/kuberay-operator/tests/role_test.yaml index 91b5c52134d..1e97a887d7f 100644 --- a/helm-chart/kuberay-operator/tests/role_test.yaml +++ b/helm-chart/kuberay-operator/tests/role_test.yaml @@ -31,3 +31,22 @@ tests: apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole name: kuberay-operator + + - it: Should always include scheduling.k8s.io RBAC rules + asserts: + - contains: + path: rules + content: + apiGroups: + - scheduling.k8s.io + resources: + - podgroups + - workloads + verbs: + - create + - delete + - get + - list + - patch + - update + - watch diff --git a/helm-chart/kuberay-operator/values.yaml b/helm-chart/kuberay-operator/values.yaml index 8810d0305b4..85cf49e9c40 100644 --- a/helm-chart/kuberay-operator/values.yaml +++ b/helm-chart/kuberay-operator/values.yaml @@ -142,6 +142,8 @@ featureGates: enabled: false - name: RayCronJob enabled: false +- name: NativeWorkloadScheduling + enabled: false # Configurations for KubeRay operator metrics. metrics: diff --git a/ray-operator/DEVELOPMENT.md b/ray-operator/DEVELOPMENT.md index 5aba02fad1e..d8f74bc2966 100644 --- a/ray-operator/DEVELOPMENT.md +++ b/ray-operator/DEVELOPMENT.md @@ -345,3 +345,47 @@ docker buildx build --tag quay.io//operator:latest --tag docker.io/-. +func podGroupName(clusterName, templateName string) string { + return clusterName + "-" + templateName +} + +// isNativeWorkloadSchedulingEnabled returns true when both the feature gate and the per-cluster +// opt-in annotation are set. +func isNativeWorkloadSchedulingEnabled(instance *rayv1.RayCluster) bool { + return features.Enabled(features.NativeWorkloadScheduling) && + instance.Annotations[NativeWorkloadSchedulingAnnotation] == "true" +} + +// shouldSetSchedulingGroup returns true when native scheduling is active and +// Workload/PodGroup resources will actually be created (i.e. no skip conditions). +func (r *RayClusterReconciler) shouldSetSchedulingGroup(instance *rayv1.RayCluster) bool { + return r.nativeSchedulingSkipReason(instance) == skipReasonNone +} + +// nativeSchedulingSkipReason returns a schedulingSkipReason indicating why native workload +// scheduling should be skipped for this RayCluster, or skipReasonNone if it should proceed. +// It is used by both reconcileNativeWorkloadScheduling and shouldSetSchedulingGroup. +func (r *RayClusterReconciler) nativeSchedulingSkipReason(instance *rayv1.RayCluster) schedulingSkipReason { + if !isNativeWorkloadSchedulingEnabled(instance) { + return skipReasonDisabled + } + if r.options.BatchSchedulerManager != nil { + if scheduler, err := r.options.BatchSchedulerManager.GetScheduler(); err == nil && scheduler != nil && scheduler.Name() != "default" { + return skipReasonBatchScheduler + } + } + if utils.IsAutoscalingEnabled(&instance.Spec) { + return skipReasonAutoscaling + } + if len(instance.Spec.WorkerGroupSpecs) > schedulingv1alpha2.WorkloadMaxPodGroupTemplates-1 { + return skipReasonTooManyWorkerGroups + } + return skipReasonNone +} + +// setSchedulingGroup sets the schedulingGroup field on a pod to link it to a PodGroup. +func setSchedulingGroup(pod *corev1.Pod, pgName string) { + pod.Spec.SchedulingGroup = &corev1.PodSchedulingGroup{ + PodGroupName: &pgName, + } +} + +// isWorkloadStale returns true if the existing Workload's PodGroupTemplates no longer match +// the current RayCluster spec (e.g., worker groups added/removed/renamed, replica count changed, +// or suspension state changed). +func isWorkloadStale(existing *schedulingv1alpha2.Workload, instance *rayv1.RayCluster) bool { + desired := buildPodGroupSpecs(instance) + + // Number of templates must match: 1 (head) + len(WorkerGroupSpecs). + if len(existing.Spec.PodGroupTemplates) != len(desired) { + return true + } + + // Build a map of existing templates by name for efficient lookup. + existingByName := make(map[string]schedulingv1alpha2.PodGroupTemplate, len(existing.Spec.PodGroupTemplates)) + for _, tmpl := range existing.Spec.PodGroupTemplates { + existingByName[tmpl.Name] = tmpl + } + + for _, d := range desired { + e, ok := existingByName[d.templateName] + if !ok { + // Template name not found — worker group added or renamed. + return true + } + if !schedulingPoliciesMatch(e.SchedulingPolicy, d.schedulingPolicy) { + return true + } + } + + return false +} + +// schedulingPoliciesMatch returns true if two PodGroupSchedulingPolicy values are equivalent. +func schedulingPoliciesMatch(a, b schedulingv1alpha2.PodGroupSchedulingPolicy) bool { + // Both unset — structurally equal. + if a.Basic == nil && a.Gang == nil && b.Basic == nil && b.Gang == nil { + return true + } + // Both Basic + if a.Basic != nil && b.Basic != nil { + return true + } + // Both Gang with same MinCount + if a.Gang != nil && b.Gang != nil { + return a.Gang.MinCount == b.Gang.MinCount + } + // One is Basic and the other is Gang (or one is nil and the other set) + return false +} + +// deleteNativeWorkloadSchedulingResources deletes the Workload and all PodGroups owned by +// the given RayCluster. NotFound errors are treated as no-ops. +func (r *RayClusterReconciler) deleteNativeWorkloadSchedulingResources(ctx context.Context, instance *rayv1.RayCluster) error { + logger := ctrl.LoggerFrom(ctx) + + // Delete all PodGroups owned by this RayCluster. + var podGroupList schedulingv1alpha2.PodGroupList + if err := r.List(ctx, &podGroupList, + client.InNamespace(instance.Namespace), + client.MatchingLabels{utils.RayClusterLabelKey: instance.Name}, + ); err != nil { + return fmt.Errorf("failed to list PodGroups for RayCluster %s/%s: %w", instance.Namespace, instance.Name, err) + } + for i := range podGroupList.Items { + pg := &podGroupList.Items[i] + // Remove the scheduler's PodGroup protection finalizer before deletion so that + // the PodGroup is deleted immediately rather than waiting for the scheduler to + // process the finalizer removal. + if controllerutil.RemoveFinalizer(pg, podGroupProtectionFinalizer) { + if err := r.Update(ctx, pg); err != nil { + if !errors.IsNotFound(err) { + return fmt.Errorf("failed to remove finalizer from PodGroup %s/%s: %w", pg.Namespace, pg.Name, err) + } + } + } + if err := r.Delete(ctx, pg); err != nil { + if !errors.IsNotFound(err) { + r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(FailedToDeletePodGroup), + "Failed to delete PodGroup %s/%s: %v", pg.Namespace, pg.Name, err) + return err + } + } else { + logger.Info("Deleted PodGroup", "name", pg.Name) + r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(DeletedPodGroup), + "Deleted PodGroup %s/%s", pg.Namespace, pg.Name) + } + } + + // Delete the Workload. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: instance.Name, + Namespace: instance.Namespace, + }, + } + if err := r.Delete(ctx, workload); err != nil { + if !errors.IsNotFound(err) { + r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(FailedToDeleteWorkload), + "Failed to delete Workload %s/%s: %v", workload.Namespace, workload.Name, err) + return err + } + } else { + logger.Info("Deleted Workload", "name", workload.Name) + r.Recorder.Eventf(instance, corev1.EventTypeNormal, string(DeletedWorkload), + "Deleted Workload %s/%s", workload.Namespace, workload.Name) + } + + return nil +} + +// setWorkloadScheduledCondition sets the WorkloadScheduled condition on the RayCluster status. +// When native workload scheduling is enabled and the cluster is not suspended, the condition +// reflects whether the Workload resource has been created. When native scheduling is not enabled +// or the cluster is suspended/suspending, the condition is removed entirely rather than set to +// False because the condition is not meaningful when scheduling resources do not exist. +func (r *RayClusterReconciler) setWorkloadScheduledCondition(ctx context.Context, instance *rayv1.RayCluster, suspendStatus rayv1.RayClusterConditionType) { + if !isNativeWorkloadSchedulingEnabled(instance) || suspendStatus == rayv1.RayClusterSuspended || suspendStatus == rayv1.RayClusterSuspending { + meta.RemoveStatusCondition(&instance.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + return + } + + workload := &schedulingv1alpha2.Workload{} + err := r.Get(ctx, types.NamespacedName{Name: instance.Name, Namespace: instance.Namespace}, workload) + if err != nil { + if !errors.IsNotFound(err) { + logger := ctrl.LoggerFrom(ctx) + logger.V(1).Info("Failed to get Workload for condition check", "error", err) + } + meta.SetStatusCondition(&instance.Status.Conditions, metav1.Condition{ + Type: string(rayv1.RayClusterWorkloadScheduled), + Status: metav1.ConditionFalse, + Reason: rayv1.WorkloadPending, + Message: "Workload has not been created yet", + }) + return + } + + meta.SetStatusCondition(&instance.Status.Conditions, metav1.Condition{ + Type: string(rayv1.RayClusterWorkloadScheduled), + Status: metav1.ConditionTrue, + Reason: rayv1.WorkloadReady, + Message: "Workload and PodGroups have been created", + }) +} diff --git a/ray-operator/controllers/ray/native_workload_scheduling_test.go b/ray-operator/controllers/ray/native_workload_scheduling_test.go new file mode 100644 index 00000000000..18aa957a3b6 --- /dev/null +++ b/ray-operator/controllers/ray/native_workload_scheduling_test.go @@ -0,0 +1,2029 @@ +package ray + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + schedulingv1alpha2 "k8s.io/api/scheduling/v1alpha2" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + clientFake "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + + configapi "github.com/ray-project/kuberay/ray-operator/apis/config/v1alpha1" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/batchscheduler" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/expectations" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + "github.com/ray-project/kuberay/ray-operator/pkg/features" +) + +func newTestScheme() *runtime.Scheme { + s := runtime.NewScheme() + _ = rayv1.AddToScheme(s) + _ = corev1.AddToScheme(s) + _ = schedulingv1alpha2.AddToScheme(s) + return s +} + +func newTestRayCluster(workerGroups ...rayv1.WorkerGroupSpec) *rayv1.RayCluster { + return &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + UID: types.UID("test-uid"), + Annotations: map[string]string{ + NativeWorkloadSchedulingAnnotation: "true", + }, + }, + Spec: rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + RayStartParams: map[string]string{ + "port": "6379", + "num-cpus": "1", + }, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "ray-head", Image: "rayproject/ray:latest"}}, + }, + }, + }, + WorkerGroupSpecs: workerGroups, + }, + } +} + +func newWorkerGroup(name string, replicas int32) rayv1.WorkerGroupSpec { + return rayv1.WorkerGroupSpec{ + GroupName: name, + Replicas: new(replicas), + MinReplicas: new(replicas), + MaxReplicas: new(replicas), + NumOfHosts: 1, + RayStartParams: map[string]string{ + "port": "6379", + "num-cpus": "1", + }, + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "ray-worker", Image: "rayproject/ray:latest"}}, + }, + }, + } +} + +func newReconciler(fakeClient client.Client, s *runtime.Scheme, recorder record.EventRecorder, opts ...RayClusterReconcilerOptions) *RayClusterReconciler { + var options RayClusterReconcilerOptions + if len(opts) > 0 { + options = opts[0] + } + return &RayClusterReconciler{ + Client: fakeClient, + Scheme: s, + Recorder: recorder, + options: options, + rayClusterScaleExpectation: expectations.NewRayClusterScaleExpectation(fakeClient), + } +} + +// --- Reconcile behavior tests --- + +func TestReconcileNativeWorkloadScheduling_CreatesWorkloadAndPodGroups(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // Verify Workload was created. + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + assert.Equal(t, "test-cluster", workload.Name) + assert.Len(t, workload.Spec.PodGroupTemplates, 2) // head + 1 worker group + + // Verify head PodGroup was created. + headPG := &schedulingv1alpha2.PodGroup{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster-head", Namespace: "default"}, headPG) + require.NoError(t, err) + + // Verify worker PodGroup was created. + workerPG := &schedulingv1alpha2.PodGroup{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster-worker-workers", Namespace: "default"}, workerPG) + require.NoError(t, err) +} + +func TestReconcileNativeWorkloadScheduling_MultipleWorkerGroups(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster( + newWorkerGroup("cpu-workers", 2), + newWorkerGroup("gpu-workers", 4), + newWorkerGroup("tpu-workers", 1), + ) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // Verify Workload has 4 templates (1 head + 3 worker groups). + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + assert.Len(t, workload.Spec.PodGroupTemplates, 4) + + // Verify all 4 PodGroups exist. + for _, pgName := range []string{"test-cluster-head", "test-cluster-worker-cpu-workers", "test-cluster-worker-gpu-workers", "test-cluster-worker-tpu-workers"} { + pg := &schedulingv1alpha2.PodGroup{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: pgName, Namespace: "default"}, pg) + require.NoError(t, err, "PodGroup %s should exist", pgName) + } +} + +func TestReconcileNativeWorkloadScheduling_Idempotent(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(20)) + ctx := context.Background() + + // Call twice — second call should succeed (AlreadyExists is a no-op). + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + err = r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // Verify still only one Workload. + workloadList := &schedulingv1alpha2.WorkloadList{} + err = fakeClient.List(ctx, workloadList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Len(t, workloadList.Items, 1) + + // Verify still only 2 PodGroups (head + worker). + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Len(t, pgList.Items, 2) +} + +func TestReconcileNativeWorkloadScheduling_SkipsWhenAnnotationMissing(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + cluster.Annotations = nil // No annotation + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // No Workloads should be created. + workloadList := &schedulingv1alpha2.WorkloadList{} + err = fakeClient.List(ctx, workloadList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Empty(t, workloadList.Items) +} + +func TestReconcileNativeWorkloadScheduling_SkipsWhenFeatureGateDisabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, false) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // No Workloads should be created. + workloadList := &schedulingv1alpha2.WorkloadList{} + err = fakeClient.List(ctx, workloadList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Empty(t, workloadList.Items) +} + +func TestReconcileNativeWorkloadScheduling_SkipsWhenAutoscalingEnabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + cluster.Spec.EnableInTreeAutoscaling = new(true) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // No Workloads should be created. + workloadList := &schedulingv1alpha2.WorkloadList{} + err = fakeClient.List(ctx, workloadList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Empty(t, workloadList.Items) + + // Should have emitted a warning event. + assert.Len(t, fakeRecorder.Events, 1) +} + +func TestReconcileNativeWorkloadScheduling_SkipsWhenBatchSchedulerConfigured(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(10) + // Create a SchedulerManager with a real non-default batch scheduler (yunikorn). + batchMgr, err := batchscheduler.NewSchedulerManager(context.Background(), + configapi.Configuration{BatchScheduler: "yunikorn"}, nil, nil) + require.NoError(t, err) + r := newReconciler(fakeClient, s, fakeRecorder, RayClusterReconcilerOptions{ + BatchSchedulerManager: batchMgr, + }) + ctx := context.Background() + + err = r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // No Workloads should be created. + workloadList := &schedulingv1alpha2.WorkloadList{} + err = fakeClient.List(ctx, workloadList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Empty(t, workloadList.Items) + + // Should have emitted a warning event. + assert.Len(t, fakeRecorder.Events, 1) +} + +func TestReconcileNativeWorkloadScheduling_FailsWhenMoreThan7WorkerGroups(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + workers := make([]rayv1.WorkerGroupSpec, 8) + for i := range workers { + workers[i] = newWorkerGroup("workers-"+string(rune('a'+i)), 1) + } + cluster := newTestRayCluster(workers...) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeding the maximum of 7") + + // No Workloads should be created. + workloadList := &schedulingv1alpha2.WorkloadList{} + err = fakeClient.List(ctx, workloadList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Empty(t, workloadList.Items) +} + +// --- Workload construction tests --- + +func TestBuildWorkload_HeadTemplateUsesBasicPolicy(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + + // First template should be "head" with BasicSchedulingPolicy. + require.Len(t, workload.Spec.PodGroupTemplates, 2) + headTemplate := workload.Spec.PodGroupTemplates[0] + assert.Equal(t, "head", headTemplate.Name) + assert.NotNil(t, headTemplate.SchedulingPolicy.Basic) + assert.Nil(t, headTemplate.SchedulingPolicy.Gang) +} + +func TestBuildWorkload_WorkerTemplateUsesGangPolicy(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + + // Second template should be "worker-workers" with GangSchedulingPolicy. + require.Len(t, workload.Spec.PodGroupTemplates, 2) + workerTemplate := workload.Spec.PodGroupTemplates[1] + assert.Equal(t, "worker-workers", workerTemplate.Name) + assert.Nil(t, workerTemplate.SchedulingPolicy.Basic) + require.NotNil(t, workerTemplate.SchedulingPolicy.Gang) + assert.Equal(t, int32(3), workerTemplate.SchedulingPolicy.Gang.MinCount) +} + +func TestBuildWorkload_MinCountMatchesDesiredReplicas(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster( + newWorkerGroup("small", 2), + newWorkerGroup("large", 5), + ) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + require.Len(t, workload.Spec.PodGroupTemplates, 3) + + assert.Equal(t, int32(2), workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount) + assert.Equal(t, int32(5), workload.Spec.PodGroupTemplates[2].SchedulingPolicy.Gang.MinCount) +} + +func TestBuildWorkload_MinCountWithNumOfHosts2(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + wg := newWorkerGroup("multi-host", 3) + wg.NumOfHosts = 2 + cluster := newTestRayCluster(wg) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + require.Len(t, workload.Spec.PodGroupTemplates, 2) + + // GetWorkerGroupDesiredReplicas multiplies by NumOfHosts: 3 * 2 = 6 + require.NotNil(t, workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang) + assert.Equal(t, int32(6), workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount) +} + +func TestBuildWorkload_SuspendedWorkerGroupMinCount0(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + wg := newWorkerGroup("suspended-group", 3) + wg.Suspend = new(true) + cluster := newTestRayCluster(wg) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + require.Len(t, workload.Spec.PodGroupTemplates, 2) + + // Suspended group should use BasicSchedulingPolicy, not gang with minCount=0. + workerTemplate := workload.Spec.PodGroupTemplates[1] + assert.NotNil(t, workerTemplate.SchedulingPolicy.Basic) + assert.Nil(t, workerTemplate.SchedulingPolicy.Gang) +} + +func TestBuildWorkload_MinReplicas0UsesBasicPolicy(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + wg := newWorkerGroup("zero-min", 0) + wg.MinReplicas = new(int32(0)) + wg.Replicas = new(int32(0)) + cluster := newTestRayCluster(wg) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + require.Len(t, workload.Spec.PodGroupTemplates, 2) + + workerTemplate := workload.Spec.PodGroupTemplates[1] + assert.NotNil(t, workerTemplate.SchedulingPolicy.Basic) + assert.Nil(t, workerTemplate.SchedulingPolicy.Gang) +} + +// --- OwnerReference tests --- + +func TestBuildWorkload_OwnerReference(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + + require.Len(t, workload.OwnerReferences, 1) + ownerRef := workload.OwnerReferences[0] + assert.Equal(t, "ray.io/v1", ownerRef.APIVersion) + assert.Equal(t, "RayCluster", ownerRef.Kind) + assert.Equal(t, "test-cluster", ownerRef.Name) + assert.Equal(t, cluster.UID, ownerRef.UID) + assert.True(t, *ownerRef.Controller) + assert.True(t, *ownerRef.BlockOwnerDeletion) +} + +func TestBuildPodGroup_OwnerReference(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + policy := schedulingv1alpha2.PodGroupSchedulingPolicy{ + Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}, + } + pg, err := r.buildPodGroup(cluster, "worker-workers", policy) + require.NoError(t, err) + + require.Len(t, pg.OwnerReferences, 1) + ownerRef := pg.OwnerReferences[0] + assert.Equal(t, "ray.io/v1", ownerRef.APIVersion) + assert.Equal(t, "RayCluster", ownerRef.Kind) + assert.Equal(t, "test-cluster", ownerRef.Name) + assert.Equal(t, cluster.UID, ownerRef.UID) + assert.True(t, *ownerRef.Controller) + assert.True(t, *ownerRef.BlockOwnerDeletion) +} + +// --- PodGroup construction tests --- + +func TestBuildPodGroup_TemplateRef(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster() + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + policy := schedulingv1alpha2.PodGroupSchedulingPolicy{ + Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}, + } + pg, err := r.buildPodGroup(cluster, "head", policy) + require.NoError(t, err) + + require.NotNil(t, pg.Spec.PodGroupTemplateRef) + require.NotNil(t, pg.Spec.PodGroupTemplateRef.Workload) + assert.Equal(t, "test-cluster", pg.Spec.PodGroupTemplateRef.Workload.WorkloadName) + assert.Equal(t, "head", pg.Spec.PodGroupTemplateRef.Workload.PodGroupTemplateName) +} + +func TestBuildPodGroup_PolicyCopied(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster() + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + gangPolicy := schedulingv1alpha2.PodGroupSchedulingPolicy{ + Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 5}, + } + pg, err := r.buildPodGroup(cluster, "worker-gpu", gangPolicy) + require.NoError(t, err) + + require.NotNil(t, pg.Spec.SchedulingPolicy.Gang) + assert.Equal(t, int32(5), pg.Spec.SchedulingPolicy.Gang.MinCount) + assert.Nil(t, pg.Spec.SchedulingPolicy.Basic) +} + +// --- Workload spec tests --- + +func TestBuildWorkload_ControllerRef(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + + require.NotNil(t, workload.Spec.ControllerRef) + assert.Equal(t, "ray.io", workload.Spec.ControllerRef.APIGroup) + assert.Equal(t, "RayCluster", workload.Spec.ControllerRef.Kind) + assert.Equal(t, "test-cluster", workload.Spec.ControllerRef.Name) +} + +func TestBuildWorkload_Labels(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.NoError(t, err) + + assert.Equal(t, "test-cluster", workload.Labels[utils.RayClusterLabelKey]) +} + +func TestBuildPodGroup_Labels(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster() + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + + policy := schedulingv1alpha2.PodGroupSchedulingPolicy{ + Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}, + } + pg, err := r.buildPodGroup(cluster, "head", policy) + require.NoError(t, err) + + assert.Equal(t, "test-cluster", pg.Labels[utils.RayClusterLabelKey]) +} + +// --- Pod scheduling group tests --- + +func TestSetSchedulingGroup_HeadPod(t *testing.T) { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "head-pod"}, + Spec: corev1.PodSpec{}, + } + setSchedulingGroup(pod, podGroupName("my-cluster", "head")) + + require.NotNil(t, pod.Spec.SchedulingGroup) + require.NotNil(t, pod.Spec.SchedulingGroup.PodGroupName) + assert.Equal(t, "my-cluster-head", *pod.Spec.SchedulingGroup.PodGroupName) +} + +func TestSetSchedulingGroup_WorkerPod(t *testing.T) { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "worker-pod"}, + Spec: corev1.PodSpec{}, + } + setSchedulingGroup(pod, podGroupName("my-cluster", "worker-gpu-workers")) + + require.NotNil(t, pod.Spec.SchedulingGroup) + require.NotNil(t, pod.Spec.SchedulingGroup.PodGroupName) + assert.Equal(t, "my-cluster-worker-gpu-workers", *pod.Spec.SchedulingGroup.PodGroupName) +} + +func TestSetSchedulingGroup_Idempotent(t *testing.T) { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "head-pod"}, + Spec: corev1.PodSpec{}, + } + + // Set once, then set again with same value — should be stable. + setSchedulingGroup(pod, "my-cluster-head") + require.NotNil(t, pod.Spec.SchedulingGroup) + assert.Equal(t, "my-cluster-head", *pod.Spec.SchedulingGroup.PodGroupName) + + setSchedulingGroup(pod, "my-cluster-head") + assert.Equal(t, "my-cluster-head", *pod.Spec.SchedulingGroup.PodGroupName) +} + +// --- Naming tests --- + +func TestPodGroupName(t *testing.T) { + tests := []struct { + clusterName string + templateName string + expected string + }{ + {"my-cluster", "head", "my-cluster-head"}, + {"my-cluster", "worker-gpu-workers", "my-cluster-worker-gpu-workers"}, + {"ray-cluster-1", "worker-cpu", "ray-cluster-1-worker-cpu"}, + } + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + assert.Equal(t, tt.expected, podGroupName(tt.clusterName, tt.templateName)) + }) + } +} + +// --- isNativeWorkloadSchedulingEnabled tests --- + +func TestIsNativeWorkloadSchedulingEnabled(t *testing.T) { + tests := []struct { + name string + featureGate bool + annotation string + expected bool + }{ + {"both enabled", true, "true", true}, + {"gate on, annotation missing", true, "", false}, + {"gate off, annotation on", false, "true", false}, + {"both off", false, "", false}, + {"gate on, annotation wrong value", true, "false", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, tt.featureGate) + cluster := &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + NativeWorkloadSchedulingAnnotation: tt.annotation, + }, + }, + } + if tt.annotation == "" { + cluster.Annotations = nil + } + assert.Equal(t, tt.expected, isNativeWorkloadSchedulingEnabled(cluster)) + }) + } +} + +func TestShouldSetSchedulingGroup(t *testing.T) { + tests := []struct { + name string + featureGate bool + annotation string + autoscaling bool + batchSched bool + tooManyWorkers bool + expected bool + }{ + {"enabled without autoscaling", true, "true", false, false, false, true}, + {"enabled with autoscaling", true, "true", true, false, false, false}, + {"disabled", false, "true", false, false, false, false}, + {"no annotation", true, "", false, false, false, false}, + {"batch scheduler configured", true, "true", false, true, false, false}, + {"too many worker groups", true, "true", false, false, true, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, tt.featureGate) + cluster := &rayv1.RayCluster{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + NativeWorkloadSchedulingAnnotation: tt.annotation, + }, + }, + } + if tt.annotation == "" { + cluster.Annotations = nil + } + if tt.autoscaling { + cluster.Spec.EnableInTreeAutoscaling = new(true) + } + if tt.tooManyWorkers { + for i := range 8 { + cluster.Spec.WorkerGroupSpecs = append(cluster.Spec.WorkerGroupSpecs, + newWorkerGroup(fmt.Sprintf("wg-%d", i), 1)) + } + } + + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(10) + var opts RayClusterReconcilerOptions + if tt.batchSched { + batchMgr, err := batchscheduler.NewSchedulerManager(context.Background(), + configapi.Configuration{BatchScheduler: "yunikorn"}, nil, nil) + require.NoError(t, err) + opts = RayClusterReconcilerOptions{ + BatchSchedulerManager: batchMgr, + } + } + r := newReconciler(fakeClient, s, fakeRecorder, opts) + assert.Equal(t, tt.expected, r.shouldSetSchedulingGroup(cluster)) + }) + } +} + +// --- buildPodGroupSpecs direct tests --- + +func TestBuildPodGroupSpecs_HeadOnly(t *testing.T) { + cluster := newTestRayCluster() + specs := buildPodGroupSpecs(cluster) + + require.Len(t, specs, 1) + assert.Equal(t, "head", specs[0].templateName) + assert.NotNil(t, specs[0].schedulingPolicy.Basic) + assert.Nil(t, specs[0].schedulingPolicy.Gang) +} + +func TestBuildPodGroupSpecs_WorkerGroupPolicies(t *testing.T) { + tests := []struct { + name string + workerGroup rayv1.WorkerGroupSpec + expectGang bool + expectMinCnt int32 + }{ + { + name: "active worker group uses gang policy", + workerGroup: newWorkerGroup("active", 3), + expectGang: true, + expectMinCnt: 3, + }, + { + name: "suspended worker group uses basic policy", + workerGroup: func() rayv1.WorkerGroupSpec { + wg := newWorkerGroup("suspended", 3) + wg.Suspend = new(true) + return wg + }(), + expectGang: false, + }, + { + name: "zero replicas uses basic policy", + workerGroup: func() rayv1.WorkerGroupSpec { + wg := newWorkerGroup("zero", 0) + wg.MinReplicas = new(int32(0)) + wg.Replicas = new(int32(0)) + return wg + }(), + expectGang: false, + }, + { + name: "multi-host multiplies replicas", + workerGroup: func() rayv1.WorkerGroupSpec { + wg := newWorkerGroup("multi", 2) + wg.NumOfHosts = 3 + return wg + }(), + expectGang: true, + expectMinCnt: 6, // 2 * 3 + }, + { + name: "replicas clamped to max", + workerGroup: func() rayv1.WorkerGroupSpec { + wg := newWorkerGroup("clamped", 10) + wg.MaxReplicas = new(int32(5)) + return wg + }(), + expectGang: true, + expectMinCnt: 5, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cluster := newTestRayCluster(tt.workerGroup) + specs := buildPodGroupSpecs(cluster) + + require.Len(t, specs, 2) + // First is always head. + assert.Equal(t, "head", specs[0].templateName) + + workerSpec := specs[1] + assert.Equal(t, "worker-"+tt.workerGroup.GroupName, workerSpec.templateName) + if tt.expectGang { + require.NotNil(t, workerSpec.schedulingPolicy.Gang) + assert.Nil(t, workerSpec.schedulingPolicy.Basic) + assert.Equal(t, tt.expectMinCnt, workerSpec.schedulingPolicy.Gang.MinCount) + } else { + assert.NotNil(t, workerSpec.schedulingPolicy.Basic) + assert.Nil(t, workerSpec.schedulingPolicy.Gang) + } + }) + } +} + +func TestBuildPodGroupSpecs_MultipleWorkerGroups(t *testing.T) { + cluster := newTestRayCluster( + newWorkerGroup("cpu", 2), + newWorkerGroup("gpu", 4), + ) + specs := buildPodGroupSpecs(cluster) + + require.Len(t, specs, 3) + assert.Equal(t, "head", specs[0].templateName) + assert.Equal(t, "worker-cpu", specs[1].templateName) + assert.Equal(t, int32(2), specs[1].schedulingPolicy.Gang.MinCount) + assert.Equal(t, "worker-gpu", specs[2].templateName) + assert.Equal(t, int32(4), specs[2].schedulingPolicy.Gang.MinCount) +} + +// --- Boundary condition tests --- + +func TestReconcileNativeWorkloadScheduling_Exactly7WorkerGroups(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + workers := make([]rayv1.WorkerGroupSpec, 7) + for i := range workers { + workers[i] = newWorkerGroup("workers-"+strconv.Itoa(i), 1) + } + cluster := newTestRayCluster(workers...) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(20)) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // Should create 1 Workload with 8 templates (1 head + 7 workers). + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + assert.Len(t, workload.Spec.PodGroupTemplates, 8) + + // Should create 8 PodGroups (1 head + 7 workers). + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + assert.Len(t, pgList.Items, 8) +} + +// --- Error path tests --- + +func TestBuildWorkload_SetControllerReferenceError(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + // Use a scheme that doesn't know about RayCluster — SetControllerReference will fail. + badScheme := runtime.NewScheme() + _ = schedulingv1alpha2.AddToScheme(badScheme) + fakeClient := clientFake.NewClientBuilder().WithScheme(badScheme).Build() + r := newReconciler(fakeClient, badScheme, record.NewFakeRecorder(10)) + + workload, err := r.buildWorkload(cluster) + require.Error(t, err) + assert.Nil(t, workload) +} + +func TestBuildPodGroup_SetControllerReferenceError(t *testing.T) { + cluster := newTestRayCluster() + badScheme := runtime.NewScheme() + _ = schedulingv1alpha2.AddToScheme(badScheme) + fakeClient := clientFake.NewClientBuilder().WithScheme(badScheme).Build() + r := newReconciler(fakeClient, badScheme, record.NewFakeRecorder(10)) + + policy := schedulingv1alpha2.PodGroupSchedulingPolicy{ + Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}, + } + pg, err := r.buildPodGroup(cluster, "head", policy) + require.Error(t, err) + assert.Nil(t, pg) +} + +func TestReconcileNativeWorkloadScheduling_WorkloadCreateFailure(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithInterceptorFuncs(interceptor.Funcs{ + Create: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if _, ok := obj.(*schedulingv1alpha2.Workload); ok { + return fmt.Errorf("simulated API server error") + } + return c.Create(ctx, obj, opts...) + }, + }).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "simulated API server error") + + // Should have emitted a FailedToCreateWorkload event. + assert.Len(t, fakeRecorder.Events, 1) + event := <-fakeRecorder.Events + assert.Contains(t, event, string(FailedToCreateWorkload)) +} + +func TestReconcileNativeWorkloadScheduling_PodGroupCreateFailure(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithInterceptorFuncs(interceptor.Funcs{ + Create: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if _, ok := obj.(*schedulingv1alpha2.PodGroup); ok { + return fmt.Errorf("simulated PodGroup creation error") + } + return c.Create(ctx, obj, opts...) + }, + }).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "simulated PodGroup creation error") + + // Workload should have been created successfully. + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + + // Should have emitted CreatedWorkload + FailedToCreatePodGroup events. + assert.Len(t, fakeRecorder.Events, 2) +} + +// --- Controller integration tests: schedulingGroup on pods --- + +func TestCreateHeadPod_SetsSchedulingGroup(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.createHeadPod(ctx, *cluster, "") + require.NoError(t, err) + + podList := &corev1.PodList{} + err = fakeClient.List(ctx, podList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + require.Len(t, podList.Items, 1) + + pod := podList.Items[0] + require.NotNil(t, pod.Spec.SchedulingGroup, "head pod should have schedulingGroup set") + assert.Equal(t, "test-cluster-head", *pod.Spec.SchedulingGroup.PodGroupName) +} + +func TestCreateWorkerPod_SetsSchedulingGroup(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + worker := newWorkerGroup("gpu-workers", 3) + cluster := newTestRayCluster(worker) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.createWorkerPod(ctx, *cluster, worker) + require.NoError(t, err) + + podList := &corev1.PodList{} + err = fakeClient.List(ctx, podList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + require.Len(t, podList.Items, 1) + + pod := podList.Items[0] + require.NotNil(t, pod.Spec.SchedulingGroup, "worker pod should have schedulingGroup set") + assert.Equal(t, "test-cluster-worker-gpu-workers", *pod.Spec.SchedulingGroup.PodGroupName) +} + +func TestCreateWorkerPodWithIndex_SetsSchedulingGroup(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + worker := newWorkerGroup("tpu-workers", 2) + cluster := newTestRayCluster(worker) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.createWorkerPodWithIndex(ctx, *cluster, worker, "replica-0", 0, 0) + require.NoError(t, err) + + podList := &corev1.PodList{} + err = fakeClient.List(ctx, podList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + require.Len(t, podList.Items, 1) + + pod := podList.Items[0] + require.NotNil(t, pod.Spec.SchedulingGroup, "worker pod should have schedulingGroup set") + assert.Equal(t, "test-cluster-worker-tpu-workers", *pod.Spec.SchedulingGroup.PodGroupName) +} + +func TestCreateHeadPod_NoSchedulingGroupWhenDisabled(t *testing.T) { + // Feature gate enabled but annotation missing + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + delete(cluster.Annotations, NativeWorkloadSchedulingAnnotation) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.createHeadPod(ctx, *cluster, "") + require.NoError(t, err) + + podList := &corev1.PodList{} + err = fakeClient.List(ctx, podList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + require.Len(t, podList.Items, 1) + + pod := podList.Items[0] + assert.Nil(t, pod.Spec.SchedulingGroup, "head pod should not have schedulingGroup when annotation is missing") +} + +func TestCreateWorkerPod_NoSchedulingGroupWhenFeatureGateDisabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, false) + + worker := newWorkerGroup("workers", 3) + cluster := newTestRayCluster(worker) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.createWorkerPod(ctx, *cluster, worker) + require.NoError(t, err) + + podList := &corev1.PodList{} + err = fakeClient.List(ctx, podList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + require.Len(t, podList.Items, 1) + + pod := podList.Items[0] + assert.Nil(t, pod.Spec.SchedulingGroup, "worker pod should not have schedulingGroup when feature gate is disabled") +} + +func TestCreateWorkerPodWithIndex_NoSchedulingGroupWhenDisabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + worker := newWorkerGroup("tpu-workers", 2) + cluster := newTestRayCluster(worker) + delete(cluster.Annotations, NativeWorkloadSchedulingAnnotation) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + err := r.createWorkerPodWithIndex(ctx, *cluster, worker, "replica-0", 0, 0) + require.NoError(t, err) + + podList := &corev1.PodList{} + err = fakeClient.List(ctx, podList, &client.ListOptions{Namespace: "default"}) + require.NoError(t, err) + require.Len(t, podList.Items, 1) + + pod := podList.Items[0] + assert.Nil(t, pod.Spec.SchedulingGroup, "worker pod should not have schedulingGroup when annotation is missing") +} + +// --- isWorkloadStale tests --- + +func TestIsWorkloadStale_NoChange(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-workers", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}}}, + }, + }, + } + assert.False(t, isWorkloadStale(workload, cluster)) +} + +func TestIsWorkloadStale_WorkerGroupAdded(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("cpu", 2), newWorkerGroup("gpu", 4)) + // Existing workload only has head + cpu. + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-cpu", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 2}}}, + }, + }, + } + assert.True(t, isWorkloadStale(workload, cluster)) +} + +func TestIsWorkloadStale_WorkerGroupRemoved(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("cpu", 2)) + // Existing workload has head + cpu + gpu. + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-cpu", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 2}}}, + {Name: "worker-gpu", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 4}}}, + }, + }, + } + assert.True(t, isWorkloadStale(workload, cluster)) +} + +func TestIsWorkloadStale_WorkerGroupRenamed(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("gpu-v2", 3)) + // Existing workload has old name. + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-gpu-v1", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}}}, + }, + }, + } + assert.True(t, isWorkloadStale(workload, cluster)) +} + +func TestIsWorkloadStale_ReplicaCountChanged(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 5)) + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-workers", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}}}, + }, + }, + } + assert.True(t, isWorkloadStale(workload, cluster)) +} + +func TestIsWorkloadStale_NumOfHostsChanged(t *testing.T) { + wg := newWorkerGroup("workers", 2) + wg.NumOfHosts = 3 + cluster := newTestRayCluster(wg) + // Existing workload has minCount=2 (old NumOfHosts=1), desired is 2*3=6. + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-workers", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 2}}}, + }, + }, + } + assert.True(t, isWorkloadStale(workload, cluster)) +} + +func TestIsWorkloadStale_WorkerGroupSuspended(t *testing.T) { + wg := newWorkerGroup("workers", 3) + wg.Suspend = new(true) + cluster := newTestRayCluster(wg) + // Existing workload has gang policy with minCount=3, but suspended group should have basic policy. + workload := &schedulingv1alpha2.Workload{ + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-workers", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}}}, + }, + }, + } + assert.True(t, isWorkloadStale(workload, cluster)) +} + +// --- schedulingPoliciesMatch tests --- + +func TestSchedulingPoliciesMatch(t *testing.T) { + tests := []struct { + name string + a, b schedulingv1alpha2.PodGroupSchedulingPolicy + expected bool + }{ + { + name: "both basic", + a: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}, + b: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}, + expected: true, + }, + { + name: "both gang same minCount", + a: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 5}}, + b: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 5}}, + expected: true, + }, + { + name: "both gang different minCount", + a: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}}, + b: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 5}}, + expected: false, + }, + { + name: "basic vs gang", + a: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}, + b: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 1}}, + expected: false, + }, + { + name: "gang vs basic", + a: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 1}}, + b: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}, + expected: false, + }, + { + name: "both nil", + a: schedulingv1alpha2.PodGroupSchedulingPolicy{}, + b: schedulingv1alpha2.PodGroupSchedulingPolicy{}, + expected: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, schedulingPoliciesMatch(tt.a, tt.b)) + }) + } +} + +// --- deleteNativeWorkloadSchedulingResources tests --- + +func TestDeleteNativeWorkloadSchedulingResources_DeletesAll(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // Pre-create Workload and PodGroups. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + headPG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + workerPG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-worker-workers", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(workload, headPG, workerPG).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.deleteNativeWorkloadSchedulingResources(ctx, cluster) + require.NoError(t, err) + + // Verify Workload is deleted. + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, &schedulingv1alpha2.Workload{}) + assert.True(t, apierrors.IsNotFound(err), "Workload should be deleted") + + // Verify PodGroups are deleted. + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, client.InNamespace("default"), client.MatchingLabels{utils.RayClusterLabelKey: "test-cluster"}) + require.NoError(t, err) + assert.Empty(t, pgList.Items, "All PodGroups should be deleted") +} + +func TestDeleteNativeWorkloadSchedulingResources_NotFoundIsNoop(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + // No pre-existing resources. + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.deleteNativeWorkloadSchedulingResources(ctx, cluster) + require.NoError(t, err) +} + +func TestDeleteNativeWorkloadSchedulingResources_PodGroupDeleteFailure(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + pg := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(pg). + WithInterceptorFuncs(interceptor.Funcs{ + Delete: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.DeleteOption) error { + if _, ok := obj.(*schedulingv1alpha2.PodGroup); ok { + return fmt.Errorf("simulated PodGroup delete error") + } + return c.Delete(ctx, obj, opts...) + }, + }).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.deleteNativeWorkloadSchedulingResources(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "simulated PodGroup delete error") + + // Should have emitted a FailedToDeletePodGroup event. + event := <-fakeRecorder.Events + assert.Contains(t, event, string(FailedToDeletePodGroup)) +} + +func TestDeleteNativeWorkloadSchedulingResources_WorkloadDeleteFailure(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(workload). + WithInterceptorFuncs(interceptor.Funcs{ + Delete: func(ctx context.Context, c client.WithWatch, obj client.Object, opts ...client.DeleteOption) error { + if _, ok := obj.(*schedulingv1alpha2.Workload); ok { + return fmt.Errorf("simulated Workload delete error") + } + return c.Delete(ctx, obj, opts...) + }, + }).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.deleteNativeWorkloadSchedulingResources(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "simulated Workload delete error") + + // Should have emitted a FailedToDeleteWorkload event. + event := <-fakeRecorder.Events + assert.Contains(t, event, string(FailedToDeleteWorkload)) +} + +func TestDeleteNativeWorkloadSchedulingResources_ListFailure(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithInterceptorFuncs(interceptor.Funcs{ + List: func(ctx context.Context, c client.WithWatch, list client.ObjectList, opts ...client.ListOption) error { + if _, ok := list.(*schedulingv1alpha2.PodGroupList); ok { + return fmt.Errorf("simulated list error") + } + return c.List(ctx, list, opts...) + }, + }).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.deleteNativeWorkloadSchedulingResources(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to list PodGroups") +} + +func TestDeleteNativeWorkloadSchedulingResources_RemovesSchedulerFinalizer(t *testing.T) { + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // Pre-create a PodGroup with the scheduler's protection finalizer. + pg := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + Finalizers: []string{podGroupProtectionFinalizer}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(pg).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.deleteNativeWorkloadSchedulingResources(ctx, cluster) + require.NoError(t, err) + + // PodGroup should be deleted (not stuck in deleting state). + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster-head", Namespace: "default"}, &schedulingv1alpha2.PodGroup{}) + assert.True(t, apierrors.IsNotFound(err), "PodGroup with finalizer should be deleted after stripping finalizer") +} + +func TestReconcileNativeWorkloadScheduling_PodGroupDeletionTimestampReturnsError(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // Pre-create a Workload whose spec matches the desired state so that + // isWorkloadStale() returns false and we proceed to PodGroup creation. + // This isolates the test to the PodGroup AlreadyExists + DeletionTimestamp path. + existingWorkload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-workers", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 3}}}, + }, + }, + } + + // Pre-create a head PodGroup with DeletionTimestamp so Create returns AlreadyExists + // and the subsequent Get finds the object mid-deletion. + now := metav1.Now() + headPG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + DeletionTimestamp: &now, + Finalizers: []string{"fake-finalizer"}, // Needed to keep the fake client from deleting immediately. + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(existingWorkload, headPG).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "is being deleted (finalizer pending)") +} + +func TestReconcileNativeWorkloadScheduling_GetFailureAfterAlreadyExists(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // Pre-create a Workload so that Create returns AlreadyExists. + existingWorkload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(existingWorkload). + WithInterceptorFuncs(interceptor.Funcs{ + Get: func(ctx context.Context, c client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if _, ok := obj.(*schedulingv1alpha2.Workload); ok { + return fmt.Errorf("simulated Get error") + } + return c.Get(ctx, key, obj, opts...) + }, + }).Build() + fakeRecorder := record.NewFakeRecorder(10) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get existing Workload") +} + +// --- Reconcile drift detection integration tests --- + +func TestReconcile_StaleWorkloadRecreated(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + // Start with 1 worker group with 3 replicas. + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // Pre-create a stale Workload with old minCount=2. + staleWorkload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-workers", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 2}}}, + }, + }, + } + stalePG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-worker-workers", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(staleWorkload, stalePG).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // Verify the Workload was recreated with the correct minCount. + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + require.Len(t, workload.Spec.PodGroupTemplates, 2) + assert.Equal(t, "worker-workers", workload.Spec.PodGroupTemplates[1].Name) + require.NotNil(t, workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang) + assert.Equal(t, int32(3), workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount) + + // Verify PodGroups were recreated (head + worker). + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, client.InNamespace("default")) + require.NoError(t, err) + assert.Len(t, pgList.Items, 2) +} + +func TestReconcile_UpToDateWorkloadNotRecreated(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + // First reconcile — creates everything. + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + originalUID := workload.UID + originalRV := workload.ResourceVersion + + // Second reconcile — should not recreate. + err = r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + assert.Equal(t, originalUID, workload.UID, "Workload UID should not change when spec is up-to-date") + assert.Equal(t, originalRV, workload.ResourceVersion, "Workload ResourceVersion should not change when spec is up-to-date") +} + +func TestReconcile_StaleWorkloadWorkerGroupAdded(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + + // Cluster now has 2 worker groups. + cluster := newTestRayCluster(newWorkerGroup("cpu", 2), newWorkerGroup("gpu", 4)) + s := newTestScheme() + + // Pre-create Workload with only cpu worker group. + staleWorkload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + Spec: schedulingv1alpha2.WorkloadSpec{ + PodGroupTemplates: []schedulingv1alpha2.PodGroupTemplate{ + {Name: "head", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Basic: &schedulingv1alpha2.BasicSchedulingPolicy{}}}, + {Name: "worker-cpu", SchedulingPolicy: schedulingv1alpha2.PodGroupSchedulingPolicy{Gang: &schedulingv1alpha2.GangSchedulingPolicy{MinCount: 2}}}, + }, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(staleWorkload).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcileNativeWorkloadScheduling(ctx, cluster) + require.NoError(t, err) + + // Verify Workload has 3 templates now (head + cpu + gpu). + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err) + assert.Len(t, workload.Spec.PodGroupTemplates, 3) + + // Verify 3 PodGroups exist. + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, client.InNamespace("default")) + require.NoError(t, err) + assert.Len(t, pgList.Items, 3) +} + +// --- reconcilePods lifecycle tests (suspend / recreate paths) --- + +func TestReconcilePods_SuspendDeletesNativeSchedulingResources(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, false) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + suspend := true + cluster.Spec.Suspend = &suspend + + s := newTestScheme() + + // Pre-create Workload and PodGroups that should be deleted on suspend. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + headPG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + workerPG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-worker-workers", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(workload, headPG, workerPG).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcilePods(ctx, cluster) + require.NoError(t, err) + + // Verify Workload is deleted. + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, &schedulingv1alpha2.Workload{}) + assert.True(t, apierrors.IsNotFound(err), "Workload should be deleted on suspend") + + // Verify PodGroups are deleted. + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, client.InNamespace("default"), client.MatchingLabels{utils.RayClusterLabelKey: "test-cluster"}) + require.NoError(t, err) + assert.Empty(t, pgList.Items, "PodGroups should be deleted on suspend") +} + +func TestReconcilePods_SuspendSkipsDeletionWhenNativeSchedulingDisabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, false) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, false) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + delete(cluster.Annotations, NativeWorkloadSchedulingAnnotation) + suspend := true + cluster.Spec.Suspend = &suspend + + s := newTestScheme() + + // Pre-create resources — they should NOT be deleted since native scheduling is disabled. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(workload).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcilePods(ctx, cluster) + require.NoError(t, err) + + // Verify Workload still exists. + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, &schedulingv1alpha2.Workload{}) + assert.NoError(t, err, "Workload should still exist when native scheduling is disabled") +} + +func TestReconcilePods_RecreateUpgradeDeletesNativeSchedulingResources(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, false) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + recreate := rayv1.RayClusterRecreate + cluster.Spec.UpgradeStrategy = &rayv1.RayClusterUpgradeStrategy{Type: &recreate} + + // Compute the "stale" hash (a different value than what the current spec produces). + staleHash := "stale-hash-value" + + s := newTestScheme() + + // Pre-create a head pod with a stale hash to trigger shouldRecreatePodsForUpgrade. + headPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head-pod", + Namespace: "default", + Labels: map[string]string{ + utils.RayClusterLabelKey: "test-cluster", + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + }, + Annotations: map[string]string{ + utils.UpgradeStrategyRecreateHashKey: staleHash, + utils.KubeRayVersion: utils.KUBERAY_VERSION, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{Name: "ray-head", Image: "rayproject/ray:latest"}}, + }, + } + + // Pre-create Workload and PodGroups that should be deleted on recreate. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + workerPG := &schedulingv1alpha2.PodGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-worker-workers", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s). + WithObjects(headPod, workload, workerPG).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcilePods(ctx, cluster) + require.NoError(t, err) + + // Verify Workload is deleted. + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, &schedulingv1alpha2.Workload{}) + assert.True(t, apierrors.IsNotFound(err), "Workload should be deleted on recreate upgrade") + + // Verify PodGroups are deleted. + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, client.InNamespace("default"), client.MatchingLabels{utils.RayClusterLabelKey: "test-cluster"}) + require.NoError(t, err) + assert.Empty(t, pgList.Items, "PodGroups should be deleted on recreate upgrade") +} + +func TestReconcilePods_ResumeRecreatesNativeSchedulingResources(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, false) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + // Cluster is NOT suspended (simulating resume after prior suspension deleted resources). + + s := newTestScheme() + + // No pre-existing Workload or PodGroups — they were deleted during suspend. + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + fakeRecorder := record.NewFakeRecorder(20) + r := newReconciler(fakeClient, s, fakeRecorder) + ctx := context.Background() + + err := r.reconcilePods(ctx, cluster) + require.NoError(t, err) + + // Verify Workload was recreated. + workload := &schedulingv1alpha2.Workload{} + err = fakeClient.Get(ctx, types.NamespacedName{Name: "test-cluster", Namespace: "default"}, workload) + require.NoError(t, err, "Workload should be created after resume") + assert.Len(t, workload.Spec.PodGroupTemplates, 2) + assert.Equal(t, "head", workload.Spec.PodGroupTemplates[0].Name) + assert.Equal(t, "worker-workers", workload.Spec.PodGroupTemplates[1].Name) + require.NotNil(t, workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang) + assert.Equal(t, int32(3), workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount) + + // Verify PodGroups were recreated. + pgList := &schedulingv1alpha2.PodGroupList{} + err = fakeClient.List(ctx, pgList, client.InNamespace("default"), client.MatchingLabels{utils.RayClusterLabelKey: "test-cluster"}) + require.NoError(t, err) + assert.Len(t, pgList.Items, 2, "PodGroups should be recreated after resume") +} + +// --- WorkloadScheduled condition tests --- + +func TestSetWorkloadScheduledCondition_TrueWhenWorkloadExists(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // Pre-create a Workload so the condition check finds it. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + fakeClient := clientFake.NewClientBuilder().WithScheme(s).WithObjects(workload).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + r.setWorkloadScheduledCondition(ctx, cluster, "") + + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + require.NotNil(t, cond, "WorkloadScheduled condition should be set") + assert.Equal(t, metav1.ConditionTrue, cond.Status) + assert.Equal(t, rayv1.WorkloadReady, cond.Reason) +} + +func TestSetWorkloadScheduledCondition_FalseWhenWorkloadDoesNotExist(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + + // No pre-existing Workload. + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + r.setWorkloadScheduledCondition(ctx, cluster, "") + + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + require.NotNil(t, cond, "WorkloadScheduled condition should be set") + assert.Equal(t, metav1.ConditionFalse, cond.Status) + assert.Equal(t, rayv1.WorkloadPending, cond.Reason) +} + +func TestSetWorkloadScheduledCondition_NotSetWhenFeatureGateDisabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, false) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + r.setWorkloadScheduledCondition(ctx, cluster, "") + + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + assert.Nil(t, cond, "WorkloadScheduled condition should not be set when feature gate is disabled") +} + +func TestSetWorkloadScheduledCondition_NotSetWhenAnnotationMissing(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + cluster.Annotations = nil // Remove the opt-in annotation. + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + r.setWorkloadScheduledCondition(ctx, cluster, "") + + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + assert.Nil(t, cond, "WorkloadScheduled condition should not be set when annotation is missing") +} + +func TestSetWorkloadScheduledCondition_RemovedWhenSuspended(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + // Pre-set the condition to verify it gets removed. + meta.SetStatusCondition(&cluster.Status.Conditions, metav1.Condition{ + Type: string(rayv1.RayClusterWorkloadScheduled), + Status: metav1.ConditionTrue, + Reason: rayv1.WorkloadReady, + Message: "Workload and PodGroups have been created", + }) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + r.setWorkloadScheduledCondition(ctx, cluster, rayv1.RayClusterSuspended) + + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + assert.Nil(t, cond, "WorkloadScheduled condition should be removed when cluster is suspended") +} + +func TestSetWorkloadScheduledCondition_RemovedWhenSuspending(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 3)) + // Pre-set the condition to verify it gets removed. + meta.SetStatusCondition(&cluster.Status.Conditions, metav1.Condition{ + Type: string(rayv1.RayClusterWorkloadScheduled), + Status: metav1.ConditionTrue, + Reason: rayv1.WorkloadReady, + Message: "Workload and PodGroups have been created", + }) + s := newTestScheme() + fakeClient := clientFake.NewClientBuilder().WithScheme(s).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + r.setWorkloadScheduledCondition(ctx, cluster, rayv1.RayClusterSuspending) + + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + assert.Nil(t, cond, "WorkloadScheduled condition should be removed when cluster is suspending") +} + +// --- calculateStatus integration tests --- + +func TestCalculateStatus_WorkloadScheduledConditionSetWhenBothGatesEnabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, true) + + cluster := newTestRayCluster(newWorkerGroup("workers", 1)) + s := newTestScheme() + + // Create a Workload so the condition check finds it. + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + // Create a head service so calculateStatus's updateEndpoints/updateHeadInfo succeed. + headService := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head-svc", + Namespace: "default", + Labels: map[string]string{ + utils.RayClusterLabelKey: "test-cluster", + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + utils.RayIDLabelKey: utils.CheckLabel(utils.GenerateIdentifier("test-cluster", rayv1.HeadNode)), + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.0.0.1", + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s).WithObjects(workload, headService).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + newInstance, err := r.calculateStatus(ctx, cluster, nil) + require.NoError(t, err) + + cond := meta.FindStatusCondition(newInstance.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + require.NotNil(t, cond, "WorkloadScheduled condition should be set through calculateStatus") + assert.Equal(t, metav1.ConditionTrue, cond.Status) + assert.Equal(t, rayv1.WorkloadReady, cond.Reason) +} + +func TestCalculateStatus_WorkloadScheduledConditionNotSetWhenStatusConditionsGateDisabled(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, true) + features.SetFeatureGateDuringTest(t, features.RayClusterStatusConditions, false) + + cluster := newTestRayCluster(newWorkerGroup("workers", 1)) + s := newTestScheme() + + workload := &schedulingv1alpha2.Workload{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster", + Namespace: "default", + Labels: map[string]string{utils.RayClusterLabelKey: "test-cluster"}, + }, + } + + headService := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cluster-head-svc", + Namespace: "default", + Labels: map[string]string{ + utils.RayClusterLabelKey: "test-cluster", + utils.RayNodeTypeLabelKey: string(rayv1.HeadNode), + utils.RayIDLabelKey: utils.CheckLabel(utils.GenerateIdentifier("test-cluster", rayv1.HeadNode)), + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.0.0.1", + }, + } + + fakeClient := clientFake.NewClientBuilder().WithScheme(s).WithObjects(workload, headService).Build() + r := newReconciler(fakeClient, s, record.NewFakeRecorder(10)) + ctx := context.Background() + + newInstance, err := r.calculateStatus(ctx, cluster, nil) + require.NoError(t, err) + + cond := meta.FindStatusCondition(newInstance.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + assert.Nil(t, cond, "WorkloadScheduled condition should not be set when RayClusterStatusConditions gate is disabled") +} diff --git a/ray-operator/controllers/ray/raycluster_controller.go b/ray-operator/controllers/ray/raycluster_controller.go index 033a2395894..dbf4d85fb21 100644 --- a/ray-operator/controllers/ray/raycluster_controller.go +++ b/ray-operator/controllers/ray/raycluster_controller.go @@ -20,6 +20,7 @@ import ( corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" rbacv1 "k8s.io/api/rbac/v1" + schedulingv1alpha2 "k8s.io/api/scheduling/v1alpha2" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/resource" @@ -101,6 +102,7 @@ type RayClusterReconcilerOptions struct { // +kubebuilder:rbac:groups=core,resources=serviceaccounts,verbs=get;list;watch;create;delete // +kubebuilder:rbac:groups="rbac.authorization.k8s.io",resources=roles,verbs=get;list;watch;create;delete;update // +kubebuilder:rbac:groups="rbac.authorization.k8s.io",resources=rolebindings,verbs=get;list;watch;create;delete +// +kubebuilder:rbac:groups=scheduling.k8s.io,resources=podgroups;workloads,verbs=get;list;watch;create;update;patch;delete // [WARNING]: There MUST be a newline after kubebuilder markers. @@ -289,7 +291,14 @@ func (r *RayClusterReconciler) rayClusterReconcile(ctx context.Context, instance } if instance.DeletionTimestamp != nil && !instance.DeletionTimestamp.IsZero() { - logger.Info("RayCluster is being deleted, just ignore") + logger.Info("RayCluster is being deleted, cleaning up native scheduling resources") + // Clean up Workload/PodGroups explicitly rather than relying on garbage collection + // because PodGroups may have a scheduler-added finalizer that blocks GC deletion. + if isNativeWorkloadSchedulingEnabled(instance) { + if err := r.deleteNativeWorkloadSchedulingResources(ctx, instance); err != nil { + return ctrl.Result{RequeueAfter: DefaultRequeueDuration}, err + } + } return ctrl.Result{}, nil } @@ -630,6 +639,11 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv statusConditionGateEnabled := features.Enabled(features.RayClusterStatusConditions) if suspendStatus == rayv1.RayClusterSuspending || (!statusConditionGateEnabled && instance.Spec.Suspend != nil && *instance.Spec.Suspend) { + if isNativeWorkloadSchedulingEnabled(instance) { + if err := r.deleteNativeWorkloadSchedulingResources(ctx, instance); err != nil { + return err + } + } if _, err := r.deleteAllPods(ctx, common.RayClusterAllPodsAssociationOptions(instance)); err != nil { r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(utils.FailedToDeletePodCollection), "Failed deleting Pods due to suspension for RayCluster %s/%s, %v", @@ -656,6 +670,11 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv // Check if pods need to be recreated with Recreate upgradeStrategy if r.shouldRecreatePodsForUpgrade(ctx, instance) { logger.Info("RayCluster spec changed with Recreate upgradeStrategy, deleting all pods") + if isNativeWorkloadSchedulingEnabled(instance) { + if err := r.deleteNativeWorkloadSchedulingResources(ctx, instance); err != nil { + return err + } + } if _, err := r.deleteAllPods(ctx, common.RayClusterAllPodsAssociationOptions(instance)); err != nil { r.Recorder.Eventf(instance, corev1.EventTypeWarning, string(utils.FailedToDeletePodCollection), "Failed deleting Pods due to spec change with Recreate upgradeStrategy for RayCluster %s/%s, %v", @@ -685,6 +704,12 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv return err } } + + // Create Workload and PodGroup resources for native workload scheduling + if err := r.reconcileNativeWorkloadScheduling(ctx, instance); err != nil { + return err + } + // Reconcile head Pod if !r.rayClusterScaleExpectation.IsSatisfied(ctx, instance.Namespace, instance.Name, expectations.HeadGroup) { logger.Info("reconcilePods", "Expectation", "NotSatisfiedHeadExpectations, reconcile head later") @@ -1326,6 +1351,11 @@ func (r *RayClusterReconciler) createHeadPod(ctx context.Context, instance rayv1 } } + // Native workload scheduling: set schedulingGroup on head pod. + if r.shouldSetSchedulingGroup(&instance) { + setSchedulingGroup(&pod, podGroupName(instance.Name, "head")) + } + if err := r.Create(ctx, &pod); err != nil { r.Recorder.Eventf(&instance, corev1.EventTypeWarning, string(utils.FailedToCreateHeadPod), "Failed to create head Pod %s/%s, %v", pod.Namespace, pod.Name, err) return err @@ -1349,6 +1379,11 @@ func (r *RayClusterReconciler) createWorkerPod(ctx context.Context, instance ray } } + // Native workload scheduling: set schedulingGroup on worker pod. + if r.shouldSetSchedulingGroup(&instance) { + setSchedulingGroup(&pod, podGroupName(instance.Name, "worker-"+worker.GroupName)) + } + replica := pod if err := r.Create(ctx, &replica); err != nil { r.Recorder.Eventf(&instance, corev1.EventTypeWarning, string(utils.FailedToCreateWorkerPod), "Failed to create worker Pod for the cluster %s/%s, %v", instance.Namespace, instance.Name, err) @@ -1373,6 +1408,11 @@ func (r *RayClusterReconciler) createWorkerPodWithIndex(ctx context.Context, ins } } + // Native workload scheduling: set schedulingGroup on worker pod. + if r.shouldSetSchedulingGroup(&instance) { + setSchedulingGroup(&pod, podGroupName(instance.Name, "worker-"+worker.GroupName)) + } + replica := pod if err := r.Create(ctx, &replica); err != nil { r.Recorder.Eventf(&instance, corev1.EventTypeWarning, string(utils.FailedToCreateWorkerPod), "Failed to create worker Pod for the cluster %s/%s, %v", instance.Namespace, instance.Name, err) @@ -1535,6 +1575,11 @@ func (r *RayClusterReconciler) SetupWithManager(mgr ctrl.Manager, reconcileConcu r.options.BatchSchedulerManager.ConfigureReconciler(b) } + if features.Enabled(features.NativeWorkloadScheduling) { + b = b.Owns(&schedulingv1alpha2.Workload{}). + Owns(&schedulingv1alpha2.PodGroup{}) + } + return b. WithOptions(controller.Options{ MaxConcurrentReconciles: reconcileConcurrency, @@ -1691,6 +1736,8 @@ func (r *RayClusterReconciler) calculateStatus(ctx context.Context, instance *ra }) } } + + r.setWorkloadScheduledCondition(ctx, newInstance, suspendStatus) } if newInstance.Spec.Suspend != nil && *newInstance.Spec.Suspend && len(runtimePods.Items) == 0 { diff --git a/ray-operator/main.go b/ray-operator/main.go index f11f9d504ef..c2a64dc7adc 100644 --- a/ray-operator/main.go +++ b/ray-operator/main.go @@ -12,11 +12,14 @@ import ( "go.uber.org/zap/zapcore" "gopkg.in/natefinch/lumberjack.v2" batchv1 "k8s.io/api/batch/v1" + schedulingv1alpha2 "k8s.io/api/scheduling/v1alpha2" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer" utilruntime "k8s.io/apimachinery/pkg/util/runtime" utilfeature "k8s.io/apiserver/pkg/util/feature" + "k8s.io/client-go/discovery" clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/rest" "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -48,6 +51,7 @@ func init() { utilruntime.Must(rayv1.AddToScheme(scheme)) utilruntime.Must(routev1.Install(scheme)) utilruntime.Must(batchv1.AddToScheme(scheme)) + utilruntime.Must(schedulingv1alpha2.AddToScheme(scheme)) utilruntime.Must(configapi.AddToScheme(scheme)) // +kubebuilder:scaffold:scheme } @@ -190,6 +194,10 @@ func main() { } features.LogFeatureGates(setupLog) + if err := validateNativeWorkloadSchedulingConfig(config); err != nil { + exitOnError(err, "startup validation failed") + } + if features.Enabled(features.RayServiceIncrementalUpgrade) { utilruntime.Must(gwv1.Install(scheme)) } @@ -240,6 +248,14 @@ func main() { restConfig.UserAgent = userAgent restConfig.QPS = float32(*config.QPS) restConfig.Burst = *config.Burst + + if features.Enabled(features.NativeWorkloadScheduling) { + if err := checkSchedulingV1alpha2Available(restConfig); err != nil { + exitOnError(err, "NativeWorkloadScheduling feature gate is enabled but scheduling.k8s.io/v1alpha2 API is not available. "+ + "Ensure Kubernetes 1.36+ with GenericWorkload feature gate enabled.") + } + } + mgr, err := ctrl.NewManager(restConfig, options) exitOnError(err, "unable to start manager") @@ -321,6 +337,30 @@ func exitOnError(err error, msg string, keysAndValues ...any) { } } +// validateNativeWorkloadSchedulingConfig checks that the NativeWorkloadScheduling feature gate +// is not enabled alongside a batch scheduler configuration, since the two are mutually exclusive. +func validateNativeWorkloadSchedulingConfig(config configapi.Configuration) error { + if features.Enabled(features.NativeWorkloadScheduling) && (config.EnableBatchScheduler || len(config.BatchScheduler) > 0) { + return fmt.Errorf("NativeWorkloadScheduling feature gate and batchScheduler configuration are mutually exclusive") + } + return nil +} + +// checkSchedulingV1alpha2Available uses the discovery API to verify that scheduling.k8s.io/v1alpha2 +// is served by the API server. This confirms the cluster is running Kubernetes 1.36+ with the +// GenericWorkload feature gate enabled. +func checkSchedulingV1alpha2Available(restConfig *rest.Config) error { + discoveryClient, err := discovery.NewDiscoveryClientForConfig(restConfig) + if err != nil { + return fmt.Errorf("failed to create discovery client: %w", err) + } + _, err = discoveryClient.ServerResourcesForGroupVersion("scheduling.k8s.io/v1alpha2") + if err != nil { + return fmt.Errorf("scheduling.k8s.io/v1alpha2 API is not available: %w", err) + } + return nil +} + // decodeConfig decodes raw config data and returns the Configuration type. func decodeConfig(configData []byte, scheme *runtime.Scheme) (configapi.Configuration, error) { cfg := configapi.Configuration{} diff --git a/ray-operator/main_test.go b/ray-operator/main_test.go index 5859654c2cc..24b8272c6db 100644 --- a/ray-operator/main_test.go +++ b/ray-operator/main_test.go @@ -1,15 +1,20 @@ package main import ( + "encoding/json" + "net/http" + "net/http/httptest" "reflect" "strings" "testing" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/rest" "k8s.io/utils/ptr" configapi "github.com/ray-project/kuberay/ray-operator/apis/config/v1alpha1" + "github.com/ray-project/kuberay/ray-operator/pkg/features" ) func Test_decodeConfig(t *testing.T) { @@ -256,3 +261,234 @@ reconcileConcurrency: 100 }) } } + +func Test_validateNativeWorkloadSchedulingConfig(t *testing.T) { + tests := []struct { + name string + featureGateEnabled bool + enableBatchScheduler bool + batchScheduler string + wantErr bool + errContains string + }{ + // Positive cases: no error expected + { + name: "all disabled — no conflict", + featureGateEnabled: false, + enableBatchScheduler: false, + batchScheduler: "", + wantErr: false, + }, + { + name: "only NativeWorkloadScheduling enabled — no conflict", + featureGateEnabled: true, + enableBatchScheduler: false, + batchScheduler: "", + wantErr: false, + }, + { + name: "only EnableBatchScheduler enabled — no conflict", + featureGateEnabled: false, + enableBatchScheduler: true, + batchScheduler: "", + wantErr: false, + }, + { + name: "only BatchScheduler set — no conflict", + featureGateEnabled: false, + enableBatchScheduler: false, + batchScheduler: "volcano", + wantErr: false, + }, + { + name: "gate off with both batch scheduler options — no conflict", + featureGateEnabled: false, + enableBatchScheduler: true, + batchScheduler: "volcano", + wantErr: false, + }, + // Negative cases: error expected (mutually exclusive) + { + name: "NativeWorkloadScheduling + EnableBatchScheduler — mutually exclusive", + featureGateEnabled: true, + enableBatchScheduler: true, + batchScheduler: "", + wantErr: true, + errContains: "mutually exclusive", + }, + { + name: "NativeWorkloadScheduling + BatchScheduler=volcano — mutually exclusive", + featureGateEnabled: true, + enableBatchScheduler: false, + batchScheduler: "volcano", + wantErr: true, + errContains: "mutually exclusive", + }, + { + name: "NativeWorkloadScheduling + BatchScheduler=yunikorn — mutually exclusive", + featureGateEnabled: true, + enableBatchScheduler: false, + batchScheduler: "yunikorn", + wantErr: true, + errContains: "mutually exclusive", + }, + { + name: "NativeWorkloadScheduling + BatchScheduler=kai-scheduler — mutually exclusive", + featureGateEnabled: true, + enableBatchScheduler: false, + batchScheduler: "kai-scheduler", + wantErr: true, + errContains: "mutually exclusive", + }, + { + name: "NativeWorkloadScheduling + EnableBatchScheduler + BatchScheduler — mutually exclusive", + featureGateEnabled: true, + enableBatchScheduler: true, + batchScheduler: "volcano", + wantErr: true, + errContains: "mutually exclusive", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + features.SetFeatureGateDuringTest(t, features.NativeWorkloadScheduling, tt.featureGateEnabled) + + config := configapi.Configuration{ + EnableBatchScheduler: tt.enableBatchScheduler, + BatchScheduler: tt.batchScheduler, + } + + err := validateNativeWorkloadSchedulingConfig(config) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("expected error containing %q, got: %v", tt.errContains, err) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func Test_checkSchedulingV1alpha2Available(t *testing.T) { + tests := []struct { + name string + handler http.HandlerFunc + wantErr bool + errContains string + }{ + // Positive: API is available + { + name: "API available — returns resource list", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/apis/scheduling.k8s.io/v1alpha2" { + w.Header().Set("Content-Type", "application/json") + resourceList := metav1.APIResourceList{ + GroupVersion: "scheduling.k8s.io/v1alpha2", + APIResources: []metav1.APIResource{ + {Name: "workloads", Kind: "Workload", Namespaced: true}, + {Name: "podgroups", Kind: "PodGroup", Namespaced: true}, + }, + } + _ = json.NewEncoder(w).Encode(resourceList) + return + } + http.NotFound(w, r) + }, + wantErr: false, + }, + { + name: "API available — empty resource list", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/apis/scheduling.k8s.io/v1alpha2" { + w.Header().Set("Content-Type", "application/json") + resourceList := metav1.APIResourceList{ + GroupVersion: "scheduling.k8s.io/v1alpha2", + } + _ = json.NewEncoder(w).Encode(resourceList) + return + } + http.NotFound(w, r) + }, + wantErr: false, + }, + // Negative: API not available + { + name: "API not available — 404", + handler: func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }, + wantErr: true, + errContains: "scheduling.k8s.io/v1alpha2 API is not available", + }, + { + name: "API not available — 500 server error", + handler: func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "internal server error", http.StatusInternalServerError) + }, + wantErr: true, + errContains: "scheduling.k8s.io/v1alpha2 API is not available", + }, + { + name: "API not available — different group version exists but not v1alpha2", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/apis/scheduling.k8s.io/v1" { + w.Header().Set("Content-Type", "application/json") + resourceList := metav1.APIResourceList{ + GroupVersion: "scheduling.k8s.io/v1", + } + _ = json.NewEncoder(w).Encode(resourceList) + return + } + http.NotFound(w, r) + }, + wantErr: true, + errContains: "scheduling.k8s.io/v1alpha2 API is not available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(tt.handler) + defer server.Close() + + restConfig := &rest.Config{ + Host: server.URL, + } + + err := checkSchedulingV1alpha2Available(restConfig) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("expected error containing %q, got: %v", tt.errContains, err) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func Test_checkSchedulingV1alpha2Available_unreachableServer(t *testing.T) { + restConfig := &rest.Config{ + Host: "http://127.0.0.1:1", // port 1 is almost certainly not listening + } + + err := checkSchedulingV1alpha2Available(restConfig) + if err == nil { + t.Errorf("expected error for unreachable server but got nil") + } + if !strings.Contains(err.Error(), "scheduling.k8s.io/v1alpha2 API is not available") { + t.Errorf("expected error about API not available, got: %v", err) + } +} diff --git a/ray-operator/pkg/features/features.go b/ray-operator/pkg/features/features.go index e0cab5b41d6..6bbd1d70ff6 100644 --- a/ray-operator/pkg/features/features.go +++ b/ray-operator/pkg/features/features.go @@ -47,6 +47,13 @@ const ( // // Enables RayCronJob controller for scheduled RayJob execution. RayCronJob featuregate.Feature = "RayCronJob" + + // owner: @marosset + // rep: N/A + // alpha: v1.6 + // + // Enables native Kubernetes gang scheduling via scheduling.k8s.io/v1alpha2 Workload and PodGroup APIs. + NativeWorkloadScheduling featuregate.Feature = "NativeWorkloadScheduling" ) func init() { @@ -59,6 +66,7 @@ var defaultFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{ RayMultiHostIndexing: {Default: true, PreRelease: featuregate.Beta}, RayServiceIncrementalUpgrade: {Default: false, PreRelease: featuregate.Alpha}, RayCronJob: {Default: false, PreRelease: featuregate.Alpha}, + NativeWorkloadScheduling: {Default: false, PreRelease: featuregate.Alpha}, } // SetFeatureGateDuringTest is a helper method to override feature gates in tests. diff --git a/ray-operator/test/e2enativescheduling/raycluster_nativescheduling_test.go b/ray-operator/test/e2enativescheduling/raycluster_nativescheduling_test.go new file mode 100644 index 00000000000..c6865cf88c3 --- /dev/null +++ b/ray-operator/test/e2enativescheduling/raycluster_nativescheduling_test.go @@ -0,0 +1,853 @@ +package e2enativescheduling + +import ( + "testing" + "time" + + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + schedulingv1alpha2 "k8s.io/api/scheduling/v1alpha2" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" + rayv1ac "github.com/ray-project/kuberay/ray-operator/pkg/client/applyconfiguration/ray/v1" + . "github.com/ray-project/kuberay/ray-operator/test/support" +) + +func TestNativeScheduling_CreatesWorkloadAndPodGroups(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("native-sched", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify Workload exists with correct spec. + LogWithTimestamp(test.T(), "Verifying Workload %s/%s exists", namespace.Name, rayCluster.Name) + workload, err := GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(workload.Spec.ControllerRef).NotTo(BeNil()) + g.Expect(workload.Spec.ControllerRef.APIGroup).To(Equal("ray.io")) + g.Expect(workload.Spec.ControllerRef.Kind).To(Equal("RayCluster")) + g.Expect(workload.Spec.ControllerRef.Name).To(Equal(rayCluster.Name)) + + // Workload should have 2 PodGroupTemplates: head + 1 worker group. + g.Expect(workload.Spec.PodGroupTemplates).To(HaveLen(2)) + g.Expect(workload.Spec.PodGroupTemplates[0].Name).To(Equal("head")) + g.Expect(workload.Spec.PodGroupTemplates[0].SchedulingPolicy.Basic).NotTo(BeNil()) + g.Expect(workload.Spec.PodGroupTemplates[1].Name).To(Equal("worker-small-group")) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + + // Verify ownerReference points to the RayCluster. + g.Expect(workload.OwnerReferences).To(HaveLen(1)) + g.Expect(workload.OwnerReferences[0].Kind).To(Equal("RayCluster")) + g.Expect(workload.OwnerReferences[0].Name).To(Equal(rayCluster.Name)) + g.Expect(*workload.OwnerReferences[0].Controller).To(BeTrue()) + + // Verify labels on Workload. + g.Expect(workload.Labels[utils.RayClusterLabelKey]).To(Equal(rayCluster.Name)) + + // Verify PodGroups exist. + LogWithTimestamp(test.T(), "Verifying PodGroups exist") + headPG, err := GetPodGroup(test, namespace.Name, rayCluster.Name+"-head") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(headPG.Spec.PodGroupTemplateRef).NotTo(BeNil()) + g.Expect(headPG.Spec.PodGroupTemplateRef.Workload).NotTo(BeNil()) + g.Expect(headPG.Spec.PodGroupTemplateRef.Workload.WorkloadName).To(Equal(rayCluster.Name)) + g.Expect(headPG.Spec.PodGroupTemplateRef.Workload.PodGroupTemplateName).To(Equal("head")) + g.Expect(headPG.Spec.SchedulingPolicy.Basic).NotTo(BeNil()) + + // Verify PodGroup ownerReference and labels. + g.Expect(headPG.OwnerReferences).To(HaveLen(1)) + g.Expect(headPG.OwnerReferences[0].Kind).To(Equal("RayCluster")) + g.Expect(headPG.OwnerReferences[0].Name).To(Equal(rayCluster.Name)) + g.Expect(*headPG.OwnerReferences[0].Controller).To(BeTrue()) + g.Expect(headPG.Labels[utils.RayClusterLabelKey]).To(Equal(rayCluster.Name)) + + workerPG, err := GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(workerPG.Spec.PodGroupTemplateRef).NotTo(BeNil()) + g.Expect(workerPG.Spec.PodGroupTemplateRef.Workload).NotTo(BeNil()) + g.Expect(workerPG.Spec.PodGroupTemplateRef.Workload.WorkloadName).To(Equal(rayCluster.Name)) + g.Expect(workerPG.Spec.PodGroupTemplateRef.Workload.PodGroupTemplateName).To(Equal("worker-small-group")) + g.Expect(workerPG.Spec.SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(workerPG.Spec.SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + g.Expect(workerPG.OwnerReferences).To(HaveLen(1)) + g.Expect(workerPG.OwnerReferences[0].Kind).To(Equal("RayCluster")) + g.Expect(workerPG.OwnerReferences[0].Name).To(Equal(rayCluster.Name)) + g.Expect(workerPG.Labels[utils.RayClusterLabelKey]).To(Equal(rayCluster.Name)) + + // Verify the scheduler processed the PodGroups by checking PodGroupScheduled condition. + LogWithTimestamp(test.T(), "Verifying PodGroupScheduled condition on PodGroups") + g.Eventually(PodGroup(test, namespace.Name, rayCluster.Name+"-head"), TestTimeoutShort). + Should(WithTransform(func(pg *schedulingv1alpha2.PodGroup) bool { + return meta.IsStatusConditionTrue(pg.Status.Conditions, schedulingv1alpha2.PodGroupScheduled) + }, BeTrue())) + g.Eventually(PodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group"), TestTimeoutShort). + Should(WithTransform(func(pg *schedulingv1alpha2.PodGroup) bool { + return meta.IsStatusConditionTrue(pg.Status.Conditions, schedulingv1alpha2.PodGroupScheduled) + }, BeTrue())) +} + +func TestNativeScheduling_PodSchedulingGroup(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("sched-group", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify head pod has schedulingGroup set. + headPod, err := GetHeadPod(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(headPod.Spec.SchedulingGroup).NotTo(BeNil()) + g.Expect(headPod.Spec.SchedulingGroup.PodGroupName).NotTo(BeNil()) + g.Expect(*headPod.Spec.SchedulingGroup.PodGroupName).To(Equal(rayCluster.Name + "-head")) + + // Verify worker pods have schedulingGroup set. + workerPods, err := GetWorkerPods(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(workerPods).NotTo(BeEmpty()) + for _, pod := range workerPods { + g.Expect(pod.Spec.SchedulingGroup).NotTo(BeNil()) + g.Expect(pod.Spec.SchedulingGroup.PodGroupName).NotTo(BeNil()) + g.Expect(*pod.Spec.SchedulingGroup.PodGroupName).To(Equal(rayCluster.Name + "-worker-small-group")) + } +} + +func TestNativeScheduling_MultipleWorkerGroups(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("multi-wg", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(rayv1ac.RayClusterSpec(). + WithRayVersion(GetRayVersion()). + WithHeadGroupSpec(rayv1ac.HeadGroupSpec(). + WithRayStartParams(map[string]string{"dashboard-host": "0.0.0.0"}). + WithTemplate(HeadPodTemplateApplyConfiguration())). + WithWorkerGroupSpecs( + rayv1ac.WorkerGroupSpec(). + WithReplicas(1). + WithMinReplicas(1). + WithMaxReplicas(1). + WithGroupName("group-a"). + WithRayStartParams(map[string]string{"num-cpus": "1"}). + WithTemplate(WorkerPodTemplateApplyConfiguration()), + rayv1ac.WorkerGroupSpec(). + WithReplicas(2). + WithMinReplicas(2). + WithMaxReplicas(2). + WithGroupName("group-b"). + WithRayStartParams(map[string]string{"num-cpus": "1"}). + WithTemplate(WorkerPodTemplateApplyConfiguration()), + )) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s with 2 worker groups", rayCluster.Namespace, rayCluster.Name) + + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify Workload has 3 PodGroupTemplates: head + 2 worker groups. + workload, err := GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(workload.Spec.PodGroupTemplates).To(HaveLen(3)) + g.Expect(workload.Spec.PodGroupTemplates[0].Name).To(Equal("head")) + g.Expect(workload.Spec.PodGroupTemplates[1].Name).To(Equal("worker-group-a")) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + g.Expect(workload.Spec.PodGroupTemplates[2].Name).To(Equal("worker-group-b")) + g.Expect(workload.Spec.PodGroupTemplates[2].SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(workload.Spec.PodGroupTemplates[2].SchedulingPolicy.Gang.MinCount).To(Equal(int32(2))) + + // Verify 3 PodGroups exist (head + group-a + group-b). + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(3)) + + pgGroupA, err := GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-group-a") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(pgGroupA.Spec.SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(pgGroupA.Spec.SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + + pgGroupB, err := GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-group-b") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(pgGroupB.Spec.SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(pgGroupB.Spec.SchedulingPolicy.Gang.MinCount).To(Equal(int32(2))) + + // Verify pods in each worker group reference the correct PodGroup. + LogWithTimestamp(test.T(), "Verifying per-group pod schedulingGroup references") + allWorkerPods, err := GetWorkerPods(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + for _, pod := range allWorkerPods { + group := pod.Labels[utils.RayNodeGroupLabelKey] + g.Expect(pod.Spec.SchedulingGroup).NotTo(BeNil(), "pod %s missing schedulingGroup", pod.Name) + g.Expect(pod.Spec.SchedulingGroup.PodGroupName).NotTo(BeNil(), "pod %s missing podGroupName", pod.Name) + g.Expect(*pod.Spec.SchedulingGroup.PodGroupName).To(Equal(rayCluster.Name+"-worker-"+group), + "pod %s has wrong podGroupName", pod.Name) + } +} + +func TestNativeScheduling_Events(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("sched-events", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify CreatedWorkload event was emitted. + g.Eventually(GetEvents(test, namespace.Name, rayCluster.Name, "CreatedWorkload"), TestTimeoutShort). + ShouldNot(BeEmpty()) + + // Verify CreatedPodGroup events were emitted (head + worker). + g.Eventually(GetEvents(test, namespace.Name, rayCluster.Name, "CreatedPodGroup"), TestTimeoutShort). + Should(HaveLen(2)) +} + +func TestNativeScheduling_NoAnnotation(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + // Create a RayCluster without the native scheduling annotation. + rayClusterAC := rayv1ac.RayCluster("no-annotation", namespace.Name). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s without native scheduling annotation", rayCluster.Namespace, rayCluster.Name) + + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify no Workload was created. + _, err = GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(errors.IsNotFound(err)).To(BeTrue(), "expected NotFound for Workload, got: %v", err) + + // Verify no PodGroups were created. + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(BeEmpty()) + + // Verify head pod does not have schedulingGroup set. + headPod, err := GetHeadPod(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(headPod.Spec.SchedulingGroup).To(BeNil()) +} + +func TestNativeScheduling_AutoscalingSkipped(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + // Create a RayCluster with autoscaling + native scheduling annotation. + rayClusterAC := rayv1ac.RayCluster("autoscale-skip", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec().WithEnableInTreeAutoscaling(true)) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s with autoscaling + native scheduling", rayCluster.Namespace, rayCluster.Name) + + // Wait for the cluster to start reconciling (HeadPodReady condition appears). + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to start reconciling", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(StatusCondition(rayv1.HeadPodReady), MatchCondition(metav1.ConditionTrue, rayv1.HeadPodRunningAndReady))) + + // Verify WorkloadSchedulingSkipped warning event was emitted. + g.Eventually(GetEvents(test, namespace.Name, rayCluster.Name, "WorkloadSchedulingSkipped"), TestTimeoutShort). + ShouldNot(BeEmpty()) + + // Verify no Workload was created. + _, err = GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(errors.IsNotFound(err)).To(BeTrue(), "expected NotFound for Workload, got: %v", err) + + // Verify head pod does not have schedulingGroup set (autoscaling skipped native scheduling). + headPod, err := GetHeadPod(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(headPod.Spec.SchedulingGroup).To(BeNil(), "head pod should not have schedulingGroup when autoscaling is enabled") +} + +func TestNativeScheduling_GangSchedules(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("gang-sched", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for the cluster to become ready — this validates that the scheduler + // processes the gang and all pods in the gang become Running. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready (gang scheduling)", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify all pods are Running. + allPods, err := GetAllPods(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(allPods).NotTo(BeEmpty()) + for _, pod := range allPods { + g.Expect(pod.Status.Phase).To(Equal(corev1.PodRunning)) + } + + // Verify Workload and PodGroups were created. + _, err = GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + + // Verify the scheduler's gang plugin processed the PodGroups. + // PodGroupScheduled=True confirms the gang constraint was evaluated and satisfied, + // not just that pods happened to schedule independently. + LogWithTimestamp(test.T(), "Verifying PodGroupScheduled condition on worker PodGroup") + g.Eventually(PodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group"), TestTimeoutShort). + Should(WithTransform(func(pg *schedulingv1alpha2.PodGroup) bool { + return meta.IsStatusConditionTrue(pg.Status.Conditions, schedulingv1alpha2.PodGroupScheduled) + }, BeTrue())) +} + +func TestNativeScheduling_OwnerReferenceGC(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("gc-test", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready so Workload and PodGroups are created. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify Workload and PodGroups exist before deletion. + _, err = GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(2)) + + // Delete the RayCluster. + LogWithTimestamp(test.T(), "Deleting RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Delete(test.Ctx(), rayCluster.Name, metav1.DeleteOptions{}) + g.Expect(err).NotTo(HaveOccurred()) + + // Verify Workload is garbage collected. + LogWithTimestamp(test.T(), "Waiting for Workload to be garbage collected") + g.Eventually(func() bool { + _, err := GetWorkload(test, namespace.Name, rayCluster.Name) + return errors.IsNotFound(err) + }, TestTimeoutShort).Should(BeTrue()) + + // Verify PodGroups are garbage collected. + LogWithTimestamp(test.T(), "Waiting for PodGroups to be garbage collected") + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(BeEmpty()) +} + +func TestNativeScheduling_Idempotent(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("idempotent", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify exactly 1 Workload and 2 PodGroups exist. + g.Eventually(Workloads(test, namespace.Name), TestTimeoutShort).Should(HaveLen(1)) + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(2)) + + // Wait and verify the count stays stable (no duplicates created by re-reconciliation). + LogWithTimestamp(test.T(), "Verifying resource counts remain stable over time") + g.Consistently(Workloads(test, namespace.Name), 10*time.Second, time.Second).Should(HaveLen(1)) + g.Consistently(PodGroups(test, namespace.Name), 10*time.Second, time.Second).Should(HaveLen(2)) +} + +func TestNativeScheduling_SuspendDeletesResources(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("suspend-del", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready so Workload and PodGroups are created. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify Workload and PodGroups exist before suspend. + g.Eventually(Workloads(test, namespace.Name), TestTimeoutShort).Should(HaveLen(1)) + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(2)) + + // Verify WorkloadScheduled condition is True before suspend. + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutShort). + Should(WithTransform(StatusCondition(rayv1.RayClusterWorkloadScheduled), MatchCondition(metav1.ConditionTrue, rayv1.WorkloadReady))) + + // Suspend the cluster. + LogWithTimestamp(test.T(), "Suspending RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + rayClusterAC = rayClusterAC.WithSpec(rayClusterAC.Spec.WithSuspend(true)) + rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + // Wait for the cluster to be suspended. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to be suspended", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(StatusCondition(rayv1.RayClusterSuspended), MatchCondition(metav1.ConditionTrue, string(rayv1.RayClusterSuspended)))) + + // Verify Workload is deleted after suspend. + LogWithTimestamp(test.T(), "Verifying Workload is deleted after suspend") + g.Eventually(func() bool { + _, err := GetWorkload(test, namespace.Name, rayCluster.Name) + return errors.IsNotFound(err) + }, TestTimeoutShort).Should(BeTrue()) + + // Verify PodGroups are deleted after suspend. + LogWithTimestamp(test.T(), "Verifying PodGroups are deleted after suspend") + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(BeEmpty()) + + // Verify WorkloadScheduled condition is removed after suspend. + LogWithTimestamp(test.T(), "Verifying WorkloadScheduled condition is removed after suspend") + g.Eventually(func(gg Gomega) { + cluster, err := GetRayCluster(test, namespace.Name, rayCluster.Name) + gg.Expect(err).NotTo(HaveOccurred()) + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + gg.Expect(cond).To(BeNil(), "WorkloadScheduled condition should be removed when suspended") + }, TestTimeoutShort).Should(Succeed()) + + // Verify DeletedWorkload and DeletedPodGroup events were emitted. + g.Eventually(GetEvents(test, namespace.Name, rayCluster.Name, "DeletedWorkload"), TestTimeoutShort). + ShouldNot(BeEmpty()) + g.Eventually(GetEvents(test, namespace.Name, rayCluster.Name, "DeletedPodGroup"), TestTimeoutShort). + ShouldNot(BeEmpty()) +} + +func TestNativeScheduling_ResumeRecreatesResources(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("resume-rec", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + g.Eventually(Workloads(test, namespace.Name), TestTimeoutShort).Should(HaveLen(1)) + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(2)) + + // Capture the original Workload UID to verify it changes after resume. + workload, err := GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + originalWorkloadUID := workload.UID + + // Suspend the cluster. + LogWithTimestamp(test.T(), "Suspending RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + rayClusterAC = rayClusterAC.WithSpec(rayClusterAC.Spec.WithSuspend(true)) + _, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + // Wait for suspend to complete and resources to be deleted. + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(StatusCondition(rayv1.RayClusterSuspended), MatchCondition(metav1.ConditionTrue, string(rayv1.RayClusterSuspended)))) + g.Eventually(func() bool { + _, err := GetWorkload(test, namespace.Name, rayCluster.Name) + return errors.IsNotFound(err) + }, TestTimeoutShort).Should(BeTrue()) + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(BeEmpty()) + + // Resume the cluster. + LogWithTimestamp(test.T(), "Resuming RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + rayClusterAC = rayClusterAC.WithSpec(rayClusterAC.Spec.WithSuspend(false)) + _, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + // Wait for cluster to become ready again. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready after resume", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify Workload is recreated with correct spec and a new UID (proving it was deleted and recreated). + LogWithTimestamp(test.T(), "Verifying Workload is recreated after resume") + g.Eventually(func(inner Gomega) { + w, err := GetWorkload(test, namespace.Name, rayCluster.Name) + inner.Expect(err).NotTo(HaveOccurred()) + inner.Expect(w.UID).NotTo(Equal(originalWorkloadUID), "Workload should have a new UID after resume") + inner.Expect(w.Spec.PodGroupTemplates).To(HaveLen(2)) + inner.Expect(w.Spec.PodGroupTemplates[0].Name).To(Equal("head")) + inner.Expect(w.Spec.PodGroupTemplates[1].Name).To(Equal("worker-small-group")) + inner.Expect(w.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + inner.Expect(w.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + }, TestTimeoutShort).Should(Succeed()) + + // Verify PodGroups are recreated. + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(2)) + + _, err = GetPodGroup(test, namespace.Name, rayCluster.Name+"-head") + g.Expect(err).NotTo(HaveOccurred()) + _, err = GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group") + g.Expect(err).NotTo(HaveOccurred()) + + // Verify WorkloadScheduled condition is True after resume. + LogWithTimestamp(test.T(), "Verifying WorkloadScheduled condition is True after resume") + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutShort). + Should(WithTransform(StatusCondition(rayv1.RayClusterWorkloadScheduled), MatchCondition(metav1.ConditionTrue, rayv1.WorkloadReady))) + + // Verify new pods get schedulingGroup set. + headPod, err := GetHeadPod(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(headPod.Spec.SchedulingGroup).NotTo(BeNil()) + g.Expect(headPod.Spec.SchedulingGroup.PodGroupName).NotTo(BeNil()) + g.Expect(*headPod.Spec.SchedulingGroup.PodGroupName).To(Equal(rayCluster.Name + "-head")) + + workerPods, err := GetWorkerPods(test, rayCluster) + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(workerPods).NotTo(BeEmpty()) + for _, pod := range workerPods { + g.Expect(pod.Spec.SchedulingGroup).NotTo(BeNil()) + g.Expect(pod.Spec.SchedulingGroup.PodGroupName).NotTo(BeNil()) + g.Expect(*pod.Spec.SchedulingGroup.PodGroupName).To(Equal(rayCluster.Name + "-worker-small-group")) + } +} + +func TestNativeScheduling_ScaleUpRecreatesWorkload(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("scale-up", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Record the original Workload UID. + workload, err := GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + originalUID := workload.UID + g.Expect(workload.Spec.PodGroupTemplates).To(HaveLen(2)) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + + // Scale up workers from 1 to 3. + LogWithTimestamp(test.T(), "Scaling up worker replicas from 1 to 3") + rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(3).WithMinReplicas(3).WithMaxReplicas(3) + rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + // Wait for the cluster to become ready with new replicas. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready after scale-up", rayCluster.Namespace, rayCluster.Name) + g.Eventually(func(inner Gomega) { + rc, err := GetRayCluster(test, namespace.Name, rayCluster.Name) + inner.Expect(err).NotTo(HaveOccurred()) + inner.Expect(RayClusterState(rc)).To(Equal(rayv1.Ready)) + inner.Expect(RayClusterDesiredWorkerReplicas(rc)).To(Equal(int32(3))) + }, TestTimeoutMedium).Should(Succeed()) + + // Verify the Workload was recreated (different UID) with updated minCount. + // Both checks are in a single Eventually to avoid a race where the UID passes + // but the spec is read from a stale object. + LogWithTimestamp(test.T(), "Verifying Workload was recreated with updated minCount") + g.Eventually(func(inner Gomega) { + w, err := GetWorkload(test, namespace.Name, rayCluster.Name) + inner.Expect(err).NotTo(HaveOccurred()) + inner.Expect(w.UID).NotTo(Equal(originalUID), "Workload should have been recreated with a new UID") + inner.Expect(w.Spec.PodGroupTemplates).To(HaveLen(2)) + inner.Expect(w.Spec.PodGroupTemplates[1].Name).To(Equal("worker-small-group")) + inner.Expect(w.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + inner.Expect(w.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(3))) + }, TestTimeoutShort).Should(Succeed()) + + // Verify PodGroups are recreated with updated minCount. + // PodGroups may be delayed by scheduler finalizer removal during the delete-then-create flow, + // so we use Eventually to wait for the reconciler to successfully recreate them. + LogWithTimestamp(test.T(), "Waiting for worker PodGroup to be recreated with updated minCount") + g.Eventually(func() int32 { + pg, err := GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group") + if err != nil || pg.DeletionTimestamp != nil { + return -1 + } + if pg.Spec.SchedulingPolicy.Gang == nil { + return -1 + } + return pg.Spec.SchedulingPolicy.Gang.MinCount + }, TestTimeoutShort).Should(Equal(int32(3))) +} + +func TestNativeScheduling_ScaleDownRecreatesWorkload(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + // Start with 3 replicas so we can scale down. + rayClusterAC := rayv1ac.RayCluster("scale-down", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(rayv1ac.RayClusterSpec(). + WithRayVersion(GetRayVersion()). + WithHeadGroupSpec(rayv1ac.HeadGroupSpec(). + WithRayStartParams(map[string]string{"dashboard-host": "0.0.0.0"}). + WithTemplate(HeadPodTemplateApplyConfiguration())). + WithWorkerGroupSpecs(rayv1ac.WorkerGroupSpec(). + WithReplicas(3). + WithMinReplicas(3). + WithMaxReplicas(3). + WithGroupName("small-group"). + WithRayStartParams(map[string]string{"num-cpus": "1"}). + WithTemplate(WorkerPodTemplateApplyConfiguration()))) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s with 3 replicas", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Record the original Workload UID and verify minCount=3. + workload, err := GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + originalUID := workload.UID + g.Expect(workload.Spec.PodGroupTemplates).To(HaveLen(2)) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + g.Expect(workload.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(3))) + + // Scale down workers from 3 to 1. + LogWithTimestamp(test.T(), "Scaling down worker replicas from 3 to 1") + rayClusterAC.Spec.WorkerGroupSpecs[0].WithReplicas(1).WithMinReplicas(1).WithMaxReplicas(1) + rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + // Wait for the cluster to become ready with new replicas. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready after scale-down", rayCluster.Namespace, rayCluster.Name) + g.Eventually(func(inner Gomega) { + rc, err := GetRayCluster(test, namespace.Name, rayCluster.Name) + inner.Expect(err).NotTo(HaveOccurred()) + inner.Expect(RayClusterState(rc)).To(Equal(rayv1.Ready)) + inner.Expect(RayClusterDesiredWorkerReplicas(rc)).To(Equal(int32(1))) + }, TestTimeoutMedium).Should(Succeed()) + + // Verify the Workload was recreated (different UID) with updated minCount=1. + LogWithTimestamp(test.T(), "Verifying Workload was recreated with updated minCount") + g.Eventually(func(inner Gomega) { + w, err := GetWorkload(test, namespace.Name, rayCluster.Name) + inner.Expect(err).NotTo(HaveOccurred()) + inner.Expect(w.UID).NotTo(Equal(originalUID), "Workload should have a new UID after scale-down") + inner.Expect(w.Spec.PodGroupTemplates).To(HaveLen(2)) + inner.Expect(w.Spec.PodGroupTemplates[1].Name).To(Equal("worker-small-group")) + inner.Expect(w.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang).NotTo(BeNil()) + inner.Expect(w.Spec.PodGroupTemplates[1].SchedulingPolicy.Gang.MinCount).To(Equal(int32(1))) + }, TestTimeoutShort).Should(Succeed()) + + // Verify worker PodGroup is recreated with updated minCount. + LogWithTimestamp(test.T(), "Waiting for worker PodGroup to be recreated with updated minCount") + g.Eventually(func() int32 { + pg, err := GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group") + if err != nil || pg.DeletionTimestamp != nil { + return -1 + } + if pg.Spec.SchedulingPolicy.Gang == nil { + return -1 + } + return pg.Spec.SchedulingPolicy.Gang.MinCount + }, TestTimeoutShort).Should(Equal(int32(1))) +} + +func TestNativeScheduling_AddWorkerGroupRecreatesWorkload(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("add-wg", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s with 1 worker group", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Record the original Workload UID and verify 2 PodGroupTemplates (head + 1 worker group). + workload, err := GetWorkload(test, namespace.Name, rayCluster.Name) + g.Expect(err).NotTo(HaveOccurred()) + originalUID := workload.UID + g.Expect(workload.Spec.PodGroupTemplates).To(HaveLen(2)) + + // Add a second worker group. + LogWithTimestamp(test.T(), "Adding second worker group 'gpu-group' to RayCluster") + rayClusterAC.Spec.WithWorkerGroupSpecs(rayv1ac.WorkerGroupSpec(). + WithReplicas(2). + WithMinReplicas(2). + WithMaxReplicas(2). + WithGroupName("gpu-group"). + WithRayStartParams(map[string]string{"num-cpus": "1"}). + WithTemplate(WorkerPodTemplateApplyConfiguration())) + rayCluster, err = test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + + // Wait for the cluster to become ready with the new worker group. + LogWithTimestamp(test.T(), "Waiting for RayCluster %s/%s to become ready after adding worker group", rayCluster.Namespace, rayCluster.Name) + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify the Workload was recreated with 3 PodGroupTemplates (head + 2 worker groups). + LogWithTimestamp(test.T(), "Verifying Workload was recreated with 3 PodGroupTemplates") + g.Eventually(func(inner Gomega) { + w, err := GetWorkload(test, namespace.Name, rayCluster.Name) + inner.Expect(err).NotTo(HaveOccurred()) + inner.Expect(w.UID).NotTo(Equal(originalUID), "Workload should have a new UID after adding worker group") + inner.Expect(w.Spec.PodGroupTemplates).To(HaveLen(3)) + inner.Expect(w.Spec.PodGroupTemplates[0].Name).To(Equal("head")) + inner.Expect(w.Spec.PodGroupTemplates[1].Name).To(Equal("worker-small-group")) + inner.Expect(w.Spec.PodGroupTemplates[2].Name).To(Equal("worker-gpu-group")) + inner.Expect(w.Spec.PodGroupTemplates[2].SchedulingPolicy.Gang).NotTo(BeNil()) + inner.Expect(w.Spec.PodGroupTemplates[2].SchedulingPolicy.Gang.MinCount).To(Equal(int32(2))) + }, TestTimeoutShort).Should(Succeed()) + + // Verify 3 PodGroups exist (head + 2 worker groups). + g.Eventually(PodGroups(test, namespace.Name), TestTimeoutShort).Should(HaveLen(3)) + + _, err = GetPodGroup(test, namespace.Name, rayCluster.Name+"-head") + g.Expect(err).NotTo(HaveOccurred()) + _, err = GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-small-group") + g.Expect(err).NotTo(HaveOccurred()) + _, err = GetPodGroup(test, namespace.Name, rayCluster.Name+"-worker-gpu-group") + g.Expect(err).NotTo(HaveOccurred()) +} + +func TestNativeScheduling_WorkloadScheduledCondition(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + rayClusterAC := rayv1ac.RayCluster("cond-test", namespace.Name). + WithAnnotations(map[string]string{"ray.io/native-workload-scheduling": "true"}). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify the WorkloadScheduled condition is True/WorkloadReady. + LogWithTimestamp(test.T(), "Verifying WorkloadScheduled condition is True") + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutShort). + Should(WithTransform(StatusCondition(rayv1.RayClusterWorkloadScheduled), SatisfyAll( + WithTransform(func(c metav1.Condition) metav1.ConditionStatus { return c.Status }, Equal(metav1.ConditionTrue)), + WithTransform(func(c metav1.Condition) string { return c.Reason }, Equal(rayv1.WorkloadReady)), + ))) +} + +func TestNativeScheduling_WorkloadScheduledConditionAbsentWhenDisabled(t *testing.T) { + test := With(t) + g := NewWithT(t) + + namespace := test.NewTestNamespace() + + // Create a RayCluster without the native scheduling annotation. + rayClusterAC := rayv1ac.RayCluster("no-cond", namespace.Name). + WithSpec(NewRayClusterSpec()) + + rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions) + g.Expect(err).NotTo(HaveOccurred()) + LogWithTimestamp(test.T(), "Created RayCluster %s/%s without native scheduling annotation", rayCluster.Namespace, rayCluster.Name) + + // Wait for cluster to become ready. + g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium). + Should(WithTransform(RayClusterState, Equal(rayv1.Ready))) + + // Verify no WorkloadScheduled condition is set, and it stays absent over time. + LogWithTimestamp(test.T(), "Verifying WorkloadScheduled condition is absent") + g.Consistently(func(gg Gomega) { + cluster, err := GetRayCluster(test, namespace.Name, rayCluster.Name) + gg.Expect(err).NotTo(HaveOccurred()) + cond := meta.FindStatusCondition(cluster.Status.Conditions, string(rayv1.RayClusterWorkloadScheduled)) + gg.Expect(cond).To(BeNil(), "WorkloadScheduled condition should not be set without annotation") + }, 10*time.Second, time.Second).Should(Succeed()) +} diff --git a/ray-operator/test/support/scheduling.go b/ray-operator/test/support/scheduling.go new file mode 100644 index 00000000000..e37fa3a1809 --- /dev/null +++ b/ray-operator/test/support/scheduling.go @@ -0,0 +1,59 @@ +package support + +import ( + "github.com/onsi/gomega" + schedulingv1alpha2 "k8s.io/api/scheduling/v1alpha2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func Workload(t Test, namespace, name string) func() (*schedulingv1alpha2.Workload, error) { + return func() (*schedulingv1alpha2.Workload, error) { + return GetWorkload(t, namespace, name) + } +} + +func GetWorkload(t Test, namespace, name string) (*schedulingv1alpha2.Workload, error) { + return t.Client().Core().SchedulingV1alpha2().Workloads(namespace).Get(t.Ctx(), name, metav1.GetOptions{}) +} + +func PodGroup(t Test, namespace, name string) func() (*schedulingv1alpha2.PodGroup, error) { + return func() (*schedulingv1alpha2.PodGroup, error) { + return GetPodGroup(t, namespace, name) + } +} + +func GetPodGroup(t Test, namespace, name string) (*schedulingv1alpha2.PodGroup, error) { + return t.Client().Core().SchedulingV1alpha2().PodGroups(namespace).Get(t.Ctx(), name, metav1.GetOptions{}) +} + +func Workloads(t Test, namespace string) func(g gomega.Gomega) []schedulingv1alpha2.Workload { + return func(g gomega.Gomega) []schedulingv1alpha2.Workload { + workloads, err := t.Client().Core().SchedulingV1alpha2().Workloads(namespace).List(t.Ctx(), metav1.ListOptions{}) + g.Expect(err).NotTo(gomega.HaveOccurred()) + return workloads.Items + } +} + +func PodGroups(t Test, namespace string) func(g gomega.Gomega) []schedulingv1alpha2.PodGroup { + return func(g gomega.Gomega) []schedulingv1alpha2.PodGroup { + podGroups, err := t.Client().Core().SchedulingV1alpha2().PodGroups(namespace).List(t.Ctx(), metav1.ListOptions{}) + g.Expect(err).NotTo(gomega.HaveOccurred()) + return podGroups.Items + } +} + +func GetEvents(t Test, namespace string, objectName string, reason string) func() ([]string, error) { + return func() ([]string, error) { + events, err := t.Client().Core().EventsV1().Events(namespace).List(t.Ctx(), metav1.ListOptions{ + FieldSelector: "regarding.name=" + objectName + ",reason=" + reason, + }) + if err != nil { + return nil, err + } + messages := make([]string, 0, len(events.Items)) + for _, e := range events.Items { + messages = append(messages, e.Note) + } + return messages, nil + } +}