Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions pkg/controller/jobs/raycluster/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"strconv"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
rayutils "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/validation/field"
Expand All @@ -40,6 +44,9 @@ import (
)

const (
redisCleanupPodSetName = kueue.PodSetReference(rayv1.RedisCleanupNode)
maxRayClusterPodSets = 8

// RayClusterPodsetReplicaSizesAnnotation is set on the job when autoscaling causes
// PodSet replica sizes to differ from the original spec. The value is a JSON
// array compatible with []kueue.PodSet, containing only the changed PodSets.
Expand Down Expand Up @@ -96,9 +103,66 @@ func BuildPodSets(rayClusterSpec *rayv1.RayClusterSpec) ([]kueue.PodSet, error)
podSets = append(podSets, workerPodSet)
}

if hasGCSFaultTolerance(rayClusterSpec) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a delicate change let's introduce a feature gate as a bailout option, say KubeRayAccountForRedisCleanup which is Beta, and add a comment o GA in 0.21 in Kueue.

podSets = append(podSets, buildRedisCleanupPodSet(rayClusterSpec))
}

return podSets, nil
}

func buildRedisCleanupPodSet(rayClusterSpec *rayv1.RayClusterSpec) kueue.PodSet {
template := *rayClusterSpec.HeadGroupSpec.Template.DeepCopy()
template.Labels = maps.Clone(template.Labels)
if template.Labels == nil {
template.Labels = make(map[string]string, 1)
}
template.Labels[rayutils.RayNodeTypeLabelKey] = string(rayv1.RedisCleanupNode)
if len(template.Spec.Containers) > 0 {
template.Spec.Containers = []corev1.Container{*template.Spec.Containers[rayutils.RayContainerIndex].DeepCopy()}
template.Spec.Containers[rayutils.RayContainerIndex].Resources = redisCleanupResourceRequirements()
}
template.Spec.RestartPolicy = corev1.RestartPolicyNever

return kueue.PodSet{
Name: redisCleanupPodSetName,
Template: template,
Count: 1,
}
}

func redisCleanupResourceRequirements() corev1.ResourceRequirements {
return corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("200m"),
corev1.ResourceMemory: resource.MustParse("256Mi"),
},
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("200m"),
corev1.ResourceMemory: resource.MustParse("256Mi"),
},
}
}

func hasGCSFaultTolerance(rayClusterSpec *rayv1.RayClusterSpec) bool {
return rayClusterSpec.GcsFaultToleranceOptions != nil
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check with the KubeRay code is this is a sufficient check? I remember that there was also a knob at the kuberay deployment that would be the global enabler, which is disabled would make this check irrelevant.

}

func ExpectedPodSetsCount(rayClusterSpec *rayv1.RayClusterSpec) int {
count := len(rayClusterSpec.WorkerGroupSpecs) + 1
if hasGCSFaultTolerance(rayClusterSpec) {
count++
}
return count
}

func maxWorkerGroupsForSpec(rayClusterSpec *rayv1.RayClusterSpec) int {
maxWorkerGroups := maxRayClusterPodSets - 1
if hasGCSFaultTolerance(rayClusterSpec) {
maxWorkerGroups--
}
return maxWorkerGroups
}

func UpdatePodSets(ctx context.Context, podSets []kueue.PodSet, c client.Client, object client.Object, enableInTreeAutoscaling *bool, rayClusterName string) ([]kueue.PodSet, error) {
log := ctrl.LoggerFrom(ctx)

Expand Down Expand Up @@ -216,9 +280,9 @@ func ValidateCreate(object client.Object, rayClusterSpec *rayv1.RayClusterSpec,
)
}

