diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ddfe80443d561..22cd7fda2414c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2564,11 +2564,29 @@ object DecimalAggregates extends Rule[LogicalPlan] { /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 + /** Tighter than the AVG fast path's `prec + 4 <= MAX_DOUBLE_DIGITS` (= 11): + * the strict-subset keeps SPARK-37024 Double-regime exposure unchanged. */ + private val AVG_PEEL_MAX_INNER_PRECISION = 7 + + /** Matches a scale-preserving widening decimal Cast; refuses CheckOverflow + * to preserve overflow semantics on the unscaled value. */ + private object WidenedDecimalChild { + def unapply(e: Expression): Option[(Expression, Int, Int, Int)] = e match { + case Cast(inner @ DecimalExpression(p, s), DecimalType.Fixed(pPrime, sPrime), _, _) + if s == sPrime && pPrime >= p && !inner.isInstanceOf[CheckOverflow] => + Some((inner, p, pPrime, s)) + case _ => None + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsAnyPattern(SUM, AVERAGE), ruleId) { case q: LogicalPlan => q.transformExpressionsDownWithPruning( _.containsAnyPattern(SUM, AVERAGE), ruleId) { case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { + // Window arm: `ExtractWindowExpressions` hoists composite children + // (here the widening Cast) into a child Project, so widened-Cast + // peel is unreachable from this expression-level rule. case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), prec + 10, scale) @@ -2583,9 +2601,27 @@ object DecimalAggregates extends Rule[LogicalPlan] { case _ => we } case ae @ AggregateExpression(af, _, _, _, _) => af match { + case Sum(WidenedDecimalChild(inner, p, pPrime, s), _) + if p + 10 <= MAX_LONG_DIGITS => + Cast( + MakeDecimal( + ae.copy(aggregateFunction = Sum(UnscaledValue(inner))), + p + 10, s), + DecimalType.bounded(pPrime + 10, s), + Option(conf.sessionLocalTimeZone)) + case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) + // Ordered before the un-widened Average arm: when pPrime in [8, 11], + // the outer Cast's DecimalType would otherwise match that arm first. + case Average(WidenedDecimalChild(inner, p, pPrime, s), _) + if p <= AVG_PEEL_MAX_INNER_PRECISION => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(inner))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, s), DoubleType)), + DecimalType.bounded(pPrime + 4, s + 4), Option(conf.sessionLocalTimeZone)) + case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS => val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index 25adbce143fb9..bc5f8b984f5ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.catalyst.optimizer +import org.scalacheck.Gen +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks + import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{Decimal, DecimalType, DoubleType, LongType} -class DecimalAggregatesSuite extends PlanTest { +class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyChecks { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), @@ -68,6 +72,115 @@ class DecimalAggregatesSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + val testRelationC = LocalRelation($"c".decimal(7, 2)) + + test("Decimal Average Aggregation widened-cast peel: Optimized (p=7, p'=12)") { + val widened = $"c".cast(DecimalType(12, 2)) + val originalQuery = testRelationC.select(avg(widened)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelationC + .select((avg(UnscaledValue($"c")) / 100.0).cast(DecimalType(16, 6)) + .as("avg(CAST(c AS DECIMAL(12,2)))")).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation widened-cast peel: Not Optimized (narrowing cast)") { + val testRelationD = LocalRelation($"d".decimal(10, 2)) + val narrowed = $"d".cast(DecimalType(8, 2)) + val originalQuery = testRelationD.select(avg(narrowed)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelationD + .select((avg(UnscaledValue(narrowed)) / 100.0).cast(DecimalType(12, 6)) + .as("avg(CAST(d AS DECIMAL(8,2)))")).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation widened-cast peel: Not Optimized (scale change)") { + val testRelationD = LocalRelation($"d".decimal(7, 2)) + val rescaled = $"d".cast(DecimalType(12, 4)) + val originalQuery = testRelationD.select(avg(rescaled)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation widened-cast peel: Not Optimized (boundary p=8)") { + val testRelationE = LocalRelation($"e".decimal(8, 2)) + val widened = $"e".cast(DecimalType(13, 2)) + val originalQuery = testRelationE.select(avg(widened)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + // SPARK-56627 F2 regression: with `pPrime in [8, 11]`, the outer Cast's + // dataType `Decimal(pPrime, s)` would let the un-widened existing + // `Average(DecimalExpression)` arm match first via `prec + 4 <= MAX_DOUBLE_DIGITS` + // (= pPrime <= 11). New AVG peel arm must be ordered before to win this band + // and rewrite via the inner `p`-based UnscaledValue path. + test("Decimal Average Aggregation widened-cast peel: " + + "Optimized for pPrime band [8, 11] (must beat existing AVG fast-path arm)") { + val testRelationE = LocalRelation($"e".decimal(7, 2)) + val widened = $"e".cast(DecimalType(10, 2)) + val originalQuery = testRelationE.select(avg(widened).as("avg_widened")) + val optimized = Optimize.execute(originalQuery.analyze) + // Expected: peeled via WidenedDecimalChild(inner=e, p=7, pPrime=10, s=2), + // outer Cast bounded(pPrime+4=14, s+4=6). NOT + // `MakeDecimal(Sum(UnscaledValue(cast(e as dec(10,2)))), 14, 2)` (existing + // arm form), which would lose F2's intent of avoiding the widened-cast + // intermediate. + val correctAnswer = testRelationE + .select( + Cast( + Divide( + avg(UnscaledValue($"e")), + Literal.create(math.pow(10.0, 2), DoubleType)), + DecimalType.bounded(14, 6), + Option(conf.sessionLocalTimeZone)) + .as("avg_widened")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + // SPARK-56627 F1 regression: `WidenedDecimalChild` must NOT peel when the + // inner expression is a `CheckOverflow` (introduced by `DecimalPrecision` + // analyzer for nullOnOverflow semantics). Peeling through `CheckOverflow` + // would change the overflow behavior of the inner aggregate. + // + // The existing un-widened `Average(DecimalExpression)` arm still fires on + // the outer Cast (dataType `Decimal(pPrime=10, s=2)`, `pPrime + 4 = 14 <= 15`), + // so the optimized plan wraps `UnscaledValue` around the OUTER cast (not + // the inner CheckOverflow). The peel-arm-fired form would instead be + // `UnscaledValue(CheckOverflow(e))` (no outer cast), which we want to AVOID. + test("Decimal Average Aggregation widened-cast peel: " + + "Not peeled for Cast(CheckOverflow(inner), wider) form (F1 guard)") { + val testRelationE = LocalRelation($"e".decimal(7, 2)) + val co = CheckOverflow($"e", DecimalType(7, 2), nullOnOverflow = true) + val widened = Cast(co, DecimalType(10, 2)) + val originalQuery = testRelationE.select(avg(widened).as("avg_co")) + val optimized = Optimize.execute(originalQuery.analyze) + + // Existing un-widened AVG arm fires on the outer Cast (pPrime=10, + // pPrime + 4 = 14 <= 15), wrapping UnscaledValue around the OUTER cast. + val correctAnswer = testRelationE + .select( + Cast( + Divide( + avg(UnscaledValue(widened)), + Literal.create(math.pow(10.0, 2), DoubleType)), + DecimalType(14, 6), + Option(conf.sessionLocalTimeZone)) + .as("avg_co")) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("Decimal Sum Aggregation over Window: Optimized") { val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame) val originalQuery = testRelation.select(windowExpr(sum($"a"), spec).as("sum_a")) @@ -120,4 +233,353 @@ class DecimalAggregatesSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + // --------------------------------------------------------------------------- + // Widened-Cast Peel (SUM-only) -- SPARK-56627 + // + // Extractor `WidenedDecimalChild` recognises a scale-preserving widening + // Cast, enabling the existing fast path to fire on `SUM(CAST(x, wider))` + // patterns that previously fell off the p+10 <= MAX_LONG_DIGITS guard. + // + // These tests assert behavioural plan-shape invariants via the local + // `Optimize` RuleExecutor (runs only DecimalAggregates). Literal-no-peel + // is covered separately via SimpleTestOptimizer because the local + // RuleExecutor here does not run ConstantFolding. + // --------------------------------------------------------------------------- + + private val widenRel = LocalRelation( + $"d7_2".decimal(7, 2), + $"d8_2".decimal(8, 2), + $"d9_2".decimal(9, 2), + $"d17_2".decimal(17, 2), + $"i".int) + + test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) peels via widened-Cast fast path") { + // Witness chosen so p+10=17 <= MAX_LONG_DIGITS(18) < pPrime+10=27 -- the + // new case fires (a bare-Cast inner cannot fall through to the existing + // DecimalExpression case). Expected shape: + // Cast(MakeDecimal(Sum(UnscaledValue(d7_2)), p+10=17, s=2), + // DecimalType.bounded(pPrime+10=27, s=2), + // Option(conf.sessionLocalTimeZone)) + val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = widenRel + .select(Cast( + MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2), + DecimalType.bounded(27, 2), + Option(conf.sessionLocalTimeZone)) + .as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) -- peel preserves schema") { + // Schema invariance via DataType equality (not string). + // Top-level output type of SUM(dec(p,s)) is DecimalType(min(p+10,38), s); + // peeled tree wraps inner with outer Cast(_, dec(pPrime+10,s)) = dec(27,2) + // -- identical to baseline schema. + val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))) + val baselineSchema = q.analyze.schema + val optimized = Optimize.execute(q.analyze) + assert(optimized.schema === baselineSchema, + s"peel changed schema: $baselineSchema -> ${optimized.schema}") + } + + test("SPARK-56627: SUM(CAST(int AS dec(10,0))) does NOT peel (non-decimal inner)") { + val q = widenRel.select(sum($"i".cast(DecimalType(10, 0)))) + val optimized = Optimize.execute(q.analyze) + // Peel must NOT fire; plan shape == input analyze. + val correctAnswer = q.analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-56627: AVG(CAST(dec(7,2) AS dec(17,2))) -- peel preserves schema") { + val q = widenRel.select(avg($"d7_2".cast(DecimalType(17, 2)))) + val baselineSchema = q.analyze.schema + val optimized = Optimize.execute(q.analyze) + assert(optimized.schema === baselineSchema, + s"peel changed schema: $baselineSchema -> ${optimized.schema}") + } + + test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(18,6))) does NOT peel (scale change)") { + val q = widenRel.select(sum($"d7_2".cast(DecimalType(18, 6)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = q.analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-56627: SUM(CAST(dec(17,2) AS dec(10,2))) does NOT peel (narrowing)") { + val q = widenRel.select(sum($"d17_2".cast(DecimalType(10, 2)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = q.analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-56627: SUM(CheckOverflow(Cast(...))) does NOT peel") { + val co = CheckOverflow( + $"d7_2".cast(DecimalType(17, 2)), DecimalType(17, 2), nullOnOverflow = true) + val q = widenRel.select(sum(co).as("s")) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = q.analyze + + comparePlans(optimized, correctAnswer) + } + + // Pre-existing fast-path regression guard. + // Witness: SUM(d7_2), no Cast. p+10 = 17 <= MAX_LONG_DIGITS(18) hits the + // existing `Sum(e @ DecimalExpression(p, s))` case. The new peel case + // prepended must NOT shadow the existing fast path on no-cast inputs. + test("SPARK-56627: SUM(dec(7,2)) hits existing DecimalExpression fast path") { + val expected = widenRel + .select(MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2).as("sum(d7_2)")).analyze + val q = widenRel.select(sum($"d7_2")) + val optimized = Optimize.execute(q.analyze) + comparePlans(optimized, expected) + } + + // Literal-in-Cast no-peel regression guard. + // + // Uses `SimpleTestOptimizer` (full optimizer batches) rather than the local + // `Optimize` RuleExecutor defined above, because this test depends on + // `ConstantFolding` running before `DecimalAggregates`: the outer Cast on a + // foldable Literal child is folded away before the peel rule ever sees it, + // so there is no Cast left to peel. Post-optimization the plan contains + // neither `MakeDecimal` nor an `UnscaledValue` call -- SUM sees a bare + // `Literal(dec(17,2))` whose precision (17) already fails the existing + // `prec + 10 <= MAX_LONG_DIGITS` guard (27 > 18), so the whole SUM arm is + // a no-op. The assertion is deliberately absence-of-peel-shape rather than + // structural equality, to survive unrelated ConstantFolding changes. + test("SPARK-56627: SUM(CAST(Literal(dec(7,2)) AS dec(17,2))) does NOT peel " + + "after ConstantFolding") { + val lit = Literal.create(Decimal("1.23"), DecimalType(7, 2)) + val q = widenRel.select(sum(lit.cast(DecimalType(17, 2)))) + val optimized = SimpleTestOptimizer.execute(q.analyze) + val hasMakeDecimal = optimized.expressions.exists(_.exists { + case _: MakeDecimal => true + case _ => false + }) + val hasUnscaledValue = optimized.expressions.exists(_.exists { + case _: UnscaledValue => true + case _ => false + }) + assert(!hasMakeDecimal, + s"peel unexpectedly fired on a folded Literal child; plan:\n$optimized") + assert(!hasUnscaledValue, + s"UnscaledValue unexpectedly present on folded Literal child; plan:\n$optimized") + } + + // Plan-shape invariant guards (null / empty-relation witnesses). + // + // DecimalAggregatesSuite is a PlanTest without a SparkSession; the local + // `Optimize` RuleExecutor runs DecimalAggregates only. At plan level, an + // all-null Literal-typed column shares the extractor path of any other + // DecimalExpression, and an empty LocalRelation shares the shape of the + // non-empty widenRel. These two witnesses assert the peel rule fires + // identically to the canonical witness under both inputs -- rule body is + // data-independent. End-to-end null-propagation semantics are covered + // separately in the sql-core equivalence suite. + + test("SPARK-56627: SUM(CAST(Literal(null, dec(7,2)) AS dec(17,2))) peels " + + "(null Literal in Cast, plan-shape invariant)") { + val nullLit = Literal.create(null, DecimalType(7, 2)) + val q = widenRel.select(sum(nullLit.cast(DecimalType(17, 2)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = widenRel + .select(Cast( + MakeDecimal(sum(UnscaledValue(nullLit)), 17, 2), + DecimalType.bounded(27, 2), + Option(conf.sessionLocalTimeZone)) + .as("sum(CAST(NULL AS DECIMAL(17,2)))")).analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-56627: SUM(CAST(dec(7,2) AS dec(17,2))) on empty LocalRelation peels " + + "(empty-relation plan-shape invariant)") { + val emptyRel = LocalRelation($"d7_2".decimal(7, 2)) + val q = emptyRel.select(sum($"d7_2".cast(DecimalType(17, 2)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = emptyRel + .select(Cast( + MakeDecimal(sum(UnscaledValue($"d7_2")), 17, 2), + DecimalType.bounded(27, 2), + Option(conf.sessionLocalTimeZone)) + .as("sum(CAST(d7_2 AS DECIMAL(17,2)))")).analyze + comparePlans(optimized, correctAnswer) + } + + // Idempotence invariant guard. + // + // Post-peel, the `Sum` child is `UnscaledValue(DecimalExpression)` which + // types to `LongType`, so the `WidenedDecimalChild` extractor (which + // guards on `DecimalType(p, s)` with `p + 10 <= MAX_LONG_DIGITS < p' + 10`) + // cannot re-match on the second pass. Use `canonicalized` (not `==`) to + // neutralise `exprId` drift across `Sum` aggregate-expression allocation + // in successive rule applications. + test("SPARK-56627: DecimalAggregates is idempotent on canonical widened witness " + + "(peel(peel(t)) == peel(t) under canonicalization)") { + val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))).analyze + val once = DecimalAggregates(q) + val twice = DecimalAggregates(DecimalAggregates(q)) + assert(once.canonicalized == twice.canonicalized, + s"DecimalAggregates re-fired on already-peeled plan.\n" + + s"once:\n$once\ntwice:\n$twice") + } + + // RuleExecutor convergence: drive DecimalAggregates inside a fixed-point + // RuleExecutor batch and assert it converges in <= 1 application after the + // first peel. Catches accidental rewrite oscillations in fixed-point batches. + test("SPARK-56627: DecimalAggregates converges under RuleExecutor on widened SUM") { + object Once extends RuleExecutor[LogicalPlan] { + val batches: Seq[Batch] = + Seq(Batch("DecimalAggregates", FixedPoint(10), DecimalAggregates)) + } + val q = widenRel.select(sum($"d7_2".cast(DecimalType(17, 2)))).analyze + val once = DecimalAggregates(q) + val converged = Once.execute(q) + assert(once.canonicalized == converged.canonicalized, + s"FixedPoint did not converge to single peel.\n" + + s"once:\n$once\nconverged:\n$converged") + } + + // Negative guard-miss: at p=9, the inner decimal already exceeds the + // existing DecimalExpression fast path (p+10=19 > MAX_LONG_DIGITS=18) so + // the peel rewrite must NOT fire. Pin via plan-equality against analyzed. + test("SPARK-56627: SUM(CAST(dec(9,2) AS dec(19,2))) does NOT peel (p=9 guard)") { + val rel = LocalRelation($"d9_2".decimal(9, 2)) + val q = rel.select(sum($"d9_2".cast(DecimalType(19, 2)))).analyze + val optimized = Optimize.execute(q) + comparePlans(optimized, q) + } + + // Plan-shape property: structural invariants on the peeled tree. + // + // Sweeps the (p, p', s) lattice where the widened-cast peel fires: + // regime (ii): p + 10 <= 18 <= p' + 10 (new arm, old fast-path off) + // regime (iii): p + 10 <= 18 < p' + 10 <= 38 + // Assertion (peel-on, structural -- NOT a hand-typed RHS clone): + // - aggregate expression is wrapped by exactly one outer Cast + // - the outer Cast wraps exactly one MakeDecimal + // - inside MakeDecimal, the Sum's child has dataType=LongType (i.e. + // UnscaledValue was inserted) + // - outer Cast target precision = p' + 10 (or 38, clamped) + // - outer Cast target scale = s + // Reframed away from RHS-equality to detect behavioural regressions + // rather than just refactor drift. + // Peel-off branch: plan is unchanged relative to its analyzed form + // (the local RuleExecutor runs only DecimalAggregates; no other rule + // can rewrite the SUM when the peel does not fire for a Cast child). + + private case class PeelInputs(p: Int, pPrime: Int, s: Int) + + private val peelGen: Gen[PeelInputs] = Gen.frequency( + 5 -> (for { + p <- Gen.choose(1, 8) + pPrime <- Gen.choose(math.max(p + 1, 9), 28) + s <- Gen.choose(0, p) + } yield PeelInputs(p, pPrime, s)), + 5 -> (for { + p <- Gen.choose(1, 8) + pPrime <- Gen.choose(9, 28) + s <- Gen.choose(0, p) + } yield PeelInputs(p, pPrime, s)) + ) + + private val boundaryGen: Gen[PeelInputs] = Gen.oneOf( + PeelInputs(7, 17, 2), PeelInputs(7, 18, 2), PeelInputs(7, 19, 2)) + + private val peelSpaceGen: Gen[PeelInputs] = Gen.frequency( + 8 -> peelGen, + 2 -> boundaryGen + ).retryUntil(in => in.p + 10 <= 18 && in.p < in.pPrime && in.pPrime + 10 <= 38) + + implicit override val generatorDrivenConfig: PropertyCheckConfiguration = + PropertyCheckConfiguration(minSuccessful = 50, minSize = 0, sizeRange = 0) + + test("SPARK-56627: DecimalAggregates widened-Cast SUM peel -- plan-shape " + + "structural-invariants property") { + forAll(peelSpaceGen) { in => + val rel = LocalRelation($"x".decimal(in.p, in.s)) + val q = rel.select(sum($"x".cast(DecimalType(in.pPrime, in.s)))) + val analyzed = q.analyze + + val optimized = Optimize.execute(analyzed) + + // Structural invariants the peel rewrite must establish, regardless + // of incidental tree-shape changes from neighbouring rules: + // + // I1. exactly one Sum node, whose child has LongType (the peeled + // UnscaledValue feed); + // I2. exactly one MakeDecimal node in the tree (rebuilds Decimal + // from the LONG accumulator); + // I3. an outer Cast whose target DecimalType has precision at + // least as wide as the user-written widened cast, so we never + // narrow result precision below the baseline plan. + val sums = optimized.expressions.flatMap(_.collect { case s: Sum => s }) + assert(sums.size == 1, s"expected exactly 1 Sum, got ${sums.size} in $optimized") + assert(sums.head.child.dataType == LongType, + s"expected Sum.child: LongType, got ${sums.head.child.dataType} in $optimized") + + val mds = optimized.expressions.flatMap(_.collect { case m: MakeDecimal => m }) + assert(mds.size == 1, + s"expected exactly 1 MakeDecimal, got ${mds.size} in $optimized") + + val outerCasts = optimized.expressions.flatMap(_.collect { + case c @ Cast(_, _: DecimalType, _, _) => c + }) + assert(outerCasts.nonEmpty, + s"expected an outer Cast to DecimalType, got none in $optimized") + val outerPrec = outerCasts.map(_.dataType.asInstanceOf[DecimalType].precision).max + assert(outerPrec >= in.pPrime, + s"outer Cast precision $outerPrec < baseline ${in.pPrime} in $optimized") + } + } + + // --------------------------------------------------------------------------- + // F5 (skeptic round 1): Long-accumulator / Double-regime safety boundary + // invariant guards. + // + // Background: a strict "overflow oracle" cannot be written at unit-test + // scale -- the existing fast-path guards (`p + 10 <= MAX_LONG_DIGITS = 18` + // for SUM, `AVG_PEEL_MAX_INNER_PRECISION = 7` for AVG) keep the peel-eligible + // inner-precision band so narrow that the Long accumulator (~9.22e18) cannot + // wrap on any reachable peel input: at `p=8` we'd need ~9.22e10 rows. So + // there is no production input that exercises a "peeled Long-wrap vs + // un-peeled CheckOverflow" asymmetry to oracle against. + // + // What we CAN lock is the boundary itself: if someone in the future relaxes + // either guard (raising `MAX_LONG_DIGITS - 10` for SUM, or + // `AVG_PEEL_MAX_INNER_PRECISION` for AVG), the input shapes below WOULD + // start peeling -- and the assertion that the rule is a no-op for these + // inputs would fail. That is the safety net we want: a mechanical guard + // that catches accidental widening of the peel-trigger surface. + test("SPARK-56627: SUM(CAST(dec(9,2) AS dec(19,2))) does NOT peel " + + "(Long-accumulator safety boundary)") { + // Boundary witness: inner p=9 makes widened-arm `p + 10 = 19 > 18` reject, + // AND outer-cast existing-arm `prec + 10 = 29 > 18` reject. Both arms are + // no-ops by design -- peel cannot fire on this shape today, and must not + // start firing if the inner-precision band is later widened without + // re-deriving the Long-accumulator bound. + val q = widenRel.select(sum($"d9_2".cast(DecimalType(19, 2)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = q.analyze + comparePlans(optimized, correctAnswer) + } + + test("SPARK-56627: AVG(CAST(dec(8,2) AS dec(20,2))) does NOT peel " + + "(Double-regime / SPARK-37024 safety boundary)") { + // Boundary witness: inner p=8 makes widened-AVG arm + // `p > AVG_PEEL_MAX_INNER_PRECISION (7)` reject, AND outer-cast existing + // AVG arm `prec + 4 = 24 > MAX_DOUBLE_DIGITS (15)` reject. The strict- + // subset guard `p <= 7` keeps this rule's trigger surface strictly + // inside the existing AVG fast path's surface, so SPARK-37024 + // (Double-regime silent precision loss) is not amplified. If someone + // raises `AVG_PEEL_MAX_INNER_PRECISION` past 7 without first fixing + // SPARK-37024, this test will start firing and flag the regression. + val q = widenRel.select(avg($"d8_2".cast(DecimalType(20, 2)))) + val optimized = Optimize.execute(q.analyze) + val correctAnswer = q.analyze + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..1186901b35753 --- /dev/null +++ b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt @@ -0,0 +1,74 @@ +================================================================================================ +DecimalAggregates SUM widened-cast peel (Aggregate) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2178 2236 56 4.6 217.8 1.0X +widened cast, peel off 2369 2381 9 4.2 236.9 0.9X +widened cast, peel on 2105 2118 12 4.8 210.5 1.0X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A2 p=7 s=2 p'=17: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2103 2115 17 4.8 210.3 1.0X +widened cast, peel off 2366 2377 7 4.2 236.6 0.9X +widened cast, peel on 2100 2109 11 4.8 210.0 1.0X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2117 2138 29 4.7 211.7 1.0X +widened cast, peel off 2403 2416 13 4.2 240.3 0.9X +widened cast, peel on 2157 2164 7 4.6 215.7 1.0X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2151 2157 7 4.6 215.1 1.0X +widened cast, peel off 2420 2427 10 4.1 242.0 0.9X +widened cast, peel on 2152 2159 9 4.6 215.2 1.0X + + +================================================================================================ +DecimalAggregates AVG widened-cast peel (Aggregate) +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2130 2136 5 4.7 213.0 1.0X +widened cast, peel off 2358 2367 15 4.2 235.8 0.9X +widened cast, peel on 2140 2150 7 4.7 214.0 1.0X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B2 p=7 s=2 p'=12: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2147 2151 3 4.7 214.7 1.0X +widened cast, peel off 2359 2361 2 4.2 235.9 0.9X +widened cast, peel on 2126 2161 20 4.7 212.6 1.0X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2173 2185 9 4.6 217.3 1.0X +widened cast, peel off 2405 2413 7 4.2 240.5 0.9X +widened cast, peel on 2167 2177 12 4.6 216.7 1.0X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 2173 2179 7 4.6 217.3 1.0X +widened cast, peel off 2393 2400 11 4.2 239.3 0.9X +widened cast, peel on 2172 2178 5 4.6 217.2 1.0X + + diff --git a/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt new file mode 100644 index 0000000000000..60109cac85ec9 --- /dev/null +++ b/sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt @@ -0,0 +1,74 @@ +================================================================================================ +DecimalAggregates SUM widened-cast peel (Aggregate) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1194 1230 57 8.4 119.4 1.0X +widened cast, peel off 1421 1433 11 7.0 142.1 0.8X +widened cast, peel on 1181 1188 5 8.5 118.1 1.0X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A2 p=7 s=2 p'=17: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1174 1189 12 8.5 117.4 1.0X +widened cast, peel off 1401 1414 8 7.1 140.1 0.8X +widened cast, peel on 1169 1178 8 8.6 116.9 1.0X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1245 1254 10 8.0 124.5 1.0X +widened cast, peel off 1498 1503 5 6.7 149.8 0.8X +widened cast, peel on 1222 1232 10 8.2 122.2 1.0X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +A4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1234 1238 3 8.1 123.4 1.0X +widened cast, peel off 1473 1478 7 6.8 147.3 0.8X +widened cast, peel on 1242 1255 16 8.1 124.2 1.0X + + +================================================================================================ +DecimalAggregates AVG widened-cast peel (Aggregate) +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1178 1185 9 8.5 117.8 1.0X +widened cast, peel off 1434 1440 8 7.0 143.4 0.8X +widened cast, peel on 1232 1235 3 8.1 123.2 1.0X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B2 p=7 s=2 p'=12: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1222 1229 7 8.2 122.2 1.0X +widened cast, peel off 1434 1444 10 7.0 143.4 0.9X +widened cast, peel on 1216 1223 6 8.2 121.6 1.0X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1267 1274 6 7.9 126.7 1.0X +widened cast, peel off 1505 1509 4 6.6 150.5 0.8X +widened cast, peel on 1272 1277 7 7.9 127.2 1.0X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1013-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz +B4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 1269 1275 5 7.9 126.9 1.0X +widened cast, peel off 1494 1501 9 6.7 149.4 0.8X +widened cast, peel on 1268 1274 6 7.9 126.8 1.0X + + diff --git a/sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt b/sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt new file mode 100644 index 0000000000000..d9c2c9662826a --- /dev/null +++ b/sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt @@ -0,0 +1,74 @@ +================================================================================================ +DecimalAggregates SUM widened-cast peel (Aggregate) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +A1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3068 3095 35 3.3 306.8 1.0X +widened cast, peel off 3396 3410 19 2.9 339.6 0.9X +widened cast, peel on 3107 3115 10 3.2 310.7 1.0X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +A2 p=7 s=2 p'=17: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3104 3120 23 3.2 310.4 1.0X +widened cast, peel off 3386 3407 27 3.0 338.6 0.9X +widened cast, peel on 3094 3106 17 3.2 309.4 1.0X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +A3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3039 3053 21 3.3 303.9 1.0X +widened cast, peel off 3336 3340 5 3.0 333.6 0.9X +widened cast, peel on 3034 3048 14 3.3 303.4 1.0X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +A4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3037 3049 16 3.3 303.7 1.0X +widened cast, peel off 3324 3340 16 3.0 332.4 0.9X +widened cast, peel on 3027 3031 4 3.3 302.7 1.0X + + +================================================================================================ +DecimalAggregates AVG widened-cast peel (Aggregate) +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +B1 p=7 s=2 p'=8: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3038 3041 2 3.3 303.8 1.0X +widened cast, peel off 3274 3283 18 3.1 327.4 0.9X +widened cast, peel on 3056 3074 15 3.3 305.6 1.0X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +B2 p=7 s=2 p'=12: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3029 3033 3 3.3 302.9 1.0X +widened cast, peel off 3288 3291 2 3.0 328.8 0.9X +widened cast, peel on 3031 3036 6 3.3 303.1 1.0X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +B3 p=5 s=0 p'=6: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3022 3030 5 3.3 302.2 1.0X +widened cast, peel off 3275 3307 28 3.1 327.5 0.9X +widened cast, peel on 3025 3028 3 3.3 302.5 1.0X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1013-azure +AMD EPYC 9V74 80-Core Processor +B4 p=5 s=0 p'=15: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +native (no cast, rule on) 3024 3039 21 3.3 302.4 1.0X +widened cast, peel off 3279 3298 17 3.1 327.9 0.9X +widened cast, peel on 3016 3023 6 3.3 301.6 1.0X + + diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt index f7c0dcd7c56b6..ff0b0e468530e 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/explain.txt @@ -257,7 +257,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, (46) HashAggregate [codegen id : 13] Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#22, i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(cast(cs_list_price#5 as decimal(12,2))), partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))), partial_avg(cast(cs_sales_price#6 as decimal(12,2))), partial_avg(cast(cs_net_profit#8 as decimal(12,2))), partial_avg(cast(c_birth_year#22 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] +Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(UnscaledValue(cs_list_price#5)), partial_avg(UnscaledValue(cs_coupon_amt#7)), partial_avg(UnscaledValue(cs_sales_price#6)), partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#22 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37, count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45, count#46] Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] @@ -268,9 +268,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29, ca_state#30, ca_county# (48) HashAggregate [codegen id : 14] Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))), avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#22 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] -Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64, avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#22 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] -Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70, avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71, avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72, avg(cast(c_birth_year#22 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] +Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)), avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)), avg(cast(c_birth_year#22 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] +Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63, avg(UnscaledValue(cs_sales_price#6))#64, avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#22 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] +Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 / 100.0) as decimal(16,6)) AS agg4#71, cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS agg5#72, avg(cast(c_birth_year#22 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] (49) TakeOrderedAndProject Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68, agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt index 276165729be54..079bb6aba3ec8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18.sf100/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7] WholeStageCodegen (14) - HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as decimal(12,2))),avg(cast(cs_coupon_amt as decimal(12,2))),avg(cast(cs_sales_price as decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] + HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] InputAdapter Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1 WholeStageCodegen (13) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt index 7db1c87c52a6a..8f25c83767ffc 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/explain.txt @@ -227,7 +227,7 @@ Arguments: [[cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, (40) HashAggregate [codegen id : 7] Input [12]: [cs_quantity#4, cs_list_price#5, cs_sales_price#6, cs_coupon_amt#7, cs_net_profit#8, cd_dep_count#14, c_birth_year#19, i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(cast(cs_list_price#5 as decimal(12,2))), partial_avg(cast(cs_coupon_amt#7 as decimal(12,2))), partial_avg(cast(cs_sales_price#6 as decimal(12,2))), partial_avg(cast(cs_net_profit#8 as decimal(12,2))), partial_avg(cast(c_birth_year#19 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] +Functions [7]: [partial_avg(cast(cs_quantity#4 as decimal(12,2))), partial_avg(UnscaledValue(cs_list_price#5)), partial_avg(UnscaledValue(cs_coupon_amt#7)), partial_avg(UnscaledValue(cs_sales_price#6)), partial_avg(UnscaledValue(cs_net_profit#8)), partial_avg(cast(c_birth_year#19 as decimal(12,2))), partial_avg(cast(cd_dep_count#14 as decimal(12,2)))] Aggregate Attributes [14]: [sum#33, count#34, sum#35, count#36, sum#37, count#38, sum#39, count#40, sum#41, count#42, sum#43, count#44, sum#45, count#46] Results [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] @@ -238,9 +238,9 @@ Arguments: hashpartitioning(i_item_id#28, ca_country#29, ca_state#30, ca_county# (42) HashAggregate [codegen id : 8] Input [19]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32, sum#47, count#48, sum#49, count#50, sum#51, count#52, sum#53, count#54, sum#55, count#56, sum#57, count#58, sum#59, count#60] Keys [5]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, spark_grouping_id#32] -Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(cast(cs_list_price#5 as decimal(12,2))), avg(cast(cs_coupon_amt#7 as decimal(12,2))), avg(cast(cs_sales_price#6 as decimal(12,2))), avg(cast(cs_net_profit#8 as decimal(12,2))), avg(cast(c_birth_year#19 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] -Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(cast(cs_list_price#5 as decimal(12,2)))#62, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63, avg(cast(cs_sales_price#6 as decimal(12,2)))#64, avg(cast(cs_net_profit#8 as decimal(12,2)))#65, avg(cast(c_birth_year#19 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] -Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, avg(cast(cs_list_price#5 as decimal(12,2)))#62 AS agg2#69, avg(cast(cs_coupon_amt#7 as decimal(12,2)))#63 AS agg3#70, avg(cast(cs_sales_price#6 as decimal(12,2)))#64 AS agg4#71, avg(cast(cs_net_profit#8 as decimal(12,2)))#65 AS agg5#72, avg(cast(c_birth_year#19 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] +Functions [7]: [avg(cast(cs_quantity#4 as decimal(12,2))), avg(UnscaledValue(cs_list_price#5)), avg(UnscaledValue(cs_coupon_amt#7)), avg(UnscaledValue(cs_sales_price#6)), avg(UnscaledValue(cs_net_profit#8)), avg(cast(c_birth_year#19 as decimal(12,2))), avg(cast(cd_dep_count#14 as decimal(12,2)))] +Aggregate Attributes [7]: [avg(cast(cs_quantity#4 as decimal(12,2)))#61, avg(UnscaledValue(cs_list_price#5))#62, avg(UnscaledValue(cs_coupon_amt#7))#63, avg(UnscaledValue(cs_sales_price#6))#64, avg(UnscaledValue(cs_net_profit#8))#65, avg(cast(c_birth_year#19 as decimal(12,2)))#66, avg(cast(cd_dep_count#14 as decimal(12,2)))#67] +Results [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, avg(cast(cs_quantity#4 as decimal(12,2)))#61 AS agg1#68, cast((avg(UnscaledValue(cs_list_price#5))#62 / 100.0) as decimal(16,6)) AS agg2#69, cast((avg(UnscaledValue(cs_coupon_amt#7))#63 / 100.0) as decimal(16,6)) AS agg3#70, cast((avg(UnscaledValue(cs_sales_price#6))#64 / 100.0) as decimal(16,6)) AS agg4#71, cast((avg(UnscaledValue(cs_net_profit#8))#65 / 100.0) as decimal(16,6)) AS agg5#72, avg(cast(c_birth_year#19 as decimal(12,2)))#66 AS agg6#73, avg(cast(cd_dep_count#14 as decimal(12,2)))#67 AS agg7#74] (43) TakeOrderedAndProject Input [11]: [i_item_id#28, ca_country#29, ca_state#30, ca_county#31, agg1#68, agg2#69, agg3#70, agg4#71, agg5#72, agg6#73, agg7#74] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt index 269bfd3f44fcb..7c3075e26fa23 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q18/simplified.txt @@ -1,6 +1,6 @@ TakeOrderedAndProject [ca_country,ca_state,ca_county,i_item_id,agg1,agg2,agg3,agg4,agg5,agg6,agg7] WholeStageCodegen (8) - HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(cast(cs_list_price as decimal(12,2))),avg(cast(cs_coupon_amt as decimal(12,2))),avg(cast(cs_sales_price as decimal(12,2))),avg(cast(cs_net_profit as decimal(12,2))),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] + HashAggregate [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] [avg(cast(cs_quantity as decimal(12,2))),avg(UnscaledValue(cs_list_price)),avg(UnscaledValue(cs_coupon_amt)),avg(UnscaledValue(cs_sales_price)),avg(UnscaledValue(cs_net_profit)),avg(cast(c_birth_year as decimal(12,2))),avg(cast(cd_dep_count as decimal(12,2))),agg1,agg2,agg3,agg4,agg5,agg6,agg7,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count,sum,count] InputAdapter Exchange [i_item_id,ca_country,ca_state,ca_county,spark_grouping_id] #1 WholeStageCodegen (7) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3c99c975977a2..180dd5d5db949 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -22,7 +22,9 @@ import java.util.Locale import scala.util.Random +import org.scalacheck.Gen import org.scalatest.matchers.must.Matchers.the +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.plans.logical.Expand @@ -47,7 +49,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub @SlowSQLTest class DataFrameAggregateSuite extends SharedSparkSession - with AdaptiveSparkPlanHelper { + with AdaptiveSparkPlanHelper + with ScalaCheckDrivenPropertyChecks { import testImplicits._ val absTol = 1e-8 @@ -4816,6 +4819,193 @@ class DataFrameAggregateSuite extends SharedSparkSession assert(estimate != null) assert(estimate.asInstanceOf[Double] == 2.0) } + + // Numerical-equivalence property (sql-core layer). + // + // Sweeps the (p, p', s, n) lattice where the widened-cast peel fires, + // asserting that SUM(CAST(x AS DECIMAL(p', s))) on an on-vs-off SQLConf + // pair returns bit-equal java.math.BigDecimal (same unscaled value AND + // same scale). Domain is restricted to the non-overflow regime so the + // peeled LONG accumulator cannot wrap. + // + // Non-overflow bound: with |unscaled(x)| < 10^p, p <= 8, n <= 1000, + // worst-case accumulator is 1000 * (10^8 - 1) < 10^12 << 2^63. + // + // A wide-target-scale fixed witness (p=8, p'=30, s=2) is exercised below + // as a unit case to guarantee a hand-enumerated boundary even if the + // property generator shrinks. + + private case class PeelDomain(p: Int, pPrime: Int, s: Int) + + private val peelDomainGen: Gen[PeelDomain] = (for { + p <- Gen.choose(1, 8) + pPrime <- Gen.choose(math.max(p + 1, 9), 28) + s <- Gen.choose(0, p) + } yield PeelDomain(p, pPrime, s)) + .retryUntil(d => d.p + 10 <= 18 && d.p < d.pPrime && d.pPrime + 10 <= 38) + + // Reference SUM via java.math.BigDecimal at the widened target scale. + // Inside the non-overflow domain (|sum unscaled| < 10^(p+10)) this is + // bit-exact equivalent to both the peeled and the baseline plan, so we + // can pin the peeled result against an external oracle without depending + // on a baseline plan we no longer exercise. + private def referenceSum( + unscaledLongs: Seq[Long], d: PeelDomain): java.math.BigDecimal = { + if (unscaledLongs.isEmpty) { + null + } else { + val acc = unscaledLongs + .map(u => java.math.BigDecimal.valueOf(u, d.s)) + .foldLeft(java.math.BigDecimal.ZERO)(_.add(_)) + acc.setScale(d.s) + } + } + + private def sumCastResult( + unscaledLongs: Seq[Long], d: PeelDomain): java.math.BigDecimal = { + // Use an explicit DecimalType(p, s) schema rather than Scala-tuple + // inference. createDataFrame on Tuple1[java.math.BigDecimal] infers + // DecimalType.SYSTEM_DEFAULT (38, 18), which would force the subsequent + // CAST to widen from (38, 18) -> (pPrime, s) rather than from the + // intended narrow (p, s) -> (pPrime, s) widening, defeating the + // WidenedDecimalChild trigger and silently exercising the wrong rule arm. + val rows = unscaledLongs.map(u => Row(java.math.BigDecimal.valueOf(u, d.s))) + val schema = StructType(StructField("x", DecimalType(d.p, d.s)) :: Nil) + val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + assert(df.schema("x").dataType == DecimalType(d.p, d.s), + s"expected inner schema DecimalType(${d.p}, ${d.s}), got ${df.schema("x").dataType}") + df.select(sum(col("x").cast(DecimalType(d.pPrime, d.s))).as("s")) + .collect()(0).getDecimal(0) + } + + test("SPARK-56627: DecimalAggregates widened-Cast SUM peel -- numerical " + + "equivalence property (sql-core layer)") { + val combinedGen: Gen[(PeelDomain, List[Long])] = for { + d <- peelDomainGen + upper = math.pow(10, d.p).toLong - 1 + n <- Gen.choose(1, 1000) + xs <- Gen.listOfN(n, Gen.choose(-upper, upper)) + } yield (d, xs) + forAll(combinedGen, minSuccessful(20), sizeRange(0)) { case (d, xs) => + val r = sumCastResult(xs, d) + val ref = referenceSum(xs, d) + assert(r.compareTo(ref) == 0, + s"peel result diverges from BigDecimal reference for " + + s"PeelDomain(p=${d.p}, pPrime=${d.pPrime}, s=${d.s}), n=${xs.size}, " + + s"sample=${xs.take(3)}, got=$r ref=$ref") + } + } + + // Wide target-scale fixed witness: (p=8, p'=30, s=2). Hand-enumerated so a + // wide target scale case is always exercised even if property shrinks. + test("SPARK-56627: SUM(CAST(dec(8,2) AS dec(30,2))) matches BigDecimal " + + "reference (wide-target-scale fixed witness, sql-core)") { + val d = PeelDomain(8, 30, 2) + val xs = Seq(0L, 1L, -1L, 99999999L, -99999999L, 12345678L, -87654321L) + val r = sumCastResult(xs, d) + val ref = referenceSum(xs, d) + assert(r.compareTo(ref) == 0, s"got=$r ref=$ref") + } + + // AVG widened-Cast peel: equivalence property (sql-core layer). + // + // Oracle: peel(AVG(CAST(x AS dec(pPrime, s)))) must be observationally + // identical to the existing fast path on AVG(x) directly. Both arms in + // Optimizer.DecimalAggregates produce + // Cast(Divide(Avg(UnscaledValue()), Lit(10^s, Double)), + // DecimalType.bounded(, s + 4)) + // and the peel arm makes equal to the user's column, so the + // Double-divide dividends are bit-identical between the two paths; only + // the outer Cast target precision differs (pPrime+4 vs p+4), a widening + // precision Cast that preserves numerical value. We therefore assert + // BigDecimal.compareTo == 0 (value equality across differing precisions). + // + // Domain: inner p in [1, 7] (the AVG strict-subset guard + // `AVG_PEEL_MAX_INNER_PRECISION = 7`), pPrime in [8, 11] (the band where + // the existing `Average(DecimalExpression)` arm would intercept on the + // outer Cast type if not for our prepended arm), s in [0, p], + // n <= 1000 rows. The inner DataFrame schema is constructed as + // DecimalType(p, s) explicitly (NOT via tuple-inference, which would + // infer DecimalType.SYSTEM_DEFAULT and silently route through a DIFFERENT + // rule arm than intended -- the failure mode this PBT must lock down). + private case class AvgDomain(p: Int, pPrime: Int, s: Int) + + private val avgDomainGen: Gen[AvgDomain] = (for { + p <- Gen.choose(1, 7) + pPrime <- Gen.choose(8, 11) + s <- Gen.choose(0, p) + } yield AvgDomain(p, pPrime, s)) + .retryUntil(d => d.p < d.pPrime) + + private def avgInputDf(unscaledLongs: Seq[Long], d: AvgDomain) = { + val rows = unscaledLongs.map(u => Row(java.math.BigDecimal.valueOf(u, d.s))) + val schema = StructType(StructField("x", DecimalType(d.p, d.s)) :: Nil) + val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + assert(df.schema("x").dataType == DecimalType(d.p, d.s), + s"expected inner schema DecimalType(${d.p}, ${d.s}), got ${df.schema("x").dataType}") + df + } + + private def avgCastResult( + unscaledLongs: Seq[Long], d: AvgDomain): java.math.BigDecimal = { + avgInputDf(unscaledLongs, d) + .select(avg(col("x").cast(DecimalType(d.pPrime, d.s))).as("a")) + .collect()(0).getDecimal(0) + } + + private def avgDirectResult( + unscaledLongs: Seq[Long], d: AvgDomain): java.math.BigDecimal = { + avgInputDf(unscaledLongs, d) + .select(avg(col("x")).as("a")) + .collect()(0).getDecimal(0) + } + + test("SPARK-56627: DecimalAggregates widened-Cast AVG peel -- " + + "equivalence vs unpeeled AVG (sql-core)") { + val combinedGen: Gen[(AvgDomain, List[Long])] = for { + d <- avgDomainGen + upper = math.pow(10, d.p).toLong - 1 + n <- Gen.choose(1, 1000) + xs <- Gen.listOfN(n, Gen.choose(-upper, upper)) + } yield (d, xs) + forAll(combinedGen, minSuccessful(20), sizeRange(0)) { case (d, xs) => + val peeled = avgCastResult(xs, d) + val direct = avgDirectResult(xs, d) + // BigDecimal.compareTo ignores trailing-zero precision differences: + // peeled has output DecimalType.bounded(pPrime+4, s+4), direct has + // DecimalType(p+4, s+4). Both wrap the same Double-divide bit pattern + // so the underlying value is identical. + assert(peeled.compareTo(direct) == 0, + s"peeled AVG diverges from unpeeled AVG for " + + s"AvgDomain(p=${d.p}, pPrime=${d.pPrime}, s=${d.s}), n=${xs.size}, " + + s"sample=${xs.take(3)}, peeled=$peeled direct=$direct") + } + } + + // Wider-pPrime regime shape witness: (p=4, p'=20, s=2). The equivalence + // PBT above only covers pPrime in [8, 11] (where the existing AVG arm + // would otherwise intercept and provide a comparable oracle). For pPrime + // outside that band the new arm still fires (only constrained by inner + // p <= 7), but the comparison oracle "AVG(x) directly" is no longer + // available because the existing arm targets a narrower output type. + // This witness asserts non-null result and the expected widened output + // schema, locking the rule's shape contract without claiming an + // unreachable oracle. + test("SPARK-56627: AVG(CAST(dec(4,2) AS dec(20,2))) peels and yields " + + "widened output schema (wider-pPrime regime shape witness)") { + val rows = Seq(123L, -456L, 789L, 0L) + .map(u => Row(java.math.BigDecimal.valueOf(u, 2))) + val schema = StructType(StructField("x", DecimalType(4, 2)) :: Nil) + val df = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema) + .select(avg(col("x").cast(DecimalType(20, 2))).as("a")) + val row = df.collect()(0) + assert(!row.isNullAt(0), s"expected non-null AVG, got null; df schema = ${df.schema}") + val outType = df.schema("a").dataType.asInstanceOf[DecimalType] + // Widened-arm output Cast target = DecimalType.bounded(pPrime + 4, s + 4) + // = DecimalType.bounded(24, 6). + assert(outType.precision == 24 && outType.scale == 6, + s"expected DecimalType(24, 6) from widened-arm peel, got $outType") + } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala new file mode 100644 index 0000000000000..fc00bea62dd16 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DecimalAggregatesBenchmark.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.Decimal + +/** + * Benchmark for the DecimalAggregates widened-cast peel optimizer rule + * (both SUM and AVG arms). + * + * Each case is a three-way comparison on the same `DECIMAL(p, s)` input: + * 1. `native` -- query writes `SUM(x)` / `AVG(x)` directly, + * hitting the existing fast path (rule on). + * 2. `widened, peel off` -- query writes `SUM(CAST(x AS DECIMAL(p', s)))` + * with `DecimalAggregates` excluded; the cast + * defeats the existing fast path, so the + * baseline `CheckOverflow` path runs. + * 3. `widened, peel on` -- same widened query with the rule enabled; + * the new peel arm strips the cast and + * restores the fast path. + * + * Reviewer story: + * - `native` vs `widened, peel off` -- shows the widening cast really + * evicts user-visible work onto the slow path (motivation). + * - `widened, peel off` vs `widened, peel on` -- shows the peel rule + * recovers the lost performance (rule benefit). + * - `widened, peel on` vs `native` -- shows the peel makes the cast + * effectively free (rule correctness echo of the numerical-equivalence + * PBT in `DataFrameAggregateSuite`). + * + * Sections: + * A -- Aggregate SUM widened-cast sweep over (p, s, p') cases. + * B -- Aggregate AVG widened-cast sweep (p <= 7 per + * AVG_PEEL_MAX_INNER_PRECISION). + * + * NOTE on Window arm: the optimizer does not extend widened-Cast peel to + * the Window arm (see DecimalAggregates rule comment) because the analyzer + * hoists the Cast into a child Project, so a Window microbenchmark would + * not exercise this rule. A Window benchmark belongs with a future + * plan-layer rewrite, not here. + * + * Case design (`p+1` boundary vs `p+10`-class wider) deliberately includes + * both the minimum widening (one extra digit, e.g. `dec(7,2) -> dec(8,2)`) + * and a production-canonical wider one (e.g. `dec(7,2) -> dec(17,2)` is the + * inner-widened-decimal shape in TPC-DS q18) so reviewers see whether + * widening magnitude matters and whether the canonical shape behaves like + * the boundary one. + * + * Args: aN (Section A/B row count), iters, apl + * (`spark.sql.decimalOperations.allowPrecisionLoss`; default true). + * Defaults committed for GHA: aN=10000000, iters=5, apl=true. + * + * To run this benchmark locally (pre-GHA smoke): + * {{{ + * build/sbt "sql/Test/runMain \ + * org.apache.spark.sql.execution.benchmark.DecimalAggregatesBenchmark \ + * 10000000 5" + * }}} + * + * To regenerate committed baselines (via `benchmark.yml` GHA workflow): + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt \ + * "sql/Test/runMain org.apache.spark.sql.execution.benchmark.DecimalAggregatesBenchmark" + * }}} + * + * Committed results: + * `sql/core/benchmarks/DecimalAggregatesBenchmark-results.txt` (JDK 17). + * `sql/core/benchmarks/DecimalAggregatesBenchmark-jdk21-results.txt`. + * `sql/core/benchmarks/DecimalAggregatesBenchmark-jdk25-results.txt`. + */ +object DecimalAggregatesBenchmark extends SqlBasedBenchmark { + + /** + * Aggregate SUM cases: (label, p, s, widened p'). + * + * All `p + 10 <= 18` so the *native* `SUM(x)` query hits the existing + * SUM Long fast path -- providing a meaningful baseline for the + * peel-on leg. Coverage: `p+1` boundary widening (A1, A3) plus a + * `p+10`-class wider cast representative of production shapes (A2, + * A4; A2 mirrors the TPC-DS q18 inner-widened-decimal shape). + */ + private val SumAggCases: Seq[(String, Int, Int, Int)] = Seq( + ("A1 p=7 s=2 p'=8", 7, 2, 8), // p+1 boundary + ("A2 p=7 s=2 p'=17", 7, 2, 17), // p+10, canonical TPC-DS q18 shape + ("A3 p=5 s=0 p'=6", 5, 0, 6), // p+1 boundary, zero scale + ("A4 p=5 s=0 p'=15", 5, 0, 15) // p+10, zero scale + ) + + /** + * Aggregate AVG cases: (label, p, s, widened p'). + * + * All `p <= 7` per the conservative `AVG_PEEL_MAX_INNER_PRECISION = 7` + * guard (see design doc 0001 rev 7 S7.1 -- strict-subset narrowing so + * SPARK-37024 Double-regime exposure is NOT amplified by this rule). + * Same `p+1` / `p+10` split as Section A. B2 mirrors the canonical + * TPC-DS q18 AVG shape. + */ + private val AvgAggCases: Seq[(String, Int, Int, Int)] = Seq( + ("B1 p=7 s=2 p'=8", 7, 2, 8), // p+1 boundary + ("B2 p=7 s=2 p'=12", 7, 2, 12), // canonical TPC-DS q18 AVG shape + ("B3 p=5 s=0 p'=6", 5, 0, 6), // p+1 boundary, zero scale + ("B4 p=5 s=0 p'=15", 5, 0, 15) // p+10, zero scale + ) + + /** Clamp generator to `10^(p-s) - 1` so rand() * bound fits `DECIMAL(p, s)`. */ + private def unscaledBound(p: Int, s: Int): Long = { + require(p - s >= 0, s"p=$p s=$s p-s must be non-negative") + math.pow(10.0, (p - s).toDouble).toLong - 1L + } + + private def setupAggTable(spark: org.apache.spark.sql.SparkSession, + n: Long, p: Int, s: Int): Unit = { + val bound = unscaledBound(p, s) + spark.range(n) + .selectExpr(s"cast(rand(42) * $bound as decimal($p, $s)) as x") + .coalesce(1) + .createOrReplaceTempView("t") + } + + private val ExcludedRulesKey: String = SQLConf.OPTIMIZER_EXCLUDED_RULES.key + private val DecimalAggregatesRule: String = + "org.apache.spark.sql.catalyst.optimizer.DecimalAggregates" + + /** + * Run a single three-way comparison: native (no cast, rule on), + * widened with rule off (baseline slow path), widened with rule on + * (peel restores fast path). `apl` is held identical across all three + * legs so any delta is attributable to (a) the widening cast and + * (b) the peel rule respectively. + */ + private def runThreeWay(label: String, n: Long, nativeSql: String, + widenedSql: String, iters: Int, apl: String): Unit = { + val bench = new Benchmark(label, n, output = output) + bench.addCase("native (no cast, rule on)", numIters = iters) { _ => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> apl) { + spark.sql(nativeSql).noop() + } + } + bench.addCase("widened cast, peel off", numIters = iters) { _ => + withSQLConf( + ExcludedRulesKey -> DecimalAggregatesRule, + SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> apl) { + spark.sql(widenedSql).noop() + } + } + bench.addCase("widened cast, peel on", numIters = iters) { _ => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> apl) { + spark.sql(widenedSql).noop() + } + } + bench.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val aN: Long = if (mainArgs.length > 0) mainArgs(0).toLong else 10L * 1000L * 1000L + val iters: Int = if (mainArgs.length > 1) mainArgs(1).toInt else 5 + val apl: String = if (mainArgs.length > 2) mainArgs(2) else "true" + + require(Decimal.MAX_LONG_DIGITS == 18, + s"Decimal.MAX_LONG_DIGITS drift: expected 18 got ${Decimal.MAX_LONG_DIGITS}") + + // Section A -- Aggregate SUM widened-cast. + runBenchmark("DecimalAggregates SUM widened-cast peel (Aggregate)") { + SumAggCases.foreach { case (label, p, s, pPrime) => + require(pPrime > p, s"$label: p'=$pPrime must exceed p=$p") + require(p + 10 <= 18, + s"$label: p=$p violates SUM Long fast path guard p+10<=MAX_LONG_DIGITS=18; " + + s"native baseline would not be meaningful") + setupAggTable(spark, aN, p, s) + runThreeWay(label, aN, + nativeSql = "select sum(x) from t", + widenedSql = s"select sum(cast(x as decimal($pPrime, $s))) from t", + iters, apl) + } + } + + // Section B -- Aggregate AVG widened-cast. + runBenchmark("DecimalAggregates AVG widened-cast peel (Aggregate)") { + AvgAggCases.foreach { case (label, p, s, pPrime) => + require(pPrime > p, s"$label: p'=$pPrime must exceed p=$p") + require(p <= 7, + s"$label: p=$p violates conservative AVG_PEEL_MAX_INNER_PRECISION=7 guard") + setupAggTable(spark, aN, p, s) + runThreeWay(label, aN, + nativeSql = "select avg(x) from t", + widenedSql = s"select avg(cast(x as decimal($pPrime, $s))) from t", + iters, apl) + } + } + } +}