diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java new file mode 100644 index 0000000000000..a8a37a70161a5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.QueryContext; +import org.apache.spark.sql.catalyst.util.StringUtils; +import org.apache.spark.sql.errors.QueryExecutionErrors; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Static helpers used by {@code Cast.doGenCode} (and corresponding eval + * paths) for ANSI overflow-checked narrowing to {@code byte} / {@code short}. + * + *
Narrowing to {@code int} / {@code long} is handled by calling the existing + * {@code LongExactNumeric} / {@code FloatExactNumeric} / {@code DoubleExactNumeric} + * Scala objects directly from codegen (see SPARK-56909). The helpers below + * cover {@code byte} / {@code short} only, since {@code ByteExactNumeric} / + * {@code ShortExactNumeric} don't expose a cross-type narrowing API. + * + *
The source and target {@link DataType} objects referenced by the overflow + * error message are held in {@code private static final} fields so the happy + * path performs no per-row {@code references[]} lookups. + */ +public final class CastUtils { + + private CastUtils() {} + + private static final DataType SHORT = DataTypes.ShortType; + private static final DataType INT = DataTypes.IntegerType; + private static final DataType LONG = DataTypes.LongType; + private static final DataType BYTE = DataTypes.ByteType; + private static final DataType FLOAT = DataTypes.FloatType; + private static final DataType DOUBLE = DataTypes.DoubleType; + + // ----- integral narrowing (ANSI: throw on overflow) ----- + + public static byte shortToByteExact(short v) { + if (v == (byte) v) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, SHORT, BYTE); + } + + public static byte intToByteExact(int v) { + if (v == (byte) v) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, INT, BYTE); + } + + public static byte longToByteExact(long v) { + if (v == (byte) v) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, BYTE); + } + + public static short intToShortExact(int v) { + if (v == (short) v) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, INT, SHORT); + } + + public static short longToShortExact(long v) { + if (v == (short) v) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, SHORT); + } + + // ----- fractional -> integral (ANSI: throw on overflow) ----- + // Mirrors castFractionToIntegralTypeCode: floor(v) <= MAX && ceil(v) >= MIN. + + public static byte floatToByteExact(float v) { + if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, BYTE); + } + + public static byte doubleToByteExact(double v) { + if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE) return (byte) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, BYTE); + } + + public static short floatToShortExact(float v) { + if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, SHORT); + } + + public static short doubleToShortExact(double v) { + if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE) return (short) v; + throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, SHORT); + } + + // ----- decimal precision adjustment ----- + // Mutates the input Decimal in place. Used by Cast.changePrecision (and by + // BinaryArithmetic / DivModLike in follow-up PRs) to apply the target + // precision/scale on the per-row hot path. + + public static Decimal changePrecisionExact( + Decimal d, int precision, int scale, QueryContext context) { + if (d.changePrecision(precision, scale)) return d; + throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(d, precision, scale, context); + } + + public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) { + return d.changePrecision(precision, scale) ? d : null; + } + + // ----- string -> boolean (ANSI: throw on invalid syntax) ----- + + public static boolean stringToBooleanExact(UTF8String s, QueryContext context) { + if (StringUtils.isTrueString(s)) return true; + if (StringUtils.isFalseString(s)) return false; + throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, context); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 419ca3f32d888..e97d1e1812408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -695,6 +695,8 @@ case class Cast( // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { + case _: StringType if ansiEnabled => + buildCast[UTF8String](_, s => CastUtils.stringToBooleanExact(s, getContextOrNull())) case _: StringType => buildCast[UTF8String](_, s => { if (StringUtils.isTrueString(s)) { @@ -702,11 +704,7 @@ case class Cast( } else if (StringUtils.isFalseString(s)) { false } else { - if (ansiEnabled) { - throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull()) - } else { - null - } + null } }) case TimestampType => @@ -984,6 +982,14 @@ case class Cast( errorOrNull(t, from, ShortType) } }) + case IntegerType if ansiEnabled => + b => CastUtils.intToShortExact(b.asInstanceOf[Int]) + case LongType if ansiEnabled => + b => CastUtils.longToShortExact(b.asInstanceOf[Long]) + case FloatType if ansiEnabled => + b => CastUtils.floatToShortExact(b.asInstanceOf[Float]) + case DoubleType if ansiEnabled => + b => CastUtils.doubleToShortExact(b.asInstanceOf[Double]) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -1040,6 +1046,16 @@ case class Cast( errorOrNull(t, from, ByteType) } }) + case ShortType if ansiEnabled => + b => CastUtils.shortToByteExact(b.asInstanceOf[Short]) + case IntegerType if ansiEnabled => + b => CastUtils.intToByteExact(b.asInstanceOf[Int]) + case LongType if ansiEnabled => + b => CastUtils.longToByteExact(b.asInstanceOf[Long]) + case FloatType if ansiEnabled => + b => CastUtils.floatToByteExact(b.asInstanceOf[Float]) + case DoubleType if ansiEnabled => + b => CastUtils.doubleToByteExact(b.asInstanceOf[Double]) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -1079,15 +1095,11 @@ case class Cast( value: Decimal, decimalType: DecimalType, nullOnOverflow: Boolean): Decimal = { - if (value.changePrecision(decimalType.precision, decimalType.scale)) { - value + if (nullOnOverflow) { + CastUtils.changePrecisionOrNull(value, decimalType.precision, decimalType.scale) } else { - if (nullOnOverflow) { - null - } else { - throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - value, decimalType.precision, decimalType.scale, getContextOrNull()) - } + CastUtils.changePrecisionExact( + value, decimalType.precision, decimalType.scale, getContextOrNull()) } } @@ -1540,23 +1552,21 @@ case class Cast( |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); |$evPrim = $d; """.stripMargin - } else { - val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) - val overflowCode = if (nullOnOverflow) { - s"$evNull = true;" - } else { - s""" - |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( - | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); - """.stripMargin - } + } else if (nullOnOverflow) { code""" |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { | $evPrim = $d; |} else { - | $overflowCode + | $evNull = true; |} """.stripMargin + } else { + val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow) + val castUtils = classOf[CastUtils].getName + code""" + |$evPrim = $castUtils.changePrecisionExact( + | $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode); + """.stripMargin } } @@ -1869,22 +1879,20 @@ case class Cast( private[this] def castToBooleanCode( from: DataType, ctx: CodegenContext): CastFunction = from match { + case _: StringType if ansiEnabled => + val castUtils = classOf[CastUtils].getName + val errorContext = getContextOrNullCode(ctx) + (c, evPrim, _) => code"$evPrim = $castUtils.stringToBooleanExact($c, $errorContext);" case _: StringType => val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}" (c, evPrim, evNull) => - val castFailureCode = if (ansiEnabled) { - val errorContext = getContextOrNullCode(ctx) - s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);" - } else { - s"$evNull = true;" - } code""" if ($stringUtils.isTrueString($c)) { $evPrim = true; } else if ($stringUtils.isFalseString($c)) { $evPrim = false; } else { - $castFailureCode + $evNull = true; } """ case TimestampType => @@ -1999,30 +2007,15 @@ case class Cast( }).getClass.getCanonicalName.stripSuffix("$") (c, evPrim, _) => code"$evPrim = $numericObj.toInt($c);" } else { - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - (c, evPrim, _) => - code""" - if ($c == ($integralType) $c) { - $evPrim = ($integralType) $c; - } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); - } - """ + // Byte / short narrowing: call the matching CastUtils helper. Existing *ExactNumeric + // objects don't expose cross-type narrowing to byte / short (their toByte / toShort are + // same-type identities), so a Java helper is the cleanest fit. + val castUtils = classOf[CastUtils].getName + val method = s"${integralPrefix(from)}To${integralType.capitalize}Exact" + (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" } } - - private[this] def lowerAndUpperBound(integralType: String): (String, String) = { - val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) match { - case "long" => (Long.MinValue, Long.MaxValue, "L") - case "int" => (Int.MinValue, Int.MaxValue, "") - case "short" => (Short.MinValue, Short.MaxValue, "") - case "byte" => (Byte.MinValue, Byte.MaxValue, "") - } - (min.toString + typeIndicator, max.toString + typeIndicator) - } - private[this] def castFractionToIntegralTypeCode( ctx: CodegenContext, integralType: String, @@ -2042,26 +2035,29 @@ case class Cast( val method = s"to${integralType.capitalize}" (c, evPrim, _) => code"$evPrim = $numericObj.$method($c);" } else { - val (min, max) = lowerAndUpperBound(integralType) - val mathClass = classOf[Math].getName - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - // When casting floating values to integral types, Spark uses the method `Numeric.toInt` - // Or `Numeric.toLong` directly. For positive floating values, it is equivalent to - // `Math.floor`; for negative floating values, it is equivalent to `Math.ceil`. - // So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound` - // to check if the floating value x is in the range of an integral type after rounding. - (c, evPrim, _) => - code""" - if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) { - $evPrim = ($integralType) $c; - } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); - } - """ + // Float / double -> byte / short: same rationale as the integral byte / short branch + // above -- no equivalent *ExactNumeric API, so route through CastUtils. + val castUtils = classOf[CastUtils].getName + val method = s"${fractionalPrefix(from)}To${integralType.capitalize}Exact" + (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);" } } + private[this] def integralPrefix(from: DataType): String = from match { + case ShortType => "short" + case IntegerType => "int" + case LongType => "long" + case _ => throw SparkException.internalError( + s"Unexpected source type $from for castIntegralTypeToIntegralTypeExactCode") + } + + private[this] def fractionalPrefix(from: DataType): String = from match { + case FloatType => "float" + case DoubleType => "double" + case _ => throw SparkException.internalError( + s"Unexpected source type $from for castFractionToIntegralTypeCode") + } + private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case _: StringType if ansiEnabled => val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")