diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc50da1f17fdf..54d821a7a420d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -81,12 +81,17 @@ case object AllTuples extends Distribution { * * @param requireAllClusterKeys When true, `Partitioning` which satisfies this distribution, * must match all `clustering` expressions in the same ordering. + * @param allowNullKeySpreading When true, the default partitioning may spread rows whose + * clustering keys contain NULL values. This is a permission for + * consumers that do not require NULL-key co-location; ordinary + * [[HashPartitioning]] can still satisfy this distribution. */ case class ClusteredDistribution( clustering: Seq[Expression], requireAllClusterKeys: Boolean = SQLConf.get.getConf( SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION), - requiredNumPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None, + allowNullKeySpreading: Boolean = false) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -97,7 +102,11 @@ case class ClusteredDistribution( assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + s"the actual number of partitions is $numPartitions.") - HashPartitioning(clustering, numPartitions) + if (allowNullKeySpreading) { + NullAwareHashPartitioning(clustering, numPartitions) + } else { + HashPartitioning(clustering, numPartitions) + } } /** @@ -282,7 +291,7 @@ trait HashPartitioningLike extends Expression with Partitioning with Unevaluable expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) => if (requireAllClusterKeys) { // Checks `HashPartitioning` is partitioned on exactly same clustering keys of // `ClusteredDistribution`. @@ -324,6 +333,45 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren) } +/** + * Represents a hash partitioning for equi-join inputs where rows with a NULL join key do not need + * to be co-located. Non-NULL join keys preserve the same partitioning contract as + * [[HashPartitioning]], while rows with any NULL join key may be spread across partitions. As a + * result, this partitioning intentionally does not satisfy a strict [[ClusteredDistribution]]. + */ +case class NullAwareHashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends HashPartitioningLike { + + override def satisfies0(required: Distribution): Boolean = { + (required match { + case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 + case _ => false + }) || { + // Stateful operators require strict NULL-key co-location and therefore cannot consume + // null-aware hash partitioning as a compatible clustered layout. + required match { + case c @ ClusteredDistribution( + requiredClustering, requireAllClusterKeys, _, allowNullKeySpreading) + if allowNullKeySpreading => + if (requireAllClusterKeys) { + c.areAllClusterKeysMatched(expressions) + } else { + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } + case _ => false + } + } + } + + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = + NullAwareHashShuffleSpec(this, distribution) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): NullAwareHashPartitioning = + copy(expressions = newChildren) +} + case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int) /** @@ -345,6 +393,47 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa copy(from = from.copy(expressions = newChildren)) } +/** + * Represents a null-aware hash partitioning whose reducer ranges have been coalesced into fewer + * partitions. It preserves the same relaxed NULL-key co-location contract as + * [[NullAwareHashPartitioning]]. + */ +case class CoalescedNullAwareHashPartitioning( + from: NullAwareHashPartitioning, + partitions: Seq[CoalescedBoundary]) extends HashPartitioningLike { + + override def expressions: Seq[Expression] = from.expressions + + override def satisfies0(required: Distribution): Boolean = { + (required match { + case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 + case _ => false + }) || { + required match { + case c @ ClusteredDistribution( + requiredClustering, requireAllClusterKeys, _, allowNullKeySpreading) + if allowNullKeySpreading => + if (requireAllClusterKeys) { + c.areAllClusterKeysMatched(expressions) + } else { + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } + case _ => false + } + } + } + + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = + CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions) + + override val numPartitions: Int = partitions.length + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): CoalescedNullAwareHashPartitioning = + copy(from = from.copy(expressions = newChildren)) +} + /** * Represents a partitioning where rows are split across partitions based on transforms defined by * `expressions`. @@ -482,7 +571,7 @@ case class KeyedPartitioning( def groupedSatisfies(required: Distribution): Boolean = { required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) => if (requireAllClusterKeys) { // Checks whether this partitioning is partitioned on exactly same clustering keys of // `ClusteredDistribution`. @@ -657,7 +746,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) // `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`. val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) => val expressions = ordering.map(_.child) if (requireAllClusterKeys) { // Checks `RangePartitioning` is partitioned on exactly same clustering keys of @@ -782,7 +871,7 @@ case class ShufflePartitionIdPassThrough( super.satisfies0(required) || { required match { // TODO(SPARK-53428): Support Direct Passthrough Partitioning in the Streaming Joins - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) => val partitioningExpressions = expr.child :: Nil if (requireAllClusterKeys) { c.areAllClusterKeysMatched(partitioningExpressions) @@ -863,6 +952,25 @@ case class RangeShuffleSpec( } } +private object HashShuffleSpecCompatibility { + def isCompatible( + leftDistribution: ClusteredDistribution, + leftNumPartitions: Int, + leftExpressions: Seq[Expression], + leftHashKeyPositions: Seq[mutable.BitSet], + rightDistribution: ClusteredDistribution, + rightNumPartitions: Int, + rightExpressions: Seq[Expression], + rightHashKeyPositions: Seq[mutable.BitSet]): Boolean = { + leftDistribution.clustering.length == rightDistribution.clustering.length && + leftNumPartitions == rightNumPartitions && + leftExpressions.length == rightExpressions.length && + leftHashKeyPositions.zip(rightHashKeyPositions).forall { case (left, right) => + left.intersect(right).nonEmpty + } + } +} + case class HashShuffleSpec( partitioning: HashPartitioning, distribution: ClusteredDistribution) extends ShuffleSpec { @@ -895,14 +1003,26 @@ case class HashShuffleSpec( // 3. both partitioning have the same number of expressions // 4. each pair of partitioning expression from both sides has overlapping positions in their // corresponding distributions. - distribution.clustering.length == otherDistribution.clustering.length && - partitioning.numPartitions == otherPartitioning.numPartitions && - partitioning.expressions.length == otherPartitioning.expressions.length && { - val otherHashKeyPositions = otherHashSpec.hashKeyPositions - hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, right) => - left.intersect(right).nonEmpty - } - } + HashShuffleSpecCompatibility.isCompatible( + distribution, + partitioning.numPartitions, + partitioning.expressions, + hashKeyPositions, + otherDistribution, + otherPartitioning.numPartitions, + otherPartitioning.expressions, + otherHashSpec.hashKeyPositions) + case otherNullAwareSpec @ NullAwareHashShuffleSpec(otherPartitioning, otherDistribution) + if distribution.allowNullKeySpreading && otherDistribution.allowNullKeySpreading => + HashShuffleSpecCompatibility.isCompatible( + distribution, + partitioning.numPartitions, + partitioning.expressions, + hashKeyPositions, + otherDistribution, + otherPartitioning.numPartitions, + otherPartitioning.expressions, + otherNullAwareSpec.hashKeyPositions) case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => @@ -923,7 +1043,73 @@ case class HashShuffleSpec( override def createPartitioning(clustering: Seq[Expression]): Partitioning = { val exprs = hashKeyPositions.map(v => clustering(v.head)) - HashPartitioning(exprs, partitioning.numPartitions) + if (distribution.allowNullKeySpreading) { + NullAwareHashPartitioning(exprs, partitioning.numPartitions) + } else { + HashPartitioning(exprs, partitioning.numPartitions) + } + } + + override def numPartitions: Int = partitioning.numPartitions +} + +/** + * Shuffle specification for [[NullAwareHashPartitioning]]. It is compatible only with shuffle + * layouts whose distributions explicitly allow NULL-key spreading. + */ +case class NullAwareHashShuffleSpec( + partitioning: NullAwareHashPartitioning, + distribution: ClusteredDistribution) extends ShuffleSpec { + + lazy val hashKeyPositions: Seq[mutable.BitSet] = { + val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet] + distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) => + distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) + } + partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized, mutable.BitSet.empty)) + } + + override def isCompatibleWith(other: ShuffleSpec): Boolean = other match { + case SinglePartitionShuffleSpec => + partitioning.numPartitions == 1 + case otherSpec @ NullAwareHashShuffleSpec(otherPartitioning, otherDistribution) => + HashShuffleSpecCompatibility.isCompatible( + distribution, + partitioning.numPartitions, + partitioning.expressions, + hashKeyPositions, + otherDistribution, + otherPartitioning.numPartitions, + otherPartitioning.expressions, + otherSpec.hashKeyPositions) + case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution) + if distribution.allowNullKeySpreading && otherDistribution.allowNullKeySpreading => + HashShuffleSpecCompatibility.isCompatible( + distribution, + partitioning.numPartitions, + partitioning.expressions, + hashKeyPositions, + otherDistribution, + otherPartitioning.numPartitions, + otherPartitioning.expressions, + otherHashSpec.hashKeyPositions) + case ShuffleSpecCollection(specs) => + specs.exists(isCompatibleWith) + case _ => + false + } + + override def canCreatePartitioning: Boolean = { + if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { + distribution.areAllClusterKeysMatched(partitioning.expressions) + } else { + true + } + } + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + val exprs = hashKeyPositions.map(v => clustering(v.head)) + NullAwareHashPartitioning(exprs, partitioning.numPartitions) } override def numPartitions: Int = partitioning.numPartitions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ed831f20f394..6ecc5f433f36a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -967,6 +967,20 @@ object SQLConf { .checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive") .createWithDefault(200) + val SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED = + buildConf("spark.sql.shuffle.spreadNullJoinKeys.enabled") + .doc("When true, Spark may spread rows with NULL equi-join keys across shuffle partitions " + + "for shuffled LEFT, RIGHT, and FULL OUTER equi-joins on nullable keys to reduce " + + "shuffle skew. Null-aware join output partitioning does not satisfy a strict " + + "ClusteredDistribution, so downstream grouping, windowing, or equi-joins may require " + + "an extra shuffle. If one input is already hash partitioned, only the other input may " + + "be reshuffled into the null-aware layout, so the pre-shuffled input can keep its NULL " + + "skew.") + .version("4.1.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED = buildConf("spark.sql.shuffle.orderIndependentChecksum.enabled") .doc("Whether to calculate order independent checksum for the shuffle data or not. If " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala index 85d285aa76c0d..cb5d77d445121 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ShuffleSpecSuite.scala @@ -453,6 +453,66 @@ class ShuffleSpecSuite extends SparkFunSuite with SQLHelper { ) } + test("compatibility: NullAwareHashShuffleSpec") { + val spreadAB = ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true) + val spreadCD = ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true) + val regularAB = ClusteredDistribution(Seq($"a", $"b")) + + val nullAwareAB = NullAwareHashShuffleSpec( + NullAwareHashPartitioning(Seq($"a", $"b"), 10), spreadAB) + val nullAwareCD = NullAwareHashShuffleSpec( + NullAwareHashPartitioning(Seq($"c", $"d"), 10), spreadCD) + val regularABSpec = HashShuffleSpec( + HashPartitioning(Seq($"a", $"b"), 10), regularAB) + val spreadABHashSpec = HashShuffleSpec( + HashPartitioning(Seq($"a", $"b"), 10), spreadAB) + + checkCompatible(nullAwareAB, nullAwareCD, expected = true) + checkCompatible(nullAwareAB, SinglePartitionShuffleSpec, expected = false) + checkCompatible( + NullAwareHashShuffleSpec(NullAwareHashPartitioning(Seq($"a", $"b"), 1), spreadAB), + SinglePartitionShuffleSpec, + expected = true) + checkCompatible(nullAwareAB, regularABSpec, expected = false) + checkCompatible(nullAwareAB, spreadABHashSpec, expected = true) + checkCompatible(spreadABHashSpec, nullAwareAB, expected = true) + } + + test("canCreatePartitioning: NullAwareHashShuffleSpec") { + val spreadDistribution = + ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true) + val partialSpec = NullAwareHashShuffleSpec( + NullAwareHashPartitioning(Seq($"a"), 10), spreadDistribution) + val fullSpec = NullAwareHashShuffleSpec( + NullAwareHashPartitioning(Seq($"a", $"b"), 10), spreadDistribution) + + withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false") { + assert(partialSpec.canCreatePartitioning) + } + withSQLConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "true") { + assert(!partialSpec.canCreatePartitioning) + assert(fullSpec.canCreatePartitioning) + } + } + + test("createPartitioning: NullAwareHashShuffleSpec") { + checkCreatePartitioning( + NullAwareHashShuffleSpec( + NullAwareHashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)), + ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true), + NullAwareHashPartitioning(Seq($"c"), 10) + ) + + checkCreatePartitioning( + HashShuffleSpec( + HashPartitioning(Seq($"a"), 10), + ClusteredDistribution(Seq($"a", $"b"), allowNullKeySpreading = true)), + ClusteredDistribution(Seq($"c", $"d"), allowNullKeySpreading = true), + NullAwareHashPartitioning(Seq($"c"), 10) + ) + } + test("createPartitioning: other specs") { val distribution = ClusteredDistribution(Seq($"a", $"b")) checkCreatePartitioning(SinglePartitionShuffleSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala index eba0346a94bd0..bff86983961c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEShuffleReadExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, CoalescedNullAwareHashPartitioning, HashPartitioning, NullAwareHashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} @@ -83,6 +83,13 @@ case class AQEShuffleReadExec private( throw SparkException.internalError(s"Unexpected ShufflePartitionSpec: $unexpected") } CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, partitions)) + case h: NullAwareHashPartitioning => + val partitions = partitionSpecs.map { + case CoalescedPartitionSpec(start, end, _) => CoalescedBoundary(start, end) + case unexpected => + throw SparkException.internalError(s"Unexpected ShufflePartitionSpec: $unexpected") + } + CurrentOrigin.withOrigin(h.origin)(CoalescedNullAwareHashPartitioning(h, partitions)) case r: RangePartitioning => CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length)) // This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 7444384229162..114f221c52f64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,7 +30,9 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow, UnsafeRowChecksum} +import org.apache.spark.sql.catalyst.expressions.{ + Attribute, BoundReference, CollationAwareMurmur3Hash, Literal, Pmod, UnsafeProjection, + UnsafeRow, UnsafeRowChecksum} import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics @@ -349,6 +351,10 @@ object ShuffleExchangeExec { // For HashPartitioning, the partitioning key is already a valid partition ID, as we use // `HashPartitioning.partitionIdExpression` to produce partitioning key. new PartitionIdPassthrough(n) + case NullAwareHashPartitioning(_, n) => + // The null-aware extractor below produces partition IDs directly: + // Pmod(hash, n) for non-NULL keys, and a round-robin counter for NULL keys. + new PartitionIdPassthrough(n) case ShufflePartitionIdPassThrough(_, n) => // For ShufflePartitionIdPassThrough, the DirectShufflePartitionID expression directly // produces partition IDs, so we use PartitionIdPassthrough to pass them through directly. @@ -403,6 +409,32 @@ object ShuffleExchangeExec { case h: HashPartitioning => val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) row => projection(row).getInt(0) + case h: NullAwareHashPartitioning => + // Non-NULL keys must produce the same partition id as + // HashPartitioning.partitionIdExpression so opted-in HashShuffleSpec and + // NullAwareHashShuffleSpec inputs stay aligned. + val joinKeyProjection = UnsafeProjection.create(h.expressions, outputAttributes) + val boundJoinKeys = h.expressions.zipWithIndex.map { case (expr, index) => + BoundReference(index, expr.dataType, expr.nullable) + } + val partitionIdExpression = Pmod( + new CollationAwareMurmur3Hash(boundJoinKeys), + Literal(h.numPartitions)) + val partitionIdProjection = UnsafeProjection.create(partitionIdExpression :: Nil) + var nullKeyPartition = + new XORShiftRandom(TaskContext.get().partitionId()).nextInt(h.numPartitions) + row => { + val joinKeys = joinKeyProjection(row) + if (joinKeys.anyNull()) { + // NULL join keys cannot match under ordinary equi-join semantics. Spread them + // round-robin within each map task so identical rows do not collapse to one reducer. + val partition = nullKeyPartition + nullKeyPartition = (nullKeyPartition + 1) % h.numPartitions + partition + } else { + partitionIdProjection(joinKeys).getInt(0) + } + } case RangePartitioning(sortingExpressions, _) => val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) @@ -419,9 +451,14 @@ object ShuffleExchangeExec { val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && newPartitioning.numPartitions > 1 + val isNullAwareHashPartitioning = + newPartitioning.isInstanceOf[NullAwareHashPartitioning] && + newPartitioning.numPartitions > 1 + val needsDeterministicLocalSort = + (isRoundRobin || isNullAwareHashPartitioning) && SQLConf.get.sortBeforeRepartition val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // [SPARK-23207] Have to make sure stateful row-to-partition assignment is deterministic, // otherwise a retry task may output different rows and thus lead to data loss. // // Currently we following the most straight-forward way that perform a local sort before @@ -429,7 +466,7 @@ object ShuffleExchangeExec { // // Note that we don't perform local sort if the new partitioning has only 1 partition, under // that case all output rows go to the same partition. - val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { + val newRdd = if (needsDeterministicLocalSort) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { override def get: RecordComparator = new RecordBinaryComparator() @@ -468,7 +505,9 @@ object ShuffleExchangeExec { } // round-robin function is order sensitive if we don't sort the input. - val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition + // Stateful partition assignment is order-sensitive when it depends on row visitation order. + val isOrderSensitive = + (isRoundRobin || isNullAwareHashPartitioning) && !SQLConf.get.sortBeforeRepartition if (needToCopyObjectsBeforeShuffle(part)) { newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index f363156c81e54..d30506034bd2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution} +import org.apache.spark.sql.internal.SQLConf /** * Holds common logic for join operators by shuffling two child relations @@ -28,6 +29,24 @@ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Dist trait ShuffledJoin extends JoinCodegenSupport { def isSkewJoin: Boolean + private lazy val canSpreadNullJoinKeys: Boolean = { + // Only NULL keys on the preserved side can create this skew: they must be emitted, but + // cannot satisfy ordinary equi-join predicates. Non-preserved NULL-keyed rows are filtered + // out by `=` and never emitted, so their reducer placement does not matter here. + // + // Null-safe equality usually rewrites to non-null shuffle keys. The NullType corner can still + // produce NULL shuffle keys, but shuffled join execution already treats those rows as + // unmatched, so spreading them does not change the result. + val preservedSideHasNullableKeys = joinType match { + case LeftOuter => leftKeys.exists(_.nullable) + case RightOuter => rightKeys.exists(_.nullable) + case FullOuter => leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable) + case _ => false + } + conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) && + preservedSideHasNullableKeys + } + override def nodeName: String = { if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName } @@ -39,6 +58,9 @@ trait ShuffledJoin extends JoinCodegenSupport { // We re-arrange the shuffle partitions to deal with skew join, and the new children // partitioning doesn't satisfy `ClusteredDistribution`. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + } else if (canSpreadNullJoinKeys) { + ClusteredDistribution(leftKeys, allowNullKeySpreading = true) :: + ClusteredDistribution(rightKeys, allowNullKeySpreading = true) :: Nil } else { ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 0273a5d6dd494..c1741cac8ad3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -67,7 +67,7 @@ abstract class DistributionAndOrderingSuiteBase protected def resolveDistribution[T <: QueryPlan[T]]( distribution: physical.Distribution, plan: QueryPlan[T]): physical.Distribution = distribution match { - case physical.ClusteredDistribution(clustering, numPartitions, _) => + case physical.ClusteredDistribution(clustering, numPartitions, _, _) => physical.ClusteredDistribution(clustering.map(resolveAttrs(_, plan)), numPartitions) case physical.OrderedDistribution(ordering) => physical.OrderedDistribution(ordering.map(resolveAttrs(_, plan).asInstanceOf[SortOrder])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 2a0ab52c36933..711f6dbdcdb11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -233,7 +233,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with }.head resolveDistribution(distribution, relation) match { - case physical.ClusteredDistribution(clustering, _, _) => + case physical.ClusteredDistribution(clustering, _, _, _) => assert(relation.keyGroupedPartitioning.isDefined && relation.keyGroupedPartitioning.get == clustering) case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 554cf5111beac..b7798b0bde5db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution import scala.util.Random -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, NullAwareHashPartitioning, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.internal.SQLConf @@ -59,6 +59,39 @@ class ExchangeSuite extends SharedSparkSession { ) } + test("null-aware hash shuffle spreads identical NULL keys from one mapper") { + val input = Seq.fill(64)(Tuple1(null.asInstanceOf[Integer])).toDF("k").coalesce(1) + val plan = input.queryExecution.executedPlan + val exchange = ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan) + val partitionSizes = exchange.execute().collectPartitions().map(_.length) + + assert(partitionSizes.sorted === Array(16, 16, 16, 16)) + } + + test("null-aware hash shuffle preserves retry determinism with local sorting") { + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") { + val input = spark.range(64).repartition(4).selectExpr("CAST(NULL AS INT) AS k") + val plan = input.queryExecution.executedPlan + val exchange = ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan) + + assert(plan.execute().outputDeterministicLevel == DeterministicLevel.UNORDERED) + assert(exchange.shuffleDependency.rdd.outputDeterministicLevel != + DeterministicLevel.INDETERMINATE) + } + } + + test("null-aware hash shuffle marks unsorted repartitioning as order-sensitive") { + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "false") { + val input = spark.range(64).repartition(4).selectExpr("CAST(NULL AS INT) AS k") + val plan = input.queryExecution.executedPlan + val exchange = ShuffleExchangeExec(NullAwareHashPartitioning(plan.output, 4), plan) + + assert(plan.execute().outputDeterministicLevel == DeterministicLevel.UNORDERED) + assert(exchange.shuffleDependency.rdd.outputDeterministicLevel == + DeterministicLevel.INDETERMINATE) + } + } + test("BroadcastMode.canonicalized") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 50322905f29f3..0e7ba599e0fb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, JoinHint, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.physical.CoalescedNullAwareHashPartitioning import org.apache.spark.sql.classic.Strategy import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec @@ -2089,55 +2090,80 @@ class AdaptiveQueryExecSuite |ON CAST(value AS INT) = b """.stripMargin) - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - // Repartition with no partition num specified. - checkBHJ(df.repartition($"b"), - // The top shuffle from repartition is optimized out. - optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true) - - // Repartition with default partition num (5 in test env) specified. - checkBHJ(df.repartition(5, $"b"), - // The top shuffle from repartition is optimized out - // The final plan must have 5 partitions, no optimization can be made to the probe side. - optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false) - - // Repartition with non-default partition num specified. - checkBHJ(df.repartition(4, $"b"), - // The top shuffle from repartition is not optimized out - optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) + def checkRepartitionOptimization( + df: Dataset[Row], + useNullAwarePartitioning: Boolean): Unit = { + val optimizeDefaultRepartition = !useNullAwarePartitioning + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + // Repartition with no partition num specified. + checkBHJ(df.repartition($"b"), + optimizeOutRepartition = optimizeDefaultRepartition, + probeSideLocalRead = useNullAwarePartitioning, + probeSideCoalescedRead = !useNullAwarePartitioning) + + // Repartition with default partition num (5 in test env) specified. + checkBHJ(df.repartition(5, $"b"), + optimizeOutRepartition = optimizeDefaultRepartition, + probeSideLocalRead = useNullAwarePartitioning, + probeSideCoalescedRead = false) + + // Repartition with non-default partition num specified. + checkBHJ(df.repartition(4, $"b"), + optimizeOutRepartition = false, + probeSideLocalRead = true, + probeSideCoalescedRead = true) + + // Repartition by col and project away the partition cols + checkBHJ(df.repartition($"b").select($"key"), + optimizeOutRepartition = false, + probeSideLocalRead = true, + probeSideCoalescedRead = true) + } - // Repartition by col and project away the partition cols - checkBHJ(df.repartition($"b").select($"key"), - // The top shuffle from repartition is not optimized out - optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) + // Force skew join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { + // Repartition with no partition num specified. + checkSMJ(df.repartition($"b"), + optimizeOutRepartition = optimizeDefaultRepartition, + optimizeSkewJoin = useNullAwarePartitioning, + coalescedRead = !useNullAwarePartitioning) + + // Repartition with default partition num (5 in test env) specified. + checkSMJ(df.repartition(5, $"b"), + optimizeOutRepartition = optimizeDefaultRepartition, + optimizeSkewJoin = useNullAwarePartitioning, + coalescedRead = false) + + // Repartition with non-default partition num specified. + checkSMJ(df.repartition(4, $"b"), + optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) + + // Repartition by col and project away the partition cols + checkSMJ(df.repartition($"b").select($"key"), + optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) + } } - // Force skew join - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.SKEW_JOIN_ENABLED.key -> "true", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { - // Repartition with no partition num specified. - checkSMJ(df.repartition($"b"), - // The top shuffle from repartition is optimized out. - optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true) - - // Repartition with default partition num (5 in test env) specified. - checkSMJ(df.repartition(5, $"b"), - // The top shuffle from repartition is optimized out. - // The final plan must have 5 partitions, can't do coalesced read. - optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false) - - // Repartition with non-default partition num specified. - checkSMJ(df.repartition(4, $"b"), - // The top shuffle from repartition is not optimized out. - optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) - - // Repartition by col and project away the partition cols - checkSMJ(df.repartition($"b").select($"key"), - // The top shuffle from repartition is not optimized out. - optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) + checkRepartitionOptimization(df, useNullAwarePartitioning = false) + withSQLConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + // Null-aware join output partitioning is not equivalent to ordinary hash repartitioning. + val nullablePreservedSideDf = sql( + """ + |SELECT * FROM ( + | SELECT * FROM testData WHERE key = 1 + |) + |RIGHT OUTER JOIN ( + | SELECT a, b FROM testData2 + | UNION ALL + | SELECT CAST(NULL AS INT) AS a, CAST(NULL AS INT) AS b + |) + |ON CAST(value AS INT) = b + """.stripMargin) + checkRepartitionOptimization(nullablePreservedSideDf, useNullAwarePartitioning = true) } } } @@ -2604,6 +2630,39 @@ class AdaptiveQueryExecSuite } } + test("AQE preserves coalesced null-aware partitioning for outer equi-join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "8", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null-1"), + (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv") + val df = nullableLeft.join( + nullableRight, nullableLeft("k") === nullableRight("k"), "left_outer") + + checkAnswer(df, Seq( + Row(1, "left-1", 1, "right-1"), + Row(null, "left-null-1", null, null), + Row(null, "left-null-2", null, null))) + + val coalescedNullAwareReads = collect(df.queryExecution.executedPlan) { + case read: AQEShuffleReadExec + if read.hasCoalescedPartition && + read.outputPartitioning.isInstanceOf[CoalescedNullAwareHashPartitioning] => + read + } + assert(coalescedNullAwareReads.nonEmpty) + } + } + test("SPARK-35794: Allow custom plugin for cost evaluator") { CostEvaluator.instantiate( classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 73e739e261b7f..2deb452c3a099 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, LessThan} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, NullAwareHashPartitioning} import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestData} import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} @@ -36,6 +37,18 @@ class OuterJoinSuite extends SharedSparkSession with SQLTestData { private val EnsureRequirements = new EnsureRequirements() + private def extractJoinParts( + left: DataFrame, + right: DataFrame, + condition: Column): ExtractEquiJoinKeys.ReturnType = { + val analyzedJoin = left.join(right, condition, "inner") + .queryExecution.analyzed + .collectFirst { case join: Join => join } + .getOrElse(fail("Failed to build analyzed equi-join")) + ExtractEquiJoinKeys.unapply(analyzedJoin) + .getOrElse(fail("Failed to extract equi-join keys")) + } + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), @@ -345,4 +358,305 @@ class OuterJoinSuite extends SharedSparkSession with SQLTestData { val df2 = join("SHUFFLE_MERGE(t1)") checkAnswer(df1, identity, df2.collect().toSeq) } + + test("ordinary outer equi-join spreads NULL keys in shuffle partitioning") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null-1"), + (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === nullableRight("k")) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning])) + + checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, left, right)), + Seq( + Row(1, "left-1", 1, "right-1"), + Row(null, "left-null-1", null, null), + Row(null, "left-null-2", null, null)), + sortAnswers = true) + } + } + + test("ordinary outer equi-join keeps hash partitioning when null-aware shuffle is disabled") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === nullableRight("k")) + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "4") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[HashPartitioning])) + } + } + + test("ordinary outer equi-join keeps hash partitioning for non-nullable join keys") { + val nonNullableLeft = spark.range(3).toDF("k") + val nonNullableRight = spark.range(3).toDF("k") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts( + nonNullableLeft, + nonNullableRight, + nonNullableLeft("k") === nonNullableRight("k")) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + nonNullableLeft.queryExecution.sparkPlan, nonNullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[HashPartitioning])) + } + } + + test("ordinary right outer equi-join spreads NULL keys in shuffle partitioning") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null-1"), + (null.asInstanceOf[Integer], "right-null-2")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === nullableRight("k")) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, RightOuter, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning])) + + checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, RightOuter, boundCondition, left, right)), + Seq( + Row(1, "left-1", 1, "right-1"), + Row(null, null, null, "right-null-1"), + Row(null, null, null, "right-null-2")), + sortAnswers = true) + } + } + + test("ordinary full outer equi-join keeps NULL keys unmatched while spreading them") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null-1"), + (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null-1"), + (null.asInstanceOf[Integer], "right-null-2")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === nullableRight("k")) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, FullOuter, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning])) + + checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, FullOuter, boundCondition, left, right)), + Seq( + Row(1, "left-1", 1, "right-1"), + Row(null, "left-null-1", null, null), + Row(null, "left-null-2", null, null), + Row(null, null, null, "right-null-1"), + Row(null, null, null, "right-null-2")), + sortAnswers = true) + } + } + + test("ordinary outer equi-join preserves null-aware shuffle beside existing hash partitioning") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === nullableRight("k")) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val existingLeftShuffle = ShuffleExchangeExec( + HashPartitioning(leftKeys, 4), + nullableLeft.queryExecution.sparkPlan) + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + existingLeftShuffle, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + + assert(partitionings.size == 2) + assert(partitionings.count(_.isInstanceOf[HashPartitioning]) == 1) + assert(partitionings.count(_.isInstanceOf[NullAwareHashPartitioning]) == 1) + } + } + + test("mixed ordinary and null-safe outer equi-join can use null-aware shuffle partitioning") { + val nullableLeft = Seq( + (Integer.valueOf(1), null.asInstanceOf[Integer], "left-match"), + (Integer.valueOf(2), null.asInstanceOf[Integer], "left-no-match")) + .toDF("k1", "k2", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), null.asInstanceOf[Integer], "right-match"), + (Integer.valueOf(2), Integer.valueOf(3), "right-no-match")) + .toDF("k1", "k2", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts( + nullableLeft, + nullableRight, + nullableLeft("k1") === nullableRight("k1") && + nullableLeft("k2").eqNullSafe(nullableRight("k2"))) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning])) + + checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, left, right)), + Seq( + Row(1, null, "left-match", 1, null, "right-match"), + Row(2, null, "left-no-match", null, null, null)), + sortAnswers = true) + } + } + + test("null-safe outer equi-join keeps hash partitioning for non-null shuffle keys") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null")) + .toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts( + nullableLeft, + nullableRight, + nullableLeft("k").eqNullSafe(nullableRight("k"))) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[HashPartitioning])) + } + } + + test("ordinary outer equi-join spreads NULL keys for shuffled hash join") { + val nullableLeft = Seq( + (Integer.valueOf(1), "left-1"), + (null.asInstanceOf[Integer], "left-null-1"), + (null.asInstanceOf[Integer], "left-null-2")).toDF("k", "lv") + val nullableRight = Seq( + (Integer.valueOf(1), "right-1"), + (null.asInstanceOf[Integer], "right-null")).toDF("k", "rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts(nullableLeft, nullableRight, nullableLeft("k") === nullableRight("k")) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + ShuffledHashJoinExec(leftKeys, rightKeys, LeftOuter, BuildRight, boundCondition, + nullableLeft.queryExecution.sparkPlan, nullableRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning])) + + checkAnswer2(nullableLeft, nullableRight, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + ShuffledHashJoinExec( + leftKeys, rightKeys, LeftOuter, BuildRight, boundCondition, left, right)), + Seq( + Row(1, "left-1", 1, "right-1"), + Row(null, "left-null-1", null, null), + Row(null, "left-null-2", null, null)), + sortAnswers = true) + } + } + + test("NullType null-safe outer equi-join remains result-safe with null-aware shuffle") { + val nullTypeLeft = spark.range(2).selectExpr("NULL AS k", "id AS lv") + val nullTypeRight = spark.range(1).selectExpr("NULL AS k", "id AS rv") + val (_, leftKeys, rightKeys, boundCondition, _, _, _, _) = + extractJoinParts( + nullTypeLeft, + nullTypeRight, + nullTypeLeft("k").eqNullSafe(nullTypeRight("k"))) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "4", + SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED.key -> "true") { + val plan = EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, + nullTypeLeft.queryExecution.sparkPlan, nullTypeRight.queryExecution.sparkPlan)) + val partitionings = plan.collect { + case exchange: ShuffleExchangeExec => exchange.outputPartitioning + } + assert(partitionings.size == 2) + assert(partitionings.forall(_.isInstanceOf[NullAwareHashPartitioning])) + + checkAnswer2(nullTypeLeft, nullTypeRight, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + SortMergeJoinExec(leftKeys, rightKeys, LeftOuter, boundCondition, left, right)), + Seq( + Row(null, 0L, null, null), + Row(null, 1L, null, null)), + sortAnswers = true) + } + } }