// Should limit the worker count to 8 - 1 (max podSets num - cluster head)
if len(rayClusterSpec.WorkerGroupSpecs) > 7 {
allErrors = append(allErrors, field.TooMany(rayClusterSpecPath.Child("workerGroupSpecs"), len(rayClusterSpec.WorkerGroupSpecs), 7))
// Should limit the worker count to the maximum PodSet count, minus the head and optional Redis cleanup PodSets.
if maxWorkerGroups := maxWorkerGroupsForSpec(rayClusterSpec); len(rayClusterSpec.WorkerGroupSpecs) > maxWorkerGroups {
allErrors = append(allErrors, field.TooMany(rayClusterSpecPath.Child("workerGroupSpecs"), len(rayClusterSpec.WorkerGroupSpecs), maxWorkerGroups))
}

// None of the workerGroups should be named "head"
Expand Down
105 changes: 105 additions & 0 deletions pkg/controller/jobs/raycluster/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
rayutils "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
Expand Down Expand Up @@ -199,6 +201,86 @@ func TestBuildPodSets(t *testing.T) {
Obj(),
},
},
"spec with gcs fault tolerance": {
rayClusterSpec: &rayv1.RayClusterSpec{
GcsFaultToleranceOptions: &rayv1.GcsFaultToleranceOptions{
RedisAddress: "redis:6379",
},
HeadGroupSpec: rayv1.HeadGroupSpec{
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{"ray.io/cluster": "raycluster"},
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "head",
Image: "rayproject/ray:2.0.0",
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
},
},
}},
},
},
},
WorkerGroupSpecs: []rayv1.WorkerGroupSpec{
{
GroupName: "workers",
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{Name: "worker"}},
},
},
},
},
},
wantPodSets: []kueue.PodSet{
*utiltestingapi.MakePodSet(headGroupPodSetName, 1).
Labels(map[string]string{"ray.io/cluster": "raycluster"}).
PodSpec(corev1.PodSpec{
Containers: []corev1.Container{{
Name: "head",
Image: "rayproject/ray:2.0.0",
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
},
},
}},
}).
Obj(),
*utiltestingapi.MakePodSet("workers", 1).
PodSpec(corev1.PodSpec{
Containers: []corev1.Container{{Name: "worker"}},
}).
Obj(),
*utiltestingapi.MakePodSet(redisCleanupPodSetName, 1).
Labels(map[string]string{
"ray.io/cluster": "raycluster",
rayutils.RayNodeTypeLabelKey: string(rayv1.RedisCleanupNode),
}).
PodSpec(corev1.PodSpec{
Containers: []corev1.Container{{
Name: "head",
Image: "rayproject/ray:2.0.0",
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("200m"),
corev1.ResourceMemory: resource.MustParse("256Mi"),
},
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("200m"),
corev1.ResourceMemory: resource.MustParse("256Mi"),
},
},
}},
RestartPolicy: corev1.RestartPolicyNever,
}).
Obj(),
},
},
}

