Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ case class ClusteredDistribution(
clustering: Seq[Expression],
requireAllClusterKeys: Boolean = SQLConf.get.getConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION),
requiredNumPartitions: Option[Int] = None) extends Distribution {
requiredNumPartitions: Option[Int] = None,
allowNullKeySpreading: Boolean = false) extends Distribution {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a Scaladoc on this field describing the contract: it's a permission, not a requirement (an ordinary HashPartitioning still satisfies this distribution when the flag is true; the flag only weakens what the default partitioning produced by createPartitioning looks like). And it's the consumer-side knob — the partitioning-side marker (NullAwareHashPartitioning today, or a flag on HashPartitioning per the comment below) is what tells downstream operators they need to re-shuffle for strict ClusteredDistribution.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
Expand All @@ -97,7 +98,11 @@ case class ClusteredDistribution(
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(clustering, numPartitions)
if (allowNullKeySpreading) {
NullAwareHashPartitioning(clustering, numPartitions)
} else {
HashPartitioning(clustering, numPartitions)
}
}

/**
Expand Down Expand Up @@ -282,7 +287,7 @@ trait HashPartitioningLike extends Expression with Partitioning with Unevaluable
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) =>
if (requireAllClusterKeys) {
// Checks `HashPartitioning` is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
Expand Down Expand Up @@ -324,6 +329,46 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions = newChildren)
}

/**
* Represents a hash partitioning for equi-join inputs where rows with a NULL join key do not need
* to be co-located. Non-NULL join keys preserve the same partitioning contract as
* [[HashPartitioning]], while rows with any NULL join key may be spread across partitions.
*/
case class NullAwareHashPartitioning(expressions: Seq[Expression], numPartitions: Int)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design alternative worth considering: a spreadNullKeys: Boolean = false field on HashPartitioning instead of a parallel type hierarchy.

The marker this carries is one bit ("NULL keys may be spread, so I don't deliver strict same-key co-location"). Encoding it as a parallel type means duplicating hashKeyPositions, canCreatePartitioning, createPartitioning, numPartitions, and (modulo the helper just extracted) isCompatibleWith in NullAwareHashShuffleSpec, plus reproducing CoalescedHashPartitioning as CoalescedNullAwareHashPartitioning, plus a new arm in every dispatcher (ShuffleExchangeExec.prepareShuffleDependency's part and getPartitionKeyExtractor, AQEShuffleReadExec.outputPartitioning).

With a flag:

  • HashPartitioning.satisfies0 only matches strict ClusteredDistribution when !spreadNullKeys, only matches allowNullKeySpreading=true distributions when spreadNullKeys.
  • HashShuffleSpec carries the flag; one extra clause in isCompatibleWith.
  • CoalescedHashPartitioning already wraps a HashPartitioning — it inherits the flag transparently. No new coalesced class.
  • Dispatchers branch on h.spreadNullKeys instead of branching on type, so every existing case h: HashPartitioning => site (BucketingUtils, V1Writes, EnsureRequirements, AQEUtils, basicPhysicalOperators, etc.) keeps working unchanged.

The one argument for distinct types is EXPLAIN-string visibility — a one-line toString fix on the flagged variant.

Separately on this class's Scaladoc: worth calling out that NullAwareHashPartitioning intentionally does NOT satisfy a strict ClusteredDistribution (NULL clustering keys aren't co-located). That's the non-obvious downstream contract — it's what forces downstream GROUP BY / window / strict equi-join to re-shuffle.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this is an alternative design. The pros and cons are:

Pros:

  • Much less duplicated code.
  • Existing HashPartitioning plumbing can often be reused directly.
  • CoalescedHashPartitioning and HashShuffleSpec can carry the flag instead of requiring parallel classes.
  • Fewer pattern-match branches across the codebase.

Cons:

  • The semantic distinction becomes easier to overlook.
  • A HashPartitioning with spreadNullKeys = true is no longer “ordinary hash partitioning” in the old sense.
  • Every place that reasons about HashPartitioning now has to remember to inspect the flag before assuming strict same-key co-location.
  • That is subtle and potentially error-prone because HashPartitioning is already widely used.
  • The class name no longer advertises the weaker contract; you would need careful toString, docs, and audits to preserve the same clarity.

I'm a bit concerned about the cons since HashPartitioning is widely used in the codebase and the change could have a bigger blast radius than just adding another NullAwareHashPartitioning.

extends HashPartitioningLike {

override def satisfies0(required: Distribution): Boolean = {
(required match {
case UnspecifiedDistribution => true
case AllTuples => numPartitions == 1
case _ => false
}) || {
required match {
case c @ ClusteredDistribution(
requiredClustering, requireAllClusterKeys, _, allowNullKeySpreading)
if allowNullKeySpreading =>
if (requireAllClusterKeys) {
c.areAllClusterKeysMatched(expressions)
} else {
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
case _ => false
}
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
NullAwareHashShuffleSpec(this, distribution)

def partitionIdExpression: Expression = Pmod(
new CollationAwareMurmur3Hash(expressions), Literal(numPartitions)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the single-eval refactor in ShuffleExchangeExec (the case h: NullAwareHashPartitioning => branch in getPartitionKeyExtractor), this method is no longer called anywhere. ShuffleExchangeExec builds an equivalent Pmod(CollationAwareMurmur3Hash(boundJoinKeys), Literal(n)) inline against the projected key row instead of calling h.partitionIdExpression. Grep confirms no external callers. Safe to delete the three-line def.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Removed


override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): NullAwareHashPartitioning =
copy(expressions = newChildren)
}

case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)

/**
Expand All @@ -345,6 +390,42 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa
copy(from = from.copy(expressions = newChildren))
}

case class CoalescedNullAwareHashPartitioning(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Scaladoc here (and on NullAwareHashShuffleSpec below). CoalescedHashPartitioning documents what it represents — worth matching that here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

from: NullAwareHashPartitioning,
partitions: Seq[CoalescedBoundary]) extends HashPartitioningLike {

override def expressions: Seq[Expression] = from.expressions

override def satisfies0(required: Distribution): Boolean = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This body is identical to NullAwareHashPartitioning.satisfies0 at line 340 — same outer UnspecifiedDistribution/AllTuples/_ => false match, same ClusteredDistribution inner match guarded on allowNullKeySpreading, same requireAllClusterKeys branching. This is the same kind of duplication addressed elsewhere in this PR by extracting HashShuffleSpecCompatibility.isCompatible (lines 944-955).

Two cleaner shapes:

  • Lift the inner block to a private helper, e.g. private def nullAwareSatisfies0(exprs, n, required) shared by both classes.
  • Or just delegate: since boundaries don't change satisfaction semantics for the allowNullKeySpreading contract, CoalescedNullAwareHashPartitioning.satisfies0(required) is essentially from.satisfies0(required) except for the AllTuples case where numPartitions differs — that single divergence is easy to handle inline.

Side note: both overrides skip the StatefulOpClusteredDistribution case that HashPartitioningLike.satisfies0 handles. Currently unreachable (streaming joins use StatefulOpClusteredDistribution, not ClusteredDistribution, so they never opt into allowNullKeySpreading), but a one-line comment that the omission is deliberate would help the next reader.

(required match {
case UnspecifiedDistribution => true
case AllTuples => numPartitions == 1
case _ => false
}) || {
required match {
case c @ ClusteredDistribution(
requiredClustering, requireAllClusterKeys, _, allowNullKeySpreading)
if allowNullKeySpreading =>
if (requireAllClusterKeys) {
c.areAllClusterKeysMatched(expressions)
} else {
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
case _ => false
}
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)

override val numPartitions: Int = partitions.length

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): CoalescedNullAwareHashPartitioning =
copy(from = from.copy(expressions = newChildren))
}

/**
* Represents a partitioning where rows are split across partitions based on transforms defined by
* `expressions`.
Expand Down Expand Up @@ -482,7 +563,7 @@ case class KeyedPartitioning(

def groupedSatisfies(required: Distribution): Boolean = {
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) =>
if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
Expand Down Expand Up @@ -657,7 +738,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
// `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`.
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) =>
val expressions = ordering.map(_.child)
if (requireAllClusterKeys) {
// Checks `RangePartitioning` is partitioned on exactly same clustering keys of
Expand Down Expand Up @@ -782,7 +863,7 @@ case class ShufflePartitionIdPassThrough(
super.satisfies0(required) || {
required match {
// TODO(SPARK-53428): Support Direct Passthrough Partitioning in the Streaming Joins
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _, _) =>
val partitioningExpressions = expr.child :: Nil
if (requireAllClusterKeys) {
c.areAllClusterKeysMatched(partitioningExpressions)
Expand Down Expand Up @@ -903,6 +984,16 @@ case class HashShuffleSpec(
left.intersect(right).nonEmpty
}
}
case otherNullAwareSpec @ NullAwareHashShuffleSpec(otherPartitioning, otherDistribution)
if distribution.allowNullKeySpreading && otherDistribution.allowNullKeySpreading =>
distribution.clustering.length == otherDistribution.clustering.length &&
partitioning.numPartitions == otherPartitioning.numPartitions &&
partitioning.expressions.length == otherPartitioning.expressions.length && {
val otherHashKeyPositions = otherNullAwareSpec.hashKeyPositions
hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, right) =>
left.intersect(right).nonEmpty
}
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ =>
Expand All @@ -923,7 +1014,67 @@ case class HashShuffleSpec(

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
val exprs = hashKeyPositions.map(v => clustering(v.head))
HashPartitioning(exprs, partitioning.numPartitions)
if (distribution.allowNullKeySpreading) {
NullAwareHashPartitioning(exprs, partitioning.numPartitions)
} else {
HashPartitioning(exprs, partitioning.numPartitions)
}
}

override def numPartitions: Int = partitioning.numPartitions
}

case class NullAwareHashShuffleSpec(
partitioning: NullAwareHashPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {

lazy val hashKeyPositions: Seq[mutable.BitSet] = {
val distKeyToPos = mutable.Map.empty[Expression, mutable.BitSet]
distribution.clustering.zipWithIndex.foreach { case (distKey, distKeyPos) =>
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map(k => distKeyToPos.getOrElse(k.canonicalized, mutable.BitSet.empty))
}

override def isCompatibleWith(other: ShuffleSpec): Boolean = other match {
Comment thread
peter-toth marked this conversation as resolved.
case SinglePartitionShuffleSpec =>
partitioning.numPartitions == 1
case otherSpec @ NullAwareHashShuffleSpec(otherPartitioning, otherDistribution) =>
distribution.clustering.length == otherDistribution.clustering.length &&
partitioning.numPartitions == otherPartitioning.numPartitions &&
partitioning.expressions.length == otherPartitioning.expressions.length && {
val otherHashKeyPositions = otherSpec.hashKeyPositions
hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, right) =>
left.intersect(right).nonEmpty
}
}
case otherHashSpec @ HashShuffleSpec(otherPartitioning, otherDistribution)
if distribution.allowNullKeySpreading && otherDistribution.allowNullKeySpreading =>
distribution.clustering.length == otherDistribution.clustering.length &&
partitioning.numPartitions == otherPartitioning.numPartitions &&
partitioning.expressions.length == otherPartitioning.expressions.length && {
val otherHashKeyPositions = otherHashSpec.hashKeyPositions
hashKeyPositions.zip(otherHashKeyPositions).forall { case (left, right) =>
left.intersect(right).nonEmpty
}
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ =>
false
}

override def canCreatePartitioning: Boolean = {
if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
distribution.areAllClusterKeysMatched(partitioning.expressions)
} else {
true
}
}

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
val exprs = hashKeyPositions.map(v => clustering(v.head))
NullAwareHashPartitioning(exprs, partitioning.numPartitions)
}

override def numPartitions: Int = partitioning.numPartitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, CoalescedNullAwareHashPartitioning, HashPartitioning, NullAwareHashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition, UnknownPartitioning}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
Expand Down Expand Up @@ -83,6 +83,13 @@ case class AQEShuffleReadExec private(
throw SparkException.internalError(s"Unexpected ShufflePartitionSpec: $unexpected")
}
CurrentOrigin.withOrigin(h.origin)(CoalescedHashPartitioning(h, partitions))
case h: NullAwareHashPartitioning =>
val partitions = partitionSpecs.map {
case CoalescedPartitionSpec(start, end, _) => CoalescedBoundary(start, end)
case unexpected =>
throw SparkException.internalError(s"Unexpected ShufflePartitionSpec: $unexpected")
}
CurrentOrigin.withOrigin(h.origin)(CoalescedNullAwareHashPartitioning(h, partitions))
case r: RangePartitioning =>
CurrentOrigin.withOrigin(r.origin)(r.copy(numPartitions = partitionSpecs.length))
// This can only happen for `REBALANCE_PARTITIONS_BY_NONE`, which uses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ case class ShuffleExchangeExec(
*/
@transient
lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = {
outputPartitioning match {
Comment thread
peter-toth marked this conversation as resolved.
Outdated
case h: NullAwareHashPartitioning =>
logWarning(s"Materializing null-aware hash shuffle with ${h.numPartitions} partitions.")
case _ =>
}
// Wrap in the exchange's RDD scope so that any wrapper RDDs created during shuffle dependency
// preparation (e.g. by prepareShuffleDependency's mapPartitionsInternal calls) get this
// exchange's scope ID.
Expand Down Expand Up @@ -349,6 +354,8 @@ object ShuffleExchangeExec {
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case NullAwareHashPartitioning(_, n) =>
new PartitionIdPassthrough(n)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the parallel HashPartitioning case immediately above carries the comment "the partitioning key is already a valid partition ID, as we use HashPartitioning.partitionIdExpression to produce partitioning key." Worth matching here so a reader who jumps to this case sees why PartitionIdPassthrough is the right partitioner. Something like:

// The NullAware extractor below produces partition IDs directly:
// `Pmod(hash, n)` for non-NULL keys, a round-robin counter for NULL keys.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Added the comment above in the code

case ShufflePartitionIdPassThrough(_, n) =>
// For ShufflePartitionIdPassThrough, the DirectShufflePartitionID expression directly
// produces partition IDs, so we use PartitionIdPassthrough to pass them through directly.
Expand Down Expand Up @@ -403,6 +410,24 @@ object ShuffleExchangeExec {
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case h: NullAwareHashPartitioning =>
val partitionIdProjection =
UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
val joinKeyProjection = UnsafeProjection.create(h.expressions, outputAttributes)
var nullKeyPartition =
new XORShiftRandom(TaskContext.get().partitionId()).nextInt(h.numPartitions)
Comment thread
peter-toth marked this conversation as resolved.
row => {
val joinKeys = joinKeyProjection(row)
if (joinKeys.anyNull()) {
// NULL join keys cannot match under ordinary equi-join semantics. Spread them
// round-robin within each map task so identical rows do not collapse to one reducer.
val partition = nullKeyPartition
nullKeyPartition = (nullKeyPartition + 1) % h.numPartitions
partition
} else {
partitionIdProjection(row).getInt(0)
}
}
Comment on lines +412 to +437
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Join keys are evaluated twice for non-NULL rows on this path: once via joinKeyProjection(row) to call anyNull(), again via partitionIdProjection(row).getInt(0) which re-evaluates the same expressions to compute the hash. For most expression shapes that's a tight loop, but redundant.

Could evaluate the keys once, check anyNull on the materialized row, then hash directly from that row.

Combined with the static-nullability gate at ShuffledJoin.canSpreadNullJoinKeys (which skips this path entirely when keys are statically non-nullable), the residual overhead becomes "check the null bitset once per row when at least one key is nullable" — about as low as this gets without adaptive observation of actual NULL frequency.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated it to only evaluate the join keys once but the logic becomes more complicated. Please take another look!

case RangePartitioning(sortingExpressions, _) =>
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
Expand All @@ -419,17 +444,22 @@ object ShuffleExchangeExec {

val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1
val isNullAwareRoundRobin =
Comment thread
peter-toth marked this conversation as resolved.
Outdated
newPartitioning.isInstanceOf[NullAwareHashPartitioning] &&
newPartitioning.numPartitions > 1
val needsDeterministicLocalSort =
(isRoundRobin || isNullAwareRoundRobin) && SQLConf.get.sortBeforeRepartition

val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
// [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
// [SPARK-23207] Have to make sure stateful row-to-partition assignment is deterministic,
// otherwise a retry task may output different rows and thus lead to data loss.
//
// Currently we following the most straight-forward way that perform a local sort before
// partitioning.
//
// Note that we don't perform local sort if the new partitioning has only 1 partition, under
// that case all output rows go to the same partition.
val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
val newRdd = if (needsDeterministicLocalSort) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
Expand Down Expand Up @@ -468,7 +498,9 @@ object ShuffleExchangeExec {
}

// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
// Stateful partition assignment is order-sensitive when it depends on row visitation order.
val isOrderSensitive =
(isRoundRobin || isNullAwareRoundRobin) && !SQLConf.get.sortBeforeRepartition
if (needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal((_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, IsNull}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution}

Expand All @@ -28,6 +28,21 @@ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Dist
trait ShuffledJoin extends JoinCodegenSupport {
def isSkewJoin: Boolean

private def containsNullSafeJoinMarker(keys: Seq[Expression]): Boolean = {
keys.exists(_.exists(_.isInstanceOf[IsNull]))
}

private lazy val canSpreadNullJoinKeys: Boolean = {
Copy link
Copy Markdown
Contributor

@peter-toth peter-toth May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this robust enough? What if someone crafts a null handling join condition by hand?

Actually, this looks good.

Actually, why this is needed at all and when can't we spread nulls?
<=> is translated to 2 key pairs Coalesce(a.k, default), Coalesce(b.k, default)) and (IsNull(a.k), IsNull(b.k)), so null never shows up in shuffle keys. The join type check seems fair enough.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most types the coalesce key is non-null, but Literal.default(NullType) is itself null, so it seems the extracted shuffle key can still contain nulls even though those rows remain matchable under <=>.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without spreading, NullType <=> keys all hash to the same value (Murmur3Hash(null) is deterministic) → all NULL rows collocate on one reducer. The executor then runs:

  • SortMergeJoinExec.scala:1116: while (advancedStreamed() && streamedRowKey.anyNull) — skip every NULL-keyed streamed row.
  • SortMergeJoinExec.scala:1529: in full-outer, leftRowKey.anyNull triggers padding emission, never a match.

So even with NULL rows colocated, the executor's anyNull guard prevents NULL=NULL from matching. The <=> semantics the user wanted (NULL matches NULL) is never delivered for NullType — the rewrite was supposed to convert NULLs
to non-null sentinels so the executor's guard wouldn't fire, but for NullType the sentinel itself is NULL, so the guard fires anyway and the join produces only padding (full outer) or nothing (inner).

With spreading, NULL rows scatter across reducers. Each reducer's executor sees some NULL rows from both sides. The anyNull guard fires the same way. Same padding emission, same lack of matching.

Output is identical with or without spreading — both produce the broken-but-self-consistent "NULL=NULL doesn't match" behavior for NullType.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that is a good point. It seems the check is indeed unnecessary then, let me remove it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't forget to update the PR description and let's leave some comments here why spreading nulls is safe in <=> outer joins.
I wonder if left anti join could also benefit from the feature.

Copy link
Copy Markdown
Member Author

@sunchao sunchao May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR description.

As for left anti join, yes, ordinary shuffled left anti equi-joins with = could likely benefit for the same reason as outer joins: preserved left-side rows with NULL join keys are guaranteed not to match, so concentrating them on one reducer is unnecessary. I kept this PR scoped to outer joins for now to avoid broadening the change.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Sure, handling outer joins in this PR is a nice improvement.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gate opts in based on join type alone, ignoring whether the shuffle keys are actually nullable. For an outer join on non-nullable keys (e.g. f.k = d.k where both k are NOT NULL — common after a NOT NULL filter or on schema-non-null columns), the new path:

  1. Adds a per-row joinKeys.anyNull() check in ShuffleExchangeExec.getPartitionKeyExtractor that always returns false.
  2. Produces NullAwareHashPartitioning as the join's output partitioning, which doesn't satisfy ordinary ClusteredDistribution. The AdaptiveQueryExecSuite diff in this PR (optimizeOutRepartition = false cases around lines 2079-2127) shows the cost — a downstream df.repartition($"b") is no longer collapsed even though the underlying NULL-skew problem can never have existed.

Two options worth considering:

  • Gate also on leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable) so a non-nullable-key outer join falls back to plain HashPartitioning.
  • If the simpler shape is preferred, add a sentence to the lazy val's comment explicitly calling out the trade-off (skew reduction vs. potentially unnecessary downstream re-shuffle / lost optimizeOutRepartition) so future readers don't read it as an oversight.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Updated

val isOuterJoin = joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter
val canSpread = isOuterJoin &&
!containsNullSafeJoinMarker(leftKeys) &&
!containsNullSafeJoinMarker(rightKeys)
if (canSpread) {
logWarning(s"Using null-aware shuffle distribution for $joinType equi-join keys.")
}
canSpread
}
Comment on lines +32 to +48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two improvements on this gate:

(1) Static nullability check. Outer joins on non-nullable keys (PK/FK / NOT-NULL columns / post-IsNotNull filtered keys) gain nothing from the null-aware path but still pay both the runtime per-row anyNull check and the downstream re-shuffle cost from outputPartitioning no longer satisfying strict ClusteredDistribution. The analyzer already tracks Expression.nullable — use it here to make the mechanism a no-op when there's no NULL to spread.

(2) Reframe the comment around the structural reason. The current comment only addresses the <=> corner. The real "why this PR exists" story is the preserved-side / pushdown-asymmetry argument — worth leading with that, with the <=> and NullType notes as a corollary.

Suggested change
private lazy val canSpreadNullJoinKeys: Boolean = {
// Null-safe equality usually rewrites to non-null shuffle keys. The NullType corner can still
// produce NULL shuffle keys, but shuffled join execution already treats those rows as
// unmatched, so spreading them does not change the result.
val isOuterJoin = joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter
conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) &&
isOuterJoin
}
private lazy val canSpreadNullJoinKeys: Boolean = {
// NULL keys on the preserved side of an outer join must be emitted but can never
// satisfy `a.k = b.k` under three-valued logic, so their reducer placement is a
// pure layout choice. Inner joins don't have this problem because
// InferFiltersFromConstraints pushes IsNotNull(key) to both sides; for outer joins
// that pushdown is blocked on the preserved side(s) -- which is exactly where
// NULL-key skew can land.
//
// For null-safe equality (`<=>`), ExtractEquiJoinKeys rewrites to (coalesce, isNull)
// shuffle keys, which are non-null for any concrete type. The NullType corner can
// still produce NULL shuffle keys, but shuffled join execution already treats those
// rows as unmatched, so spreading them does not change the result.
val isOuterJoin = joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter
conf.getConf(SQLConf.SHUFFLE_SPREAD_NULL_JOIN_KEYS_ENABLED) &&
isOuterJoin &&
(leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable))
}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


override def nodeName: String = {
if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName
}
Expand All @@ -39,6 +54,9 @@ trait ShuffledJoin extends JoinCodegenSupport {
// We re-arrange the shuffle partitions to deal with skew join, and the new children
// partitioning doesn't satisfy `ClusteredDistribution`.
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
} else if (canSpreadNullJoinKeys) {
ClusteredDistribution(leftKeys, allowNullKeySpreading = true) ::
Comment thread
peter-toth marked this conversation as resolved.
ClusteredDistribution(rightKeys, allowNullKeySpreading = true) :: Nil
} else {
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Condition : ((isnotnull(cs_warehouse_sk#1) AND isnotnull(cs_item_sk#2)) AND migh

(4) Exchange
Input [5]: [cs_warehouse_sk#1, cs_item_sk#2, cs_order_number#3, cs_sales_price#4, cs_sold_date_sk#5]
Arguments: hashpartitioning(cs_order_number#3, cs_item_sk#2, 5), ENSURE_REQUIREMENTS, [plan_id=2]
Arguments: nullawarehashpartitioning(cs_order_number#3, cs_item_sk#2, 5), ENSURE_REQUIREMENTS, [plan_id=2]

(5) Sort [codegen id : 2]
Input [5]: [cs_warehouse_sk#1, cs_item_sk#2, cs_order_number#3, cs_sales_price#4, cs_sold_date_sk#5]
Expand All @@ -77,7 +77,7 @@ Input [4]: [cr_item_sk#8, cr_order_number#9, cr_refunded_cash#10, cr_returned_da

(10) Exchange
Input [3]: [cr_item_sk#8, cr_order_number#9, cr_refunded_cash#10]
Arguments: hashpartitioning(cr_order_number#9, cr_item_sk#8, 5), ENSURE_REQUIREMENTS, [plan_id=3]
Arguments: nullawarehashpartitioning(cr_order_number#9, cr_item_sk#8, 5), ENSURE_REQUIREMENTS, [plan_id=3]

(11) Sort [codegen id : 4]
Input [3]: [cr_item_sk#8, cr_order_number#9, cr_refunded_cash#10]
Expand Down
Loading