for name, tc := range testCases {
Expand Down Expand Up @@ -825,6 +907,29 @@ func TestValidateCreateRayClusterSpec(t *testing.T) {
field.TooMany(field.NewPath("spec", "workerGroupSpecs"), 8, 7),
},
},
"too many worker groups with gcs fault tolerance": {
object: testingrayutil.MakeCluster("raycluster", "ns").Obj(),
rayClusterSpec: &rayv1.RayClusterSpec{
GcsFaultToleranceOptions: &rayv1.GcsFaultToleranceOptions{
RedisAddress: "redis:6379",
},
HeadGroupSpec: rayv1.HeadGroupSpec{
Template: corev1.PodTemplateSpec{},
},
WorkerGroupSpecs: []rayv1.WorkerGroupSpec{
{GroupName: "workers1"},
{GroupName: "workers2"},
{GroupName: "workers3"},
{GroupName: "workers4"},
{GroupName: "workers5"},
{GroupName: "workers6"},
{GroupName: "workers7"}, // 7 workers plus head plus Redis cleanup is too many.
},
},
wantErrors: field.ErrorList{
field.TooMany(field.NewPath("spec", "workerGroupSpecs"), 7, 6),
},
},
"worker group named 'head'": {
object: testingrayutil.MakeCluster("raycluster", "ns").Obj(),
rayClusterSpec: &rayv1.RayClusterSpec{
Expand Down
49 changes: 3 additions & 46 deletions pkg/controller/jobs/raycluster/raycluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,54 +102,11 @@ func (j *RayCluster) PodLabelSelector() string {
}

func (j *RayCluster) PodSets(ctx context.Context) ([]kueue.PodSet, error) {
// len = workerGroups + head
podSets := make([]kueue.PodSet, len(j.Spec.WorkerGroupSpecs)+1)

// head
podSets[0] = kueue.PodSet{
Name: headGroupPodSetName,
Template: *j.Spec.HeadGroupSpec.Template.DeepCopy(),
Count: 1,
}

if features.Enabled(features.TopologyAwareScheduling) {
topologyRequest, err := jobframework.NewPodSetTopologyRequest(
&j.Spec.HeadGroupSpec.Template.ObjectMeta).Build()
if err != nil {
return nil, err
}
podSets[0].TopologyRequest = topologyRequest
}

// workers
for index := range j.Spec.WorkerGroupSpecs {
wgs := &j.Spec.WorkerGroupSpecs[index]
count := int32(1)
if wgs.Replicas != nil {
count = *wgs.Replicas
}
if wgs.NumOfHosts > 1 {
count *= wgs.NumOfHosts
}
podSets[index+1] = kueue.PodSet{
Name: kueue.NewPodSetReference(wgs.GroupName),
Template: *wgs.Template.DeepCopy(),
Count: count,
}
if features.Enabled(features.TopologyAwareScheduling) {
topologyRequest, err := jobframework.NewPodSetTopologyRequest(
&wgs.Template.ObjectMeta).Build()
if err != nil {
return nil, err
}
podSets[index+1].TopologyRequest = topologyRequest
}
}
return podSets, nil
return BuildPodSets(&j.Spec)
}

func (j *RayCluster) RunWithPodSetsInfo(ctx context.Context, podSetsInfo []podset.PodSetInfo) error {
expectedLen := len(j.Spec.WorkerGroupSpecs) + 1
expectedLen := ExpectedPodSetsCount(&j.Spec)
if len(podSetsInfo) != expectedLen {
return podset.BadPodSetsInfoLenError(expectedLen, len(podSetsInfo))
}
Expand All @@ -165,7 +122,7 @@ func (j *RayCluster) RunWithPodSetsInfo(ctx context.Context, podSetsInfo []podse
}

func (j *RayCluster) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
if len(podSetsInfo) != len(j.Spec.WorkerGroupSpecs)+1 {
if len(podSetsInfo) != ExpectedPodSetsCount(&j.Spec) {
return false
}

Expand Down
13 changes: 8 additions & 5 deletions pkg/controller/jobs/raycluster/raycluster_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ func (w *RayClusterWebhook) validateCreate(ctx context.Context, job *rayv1.RayCl
)
}

// Should limit the worker count to 8 - 1 (max podSets num - cluster head)
if len(spec.WorkerGroupSpecs) > 7 {
allErrors = append(allErrors, field.TooMany(specPath.Child("workerGroupSpecs"), len(spec.WorkerGroupSpecs), 7))
// Should limit the worker count to the maximum PodSet count, minus the head and optional Redis cleanup PodSets.
if maxWorkerGroups := maxWorkerGroupsForSpec(spec); len(spec.WorkerGroupSpecs) > maxWorkerGroups {
allErrors = append(allErrors, field.TooMany(specPath.Child("workerGroupSpecs"), len(spec.WorkerGroupSpecs), maxWorkerGroups))
}

// None of the workerGroups should be named "head"
Expand Down Expand Up @@ -220,9 +220,12 @@ func (w *RayClusterWebhook) validateTopologyRequest(ctx context.Context, rayJob
if podSetsErr == nil {
allErrs = append(allErrs, jobframework.ValidatePodSetGroupingTopology(podSets, BuildPodSetAnnotationsPathByNameMap(&rayJob.Spec, headGroupMetaPath, workerGroupSpecsPath))...)
for i, p := range podSets {
if p.Name == headGroupPodSetName {
switch p.Name {
case headGroupPodSetName:
allErrs = append(allErrs, jobframework.ValidateSliceSizeAnnotationUpperBound(headGroupMetaPath, &p.Template.ObjectMeta, &p)...)
} else {
case redisCleanupPodSetName:
continue
default:
// the raycluster PodSets function places the worker podsets from index 1
workerGroupMetaPath := workerGroupSpecsPath.Index(i-1).Child("template", "metadata")
allErrs = append(allErrs, jobframework.ValidateTASPodSetRequest(workerGroupMetaPath, &p.Template.ObjectMeta)...)
Expand Down
12 changes: 12 additions & 0 deletions pkg/controller/jobs/raycluster/raycluster_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ func TestValidateCreate(t *testing.T) {
field.TooMany(field.NewPath("spec", "workerGroupSpecs"), 8, 7),
}.ToAggregate(),
},
"invalid managed - too many worker groups with gcs fault tolerance": {
job: func() *rayv1.RayCluster {
job := testingrayutil.MakeCluster("job", "ns").Queue("queue").
WithWorkerGroups(bigWorkerGroup[:7]...).
Obj()
job.Spec.GcsFaultToleranceOptions = &rayv1.GcsFaultToleranceOptions{RedisAddress: "redis:6379"}
return job
}(),
wantErr: field.ErrorList{
field.TooMany(field.NewPath("spec", "workerGroupSpecs"), 7, 6),
}.ToAggregate(),
},
"worker group uses head name": {
job: testingrayutil.MakeCluster("job", "ns").Queue("queue").
WithWorkerGroups(rayv1.WorkerGroupSpec{
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/jobs/rayservice/rayservice_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (j *RayService) PodSets(ctx context.Context) ([]kueue.PodSet, error) {
}

func (j *RayService) RunWithPodSetsInfo(ctx context.Context, podSetsInfo []podset.PodSetInfo) error {
expectedLen := len(j.Spec.RayClusterSpec.WorkerGroupSpecs) + 1
expectedLen := raycluster.ExpectedPodSetsCount(&j.Spec.RayClusterSpec)
if len(podSetsInfo) != expectedLen {
return podset.BadPodSetsInfoLenError(expectedLen, len(podSetsInfo))
}
Expand All @@ -188,7 +188,7 @@ func (j *RayService) RunWithPodSetsInfo(ctx context.Context, podSetsInfo []podse
}

func (j *RayService) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
if len(podSetsInfo) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 {
if len(podSetsInfo) != raycluster.ExpectedPodSetsCount(&j.Spec.RayClusterSpec) {
return false
}

Expand Down
Loading