From 55ec5d5051c632bb4f165184179736974fcb1c47 Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Tue, 18 Jul 2023 08:21:35 +0900 Subject: [PATCH 001/986] [SPARK-44453][PYTHON] Use difflib to display errors in assertDataFrameEqual ### What changes were proposed in this pull request? This PR uses the built-in Python library, difflib, to display errors in the testing util `assertDataFrameEqual` ### Why are the changes needed? The change makes the error message output more user-friendly, as well as consistent with `assertSchemaEqual` ### Does this PR introduce _any_ user-facing change? Yes, the PR changes the test util output for the user-facing util function `assertDataFrameEqual`. ### How was this patch tested? Existing tests in `runtime/python/pyspark/sql/tests/test_utils.py` and `runtime/python/pyspark/sql/tests/connect/test_utils.py` Example output: Screenshot 2023-07-16 at 8 20 31 PM Screenshot 2023-07-16 at 8 20 41 PM Closes #42031 from asl3/difflib-assertdfequal. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon (cherry picked from commit aa688106df02c9d31e3c93be4c5e28a8e8aec92b) Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_utils.py | 177 +++++++++---------------- python/pyspark/testing/utils.py | 33 +++-- 2 files changed, 82 insertions(+), 128 deletions(-) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index eae3f52850419..9c31eb4d6bdd1 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -151,19 +151,14 @@ def test_assert_approx_equal_arraytype_float_default_rtol_fail(self): expected_error_message = "Results do not match: " percent_diff = (1 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff - diff_msg = ( - "[actual]" - + "\n" - + str(df1.collect()[1]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[1]) - + "\n\n" - + "********************" - + "\n\n" - ) - expected_error_message += "\n" + diff_msg + + generated_diff = difflib.ndiff( + str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + ) + diff_msg = "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + + expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -294,19 +289,14 @@ def test_assert_notequal_arraytype(self): expected_error_message = "Results do not match: " percent_diff = (1 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff - diff_msg = ( - "[actual]" - + "\n" - + str(df1.collect()[1]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[1]) - + "\n\n" - + "********************" - + "\n\n" - ) - expected_error_message += "\n" + diff_msg + + generated_diff = difflib.ndiff( + str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + ) + diff_msg = "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + + expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -598,19 +588,14 @@ def test_assert_notequal_nullval(self): expected_error_message = "Results do not match: " percent_diff = (1 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff - diff_msg = ( - "[actual]" - + "\n" - + str(df1.collect()[1]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[1]) - + "\n\n" - + "********************" - + "\n\n" - ) - expected_error_message += "\n" + diff_msg + + generated_diff = difflib.ndiff( + str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + ) + diff_msg = "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + + expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -722,31 +707,19 @@ def test_check_row_order_error(self): expected_error_message = "Results do not match: " percent_diff = (2 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff - diff_msg = ( - "[actual]" - + "\n" - + str(df1.collect()[0]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[0]) - + "\n\n" - + "********************" - + "\n\n" - ) - diff_msg += ( - "[actual]" - + "\n" - + str(df1.collect()[1]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[1]) - + "\n\n" - + "********************" - + "\n\n" - ) - expected_error_message += "\n" + diff_msg + + generated_diff = difflib.ndiff( + str(df1.collect()[0]).splitlines(), str(df2.collect()[0]).splitlines() + ) + diff_msg = "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + generated_diff = difflib.ndiff( + str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + ) + diff_msg += "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + + expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2, checkRowOrder=True) @@ -829,31 +802,19 @@ def test_assert_pyspark_df_not_equal(self): expected_error_message = "Results do not match: " percent_diff = (2 / 3) * 100 expected_error_message += "( %.5f %% )" % percent_diff - diff_msg = ( - "[actual]" - + "\n" - + str(df1.collect()[0]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[0]) - + "\n\n" - + "********************" - + "\n\n" - ) - diff_msg += ( - "[actual]" - + "\n" - + str(df1.collect()[2]) - + "\n\n" - + "[expected]" - + "\n" - + str(df2.collect()[2]) - + "\n\n" - + "********************" - + "\n\n" - ) - expected_error_message += "\n" + diff_msg + + generated_diff = difflib.ndiff( + str(df1.collect()[0]).splitlines(), str(df2.collect()[0]).splitlines() + ) + diff_msg = "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + generated_diff = difflib.ndiff( + str(df1.collect()[2]).splitlines(), str(df2.collect()[2]).splitlines() + ) + diff_msg += "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + + expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -1197,31 +1158,19 @@ def test_list_row_unequal_schema(self): expected_error_message = "Results do not match: " percent_diff = (2 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff - diff_msg = ( - "[actual]" - + "\n" - + str(df1.collect()[0]) - + "\n\n" - + "[expected]" - + "\n" - + str(list_of_rows[0]) - + "\n\n" - + "********************" - + "\n\n" - ) - diff_msg += ( - "[actual]" - + "\n" - + str(df1.collect()[1]) - + "\n\n" - + "[expected]" - + "\n" - + str(list_of_rows[1]) - + "\n\n" - + "********************" - + "\n\n" - ) - expected_error_message += "\n" + diff_msg + + generated_diff = difflib.ndiff( + str(df1.collect()[0]).splitlines(), str(list_of_rows[0]).splitlines() + ) + diff_msg = "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + generated_diff = difflib.ndiff( + str(df1.collect()[1]).splitlines(), str(list_of_rows[1]).splitlines() + ) + diff_msg += "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" + + expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, list_of_rows) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 14db926420942..acbfb522f69f4 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -346,6 +346,9 @@ def assertDataFrameEqual( Notes ----- + When assertDataFrameEqual fails, the error message uses the Python `difflib` library to display + a diff log of each row that differs in `actual` and `expected`. + For checkRowOrder, note that PySpark DataFrame ordering is non-deterministic, unless explicitly sorted. @@ -374,15 +377,18 @@ def assertDataFrameEqual( >>> assertDataFrameEqual(df1, df2) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.667 % ) - [actual] - Row(id='1', amount=1000.0) - [expected] - Row(id='1', amount=1001.0) - [actual] - Row(id='3', amount=2000.0) - [expected] - Row(id='3', amount=2003.0) + PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % ) + --- actual + +++ expected + - Row(id='1', amount=1000.0) + ? ^ + + Row(id='1', amount=1001.0) + ? ^ + - Row(id='3', amount=2000.0) + ? ^ + + Row(id='3', amount=2003.0) + ? ^ + """ if actual is None and expected is None: return True @@ -471,15 +477,14 @@ def assert_rows_equal(rows1: List[Row], rows2: List[Row]): if not compare_rows(r1, r2): rows_equal = False diff_rows_cnt += 1 - diff_msg += ( - "[actual]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" + str(r2) + "\n\n" - ) - diff_msg += "********************" + "\n\n" + generated_diff = difflib.ndiff(str(r1).splitlines(), str(r2).splitlines()) + diff_msg += "\n" + "\n".join(generated_diff) + "\n" + diff_msg += "********************" + "\n" if not rows_equal: percent_diff = (diff_rows_cnt / len(zipped)) * 100 error_msg += "( %.5f %% )" % percent_diff - error_msg += "\n" + diff_msg + error_msg += "\n" + "--- actual\n+++ expected\n" + diff_msg raise PySparkAssertionError( error_class="DIFFERENT_ROWS", message_parameters={"error_msg": error_msg}, From 30971c41aedaed32da3670058e3a13917d6a4e0f Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Tue, 18 Jul 2023 08:50:49 +0900 Subject: [PATCH 002/986] [SPARK-44413][PYTHON] Clarify error for unsupported arg data type in assertDataFrameEqual ### What changes were proposed in this pull request? This PR adds an error class, `INVALID_TYPE_DF_EQUALITY_ARG`, to clarify the error message for unsupported argument data types when calling `assertDataFrameEqual`. ### Why are the changes needed? The fix helps clarify why an error is thrown and what is wrong when a user passes unsupported arg types into the `assertDataFrameEqual` util function. ### Does this PR introduce any user-facing change? Yes, the PR modifies error message seen by users. ### How was this patch tested? Modified tests in `runtime/python/pyspark/sql/tests/test_utils.py` and `runtime/python/pyspark/sql/tests/connect/test_utils.py` Closes #42027 from asl3/datatype-error-clarify. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon (cherry picked from commit d574dbcc85965df4a48d608230e591cc23adb525) Signed-off-by: Hyukjin Kwon --- python/pyspark/errors/error_classes.py | 5 ++++ python/pyspark/sql/tests/test_utils.py | 33 +++++++++++++++++------ python/pyspark/testing/utils.py | 36 +++++++++++++++++++------- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 2cecee4da440a..e45bc0797c95b 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -263,6 +263,11 @@ "StructField does not have typeName. Use typeName on its type explicitly instead." ] }, + "INVALID_TYPE_DF_EQUALITY_ARG" : { + "message" : [ + "Expected type for `` but got type ." + ] + }, "INVALID_UDF_EVAL_TYPE" : { "message" : [ "Eval type for UDF must be ." diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 9c31eb4d6bdd1..a1cefe7c840d6 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -39,6 +39,7 @@ IntegerType, BooleanType, ) +from pyspark.sql.dataframe import DataFrame import difflib @@ -633,8 +634,12 @@ def test_assert_error_pandas_df(self): self.check_error( exception=pe.exception, - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": pd.DataFrame}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": DataFrame, + "arg_name": "df", + "actual_type": pd.DataFrame, + }, ) with self.assertRaises(PySparkAssertionError) as pe: @@ -642,8 +647,12 @@ def test_assert_error_pandas_df(self): self.check_error( exception=pe.exception, - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": pd.DataFrame}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": DataFrame, + "arg_name": "df", + "actual_type": pd.DataFrame, + }, ) def test_assert_error_non_pyspark_df(self): @@ -655,8 +664,12 @@ def test_assert_error_non_pyspark_df(self): self.check_error( exception=pe.exception, - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(dict1)}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": DataFrame, + "arg_name": "df", + "actual_type": type(dict1), + }, ) with self.assertRaises(PySparkAssertionError) as pe: @@ -664,8 +677,12 @@ def test_assert_error_non_pyspark_df(self): self.check_error( exception=pe.exception, - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(dict1)}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": DataFrame, + "arg_name": "df", + "actual_type": type(dict1), + }, ) def test_row_order_ignored(self): diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index acbfb522f69f4..b8977b6fffd79 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -35,7 +35,7 @@ from pyspark import SparkContext, SparkConf from pyspark.errors import PySparkAssertionError, PySparkException from pyspark.find_spark_home import _find_spark_home -from pyspark.sql.dataframe import DataFrame as DataFrame +from pyspark.sql.dataframe import DataFrame from pyspark.sql import Row from pyspark.sql.types import StructType, AtomicType, StructField @@ -322,7 +322,7 @@ def assertDataFrameEqual( ): r""" A util function to assert equality between `actual` (DataFrame) and `expected` - (either DataFrame or list of Rows), with optional parameter `checkRowOrder`. + (DataFrame or list of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`. .. versionadded:: 3.5.0 @@ -401,8 +401,12 @@ def assertDataFrameEqual( if not isinstance(actual, DataFrame) and not isinstance(actual, ConnectDataFrame): raise PySparkAssertionError( - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(actual)}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": DataFrame, + "arg_name": "df", + "actual_type": type(actual), + }, ) elif ( not isinstance(expected, DataFrame) @@ -410,19 +414,31 @@ def assertDataFrameEqual( and not isinstance(expected, List) ): raise PySparkAssertionError( - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(expected)}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": type(expected), + }, ) except Exception: if not isinstance(actual, DataFrame): raise PySparkAssertionError( - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(actual)}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": DataFrame, + "arg_name": "df", + "actual_type": type(actual), + }, ) elif not isinstance(expected, DataFrame) and not isinstance(expected, List): raise PySparkAssertionError( - error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(expected)}, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": Union[DataFrame, List[Row]], + "arg_name": "expected", + "actual_type": type(expected), + }, ) # special cases: empty datasets, datasets with 0 columns From 24b63b67c36356f18b0d7cf676acf8d027c7e437 Mon Sep 17 00:00:00 2001 From: jdesjean Date: Tue, 18 Jul 2023 09:04:33 +0900 Subject: [PATCH 003/986] [SPARK-43923][CONNECT] Post listenerBus events during ExecutePlanRequest ### What changes were proposed in this pull request? Add new SparkListenerEvent during Spark Connect ExecutePlanRequest: SparkListenerConnectOperationStarted SparkListenerConnectOperationParsed SparkListenerConnectOperationCanceled, SparkListenerConnectOperationFailed SparkListenerConnectOperationFinished SparkListenerConnectOperationClosed SparkListenerConnectSessionClosed . ### Why are the changes needed? HiveThriftServer2EventManager currently posts events to the listener bus to allow external listeners to track query execution. Mirror these events in Spark Connect. Created new events instead of reusing the thrift events to allow them to evolve separately. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Manual + Unit + E2E Closes #41443 from jdesjean/SPARK-43923. Authored-by: jdesjean Signed-off-by: Hyukjin Kwon (cherry picked from commit b44e6054dac545008075d602b5d3612522fd3b3a) Signed-off-by: Hyukjin Kwon --- .../spark/sql/connect/common/ProtoUtils.scala | 10 +- .../execution/ExecuteThreadRunner.scala | 10 +- .../execution/SparkConnectPlanExecution.scala | 26 +- .../connect/planner/SparkConnectPlanner.scala | 121 ++-- .../service/ExecuteEventsManager.scala | 420 +++++++++++++ .../sql/connect/service/ExecuteHolder.scala | 3 + .../service/SessionEventsManager.scala | 128 ++++ .../sql/connect/service/SessionHolder.scala | 12 + .../SparkConnectExecutePlanHandler.scala | 2 + .../connect/service/SparkConnectService.scala | 11 +- .../spark/sql/connect/utils/ErrorUtils.scala | 29 +- .../planner/SparkConnectPlannerSuite.scala | 29 +- .../planner/SparkConnectServiceSuite.scala | 588 ++++++++++++++---- .../SparkConnectPluginRegistrySuite.scala | 6 +- .../service/ExecuteEventsManagerSuite.scala | 318 ++++++++++ .../service/SessionEventsManagerSuite.scala | 102 +++ 16 files changed, 1621 insertions(+), 194 deletions(-) create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala create mode 100644 connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala create mode 100644 connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index 83f84f45b317e..e0c7d267c604e 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -27,15 +27,15 @@ private[connect] object ProtoUtils { private val MAX_BYTES_SIZE = 8 private val MAX_STRING_SIZE = 1024 - def abbreviate(message: Message): Message = { + def abbreviate(message: Message, maxStringSize: Int = MAX_STRING_SIZE): Message = { val builder = message.toBuilder message.getAllFields.asScala.iterator.foreach { case (field: FieldDescriptor, string: String) if field.getJavaType == FieldDescriptor.JavaType.STRING && string != null => val size = string.size - if (size > MAX_STRING_SIZE) { - builder.setField(field, createString(string.take(MAX_STRING_SIZE), size)) + if (size > maxStringSize) { + builder.setField(field, createString(string.take(maxStringSize), size)) } else { builder.setField(field, string) } @@ -43,8 +43,8 @@ private[connect] object ProtoUtils { case (field: FieldDescriptor, byteString: ByteString) if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null => val size = byteString.size - if (size > MAX_BYTES_SIZE) { - val prefix = Array.tabulate(MAX_BYTES_SIZE)(byteString.byteAt) + if (size > maxStringSize) { + val prefix = Array.tabulate(maxStringSize)(byteString.byteAt) builder.setField(field, createByteString(prefix, size)) } else { builder.setField(field, byteString) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index b7b3d2adf9f79..6c2ffa4654747 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -78,7 +78,6 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends // and different exceptions like InterruptedException, ClosedByInterruptException etc. // could be thrown. if (interrupted) { - // Turn the interrupt into OPERATION_CANCELED error. throw new SparkSQLException("OPERATION_CANCELED", Map.empty) } else { // Rethrown the original error. @@ -92,7 +91,9 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends "execute", executeHolder.responseObserver, executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId) + executeHolder.sessionHolder.sessionId, + Some(executeHolder.eventsManager), + interrupted) } } @@ -148,9 +149,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends val planner = new SparkConnectPlanner(executeHolder.sessionHolder) planner.process( command = command, - userId = request.getUserContext.getUserId, - sessionId = request.getSessionId, - responseObserver = responseObserver) + responseObserver = responseObserver, + executeHolder = executeHolder) responseObserver.onCompleted() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 74b4a5f659741..d2124a38c9d4e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.execution import scala.collection.JavaConverters._ +import scala.util.{Failure, Success} import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -54,13 +55,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) throw new IllegalStateException( s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.") } - - // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(sessionHolder) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker val dataframe = - Dataset.ofRows(sessionHolder.session, planner.transformRelation(request.getPlan.getRoot)) + Dataset.ofRows( + sessionHolder.session, + planner.transformRelation(request.getPlan.getRoot), + tracker) responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) - processAsArrowBatches(request.getSessionId, dataframe, responseObserver) + processAsArrowBatches(dataframe, responseObserver, executeHolder) responseObserver.onNext( MetricGenerator.createMetricsResponse(request.getSessionId, dataframe)) if (dataframe.queryExecution.observedMetrics.nonEmpty) { @@ -87,10 +90,11 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) batches.map(b => b -> batches.rowCountInLastBatch) } - private def processAsArrowBatches( - sessionId: String, + def processAsArrowBatches( dataframe: DataFrame, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executePlan: ExecuteHolder): Unit = { + val sessionId = executePlan.sessionHolder.sessionId val spark = dataframe.sparkSession val schema = dataframe.schema val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch @@ -120,6 +124,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) dataframe.queryExecution.executedPlan match { case LocalTableScanExec(_, rows) => + executePlan.eventsManager.postFinished() converter(rows.iterator).foreach { case (bytes, count) => sendBatch(bytes, count) } @@ -156,13 +161,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) resultFunc = () => ()) // Collect errors and propagate them to the main thread. - future.onComplete { result => - result.failed.foreach { throwable => + future.onComplete { + case Success(_) => + executePlan.eventsManager.postFinished() + case Failure(throwable) => signal.synchronized { error = Some(throwable) signal.notify() } - } }(ThreadUtils.sameThread) // The main thread will wait until 0-th partition is available, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f82cb6760404d..39cb4c1b972b1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -57,8 +57,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry -import org.apache.spark.sql.connect.service.SessionHolder -import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService} import org.apache.spark.sql.connect.utils.MetricGenerator import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution @@ -86,7 +85,11 @@ final case class InvalidCommandInput( class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { - def session: SparkSession = sessionHolder.session + private[connect] def session: SparkSession = sessionHolder.session + + private[connect] def userId: String = sessionHolder.userId + + private[connect] def sessionId: String = sessionHolder.sessionId private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) @@ -2333,56 +2336,58 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def process( command: proto.Command, - userId: String, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => - handleRegisterUserDefinedFunction(command.getRegisterFunction) + handleRegisterUserDefinedFunction(command.getRegisterFunction, executeHolder) case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION => - handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction) + handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction, executeHolder) case proto.Command.CommandTypeCase.WRITE_OPERATION => - handleWriteOperation(command.getWriteOperation) + handleWriteOperation(command.getWriteOperation, executeHolder) case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => - handleCreateViewCommand(command.getCreateDataframeView) + handleCreateViewCommand(command.getCreateDataframeView, executeHolder) case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 => - handleWriteOperationV2(command.getWriteOperationV2) + handleWriteOperationV2(command.getWriteOperationV2, executeHolder) case proto.Command.CommandTypeCase.EXTENSION => - handleCommandPlugin(command.getExtension) + handleCommandPlugin(command.getExtension, executeHolder) case proto.Command.CommandTypeCase.SQL_COMMAND => - handleSqlCommand(command.getSqlCommand, sessionId, responseObserver) + handleSqlCommand(command.getSqlCommand, responseObserver, executeHolder) case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START => handleWriteStreamOperationStart( command.getWriteStreamOperationStart, - userId, - sessionId, - responseObserver) + responseObserver, + executeHolder) case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND => - handleStreamingQueryCommand(command.getStreamingQueryCommand, sessionId, responseObserver) + handleStreamingQueryCommand( + command.getStreamingQueryCommand, + responseObserver, + executeHolder) case proto.Command.CommandTypeCase.STREAMING_QUERY_MANAGER_COMMAND => handleStreamingQueryManagerCommand( command.getStreamingQueryManagerCommand, - sessionId, - responseObserver) + responseObserver, + executeHolder) case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND => - handleGetResourcesCommand(sessionId, responseObserver) + handleGetResourcesCommand(responseObserver, executeHolder) case _ => throw new UnsupportedOperationException(s"$command not supported.") } } def handleSqlCommand( getSqlCommand: SqlCommand, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { // Eagerly execute commands of the provided SQL string. val args = getSqlCommand.getArgsMap val posArgs = getSqlCommand.getPosArgsList + val tracker = executeHolder.eventsManager.createQueryPlanningTracker val df = if (!args.isEmpty) { - session.sql(getSqlCommand.getSql, args.asScala.mapValues(transformLiteral).toMap) + session.sql(getSqlCommand.getSql, args.asScala.mapValues(transformLiteral).toMap, tracker) } else if (!posArgs.isEmpty) { - session.sql(getSqlCommand.getSql, posArgs.asScala.map(transformLiteral).toArray) + session.sql(getSqlCommand.getSql, posArgs.asScala.map(transformLiteral).toArray, tracker) } else { - session.sql(getSqlCommand.getSql) + session.sql(getSqlCommand.getSql, Map.empty[String, Any], tracker) } // Check if commands have been executed. val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult] @@ -2430,6 +2435,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .newBuilder() .setData(ByteString.copyFrom(bytes)))) } else { + // Trigger assertExecutedPlanPrepared to ensure post ReadyForExecution before finished + // executedPlan is currently called by createMetricsResponse below + df.queryExecution.assertExecutedPlanPrepared() result.setRelation( proto.Relation .newBuilder() @@ -2440,6 +2448,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .putAllArgs(getSqlCommand.getArgsMap) .addAllPosArgs(getSqlCommand.getPosArgsList))) } + executeHolder.eventsManager.postFinished() // Exactly one SQL Command Result Batch responseObserver.onNext( ExecutePlanResponse @@ -2453,7 +2462,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def handleRegisterUserDefinedFunction( - fun: proto.CommonInlineUserDefinedFunction): Unit = { + fun: proto.CommonInlineUserDefinedFunction, + executeHolder: ExecuteHolder): Unit = { fun.getFunctionCase match { case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => handleRegisterPythonUDF(fun) @@ -2465,10 +2475,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") } + executeHolder.eventsManager.postFinished() } private def handleRegisterUserDefinedTableFunction( - fun: proto.CommonInlineUserDefinedTableFunction): Unit = { + fun: proto.CommonInlineUserDefinedTableFunction, + executeHolder: ExecuteHolder): Unit = { fun.getFunctionCase match { case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF => val function = createPythonUserDefinedTableFunction(fun) @@ -2477,6 +2489,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") } + executeHolder.eventsManager.postFinished() } private def createPythonUserDefinedTableFunction( @@ -2532,7 +2545,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { session.udf.register(fun.getFunctionName, udf) } - private def handleCommandPlugin(extension: ProtoAny): Unit = { + private def handleCommandPlugin(extension: ProtoAny, executeHolder: ExecuteHolder): Unit = { SparkConnectPluginRegistry.commandRegistry // Lazily traverse the collection. .view @@ -2542,9 +2555,12 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .find(_.nonEmpty) .flatten .getOrElse(throw InvalidPlanInput("No handler found for extension")) + executeHolder.eventsManager.postFinished() } - private def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = { + private def handleCreateViewCommand( + createView: proto.CreateDataFrameViewCommand, + executeHolder: ExecuteHolder): Unit = { val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView val tableIdentifier = @@ -2566,7 +2582,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { replace = createView.getReplace, viewType = viewType) - Dataset.ofRows(session, plan).queryExecution.commandExecuted + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + Dataset.ofRows(session, plan, tracker).queryExecution.commandExecuted + executeHolder.eventsManager.postFinished() } /** @@ -2578,11 +2596,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { * * @param writeOperation */ - private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { + private def handleWriteOperation( + writeOperation: proto.WriteOperation, + executeHolder: ExecuteHolder): Unit = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. - val dataset = Dataset.ofRows(session, logicalPlan = plan) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + val dataset = Dataset.ofRows(session, plan, tracker) val w = dataset.write if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { @@ -2637,6 +2658,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { "WriteOperation:SaveTypeCase not supported " + s"${writeOperation.getSaveTypeCase.getNumber}") } + executeHolder.eventsManager.postFinished() } /** @@ -2648,11 +2670,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { * * @param writeOperation */ - def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = { + def handleWriteOperationV2( + writeOperation: proto.WriteOperationV2, + executeHolder: ExecuteHolder): Unit = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. - val dataset = Dataset.ofRows(session, logicalPlan = plan) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + val dataset = Dataset.ofRows(session, plan, tracker) val w = dataset.writeTo(table = writeOperation.getTableName) @@ -2703,15 +2728,18 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw new UnsupportedOperationException( s"WriteOperationV2:ModeValue not supported ${writeOperation.getModeValue}") } + executeHolder.eventsManager.postFinished() } def handleWriteStreamOperationStart( writeOp: WriteStreamOperationStart, - userId: String, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { val plan = transformRelation(writeOp.getInput) - val dataset = Dataset.ofRows(session, logicalPlan = plan) + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + val dataset = Dataset.ofRows(session, plan, tracker) + // Call manually as writeStream does not trigger ReadyForExecution + tracker.setReadyForExecution() val writer = dataset.writeStream @@ -2789,6 +2817,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { SparkConnectService.streamingSessionManager.registerNewStreamingQuery( sessionHolder = SessionHolder(userId = userId, sessionId = sessionId, session), query = query) + executeHolder.eventsManager.postFinished() val result = WriteStreamOperationStartResult .newBuilder() @@ -2811,8 +2840,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleStreamingQueryCommand( command: StreamingQueryCommand, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { val id = command.getQueryId.getId val runId = command.getQueryId.getRunId @@ -2915,6 +2944,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw new IllegalArgumentException("Missing command in StreamingQueryCommand") } + executeHolder.eventsManager.postFinished() responseObserver.onNext( ExecutePlanResponse .newBuilder() @@ -2982,9 +3012,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { def handleStreamingQueryManagerCommand( command: StreamingQueryManagerCommand, - sessionId: String, - responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - + responseObserver: StreamObserver[ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { val respBuilder = StreamingQueryManagerCommandResult.newBuilder() command.getCommandCase match { @@ -3045,6 +3074,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { throw new IllegalArgumentException("Missing command in StreamingQueryManagerCommand") } + executeHolder.eventsManager.postFinished() responseObserver.onNext( ExecutePlanResponse .newBuilder() @@ -3054,8 +3084,9 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } def handleGetResourcesCommand( - sessionId: String, - responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { + responseObserver: StreamObserver[proto.ExecutePlanResponse], + executeHolder: ExecuteHolder): Unit = { + executeHolder.eventsManager.postFinished() responseObserver.onNext( proto.ExecutePlanResponse .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala new file mode 100644 index 0000000000000..0af54f034a254 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import com.fasterxml.jackson.annotation.JsonIgnore +import com.google.protobuf.Message + +import org.apache.spark.connect.proto +import org.apache.spark.scheduler.SparkListenerEvent +import org.apache.spark.sql.catalyst.{QueryPlanningTracker, QueryPlanningTrackerCallback} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.common.ProtoUtils +import org.apache.spark.util.{Clock, Utils} + +object ExecuteEventsManager { + // TODO: Make this configurable + val MAX_STATEMENT_TEXT_SIZE = 65535 +} + +sealed abstract class ExecuteStatus(value: Int) + +object ExecuteStatus { + case object Pending extends ExecuteStatus(0) + case object Started extends ExecuteStatus(1) + case object Analyzed extends ExecuteStatus(2) + case object ReadyForExecution extends ExecuteStatus(3) + case object Finished extends ExecuteStatus(4) + case object Failed extends ExecuteStatus(5) + case object Canceled extends ExecuteStatus(6) + case object Closed extends ExecuteStatus(7) +} + +/** + * Post request Connect events to @link org.apache.spark.scheduler.LiveListenerBus. + * + * @param executeHolder: + * Request for which the events are generated. + * @param clock: + * Source of time for unit tests. + */ +case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { + + private def operationId = executeHolder.operationId + + private def jobTag = executeHolder.jobTag + + private def listenerBus = sessionHolder.session.sparkContext.listenerBus + + private def sessionHolder = executeHolder.sessionHolder + + private def sessionId = executeHolder.request.getSessionId + + private def sessionStatus = sessionHolder.eventManager.status + + private var _status: ExecuteStatus = ExecuteStatus.Pending + + private var error = Option.empty[Boolean] + + private var canceled = Option.empty[Boolean] + + /** + * @return + * Last event posted by the Connect request + */ + private[connect] def status: ExecuteStatus = _status + + /** + * @return + * True when the Connect request has posted @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationCanceled + */ + private[connect] def hasCanceled: Option[Boolean] = canceled + + /** + * @return + * True when the Connect request has posted @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationFailed + */ + private[connect] def hasError: Option[Boolean] = error + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted. + */ + def postStarted(): Unit = { + assertStatus(List(ExecuteStatus.Pending), ExecuteStatus.Started) + val request = executeHolder.request + val plan: Message = + request.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.COMMAND => request.getPlan.getCommand + case proto.Plan.OpTypeCase.ROOT => request.getPlan.getRoot + case _ => + throw new UnsupportedOperationException( + s"${request.getPlan.getOpTypeCase} not supported.") + } + + listenerBus.post( + SparkListenerConnectOperationStarted( + jobTag, + operationId, + clock.getTimeMillis(), + sessionId, + request.getUserContext.getUserId, + request.getUserContext.getUserName, + Utils.redact( + sessionHolder.session.sessionState.conf.stringRedactionPattern, + ProtoUtils.abbreviate(plan, ExecuteEventsManager.MAX_STATEMENT_TEXT_SIZE).toString), + Some(request))) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationAnalyzed. + * + * @param analyzedPlan + * The analyzed plan generated by the Connect request plan. None when the request does not + * generate a plan. + */ + def postAnalyzed(analyzedPlan: Option[LogicalPlan] = None): Unit = { + assertStatus(List(ExecuteStatus.Started, ExecuteStatus.Analyzed), ExecuteStatus.Analyzed) + val event = + SparkListenerConnectOperationAnalyzed(jobTag, operationId, clock.getTimeMillis()) + event.analyzedPlan = analyzedPlan + listenerBus.post(event) + } + + /** + * Post @link + * org.apache.spark.sql.connect.service.SparkListenerConnectOperationReadyForExecution. + */ + def postReadyForExecution(): Unit = { + assertStatus(List(ExecuteStatus.Analyzed), ExecuteStatus.ReadyForExecution) + listenerBus.post( + SparkListenerConnectOperationReadyForExecution(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationCanceled. + */ + def postCanceled(): Unit = { + assertStatus( + List( + ExecuteStatus.Started, + ExecuteStatus.Analyzed, + ExecuteStatus.ReadyForExecution, + ExecuteStatus.Finished, + ExecuteStatus.Failed), + ExecuteStatus.Canceled) + canceled = Some(true) + listenerBus + .post(SparkListenerConnectOperationCanceled(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFailed. + * + * @param errorMessage + * The message of the error thrown during the request. + */ + def postFailed(errorMessage: String): Unit = { + assertStatus( + List( + ExecuteStatus.Started, + ExecuteStatus.Analyzed, + ExecuteStatus.ReadyForExecution, + ExecuteStatus.Finished), + ExecuteStatus.Failed) + error = Some(true) + listenerBus.post( + SparkListenerConnectOperationFailed( + jobTag, + operationId, + clock.getTimeMillis(), + errorMessage)) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished. + */ + def postFinished(): Unit = { + assertStatus( + List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution), + ExecuteStatus.Finished) + listenerBus + .post(SparkListenerConnectOperationFinished(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationClosed. + */ + def postClosed(): Unit = { + assertStatus( + List(ExecuteStatus.Finished, ExecuteStatus.Failed, ExecuteStatus.Canceled), + ExecuteStatus.Closed) + listenerBus + .post(SparkListenerConnectOperationClosed(jobTag, operationId, clock.getTimeMillis())) + } + + /** + * @return + * \@link A org.apache.spark.sql.catalyst.QueryPlanningTracker that calls postAnalyzed & + * postReadyForExecution after analysis & prior execution. + */ + def createQueryPlanningTracker(): QueryPlanningTracker = { + new QueryPlanningTracker(Some(new QueryPlanningTrackerCallback { + def analyzed(tracker: QueryPlanningTracker, analyzedPlan: LogicalPlan): Unit = { + postAnalyzed(Some(analyzedPlan)) + } + + def readyForExecution(tracker: QueryPlanningTracker): Unit = postReadyForExecution + })) + } + + private[connect] def status_(executeStatus: ExecuteStatus): Unit = { + _status = executeStatus + } + + private def assertStatus( + validStatuses: List[ExecuteStatus], + eventStatus: ExecuteStatus): Unit = { + if (!validStatuses + .find(s => s == status) + .isDefined) { + throw new IllegalStateException(s""" + operationId: $operationId with status ${status} + is not within statuses $validStatuses for event $eventStatus + """) + } + if (sessionHolder.eventManager.status != SessionStatus.Started) { + throw new IllegalStateException(s""" + sessionId: $sessionId with status $sessionStatus + is not Started for event $eventStatus + """) + } + _status = eventStatus + } +} + +/** + * Event sent after reception of a Connect request (i.e. not queued), but prior any analysis or + * execution. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.setJobGroup) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param sessionId: + * ID assigned by the client or Connect the operation was executed on. + * @param userId: + * Opaque userId set in the Connect request. + * @param userName: + * Opaque userName set in the Connect request. + * @param statementText: + * The connect request plan converted to text. + * @param planRequest: + * The Connect request. None if the operation is not of type @link proto.ExecutePlanRequest + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationStarted( + jobTag: String, + operationId: String, + eventTime: Long, + sessionId: String, + userId: String, + userName: String, + statementText: String, + planRequest: Option[proto.ExecutePlanRequest], + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * The event is sent after a Connect request has been analyzed (@link + * org.apache.spark.sql.catalyst.QueryPlanningTracker.ANALYSIS). + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationAnalyzed( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent { + + /** + * Analyzed Spark plan generated by the Connect request. None when the Connect request does not + * generate a Spark plan. + */ + @JsonIgnore var analyzedPlan: Option[LogicalPlan] = None +} + +/** + * The event is sent after a Connect request is ready for execution. For eager commands this is + * after @link org.apache.spark.sql.catalyst.QueryPlanningTracker.ANALYSIS. For other requests it + * is after \@link org.apache.spark.sql.catalyst.QueryPlanningTracker.PLANNING + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationReadyForExecution( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has been canceled. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationCanceled( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has failed. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param errorMessage: + * The message of the error thrown during the request. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationFailed( + jobTag: String, + operationId: String, + eventTime: Long, + errorMessage: String, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has finished executing, but prior results have been sent to + * client. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationFinished( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect request has finished executing and results have been sent to client. + * + * @param jobTag: + * Opaque Spark jobTag (@link org.apache.spark.SparkContext.addJobTag) assigned by Connect + * during a request. Designed to be unique across sessions and requests. + * @param operationId: + * 36 characters UUID assigned by Connect during a request. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata during the request. + */ +case class SparkListenerConnectOperationClosed( + jobTag: String, + operationId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 89aceaee1e4af..1f70973b60e0d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner} +import org.apache.spark.util.SystemClock /** * Object used to hold the Spark Connect execution state. @@ -41,6 +42,8 @@ private[connect] class ExecuteHolder( val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] = new ExecuteResponseObserver[proto.ExecutePlanResponse]() + val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new SystemClock()) + private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this) /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala new file mode 100644 index 0000000000000..f275fab56bf5f --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionEventsManager.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import org.apache.spark.scheduler.SparkListenerEvent +import org.apache.spark.util.{Clock} + +sealed abstract class SessionStatus(value: Int) + +object SessionStatus { + case object Pending extends SessionStatus(0) + case object Started extends SessionStatus(1) + case object Closed extends SessionStatus(2) +} + +/** + * Post session Connect events to @link org.apache.spark.scheduler.LiveListenerBus. + * + * @param sessionHolder: + * Session for which the events are generated. + * @param clock: + * Source of time for unit tests. + */ +case class SessionEventsManager(sessionHolder: SessionHolder, clock: Clock) { + + private def sessionId = sessionHolder.sessionId + + private var _status: SessionStatus = SessionStatus.Pending + + private[connect] def status_(sessionStatus: SessionStatus): Unit = { + _status = sessionStatus + } + + /** + * @return + * Last event posted by the Connect session + */ + def status: SessionStatus = _status + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectSessionStarted. + */ + def postStarted(): Unit = { + assertStatus(List(SessionStatus.Pending), SessionStatus.Started) + sessionHolder.session.sparkContext.listenerBus + .post( + SparkListenerConnectSessionStarted( + sessionHolder.sessionId, + sessionHolder.userId, + clock.getTimeMillis())) + } + + /** + * Post @link org.apache.spark.sql.connect.service.SparkListenerConnectSessionClosed. + */ + def postClosed(): Unit = { + assertStatus(List(SessionStatus.Started), SessionStatus.Closed) + sessionHolder.session.sparkContext.listenerBus + .post( + SparkListenerConnectSessionClosed( + sessionHolder.sessionId, + sessionHolder.userId, + clock.getTimeMillis())) + } + + private def assertStatus( + validStatuses: List[SessionStatus], + eventStatus: SessionStatus): Unit = { + if (!validStatuses + .find(s => s == status) + .isDefined) { + throw new IllegalStateException(s""" + sessionId: $sessionId with status ${status} + is not within statuses $validStatuses for event $eventStatus + """) + } + _status = eventStatus + } +} + +/** + * Event sent after a Connect session has been started. + * + * @param sessionId: + * ID assigned by the client or Connect the operation was executed on. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata + */ +case class SparkListenerConnectSessionStarted( + sessionId: String, + userId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent + +/** + * Event sent after a Connect session has been closed. + * + * @param sessionId: + * ID assigned by the client or Connect the operation was executed on. + * @param eventTime: + * The time in ms when the event was generated. + * @param extraTags: + * Additional metadata + */ +case class SparkListenerConnectSessionClosed( + sessionId: String, + userId: String, + eventTime: Long, + extraTags: Map[String, String] = Map.empty) + extends SparkListenerEvent diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 2f3bd1badcecc..5ac4f6db82aa3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.util.{SystemClock} import org.apache.spark.util.Utils /** @@ -44,6 +45,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio val executions: ConcurrentMap[String, ExecuteHolder] = new ConcurrentHashMap[String, ExecuteHolder]() + val eventManager: SessionEventsManager = SessionEventsManager(this, new SystemClock()) + // Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like // foreachBatch() in Streaming. Lazy since most sessions don't need it. private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap() @@ -60,6 +63,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio executePlanHolder } + private[connect] def executeHolder(operationId: String): Option[ExecuteHolder] = { + Option(executions.get(operationId)) + } + private[connect] def removeExecuteHolder(operationId: String): Unit = { executions.remove(operationId) } @@ -98,12 +105,17 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ def classloader: ClassLoader = artifactManager.classloader + private[connect] def initializeSession(): Unit = { + eventManager.postStarted() + } + /** * Expire this session and trigger state cleanup mechanisms. */ private[connect] def expireSession(): Unit = { logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") artifactManager.cleanUpResources() + eventManager.postClosed() } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala index 50ca733b43915..b4e91c438359c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala @@ -32,6 +32,7 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec val executeHolder = sessionHolder.createExecuteHolder(v) try { + executeHolder.eventsManager.postStarted() executeHolder.start() val responseSender = new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](responseObserver) @@ -44,6 +45,7 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec } finally { // TODO this will change with detachable execution. executeHolder.join() + executeHolder.eventsManager.postClosed() sessionHolder.removeExecuteHolder(executeHolder.operationId) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index c38fbbdfcf97d..ad40c94d5498c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -223,10 +223,19 @@ object SparkConnectService { userSessionMapping.get( (userId, sessionId), () => { - SessionHolder(userId, sessionId, newIsolatedSession()) + val holder = SessionHolder(userId, sessionId, newIsolatedSession()) + holder.initializeSession() + holder }) } + /** + * Used for testing + */ + private[connect] def invalidateAllSessions(): Unit = { + userSessionMapping.invalidateAll() + } + private def newIsolatedSession(): SparkSession = { SparkSession.active.newSession() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index d0f754827dad8..326bdd0052c64 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} import org.apache.spark.api.python.PythonException import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.ExecuteEventsManager import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.internal.SQLConf @@ -103,32 +104,42 @@ private[connect] object ErrorUtils extends Logging { opType: String, observer: StreamObserver[V], userId: String, - sessionId: String): PartialFunction[Throwable, Unit] = { + sessionId: String, + events: Option[ExecuteEventsManager] = None, + isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { val session = SparkConnectService .getOrCreateIsolatedSession(userId, sessionId) .session val stackTraceEnabled = session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) - { + val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", se) - observer.onError( + ( + se, StatusProto.toStatusRuntimeException( buildStatusFromThrowable(se.getCause, stackTraceEnabled))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( - StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) case e: Throwable => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( + ( + e, Status.UNKNOWN .withCause(e) .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) .asRuntimeException()) } + partial + .andThen { case (original, wrapped) => + logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", original) + if (isInterrupted) { + events.foreach(_.postCanceled) + } else { + events.foreach(_.postFailed(wrapped.getMessage)) + } + observer.onError(wrapped) + } } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index a10540676b04e..595f9d65c269b 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionHolder, SessionStatus} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -58,8 +58,9 @@ trait SparkConnectPlanTest extends SharedSparkSession { } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(SessionHolder.forTesting(spark)) - .process(cmd, "clientId", "sessionId", new MockObserver()) + val executeHolder = buildExecutePlanHolder(cmd) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(cmd, new MockObserver(), executeHolder) } def readRel: proto.Relation = @@ -114,6 +115,28 @@ trait SparkConnectPlanTest extends SharedSparkSession { localRelationBuilder.setData(ByteString.copyFrom(bytes)) proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } + + def buildExecutePlanHolder(command: proto.Command): ExecuteHolder = { + val sessionHolder = SessionHolder.forTesting(spark) + sessionHolder.eventManager.status_(SessionStatus.Started) + + val context = proto.UserContext + .newBuilder() + .setUserId(sessionHolder.userId) + .build() + val plan = proto.Plan + .newBuilder() + .setCommand(command) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + val executeHolder = sessionHolder.createExecuteHolder(request) + executeHolder.eventsManager.status_(ExecuteStatus.Started) + executeHolder + } } /** diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index bceaada9051e3..498084efb8f3f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -16,30 +16,49 @@ */ package org.apache.spark.sql.connect.planner +import java.util.UUID +import java.util.concurrent.Semaphore + import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.protobuf +import com.google.protobuf.ByteString import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.commons.lang3.{JavaVersion, SystemUtils} +import org.mockito.Mockito.when +import org.scalatest.Tag +import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.CreateDataFrameViewCommand +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ -import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService} -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionHolder, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted} +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.util.Utils /** * Testing Connect Service implementation. */ -class SparkConnectServiceSuite extends SharedSparkSession { +class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with Logging { private def sparkSessionHolder = SessionHolder.forTesting(spark) + private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093") test("Test schema in analyze response") { withTable("test") { @@ -131,126 +150,365 @@ class SparkConnectServiceSuite extends SharedSparkSession { } test("SPARK-41224: collect data using arrow") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) - val instance = new SparkConnectService(false) - val connect = new MockRemoteSession() - val context = proto.UserContext - .newBuilder() - .setUserId("c1") - .build() - val plan = proto.Plan - .newBuilder() - .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) - .build() - val request = proto.ExecutePlanRequest - .newBuilder() - .setPlan(plan) - .setUserContext(context) - .build() - - // Execute plan. - @volatile var done = false - val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] - instance.executePlan( - request, - new StreamObserver[proto.ExecutePlanResponse] { - override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v - - override def onError(throwable: Throwable): Unit = throw throwable - - override def onCompleted(): Unit = done = true - }) - - // The current implementation is expected to be blocking. This is here to make sure it is. - assert(done) - - // 4 Partitions + Metrics - assert(responses.size == 6) - - // Make sure the first response is schema only - val head = responses.head - assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) - - // Make sure the last response is metrics only - val last = responses.last - assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) - - val allocator = new RootAllocator() - - // Check the 'data' batches - var expectedId = 0L - var previousEId = 0.0d - responses.tail.dropRight(1).foreach { response => - assert(response.hasArrowBatch) - val batch = response.getArrowBatch - assert(batch.getData != null) - assert(batch.getRowCount == 25) - - val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) - while (reader.loadNextBatch()) { - val root = reader.getVectorSchemaRoot - val idVector = root.getVector(0).asInstanceOf[BigIntVector] - val eidVector = root.getVector(1).asInstanceOf[Float8Vector] - val numRows = root.getRowCount - var i = 0 - while (i < numRows) { - assert(idVector.get(i) == expectedId) - expectedId += 1 - val eid = eidVector.get(i) - assert(eid > previousEId) - previousEId = eid - i += 1 + withEvents { verifyEvents => + // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 + assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted() + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 4 Partitions + Metrics + assert(responses.size == 6) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) + + val allocator = new RootAllocator() + + // Check the 'data' batches + var expectedId = 0L + var previousEId = 0.0d + responses.tail.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + assert(batch.getRowCount == 25) + + val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) + while (reader.loadNextBatch()) { + val root = reader.getVectorSchemaRoot + val idVector = root.getVector(0).asInstanceOf[BigIntVector] + val eidVector = root.getVector(1).asInstanceOf[Float8Vector] + val numRows = root.getRowCount + var i = 0 + while (i < numRows) { + assert(idVector.get(i) == expectedId) + expectedId += 1 + val eid = eidVector.get(i) + assert(eid > previousEId) + previousEId = eid + i += 1 + } } + reader.close() } - reader.close() + allocator.close() } - allocator.close() } - test("SPARK-41165: failures in the arrow collect path should not cause hangs") { - val instance = new SparkConnectService(false) + gridTest("SPARK-43923: commands send events")( + Seq( + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show tables").build()), + proto.Command + .newBuilder() + .setWriteOperation( + proto.WriteOperation + .newBuilder() + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) + .setPath("my/test/path") + .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), + proto.Command + .newBuilder() + .setWriteOperationV2( + proto.WriteOperationV2 + .newBuilder() + .setInput(proto.Relation.newBuilder.setRange( + proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) + .setTableName("testcat.testtable") + .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), + proto.Command + .newBuilder() + .setCreateDataframeView( + CreateDataFrameViewCommand + .newBuilder() + .setName("testview") + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), + proto.Command + .newBuilder() + .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), + proto.Command + .newBuilder() + .setExtension( + protobuf.Any.pack( + proto.ExamplePluginCommand + .newBuilder() + .setCustomField("SPARK-43923") + .build())), + proto.Command + .newBuilder() + .setWriteStreamOperationStart( + proto.WriteStreamOperationStart + .newBuilder() + .setInput( + proto.Relation + .newBuilder() + .setRead(proto.Read + .newBuilder() + .setIsStreaming(true) + .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .build()) + .build()) + .setOutputMode("Append") + .setAvailableNow(true) + .setQueryName("test") + .setFormat("memory") + .putOptions("checkpointLocation", s"${UUID.randomUUID}") + .setPath("test-path") + .build()), + proto.Command + .newBuilder() + .setStreamingQueryCommand( + proto.StreamingQueryCommand + .newBuilder() + .setQueryId( + proto.StreamingQueryInstanceId + .newBuilder() + .setId(DEFAULT_UUID.toString) + .setRunId(DEFAULT_UUID.toString) + .build()) + .setStop(true)), + proto.Command + .newBuilder() + .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand + .newBuilder() + .setListListeners(true)), + proto.Command + .newBuilder() + .setRegisterFunction( + proto.CommonInlineUserDefinedFunction + .newBuilder() + .setFunctionName("function") + .setPythonUdf( + proto.PythonUDF + .newBuilder() + .setEvalType(100) + .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) + .setCommand(ByteString.copyFrom("command".getBytes())) + .setPythonVer("3.10") + .build())))) { command => + withCommandTest { verifyEvents => + val instance = new SparkConnectService(false) + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setCommand(command) + .build() + + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setSessionId("s1") + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted() + // The current implementation is expected to be blocking. + // This is here to make sure it is. + assert(done) - // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session - val instaKill: Long => Long = { _ => - throw new Exception("Kaboom") + // Result + Metrics + if (responses.size > 1) { + assert(responses.size == 2) + + // Make sure the first response result only + val head = responses.head + assert(head.hasSqlCommandResult && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSqlCommandResult) + } } - session.udf.register("insta_kill", instaKill) - - val connect = new MockRemoteSession() - val context = proto.UserContext - .newBuilder() - .setUserId("c1") - .build() - val plan = proto.Plan - .newBuilder() - .setRoot(connect.sql("select insta_kill(id) from range(10)")) - .build() - val request = proto.ExecutePlanRequest - .newBuilder() - .setPlan(plan) - .setUserContext(context) - .setSessionId("session") - .build() - - // The observer is executed inside this thread. So - // we can perform the checks inside the observer. - instance.executePlan( - request, - new StreamObserver[proto.ExecutePlanResponse] { - override def onNext(v: proto.ExecutePlanResponse): Unit = { - fail("this should not receive responses") - } + } - override def onError(throwable: Throwable): Unit = { - assert(throwable.isInstanceOf[StatusRuntimeException]) - } + test("SPARK-43923: canceled request send events") { + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val sleep: Long => Long = { time => + Thread.sleep(time) + time + } + session.udf.register("sleep", sleep) - override def onCompleted(): Unit = { - fail("this should not complete") + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select sleep(10000)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId("session") + .build() + + val thread = new Thread { + override def run: Unit = { + verifyEvents.listener.semaphoreStarted.acquire() + instance.interrupt( + proto.InterruptRequest + .newBuilder() + .setSessionId("session") + .setUserContext(context) + .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL) + .build(), + new StreamObserver[proto.InterruptResponse] { + override def onNext(v: proto.InterruptResponse): Unit = {} + + override def onError(throwable: Throwable): Unit = {} + + override def onCompleted(): Unit = {} + }) } - }) + } + thread.start() + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + logInfo(s"$v") + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onCanceled + } + + override def onCompleted(): Unit = { + fail("this should not complete") + } + }) + thread.join() + verifyEvents.onCompleted() + } + } + + test("SPARK-41165: failures in the arrow collect path should not cause hangs") { + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session + val instaKill: Long => Long = { _ => + throw new Exception("Kaboom") + } + session.udf.register("insta_kill", instaKill) + + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select insta_kill(id) from range(10)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId("session") + .build() + + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + fail("this should not receive responses") + } + + override def onError(throwable: Throwable): Unit = { + assert(throwable.isInstanceOf[StatusRuntimeException]) + verifyEvents.onError(throwable) + } + + override def onCompleted(): Unit = { + fail("this should not complete") + } + }) + verifyEvents.onCompleted() + } } test("Test explain mode in analyze response") { @@ -378,4 +636,108 @@ class SparkConnectServiceSuite extends SharedSparkSession { assert(valuesList.last.hasLong && valuesList.last.getLong == 99) } } + + protected def withCommandTest(f: VerifyEvents => Unit): Unit = { + withView("testview") { + withTable("testcat.testtable") { + withSparkConf( + "spark.sql.catalog.testcat" -> classOf[InMemoryPartitionTableCatalog].getName, + Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES.key -> + "org.apache.spark.sql.connect.plugin.ExampleCommandPlugin") { + withEvents { verifyEvents => + val restartedQuery = mock[StreamingQuery] + when(restartedQuery.id).thenReturn(DEFAULT_UUID) + when(restartedQuery.runId).thenReturn(DEFAULT_UUID) + SparkConnectService.streamingSessionManager.registerNewStreamingQuery( + SparkConnectService.getOrCreateIsolatedSession("c1", "s1"), + restartedQuery) + f(verifyEvents) + } + } + } + } + } + + protected def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + pairs.foreach { kv => conf.remove(kv._1) } + } + } + + protected def withEvents(f: VerifyEvents => Unit): Unit = { + val verifyEvents = new VerifyEvents(spark.sparkContext) + spark.sparkContext.addSparkListener(verifyEvents.listener) + Utils.tryWithSafeFinally({ + f(verifyEvents) + SparkConnectService.invalidateAllSessions() + verifyEvents.onSessionClosed() + }) { + verifyEvents.waitUntilEmpty() + spark.sparkContext.removeSparkListener(verifyEvents.listener) + SparkConnectService.invalidateAllSessions() + SparkConnectPluginRegistry.reset() + } + } + + protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])( + testFun: A => Unit): Unit = { + for (param <- params) { + test(testNamePrefix + s" ($param)", testTags: _*)(testFun(param)) + } + } + + class VerifyEvents(val sparkContext: SparkContext) { + val listener: MockSparkListener = new MockSparkListener() + val listenerBus = sparkContext.listenerBus + val LISTENER_BUS_TIMEOUT = 30000 + def executeHolder: ExecuteHolder = { + assert(listener.executeHolder.isDefined) + listener.executeHolder.get + } + def onNext(v: proto.ExecutePlanResponse): Unit = { + if (v.hasSchema) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Analyzed) + } + if (v.hasMetrics) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Finished) + } + } + def onError(throwable: Throwable): Unit = { + assert(executeHolder.eventsManager.hasCanceled.isEmpty) + assert(executeHolder.eventsManager.hasError.isDefined) + } + def onCompleted(): Unit = { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } + def onCanceled(): Unit = { + assert(executeHolder.eventsManager.hasCanceled.contains(true)) + assert(executeHolder.eventsManager.hasError.isEmpty) + } + def onSessionClosed(): Unit = { + assert(executeHolder.sessionHolder.eventManager.status == SessionStatus.Closed) + } + def onSessionStarted(): Unit = { + assert(executeHolder.sessionHolder.eventManager.status == SessionStatus.Started) + } + def waitUntilEmpty(): Unit = { + listenerBus.waitUntilEmpty(LISTENER_BUS_TIMEOUT) + } + } + class MockSparkListener() extends SparkListener { + val semaphoreStarted = new Semaphore(0) + var executeHolder = Option.empty[ExecuteHolder] + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: SparkListenerConnectOperationStarted => + semaphoreStarted.release() + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession(e.userId, e.sessionId) + executeHolder = sessionHolder.executeHolder(e.operationId) + case _ => + } + } + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 2bdabc7ccc214..fdb9032379419 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest} -import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.test.SharedSparkSession class DummyPlugin extends RelationPlugin { @@ -196,8 +195,9 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(SessionHolder.forTesting(spark)) - .process(plan, "clientId", "sessionId", new MockObserver()) + val executeHolder = buildExecutePlanHolder(plan) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(plan, new MockObserver(), executeHolder) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala new file mode 100644 index 0000000000000..365b17632a742 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import scala.util.matching.Regex + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{ExecutePlanRequest, Plan, UserContext} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.planner.SparkConnectPlanTest +import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.util.ManualClock + +class ExecuteEventsManagerSuite + extends SparkFunSuite + with MockitoSugar + with SparkConnectPlanTest { + + val DEFAULT_ERROR = "error" + val DEFAULT_CLOCK = new ManualClock() + val DEFAULT_NODE_NAME = "nodeName" + val DEFAULT_TEXT = """limit { + limit: 10 +} +""" + val DEFAULT_USER_ID = "1" + val DEFAULT_USER_NAME = "userName" + val DEFAULT_SESSION_ID = "2" + val DEFAULT_QUERY_ID = "3" + val DEFAULT_CLIENT_TYPE = "clientType" + + test("SPARK-43923: post started") { + val events = setupEvents(ExecuteStatus.Pending) + events.postStarted() + + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(SparkListenerConnectOperationStarted( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_USER_NAME, + DEFAULT_TEXT, + Some(events.executeHolder.request), + Map.empty)) + } + + test("SPARK-43923: post analyzed with plan") { + val events = setupEvents(ExecuteStatus.Started) + + val mockPlan = mock[LogicalPlan] + events.postAnalyzed(Some(mockPlan)) + val event = SparkListenerConnectOperationAnalyzed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + event.analyzedPlan = Some(mockPlan) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(event) + } + + test("SPARK-43923: post analyzed with empty plan") { + val events = setupEvents(ExecuteStatus.Started) + events.postAnalyzed() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationAnalyzed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post readyForExecution") { + val events = setupEvents(ExecuteStatus.Analyzed) + events.postReadyForExecution() + val event = SparkListenerConnectOperationReadyForExecution( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(event) + } + + test("SPARK-43923: post canceled") { + val events = setupEvents(ExecuteStatus.Started) + events.postCanceled() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationCanceled( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post failed") { + val events = setupEvents(ExecuteStatus.Started) + events.postFailed(DEFAULT_ERROR) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFailed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + DEFAULT_ERROR, + Map.empty[String, String])) + } + + test("SPARK-43923: post finished") { + val events = setupEvents(ExecuteStatus.Started) + events.postFinished() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFinished( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post closed") { + val events = setupEvents(ExecuteStatus.Finished) + events.postClosed() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationClosed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: Closed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Closed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postCanceled() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Finished wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Finished) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + } + + test("SPARK-43923: Failed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Finished) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + } + + test("SPARK-43923: Canceled wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Canceled) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postCanceled() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postFailed(DEFAULT_ERROR) + } + } + + test("SPARK-43923: ReadyForExecution wrong order throws exception") { + val events = setupEvents(ExecuteStatus.ReadyForExecution) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Analyzed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Analyzed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Started wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Started) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Started wrong session status") { + val events = setupEvents(ExecuteStatus.Started, SessionStatus.Pending) + assertThrows[IllegalStateException] { + events.postStarted() + } + } + + def setupEvents( + executeStatus: ExecuteStatus, + sessionStatus: SessionStatus = SessionStatus.Started): ExecuteEventsManager = { + val mockSession = mock[SparkSession] + val sessionHolder = SessionHolder(DEFAULT_USER_ID, DEFAULT_SESSION_ID, mockSession) + sessionHolder.eventManager.status_(sessionStatus) + val mockContext = mock[SparkContext] + val mockListenerBus = mock[LiveListenerBus] + val mockSessionState = mock[SessionState] + val mockConf = mock[SQLConf] + when(mockSession.sessionState).thenReturn(mockSessionState) + when(mockSessionState.conf).thenReturn(mockConf) + when(mockConf.stringRedactionPattern).thenReturn(Option.empty[Regex]) + when(mockContext.listenerBus).thenReturn(mockListenerBus) + when(mockSession.sparkContext).thenReturn(mockContext) + + val relation = proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(10)) + .build() + + val executePlanRequest = ExecutePlanRequest + .newBuilder() + .setPlan(Plan.newBuilder().setRoot(relation)) + .setUserContext( + UserContext + .newBuilder() + .setUserId(DEFAULT_USER_ID) + .setUserName(DEFAULT_USER_NAME)) + .setSessionId(DEFAULT_SESSION_ID) + .setClientType(DEFAULT_CLIENT_TYPE) + .build() + + val executeHolder = new ExecuteHolder(executePlanRequest, DEFAULT_QUERY_ID, sessionHolder) + + val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK) + eventsManager.status_(executeStatus) + eventsManager + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala new file mode 100644 index 0000000000000..7025146b0295b --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.planner.SparkConnectPlanTest +import org.apache.spark.util.ManualClock + +class SessionEventsManagerSuite + extends SparkFunSuite + with MockitoSugar + with SparkConnectPlanTest { + + val DEFAULT_ERROR = "error" + val DEFAULT_CLOCK = new ManualClock() + val DEFAULT_NODE_NAME = "nodeName" + val DEFAULT_TEXT = """limit { + limit: 10 +} +""" + val DEFAULT_USER_ID = "1" + val DEFAULT_USER_NAME = "userName" + val DEFAULT_SESSION_ID = "2" + val DEFAULT_QUERY_ID = "3" + val DEFAULT_CLIENT_TYPE = "clientType" + + test("SPARK-43923: post started") { + val events = setupEvents(SessionStatus.Pending) + events.postStarted() + + verify(events.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectSessionStarted( + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_CLOCK.getTimeMillis(), + Map.empty)) + } + + test("SPARK-43923: post closed") { + val events = setupEvents(SessionStatus.Started) + events.postClosed() + + verify(events.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectSessionClosed( + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_CLOCK.getTimeMillis(), + Map.empty)) + } + + test("SPARK-43923: Started wrong order throws exception") { + val events = setupEvents(SessionStatus.Started) + assertThrows[IllegalStateException] { + events.postStarted() + } + } + + test("SPARK-43923: Closed wrong order throws exception") { + val events = setupEvents(SessionStatus.Closed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + def setupEvents(status: SessionStatus): SessionEventsManager = { + val mockSession = mock[SparkSession] + val sessionHolder = SessionHolder(DEFAULT_USER_ID, DEFAULT_SESSION_ID, mockSession) + val mockContext = mock[SparkContext] + val mockListenerBus = mock[LiveListenerBus] + when(mockContext.listenerBus).thenReturn(mockListenerBus) + when(mockSession.sparkContext).thenReturn(mockContext) + + val eventsManager = SessionEventsManager(sessionHolder, DEFAULT_CLOCK) + eventsManager.status_(status) + eventsManager + } +} From c96b9710f50ed045d8e949db6c0e251bbd57b201 Mon Sep 17 00:00:00 2001 From: Richard Yu Date: Tue, 18 Jul 2023 08:57:55 +0800 Subject: [PATCH 004/986] [SPARK-44059] Add analyzer support of named arguments for built-in functions ### What changes were proposed in this pull request? Add analyzer support for named function arguments. ### Why are the changes needed? Part of the project needed for general named function argument support. ### Does this PR introduce _any_ user-facing change? We added support for named arguments for the ```CountMinSketchAgg``` and ```Mask``` SQL functions. ### How was this patch tested? A new suite was added for this test called NamedArgumentFunctionSuite. Closes #42020 from learningchess2003/44059-final. Authored-by: Richard Yu Signed-off-by: Wenchen Fan (cherry picked from commit 228b5dbfd7688a8efa7135d9ec7b00b71e41a38a) Signed-off-by: Wenchen Fan --- .../utils/src/main/resources/error/README.md | 1 + .../main/resources/error/error-classes.json | 44 ++- ...outine-parameter-assignment-error-class.md | 36 ++ docs/sql-error-conditions.md | 34 +- .../catalyst/analysis/FunctionRegistry.scala | 113 +++++- .../aggregate/CountMinSketchAgg.scala | 49 ++- .../sql/catalyst/expressions/generators.scala | 79 +++- .../expressions/maskExpressions.scala | 31 +- .../plans/logical/FunctionBuilderBase.scala | 177 +++++++++ .../sql/errors/QueryCompilationErrors.scala | 76 +++- .../NamedParameterFunctionSuite.scala | 151 ++++++++ .../named-function-arguments.sql.out | 337 +++++++++++++++--- .../inputs/named-function-arguments.sql | 55 +++ .../results/named-function-arguments.sql.out | 324 ++++++++++++++--- .../sql/errors/QueryParsingErrorsSuite.scala | 9 +- 15 files changed, 1386 insertions(+), 130 deletions(-) create mode 100644 docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala diff --git a/common/utils/src/main/resources/error/README.md b/common/utils/src/main/resources/error/README.md index dfcb42d49e79a..aed2c0becd311 100644 --- a/common/utils/src/main/resources/error/README.md +++ b/common/utils/src/main/resources/error/README.md @@ -666,6 +666,7 @@ The following SQLSTATEs are collated from: |4274C |42 |Syntax Error or Access Rule Violation |74C |The specified attribute was not found in the trusted context.|DB2 |N |DB2 | |4274D |42 |Syntax Error or Access Rule Violation |74D |The specified attribute already exists in the trusted context.|DB2 |N |DB2 | |4274E |42 |Syntax Error or Access Rule Violation |74E |The specified attribute is not supported in the trusted context.|DB2 |N |DB2 | +|4274K |42 |Syntax Error or Access Rule Violation |74K |Invalid use of a named argument when invoking a routine.|DB2 |N |DB2 | |4274M |42 |Syntax Error or Access Rule Violation |74M |An undefined period name was detected. |DB2 |N |DB2 | |42801 |42 |Syntax Error or Access Rule Violation |801 |Isolation level UR is invalid, because the result table is not read-only.|DB2 |N |DB2 | |42802 |42 |Syntax Error or Access Rule Violation |802 |The number of target values is not the same as the number of source values.|DB2 |N |DB2 | diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index e8cdaa6c63b3f..b136878e6d2c0 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -738,6 +738,24 @@ ], "sqlState" : "23505" }, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT" : { + "message" : [ + "Call to function is invalid because it includes multiple argument assignments to the same parameter name ." + ], + "subClass" : { + "BOTH_POSITIONAL_AND_NAMED" : { + "message" : [ + "A positional argument and named argument both referred to the same parameter." + ] + }, + "DOUBLE_NAMED_ARGUMENT_REFERENCE" : { + "message" : [ + "More than one named argument referred to the same parameter." + ] + } + }, + "sqlState" : "4274K" + }, "EMPTY_JSON_FIELD_VALUE" : { "message" : [ "Failed to parse an empty string for data type ." @@ -1956,7 +1974,13 @@ "Not allowed to implement multiple UDF interfaces, UDF class ." ] }, - "NAMED_ARGUMENTS_SUPPORT_DISABLED" : { + "NAMED_PARAMETERS_NOT_SUPPORTED" : { + "message" : [ + "Named parameters are not supported for function ; please retry the query with positional arguments to the function call instead." + ], + "sqlState" : "4274K" + }, + "NAMED_PARAMETER_SUPPORT_DISABLED" : { "message" : [ "Cannot call function because named argument references are not enabled here. In this case, the named argument reference was . Set \"spark.sql.allowNamedFunctionArguments\" to \"true\" to turn on feature." ] @@ -2295,6 +2319,12 @@ ], "sqlState" : "42614" }, + "REQUIRED_PARAMETER_NOT_FOUND" : { + "message" : [ + "Cannot invoke function because the parameter named is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally or by name) and retry the query again." + ], + "sqlState" : "4274K" + }, "REQUIRES_SINGLE_PART_NAMESPACE" : { "message" : [ " requires a single-part namespace, but got ." @@ -2485,6 +2515,12 @@ ], "sqlState" : "42K09" }, + "UNEXPECTED_POSITIONAL_ARGUMENT" : { + "message" : [ + "Cannot invoke function because it contains positional argument(s) following named argument(s); please rearrange them so the positional arguments come first and then retry the query again." + ], + "sqlState" : "4274K" + }, "UNKNOWN_PROTOBUF_MESSAGE_TYPE" : { "message" : [ "Attempting to treat as a Message, but it was ." @@ -2514,6 +2550,12 @@ ], "sqlState" : "428C4" }, + "UNRECOGNIZED_PARAMETER_NAME" : { + "message" : [ + "Cannot invoke function because the function call included a named argument reference for the argument named , but this function does not include any signature containing an argument with this name. Did you mean one of the following? []." + ], + "sqlState" : "4274K" + }, "UNRECOGNIZED_SQL_TYPE" : { "message" : [ "Unrecognized SQL type - name: , id: ." diff --git a/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md b/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md new file mode 100644 index 0000000000000..d9f14b5a55ef8 --- /dev/null +++ b/docs/sql-error-conditions-duplicate-routine-parameter-assignment-error-class.md @@ -0,0 +1,36 @@ +--- +layout: global +title: DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT error class +displayTitle: DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT error class +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +[SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Call to function `` is invalid because it includes multiple argument assignments to the same parameter name ``. + +This error class has the following derived error classes: + +## BOTH_POSITIONAL_AND_NAMED + +A positional argument and named argument both referred to the same parameter. + +## DOUBLE_NAMED_ARGUMENT_REFERENCE + +More than one named argument referred to the same parameter. + + diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 91b77a6452bc5..5686324a0558b 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -456,6 +456,14 @@ Found duplicate clauses: ``. Please, remove one of them. Found duplicate keys ``. +### [DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT](sql-error-conditions-duplicate-routine-parameter-assignment-error-class.html) + +[SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Call to function `` is invalid because it includes multiple argument assignments to the same parameter name ``. + +For more details see [DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT](sql-error-conditions-duplicate-routine-parameter-assignment-error-class.html) + ### EMPTY_JSON_FIELD_VALUE [SQLSTATE: 42604](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -1210,7 +1218,13 @@ SQLSTATE: none assigned Not allowed to implement multiple UDF interfaces, UDF class ``. -### NAMED_ARGUMENTS_SUPPORT_DISABLED +### NAMED_PARAMETERS_NOT_SUPPORTED + +[SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Named parameters are not supported for function ``; please retry the query with positional arguments to the function call instead. + +### NAMED_PARAMETER_SUPPORT_DISABLED SQLSTATE: none assigned @@ -1521,6 +1535,12 @@ Failed to rename as `` was not found. The `` clause may be used at most once per `` operation. +### REQUIRED_PARAMETER_NOT_FOUND + +[SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot invoke function `` because the parameter named `` is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally or by name) and retry the query again. + ### REQUIRES_SINGLE_PART_NAMESPACE [SQLSTATE: 42K05](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -1724,6 +1744,12 @@ Found an unclosed bracketed comment. Please, append */ at the end of the comment Parameter `` of function `` requires the `` type, however `` has the type ``. +### UNEXPECTED_POSITIONAL_ARGUMENT + +[SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot invoke function `` because it contains positional argument(s) following named argument(s); please rearrange them so the positional arguments come first and then retry the query again. + ### UNKNOWN_PROTOBUF_MESSAGE_TYPE SQLSTATE: none assigned @@ -1754,6 +1780,12 @@ Unpivot value columns must share a least common type, some types do not: [``). +### UNRECOGNIZED_PARAMETER_NAME + +[SQLSTATE: 4274K](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot invoke function `` because the function call included a named argument reference for the argument named ``, but this function does not include any signature containing an argument with this name. Did you mean one of the following? [``]. + ### UNRECOGNIZED_SQL_TYPE [SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a9bda2e0b7c99..558579cdb80ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.xml._ -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRowRelation, Range} +import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -358,8 +358,8 @@ object FunctionRegistry { // misc non-aggregate functions expression[Abs]("abs"), expression[Coalesce]("coalesce"), - expression[Explode]("explode"), - expressionGeneratorOuter[Explode]("explode_outer"), + expressionBuilder("explode", ExplodeExpressionBuilder), + expressionGeneratorBuilderOuter("explode_outer", ExplodeExpressionBuilder), expression[Greatest]("greatest"), expression[If]("if"), expression[Inline]("inline"), @@ -491,7 +491,7 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectList]("array_agg", true, Some("3.3.0")), expression[CollectSet]("collect_set"), - expression[CountMinSketchAgg]("count_min_sketch"), + expressionBuilder("count_min_sketch", CountMinSketchAggExpressionBuilder), expression[BoolAnd]("every", true), expression[BoolAnd]("bool_and"), expression[BoolOr]("any", true), @@ -823,7 +823,7 @@ object FunctionRegistry { castAlias("string", StringType), // mask functions - expression[Mask]("mask"), + expressionBuilder("mask", MaskExpressionBuilder), // csv expression[CsvToStructs]("from_csv"), @@ -887,6 +887,9 @@ object FunctionRegistry { since: Option[String] = None): (String, (ExpressionInfo, FunctionBuilder)) = { val (expressionInfo, builder) = FunctionRegistryBase.build[T](name, since) val newBuilder = (expressions: Seq[Expression]) => { + if (expressions.exists(_.isInstanceOf[NamedArgumentExpression])) { + throw QueryCompilationErrors.namedArgumentsNotSupported(name) + } val expr = builder(expressions) if (setAlias) expr.setTagValue(FUNC_ALIAS, name) expr @@ -894,6 +897,32 @@ object FunctionRegistry { (name, (expressionInfo, newBuilder)) } + /** + * This method will be used to rearrange the arguments provided in function invocation + * in the order defined by the function signature given in the builder instance. + * + * @param name The name of the function + * @param builder The builder of the function expression + * @param expressions The argument list passed in function invocation + * @tparam T The class of the builder + * @return An argument list in positional order defined by the builder + */ + def rearrangeExpressions[T <: FunctionBuilderBase[_]]( + name: String, + builder: T, + expressions: Seq[Expression]) : Seq[Expression] = { + val rearrangedExpressions = if (!builder.functionSignature.isEmpty) { + val functionSignature = builder.functionSignature.get + builder.rearrange(functionSignature, expressions, name) + } else { + expressions + } + if (rearrangedExpressions.exists(_.isInstanceOf[NamedArgumentExpression])) { + throw QueryCompilationErrors.namedArgumentsNotSupported(name) + } + rearrangedExpressions + } + private def expressionBuilder[T <: ExpressionBuilder : ClassTag]( name: String, builder: T, @@ -902,7 +931,8 @@ object FunctionRegistry { val info = FunctionRegistryBase.expressionInfo[T](name, since) val funcBuilder = (expressions: Seq[Expression]) => { assert(expressions.forall(_.resolved), "function arguments must be resolved.") - val expr = builder.build(name, expressions) + val rearrangedExpressions = rearrangeExpressions(name, builder, expressions) + val expr = builder.build(name, rearrangedExpressions) if (setAlias) expr.setTagValue(FUNC_ALIAS, name) expr } @@ -935,9 +965,22 @@ object FunctionRegistry { private def expressionGeneratorOuter[T <: Generator : ClassTag](name: String) : (String, (ExpressionInfo, FunctionBuilder)) = { - val (_, (info, generatorBuilder)) = expression[T](name) + val (_, (info, builder)) = expression[T](name) val outerBuilder = (args: Seq[Expression]) => { - GeneratorOuter(generatorBuilder(args).asInstanceOf[Generator]) + GeneratorOuter(builder(args).asInstanceOf[Generator]) + } + (name, (info, outerBuilder)) + } + + private def expressionGeneratorBuilderOuter[T <: ExpressionBuilder : ClassTag] + (name: String, builder: T) : (String, (ExpressionInfo, FunctionBuilder)) = { + val info = FunctionRegistryBase.expressionInfo[T](name, since = None) + val outerBuilder = (args: Seq[Expression]) => { + val rearrangedArgs = + FunctionRegistry.rearrangeExpressions(name, builder, args) + val generator = builder.build(name, rearrangedArgs) + assert(generator.isInstanceOf[Generator]) + GeneratorOuter(generator.asInstanceOf[Generator]) } (name, (info, outerBuilder)) } @@ -980,6 +1023,30 @@ object TableFunctionRegistry { (name, (info, (expressions: Seq[Expression]) => builder(expressions))) } + /** + * A function used for table-valued functions to return a builder that + * when given input arguments, will return a function expression representing + * the table-valued functions. + * + * @param name Name of the function + * @param builder Object which will build the expression given input arguments + * @param since Time of implementation + * @tparam T Type of the builder + * @return A tuple of the function name, expression info, and function builder + */ + def generatorBuilder[T <: GeneratorBuilder : ClassTag]( + name: String, + builder: T, + since: Option[String] = None): (String, (ExpressionInfo, TableFunctionBuilder)) = { + val info = FunctionRegistryBase.expressionInfo[T](name, since) + val funcBuilder = (expressions: Seq[Expression]) => { + assert(expressions.forall(_.resolved), "function arguments must be resolved.") + val rearrangedExpressions = FunctionRegistry.rearrangeExpressions(name, builder, expressions) + builder.build(name, rearrangedExpressions) + } + (name, (info, funcBuilder)) + } + def generator[T <: Generator : ClassTag](name: String, outer: Boolean = false) : (String, (ExpressionInfo, TableFunctionBuilder)) = { val (info, builder) = FunctionRegistryBase.build[T](name, since = None) @@ -999,8 +1066,8 @@ object TableFunctionRegistry { val logicalPlans: Map[String, (ExpressionInfo, TableFunctionBuilder)] = Map( logicalPlan[Range]("range"), - generator[Explode]("explode"), - generator[Explode]("explode_outer", outer = true), + generatorBuilder("explode", ExplodeGeneratorBuilder), + generatorBuilder("explode_outer", ExplodeOuterGeneratorBuilder), generator[Inline]("inline"), generator[Inline]("inline_outer", outer = true), generator[JsonTuple]("json_tuple"), @@ -1022,6 +1089,28 @@ object TableFunctionRegistry { val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet } -trait ExpressionBuilder { - def build(funcName: String, expressions: Seq[Expression]): Expression +/** + * This is a trait used for scalar valued functions that defines how their expression + * representations are constructed in [[FunctionRegistry]]. + */ +trait ExpressionBuilder extends FunctionBuilderBase[Expression] + +/** + * This is a trait used for table valued functions that defines how their expression + * representations are constructed in [[TableFunctionRegistry]]. + */ +trait GeneratorBuilder extends FunctionBuilderBase[LogicalPlan] { + override final def build(funcName: String, expressions: Seq[Expression]) : LogicalPlan = { + Generate( + buildGenerator(funcName, expressions), + unrequiredChildIndex = Nil, + outer = isOuter, + qualifier = None, + generatorOutput = Nil, + child = OneRowRelation()) + } + + def isOuter: Boolean + + def buildGenerator(funcName: String, expressions: Seq[Expression]) : Generator } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index 6cefca418cea0..b7988922bd79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.catalyst.trees.QuaternaryLike import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types._ @@ -39,22 +40,6 @@ import org.apache.spark.util.sketch.CountMinSketch * @param confidenceExpression confidence, must be positive and less than 1.0 * @param seedExpression random seed */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp, - confidence and seed. The result is an array of bytes, which can be deserialized to a - `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for - cardinality estimation using sub-linear space. - """, - examples = """ - Examples: - > SELECT hex(_FUNC_(col, 0.5d, 0.5d, 1)) FROM VALUES (1), (2), (1) AS tab(col); - 0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000 - """, - group = "agg_funcs", - since = "2.2.0") -// scalastyle:on line.size.limit case class CountMinSketchAgg( child: Expression, epsExpression: Expression, @@ -208,3 +193,33 @@ case class CountMinSketchAgg( confidenceExpression = third, seedExpression = fourth) } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp, + confidence and seed. The result is an array of bytes, which can be deserialized to a + `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for + cardinality estimation using sub-linear space. + """, + examples = """ + Examples: + > SELECT hex(_FUNC_(col, 0.5d, 0.5d, 1)) FROM VALUES (1), (2), (1) AS tab(col); + 0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000 + """, + group = "agg_funcs", + since = "2.2.0") +// scalastyle:on line.size.limit +object CountMinSketchAggExpressionBuilder extends ExpressionBuilder { + final val defaultFunctionSignature = FunctionSignature(Seq( + InputParameter("column"), + InputParameter("epsilon"), + InputParameter("confidence"), + InputParameter("seed") + )) + override def functionSignature: Option[FunctionSignature] = Some(defaultFunctionSignature) + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + assert(expressions.size == 4) + new CountMinSketchAgg(expressions(0), expressions(1), expressions(2), expressions(3)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 6ae7ea206c833..afaaf07d2726b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -21,11 +21,12 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, GeneratorBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.catalyst.trees.TreePattern.{GENERATOR, TreePattern} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.catalyst.util.SQLKeywordUtils._ @@ -413,6 +414,21 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with * 20 * }}} */ +case class Explode(child: Expression) extends ExplodeBase { + override val position: Boolean = false + override protected def withNewChildInternal(newChild: Expression): Explode = + copy(child = newChild) +} + +trait ExplodeGeneratorBuilderBase extends GeneratorBuilder { + override def functionSignature: Option[FunctionSignature] = + Some(FunctionSignature(Seq(InputParameter("collection")))) + override def buildGenerator(funcName: String, expressions: Seq[Expression]): Generator = { + assert(expressions.size == 1) + Explode(expressions(0)) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(expr) - Separates the elements of array `expr` into multiple rows, or the elements of map `expr` into multiple rows and columns. Unless specified otherwise, uses the default column name `col` for elements of the array or `key` and `value` for the elements of the map.", @@ -421,16 +437,66 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with > SELECT _FUNC_(array(10, 20)); 10 20 + > SELECT _FUNC_(collection => array(10, 20)); + 10 + 20 + > SELECT * FROM _FUNC_(collection => array(10, 20)); + 10 + 20 """, since = "1.0.0", group = "generator_funcs") // scalastyle:on line.size.limit -case class Explode(child: Expression) extends ExplodeBase { - override val position: Boolean = false - override protected def withNewChildInternal(newChild: Expression): Explode = - copy(child = newChild) +object ExplodeExpressionBuilder extends ExpressionBuilder { + override def functionSignature: Option[FunctionSignature] = + Some(FunctionSignature(Seq(InputParameter("collection")))) + + override def build(funcName: String, expressions: Seq[Expression]) : Expression = + Explode(expressions(0)) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Separates the elements of array `expr` into multiple rows, or the elements of map `expr` into multiple rows and columns. Unless specified otherwise, uses the default column name `col` for elements of the array or `key` and `value` for the elements of the map.", + examples = """ + Examples: + > SELECT _FUNC_(array(10, 20)); + 10 + 20 + > SELECT _FUNC_(collection => array(10, 20)); + 10 + 20 + > SELECT * FROM _FUNC_(collection => array(10, 20)); + 10 + 20 + """, + since = "1.0.0", + group = "generator_funcs") +// scalastyle:on line.size.limit +object ExplodeGeneratorBuilder extends ExplodeGeneratorBuilderBase { + override def isOuter: Boolean = false +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Separates the elements of array `expr` into multiple rows, or the elements of map `expr` into multiple rows and columns. Unless specified otherwise, uses the default column name `col` for elements of the array or `key` and `value` for the elements of the map.", + examples = """ + Examples: + > SELECT _FUNC_(array(10, 20)); + 10 + 20 + > SELECT _FUNC_(collection => array(10, 20)); + 10 + 20 + """, + since = "1.0.0", + group = "generator_funcs") +// scalastyle:on line.size.limit +object ExplodeOuterGeneratorBuilder extends ExplodeGeneratorBuilderBase { + override def isOuter: Boolean = true +} + + /** * Given an input array produces a sequence of rows for each position and value in the array. * @@ -448,6 +514,9 @@ case class Explode(child: Expression) extends ExplodeBase { > SELECT _FUNC_(array(10,20)); 0 10 1 20 + > SELECT * FROM _FUNC_(array(10,20)); + 0 10 + 1 20 """, since = "2.0.0", group = "generator_funcs") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index af74e7c0f7b24..61a96ff5ff951 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -75,6 +76,26 @@ import org.apache.spark.unsafe.types.UTF8String since = "3.4.0", group = "string_funcs") // scalastyle:on line.size.limit +object MaskExpressionBuilder extends ExpressionBuilder { + override def functionSignature: Option[FunctionSignature] = { + val strArg = InputParameter("str") + val upperCharArg = InputParameter("upperChar", Some(Literal(Mask.MASKED_UPPERCASE))) + val lowerCharArg = InputParameter("lowerChar", Some(Literal(Mask.MASKED_LOWERCASE))) + val digitCharArg = InputParameter("digitChar", Some(Literal(Mask.MASKED_DIGIT))) + val otherCharArg = InputParameter( + "otherChar", + Some(Literal(Mask.MASKED_IGNORE, StringType))) + val functionSignature: FunctionSignature = FunctionSignature(Seq( + strArg, upperCharArg, lowerCharArg, digitCharArg, otherCharArg)) + Some(functionSignature) + } + + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + assert(expressions.size == 5) + new Mask(expressions(0), expressions(1), expressions(2), expressions(3), expressions(4)) + } +} + case class Mask( input: Expression, upperChar: Expression, @@ -277,13 +298,13 @@ case class MaskArgument(maskChar: Char, ignore: Boolean) object Mask { // Default character to replace upper-case characters - private val MASKED_UPPERCASE = 'X' + val MASKED_UPPERCASE = 'X' // Default character to replace lower-case characters - private val MASKED_LOWERCASE = 'x' + val MASKED_LOWERCASE = 'x' // Default character to replace digits - private val MASKED_DIGIT = 'n' + val MASKED_DIGIT = 'n' // This value helps to retain original value in the input by ignoring the replacement rules - private val MASKED_IGNORE = null + val MASKED_IGNORE = null def transformInput( input: Any, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala new file mode 100644 index 0000000000000..4a2b9eae98100 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * This is a base trait that is used for implementing builder classes that can be used to construct + * expressions or logical plans depending on if it is a table-valued or scalar-valued function. + * + * Two classes of builders currently exist for this trait: [[GeneratorBuilder]] and + * [[ExpressionBuilder]]. If a new class of functions are to be added, a new trait should also be + * created which extends this trait. + * + * @tparam T The type that is expected to be returned by the [[FunctionBuilderBase.build]] function + */ +trait FunctionBuilderBase[T] { + /** + * A method that returns the method signature for this function. + * Each function signature includes a list of parameters to which the analyzer can + * compare a function call with provided arguments to determine if that function + * call is a match for the function signature. + * + * IMPORTANT: For now, each function expression builder should have only one function signature. + * Also, for any function signature, required arguments must always come before optional ones. + */ + def functionSignature: Option[FunctionSignature] = None + + /** + * This function rearranges the arguments provided during function invocation in positional order + * according to the function signature. This method will fill in the default values if optional + * parameters do not have their values specified. Any function which supports named arguments + * will have this routine invoked, even if no named arguments are present in the argument list. + * This is done to eliminate constructor overloads in some methods which use them for default + * values prior to the implementation of the named argument framework. This function will also + * check if the number of arguments are correct. If that is not the case, then an error will be + * thrown. + * + * IMPORTANT: This method will be called before the [[FunctionBuilderBase.build]] method is + * invoked. It is guaranteed that the expressions provided to the [[FunctionBuilderBase.build]] + * functions forms a valid set of argument expressions that can be used in the construction of + * the function expression. + * + * @param expectedSignature The method signature which we rearrange our arguments according to + * @param providedArguments The list of arguments passed from function invocation + * @param functionName The name of the function + * @return The rearranged argument list with arguments in positional order + */ + def rearrange( + expectedSignature: FunctionSignature, + providedArguments: Seq[Expression], + functionName: String) : Seq[Expression] = { + NamedParametersSupport.defaultRearrange(expectedSignature, providedArguments, functionName) + } + + def build(funcName: String, expressions: Seq[Expression]): T +} + +object NamedParametersSupport { + /** + * This method is the default routine which rearranges the arguments in positional order according + * to the function signature provided. This will also fill in any default values that exists for + * optional arguments. This method will also be invoked even if there are no named arguments in + * the argument list. This method will keep all positional arguments in their original order. + * + * @param functionSignature The function signature that defines the positional ordering + * @param args The argument list provided in function invocation + * @param functionName The name of the function + * @return A list of arguments rearranged in positional order defined by the provided signature + */ + final def defaultRearrange( + functionSignature: FunctionSignature, + args: Seq[Expression], + functionName: String): Seq[Expression] = { + val parameters: Seq[InputParameter] = functionSignature.parameters + if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) { + throw QueryCompilationErrors.unexpectedRequiredParameterInFunctionSignature( + functionName, functionSignature) + } + + val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression]) + val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) + + // The following loop checks for the following: + // 1. Unrecognized parameter names + // 2. Duplicate routine parameter assignments + val allParameterNames: Seq[String] = parameters.map(_.name) + val parameterNamesSet: Set[String] = allParameterNames.toSet + val positionalParametersSet = allParameterNames.take(positionalArgs.size).toSet + val namedParametersSet = collection.mutable.Set[String]() + + for (arg <- namedArgs) { + arg match { + case namedArg: NamedArgumentExpression => + val parameterName = namedArg.key + if (!parameterNamesSet.contains(parameterName)) { + throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, + parameterNamesSet.toSeq) + } + if (positionalParametersSet.contains(parameterName)) { + throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( + functionName, namedArg.key) + } + if (namedParametersSet.contains(parameterName)) { + throw QueryCompilationErrors.doubleNamedArgumentReference( + functionName, namedArg.key) + } + namedParametersSet.add(namedArg.key) + case _ => + throw QueryCompilationErrors.unexpectedPositionalArgument(functionName) + } + } + + // Check the argument list size against the provided parameter list length. + if (parameters.size < args.length) { + val validParameterSizes = + Array.range(parameters.count(_.default.isEmpty), parameters.size + 1).toSeq + throw QueryCompilationErrors.wrongNumArgsError( + functionName, validParameterSizes, args.length) + } + + // This constructs a map from argument name to value for argument rearrangement. + val namedArgMap = namedArgs.map { arg => + val namedArg = arg.asInstanceOf[NamedArgumentExpression] + namedArg.key -> namedArg.value + }.toMap + + // We rearrange named arguments to match their positional order. + val rearrangedNamedArgs: Seq[Expression] = namedParameters.map { param => + namedArgMap.getOrElse( + param.name, + if (param.default.isEmpty) { + throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name) + } else { + param.default.get + } + ) + } + val rearrangedArgs = positionalArgs ++ rearrangedNamedArgs + assert(rearrangedArgs.size == parameters.size) + rearrangedArgs + } +} + +/** + * Represents a parameter of a function expression. Function expressions should use this class + * to construct the argument lists returned in [[Builder]] + * + * @param name The name of the string. + * @param default The default value of the argument. If the default is none, then that means the + * argument is required. If no argument is provided, an exception is thrown. + */ +case class InputParameter(name: String, default: Option[Expression] = None) + +/** + * Represents a method signature and the list of arguments it receives as input. + * Currently, overloads are not supported and only one FunctionSignature is allowed + * per function expression. + * + * @param parameters The list of arguments which the function takes + */ +case class FunctionSignature(parameters: Seq[InputParameter]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 955046e74e1ed..346f25580aaeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Join, LogicalPlan, SerdeInfo, Window} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, FunctionSignature, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.connector.catalog._ @@ -50,6 +50,78 @@ import org.apache.spark.sql.types._ */ private[sql] object QueryCompilationErrors extends QueryErrorsBase { + def unexpectedRequiredParameterInFunctionSignature( + functionName: String, functionSignature: FunctionSignature) : Throwable = { + val errorMessage = s"Function $functionName has an unexpected required argument for" + + s" the provided function signature $functionSignature. All required arguments should" + + " come before optional arguments." + SparkException.internalError(errorMessage) + } + + def namedArgumentsNotSupported(functionName: String) : Throwable = { + new AnalysisException( + errorClass = "NAMED_PARAMETERS_NOT_SUPPORTED", + messageParameters = Map("functionName" -> toSQLId(functionName)) + ) + } + + def positionalAndNamedArgumentDoubleReference( + functionName: String, parameterName: String) : Throwable = { + val errorClass = + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED" + new AnalysisException( + errorClass = errorClass, + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "parameterName" -> toSQLId(parameterName)) + ) + } + + def doubleNamedArgumentReference( + functionName: String, parameterName: String): Throwable = { + val errorClass = + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE" + new AnalysisException( + errorClass = errorClass, + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "parameterName" -> toSQLId(parameterName)) + ) + } + + def requiredParameterNotFound( + functionName: String, parameterName: String) : Throwable = { + new AnalysisException( + errorClass = "REQUIRED_PARAMETER_NOT_FOUND", + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "parameterName" -> toSQLId(parameterName)) + ) + } + + def unrecognizedParameterName( + functionName: String, argumentName: String, candidates: Seq[String]): Throwable = { + import org.apache.spark.sql.catalyst.util.StringUtils.orderSuggestedIdentifiersBySimilarity + + val inputs = candidates.map(candidate => Seq(candidate)).toSeq + val recommendations = orderSuggestedIdentifiersBySimilarity(argumentName, inputs) + .take(3) + new AnalysisException( + errorClass = "UNRECOGNIZED_PARAMETER_NAME", + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "argumentName" -> toSQLId(argumentName), + "proposal" -> recommendations.mkString(" ")) + ) + } + + def unexpectedPositionalArgument(functionName: String): Throwable = { + new AnalysisException( + errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", + messageParameters = Map("functionName" -> toSQLId(functionName)) + ) + } + def groupingIDMismatchError(groupingID: GroupingID, groupByExprs: Seq[Expression]): Throwable = { new AnalysisException( errorClass = "GROUPING_ID_COLUMN_MISMATCH", @@ -195,7 +267,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { def namedArgumentsNotEnabledError(functionName: String, argumentName: String): Throwable = { new AnalysisException( - errorClass = "NAMED_ARGUMENTS_SUPPORT_DISABLED", + errorClass = "NAMED_PARAMETER_SUPPORT_DISABLED", messageParameters = Map( "functionName" -> toSQLId(functionName), "argument" -> toSQLId(argumentName)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala new file mode 100644 index 0000000000000..dd5cb5e7d03c8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/NamedParameterFunctionSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NamedArgumentExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, FunctionSignature, InputParameter, NamedParametersSupport} +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.types.DataType + + +case class DummyExpression( + k1: Expression, + k2: Expression, + k3: Expression, + k4: Expression) extends Expression { + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = None + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = null + override def dataType: DataType = null + override def children: Seq[Expression] = Nil + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = null +} + +object DummyExpressionBuilder extends ExpressionBuilder { + def defaultFunctionSignature: FunctionSignature = { + FunctionSignature(Seq(InputParameter("k1"), + InputParameter("k2"), + InputParameter("k3"), + InputParameter("k4"))) + } + + override def functionSignature: Option[FunctionSignature] = + Some(defaultFunctionSignature) + override def build(funcName: String, expressions: Seq[Expression]): Expression = + DummyExpression(expressions(0), expressions(1), expressions(2), expressions(3)) +} + +class NamedArgumentFunctionSuite extends AnalysisTest { + + final val k1Arg = Literal("v1") + final val k2Arg = NamedArgumentExpression("k2", Literal("v2")) + final val k3Arg = NamedArgumentExpression("k3", Literal("v3")) + final val k4Arg = NamedArgumentExpression("k4", Literal("v4")) + final val namedK1Arg = NamedArgumentExpression("k1", Literal("v1-2")) + final val args = Seq(k1Arg, k4Arg, k2Arg, k3Arg) + final val expectedSeq = Seq(Literal("v1"), Literal("v2"), Literal("v3"), Literal("v4")) + final val signature = DummyExpressionBuilder.defaultFunctionSignature + final val illegalSignature = FunctionSignature(Seq( + InputParameter("k1"), InputParameter("k2", Option(Literal("v2"))), InputParameter("k3"))) + + test("Check rearrangement of expressions") { + val rearrangedArgs = NamedParametersSupport.defaultRearrange( + signature, args, "function") + for ((returnedArg, expectedArg) <- rearrangedArgs.zip(expectedSeq)) { + assert(returnedArg == expectedArg) + } + val rearrangedArgsWithBuilder = + FunctionRegistry.rearrangeExpressions("function", DummyExpressionBuilder, args) + for ((returnedArg, expectedArg) <- rearrangedArgsWithBuilder.zip(expectedSeq)) { + assert(returnedArg == expectedArg) + } + } + + private def parseRearrangeException(functionSignature: FunctionSignature, + expressions: Seq[Expression], + functionName: String = "function"): SparkThrowable = { + intercept[SparkThrowable]( + NamedParametersSupport.defaultRearrange(functionSignature, expressions, functionName)) + } + + private def parseExternalException[T <: FunctionBuilderBase[_]]( + functionName: String, + builder: T, + expressions: Seq[Expression]) : SparkThrowable = { + intercept[SparkThrowable]( + FunctionRegistry.rearrangeExpressions[T](functionName, builder, expressions)) + } + + test("DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT") { + val errorClass = + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED" + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, namedK1Arg), "foo"), + errorClass = errorClass, + parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k1")) + ) + checkError( + exception = parseRearrangeException( + signature, Seq(k1Arg, k2Arg, k3Arg, k4Arg, k4Arg), "foo"), + errorClass = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4")) + ) + } + + test("REQUIRED_PARAMETER_NOT_FOUND") { + checkError( + exception = parseRearrangeException(signature, Seq(k1Arg, k2Arg, k3Arg), "foo"), + errorClass = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map("functionName" -> toSQLId("foo"), "parameterName" -> toSQLId("k4")) + ) + } + + test("UNRECOGNIZED_PARAMETER_NAME") { + checkError( + exception = parseRearrangeException(signature, + Seq(k1Arg, k2Arg, k3Arg, k4Arg, NamedArgumentExpression("k5", Literal("k5"))), "foo"), + errorClass = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map("functionName" -> toSQLId("foo"), "argumentName" -> toSQLId("k5"), + "proposal" -> (toSQLId("k1") + " " + toSQLId("k2") + " " + toSQLId("k3"))) + ) + } + + test("UNEXPECTED_POSITIONAL_ARGUMENT") { + checkError( + exception = parseRearrangeException(signature, + Seq(k2Arg, k3Arg, k1Arg, k4Arg), "foo"), + errorClass = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map("functionName" -> toSQLId("foo")) + ) + } + + test("INTERNAL_ERROR: Enforce optional arguments after required arguments") { + val errorMessage = s"Function foo has an unexpected required argument for the provided" + + s" function signature ${illegalSignature}. All required arguments should come before" + + s" optional arguments." + checkError( + exception = parseRearrangeException(illegalSignature, args, "foo"), + errorClass = "INTERNAL_ERROR", + parameters = Map("message" -> errorMessage) + ) + } +} diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out index faa05535cb322..e01e0ca5ee011 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out @@ -2,111 +2,368 @@ -- !query SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd') -- !query analysis +Project [mask(AbCD123-@$#, Q, q, d, o) AS mask(AbCD123-@$#, Q, q, d, o)#x] ++- OneRowRelation + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#') +-- !query analysis +Project [mask(AbCD123-@$#, Q, q, d, o) AS mask(AbCD123-@$#, Q, q, d, o)#x] ++- OneRowRelation + + +-- !query +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd') +-- !query analysis +Project [mask(AbCD123-@$#, Q, q, d, null) AS mask(AbCD123-@$#, Q, q, d, NULL)#x] ++- OneRowRelation + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#') +-- !query analysis +Project [mask(AbCD123-@$#, Q, q, d, null) AS mask(AbCD123-@$#, Q, q, d, NULL)#x] ++- OneRowRelation + + +-- !query +create temporary view t2 as select * from values + ('val2a', 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ('val1c', 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ('val1b', null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ('val2e', 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1f', 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1b', 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ('val1c', 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ('val1e', 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ('val1f', 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ('val1b', null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) +-- !query analysis +CreateViewCommand `t2`, select * from values + ('val2a', 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ('val1c', 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ('val1b', null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ('val2e', 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1f', 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1b', 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ('val1c', 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ('val1e', 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ('val1f', 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ('val1b', null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i), false, false, LocalTempView, true + +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + +- SubqueryAlias t2 + +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + + +-- !query +SELECT hex(count_min_sketch(t2d, seed => 1, epsilon => 0.5d, confidence => 0.5d)) FROM t2 +-- !query analysis +Aggregate [hex(count_min_sketch(t2d#xL, 0.5, 0.5, 1, 0, 0)) AS hex(count_min_sketch(t2d, 0.5, 0.5, 1))#x] ++- SubqueryAlias t2 + +- View (`t2`, [t2a#x,t2b#x,t2c#x,t2d#xL,t2e#x,t2f#x,t2g#x,t2h#x,t2i#x]) + +- Project [cast(t2a#x as string) AS t2a#x, cast(t2b#x as smallint) AS t2b#x, cast(t2c#x as int) AS t2c#x, cast(t2d#xL as bigint) AS t2d#xL, cast(t2e#x as float) AS t2e#x, cast(t2f#x as double) AS t2f#x, cast(t2g#x as double) AS t2g#x, cast(t2h#x as timestamp) AS t2h#x, cast(t2i#x as date) AS t2i#x] + +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + +- SubqueryAlias t2 + +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + + +-- !query +SELECT hex(count_min_sketch(seed => 1, epsilon => 0.5d, confidence => 0.5d, column => t2d)) FROM t2 +-- !query analysis +Aggregate [hex(count_min_sketch(t2d#xL, 0.5, 0.5, 1, 0, 0)) AS hex(count_min_sketch(t2d, 0.5, 0.5, 1))#x] ++- SubqueryAlias t2 + +- View (`t2`, [t2a#x,t2b#x,t2c#x,t2d#xL,t2e#x,t2f#x,t2g#x,t2h#x,t2i#x]) + +- Project [cast(t2a#x as string) AS t2a#x, cast(t2b#x as smallint) AS t2b#x, cast(t2c#x as int) AS t2c#x, cast(t2d#xL as bigint) AS t2d#xL, cast(t2e#x as float) AS t2e#x, cast(t2f#x as double) AS t2f#x, cast(t2g#x as double) AS t2g#x, cast(t2h#x as timestamp) AS t2h#x, cast(t2i#x as date) AS t2i#x] + +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + +- SubqueryAlias t2 + +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + + +-- !query +SELECT hex(count_min_sketch(t2d, 0.5d, seed => 1, confidence => 0.5d)) FROM t2 +-- !query analysis +Aggregate [hex(count_min_sketch(t2d#xL, 0.5, 0.5, 1, 0, 0)) AS hex(count_min_sketch(t2d, 0.5, 0.5, 1))#x] ++- SubqueryAlias t2 + +- View (`t2`, [t2a#x,t2b#x,t2c#x,t2d#xL,t2e#x,t2f#x,t2g#x,t2h#x,t2i#x]) + +- Project [cast(t2a#x as string) AS t2a#x, cast(t2b#x as smallint) AS t2b#x, cast(t2c#x as int) AS t2c#x, cast(t2d#xL as bigint) AS t2d#xL, cast(t2e#x as float) AS t2e#x, cast(t2f#x as double) AS t2f#x, cast(t2g#x as double) AS t2g#x, cast(t2h#x as timestamp) AS t2h#x, cast(t2i#x as date) AS t2i#x] + +- Project [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + +- SubqueryAlias t2 + +- LocalRelation [t2a#x, t2b#x, t2c#x, t2d#xL, t2e#x, t2f#x, t2g#x, t2h#x, t2i#x] + + +-- !query +SELECT * FROM explode(collection => array(1, 2)) +-- !query analysis +Project [col#x] ++- Generate explode(array(1, 2)), false, [col#x] + +- OneRowRelation + + +-- !query +SELECT * FROM explode_outer(collection => map('a', 1, 'b', 2)) +-- !query analysis +Project [key#x, value#x] ++- Generate explode(map(a, 1, b, 2)), true, [key#x, value#x] + +- OneRowRelation + + +-- !query +SELECT * FROM explode(array(1, 2)), explode(array(3, 4)) +-- !query analysis +Project [col#x, col#x] ++- Join Inner + :- Generate explode(array(1, 2)), false, [col#x] + : +- OneRowRelation + +- Generate explode(array(3, 4)), false, [col#x] + +- OneRowRelation + + +-- !query +SELECT * FROM explode(array(1, 2)) AS t, LATERAL explode(array(3 * t.col, 4 * t.col)) +-- !query analysis +Project [col#x, col#x] ++- LateralJoin lateral-subquery#x [col#x && col#x], Inner + : +- Generate explode(array((3 * outer(col#x)), (4 * outer(col#x)))), false, [col#x] + : +- OneRowRelation + +- SubqueryAlias t + +- Generate explode(array(1, 2)), false, [col#x] + +- OneRowRelation + + +-- !query +SELECT num, val, 'Spark' FROM explode(map(1, 'a', 2, 'b')) AS t(num, val) +-- !query analysis +Project [num#x, val#x, Spark AS Spark#x] ++- SubqueryAlias t + +- Project [key#x AS num#x, value#x AS val#x] + +- Generate explode(map(1, a, 2, b)), false, [key#x, value#x] + +- OneRowRelation + + +-- !query +SELECT * FROM explode(collection => explode(array(1))) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + "sqlState" : "0A000", + "messageParameters" : { + "expression" : "\"explode(explode(array(1)))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 54, + "fragment" : "explode(collection => explode(array(1)))" + } ] +} + + +-- !query +SELECT * FROM explode(collection => explode(collection => array(1))) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + "sqlState" : "0A000", + "messageParameters" : { + "expression" : "\"explode(explode(array(1)))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 68, + "fragment" : "explode(collection => explode(collection => array(1)))" + } ] +} + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW v AS SELECT id FROM range(0, 8) +-- !query analysis +CreateViewCommand `v`, SELECT id FROM range(0, 8), false, true, LocalTempView, true + +- Project [id#xL] + +- Range (0, 8, step=1, splits=None) + + +-- !query +SELECT * FROM explode(collection => TABLE(v)) +-- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"", + "inputType" : "\"STRUCT\"", + "paramIndex" : "1", + "requiredType" : "(\"ARRAY\" or \"MAP\")", + "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 45, + "fragment" : "explode(collection => TABLE(v))" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNEXPECTED_POSITIONAL_ARGUMENT", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`mask`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, "stopIndex" : 98, - "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + "fragment" : "mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" } ] } -- !query -SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#') +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', digitChar => 'e') -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - "sqlState" : "42K09", + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + "sqlState" : "4274K", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(Q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#))\"" + "functionName" : "`mask`", + "parameterName" : "`digitChar`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 105, - "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#')" + "stopIndex" : 116, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', digitChar => 'e')" } ] } -- !query -SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd') +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbC') -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - "sqlState" : "42K09", + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + "sqlState" : "4274K", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), NULL)\"" + "functionName" : "`mask`", + "parameterName" : "`str`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 80, - "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd')" + "stopIndex" : 112, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbC')" } ] } -- !query -SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#') +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd') -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - "sqlState" : "42K09", + "errorClass" : "REQUIRED_PARAMETER_NOT_FOUND", + "sqlState" : "4274K", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(Q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#), NULL)\"" + "functionName" : "`mask`", + "parameterName" : "`str`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 87, - "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#')" + "stopIndex" : 83, + "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" } ] } -- !query -SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', cellular => 'automata') -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.INPUT_SIZE_NOT_ONE", - "sqlState" : "42K09", + "errorClass" : "UNRECOGNIZED_PARAMETER_NAME", + "sqlState" : "4274K", "messageParameters" : { - "exprName" : "upperChar", - "sqlExpr" : "\"mask(namedargumentexpression(q), AbCD123-@$#, namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + "argumentName" : "`cellular`", + "functionName" : "`mask`", + "proposal" : "`str` `upperChar` `otherChar`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 98, - "fragment" : "mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + "stopIndex" : 122, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', cellular => 'automata')" + } ] +} + + +-- !query +SELECT encode(str => 'a', charset => 'utf-8') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NAMED_PARAMETERS_NOT_SUPPORTED", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`encode`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 45, + "fragment" : "encode(str => 'a', charset => 'utf-8')" + } ] +} + + +-- !query +SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o', 'k') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "6", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2, 3, 4, 5]", + "functionName" : "`mask`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 51, + "fragment" : "mask('AbCD123-@$#', 'Q', 'q', 'd', 'o', 'k')" } ] } diff --git a/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql b/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql index aeb7b1e85cd8c..99f33d7815255 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/named-function-arguments.sql @@ -1,5 +1,60 @@ +-- Test for named arguments for Mask SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd'); SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#'); SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd'); SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#'); + +-- Test for named arguments for CountMinSketchAgg +create temporary view t2 as select * from values + ('val2a', 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ('val1c', 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ('val1b', null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ('val2e', 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1f', 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1b', 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ('val1c', 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ('val1e', 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ('val1f', 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ('val1b', null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i); + +SELECT hex(count_min_sketch(t2d, seed => 1, epsilon => 0.5d, confidence => 0.5d)) FROM t2; +SELECT hex(count_min_sketch(seed => 1, epsilon => 0.5d, confidence => 0.5d, column => t2d)) FROM t2; +SELECT hex(count_min_sketch(t2d, 0.5d, seed => 1, confidence => 0.5d)) FROM t2; + +-- Test for tabled value functions explode and explode_outer +SELECT * FROM explode(collection => array(1, 2)); +SELECT * FROM explode_outer(collection => map('a', 1, 'b', 2)); +SELECT * FROM explode(array(1, 2)), explode(array(3, 4)); +SELECT * FROM explode(array(1, 2)) AS t, LATERAL explode(array(3 * t.col, 4 * t.col)); +SELECT num, val, 'Spark' FROM explode(map(1, 'a', 2, 'b')) AS t(num, val); + +-- Test for wrapped EXPLODE call to check error preservation +SELECT * FROM explode(collection => explode(array(1))); +SELECT * FROM explode(collection => explode(collection => array(1))); + +-- Test with TABLE parser rule +CREATE OR REPLACE TEMPORARY VIEW v AS SELECT id FROM range(0, 8); +SELECT * FROM explode(collection => TABLE(v)); + +-- Unexpected positional argument SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd'); + +-- Duplicate parameter assignment +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', digitChar => 'e'); +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbC'); + +-- Required parameter not found +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd'); + +-- Unrecognized parameter name +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', cellular => 'automata'); + +-- Named arguments not supported +SELECT encode(str => 'a', charset => 'utf-8'); + +-- Wrong number of arguments +SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o', 'k'); diff --git a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out index 842374542ec6e..3b223cc0e1529 100644 --- a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out @@ -2,121 +2,365 @@ -- !query SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd') -- !query schema +struct +-- !query output +QqQQdddoooo + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#') +-- !query schema +struct +-- !query output +QqQQdddoooo + + +-- !query +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd') +-- !query schema +struct +-- !query output +QqQQddd-@$# + + +-- !query +SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#') +-- !query schema +struct +-- !query output +QqQQddd-@$# + + +-- !query +create temporary view t2 as select * from values + ('val2a', 6S, 12, 14L, float(15), 20D, 20E2, timestamp '2014-04-04 01:01:00.000', date '2014-04-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 8S, 16, 119L, float(17), 25D, 26E2, timestamp '2015-05-04 01:01:00.000', date '2015-05-04'), + ('val1c', 12S, 16, 219L, float(17), 25D, 26E2, timestamp '2016-05-04 01:01:00.000', date '2016-05-04'), + ('val1b', null, 16, 319L, float(17), 25D, 26E2, timestamp '2017-05-04 01:01:00.000', null), + ('val2e', 8S, null, 419L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1f', 19S, null, 519L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'), + ('val1b', 10S, 12, 19L, float(17), 25D, 26E2, timestamp '2014-06-04 01:01:00.000', date '2014-06-04'), + ('val1b', 8S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-07-04 01:01:00.000', date '2014-07-04'), + ('val1c', 12S, 16, 19L, float(17), 25D, 26E2, timestamp '2014-08-04 01:01:00.000', date '2014-08-05'), + ('val1e', 8S, null, 19L, float(17), 25D, 26E2, timestamp '2014-09-04 01:01:00.000', date '2014-09-04'), + ('val1f', 19S, null, 19L, float(17), 25D, 26E2, timestamp '2014-10-04 01:01:00.000', date '2014-10-04'), + ('val1b', null, 16, 19L, float(17), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', null) + as t2(t2a, t2b, t2c, t2d, t2e, t2f, t2g, t2h, t2i) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT hex(count_min_sketch(t2d, seed => 1, epsilon => 0.5d, confidence => 0.5d)) FROM t2 +-- !query schema +struct +-- !query output +00000001000000000000000D0000000100000004000000005D8D6AB90000000000000002000000000000000700000000000000010000000000000003 + + +-- !query +SELECT hex(count_min_sketch(seed => 1, epsilon => 0.5d, confidence => 0.5d, column => t2d)) FROM t2 +-- !query schema +struct +-- !query output +00000001000000000000000D0000000100000004000000005D8D6AB90000000000000002000000000000000700000000000000010000000000000003 + + +-- !query +SELECT hex(count_min_sketch(t2d, 0.5d, seed => 1, confidence => 0.5d)) FROM t2 +-- !query schema +struct +-- !query output +00000001000000000000000D0000000100000004000000005D8D6AB90000000000000002000000000000000700000000000000010000000000000003 + + +-- !query +SELECT * FROM explode(collection => array(1, 2)) +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT * FROM explode_outer(collection => map('a', 1, 'b', 2)) +-- !query schema +struct +-- !query output +a 1 +b 2 + + +-- !query +SELECT * FROM explode(array(1, 2)), explode(array(3, 4)) +-- !query schema +struct +-- !query output +1 3 +1 4 +2 3 +2 4 + + +-- !query +SELECT * FROM explode(array(1, 2)) AS t, LATERAL explode(array(3 * t.col, 4 * t.col)) +-- !query schema +struct +-- !query output +1 3 +1 4 +2 6 +2 8 + + +-- !query +SELECT num, val, 'Spark' FROM explode(map(1, 'a', 2, 'b')) AS t(num, val) +-- !query schema +struct +-- !query output +1 a Spark +2 b Spark + + +-- !query +SELECT * FROM explode(collection => explode(array(1))) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + "sqlState" : "0A000", + "messageParameters" : { + "expression" : "\"explode(explode(array(1)))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 54, + "fragment" : "explode(collection => explode(array(1)))" + } ] +} + + +-- !query +SELECT * FROM explode(collection => explode(collection => array(1))) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + "sqlState" : "0A000", + "messageParameters" : { + "expression" : "\"explode(explode(array(1)))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 68, + "fragment" : "explode(collection => explode(collection => array(1)))" + } ] +} + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW v AS SELECT id FROM range(0, 8) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT * FROM explode(collection => TABLE(v)) +-- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"", + "inputType" : "\"STRUCT\"", + "paramIndex" : "1", + "requiredType" : "(\"ARRAY\" or \"MAP\")", + "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 15, + "stopIndex" : 45, + "fragment" : "explode(collection => TABLE(v))" + } ] +} + + +-- !query +SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNEXPECTED_POSITIONAL_ARGUMENT", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`mask`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, "stopIndex" : 98, - "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + "fragment" : "mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" } ] } -- !query -SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#') +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', digitChar => 'e') -- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - "sqlState" : "42K09", + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + "sqlState" : "4274K", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(Q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#))\"" + "functionName" : "`mask`", + "parameterName" : "`digitChar`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 105, - "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbCD123-@$#')" + "stopIndex" : 116, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', digitChar => 'e')" } ] } -- !query -SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd') +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbC') -- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - "sqlState" : "42K09", + "errorClass" : "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + "sqlState" : "4274K", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(AbCD123-@$#, namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), NULL)\"" + "functionName" : "`mask`", + "parameterName" : "`str`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 80, - "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', digitChar => 'd')" + "stopIndex" : 112, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', str => 'AbC')" } ] } -- !query -SELECT mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#') +SELECT mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd') -- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", - "sqlState" : "42K09", + "errorClass" : "REQUIRED_PARAMETER_NOT_FOUND", + "sqlState" : "4274K", "messageParameters" : { - "inputExpr" : "\"namedargumentexpression(Q)\"", - "inputName" : "upperChar", - "inputType" : "\"STRING\"", - "sqlExpr" : "\"mask(namedargumentexpression(q), namedargumentexpression(Q), namedargumentexpression(d), namedargumentexpression(AbCD123-@$#), NULL)\"" + "functionName" : "`mask`", + "parameterName" : "`str`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 87, - "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', digitChar => 'd', str => 'AbCD123-@$#')" + "stopIndex" : 83, + "fragment" : "mask(lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" } ] } -- !query -SELECT mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd') +SELECT mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', cellular => 'automata') -- !query schema struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "DATATYPE_MISMATCH.INPUT_SIZE_NOT_ONE", - "sqlState" : "42K09", + "errorClass" : "UNRECOGNIZED_PARAMETER_NAME", + "sqlState" : "4274K", "messageParameters" : { - "exprName" : "upperChar", - "sqlExpr" : "\"mask(namedargumentexpression(q), AbCD123-@$#, namedargumentexpression(Q), namedargumentexpression(o), namedargumentexpression(d))\"" + "argumentName" : "`cellular`", + "functionName" : "`mask`", + "proposal" : "`str` `upperChar` `otherChar`" }, "queryContext" : [ { "objectType" : "", "objectName" : "", "startIndex" : 8, - "stopIndex" : 98, - "fragment" : "mask(lowerChar => 'q', 'AbCD123-@$#', upperChar => 'Q', otherChar => 'o', digitChar => 'd')" + "stopIndex" : 122, + "fragment" : "mask('AbCD123-@$#', lowerChar => 'q', upperChar => 'Q', otherChar => 'o', digitChar => 'd', cellular => 'automata')" + } ] +} + + +-- !query +SELECT encode(str => 'a', charset => 'utf-8') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NAMED_PARAMETERS_NOT_SUPPORTED", + "sqlState" : "4274K", + "messageParameters" : { + "functionName" : "`encode`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 45, + "fragment" : "encode(str => 'a', charset => 'utf-8')" + } ] +} + + +-- !query +SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o', 'k') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "6", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2, 3, 4, 5]", + "functionName" : "`mask`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 51, + "fragment" : "mask('AbCD123-@$#', 'Q', 'q', 'd', 'o', 'k')" } ] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index 2731760f7ef05..7ebb677b12158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -32,16 +32,11 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL intercept[ParseException](sql(sqlText).collect()) } - test("NAMED_ARGUMENTS_SUPPORT_DISABLED: named arguments not turned on") { + test("NAMED_PARAMETER_SUPPORT_DISABLED: named arguments not turned on") { withSQLConf("spark.sql.allowNamedFunctionArguments" -> "false") { - checkError( - exception = parseException("SELECT * FROM encode(value => 'abc', charset => 'utf-8')"), - errorClass = "NAMED_ARGUMENTS_SUPPORT_DISABLED", - parameters = Map("functionName" -> toSQLId("encode"), "argument" -> toSQLId("value")) - ) checkError( exception = parseException("SELECT explode(arr => array(10, 20))"), - errorClass = "NAMED_ARGUMENTS_SUPPORT_DISABLED", + errorClass = "NAMED_PARAMETER_SUPPORT_DISABLED", parameters = Map("functionName"-> toSQLId("explode"), "argument" -> toSQLId("arr")) ) } From 28eaa9c0861984d925098172ad08fa5d255ef8bc Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Tue, 18 Jul 2023 11:03:19 +0800 Subject: [PATCH 005/986] [SPARK-44154][SQL] Added more unit tests to BitmapExpressionUtilsSuite and made minor improvements to Bitmap Aggregate Expressions ### What changes were proposed in this pull request? I firstly added more unit tests for the `BITMAT_BIT_POSITION` and `BITMAP_BUCKET_NUMBER` expressions. Secondly, I made a minor improvement in the implementation of the `BITMAP_CONSTRUCT_AGG` and `BUTMAP_OR_AGG` expressions, where I converted `inputAggBufferAttributes` from a method to a value. ### Why are the changes needed? The unit tests cover more corner cases. Having `inputAggBufferAttributes` as a value makes it so that the AttributeReferences aren't reinitialized every time `inputAggBufferAttributes` is referred to. ### How was this patch tested? I reran all the tests for Bitmap expressions and they succeeded. The test suites were `BitmapExpressionUtilsSuite` and `BitmapExpressionsQuerySuite`. Closes #42043 from harshmotw-db/harsh-dev. Authored-by: Harsh Motwani Signed-off-by: Wenchen Fan (cherry picked from commit df550cfc3ad820448ed5040df94c2b32fb161290) Signed-off-by: Wenchen Fan --- .../expressions/bitmapExpressions.scala | 16 ++++++++-------- .../BitmapExpressionUtilsSuite.scala | 17 +++++++++++------ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala index 2c89428f24330..bd8f4efa059bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala @@ -177,17 +177,17 @@ case class BitmapConstructAgg(child: Expression, override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + // The aggregation buffer is a fixed size binary. + private val bitmapAttr = AttributeReference("bitmap", BinaryType, nullable = false)() + override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil override def defaultResult: Option[Literal] = Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0))) - override def inputAggBufferAttributes: Seq[AttributeReference] = + override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - // The aggregation buffer is a fixed size binary. - private val bitmapAttr = AttributeReference("bitmap", BinaryType, nullable = false)() - override def initialize(buffer: InternalRow): Unit = { buffer.update(mutableAggBufferOffset, Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)) } @@ -270,17 +270,17 @@ case class BitmapOrAgg(child: Expression, override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + // The aggregation buffer is a fixed size binary. + private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)() + override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil override def defaultResult: Option[Literal] = Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0))) - override def inputAggBufferAttributes: Seq[AttributeReference] = + override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - // The aggregation buffer is a fixed size binary. - private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)() - override def initialize(buffer: InternalRow): Unit = { buffer.update(mutableAggBufferOffset, Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)) } diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala index ee1f4026fedb8..53935c66c6136 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtilsSuite.scala @@ -22,23 +22,27 @@ import org.apache.spark.SparkFunSuite class BitmapExpressionUtilsSuite extends SparkFunSuite { test("bitmap_bucket_number with positive inputs") { - Seq((0L, 0L), (1L, 1L), (2L, 1L), (3L, 1L), - (32768L, 1L), (32769L, 2L), (32770L, 2L)).foreach { + Seq((0L, 0L), (1L, 1L), (2L, 1L), (3L, 1L), (65537L, 3L), (65536L, 2L), (3232423L, 99L), + (4538345L, 139L), (845894934L, 25815L), (2147483647L, 65536L), + (Long.MaxValue, 281474976710656L), (32768L, 1L), (32769L, 2L), (32770L, 2L)).foreach { case (input, expected) => assert(BitmapExpressionUtils.bitmapBucketNumber(input) == expected) } } test("bitmap_bucket_number with negative inputs") { - Seq((-1L, 0L), (-2L, 0L), (-3L, 0L), - (-32767L, 0L), (-32768L, -1L), (-32769L, -1L)).foreach { + Seq((-1L, 0L), (-2L, 0L), (-3L, 0L), (-65536L, -2L), (65537L, 3L), (-65535L, -1L), + (-3843485L, -117L), (-2147483647L, -65535L), (-2147483648L, -65536L), + (Long.MinValue, -281474976710656L), (Long.MinValue + 1, -281474976710655L), (-32767L, 0L), + (-32768L, -1L), (-32769L, -1L)).foreach { case (input, expected) => assert(BitmapExpressionUtils.bitmapBucketNumber(input) == expected) } } test("bitmap_bit_position with positive inputs") { - Seq((0L, 0L), (1L, 0L), (2L, 1L), (3L, 2L), + Seq((0L, 0L), (1L, 0L), (2L, 1L), (3L, 2L), (65537L, 0L), (65536L, 32767L), (3232423L, 21158L), + (4538345L, 16360L), (845894934L, 21781L), (2147483647L, 32766L), (Long.MaxValue, 32766L), (32768L, 32767L), (32769L, 0L), (32770L, 1L)).foreach { case (input, expected) => assert(BitmapExpressionUtils.bitmapBitPosition(input) == expected) @@ -46,7 +50,8 @@ class BitmapExpressionUtilsSuite extends SparkFunSuite { } test("bitmap_bit_position with negative inputs") { - Seq((-1L, 1L), (-2L, 2L), (-3L, 3L), + Seq((-1L, 1L), (-2L, 2L), (-3L, 3L), (-65536L, 0L), (-65535L, 32767L), (-3843485L, 9629L), + (-2147483647L, 32767L), (-2147483648L, 0L), (Long.MinValue, 0L), (Long.MinValue + 1, 32767L), (-32767L, 32767L), (-32768L, 0L), (-32769L, 1L)).foreach { case (input, expected) => assert(BitmapExpressionUtils.bitmapBitPosition(input) == expected) From bda0545e785b2c15db57b07b2b72b3a164d8a929 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 18 Jul 2023 14:50:21 +0900 Subject: [PATCH 006/986] [SPARK-44348][CONNECT][FOLLOW-UP] Avoid double slashes in the URI ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/41942 that does `substring(1)` to remove the leading slash so it makes it relative parts from URI. Otherwise, it can end up with having double slashes in the middle. ### Why are the changes needed? To avoid having unnecessary double slashes ... and save one byte :-) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually tested. It's really trivial. Closes #42051 from HyukjinKwon/minor-change. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- core/src/main/scala/org/apache/spark/util/Utils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index eafb444366209..048c6ebfa7489 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -442,7 +442,7 @@ private[spark] object Utils extends Logging with SparkClassUtils { // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. // We should remove it after we get the raw path. - encodeRelativeUnixPathToURIRawPath(fileName).substring(1) + encodeRelativeUnixPathToURIRawPath(fileName) } /** @@ -453,7 +453,7 @@ private[spark] object Utils extends Logging with SparkClassUtils { // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. // We should remove it after we get the raw path. - new URI("file", null, "localhost", -1, "/" + path, null, null).getRawPath + new URI("file", null, "localhost", -1, "/" + path, null, null).getRawPath.substring(1) } /** From c9bfcb9448b51985ad9a5361fbaaf828ad670cdc Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 18 Jul 2023 14:59:09 +0900 Subject: [PATCH 007/986] [SPARK-43967][PYTHON] Support regular Python UDTFs with empty return values ### What changes were proposed in this pull request? This PR adds support for regular (non-arrow-optimized) Python UDTFs that return empty results, for example: ``` def eval(self): ... ``` or ``` def eval(self): yield ``` This feature is already available in arrow-optimized UDTFs. ### Why are the changes needed? To align the behavior of regular Python UDTFs with arrow-optimized UDTFs. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can run regular Python UDTFs with empty return statement. ### How was this patch tested? Existing UTs. Closes #42044 from allisonwang-db/spark-43967-empty-return. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_udtf.py | 35 ++----------------- python/pyspark/worker.py | 20 +++++++++-- .../execution/python/EvalPythonUDTFExec.scala | 11 +++++- 3 files changed, 30 insertions(+), 36 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index f109302dec512..ec3379accca60 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -164,17 +164,14 @@ class TestUDTF: def eval(self, a: int): ... - # TODO(SPARK-43967): Support Python UDTFs with empty return values - with self.assertRaisesRegex(PythonException, "TypeError"): - TestUDTF(lit(1)).collect() + self.assertEqual(TestUDTF(lit(1)).collect(), []) @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): return - with self.assertRaisesRegex(PythonException, "TypeError"): - TestUDTF(lit(1)).collect() + self.assertEqual(TestUDTF(lit(1)).collect(), []) def test_udtf_with_conditional_return(self): class TestUDTF: @@ -195,9 +192,7 @@ class TestUDTF: def eval(self, a: int): yield - # TODO(SPARK-43967): Support Python UDTFs with empty return values - with self.assertRaisesRegex(Py4JJavaError, "java.lang.NullPointerException"): - TestUDTF(lit(1)).collect() + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=None)]) def test_udtf_with_none_output(self): @udtf(returnType="a: int") @@ -807,21 +802,6 @@ def eval(self, a: int): func = udtf(TestUDTF, returnType="a: int") self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) - def test_udtf_eval_with_no_return(self): - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - ... - - self.assertEqual(TestUDTF(lit(1)).collect(), []) - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - return - - self.assertEqual(TestUDTF(lit(1)).collect(), []) - def test_udtf_terminate_with_wrong_num_output(self): # The error message for arrow-optimized UDTF is different from regular UDTF. err_msg = "The number of columns in the result does not match the specified schema." @@ -848,15 +828,6 @@ def terminate(self): with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() - def test_udtf_with_empty_yield(self): - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield - - # Arrow-optimized UDTF can support this. - self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=None)]) - def test_udtf_with_wrong_num_output(self): # The error message for arrow-optimized UDTF is different. err_msg = "The number of columns in the result does not match the specified schema." diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8c12312da27bc..2445b46970cdd 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -579,7 +579,21 @@ def mapper(_, it): def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal - return lambda *a: map(toInternal, f(*a)) + + # Evaluate the function and return a tuple back to the executor. + def evaluate(*a) -> tuple: + res = f(*a) + if res is None: + # If the function returns None or does not have an explicit return statement, + # an empty tuple is returned to the executor. + # This is because directly constructing tuple(None) results in an exception. + return tuple() + else: + # If the function returns a result, we map it to the internal representation and + # returns the results as a tuple. + return tuple(map(toInternal, res)) + + return evaluate eval = wrap_udtf(getattr(udtf, "eval"), return_type) @@ -592,11 +606,11 @@ def wrap_udtf(f, return_type): def mapper(_, it): try: for a in it: - yield tuple(eval(*[a[o] for o in arg_offsets])) + yield eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: try: - yield tuple(terminate()) + yield terminate() except BaseException as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala index 827b2fc2bb395..fab417a0f86fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala @@ -103,6 +103,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { } val joined = new JoinedRow + val nullRow = new GenericInternalRow(udtf.elementSchema.length) val resultProj = UnsafeProjection.create(output, output) outputRowIterator.flatMap { outputRows => @@ -118,7 +119,15 @@ trait EvalPythonUDTFExec extends UnaryExecNode { // from the UDTF are from the `terminate()` call. We leave the left side as the last // element of its child output to keep it consistent with the Generate implementation // and Hive UDTFs. - outputRows.map(r => resultProj(joined.withRight(r))) + outputRows.map { r => + // When the UDTF's result is None, such as `def eval(): yield`, + // we join it with a null row to avoid NullPointerException. + if (r == null) { + resultProj(joined.withRight(nullRow)) + } else { + resultProj(joined.withRight(r)) + } + } } } } From cd312e70a775013f844a55b965aa18eef4185442 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 18 Jul 2023 17:14:44 +0900 Subject: [PATCH 008/986] [SPARK-44471][INFRA][BRANCH-3.5] Add Github action test job for branch-3.5 ### What changes were proposed in this pull request? Add Github action test job for branch-3.5 ### Why are the changes needed? Daily test for branch-3.5 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Daily build for branch 3.5 Closes #42057 from xuanyuanking/SPARK-44471-3.5. Lead-authored-by: Yuanjian Li Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .github/workflows/build_and_test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 98da34f7cdef1..f5a109d95de5f 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -30,8 +30,7 @@ on: description: Branch to run the build against required: false type: string - # Change 'master' to 'branch-3.5' in branch-3.5 branch after cutting it. - default: master + default: branch-3.5 hadoop: description: Hadoop version to run with. HADOOP_PROFILE environment variable should accept it. required: false From 80dee3c955943f26d79c75fd68456a6ba44ad6f1 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 18 Jul 2023 17:31:51 +0900 Subject: [PATCH 009/986] [SPARK-43923][CONNECT][FOLLOW-UP][TESTS] Skip "Test observe response" at SparkConnectServiceSuite This PR skips "Test observe response" at SparkConnectServiceSuite for now. To unblock other PRs. No, dev-only. Skipped unittests Closes #42059 from HyukjinKwon/SPARK-43923-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon (cherry picked from commit 32b7b58114d8dc828f20990921f0cd6067157d55) Signed-off-by: Hyukjin Kwon --- .../spark/sql/connect/planner/SparkConnectServiceSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 498084efb8f3f..cfa37b86cd41a 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -564,7 +564,8 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } - test("Test observe response") { + // TODO(SPARK-44474): Reenable Test observe response at SparkConnectServiceSuite + ignore("Test observe response") { // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("test") { From 799f70477fbff1ba26481c87bb7f7c6cfe55d026 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 18 Jul 2023 17:39:56 +0800 Subject: [PATCH 010/986] [SPARK-43203][SQL][FOLLOWUP] V2SessionCatalog.dropTable should handle null table ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/41348 . Previously `V2SessionCatalog.dropTable` treated null table as table not exists, but #41348 broke it. This PR fixes it. ### Why are the changes needed? to keep old behavior. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests Closes #42056 from cloud-fan/mm. Authored-by: Wenchen Fan Signed-off-by: Kent Yao (cherry picked from commit 704131b13f88004a587d5959094147fc9f4c4b73) Signed-off-by: Kent Yao --- .../datasources/v2/V2SessionCatalog.scala | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index e5496a4676015..f311ccbb6309d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -115,10 +115,6 @@ class V2SessionCatalog(catalog: SessionCatalog) createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties) } - override def purgeTable(ident: Identifier): Boolean = { - dropTableInternal(ident, purge = true) - } - // TODO: remove it when no tests calling this deprecated method. override def createTable( ident: Identifier, @@ -202,6 +198,10 @@ class V2SessionCatalog(catalog: SessionCatalog) loadTable(ident) } + override def purgeTable(ident: Identifier): Boolean = { + dropTableInternal(ident, purge = true) + } + override def dropTable(ident: Identifier): Boolean = { dropTableInternal(ident) } @@ -218,14 +218,16 @@ class V2SessionCatalog(catalog: SessionCatalog) foundType = v1Table.tableType.name, alternative = "DROP VIEW" ) + case null => + false case _ => + catalog.invalidateCachedTable(ident.asTableIdentifier) + catalog.dropTable( + ident.asTableIdentifier, + ignoreIfNotExists = true, + purge = purge) + true } - catalog.invalidateCachedTable(ident.asTableIdentifier) - catalog.dropTable( - ident.asTableIdentifier, - ignoreIfNotExists = true, - purge = purge) - true } catch { case _: NoSuchTableException => false From e718d0d4dc57f9e0ecdb7067ee9778250200fe83 Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Wed, 19 Jul 2023 07:39:11 +0900 Subject: [PATCH 011/986] [SPARK-44464][SS] Implement applyInPandasWithState in PySpark ### What changes were proposed in this pull request? Change the serialization format for group-by-with-state outputs: include an explicit hidden column indicating how many data and state records there are. ### Why are the changes needed? The current implementation of ApplyInPandasWithStatePythonRunner cannot deal with outputs where the first column of the row is null, as it cannot distinguish the case where the column is null, or the field is filled as the number of data records are smaller than state records. It causes incorrect results for the former case. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add unit tests that cover null cases and different other scenarios. Closes #42046 from siying/pypanda. Authored-by: Siying Dong Signed-off-by: Jungtaek Lim (cherry picked from commit 6cebd97f4670e1e998b0064ddf4db11050fe52dd) Signed-off-by: Jungtaek Lim --- python/pyspark/sql/pandas/serializers.py | 43 ++++++- ...st_parity_pandas_grouped_map_with_state.py | 20 +++ .../test_pandas_grouped_map_with_state.py | 114 ++++++++++++++++-- .../ApplyInPandasWithStatePythonRunner.scala | 84 +++++++------ 4 files changed, 208 insertions(+), 53 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f835ea57b7751..f22a73cbbef3b 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -27,7 +27,15 @@ _create_converter_from_pandas, _create_converter_to_pandas, ) -from pyspark.sql.types import DataType, StringType, StructType, BinaryType, StructField, LongType +from pyspark.sql.types import ( + DataType, + StringType, + StructType, + BinaryType, + StructField, + LongType, + IntegerType, +) class SpecialLengths: @@ -603,6 +611,15 @@ def __init__( self.utf8_deserializer = UTF8Deserializer() self.state_object_schema = state_object_schema + self.result_count_df_type = StructType( + [ + StructField("dataCount", IntegerType()), + StructField("stateCount", IntegerType()), + ] + ) + + self.result_count_pdf_arrow_type = to_arrow_type(self.result_count_df_type) + self.result_state_df_type = StructType( [ StructField("properties", StringType()), @@ -799,16 +816,26 @@ def construct_state_pdf(state): def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): """ Construct a new Arrow RecordBatch based on output pandas DataFrames and states. Each - one matches to the single struct field for Arrow schema, hence the return value of - Arrow RecordBatch will have schema with two fields, in `data`, `state` order. + one matches to the single struct field for Arrow schema. We also need an extra one to + indicate array length for data and state, so the return value of Arrow RecordBatch will + have schema with three fields, in `count`, `data`, `state` order. (Readers are expected to access the field via position rather than the name. We do not guarantee the name of the field.) Note that Arrow RecordBatch requires all columns to have all same number of rows, - hence this function inserts empty data for state/data with less elements to compensate. + hence this function inserts empty data for count/state/data with less elements to + compensate. """ - max_data_cnt = max(pdf_data_cnt, state_data_cnt) + max_data_cnt = max(1, max(pdf_data_cnt, state_data_cnt)) + + # We only use the first row in the count column, and fill other rows to be the same + # value, hoping it is more friendly for compression, in case it is needed. + count_dict = { + "dataCount": [pdf_data_cnt] * max_data_cnt, + "stateCount": [state_data_cnt] * max_data_cnt, + } + count_pdf = pd.DataFrame.from_dict(count_dict) empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt empty_row_cnt_in_state = max_data_cnt - state_data_cnt @@ -829,7 +856,11 @@ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_dat merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) return self._create_batch( - [(merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type)] + [ + (count_pdf, self.result_count_pdf_arrow_type), + (merged_pdf, pdf_schema), + (merged_state_pdf, self.result_state_pdf_arrow_type), + ] ) def serialize_batches(): diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py index 3a38cd17406ce..dc3bdf28f81c8 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py @@ -29,6 +29,26 @@ class GroupedApplyInPandasWithStateTests( def test_apply_in_pandas_with_state_basic(self): super().test_apply_in_pandas_with_state_basic() + @unittest.skip("foreachBatch will be supported in SPARK-42944.") + def test_apply_in_pandas_with_state_basic_no_state(self): + super().test_apply_in_pandas_with_state_basic() + + @unittest.skip("foreachBatch will be supported in SPARK-42944.") + def test_apply_in_pandas_with_state_basic_no_state_no_data(self): + super().test_apply_in_pandas_with_state_basic() + + @unittest.skip("foreachBatch will be supported in SPARK-42944.") + def test_apply_in_pandas_with_state_basic_more_data(self): + super().test_apply_in_pandas_with_state_basic() + + @unittest.skip("foreachBatch will be supported in SPARK-42944.") + def test_apply_in_pandas_with_state_basic_fewer_data(self): + super().test_apply_in_pandas_with_state_basic() + + @unittest.skip("foreachBatch will be supported in SPARK-42944.") + def test_apply_in_pandas_with_state_basic_with_null(self): + super().test_apply_in_pandas_with_state_basic() + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state import * # noqa: F401 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py index a2a6544faa0fd..e1ec97928f72b 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py @@ -60,7 +60,7 @@ def conf(cls): cfg.set("spark.sql.shuffle.partitions", "5") return cfg - def test_apply_in_pandas_with_state_basic(self): + def _test_apply_in_pandas_with_state_basic(self, func, check_results): input_path = tempfile.mkdtemp() def prepare_test_resource(): @@ -81,6 +81,22 @@ def prepare_test_resource(): ) state_type = StructType([StructField("c", LongType())]) + q = ( + df.groupBy(df["value"]) + .applyInPandasWithState( + func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout + ) + .writeStream.queryName("this_query") + .foreachBatch(check_results) + .outputMode("update") + .start() + ) + + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + + def test_apply_in_pandas_with_state_basic(self): def func(key, pdf_iter, state): assert isinstance(state, GroupState) @@ -98,20 +114,92 @@ def check_results(batch_df, _): {Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")}, ) - q = ( - df.groupBy(df["value"]) - .applyInPandasWithState( - func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout + self._test_apply_in_pandas_with_state_basic(func, check_results) + + def test_apply_in_pandas_with_state_basic_no_state(self): + def func(key, pdf_iter, state): + assert isinstance(state, GroupState) + # 2 data rows + yield pd.DataFrame({"key": [key[0], "foo"], "countAsString": ["100", "222"]}) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.sort("key").collect()), + { + Row(key="hello", countAsString="100"), + Row(key="this", countAsString="100"), + Row(key="foo", countAsString="222"), + }, ) - .writeStream.queryName("this_query") - .foreachBatch(check_results) - .outputMode("update") - .start() - ) - self.assertEqual(q.name, "this_query") - self.assertTrue(q.isActive) - q.processAllAvailable() + self._test_apply_in_pandas_with_state_basic(func, check_results) + + def test_apply_in_pandas_with_state_basic_no_state_no_data(self): + def func(key, pdf_iter, state): + assert isinstance(state, GroupState) + # 2 data rows + yield pd.DataFrame({"key": [], "countAsString": []}) + + def check_results(batch_df, _): + self.assertTrue(len(set(batch_df.sort("key").collect())) == 0) + + self._test_apply_in_pandas_with_state_basic(func, check_results) + + def test_apply_in_pandas_with_state_basic_more_data(self): + # Test data rows returned are more or fewer than state. + def func(key, pdf_iter, state): + state.update((1,)) + assert isinstance(state, GroupState) + # 3 rows + yield pd.DataFrame( + {"key": [key[0], "foo", key[0] + "_2"], "countAsString": ["1", "666", "2"]} + ) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.sort("key").collect()), + { + Row(key="hello", countAsString="1"), + Row(key="foo", countAsString="666"), + Row(key="hello_2", countAsString="2"), + Row(key="this", countAsString="1"), + Row(key="this_2", countAsString="2"), + }, + ) + + self._test_apply_in_pandas_with_state_basic(func, check_results) + + def test_apply_in_pandas_with_state_basic_fewer_data(self): + # Test data rows returned are more or fewer than state. + def func(key, pdf_iter, state): + state.update((1,)) + assert isinstance(state, GroupState) + yield pd.DataFrame({"key": [], "countAsString": []}) + + def check_results(batch_df, _): + self.assertTrue(len(set(batch_df.sort("key").collect())) == 0) + + self._test_apply_in_pandas_with_state_basic(func, check_results) + + def test_apply_in_pandas_with_state_basic_with_null(self): + def func(key, pdf_iter, state): + assert isinstance(state, GroupState) + + total_len = 0 + for pdf in pdf_iter: + total_len += len(pdf) + + state.update((total_len,)) + assert state.get[0] == 1 + yield pd.DataFrame({"key": [None], "countAsString": [str(total_len)]}) + + def check_results(batch_df, _): + self.assertEqual( + set(batch_df.sort("key").collect()), + {Row(key=None, countAsString="1")}, + ) + + self._test_apply_in_pandas_with_state_basic(func, check_results) def test_apply_in_pandas_with_state_python_worker_random_failure(self): input_path = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 9fc6ae04e94c7..d4c535fe76a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} +import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER, InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.internal.SQLConf @@ -155,7 +155,24 @@ class ApplyInPandasWithStatePythonRunner( // data and state metadata have same number of rows, which is required by Arrow record // batch. assert(batch.numRows() > 0) - assert(schema.length == 2) + assert(schema.length == 3) + + def getValueFromCountColumn(batch: ColumnarBatch): (Int, Int) = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val dataType = schema(0).dataType.asInstanceOf[StructType] + assert( + DataTypeUtils.sameType(dataType, COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER), + s"Schema equality check failure! type from Arrow: $dataType, " + + s"expected type: $COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER" + ) + + // NOTE: See ApplyInPandasWithStatePythonRunner.COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER + // for the schema. + val dataCount = structVector.getChild(0).getInt(0) + val stateCount = structVector.getChild(1).getInt(0) + (dataCount, stateCount) + } def getColumnarBatchForStructTypeColumn( batch: ColumnarBatch, @@ -174,51 +191,43 @@ class ApplyInPandasWithStatePythonRunner( flattenedBatch } - def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { - val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, outputSchema) - dataBatch.rowIterator.asScala.flatMap { row => - if (row.isNullAt(0)) { - // The entire row in record batch seems to be for state metadata. - None - } else { - Some(row) - } + + def constructIterForData(batch: ColumnarBatch, numRows: Int): Iterator[InternalRow] = { + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 1, outputSchema) + dataBatch.rowIterator.asScala.take(numRows).flatMap { row => + Some(row) } } - def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = { - val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1, + def constructIterForState(batch: ColumnarBatch, numRows: Int): Iterator[OutTypeForState] = { + val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 2, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER) - stateMetadataBatch.rowIterator().asScala.flatMap { row => + stateMetadataBatch.rowIterator().asScala.take(numRows).flatMap { row => implicit val formats = org.json4s.DefaultFormats - if (row.isNullAt(0)) { - // The entire row in record batch seems to be for data. + // NOTE: See ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER + // for the schema. + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { None } else { - // NOTE: See ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER - // for the schema. - val propertiesAsJson = parse(row.getUTF8String(0).toString) - val keyRowAsUnsafeAsBinary = row.getBinary(1) - val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) - keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) - val maybeObjectRow = if (row.isNullAt(2)) { - None - } else { - val pickledStateValue = row.getBinary(2) - Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, - stateRowDeserializer)) - } - val oldTimeoutTimestamp = row.getLong(3) - - Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), - oldTimeoutTimestamp)) + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) } + val oldTimeoutTimestamp = row.getLong(3) + + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) } } - (constructIterForState(batch), constructIterForData(batch)) + val (dataCount, stateCount) = getValueFromCountColumn(batch) + (constructIterForState(batch, stateCount), constructIterForData(batch, dataCount)) } } @@ -235,4 +244,11 @@ object ApplyInPandasWithStatePythonRunner { StructField("oldTimeoutTimestamp", LongType) ) ) + val COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType( + Array( + StructField("dataCount", IntegerType), + StructField("stateCount", IntegerType) + ) + ) + } From 5bfaa71d7bc63a19c73bc0208eccf1d68dbf6ac7 Mon Sep 17 00:00:00 2001 From: Raghu Angadi Date: Wed, 19 Jul 2023 09:04:05 +0900 Subject: [PATCH 012/986] [SPARK-42944][SS][PYTHON] Streaming ForeachBatch in Python Adds `foreachBatch()` in Python. This adds a new runner `StreamingPythonRunner`. Note that this PR focuses on core functionality and includes TODO for followup improvements (will update with jira tickets where missing). Included more inline comments to help with the review. ### What changes were proposed in this pull request? Adds support for foreachBatch() in Spark connect. ### Why are the changes needed? - Manual tests - Unit tests: - The tests are updated to use a global temp view, rather than shared variable since connect version of the function runs on the server side. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Closes #42035 from rangadi/feb-py. Authored-by: Raghu Angadi Signed-off-by: Hyukjin Kwon (cherry picked from commit d93f6d145142ec15a96b0a3bfbbaa044c4b725e9) Signed-off-by: Hyukjin Kwon --- .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../connect/planner/SparkConnectPlanner.scala | 5 +- .../planner/StreamingForeachBatchHelper.scala | 70 +++++++++++++-- .../spark/api/python/PythonRunner.scala | 1 + .../api/python/PythonWorkerFactory.scala | 9 +- .../api/python/StreamingPythonRunner.scala | 88 +++++++++++++++++++ python/pyspark/sql/connect/session.py | 9 ++ .../sql/connect/streaming/readwriter.py | 12 +-- .../test_parity_streaming_foreachBatch.py | 44 ++++++++++ .../streaming/test_streaming_foreachBatch.py | 20 +++-- python/pyspark/streaming_worker.py | 78 ++++++++++++++++ 11 files changed, 311 insertions(+), 27 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala create mode 100644 python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py create mode 100644 python/pyspark/streaming_worker.py diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1287176d76e88..91d744b9e4863 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -335,7 +335,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { .start() eventually(timeout(30.seconds)) { // Wait for first progress. - assert(q.lastProgress != null) + assert(q.lastProgress != null, "Failed to make progress") assert(q.lastProgress.numInputRows > 0) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 39cb4c1b972b1..92a9524f67aaf 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2792,7 +2792,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { if (writeOp.hasForeachBatch) { val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match { case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION => - throw InvalidPlanInput("Python ForeachBatch is not supported yet. WIP.") + val pythonFn = transformPythonFunction(writeOp.getForeachBatch.getPythonFunction) + StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder) case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION => val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType]( @@ -2801,7 +2802,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => - throw InvalidPlanInput("Unexpected") + throw InvalidPlanInput("Unexpected") // Unreachable } writer.foreachBatch(foreachBatchFn) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 66487e7048c08..3148139377767 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.connect.planner import java.util.UUID +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.SimplePythonFunction +import org.apache.spark.api.python.StreamingPythonRunner import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.service.SparkConnectService /** * A helper class for handling ForeachBatch related functionality in Spark Connect servers @@ -29,23 +33,25 @@ object StreamingForeachBatchHelper extends Logging { type ForeachBatchFnType = (DataFrame, Long) => Unit + private case class FnArgsWithId(dfId: String, df: DataFrame, batchId: Long) + /** * Return a new ForeachBatch function that wraps `fn`. It sets up DataFrame cache so that the * user function can access it. The cache is cleared once ForeachBatch returns. */ - def dataFrameCachingWrapper( - fn: ForeachBatchFnType, + private def dataFrameCachingWrapper( + fn: FnArgsWithId => Unit, sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, batchId: Long) => { val dfId = UUID.randomUUID().toString log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to the log. - // TODO: Sanity check there is no other active DataFrame for this query. Need to include - // query id available in the cache for this check. + // TODO: Sanity check there is no other active DataFrame for this query. The query id + // needs to be saved in the cache for this check. sessionHolder.cacheDataFrameById(dfId, df) try { - fn(df, batchId) + fn(FnArgsWithId(dfId, df, batchId)) } finally { log.info(s"Removing DataFrame with id $dfId from the cache") sessionHolder.removeCachedDataFrame(dfId) @@ -57,13 +63,61 @@ object StreamingForeachBatchHelper extends Logging { * Handles setting up Scala remote session and other Spark Connect environment and then runs the * provided foreachBatch function `fn`. * - * HACK ALERT: This version does not atually set up Spark connect. Directly passes the - * DataFrame, so the user code actually runs with legacy DataFrame. + * HACK ALERT: This version does not actually set up Spark Connect session. Directly passes the + * DataFrame, so the user code actually runs with legacy DataFrame and session.. */ def scalaForeachBatchWrapper( fn: ForeachBatchFnType, sessionHolder: SessionHolder): ForeachBatchFnType = { // TODO: Set up Spark Connect session. Do we actually need this for the first version? - dataFrameCachingWrapper(fn, sessionHolder) + dataFrameCachingWrapper( + (args: FnArgsWithId) => { + fn(args.df, args.batchId) // dfId is not used, see hack comment above. + }, + sessionHolder) } + + /** + * Starts up Python worker and initializes it with Python function. Returns a foreachBatch + * function that sets up the session and Dataframe cache and and interacts with the Python + * worker to execute user's function. + */ + def pythonForeachBatchWrapper( + pythonFn: SimplePythonFunction, + sessionHolder: SessionHolder): ForeachBatchFnType = { + + val port = SparkConnectService.localPort + val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + val runner = StreamingPythonRunner(pythonFn, connectUrl) + val (dataOut, dataIn) = runner.init(sessionHolder.sessionId) + + val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { + + // TODO(SPARK-44460): Support Auth credentials + // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession needs to be created. + // This is because MicroBatch execution clones the session during start. + // The session attached to the foreachBatch dataframe is different from the one the one + // the query was started with. `sessionHolder` here contains the latter. + + PythonRDD.writeUTF(args.dfId, dataOut) + dataOut.writeLong(args.batchId) + dataOut.flush() + + val ret = dataIn.readInt() + log.info(s"Python foreach batch for dfId ${args.dfId} completed (ret: $ret)") + } + + dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder) + } + + // TODO(SPARK-44433): Improve termination of Processes + // The goal is that when a query is terminated, the python process asociated with foreachBatch + // should be terminated. One way to do that is by registering stremaing query listener: + // After pythonForeachBatchWrapper() is invoked by the SparkConnectPlanner. + // At that time, we don't have the streaming queries yet. + // Planner should call back into this helper with the query id when it starts it immediately + // after. Save the query id to StreamingPythonRunner mapping. This mapping should be + // part of the SessionHolder. + // When a query is terminated, check the mapping and terminate any associated runner. + // These runners should be terminated when a session is deleted (due to timeout, etc). } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ffb1098576814..2831ae74f5606 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -799,6 +799,7 @@ private[spark] object SpecialLengths { val END_OF_STREAM = -4 val NULL = -5 val START_ARROW_STREAM = -6 + val END_OF_MICRO_BATCH = -7 } private[spark] object BarrierTaskContextMessageProtocol { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 19181bd98e111..6039f8d232b4c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -106,10 +106,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } createThroughDaemon() } else { - createSimpleWorker() + createSimpleWorker(workerModule) } } + /** Creates a Python worker with `pyspark.streaming_worker` module. */ + def createStreamingWorker(): (Socket, Option[Int]) = { + createSimpleWorker("pyspark.streaming_worker") + } + /** * Connect to a worker launched through pyspark/daemon.py (by default), which forks python * processes itself to avoid the high cost of forking from Java. This currently only works @@ -150,7 +155,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private def createSimpleWorker(): (Socket, Option[Int]) = { + private def createSimpleWorker(workerModule: String): (Socket, Option[Int]) = { var serverSocket: ServerSocket = null try { serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala new file mode 100644 index 0000000000000..77dc88e0cfa5d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} +import java.net.Socket + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTHON_USE_DAEMON} + + +private[spark] object StreamingPythonRunner { + def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = { + new StreamingPythonRunner(func, connectUrl) + } +} + +private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String) + extends Logging { + private val conf = SparkEnv.get.conf + protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) + + private val envVars: java.util.Map[String, String] = func.envVars + private val pythonExec: String = func.pythonExec + protected val pythonVer: String = func.pythonVer + + /** + * Initializes the Python worker for streaming functions. Sets up Spark Connect session + * to be used with the functions. + */ + def init(sessionId: String): (DataOutputStream, DataInputStream) = { + log.info(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + + val env = SparkEnv.get + + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) + + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + conf.set(PYTHON_USE_DAEMON, false) + envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) + + val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, envVars.asScala.toMap) + val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker() + + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + + // TODO: verify python version + + // Send sessionId + PythonRDD.writeUTF(sessionId, dataOut) + + // send the user function to python process + val command = func.command + dataOut.writeInt(command.length) + dataOut.write(command.toArray) + dataOut.flush() + + val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val resFromPython = dataIn.readInt() + log.info(s"Runner initialization returned $resFromPython") + + (dataOut, dataIn) + } +} diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 52eab1bf5f930..37a5bdd9f9fd7 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -58,6 +58,7 @@ LogicalPlan, CachedLocalRelation, CachedRelation, + CachedRemoteRelation, ) from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming import DataStreamReader, StreamingQueryManager @@ -670,6 +671,14 @@ def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None: copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__ + def _createRemoteDataFrame(self, remote_id: str) -> "DataFrame": + """ + In internal API to reference a runtime DataFrame on the server side. + This is used in ForeachBatch() runner, where the remote DataFrame refers to the + output of a micro batch. + """ + return DataFrame.withPlan(CachedRemoteRelation(remote_id), self) + @staticmethod def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: """ diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 156a3ba87db43..c8cd408404f8f 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -32,7 +32,7 @@ DataStreamWriter as PySparkDataStreamWriter, ) from pyspark.sql.types import Row, StructType -from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkNotImplementedError +from pyspark.errors import PySparkTypeError, PySparkValueError if TYPE_CHECKING: from pyspark.sql.connect.session import SparkSession @@ -495,14 +495,14 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ - # TODO (SPARK-42944): Implement and uncomment the doc def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": - raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "foreachBatch()"}, + self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps( + func ) + self._write_proto.foreach_batch.python_function.python_ver = "%d.%d" % sys.version_info[:2] + return self - # foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ + foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ def _start_internal( self, diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py new file mode 100644 index 0000000000000..c4aa936a43efe --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase): + @unittest.skip("SPARK-44463: Error handling needs improvement in connect foreachBatch") + def test_streaming_foreachBatch_propagates_python_errors(self): + super().test_streaming_foreachBatch_propagates_python_errors + + @unittest.skip("This seems specific to py4j and pinned threads. The intention is unclear") + def test_streaming_foreachBatch_graceful_stop(self): + super().test_streaming_foreachBatch_graceful_stop() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.streaming.test_parity_streaming_foreachBatch import * # noqa: F401,E501 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py index 7e5720e429990..d4e185c3d856d 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -20,40 +20,40 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingTestsForeachBatch(ReusedSQLTestCase): +class StreamingTestsForeachBatchMixin: def test_streaming_foreachBatch(self): q = None - collected = dict() def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() + batch_df.createOrReplaceGlobalTempView("test_view") try: df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") q = df.writeStream.foreachBatch(collectBatch).start() q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) + collected = self.spark.sql("select * from global_temp.test_view").collect() + self.assertTrue(len(collected), 2) finally: if q: q.stop() def test_streaming_foreachBatch_tempview(self): q = None - collected = dict() def collectBatch(batch_df, batch_id): batch_df.createOrReplaceTempView("updates") # it should use the spark session within given DataFrame, as microbatch execution will # clone the session which is no longer same with the session used to start the # streaming query - collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() + assert len(batch_df.sparkSession.sql("SELECT * FROM updates").collect()) == 2 + # Write to a global view verify on the repl/client side. + batch_df.createOrReplaceGlobalTempView("temp_view") try: df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") q = df.writeStream.foreachBatch(collectBatch).start() q.processAllAvailable() - self.assertTrue(0 in collected) + collected = self.spark.sql("SELECT * FROM global_temp.temp_view").collect() self.assertTrue(len(collected[0]), 2) finally: if q: @@ -89,6 +89,10 @@ def func(batch_df, _): self.assertIsNone(q.exception(), "No exception has to be propagated.") +class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": import unittest from pyspark.sql.tests.streaming.test_streaming_foreachBatch import * # noqa: F401 diff --git a/python/pyspark/streaming_worker.py b/python/pyspark/streaming_worker.py new file mode 100644 index 0000000000000..490bae44d99f3 --- /dev/null +++ b/python/pyspark/streaming_worker.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A worker for streaming foreachBatch and query listener in Spark Connect. +""" +import os + +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + write_int, + read_long, + UTF8Deserializer, + CPickleSerializer, +) +from pyspark import worker +from pyspark.sql import SparkSession + +pickleSer = CPickleSerializer() +utf8_deserializer = UTF8Deserializer() + + +def main(infile, outfile): # type: ignore[no-untyped-def] + log_name = "Streaming ForeachBatch worker" + connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] + sessionId = utf8_deserializer.loads(infile) + + print(f"{log_name} is starting with url {connect_url} and sessionId {sessionId}.") + + sparkConnectSession = SparkSession.builder.remote(connect_url).getOrCreate() + sparkConnectSession._client._session_id = sessionId + + # TODO(SPARK-44460): Pass credentials. + # TODO(SPARK-44461): Enable Process Isolation + + func = worker.read_command(pickleSer, infile) + write_int(0, outfile) # Indicate successful initialization + + outfile.flush() + + def process(dfId, batchId): # type: ignore[no-untyped-def] + print(f"{log_name} Started batch {batchId} with DF id {dfId}") + batchDf = sparkConnectSession._createRemoteDataFrame(dfId) + func(batchDf, batchId) + print(f"{log_name} Completed batch {batchId} with DF id {dfId}") + + while True: + dfRefId = utf8_deserializer.loads(infile) + batchId = read_long(infile) + process(dfRefId, int(batchId)) # TODO(SPARK-44463): Propagate error to the user. + write_int(0, outfile) + outfile.flush() + + +if __name__ == "__main__": + print("Starting streaming worker") + + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() + main(sock_file, sock_file) From b831d826ba8443ca6195704218045dceae632d97 Mon Sep 17 00:00:00 2001 From: vicennial Date: Wed, 19 Jul 2023 09:36:51 +0900 Subject: [PATCH 013/986] [SPARK-44476][CORE][CONNECT] Fix population of artifacts for a JobArtifactState with no associated artifacts ### What changes were proposed in this pull request? When a `JobArtifactSet` is created form a specific `JobArtifactState`, an empty collection of files/jars/archives is returned if there have previously been no associated artifacts of that specific type rather than all files/jars/archives. ### Why are the changes needed? Consider each artifact type - files/jars/archives. For each artifact type, the following bug exists: 1. Initialise a `JobArtifactState` with no artifacts added to it. 2. Create a `JobArtifactSet` from the `JobArtifactState`. 3. Add an artifact with the same active `JobArtifactState`. 4. Create another `JobArtifactSet` In the current behaviour, the set created in step 2 contains all existing artifacts of that type (through `sc.allAddedFiles` for example) while step 4 would only contain the single artifact added in step 3. ### Does this PR introduce _any_ user-facing change? No. (Bug-fix addresses the previously incorrect user behaviour) ### How was this patch tested? New unit test in `JobArtifactSetSuite`. Closes #42062 from vicennial/SPARK-44476. Authored-by: vicennial Signed-off-by: Hyukjin Kwon (cherry picked from commit 98de9a0ad46086bfad5d89a66540a24e39d2f029) Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/JobArtifactSet.scala | 6 ++-- .../apache/spark/JobArtifactSetSuite.scala | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala index 54922f5783af0..c304e15f35889 100644 --- a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala +++ b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala @@ -110,13 +110,13 @@ private[spark] object JobArtifactSet { new JobArtifactSet( state = maybeState, jars = maybeState - .map(s => sc.addedJars.getOrElse(s.uuid, sc.allAddedJars)) + .map(s => sc.addedJars.getOrElse(s.uuid, Map.empty[String, Long])) .getOrElse(sc.allAddedJars).toMap, files = maybeState - .map(s => sc.addedFiles.getOrElse(s.uuid, sc.allAddedFiles)) + .map(s => sc.addedFiles.getOrElse(s.uuid, Map.empty[String, Long])) .getOrElse(sc.allAddedFiles).toMap, archives = maybeState - .map(s => sc.addedArchives.getOrElse(s.uuid, sc.allAddedArchives)) + .map(s => sc.addedArchives.getOrElse(s.uuid, Map.empty[String, Long])) .getOrElse(sc.allAddedArchives).toMap) } } diff --git a/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala b/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala index 66d02e8b511aa..bf1cb4dded85b 100644 --- a/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala @@ -86,4 +86,40 @@ class JobArtifactSetSuite extends SparkFunSuite with LocalSparkContext { assert(JobArtifactSet.getActiveOrDefault(sc).state.isEmpty) } } + + test("SPARK-44476: JobArtifactState is not populated with all artifacts if none are " + + "explicitly added to it.") { + withTempDir { dir => + val conf = new SparkConf() + .setAppName("test") + .setMaster("local") + .set("spark.repl.class.uri", "dummyUri") + sc = new SparkContext(conf) + + val jarPath = File.createTempFile("testJar", ".jar", dir).getAbsolutePath + val filePath = File.createTempFile("testFile", ".txt", dir).getAbsolutePath + val fileToZip = File.createTempFile("testFile", "", dir).getAbsolutePath + val archivePath = s"$fileToZip.zip" + createZipFile(fileToZip, archivePath) + + val otherJobArtifactState = JobArtifactState("other", Some("state")) + + JobArtifactSet.withActiveJobArtifactState(otherJobArtifactState) { + sc.addJar(jarPath) + sc.addFile(filePath) + sc.addArchive(archivePath) + } + + val artifactState = JobArtifactState("abc", Some("xyz")) + JobArtifactSet.withActiveJobArtifactState(artifactState) { + val jobArtifactSet = JobArtifactSet.getActiveOrDefault(sc) + + // Artifacts from the other state must be not visible to this state. + assert(jobArtifactSet.state.contains(artifactState)) + assert(jobArtifactSet.jars.isEmpty) + assert(jobArtifactSet.files.isEmpty) + assert(jobArtifactSet.archives.isEmpty) + } + } + } } From f497ab568bb9541621785492e699168dd8cb996c Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 19 Jul 2023 09:51:41 +0900 Subject: [PATCH 014/986] [SPARK-43755][CONNECT][MINOR] Open `AdaptiveSparkPlanHelper.allChildren` instead of using copy in `MetricGenerator` ### What changes were proposed in this pull request? Minor refactor - make `AdaptiveSparkPlanHelper.allChildren` protected in place of private, so that we don't have to copy it over, risking that the two versions will get out of sync. ### Why are the changes needed? Minor refactor. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI Closes #42060 from juliuszsompolski/SPARK-43755-fup. Authored-by: Juliusz Sompolski Signed-off-by: Hyukjin Kwon (cherry picked from commit 227e28cf84c0720ef0515a3744bc017b6eab26ed) Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/connect/utils/MetricGenerator.scala | 8 +------- .../sql/execution/adaptive/AdaptiveSparkPlanHelper.scala | 2 +- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala index 88120e616efdb..6395fb588ab84 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper /** * Helper object for generating responses with metrics from queries. @@ -47,12 +47,6 @@ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper { allChildren(p).flatMap(c => transformPlan(c, p.id)) } - private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { - case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) - case s: QueryStageExec => Seq(s.plan) - case _ => p.children - } - private def transformPlan( p: SparkPlan, parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala index eecfa40e8d0bd..c58d925f28e5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala @@ -122,7 +122,7 @@ trait AdaptiveSparkPlanHelper { subqueries ++ subqueries.flatMap(subqueriesAll) } - private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { + protected def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) case s: QueryStageExec => Seq(s.plan) case _ => p.children From 4ca29161ba16ce195a08b36c9abf2fc5d1841d1a Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Wed, 19 Jul 2023 09:02:32 +0800 Subject: [PATCH 015/986] [SPARK-44324][SQL][CONNECT] Move CaseInsensitiveMap to sql/api ### What changes were proposed in this pull request? Move CaseInsensitiveMap to sql/api. ### Why are the changes needed? So that Spark Connect Scala client do not need to depend on the Catalyst. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test Closes #41882 from amaliujia/move_case_senstiive_map. Authored-by: Rui Wang Signed-off-by: Wenchen Fan (cherry picked from commit ede82a9519241492ae7bc3aef703b9eb875befb7) Signed-off-by: Wenchen Fan --- sql/api/pom.xml | 18 ++++++++++++++++++ .../sql/catalyst/util/CaseInsensitiveMap.scala | 0 .../sql/catalyst/util/CaseInsensitiveMap.scala | 0 3 files changed, 18 insertions(+) rename sql/{catalyst => api}/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala (100%) rename sql/{catalyst => api}/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala (100%) diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 41a5b85d4c670..6add5679ce7c4 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -50,6 +50,24 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.codehaus.mojo + build-helper-maven-plugin + + + add-sources + generate-sources + + add-source + + + + src/main/scala-${scala.binary.version} + + + + + \ No newline at end of file diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/api/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala similarity index 100% rename from sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala rename to sql/api/src/main/scala-2.12/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/api/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala similarity index 100% rename from sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala rename to sql/api/src/main/scala-2.13/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala From cf52568af1ef0a17a68f2ea57fb35442547482b5 Mon Sep 17 00:00:00 2001 From: Jack Chen Date: Wed, 19 Jul 2023 09:12:54 +0800 Subject: [PATCH 016/986] [SPARK-44448][SQL] Fix wrong results bug from DenseRankLimitIterator and InferWindowGroupLimit ### What changes were proposed in this pull request? Top-k filters on a dense_rank() window function return wrong results, due to a bug in optimization InferWindowGroupLimit, specifically in the code for DenseRankLimitIterator, introduced in https://issues.apache.org/jira/browse/SPARK-37099. The bug is in DenseRankLimitIterator, it fails to reset state properly when transitioning from one window partition to the next. reset only resets rank = 0, what it is missing is to reset currentRankRow = null. This means that when processing the second and later window partitions, the rank incorrectly gets incremented based on comparing the ordering of the last row of the previous partition to the first row of the new partition. This means that a dense_rank window func that has more than one window partition and more than one row with dense_rank = 1 in the second or later partitions can give wrong results when optimized. RankLimitIterator narrowly avoids this bug by happenstance, the first row in the new partition will try to increment rank, but increment it by the value of count which is 0, so it happens to work by accident. This PR also fixes the reset function in RankLimitIterator to make it more robust. Example repro: ``` create or replace temp view t1 (p, o) as values (1, 1), (1, 1), (1, 2), (2, 1), (2, 1), (2, 2); select * from (select *, dense_rank() over (partition by p order by o) as rnk from t1) where rnk = 1; ``` Spark result: ``` [1,1,1] [1,1,1] [2,1,1] ``` Correct result: ``` [1,1,1] [1,1,1] [2,1,1] [2,1,1] ``` ### Why are the changes needed? Fix wrong results bug. ### Does this PR introduce _any_ user-facing change? Yes, fixes wrong results. ### How was this patch tested? Add sql tests and unit tests. Unfortunately, the previous tests for the optimization only had a single row per rank, so did not catch the bug as the bug requires multiple rows per rank. This PR strengthens the tests with data that contains multiple rows per rank. Closes #42026 from jchen5/dense-rank-limit. Authored-by: Jack Chen Signed-off-by: Wenchen Fan (cherry picked from commit f35c814f9b3cf73b94297f8086a117df4c46f39a) Signed-off-by: Wenchen Fan --- .../window/WindowGroupLimitExec.scala | 2 + .../sql-tests/analyzer-results/window.sql.out | 131 ++++++++++++++++++ .../resources/sql-tests/inputs/window.sql | 13 ++ .../sql-tests/results/window.sql.out | 90 ++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 52 +++++-- 5 files changed, 276 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala index b1f375f415102..98969f60c2b43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowGroupLimitExec.scala @@ -170,6 +170,7 @@ case class RankLimitIterator( override def reset(): Unit = { rank = 0 count = 0 + currentRankRow = null } } @@ -193,6 +194,7 @@ case class DenseRankLimitIterator( override def reset(): Unit = { rank = 0 + currentRankRow = null } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out index ad75b97ceb177..cc2cc8bfc0e26 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/window.sql.out @@ -1279,3 +1279,134 @@ org.apache.spark.sql.AnalysisException "windowName" : "w" } } + + +-- !query +create or replace temp view t1 (p, o) as values (1, 1), (1, 1), (1, 2), (2, 1), (2, 1), (2, 2) +-- !query analysis +CreateViewCommand `t1`, [(p,None), (o,None)], values (1, 1), (1, 1), (1, 2), (2, 1), (2, 1), (2, 2), false, true, LocalTempView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +select * from (select *, dense_rank() over (partition by p order by o) as rnk from t1) where rnk = 1 +-- !query analysis +Project [p#x, o#x, rnk#x] ++- Filter (rnk#x = 1) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [p#x, o#x, rnk#x] + +- Project [p#x, o#x, rnk#x, rnk#x] + +- Window [dense_rank(o#x) windowspecdefinition(p#x, o#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rnk#x], [p#x], [o#x ASC NULLS FIRST] + +- Project [p#x, o#x] + +- SubqueryAlias t1 + +- View (`t1`, [p#x,o#x]) + +- Project [cast(col1#x as int) AS p#x, cast(col2#x as int) AS o#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +SELECT * FROM (SELECT cate, val, rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1 +-- !query analysis +Project [cate#x, val#x, r#x] ++- Filter (r#x = 1) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [cate#x, val#x, r#x] + +- Project [cate#x, val#x, r#x, r#x] + +- Window [rank(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS r#x], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias testdata + +- View (`testData`, [val#x,val_long#xL,val_double#x,val_date#x,val_timestamp#x,cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +SELECT * FROM (SELECT cate, val, rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2 +-- !query analysis +Project [cate#x, val#x, r#x] ++- Filter (r#x <= 2) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [cate#x, val#x, r#x] + +- Project [cate#x, val#x, r#x, r#x] + +- Window [rank(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS r#x], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias testdata + +- View (`testData`, [val#x,val_long#xL,val_double#x,val_date#x,val_timestamp#x,cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +SELECT * FROM (SELECT cate, val, dense_rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1 +-- !query analysis +Project [cate#x, val#x, r#x] ++- Filter (r#x = 1) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [cate#x, val#x, r#x] + +- Project [cate#x, val#x, r#x, r#x] + +- Window [dense_rank(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS r#x], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias testdata + +- View (`testData`, [val#x,val_long#xL,val_double#x,val_date#x,val_timestamp#x,cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +SELECT * FROM (SELECT cate, val, dense_rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2 +-- !query analysis +Project [cate#x, val#x, r#x] ++- Filter (r#x <= 2) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [cate#x, val#x, r#x] + +- Project [cate#x, val#x, r#x, r#x] + +- Window [dense_rank(val#x) windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS r#x], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias testdata + +- View (`testData`, [val#x,val_long#xL,val_double#x,val_date#x,val_timestamp#x,cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +SELECT * FROM (SELECT cate, val, row_number() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1 +-- !query analysis +Project [cate#x, val#x, r#x] ++- Filter (r#x = 1) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [cate#x, val#x, r#x] + +- Project [cate#x, val#x, r#x, r#x] + +- Window [row_number() windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS r#x], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias testdata + +- View (`testData`, [val#x,val_long#xL,val_double#x,val_date#x,val_timestamp#x,cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + + +-- !query +SELECT * FROM (SELECT cate, val, row_number() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2 +-- !query analysis +Project [cate#x, val#x, r#x] ++- Filter (r#x <= 2) + +- SubqueryAlias __auto_generated_subquery_name + +- Project [cate#x, val#x, r#x] + +- Project [cate#x, val#x, r#x, r#x] + +- Window [row_number() windowspecdefinition(cate#x, val#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS r#x], [cate#x], [val#x ASC NULLS FIRST] + +- Project [cate#x, val#x] + +- SubqueryAlias testdata + +- View (`testData`, [val#x,val_long#xL,val_double#x,val_date#x,val_timestamp#x,cate#x]) + +- Project [cast(val#x as int) AS val#x, cast(val_long#xL as bigint) AS val_long#xL, cast(val_double#x as double) AS val_double#x, cast(val_date#x as date) AS val_date#x, cast(val_timestamp#x as timestamp) AS val_timestamp#x, cast(cate#x as string) AS cate#x] + +- Project [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] + +- SubqueryAlias testData + +- LocalRelation [val#x, val_long#xL, val_double#x, val_date#x, val_timestamp#x, cate#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index 8f8963c55586a..f94ff0f0a68a5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -2,6 +2,8 @@ --CONFIG_DIM1 spark.sql.codegen.wholeStage=true --CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY --CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN +--CONFIG_DIM2 spark.sql.optimizer.windowGroupLimitThreshold=-1 +--CONFIG_DIM2 spark.sql.optimizer.windowGroupLimitThreshold=1000 -- Test data. CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES @@ -465,3 +467,14 @@ SELECT SUM(salary) OVER w sum_salary FROM basic_pays; + +-- Test cases for InferWindowGroupLimit +create or replace temp view t1 (p, o) as values (1, 1), (1, 1), (1, 2), (2, 1), (2, 1), (2, 2); +select * from (select *, dense_rank() over (partition by p order by o) as rnk from t1) where rnk = 1; + +SELECT * FROM (SELECT cate, val, rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1; +SELECT * FROM (SELECT cate, val, rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2; +SELECT * FROM (SELECT cate, val, dense_rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1; +SELECT * FROM (SELECT cate, val, dense_rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2; +SELECT * FROM (SELECT cate, val, row_number() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1; +SELECT * FROM (SELECT cate, val, row_number() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2; diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 13e2f5209dd36..7566f9eb20fd9 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1342,3 +1342,93 @@ org.apache.spark.sql.AnalysisException "windowName" : "w" } } + + +-- !query +create or replace temp view t1 (p, o) as values (1, 1), (1, 1), (1, 2), (2, 1), (2, 1), (2, 2) +-- !query schema +struct<> +-- !query output + + + +-- !query +select * from (select *, dense_rank() over (partition by p order by o) as rnk from t1) where rnk = 1 +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +2 1 1 +2 1 1 + + +-- !query +SELECT * FROM (SELECT cate, val, rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1 +-- !query schema +struct +-- !query output +NULL NULL 1 +a NULL 1 +b 1 1 + + +-- !query +SELECT * FROM (SELECT cate, val, rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2 +-- !query schema +struct +-- !query output +NULL 3 2 +NULL NULL 1 +a 1 2 +a 1 2 +a NULL 1 +b 1 1 +b 2 2 + + +-- !query +SELECT * FROM (SELECT cate, val, dense_rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1 +-- !query schema +struct +-- !query output +NULL NULL 1 +a NULL 1 +b 1 1 + + +-- !query +SELECT * FROM (SELECT cate, val, dense_rank() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2 +-- !query schema +struct +-- !query output +NULL 3 2 +NULL NULL 1 +a 1 2 +a 1 2 +a NULL 1 +b 1 1 +b 2 2 + + +-- !query +SELECT * FROM (SELECT cate, val, row_number() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r = 1 +-- !query schema +struct +-- !query output +NULL NULL 1 +a NULL 1 +b 1 1 + + +-- !query +SELECT * FROM (SELECT cate, val, row_number() OVER(PARTITION BY cate ORDER BY val) as r FROM testData) where r <= 2 +-- !query schema +struct +-- !query output +NULL 3 2 +NULL NULL 1 +a 1 2 +a NULL 1 +b 1 1 +b 2 2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index f2f645b126cbf..a57e927ba8427 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -1283,7 +1283,13 @@ class DataFrameWindowFunctionsSuite extends QueryTest ("b", 1, "n", Double.PositiveInfinity), ("c", 1, "z", -2.0), ("c", 1, "a", -4.0), - ("c", 2, nullStr, 5.0)).toDF("key", "value", "order", "value2") + ("c", 2, nullStr, 5.0), + ("d", 0, "1", 1.0), + ("d", 1, "1", 2.0), + ("d", 2, "2", 3.0), + ("d", 3, "2", -1.0), + ("d", 4, "2", 2.0), + ("d", 4, "3", 2.0)).toDF("key", "value", "order", "value2") val window = Window.partitionBy($"key").orderBy($"order".asc_nulls_first) val window2 = Window.partitionBy($"key").orderBy($"order".desc_nulls_first) @@ -1304,7 +1310,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest Seq( Row("a", 4, "", 2.0, 1), Row("b", 1, "h", Double.NaN, 1), - Row("c", 2, null, 5.0, 1) + Row("c", 2, null, 5.0, 1), + Row("d", 0, "1", 1.0, 1) ) ) @@ -1313,7 +1320,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("a", 4, "", 2.0, 1), Row("a", 4, "", 2.0, 1), Row("b", 1, "h", Double.NaN, 1), - Row("c", 2, null, 5.0, 1) + Row("c", 2, null, 5.0, 1), + Row("d", 0, "1", 1.0, 1), + Row("d", 1, "1", 2.0, 1) ) ) @@ -1322,7 +1331,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("a", 4, "", 2.0, 1), Row("a", 4, "", 2.0, 1), Row("b", 1, "h", Double.NaN, 1), - Row("c", 2, null, 5.0, 1) + Row("c", 2, null, 5.0, 1), + Row("d", 0, "1", 1.0, 1), + Row("d", 1, "1", 2.0, 1) ) ) @@ -1353,7 +1364,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("b", 1, "h", Double.NaN, 1), Row("b", 1, "n", Double.PositiveInfinity, 2), Row("c", 1, "a", -4.0, 2), - Row("c", 2, null, 5.0, 1) + Row("c", 2, null, 5.0, 1), + Row("d", 0, "1", 1.0, 1), + Row("d", 1, "1", 2.0, 2) ) ) @@ -1364,7 +1377,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("b", 1, "h", Double.NaN, 1), Row("b", 1, "n", Double.PositiveInfinity, 2), Row("c", 1, "a", -4.0, 2), - Row("c", 2, null, 5.0, 1) + Row("c", 2, null, 5.0, 1), + Row("d", 0, "1", 1.0, 1), + Row("d", 1, "1", 2.0, 1) ) ) @@ -1376,7 +1391,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("b", 1, "h", Double.NaN, 1), Row("b", 1, "n", Double.PositiveInfinity, 2), Row("c", 1, "a", -4.0, 2), - Row("c", 2, null, 5.0, 1) + Row("c", 2, null, 5.0, 1), + Row("d", 0, "1", 1.0, 1), + Row("d", 1, "1", 2.0, 1), + Row("d", 2, "2", 3.0, 2), + Row("d", 3, "2", -1.0, 2), + Row("d", 4, "2", 2.0, 2) ) ) @@ -1408,7 +1428,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest checkAnswer(df.withColumn("rn", row_number().over(window)).where(condition), Seq( Row("a", 4, "", 2.0, 2), - Row("b", 1, "n", Double.PositiveInfinity, 2) + Row("b", 1, "n", Double.PositiveInfinity, 2), + Row("d", 1, "1", 2.0, 2) ) ) @@ -1421,7 +1442,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest checkAnswer(df.withColumn("rn", dense_rank().over(window)).where(condition), Seq( Row("a", 0, "c", 1.0, 2), - Row("b", 1, "n", Double.PositiveInfinity, 2) + Row("b", 1, "n", Double.PositiveInfinity, 2), + Row("d", 2, "2", 3.0, 2), + Row("d", 4, "2", 2.0, 2) ) ) @@ -1433,7 +1456,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest Seq( Row("a", 4, "", 2.0, 1, 1), Row("b", 1, "h", Double.NaN, 1, 1), - Row("c", 2, null, 5.0, 1, 1) + Row("c", 2, null, 5.0, 1, 1), + Row("d", 0, "1", 1.0, 1, 1) ) ) @@ -1446,7 +1470,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("a", 4, "", 2.0, 1, 1), Row("a", 4, "", 2.0, 1, 1), Row("b", 1, "h", Double.NaN, 1, 1), - Row("c", 2, null, 5.0, 1, 1) + Row("c", 2, null, 5.0, 1, 1), + Row("d", 0, "1", 1.0, 1, 1), + Row("d", 1, "1", 2.0, 1, 1) ) ) @@ -1459,7 +1485,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("a", 4, "", 2.0, 1, 1), Row("a", 4, "", 2.0, 1, 1), Row("b", 1, "h", Double.NaN, 1, 1), - Row("c", 2, null, 5.0, 1, 1) + Row("c", 2, null, 5.0, 1, 1), + Row("d", 0, "1", 1.0, 1, 1), + Row("d", 1, "1", 2.0, 1, 1) ) ) From c417c123a48082d8db7362685d37b2133c4bb670 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 19 Jul 2023 10:21:39 +0900 Subject: [PATCH 017/986] [SPARK-44401][PYTHON][DOCS] Arrow Python UDF Use Guide ### What changes were proposed in this pull request? Add use guide for Arrow Python UDF. ### Why are the changes needed? Better documentation and usability of Arrow Python UDF. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? N/A. Closes #41974 from xinrong-meng/guide. Authored-by: Xinrong Meng Signed-off-by: Hyukjin Kwon (cherry picked from commit bdea4c5f52835b8375dc0f523a2887112df0ad47) Signed-off-by: Hyukjin Kwon --- examples/src/main/python/sql/arrow.py | 21 +++++++++++++++ .../source/user_guide/sql/arrow_pandas.rst | 26 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 2510ffd423bfa..03daf18eadbf3 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -275,6 +275,27 @@ def merge_ordered(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: # +--------+---+---+----+ +def arrow_python_udf_example(spark: SparkSession) -> None: + from pyspark.sql.functions import udf + + @udf(returnType='int') # A default, pickled Python UDF + def slen(s): # type: ignore[no-untyped-def] + return len(s) + + @udf(returnType='int', useArrow=True) # An Arrow Python UDF + def arrow_slen(s): # type: ignore[no-untyped-def] + return len(s) + + df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + + df.select(slen("name"), arrow_slen("name")).show() + # +----------+----------------+ + # |slen(name)|arrow_slen(name)| + # +----------+----------------+ + # | 8| 8| + # +----------+----------------+ + + if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index 9667b745628e2..e9355cdf4f0af 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -333,6 +333,32 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa For detailed usage, please see :meth:`PandasCogroupedOps.applyInPandas` +Arrow Python UDFs +----------------- + +Arrow Python UDFs are user defined functions that are executed row-by-row, utilizing Arrow for efficient batch data +transfer and serialization. To define an Arrow Python UDF, you can use the :meth:`udf` decorator or wrap the function +with the :meth:`udf` method, ensuring the ``useArrow`` parameter is set to True. Additionally, you can enable Arrow +optimization for Python UDFs throughout the entire SparkSession by setting the Spark configuration ``spark.sql +.execution.pythonUDF.arrow.enabled`` to true. It's important to note that the Spark configuration takes effect only +when ``useArrow`` is either not set or set to None. + +The type hints for Arrow Python UDFs should be specified in the same way as for default, pickled Python UDFs. + +Here's an example that demonstrates the usage of both a default, pickled Python UDF and an Arrow Python UDF: + +.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py + :language: python + :lines: 279-297 + :dedent: 4 + +Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF +type coercion poses challenges when the Python instances returned by UDFs do not align with the user-specified +return type. The default, pickled Python UDFs' type coercion has certain limitations, such as relying on None as a +fallback for type mismatches, leading to potential ambiguity and data loss. Additionally, converting date, datetime, +and tuples to strings can yield ambiguous results. Arrow Python UDFs, on the other hand, leverage Arrow's +capabilities to standardize type coercion and address these issues effectively. + Usage Notes ----------- From f4672b5fa5c229e9a76d45c620a07ec0ae1f81f4 Mon Sep 17 00:00:00 2001 From: Mathew Jacob Date: Wed, 19 Jul 2023 10:50:10 +0900 Subject: [PATCH 018/986] [SPARK-44264][ML][PYTHON] Refactoring TorchDistributor To Allow for Custom "run_training_on_file" Function Pointer ### What Was Changed We enable for a custom function pointer to be passed around the private functions that allow for distributed training of a function. ### Why Do We Need This Change By abstracting the "run_training_on_pytorch_file" function to something that can be passed in, it allows for much easier creation of distributors that run on top of torch.distributed. Specifically, it makes it easy to implement distributed training of picklable functions in DeepspeedTorchDistributor. As mentioned, if there are accelerators that come out in the future built on top of torch.distributed, it will be very easy to support them in Spark. One can simply do the following: 1. Inherit from TorchDistributor and define a _run_training_on_pytorch_file function or equivalent for your class 2. When defining run(...), simply return _run() and pass in your custom _run_training_on_pytorch_file function in as the respective argument ### Any User-Facing Changes? No. ### How Is This Tested? The existing tests for TorchDistributor. Closes #41973 from mathewjacob1002/distributed_func_support_prototype. Lead-authored-by: Mathew Jacob Co-authored-by: Mathew Jacob <134338709+mathewjacob1002@users.noreply.github.com> Signed-off-by: Hyukjin Kwon (cherry picked from commit ee0e687b2bf50c6b8ab1b40b948e91ef03da42b8) Signed-off-by: Hyukjin Kwon --- python/pyspark/ml/torch/distributor.py | 117 ++++++++++++++++-- .../ml/torch/tests/test_distributor.py | 9 +- 2 files changed, 113 insertions(+), 13 deletions(-) diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 5b5d7c2288a62..71bcde3b48e45 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -515,10 +515,64 @@ def _execute_command( f"The {last_n_msg} included below: {task_output}" ) + @staticmethod + def _get_output_from_framework_wrapper( + framework_wrapper: Optional[Callable], + input_params: Dict, + train_object: Union[Callable, str], + run_pytorch_file_fn: Optional[Callable], + *args: Any, + **kwargs: Any, + ) -> Optional[Any]: + """ + This function is meant to get the output from framework wrapper function by passing in the + correct arguments, depending on the type of train_object. + + Parameters + ---------- + framework_wrapper: Optional[Callable] + Function pointer that will be invoked. Can either be the function that runs distributed + training on files if train_object is a string. Otherwise, it will be the function that + runs distributed training for functions if the train_object is a Callable + input_params: Dict + A dictionary that maps parameter to arguments for the command to be created. + train_object: Union[Callable, str] + This input comes from the user. If the user inputs a string, then this means + it's a filepath. Otherwise, if the input is a function, then this means that + the user wants to run this function in a distributed manner. + run_pytorch_file_fn: Optional[Callable] + The function that will be used to run distributed training of a file; + mainly used for the distributed training using a function. + *args: Any + Extra arguments to be used by framework wrapper. + **kwargs: Any + Extra keyword args to be used. Not currently supported but kept for + future improvement. + + Returns + ------- + Optional[Any] + Returns the result of the framework_wrapper + """ + if not framework_wrapper: + raise RuntimeError("`framework_wrapper` is not set. ...") + # The object to train is a file path, so framework_wrapper is some + # run_training_on_pytorch_file function. + if type(train_object) is str: + return framework_wrapper(input_params, train_object, *args, **kwargs) + else: + # We are doing training with a function, will call run_training_on_pytorch_function + if not run_pytorch_file_fn: + run_pytorch_file_fn = TorchDistributor._run_training_on_pytorch_file + return framework_wrapper( + input_params, train_object, run_pytorch_file_fn, *args, **kwargs + ) + def _run_local_training( self, framework_wrapper_fn: Callable, train_object: Union[Callable, str], + run_pytorch_file_fn: Optional[Callable], *args: Any, **kwargs: Any, ) -> Optional[Any]: @@ -536,7 +590,14 @@ def _run_local_training( os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus) self.logger.info(f"Started local training with {self.num_processes} processes") - output = framework_wrapper_fn(self.input_params, train_object, *args, **kwargs) + output = TorchDistributor._get_output_from_framework_wrapper( + framework_wrapper_fn, + self.input_params, + train_object, + run_pytorch_file_fn, + *args, + **kwargs, + ) self.logger.info(f"Finished local training with {self.num_processes} processes") finally: @@ -552,6 +613,7 @@ def _get_spark_task_function( self, framework_wrapper_fn: Optional[Callable], train_object: Union[Callable, str], + run_pytorch_file_fn: Optional[Callable], input_dataframe: Optional["DataFrame"], *args: Any, **kwargs: Any, @@ -656,7 +718,14 @@ def set_gpus(context: "BarrierTaskContext") -> None: input_params["log_streaming_client"] = log_streaming_client try: with TorchDistributor._setup_spark_partition_data(iterator, schema_json): - output = framework_wrapper_fn(input_params, train_object, *args, **kwargs) + output = TorchDistributor._get_output_from_framework_wrapper( + framework_wrapper_fn, + input_params, + train_object, + run_pytorch_file_fn, + *args, + **kwargs, + ) finally: try: LogStreamingClient._destroy() @@ -686,6 +755,7 @@ def _run_distributed_training( self, framework_wrapper_fn: Callable, train_object: Union[Callable, str], + run_pytorch_file_fn: Optional[Callable], spark_dataframe: Optional["DataFrame"], *args: Any, **kwargs: Any, @@ -702,7 +772,12 @@ def _run_distributed_training( try: spark_task_function = self._get_spark_task_function( - framework_wrapper_fn, train_object, spark_dataframe, *args, **kwargs + framework_wrapper_fn, + train_object, + run_pytorch_file_fn, + spark_dataframe, + *args, + **kwargs, ) self._check_encryption() self.logger.info( @@ -804,13 +879,21 @@ def _setup_spark_partition_data( @staticmethod def _run_training_on_pytorch_function( - input_params: Dict[str, Any], train_fn: Callable, *args: Any, **kwargs: Any + input_params: Dict[str, Any], + train_fn: Callable, + run_pytorch_file_fn: Optional[Callable], + *args: Any, + **kwargs: Any, ) -> Any: + + if not run_pytorch_file_fn: + run_pytorch_file_fn = TorchDistributor._run_training_on_pytorch_file + with TorchDistributor._setup_files(train_fn, *args, **kwargs) as ( train_file_path, output_file_path, ): - TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path) + run_pytorch_file_fn(input_params, train_file_path) if not os.path.exists(output_file_path): raise RuntimeError( "TorchDistributor failed during training." @@ -910,17 +993,28 @@ def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> train_object is a Callable with an expected output. Returns None if train_object is a file. """ + return self._run( + train_object, TorchDistributor._run_training_on_pytorch_file, *args, **kwargs + ) + + def _run( + self, + train_object: Union[Callable, str], + run_pytorch_file_fn: Callable, + *args: Any, + **kwargs: Any, + ) -> Optional[Any]: if isinstance(train_object, str): - framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_file + framework_wrapper_fn = run_pytorch_file_fn else: - framework_wrapper_fn = ( - TorchDistributor._run_training_on_pytorch_function # type: ignore - ) + framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_function if self.local_mode: - output = self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs) + output = self._run_local_training( + framework_wrapper_fn, train_object, run_pytorch_file_fn, *args, **kwargs + ) else: output = self._run_distributed_training( - framework_wrapper_fn, train_object, None, *args, **kwargs + framework_wrapper_fn, train_object, run_pytorch_file_fn, None, *args, **kwargs ) return output @@ -976,6 +1070,7 @@ def _train_on_dataframe( return self._run_distributed_training( TorchDistributor._run_training_on_pytorch_function, train_function, + TorchDistributor._run_training_on_pytorch_file, spark_dataframe, *args, **kwargs, diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py index 89c8893ea1b30..364ed83f98dce 100644 --- a/python/pyspark/ml/torch/tests/test_distributor.py +++ b/python/pyspark/ml/torch/tests/test_distributor.py @@ -383,7 +383,7 @@ def test_local_training_succeeds(self) -> None: ) self.assertEqual( expected, - dist._run_local_training(dist._run_training_on_pytorch_file, "train.py"), + dist._run_local_training(dist._run_training_on_pytorch_file, "train.py", None), ) # cleanup if cuda_env_var: @@ -463,7 +463,12 @@ def test_dist_training_succeeds(self) -> None: ) self.assertEqual( expected, - dist._run_distributed_training(dist._run_training_on_pytorch_file, "...", None), + dist._run_distributed_training( + dist._run_training_on_pytorch_file, + "...", + TorchDistributor._run_training_on_pytorch_file, + None, + ), ) def test_get_num_tasks_distributed(self) -> None: From 8d805b8c65ac3b6bfcb204473b138cf2b06fef0c Mon Sep 17 00:00:00 2001 From: jdesjean Date: Wed, 19 Jul 2023 11:25:24 +0900 Subject: [PATCH 019/986] [SPARK-44474][CONNECT] Reenable "Test observe response" at SparkConnectServiceSuite ### What changes were proposed in this pull request? Finished is emitted in SparkConnectPlanExecution after the arrow conversion job is completed. However, since we don't await the completion of the job, it's possible for SparkConnectPlanExecution to complete before sending Finished. Closed is emitted in SparkConnectExecutePlanHandler in a separate thread. Add await in order to guarantee the order of events between Finished & Closed. ### Why are the changes needed? `Test observe response` at SparkConnectServiceSuite was disabled as flaky after [introduction of events](https://github.com/apache/spark/pull/41443). Failure surfaced race condition in emitting the Finished & Closed events for Connect request of type plan. The correct order of events is Finished < Closed. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit Closes #42063 from jdesjean/SPARK-44474. Authored-by: jdesjean Signed-off-by: Hyukjin Kwon (cherry picked from commit cf99e6c00c3c167b1a799db65bca85d7a627ea08) Signed-off-by: Hyukjin Kwon --- .../execution/SparkConnectPlanExecution.scala | 38 ++++++++++--------- .../planner/SparkConnectServiceSuite.scala | 3 +- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index d2124a38c9d4e..334dcdbcb4287 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.execution import scala.collection.JavaConverters._ +import scala.concurrent.duration.Duration import scala.util.{Failure, Success} import com.google.protobuf.ByteString @@ -153,23 +154,23 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) () } - val future = spark.sparkContext.submitJob( - rdd = batches, - processPartition = (iter: Iterator[Batch]) => iter.toArray, - partitions = Seq.range(0, numPartitions), - resultHandler = resultHandler, - resultFunc = () => ()) - - // Collect errors and propagate them to the main thread. - future.onComplete { - case Success(_) => - executePlan.eventsManager.postFinished() - case Failure(throwable) => - signal.synchronized { - error = Some(throwable) - signal.notify() - } - }(ThreadUtils.sameThread) + val future = spark.sparkContext + .submitJob( + rdd = batches, + processPartition = (iter: Iterator[Batch]) => iter.toArray, + partitions = Seq.range(0, numPartitions), + resultHandler = resultHandler, + resultFunc = () => ()) + // Collect errors and propagate them to the main thread. + .andThen { + case Success(_) => + executePlan.eventsManager.postFinished() + case Failure(throwable) => + signal.synchronized { + error = Some(throwable) + signal.notify() + } + }(ThreadUtils.sameThread) // The main thread will wait until 0-th partition is available, // then send it to client and wait for the next partition. @@ -199,6 +200,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) currentPartitionId += 1 } + ThreadUtils.awaitReady(future, Duration.Inf) + } else { + executePlan.eventsManager.postFinished() } } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index cfa37b86cd41a..498084efb8f3f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -564,8 +564,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } - // TODO(SPARK-44474): Reenable Test observe response at SparkConnectServiceSuite - ignore("Test observe response") { + test("Test observe response") { // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("test") { From f250ab2cb7b2a7a158403a1b5e3ca4e57e934c94 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 19 Jul 2023 12:00:31 +0800 Subject: [PATCH 020/986] [SPARK-43965][PYTHON][CONNECT][FOLLOWUP] Include test_parity_udtf in spark test module ### What changes were proposed in this pull request? This is a follow up PR for SPARK-43965 to include `test_parity_udtf` in spark test module. It also fixes a test failure. ### Why are the changes needed? To include a new test case. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #42065 from allisonwang-db/spark-43965-follow-up. Authored-by: allisonwang-db Signed-off-by: Ruifeng Zheng (cherry picked from commit a1eea7f206b92963417ce77ee8047705a7b32942) Signed-off-by: Ruifeng Zheng --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/tests/connect/test_parity_udtf.py | 9 --------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 85ef0c494b8ef..6382e9a536931 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -852,6 +852,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_column", "pyspark.sql.tests.connect.test_parity_readwriter", "pyspark.sql.tests.connect.test_parity_udf", + "pyspark.sql.tests.connect.test_parity_udtf", "pyspark.sql.tests.connect.test_parity_pandas_udf", "pyspark.sql.tests.connect.test_parity_pandas_map", "pyspark.sql.tests.connect.test_parity_arrow_map", diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index f5f37f1f5c070..e18e116e0034d 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -104,15 +104,6 @@ def terminate(self): with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): TestUDTF(lit(1)).show() - def test_udtf_with_empty_yield(self): - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield - - with self.assertRaisesRegex(SparkConnectGrpcException, "java.lang.NullPointerException"): - TestUDTF(lit(1)).collect() - class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): @classmethod From 4f22b2f944d3253898bddff9b0958b7a2c813bc7 Mon Sep 17 00:00:00 2001 From: Vinod KC Date: Wed, 19 Jul 2023 12:02:52 +0800 Subject: [PATCH 021/986] [SPARK-44361][SQL] Use PartitionEvaluator API in MapInBatchExec ### What changes were proposed in this pull request? SQL operator `MapInBatchExec` is updated to use the `PartitionEvaluator` API to do execution. Added a new method `mapPartitionsWithEvaluator` in `RDDBarrier`. ### Why are the changes needed? To avoid the use of lambda during distributed execution. Ref: SPARK-43061 for more details. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test cases. Once all SQL operators are refactored, will enable `spark.sql.execution.usePartitionEvaluator` by default, so all tests cover this code path. Closes #42024 from vinodkc/br_SPARK-44361. Authored-by: Vinod KC Signed-off-by: Wenchen Fan (cherry picked from commit 9b43a9f3ea551a594835a4742f7b2d1fdb1cf518) Signed-off-by: Wenchen Fan --- .../org/apache/spark/rdd/RDDBarrier.scala | 16 +++- .../python/MapInBatchEvaluatorFactory.scala | 92 +++++++++++++++++++ .../sql/execution/python/MapInBatchExec.scala | 80 +++++++--------- 3 files changed, 137 insertions(+), 51 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index b70ea0073c9a0..13ce8f1e1b540 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -19,8 +19,8 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.TaskContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.{PartitionEvaluatorFactory, TaskContext} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -76,5 +76,17 @@ class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) { ) } + /** + * Return a new RDD by applying an evaluator to each partition of the wrapped RDD. The given + * evaluator factory will be serialized and sent to executors, and each task will create an + * evaluator with the factory, and use the evaluator to transform the data of the input + * partition. + */ + @DeveloperApi + @Since("3.5.0") + def mapPartitionsWithEvaluator[U: ClassTag]( + evaluatorFactory: PartitionEvaluatorFactory[T, U]): RDD[U] = rdd.withScope { + new MapPartitionsWithEvaluatorRDD(rdd, evaluatorFactory) + } // TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala new file mode 100644 index 0000000000000..efb063476a41e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.{ContextAwareIterator, TaskContext} +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} +import org.apache.spark.api.python.{ChainedPythonFunctions} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +class MapInBatchEvaluatorFactory( + output: Seq[Attribute], + chainedFunc: Seq[ChainedPythonFunctions], + outputTypes: StructType, + batchSize: Int, + pythonEvalType: Int, + sessionLocalTimeZone: String, + largeVarTypes: Boolean, + pythonRunnerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] { + + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = + new MapInBatchEvaluator + + private class MapInBatchEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + assert(inputs.length == 1) + val inputIter = inputs.head + // Single function with one struct. + val argOffsets = Array(Array(0)) + val context = TaskContext.get() + val contextAwareIterator = new ContextAwareIterator(context, inputIter) + + // Here we wrap it via another row so that Python sides understand it + // as a DataFrame. + val wrappedIter = contextAwareIterator.map(InternalRow(_)) + + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = + if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, + pythonEvalType, + argOffsets, + StructType(Array(StructField("struct", outputTypes))), + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) + + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter + .flatMap { batch => + // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select + // the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + } + .map(unsafeProj) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index b4af3db3c83a7..0703f57c33d38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -17,18 +17,14 @@ package org.apache.spark.sql.execution.python -import scala.collection.JavaConverters._ - -import org.apache.spark.{ContextAwareIterator, JobArtifactSet, TaskContext} +import org.apache.spark.JobArtifactSet import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** * A relation produced by applying a function that takes an iterator of batches @@ -56,53 +52,39 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { override def outputPartitioning: Partitioning = child.outputPartitioning override protected def doExecute(): RDD[InternalRow] = { - def mapper(inputIter: Iterator[InternalRow]): Iterator[InternalRow] = { - // Single function with one struct. - val argOffsets = Array(Array(0)) - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - val outputTypes = child.schema - - val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, inputIter) - - // Here we wrap it via another row so that Python sides understand it - // as a DataFrame. - val wrappedIter = contextAwareIterator.map(InternalRow(_)) - - // DO NOT use iter.grouped(). See BatchIterator. - val batchIter = - if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) - - val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, - pythonEvalType, - argOffsets, - StructType(Array(StructField("struct", outputTypes))), - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) - - columnarBatchIter.flatMap { batch => - // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select - // the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) - } + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + val pythonFunction = func.asInstanceOf[PythonUDF].func + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + val evaluatorFactory = new MapInBatchEvaluatorFactory( + output, + chainedFunc, + child.schema, + conf.arrowMaxRecordsPerBatch, + pythonEvalType, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID) if (isBarrier) { - child.execute().barrier().mapPartitions(mapper) + val rddBarrier = child.execute().barrier() + if (conf.usePartitionEvaluator) { + rddBarrier.mapPartitionsWithEvaluator(evaluatorFactory) + } else { + rddBarrier.mapPartitions { iter => + evaluatorFactory.createEvaluator().eval(0, iter) + } + } } else { - child.execute().mapPartitionsInternal(mapper) + val inputRdd = child.execute() + if (conf.usePartitionEvaluator) { + inputRdd.mapPartitionsWithEvaluator(evaluatorFactory) + } else { + inputRdd.mapPartitionsInternal { iter => + evaluatorFactory.createEvaluator().eval(0, iter) + } + } } } } From 3c07a3e7e5dba165f2c579bb4969405bc9d4f20b Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Wed, 19 Jul 2023 02:18:59 -0500 Subject: [PATCH 022/986] [SPARK-44272][YARN] Path Inconsistency when Operating statCache within Yarn Client ### What changes were proposed in this pull request? 1. Change `statCache.getOrElse` to `statCache.getOrElseUpdate` so that the corresponding FileStatus can be cached into `statCache` 2. Change the `Path` parameter `isPublic`, `checkPermissionOfOther`, and `ancestorsHaveExecutePermissions` to `URI`. 3. Add `getParentURI` method when we construct the parent URI. ### Why are the changes needed? We should not use `uri.getPath()` when constructing the Path which will not retain information like scheme. This means that `statCache` is not really taking any effect. For example, if uri is "file:/foo.invalid.com:8080/tmp/testing", then ``` uri.getPath -> /foo.invalid.com:8080/tmp/testing uri.toString -> file:/foo.invalid.com:8080/tmp/testing ``` Please also see more details from JIRA [ticket](https://issues.apache.org/jira/browse/SPARK-44272). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add additional UT to validate the FileStatus is cached as expected. Closes #41821 from shuwang21/fixcache. Lead-authored-by: Shu Wang Co-authored-by: Shu Wang Signed-off-by: Mridul Muralidharan gmail.com> (cherry picked from commit 0879a25c4c271dea6cd8f2a45e5c5e6e6743a962) Signed-off-by: Mridul Muralidharan --- .../yarn/ClientDistributedCacheManager.scala | 41 +++++++++----- .../ClientDistributedCacheManagerSuite.scala | 55 +++++++++++++++++++ 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 8add74428882b..6d50b5e4fd2da 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -68,7 +68,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { link: String, statCache: Map[URI, FileStatus], appMasterOnly: Boolean = false): Unit = { - val destStatus = statCache.getOrElse(destPath.toUri(), fs.getFileStatus(destPath)) + val destStatus = getFileStatus(fs, destPath.toUri, statCache) val amJarRsrc = Records.newRecord(classOf[LocalResource]) amJarRsrc.setType(resourceType) val visibility = getVisibility(conf, destPath.toUri(), statCache) @@ -119,46 +119,61 @@ private[spark] class ClientDistributedCacheManager() extends Logging { */ private def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { val fs = FileSystem.get(uri, conf) - val current = new Path(uri.getPath()) // the leaf level file should be readable by others - if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) { + if (!checkPermissionOfOther(fs, uri, FsAction.READ, statCache)) { return false } - ancestorsHaveExecutePermissions(fs, current.getParent(), statCache) + ancestorsHaveExecutePermissions(fs, getParentURI(uri), statCache) } /** - * Returns true if all ancestors of the specified path have the 'execute' + * Get the Parent URI of the given URI. Notes that the query & fragment of original URI will not + * be inherited when obtaining parent URI. + * + * @return the parent URI, null if the given uri is the root + */ + private[yarn] def getParentURI(uri: URI): URI = { + val path = new Path(uri.toString) + val parent = path.getParent() + if (parent == null) { + null + } else { + parent.toUri() + } + } + + /** + * Returns true if all ancestors of the specified uri have the 'execute' * permission set for all users (i.e. that other users can traverse - * the directory hierarchy to the given path) + * the directory hierarchy to the given uri) * @return true if all ancestors have the 'execute' permission set for all users */ private def ancestorsHaveExecutePermissions( fs: FileSystem, - path: Path, + uri: URI, statCache: Map[URI, FileStatus]): Boolean = { - var current = path + var current = uri while (current != null) { - // the subdirs in the path should have execute permissions for others + // the subdirs in the corresponding uri path should have execute permissions for others if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) { return false } - current = current.getParent() + current = getParentURI(current) } true } /** - * Checks for a given path whether the Other permissions on it + * Checks for a given URI whether the Other permissions on it * imply the permission in the passed FsAction * @return true if the path in the uri is visible to all, false otherwise */ private def checkPermissionOfOther( fs: FileSystem, - path: Path, + uri: URI, action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { - val status = getFileStatus(fs, path.toUri(), statCache) + val status = getFileStatus(fs, uri, statCache) val perms = status.getPermission() val otherAction = perms.getOtherAction() otherAction.implies(action) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 996654f7415a9..4e8971cbfa05d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import java.net.URI +import scala.collection.mutable import scala.collection.mutable.HashMap import scala.collection.mutable.Map @@ -44,6 +45,60 @@ class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar } } + test("SPARK-44272: test addResource added FileStatus to statCache and getVisibility can read" + + " from statCache") { + val distMgr = new ClientDistributedCacheManager() { + override private[yarn] def getFileStatus(fs: FileSystem, uri: URI, + statCache: mutable.Map[URI, FileStatus]): FileStatus = { + statCache.getOrElseUpdate(uri, new FileStatus()) + } + } + val fs = mock[FileSystem] + val conf = new Configuration() + val destPathA = new Path("file:///foo.invalid.com:8080/tmp/A") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + distMgr.addResource(fs, conf, destPathA, localResources, LocalResourceType.FILE, "link", + statCache, false) + assert(statCache.size === 2) + assert(statCache.contains(destPathA.toUri)) + assert(statCache.contains(destPathA.getParent.toUri)) + + val destPathB = new Path("file:///foo.invalid.com:8080/tmp/B") + distMgr.addResource(fs, conf, destPathB, localResources, LocalResourceType.FILE, "link", + statCache, false) + assert(statCache.size === 3) + assert(statCache.contains(destPathB.toUri)) + + val destPathC = new Path("file:///foo.invalid.com:8080/root/C") + distMgr.addResource(fs, conf, destPathC, localResources, LocalResourceType.FILE, "link", + statCache, false) + assert(statCache.size === 5) + assert(statCache.contains(destPathC.toUri)) + assert(statCache.contains(destPathC.getParent.toUri)) + } + + test("SPARK-44272: test getParentURI") { + val distMgr = new ClientDistributedCacheManager() + val scheme = "file" + val userInfo = "user" + val host = "foo.com" + val port = 8080 + val path = "/tmp/testing" + val uri = new URI(scheme, userInfo, host, port, path, null, null) + val parentURI = distMgr.getParentURI(uri) + assert(uri.getScheme === parentURI.getScheme) + assert(uri.getUserInfo === parentURI.getUserInfo) + assert(uri.getHost === parentURI.getHost) + assert(uri.getPort === parentURI.getPort) + assert(new Path(uri.getPath).getParent.toString === parentURI.getPath) + + val rootPath = "/" + val parentRootURI = distMgr.getParentURI( + new URI(scheme, userInfo, host, port, rootPath, null, null)) + assert(parentRootURI === null) + } + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] From e6d3009265374ff2e6f431ca372f3a6f0f9554a9 Mon Sep 17 00:00:00 2001 From: Mathew Jacob Date: Wed, 19 Jul 2023 17:29:29 +0900 Subject: [PATCH 023/986] [SPARK-44264][ML][PYTHON] Support Distributed Training of Functions Using Deepspeed Made the DeepspeedTorchDistributor run() method use the _run() function as the backbone. It allows the user to run distributed training of a function with deepspeed easily. This adds the ability for the user to pass in a function as the train_object when calling DeepspeedTorchDistributor.run(). The user must have all necessary imports within the function itself, and the function must be picklable. An example use case can be found in the python file linked in the JIRA ticket. Notebook/file linked in the JIRA ticket. Formal e2e tests will come in future PR. - [ ] Add more e2e tests for both running a regular pytorch file and running a function for training - [ ] Write more documentation Closes #42067 from mathewjacob1002/add_func_deepspeed. Authored-by: Mathew Jacob Signed-off-by: Hyukjin Kwon (cherry picked from commit 392f8d80c8cc4823ea513e78d452bba7f1a7d76c) Signed-off-by: Hyukjin Kwon --- .../ml/deepspeed/deepspeed_distributor.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/deepspeed/deepspeed_distributor.py b/python/pyspark/ml/deepspeed/deepspeed_distributor.py index df1aac21e1f89..d6ae98de5e345 100644 --- a/python/pyspark/ml/deepspeed/deepspeed_distributor.py +++ b/python/pyspark/ml/deepspeed/deepspeed_distributor.py @@ -15,7 +15,6 @@ # limitations under the License. # import json -import os import sys import tempfile from typing import ( @@ -135,19 +134,6 @@ def _run_training_on_pytorch_file( def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> Optional[Any]: # If the "train_object" is a string, then we assume it's a filepath. # Otherwise, we assume it's a function. - if isinstance(train_object, str): - if os.path.exists(train_object) is False: - raise FileNotFoundError(f"The path to training file {train_object} does not exist.") - framework_wrapper_fn = DeepspeedTorchDistributor._run_training_on_pytorch_file - else: - raise RuntimeError("Python training functions aren't supported as inputs at this time") - - if self.local_mode: - return self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs) - return self._run_distributed_training( - framework_wrapper_fn, - train_object, - spark_dataframe=None, - *args, - **kwargs, # type:ignore[misc] + return self._run( + train_object, DeepspeedTorchDistributor._run_training_on_pytorch_file, *args, **kwargs ) From 3107c1e642b5efe7fd88329197d912f72f711c80 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 19 Jul 2023 18:02:27 +0900 Subject: [PATCH 024/986] [SPARK-44361][SQL][FOLLOW-UP] Remove unused variables and fix import statements ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/42024 that removes unused variables and fix import statements (which should be the part of the whole refactoring). ### Why are the changes needed? To properly cleanup. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests should covoer Closes #42068 from HyukjinKwon/SPARK-44361-followup. Lead-authored-by: Hyukjin Kwon Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon (cherry picked from commit bca28f87ae12ffe3b49c78503af580b503f120ee) Signed-off-by: Hyukjin Kwon --- .../python/MapInBatchEvaluatorFactory.scala | 25 ++++++++----------- .../sql/execution/python/MapInBatchExec.scala | 6 ----- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index efb063476a41e..1e15aa7f777bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ -import org.apache.spark.{ContextAwareIterator, TaskContext} -import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} -import org.apache.spark.api.python.{ChainedPythonFunctions} +import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, PartitionEvaluatorFactory, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -76,17 +75,15 @@ class MapInBatchEvaluatorFactory( val unsafeProj = UnsafeProjection.create(output, output) - columnarBatchIter - .flatMap { batch => - // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select - // the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - } - .map(unsafeProj) + columnarBatchIter.flatMap { batch => + // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select + // the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 0703f57c33d38..368184934fa03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -39,14 +39,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { protected val isBarrier: Boolean - private val pythonFunction = func.asInstanceOf[PythonUDF].func - override def producedAttributes: AttributeSet = AttributeSet(output) - private val batchSize = conf.arrowMaxRecordsPerBatch - - private val largeVarTypes = conf.arrowUseLargeVarTypes - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override def outputPartitioning: Partitioning = child.outputPartitioning From 244d02be8e4535458d2c12be7b461aa1b1e497f6 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 19 Jul 2023 09:26:26 -0400 Subject: [PATCH 025/986] [SPARK-44396][CONNECT] Direct Arrow Deserialization ### What changes were proposed in this pull request? This PR adds direct arrow to user object deserialization to the Spark Connect Scala Client. ### Why are the changes needed? We want to decouple the scala client from catalyst. We need a way to encode user object from and to arrrow. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests to `ArrowEncoderSuite`. Closes #42011 from hvanhovell/SPARK-44396. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell (cherry picked from commit 5939b75b5fe701cb63fedc64f57c9f0a15ef9202) Signed-off-by: Herman van Hovell --- connector/connect/client/jvm/pom.xml | 19 + .../client/arrow/ScalaCollectionUtils.scala | 38 ++ .../client/arrow/ScalaCollectionUtils.scala | 37 ++ .../sql/connect/client/SparkResult.scala | 230 +++++--- .../client/arrow/ArrowDeserializer.scala | 533 ++++++++++++++++++ .../client/arrow/ArrowEncoderUtils.scala | 3 + .../ConcatenatingArrowStreamReader.scala | 185 ++++++ .../apache/spark/sql/ClientE2ETestSuite.scala | 49 +- .../KeyValueGroupedDatasetE2ETestSuite.scala | 36 +- .../spark/sql/application/ReplE2ESuite.scala | 6 +- .../client/arrow/ArrowEncoderSuite.scala | 127 +++-- 11 files changed, 1085 insertions(+), 178 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala create mode 100644 connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 8a51bf65d6a88..0f6783cbd685b 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -140,6 +140,7 @@ + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes @@ -224,6 +225,24 @@ + + org.codehaus.mojo + build-helper-maven-plugin + + + add-sources + generate-sources + + add-source + + + + src/main/scala-${scala.binary.version} + + + + + \ No newline at end of file diff --git a/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala new file mode 100644 index 0000000000000..c2e01d974e0e4 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import scala.collection.generic.{GenericCompanion, GenMapFactory} +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion + +/** + * A couple of scala version specific collection utility functions. + */ +private[arrow] object ScalaCollectionUtils { + def getIterableCompanion(tag: ClassTag[_]): GenericCompanion[Iterable] = { + ArrowDeserializers.resolveCompanion[GenericCompanion[Iterable]](tag) + } + def getMapCompanion(tag: ClassTag[_]): GenMapFactory[Map] = { + resolveCompanion[GenMapFactory[Map]](tag) + } + def wrap[T](array: AnyRef): mutable.WrappedArray[T] = { + mutable.WrappedArray.make(array) + } +} diff --git a/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala new file mode 100644 index 0000000000000..8a80e34162283 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import scala.collection.{mutable, IterableFactory, MapFactory} +import scala.reflect.ClassTag + +import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion + +/** + * A couple of scala version specific collection utility functions. + */ +private[arrow] object ScalaCollectionUtils { + def getIterableCompanion(tag: ClassTag[_]): IterableFactory[Iterable] = { + ArrowDeserializers.resolveCompanion[IterableFactory[Iterable]](tag) + } + def getMapCompanion(tag: ClassTag[_]): MapFactory[Map] = { + resolveCompanion[MapFactory[Map]](tag) + } + def wrap[T](array: AnyRef): mutable.WrappedArray[T] = { + mutable.WrappedArray.make(array.asInstanceOf[Array[T]]) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index a727c86f70fc6..1cdc2035de60b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -16,53 +16,48 @@ */ package org.apache.spark.sql.connect.client -import java.util.Collections +import java.util.Objects -import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.FieldVector -import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} +import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable} +import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} +import org.apache.spark.sql.connect.client.util.Cleanable import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} private[sql] class SparkResult[T]( responses: java.util.Iterator[proto.ExecutePlanResponse], allocator: BufferAllocator, encoder: AgnosticEncoder[T]) extends AutoCloseable - with Cleanable { + with Cleanable { self => private[this] var numRecords: Int = 0 private[this] var structType: StructType = _ - private[this] var boundEncoder: ExpressionEncoder[T] = _ - private[this] var nextBatchIndex: Int = 0 - private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch] - - private def createEncoder(schema: StructType): ExpressionEncoder[T] = { - val agnosticEncoder = createEncoder(encoder, schema).asInstanceOf[AgnosticEncoder[T]] - ExpressionEncoder(agnosticEncoder) - } + private[this] var arrowSchema: pojo.Schema = _ + private[this] var nextResultIndex: Int = 0 + private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])] /** * Update RowEncoder and recursively update the fields of the ProductEncoder if found. */ - private def createEncoder(enc: AgnosticEncoder[_], dataType: DataType): AgnosticEncoder[_] = { + private def createEncoder[E]( + enc: AgnosticEncoder[E], + dataType: DataType): AgnosticEncoder[E] = { enc match { case UnboundRowEncoder => // Replace the row encoder with the encoder inferred from the schema. - RowEncoder.encoderFor(dataType.asInstanceOf[StructType]) + RowEncoder + .encoderFor(dataType.asInstanceOf[StructType]) + .asInstanceOf[AgnosticEncoder[E]] case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) => // Recursively continue updating the tuple product encoder val schema = dataType.asInstanceOf[StructType] @@ -76,53 +71,61 @@ private[sql] class SparkResult[T]( } } - private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = { - while (responses.hasNext) { + private def processResponses( + stopOnSchema: Boolean = false, + stopOnArrowSchema: Boolean = false, + stopOnFirstNonEmptyResponse: Boolean = false): Boolean = { + var nonEmpty = false + var stop = false + while (!stop && responses.hasNext) { val response = responses.next() if (response.hasSchema) { // The original schema should arrive before ArrowBatches. structType = DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType] - } else if (response.hasArrowBatch) { + stop |= stopOnSchema + } + if (response.hasArrowBatch) { val ipcStreamBytes = response.getArrowBatch.getData - val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator) - try { - val root = reader.getVectorSchemaRoot - if (structType == null) { - // If the schema is not available yet, fallback to the schema from Arrow. - structType = ArrowUtils.fromArrowSchema(root.getSchema) - } - // TODO: create encoders that directly operate on arrow vectors. - if (boundEncoder == null) { - boundEncoder = createEncoder(structType) - .resolveAndBind(DataTypeUtils.toAttributes(structType)) - } - while (reader.loadNextBatch()) { - val rowCount = root.getRowCount - if (rowCount > 0) { - val vectors = root.getFieldVectors.asScala - .map(v => new ArrowColumnVector(transferToNewVector(v))) - .toArray[ColumnVector] - idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount)) - nextBatchIndex += 1 - numRecords += rowCount - if (stopOnFirstNonEmptyResponse) { - return true - } - } + val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator) + if (arrowSchema == null) { + arrowSchema = reader.schema + stop |= stopOnArrowSchema + } else if (arrowSchema != reader.schema) { + throw new IllegalStateException( + s"""Schema Mismatch between expected and received schema: + |=== Expected Schema === + |$arrowSchema + |=== Received Schema === + |${reader.schema} + |""".stripMargin) + } + if (structType == null) { + // If the schema is not available yet, fallback to the arrow schema. + structType = ArrowUtils.fromArrowSchema(reader.schema) + } + var numRecordsInBatch = 0 + val messages = Seq.newBuilder[ArrowMessage] + while (reader.hasNext) { + val message = reader.next() + message match { + case batch: ArrowRecordBatch => + numRecordsInBatch += batch.getLength + case _ => } - } finally { - reader.close() + messages += message + } + // Skip the entire result if it is empty. + if (numRecordsInBatch > 0) { + numRecords += numRecordsInBatch + resultMap.put(nextResultIndex, (reader.bytesRead, messages.result())) + nextResultIndex += 1 + nonEmpty |= true + stop |= stopOnFirstNonEmptyResponse } } } - false - } - - private def transferToNewVector(in: FieldVector): FieldVector = { - val pair = in.getTransferPair(allocator) - pair.transfer() - pair.getTo.asInstanceOf[FieldVector] + nonEmpty } /** @@ -130,7 +133,7 @@ private[sql] class SparkResult[T]( */ def length: Int = { // We need to process all responses to make sure numRecords is correct. - processResponses(stopOnFirstNonEmptyResponse = false) + processResponses() numRecords } @@ -139,7 +142,9 @@ private[sql] class SparkResult[T]( * the schema of the result. */ def schema: StructType = { - processResponses(stopOnFirstNonEmptyResponse = true) + if (structType == null) { + processResponses(stopOnSchema = true) + } structType } @@ -172,52 +177,93 @@ private[sql] class SparkResult[T]( private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = { new java.util.Iterator[T] with AutoCloseable { - private[this] var batchIndex: Int = -1 - private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator() - private[this] var deserializer: Deserializer[T] = _ + private[this] var iterator: CloseableIterator[T] = _ - override def hasNext: Boolean = { - if (iterator.hasNext) { - return true - } - - val nextBatchIndex = batchIndex + 1 - if (destructive) { - idxToBatches.remove(batchIndex).foreach(_.close()) + private def initialize(): Unit = { + if (iterator == null) { + iterator = new ArrowDeserializingIterator( + createEncoder(encoder, schema), + new ConcatenatingArrowStreamReader( + allocator, + Iterator.single(new ResultMessageIterator(destructive)), + destructive)) } + } - val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) { - processResponses(stopOnFirstNonEmptyResponse = true) - } else { - true - } - if (hasNextBatch) { - batchIndex = nextBatchIndex - iterator = idxToBatches(nextBatchIndex).rowIterator() - if (deserializer == null) { - deserializer = boundEncoder.createDeserializer() - } - } - hasNextBatch + override def hasNext: Boolean = { + initialize() + iterator.hasNext } override def next(): T = { - if (!hasNext) { - throw new NoSuchElementException - } - deserializer(iterator.next()) + initialize() + iterator.next() } - override def close(): Unit = SparkResult.this.close() + override def close(): Unit = { + if (iterator != null) { + iterator.close() + } + } } } /** * Close this result, freeing any underlying resources. */ - override def close(): Unit = { - idxToBatches.values.foreach(_.close()) + override def close(): Unit = cleaner.close() + + override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap) + + private class ResultMessageIterator(destructive: Boolean) extends AbstractMessageIterator { + private[this] var totalBytesRead = 0L + private[this] var nextResultIndex = 0 + private[this] var current: Iterator[ArrowMessage] = Iterator.empty + + override def bytesRead: Long = totalBytesRead + + override def schema: pojo.Schema = { + if (arrowSchema == null) { + // We need a schema to proceed. Spark Connect will always + // return a result (with a schema) even if the result is empty. + processResponses(stopOnArrowSchema = true) + Objects.requireNonNull(arrowSchema) + } + arrowSchema + } + + override def hasNext: Boolean = { + if (current.hasNext) { + return true + } + val hasNextResult = if (!resultMap.contains(nextResultIndex)) { + self.processResponses(stopOnFirstNonEmptyResponse = true) + } else { + true + } + if (hasNextResult) { + val Some((sizeInBytes, messages)) = if (destructive) { + resultMap.remove(nextResultIndex) + } else { + resultMap.get(nextResultIndex) + } + totalBytesRead += sizeInBytes + current = messages.iterator + nextResultIndex += 1 + } + hasNextResult + } + + override def next(): ArrowMessage = { + if (!hasNext) { + throw new NoSuchElementException() + } + current.next() + } } +} - override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq) +private[client] class SparkResultCloseable(resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])]) + extends AutoCloseable { + override def close(): Unit = resultMap.values.foreach(_._2.foreach(_.close())) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala new file mode 100644 index 0000000000000..154866d699a34 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import java.io.{ByteArrayInputStream, IOException} +import java.lang.invoke.{MethodHandles, MethodType} +import java.lang.reflect.Modifier +import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} +import java.time._ +import java.util +import java.util.{List => JList, Locale, Map => JMap} + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.util.Text + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.types.Decimal + +/** + * Helper class for converting arrow batches into user objects. + */ +object ArrowDeserializers { + import ArrowEncoderUtils._ + + /** + * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC Streams, and + * deserializes these streams into one or more instances of `T` + */ + def deserializeFromArrow[T]( + input: Iterator[Array[Byte]], + encoder: AgnosticEncoder[T], + allocator: BufferAllocator): CloseableIterator[T] = { + try { + val reader = new ConcatenatingArrowStreamReader( + allocator, + input.map(bytes => new MessageIterator(new ByteArrayInputStream(bytes), allocator)), + destructive = true) + new ArrowDeserializingIterator(encoder, reader) + } catch { + case _: IOException => + new EmptyDeserializingIterator(encoder) + } + } + + /** + * Create a deserializer of `T` on top of the given `root`. + */ + private[arrow] def deserializerFor[T]( + encoder: AgnosticEncoder[T], + root: VectorSchemaRoot): Deserializer[T] = { + val data: AnyRef = if (encoder.isStruct) { + root + } else { + // The input schema is allowed to have multiple columns, + // by convention we bind to the first one. + root.getVector(0) + } + deserializerFor(encoder, data).asInstanceOf[Deserializer[T]] + } + + private[arrow] def deserializerFor( + encoder: AgnosticEncoder[_], + data: AnyRef): Deserializer[Any] = { + (encoder, data) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => + new FieldDeserializer[Boolean, BitVector](v) { + def value(i: Int): Boolean = vector.get(i) != 0 + } + case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) => + new FieldDeserializer[Byte, TinyIntVector](v) { + def value(i: Int): Byte = vector.get(i) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) => + new FieldDeserializer[Short, SmallIntVector](v) { + def value(i: Int): Short = vector.get(i) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) => + new FieldDeserializer[Int, IntVector](v) { + def value(i: Int): Int = vector.get(i) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) => + new FieldDeserializer[Long, BigIntVector](v) { + def value(i: Int): Long = vector.get(i) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) => + new FieldDeserializer[Float, Float4Vector](v) { + def value(i: Int): Float = vector.get(i) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) => + new FieldDeserializer[Double, Float8Vector](v) { + def value(i: Int): Double = vector.get(i) + } + case (NullEncoder, v: NullVector) => + new FieldDeserializer[Any, NullVector](v) { + def value(i: Int): Any = null + } + case (StringEncoder, v: VarCharVector) => + new FieldDeserializer[String, VarCharVector](v) { + def value(i: Int): String = getString(vector, i) + } + case (JavaEnumEncoder(tag), v: VarCharVector) => + // It would be nice if we can get Enum.valueOf working... + val valueOf = methodLookup.findStatic( + tag.runtimeClass, + "valueOf", + MethodType.methodType(tag.runtimeClass, classOf[String])) + new FieldDeserializer[Enum[_], VarCharVector](v) { + def value(i: Int): Enum[_] = { + valueOf.invoke(getString(vector, i)).asInstanceOf[Enum[_]] + } + } + case (ScalaEnumEncoder(parent, _), v: VarCharVector) => + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(parent).module.asModule + val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] + new FieldDeserializer[Enumeration#Value, VarCharVector](v) { + def value(i: Int): Enumeration#Value = enumeration.withName(getString(vector, i)) + } + case (BinaryEncoder, v: VarBinaryVector) => + new FieldDeserializer[Array[Byte], VarBinaryVector](v) { + def value(i: Int): Array[Byte] = vector.get(i) + } + case (SparkDecimalEncoder(_), v: DecimalVector) => + new FieldDeserializer[Decimal, DecimalVector](v) { + def value(i: Int): Decimal = Decimal(vector.getObject(i)) + } + case (ScalaDecimalEncoder(_), v: DecimalVector) => + new FieldDeserializer[BigDecimal, DecimalVector](v) { + def value(i: Int): BigDecimal = BigDecimal(vector.getObject(i)) + } + case (JavaDecimalEncoder(_, _), v: DecimalVector) => + new FieldDeserializer[JBigDecimal, DecimalVector](v) { + def value(i: Int): JBigDecimal = vector.getObject(i) + } + case (ScalaBigIntEncoder, v: DecimalVector) => + new FieldDeserializer[BigInt, DecimalVector](v) { + def value(i: Int): BigInt = new BigInt(vector.getObject(i).toBigInteger) + } + case (JavaBigIntEncoder, v: DecimalVector) => + new FieldDeserializer[JBigInteger, DecimalVector](v) { + def value(i: Int): JBigInteger = vector.getObject(i).toBigInteger + } + case (DayTimeIntervalEncoder, v: DurationVector) => + new FieldDeserializer[Duration, DurationVector](v) { + def value(i: Int): Duration = vector.getObject(i) + } + case (YearMonthIntervalEncoder, v: IntervalYearVector) => + new FieldDeserializer[Period, IntervalYearVector](v) { + def value(i: Int): Period = vector.getObject(i).normalized() + } + case (DateEncoder(_), v: DateDayVector) => + new FieldDeserializer[java.sql.Date, DateDayVector](v) { + def value(i: Int): java.sql.Date = DateTimeUtils.toJavaDate(vector.get(i)) + } + case (LocalDateEncoder(_), v: DateDayVector) => + new FieldDeserializer[LocalDate, DateDayVector](v) { + def value(i: Int): LocalDate = DateTimeUtils.daysToLocalDate(vector.get(i)) + } + case (TimestampEncoder(_), v: TimeStampMicroTZVector) => + new FieldDeserializer[java.sql.Timestamp, TimeStampMicroTZVector](v) { + def value(i: Int): java.sql.Timestamp = DateTimeUtils.toJavaTimestamp(vector.get(i)) + } + case (InstantEncoder(_), v: TimeStampMicroTZVector) => + new FieldDeserializer[Instant, TimeStampMicroTZVector](v) { + def value(i: Int): Instant = DateTimeUtils.microsToInstant(vector.get(i)) + } + case (LocalDateTimeEncoder, v: TimeStampMicroVector) => + new FieldDeserializer[LocalDateTime, TimeStampMicroVector](v) { + def value(i: Int): LocalDateTime = DateTimeUtils.microsToLocalDateTime(vector.get(i)) + } + + case (OptionEncoder(value), v) => + val deserializer = deserializerFor(value, v) + new Deserializer[Any] { + override def get(i: Int): Any = Option(deserializer.get(i)) + } + + case (ArrayEncoder(element, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector) + new FieldDeserializer[AnyRef, ListVector](v) { + def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag) + } + + case (IterableEncoder(tag, element, _, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector) + if (isSubClass(Classes.WRAPPED_ARRAY, tag)) { + // Wrapped array is a bit special because we need to use an array of the element type. + // Some parts of our codebase (unfortunately) rely on this for type inference on results. + new FieldDeserializer[mutable.WrappedArray[Any], ListVector](v) { + def value(i: Int): mutable.WrappedArray[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + ScalaCollectionUtils.wrap(array) + } + } + } else if (isSubClass(Classes.ITERABLE, tag)) { + val companion = ScalaCollectionUtils.getIterableCompanion(tag) + new FieldDeserializer[Iterable[Any], ListVector](v) { + def value(i: Int): Iterable[Any] = { + val builder = companion.newBuilder[Any] + loadListIntoBuilder(vector, i, deserializer, builder) + builder.result() + } + } + } else if (isSubClass(Classes.JLIST, tag)) { + val newInstance = resolveJavaListCreator(tag) + new FieldDeserializer[JList[Any], ListVector](v) { + def value(i: Int): JList[Any] = { + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + val list = newInstance(end - index) + while (index < end) { + list.add(deserializer.get(index)) + index += 1 + } + list + } + } + } else { + throw unsupportedCollectionType(tag.runtimeClass) + } + + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val keyDeserializer = deserializerFor(key, structVector.getChild(MapVector.KEY_NAME)) + val valueDeserializer = + deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME)) + if (isSubClass(Classes.MAP, tag)) { + val companion = ScalaCollectionUtils.getMapCompanion(tag) + new FieldDeserializer[Map[Any, Any], MapVector](v) { + def value(i: Int): Map[Any, Any] = { + val builder = companion.newBuilder[Any, Any] + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + builder.sizeHint(end - index) + while (index < end) { + builder += (keyDeserializer.get(index) -> valueDeserializer.get(index)) + index += 1 + } + builder.result() + } + } + } else if (isSubClass(Classes.JMAP, tag)) { + val newInstance = resolveJavaMapCreator(tag) + new FieldDeserializer[JMap[Any, Any], MapVector](v) { + def value(i: Int): JMap[Any, Any] = { + val map = newInstance() + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + while (index < end) { + map.put(keyDeserializer.get(index), valueDeserializer.get(index)) + index += 1 + } + map + } + } + } else { + throw unsupportedCollectionType(tag.runtimeClass) + } + + case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) => + // We should try to make this work with MethodHandles. + val Some(constructor) = + ScalaReflection.findConstructor(tag.runtimeClass, fields.map(_.enc.clsTag.runtimeClass)) + val deserializers = if (isTuple(tag.runtimeClass)) { + fields.zip(vectors).map { case (field, vector) => + deserializerFor(field.enc, vector) + } + } else { + val lookup = createFieldLookup(vectors) + fields.map { field => + deserializerFor(field.enc, lookup(field.name)) + } + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef])) + } + } + + case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => + val lookup = createFieldLookup(vectors) + val deserializers = fields.toArray.map { field => + deserializerFor(field.enc, lookup(field.name)) + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val values = deserializers.map(_.get(i)) + new GenericRowWithSchema(values, r.schema) + } + } + + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + val constructor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + val lookup = createFieldLookup(vectors) + val setters = fields.map { field => + val vector = lookup(field.name) + val deserializer = deserializerFor(field.enc, vector) + val setter = methodLookup.findVirtual( + tag.runtimeClass, + field.writeMethod.get, + MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) + (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val instance = constructor.invoke() + setters.foreach(_(instance, i)) + instance + } + } + + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType) + + case _ => + throw new RuntimeException( + s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") + } + } + + private val methodLookup = MethodHandles.lookup() + + /** + * Resolve the companion object for a scala class. In our particular case the class we pass in + * is a Scala collection. We use the companion to create a builder for that collection. + */ + private[arrow] def resolveCompanion[T](tag: ClassTag[_]): T = { + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(tag.runtimeClass).companion.asModule + mirror.reflectModule(module).instance.asInstanceOf[T] + } + + /** + * Create a function that creates a [[util.List]] instance. The int parameter of the creator + * function is a size hint. + * + * If the [[ClassTag]] `tag` points to an interface instead of a concrete class we try to use + * [[util.ArrayList]]. For concrete classes we try to use a constructor that takes a single + * [[Int]] argument, it is assumed this is a size hint. If no such constructor exists we + * fallback to a no-args constructor. + */ + private def resolveJavaListCreator(tag: ClassTag[_]): Int => JList[Any] = { + val cls = tag.runtimeClass + val modifiers = cls.getModifiers + if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) { + // Abstract class or interface; we try to use ArrayList. + if (!cls.isAssignableFrom(classOf[util.ArrayList[_]])) { + unsupportedCollectionType(cls) + } + (size: Int) => new util.ArrayList[Any](size) + } else { + try { + // Try to use a constructor that (hopefully) takes a size argument. + val ctor = methodLookup.findConstructor( + tag.runtimeClass, + MethodType.methodType(classOf[Unit], Integer.TYPE)) + size => ctor.invoke(size).asInstanceOf[JList[Any]] + } catch { + case _: java.lang.NoSuchMethodException => + // Use a no-args constructor. + val ctor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + _ => ctor.invoke().asInstanceOf[JList[Any]] + } + } + } + + /** + * Create a function that creates a [[util.Map]] instance. + * + * If the [[ClassTag]] `tag` points to an interface instead of a concrete class we try to use + * [[util.HashMap]]. For concrete classes we try to use a no-args constructor. + */ + private def resolveJavaMapCreator(tag: ClassTag[_]): () => JMap[Any, Any] = { + val cls = tag.runtimeClass + val modifiers = cls.getModifiers + if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) { + // Abstract class or interface; we try to use HashMap. + if (!cls.isAssignableFrom(classOf[java.util.HashMap[_, _]])) { + unsupportedCollectionType(cls) + } + () => new util.HashMap[Any, Any]() + } else { + // Use a no-args constructor. + val ctor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + () => ctor.invoke().asInstanceOf[JMap[Any, Any]] + } + } + + /** + * Create a function that can lookup one [[FieldVector vectors]] in `fields` by name. This + * lookup is case insensitive. If the schema contains fields with duplicate (with + * case-insensitive resolution) names an exception is thrown. The returned function will throw + * an exception when no column can be found for a name. + * + * A small note on the binding process in general. Over complete schemas are currently allowed, + * meaning that the data can have more column than the encoder. In this the over complete + * (unbound) columns are ignored. + */ + private def createFieldLookup(fields: Seq[FieldVector]): String => FieldVector = { + def toKey(k: String): String = k.toLowerCase(Locale.ROOT) + val lookup = mutable.Map.empty[String, FieldVector] + fields.foreach { field => + val key = toKey(field.getName) + val old = lookup.put(key, field) + if (old.isDefined) { + throw QueryCompilationErrors.ambiguousColumnOrFieldError( + field.getName :: Nil, + fields.count(f => toKey(f.getName) == key)) + } + } + name => { + lookup.getOrElse(toKey(name), throw QueryCompilationErrors.columnNotFoundError(name)) + } + } + + private def isTuple(cls: Class[_]): Boolean = cls.getName.startsWith("scala.Tuple") + + private def getString(v: VarCharVector, i: Int): String = { + // This is currently a bit heavy on allocations: + // - byte array created in VarCharVector.get + // - CharBuffer created CharSetEncoder + // - char array in String + // By using direct buffers and reusing the char buffer + // we could get rid of the first two allocations. + Text.decode(v.get(i)) + } + + private def loadListIntoBuilder( + v: ListVector, + i: Int, + deserializer: Deserializer[Any], + builder: mutable.Builder[Any, _]): Unit = { + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + builder.sizeHint(end - index) + while (index < end) { + builder += deserializer.get(index) + index += 1 + } + } + + private def getArray(v: ListVector, i: Int, deserializer: Deserializer[Any])(implicit + tag: ClassTag[Any]): AnyRef = { + val builder = mutable.ArrayBuilder.make[Any] + loadListIntoBuilder(v, i, deserializer, builder) + builder.result() + } + + abstract class Deserializer[+E] { + def get(i: Int): E + } + + abstract class FieldDeserializer[E, V <: FieldVector](val vector: V) extends Deserializer[E] { + def value(i: Int): E + def isNull(i: Int): Boolean = vector.isNull(i) + override def get(i: Int): E = { + if (!isNull(i)) { + value(i) + } else { + null.asInstanceOf[E] + } + } + } + + abstract class StructFieldSerializer[E](v: StructVector) + extends FieldDeserializer[E, StructVector](v) { + override def isNull(i: Int): Boolean = vector != null && vector.isNull(i) + } +} + +class EmptyDeserializingIterator[E](val encoder: AgnosticEncoder[E]) + extends CloseableIterator[E] { + override def close(): Unit = () + override def hasNext: Boolean = false + override def next(): E = throw new NoSuchElementException() +} + +class ArrowDeserializingIterator[E]( + val encoder: AgnosticEncoder[E], + private[this] val reader: ArrowReader) + extends CloseableIterator[E] { + private[this] var index = 0 + private[this] val root = reader.getVectorSchemaRoot + private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, root) + + override def hasNext: Boolean = { + if (index >= root.getRowCount) { + if (reader.loadNextBatch()) { + index = 0 + } + } + index < root.getRowCount + } + + override def next(): E = { + if (!hasNext) { + throw new NoSuchElementException() + } + val result = deserializer.get(index) + index += 1 + result + } + + override def close(): Unit = reader.close() +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index f6b140bae557b..ed27336985416 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -24,8 +24,11 @@ import org.apache.arrow.vector.complex.StructVector private[arrow] object ArrowEncoderUtils { object Classes { + val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]] val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] + val MAP: Class[_] = classOf[scala.collection.Map[_, _]] val JLIST: Class[_] = classOf[java.util.List[_]] + val JMAP: Class[_] = classOf[java.util.Map[_, _]] } def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala new file mode 100644 index 0000000000000..90963c831c252 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import java.io.{InputStream, IOException} +import java.nio.channels.Channels + +import org.apache.arrow.flatbuf.MessageHeader +import org.apache.arrow.memory.{ArrowBuf, BufferAllocator} +import org.apache.arrow.vector.ipc.{ArrowReader, ReadChannel} +import org.apache.arrow.vector.ipc.message.{ArrowDictionaryBatch, ArrowMessage, ArrowRecordBatch, MessageChannelReader, MessageResult, MessageSerializer} +import org.apache.arrow.vector.types.pojo.Schema + +/** + * An [[ArrowReader]] that concatenates multiple [[MessageIterator]]s into a single stream. Each + * iterator represents a single IPC stream. The concatenated streams all must have the same + * schema. If the schema is different an exception is thrown. + * + * In some cases we want to retain the messages (see `SparkResult`). Normally a stream reader + * closes its messages when it consumes them. In order to prevent that from happening in + * non-destructive mode we clone the messages before passing them to the reading logic. + */ +class ConcatenatingArrowStreamReader( + allocator: BufferAllocator, + input: Iterator[AbstractMessageIterator], + destructive: Boolean) + extends ArrowReader(allocator) { + + private[this] var totalBytesRead: Long = 0 + private[this] var current: AbstractMessageIterator = _ + + override protected def readSchema(): Schema = { + // readSchema() should only be called once during initialization. + assert(current == null) + if (!input.hasNext) { + // ArrowStreamReader throws the same exception. + throw new IOException("Unexpected end of input. Missing schema.") + } + current = input.next() + current.schema + } + + private def nextMessage(): ArrowMessage = { + // readSchema() should have been invoked at this point so 'current' should be initialized. + assert(current != null) + // Try to find a non-empty message iterator. + while (!current.hasNext && input.hasNext) { + totalBytesRead += current.bytesRead + current = input.next() + if (current.schema != getVectorSchemaRoot.getSchema) { + throw new IllegalStateException() + } + } + if (current.hasNext) { + current.next() + } else { + null + } + } + + override def loadNextBatch(): Boolean = { + // Keep looping until we load a non-empty batch or until we exhaust the input. + var message = nextMessage() + while (message != null) { + message match { + case rb: ArrowRecordBatch => + loadRecordBatch(cloneIfNonDestructive(rb)) + if (getVectorSchemaRoot.getRowCount > 0) { + return true + } + case db: ArrowDictionaryBatch => + loadDictionary(cloneIfNonDestructive(db)) + } + message = nextMessage() + } + false + } + + private def cloneIfNonDestructive(batch: ArrowRecordBatch): ArrowRecordBatch = { + if (destructive) { + return batch + } + cloneRecordBatch(batch) + } + + private def cloneIfNonDestructive(batch: ArrowDictionaryBatch): ArrowDictionaryBatch = { + if (destructive) { + return batch + } + new ArrowDictionaryBatch( + batch.getDictionaryId, + cloneRecordBatch(batch.getDictionary), + batch.isDelta) + } + + private def cloneRecordBatch(batch: ArrowRecordBatch): ArrowRecordBatch = { + new ArrowRecordBatch( + batch.getLength, + batch.getNodes, + batch.getBuffers, + batch.getBodyCompression, + true, + true) + } + + override def bytesRead(): Long = { + if (current != null) { + totalBytesRead + current.bytesRead + } else { + 0 + } + } + + override def closeReadSource(): Unit = () +} + +trait AbstractMessageIterator extends Iterator[ArrowMessage] { + def schema: Schema + def bytesRead: Long +} + +/** + * Decode an Arrow IPC stream into individual messages. Please note that this iterator MUST have a + * valid IPC stream as its input, otherwise construction will fail. + */ +class MessageIterator(input: InputStream, allocator: BufferAllocator) + extends AbstractMessageIterator { + private[this] val in = new ReadChannel(Channels.newChannel(input)) + private[this] val reader = new MessageChannelReader(in, allocator) + private[this] var result: MessageResult = _ + + // Eagerly read the schema. + val schema: Schema = { + val result = reader.readNext() + if (result == null) { + throw new IOException("Unexpected end of input. Missing schema.") + } + MessageSerializer.deserializeSchema(result.getMessage) + } + + override def bytesRead: Long = reader.bytesRead() + + override def hasNext: Boolean = { + if (result == null) { + result = reader.readNext() + } + result != null + } + + override def next(): ArrowMessage = { + if (!hasNext) { + throw new NoSuchElementException() + } + val message = result.getMessage.headerType() match { + case MessageHeader.RecordBatch => + MessageSerializer.deserializeRecordBatch(result.getMessage, bodyBuffer(result)) + case MessageHeader.DictionaryBatch => + MessageSerializer.deserializeDictionaryBatch(result.getMessage, bodyBuffer(result)) + } + result = null + message + } + + private def bodyBuffer(result: MessageResult): ArrowBuf = { + var buffer = result.getBodyBuffer + if (buffer == null) { + buffer = allocator.getEmpty + } + buffer + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 73c04389c0597..07dd2a96bd8f7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { @@ -571,7 +570,8 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM (col("id") / lit(10.0d)).as("b"), col("id"), lit("world").as("d"), - (col("id") % 2).cast("int").as("a")) + // TODO SPARK-44449 make this int again when upcasting is in. + (col("id") % 2).cast("double").as("a")) private def validateMyTypeResult(result: Array[MyType]): Unit = { result.zipWithIndex.foreach { case (MyType(id, a, b), i) => @@ -818,10 +818,11 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM } test("toJSON") { + // TODO SPARK-44449 make this int again when upcasting is in. val expected = Array( - """{"b":0.0,"id":0,"d":"world","a":0}""", - """{"b":0.1,"id":1,"d":"world","a":1}""", - """{"b":0.2,"id":2,"d":"world","a":0}""") + """{"b":0.0,"id":0,"d":"world","a":0.0}""", + """{"b":0.1,"id":1,"d":"world","a":1.0}""", + """{"b":0.2,"id":2,"d":"world","a":0.0}""") val result = spark .range(3) .select(generateMyTypeColumns: _*) @@ -893,14 +894,12 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM test("Dataset result destructive iterator") { // Helper methods for accessing private field `idxToBatches` from SparkResult - val _idxToBatches = - PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches")) + val getResultMap = + PrivateMethod[mutable.Map[Int, Any]](Symbol("resultMap")) - def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = { - val idxToBatches = result invokePrivate _idxToBatches() - - // Sort by key to get stable results. - idxToBatches.toSeq.sortBy(_._1).map(_._2) + def assertResultsMapEmpty(result: SparkResult[_]): Unit = { + val resultMap = result invokePrivate getResultMap() + assert(resultMap.isEmpty) } val df = spark @@ -911,25 +910,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM try { // build and verify the destructive iterator val iterator = result.destructiveIterator - // batches is empty before traversing the result iterator - assert(getColumnarBatches(result).isEmpty) - var previousBatch: ColumnarBatch = null - val buffer = mutable.Buffer.empty[Long] + // resultMap Map is empty before traversing the result iterator + assertResultsMapEmpty(result) + val buffer = mutable.Set.empty[Long] while (iterator.hasNext) { - // always having 1 batch, since a columnar batch will be removed and closed after - // its data got consumed. - val batches = getColumnarBatches(result) - assert(batches.size === 1) - assert(batches.head != previousBatch) - previousBatch = batches.head - - buffer.append(iterator.next()) + // resultMap is empty during iteration because results get removed immediately on access. + assertResultsMapEmpty(result) + buffer += iterator.next() } - // Batches should be closed and removed after traversing all the records. - assert(getColumnarBatches(result).isEmpty) + // resultMap Map is empty afterward because all results have been removed. + assertResultsMapEmpty(result) - val expectedResult = Seq(6L, 7L, 8L) - assert(buffer.size === 3 && expectedResult.forall(buffer.contains)) + val expectedResult = Set(6L, 7L, 8L) + assert(buffer.size === 3 && expectedResult == buffer) } finally { result.close() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index e15069f2d9e96..ab3e13da53178 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -68,10 +68,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("keyAs - keys") { + // TODO SPARK-44449 make this long again when upcasting is in. // It is okay to cast from Long to Double, but not Long to Int. val values = spark .range(10) - .groupByKey(v => v % 2) + .groupByKey(v => (v % 2).toDouble) .keyAs[Double] .keys .collectAsList() @@ -232,9 +233,10 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("agg, keyAs") { + // TODO SPARK-44449 make this long again when upcasting is in. val ds = spark .range(10) - .groupByKey(v => v % 2) + .groupByKey(v => (v % 2).toDouble) .keyAs[Double] .agg(count("*")) @@ -244,7 +246,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { test("typed aggregation: expr") { val session: SparkSession = spark import session.implicits._ - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long]), @@ -254,7 +257,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), @@ -264,7 +268,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), @@ -274,7 +279,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -289,7 +295,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -305,7 +312,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -322,7 +330,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -340,7 +349,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -473,9 +483,9 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) .toDF("key", "seq", "value") val grouped = ds.groupBy($"value").as[String, (String, Int, Int)] - val keys = grouped.keyAs[String].keys.sort($"value") - - checkDataset(keys, "1", "2", "10", "20") + // TODO SPARK-44449 make this string again when upcasting is in. + val keys = grouped.keyAs[Int].keys.sort($"value") + checkDataset(keys, 1, 2, 10, 20) } test("flatMapGroupsWithState") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 58758a1384031..800ce43a60df0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -208,8 +208,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } test("UDF Registration") { + // TODO SPARK-44449 make this long again when upcasting is in. val input = """ - |class A(x: Int) { def get = x * 100 } + |class A(x: Int) { def get: Long = x * 100 } |val myUdf = udf((x: Int) => new A(x).get) |spark.udf.register("dummyUdf", myUdf) |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() @@ -219,8 +220,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } test("UDF closure registration") { + // TODO SPARK-44449 make this int again when upcasting is in. val input = """ - |class A(x: Int) { def get = x * 15 } + |class A(x: Int) { def get: Long = x * 15 } |spark.udf.register("directUdf", (x: Int) => new A(x).get) |spark.sql("select directUdf(id) from range(5)").as[Long].collect() """.stripMargin diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 0c327484e477d..16eec3eee3110 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -21,24 +21,19 @@ import java.util import java.util.{Collections, Objects} import scala.beans.BeanProperty -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.classTag -import scala.util.control.NonFatal -import com.google.protobuf.ByteString import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.connect.proto -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} -import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType} @@ -96,15 +91,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } val resultIterator = - try { - deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator) - } catch { - case NonFatal(e) => - arrowIterator.close() - serializerAllocator.close() - deserializerAllocator.close() - throw e - } + ArrowDeserializers.deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator) new CloseableIterator[T] { override def close(): Unit = { arrowIterator.close() @@ -117,25 +104,6 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } - // Temporary hack until we merge the deserializer. - private def deserializeFromArrow[E]( - batches: Iterator[Array[Byte]], - encoder: AgnosticEncoder[E], - allocator: BufferAllocator): CloseableIterator[E] = { - val responses = batches.map { batch => - val builder = proto.ExecutePlanResponse.newBuilder() - builder.getArrowBatchBuilder.setData(ByteString.copyFrom(batch)) - builder.build() - } - val result = new SparkResult[E](responses.asJava, allocator, encoder) - new CloseableIterator[E] { - private val itr = result.iterator - override def close(): Unit = itr.close() - override def hasNext: Boolean = itr.hasNext - override def next(): E = itr.next() - } - } - private def roundTripAndCheck[T]( encoder: AgnosticEncoder[T], toInputIterator: () => Iterator[Any], @@ -246,6 +214,15 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { assert(inspector.sizeInBytes > 0) } + test("deserializing empty iterator") { + withAllocator { allocator => + val iterator = + ArrowDeserializers.deserializeFromArrow(Iterator.empty, singleIntEncoder, allocator) + assert(iterator.isEmpty) + assert(allocator.getAllocatedMemory == 0) + } + } + test("single batch") { val inspector = new CountingBatchInspector roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () => @@ -533,15 +510,22 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { val maybeNull = MaybeNull(11) Iterator.tabulate(100) { i => val bean = new JavaMapData - bean.setDummyToDoubleListMap(maybeNull { - val map = new util.HashMap[DummyBean, java.util.List[java.lang.Double]] - (0 until (i % 5)).foreach { j => - val dummy = new DummyBean - dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + bean.setMetricMap(maybeNull { + val map = new util.HashMap[String, util.List[java.lang.Double]] + (0 until (i % 20)).foreach { i => val values = Array.tabulate(i % 40) { j => Double.box(j.toDouble) } - map.put(dummy, maybeNull(util.Arrays.asList(values: _*))) + map.put("k" + i, maybeNull(util.Arrays.asList(values: _*))) + } + map + }) + bean.setDummyToStringMap(maybeNull { + val map = new util.HashMap[DummyBean, String] + (0 until (i % 5)).foreach { j => + val dummy = new DummyBean + dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + map.put(dummy, maybeNull("s" + i + "v" + j)) } map }) @@ -675,6 +659,57 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { .add("Ca", "array") .add("Cb", "binary"))) + test("bind to schema") { + // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly + // different field order, and uses different cased names in a couple of places. + withAllocator { allocator => + val input = Row( + 887, + "foo", + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f), + Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false))) + val expected = Row( + "foo", + Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) + val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) + val result = + ArrowDeserializers.deserializeFromArrow(arrowBatches, narrowSchemaEncoder, allocator) + val actual = result.next() + assert(result.isEmpty) + assert(expected === actual) + result.close() + arrowBatches.close() + } + } + + test("unknown field") { + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow(arrowBatches, wideSchemaEncoder, allocator) + } + arrowBatches.close() + } + } + + test("duplicate fields") { + val duplicateSchemaEncoder = toRowEncoder( + new StructType() + .add("foO", "string") + .add("Foo", "string")) + val fooSchemaEncoder = toRowEncoder( + new StructType() + .add("foo", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow(arrowBatches, fooSchemaEncoder, allocator) + } + arrowBatches.close() + } + } + /* ******************************************************************** * * Arrow serialization/deserialization specific errors * ******************************************************************** */ @@ -833,17 +868,23 @@ case class MapData(intStringMap: Map[Int, String], metricMap: Map[String, Array[ class JavaMapData { @scala.beans.BeanProperty - var dummyToDoubleListMap: java.util.Map[DummyBean, java.util.List[java.lang.Double]] = _ + var dummyToStringMap: java.util.Map[DummyBean, String] = _ + + @scala.beans.BeanProperty + var metricMap: java.util.HashMap[String, java.util.List[java.lang.Double]] = _ def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData] override def equals(other: Any): Boolean = other match { case that: JavaMapData if that canEqual this => - dummyToDoubleListMap == that.dummyToDoubleListMap + dummyToStringMap == that.dummyToStringMap && + metricMap == that.metricMap case _ => false } - override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap) + override def hashCode(): Int = { + java.util.Arrays.deepHashCode(Array(dummyToStringMap, metricMap)) + } } class DummyBean { From b20955003f9566bc1caeb625c0eb51402e5eab94 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Thu, 20 Jul 2023 08:59:29 +0900 Subject: [PATCH 026/986] [SPARK-44278][CONNECT] Implement a GRPC server interceptor that cleans up thread local properties ### What changes were proposed in this pull request? This pr implements a GRPC server interceptor that cleans up thread local properties (e.g. SparkContext local properties) for Spark Connect service. ### Why are the changes needed? To prevent the leakage of thread-local variables from one request to another request, ensure proper isolation between requests. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #41831 from heyihong/SPARK-44278. Authored-by: Yihong He Signed-off-by: Hyukjin Kwon (cherry picked from commit fa9725fac0c0c0aca5d86f2a5f0a80f821d5d828) Signed-off-by: Hyukjin Kwon --- .../LocalPropertiesCleanupInterceptor.scala | 51 +++++++++++++++++++ .../service/InterceptorRegistrySuite.scala | 10 ++++ 2 files changed, 61 insertions(+) create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalPropertiesCleanupInterceptor.scala diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalPropertiesCleanupInterceptor.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalPropertiesCleanupInterceptor.scala new file mode 100644 index 0000000000000..1d9acc4d75e44 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/LocalPropertiesCleanupInterceptor.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener + +import org.apache.spark.SparkContext + +/** + * Interceptor for cleaning up local properties in the SparkContext after gRPC server calls. + */ +class LocalPropertiesCleanupInterceptor extends ServerInterceptor { + + override def interceptCall[ReqT, RespT]( + serverCall: ServerCall[ReqT, RespT], + metadata: Metadata, + serverCallHandler: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + new SimpleForwardingServerCallListener[ReqT]( + serverCallHandler.startCall(serverCall, metadata)) { + override def onComplete(): Unit = { + cleanupLocalProperties() + super.onComplete() + } + + override def onCancel(): Unit = { + cleanupLocalProperties() + super.onCancel() + } + + private def cleanupLocalProperties(): Unit = { + SparkContext.getActive.foreach(_.getLocalProperties.clear()) + } + } + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala index 7f85966f0a7b6..33f6627ee0e10 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala @@ -184,4 +184,14 @@ class InterceptorRegistrySuite extends SharedSparkSession { assert(interceptors.head.isInstanceOf[LoggingInterceptor]) } } + + test("LocalPropertiesCleanupInterceptor initializes when configured in spark conf") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.LocalPropertiesCleanupInterceptor") { + val interceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() + assert(interceptors.size == 1) + assert(interceptors.head.isInstanceOf[LocalPropertiesCleanupInterceptor]) + } + } } From 5511a0bd1d354a99e77de4ef132dfdd1ed6bd762 Mon Sep 17 00:00:00 2001 From: Jack Chen Date: Thu, 20 Jul 2023 11:09:50 +0800 Subject: [PATCH 027/986] [SPARK-44431][SQL] Fix behavior of null IN (empty list) in optimization rules ### What changes were proposed in this pull request? `null IN (empty list)` incorrectly evaluates to null, when it should evaluate to false. (The reason it should be false is because a IN (b1, b2) is defined as a = b1 OR a = b2, and an empty IN list is treated as an empty OR which is false. This is specified by ANSI SQL.) Many places in Spark execution (In, InSet, InSubquery) and optimization (OptimizeIn, NullPropagation) implemented this wrong behavior. This is a longstanding correctness issue which has existed since null support for IN expressions was first added to Spark. This PR fixes the optimization rules OptimizeIn and NullPropagation, which followed the preexisting, incorrect execution behavior. The execution fixes will be in the next PR. The behavior is under a flag, which will be available to revert to the legacy behavior if needed. This flag is set to disable the new behavior until all of the fix PRs are complete. See [this doc](https://docs.google.com/document/d/1k8AY8oyT-GI04SnP7eXttPDnDj-Ek-c3luF2zL6DPNU/edit) for more information. ### Why are the changes needed? Fix wrong SQL semantics ### Does this PR introduce _any_ user-facing change? Not yet, but will fix wrong SQL semantics when enabled ### How was this patch tested? Add unit tests and sql tests. Closes #42007 from jchen5/null-in-empty-opt. Authored-by: Jack Chen Signed-off-by: Wenchen Fan (cherry picked from commit db357edb7b21b12c9721a86985a3cc92fcd32bf3) Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 3 + .../sql/catalyst/optimizer/expressions.scala | 32 +- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../catalyst/optimizer/OptimizeInSuite.scala | 174 +++++++--- .../optimizer/ReplaceOperatorSuite.scala | 26 +- .../in-subquery/in-null-semantics.sql.out | 304 ++++++++++++++++++ .../in-subquery/in-null-semantics.sql | 57 ++++ .../in-subquery/in-null-semantics.sql.out | 239 ++++++++++++++ .../org/apache/spark/sql/EmptyInSuite.scala | 63 ++++ 9 files changed, 857 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-null-semantics.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fd2ea96a2967f..95cf3aee16f3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -96,6 +96,9 @@ abstract class Optimizer(catalogManager: CatalogManager) OptimizeRepartition, TransposeWindow, NullPropagation, + // NullPropagation may introduce Exists subqueries, so RewriteNonCorrelatedExists must run + // after. + RewriteNonCorrelatedExists, NullDownPropagation, ConstantPropagation, FoldablePropagation, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d756a2dcb744..8cb560199c069 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, TreeNodeTag} import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -283,9 +284,13 @@ object OptimizeIn extends Rule[LogicalPlan] { _.containsPattern(IN), ruleId) { case q: LogicalPlan => q.transformExpressionsDownWithPruning(_.containsPattern(IN), ruleId) { case In(v, list) if list.isEmpty => - // When v is not nullable, the following expression will be optimized - // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) + if (!SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR)) { + FalseLiteral + } else { + // Incorrect legacy behavior optimizes to null if the left side is null, and otherwise + // to false. + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) + } case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.length == 1 @@ -841,9 +846,24 @@ object NullPropagation extends Rule[LogicalPlan] { } } - // If the value expression is NULL then transform the In expression to null literal. - case In(Literal(null, _), _) => Literal.create(null, BooleanType) - case InSubquery(Seq(Literal(null, _)), _) => Literal.create(null, BooleanType) + // If the list is empty, transform the In expression to false literal. + case In(_, list) + if list.isEmpty && !SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR) => + Literal.create(false, BooleanType) + // If the value expression is NULL (and the list is non-empty), then transform the + // In expression to null literal. + // If the legacy flag is set, then it becomes null even if the list is empty (which is + // incorrect legacy behavior) + case In(Literal(null, _), list) + if list.nonEmpty || SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR) + => Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), _) + if SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR) => + Literal.create(null, BooleanType) + case InSubquery(Seq(Literal(null, _)), ListQuery(sub, _, _, _, conditions, _)) + if !SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR) + && conditions.isEmpty => + If(Exists(sub), Literal(null, BooleanType), FalseLiteral) // Non-leaf NullIntolerant expressions will return null, if at least one of its children is // a null literal. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4d1d4b45fa303..00bb6f77ef339 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4289,6 +4289,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR = + buildConf("spark.sql.legacy.nullInEmptyListBehavior") + .internal() + .doc("When set to true, restores the legacy incorrect behavior of IN expressions for " + + "NULL values IN an empty list (including IN subqueries and literal IN lists): " + + "`null IN (empty list)` should evaluate to false, but sometimes (not always) " + + "incorrectly evaluates to null in the legacy behavior.") + .version("3.5.0") + .booleanConf + .createWithDefault(true) + val ERROR_MESSAGE_FORMAT = buildConf("spark.sql.error.messageFormat") .doc("When PRETTY, the error message consists of textual representation of error class, " + "message and query context. The MINIMAL and STANDARD formats are pretty JSON formats where " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 7f377d18e9def..7418128dd48a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD +import org.apache.spark.sql.internal.SQLConf.{LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR, OPTIMIZER_INSET_CONVERSION_THRESHOLD} import org.apache.spark.sql.types._ class OptimizeInSuite extends PlanTest { @@ -121,19 +121,43 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("OptimizedIn test: NULL IN (subquery) gets transformed to Filter(null)") { - val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) - val originalQuery = - testRelation - .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) - .analyze + test("OptimizedIn test: Legacy behavior: " + + "NULL IN (subquery) gets transformed to Filter(null)") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "true") { + val subquery = ListQuery(testRelation.select(UnresolvedAttribute("a"))) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), subquery)) + .analyze - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .where(Literal.create(null, BooleanType)) - .analyze - comparePlans(optimized, correctAnswer) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(null, BooleanType)) + .analyze + comparePlans(optimized, correctAnswer) + } + } + + test("OptimizedIn test: NULL IN (subquery) gets transformed to " + + "If(Exists(subquery), null, false)") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "false") { + val subquery = testRelation.select(UnresolvedAttribute("a")) + val originalQuery = + testRelation + .where(InSubquery(Seq(Literal.create(null, NullType)), ListQuery(subquery))) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + // Our simplified Optimize results in an extra redundant Project. This gets collapsed in + // the full optimizer. + val correctAnswer = + testRelation + .where(If(Exists(Project(Seq(UnresolvedAttribute("a")), subquery)), + Literal.create(null, BooleanType), Literal(false))) + .analyze + comparePlans(optimized, correctAnswer) + } } test("OptimizedIn test: Inset optimization disabled as " + @@ -219,36 +243,108 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + - "when value is not nullable") { - val originalQuery = - testRelation - .where(In(Literal("a"), Nil)) - .analyze + test("OptimizedIn test: expr IN (empty list) gets transformed to literal false") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "false") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(Literal(false)) - .analyze + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(false, BooleanType)) + .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) + } } - test("OptimizedIn test: In empty list gets transformed to `If` expression " + - "when value is nullable") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Nil)) - .analyze + test("OptimizedIn test: null IN (empty list) gets transformed to literal false") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "false") { + val originalQuery = + testRelation + .where(In(Literal.create(null, NullType), Nil)) + .analyze - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(If(IsNotNull(UnresolvedAttribute("a")), - Literal(false), Literal.create(null, BooleanType))) - .analyze + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Literal.create(false, BooleanType)) + .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, correctAnswer) + } + } + + test("OptimizedIn test: expr IN (empty list) gets transformed to literal false in select") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "false") { + val originalQuery = + testRelation + .select(In(UnresolvedAttribute("a"), Nil).as("x")) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Literal.create(false, BooleanType).as("x")) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + test("OptimizedIn test: null IN (empty list) gets transformed to literal false in select") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "false") { + val originalQuery = + testRelation + .select(In(Literal.create(null, NullType), Nil).as("x")) + .analyze + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select(Literal.create(false, BooleanType).as("x")) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + test("OptimizedIn test: Legacy behavior: " + + "In empty list gets transformed to FalseLiteral when value is not nullable") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "true") { + val originalQuery = + testRelation + .where(In(Literal("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(Literal(false)) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + test("OptimizedIn test: Legacy behavior: " + + "In empty list gets transformed to `If` expression when value is nullable") { + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "true") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 06fcb12acdd08..5d81e96a8e583 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR import org.apache.spark.sql.types.BooleanType class ReplaceOperatorSuite extends PlanTest { @@ -233,13 +234,24 @@ class ReplaceOperatorSuite extends PlanTest { val basePlan = LocalRelation(Seq($"a".int, $"b".int)) val otherPlan = basePlan.where($"a".in(1, 2) || $"b".in()) val except = Except(basePlan, otherPlan, false) - val result = OptimizeIn(Optimize.execute(except.analyze)) - val correctAnswer = Aggregate(basePlan.output, basePlan.output, - Filter(!Coalesce(Seq( - $"a".in(1, 2) || If($"b".isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)), - Literal.FalseLiteral)), - basePlan)).analyze - comparePlans(result, correctAnswer) + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "false") { + val result = OptimizeIn(Optimize.execute(except.analyze)) + val correctAnswer = Aggregate(basePlan.output, basePlan.output, + Filter(!Coalesce(Seq( + $"a".in(1, 2) || Literal.FalseLiteral, + Literal.FalseLiteral)), + basePlan)).analyze + comparePlans(result, correctAnswer) + } + withSQLConf(LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> "true") { + val result = OptimizeIn(Optimize.execute(except.analyze)) + val correctAnswer = Aggregate(basePlan.output, basePlan.output, + Filter(!Coalesce(Seq( + $"a".in(1, 2) || If($"b".isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)), + Literal.FalseLiteral)), + basePlan)).analyze + comparePlans(result, correctAnswer) + } } test("SPARK-26366: ReplaceExceptWithFilter should not transform non-deterministic") { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-null-semantics.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-null-semantics.sql.out new file mode 100644 index 0000000000000..ac5c41dd307b4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/in-subquery/in-null-semantics.sql.out @@ -0,0 +1,304 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +create temp view v (c) as values (1), (null) +-- !query analysis +CreateViewCommand `v`, [(c,None)], values (1), (null), false, false, LocalTempView, true + +- LocalRelation [col1#x] + + +-- !query +create temp view v_empty (e) as select 1 where false +-- !query analysis +CreateViewCommand `v_empty`, [(e,None)], select 1 where false, false, false, LocalTempView, true + +- Project [1 AS 1#x] + +- Filter false + +- OneRowRelation + + +-- !query +create table t(c int) using json +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t`, false + + +-- !query +insert into t values (1), (null) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t, false, JSON, [path=file:[not included in comparison]/{warehouse_dir}/t], Append, `spark_catalog`.`default`.`t`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t), [c] ++- Project [cast(col1#x as int) AS c#x] + +- LocalRelation [col1#x] + + +-- !query +create table t2(d int) using json +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t2`, false + + +-- !query +insert into t2 values (2) +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t2, false, JSON, [path=file:[not included in comparison]/{warehouse_dir}/t2], Append, `spark_catalog`.`default`.`t2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t2), [d] ++- Project [cast(col1#x as int) AS d#x] + +- LocalRelation [col1#x] + + +-- !query +create table t_empty(e int) using json +-- !query analysis +CreateDataSourceTableCommand `spark_catalog`.`default`.`t_empty`, false + + +-- !query +set spark.sql.legacy.nullInEmptyListBehavior = false +-- !query analysis +SetCommand (spark.sql.legacy.nullInEmptyListBehavior,Some(false)) + + +-- !query +select c, c in (select e from t_empty) from t +-- !query analysis +Project [c#x, c#x IN (list#x []) AS (c IN (listquery()))#x] +: +- Project [e#x] +: +- SubqueryAlias spark_catalog.default.t_empty +: +- Relation spark_catalog.default.t_empty[e#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[c#x] json + + +-- !query +select c, c in (select e from v_empty) from v +-- !query analysis +Project [c#x, c#x IN (list#x []) AS (c IN (listquery()))#x] +: +- Project [e#x] +: +- SubqueryAlias v_empty +: +- View (`v_empty`, [e#x]) +: +- Project [cast(1#x as int) AS e#x] +: +- Project [1 AS 1#x] +: +- Filter false +: +- OneRowRelation ++- SubqueryAlias v + +- View (`v`, [c#x]) + +- Project [cast(col1#x as int) AS c#x] + +- LocalRelation [col1#x] + + +-- !query +select c, c not in (select e from t_empty) from t +-- !query analysis +Project [c#x, NOT c#x IN (list#x []) AS (NOT (c IN (listquery())))#x] +: +- Project [e#x] +: +- SubqueryAlias spark_catalog.default.t_empty +: +- Relation spark_catalog.default.t_empty[e#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[c#x] json + + +-- !query +select c, c not in (select e from v_empty) from v +-- !query analysis +Project [c#x, NOT c#x IN (list#x []) AS (NOT (c IN (listquery())))#x] +: +- Project [e#x] +: +- SubqueryAlias v_empty +: +- View (`v_empty`, [e#x]) +: +- Project [cast(1#x as int) AS e#x] +: +- Project [1 AS 1#x] +: +- Filter false +: +- OneRowRelation ++- SubqueryAlias v + +- View (`v`, [c#x]) + +- Project [cast(col1#x as int) AS c#x] + +- LocalRelation [col1#x] + + +-- !query +select null in (select e from t_empty) +-- !query analysis +Project [cast(null as int) IN (list#x []) AS (NULL IN (listquery()))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias spark_catalog.default.t_empty +: +- Relation spark_catalog.default.t_empty[e#x] json ++- OneRowRelation + + +-- !query +select null in (select e from v_empty) +-- !query analysis +Project [cast(null as int) IN (list#x []) AS (NULL IN (listquery()))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias v_empty +: +- View (`v_empty`, [e#x]) +: +- Project [cast(1#x as int) AS e#x] +: +- Project [1 AS 1#x] +: +- Filter false +: +- OneRowRelation ++- OneRowRelation + + +-- !query +select null not in (select e from t_empty) +-- !query analysis +Project [NOT cast(null as int) IN (list#x []) AS (NOT (NULL IN (listquery())))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias spark_catalog.default.t_empty +: +- Relation spark_catalog.default.t_empty[e#x] json ++- OneRowRelation + + +-- !query +select null not in (select e from v_empty) +-- !query analysis +Project [NOT cast(null as int) IN (list#x []) AS (NOT (NULL IN (listquery())))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias v_empty +: +- View (`v_empty`, [e#x]) +: +- Project [cast(1#x as int) AS e#x] +: +- Project [1 AS 1#x] +: +- Filter false +: +- OneRowRelation ++- OneRowRelation + + +-- !query +select * from t left join t2 on (t.c in (select e from t_empty)) is null +-- !query analysis +Project [c#x, d#x] ++- Join LeftOuter, isnull(c#x IN (list#x [])) + : +- Project [e#x] + : +- SubqueryAlias spark_catalog.default.t_empty + : +- Relation spark_catalog.default.t_empty[e#x] json + :- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[c#x] json + +- SubqueryAlias spark_catalog.default.t2 + +- Relation spark_catalog.default.t2[d#x] json + + +-- !query +select * from t left join t2 on (t.c not in (select e from t_empty)) is null +-- !query analysis +Project [c#x, d#x] ++- Join LeftOuter, isnull(NOT c#x IN (list#x [])) + : +- Project [e#x] + : +- SubqueryAlias spark_catalog.default.t_empty + : +- Relation spark_catalog.default.t_empty[e#x] json + :- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[c#x] json + +- SubqueryAlias spark_catalog.default.t2 + +- Relation spark_catalog.default.t2[d#x] json + + +-- !query +set spark.sql.legacy.nullInEmptyListBehavior = true +-- !query analysis +SetCommand (spark.sql.legacy.nullInEmptyListBehavior,Some(true)) + + +-- !query +select null in (select e from t_empty) +-- !query analysis +Project [cast(null as int) IN (list#x []) AS (NULL IN (listquery()))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias spark_catalog.default.t_empty +: +- Relation spark_catalog.default.t_empty[e#x] json ++- OneRowRelation + + +-- !query +select null in (select e from v_empty) +-- !query analysis +Project [cast(null as int) IN (list#x []) AS (NULL IN (listquery()))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias v_empty +: +- View (`v_empty`, [e#x]) +: +- Project [cast(1#x as int) AS e#x] +: +- Project [1 AS 1#x] +: +- Filter false +: +- OneRowRelation ++- OneRowRelation + + +-- !query +select null not in (select e from t_empty) +-- !query analysis +Project [NOT cast(null as int) IN (list#x []) AS (NOT (NULL IN (listquery())))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias spark_catalog.default.t_empty +: +- Relation spark_catalog.default.t_empty[e#x] json ++- OneRowRelation + + +-- !query +select null not in (select e from v_empty) +-- !query analysis +Project [NOT cast(null as int) IN (list#x []) AS (NOT (NULL IN (listquery())))#x] +: +- Project [e#x] +: +- Project [e#x] +: +- SubqueryAlias v_empty +: +- View (`v_empty`, [e#x]) +: +- Project [cast(1#x as int) AS e#x] +: +- Project [1 AS 1#x] +: +- Filter false +: +- OneRowRelation ++- OneRowRelation + + +-- !query +select * from t left join t2 on (t.c in (select e from t_empty)) is null +-- !query analysis +Project [c#x, d#x] ++- Join LeftOuter, isnull(c#x IN (list#x [])) + : +- Project [e#x] + : +- SubqueryAlias spark_catalog.default.t_empty + : +- Relation spark_catalog.default.t_empty[e#x] json + :- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[c#x] json + +- SubqueryAlias spark_catalog.default.t2 + +- Relation spark_catalog.default.t2[d#x] json + + +-- !query +select * from t left join t2 on (t.c not in (select e from t_empty)) is null +-- !query analysis +Project [c#x, d#x] ++- Join LeftOuter, isnull(NOT c#x IN (list#x [])) + : +- Project [e#x] + : +- SubqueryAlias spark_catalog.default.t_empty + : +- Relation spark_catalog.default.t_empty[e#x] json + :- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[c#x] json + +- SubqueryAlias spark_catalog.default.t2 + +- Relation spark_catalog.default.t2[d#x] json + + +-- !query +reset spark.sql.legacy.nullInEmptyListBehavior +-- !query analysis +ResetCommand spark.sql.legacy.nullInEmptyListBehavior + + +-- !query +drop table t +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t + + +-- !query +drop table t2 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2 + + +-- !query +drop table t_empty +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_empty diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql new file mode 100644 index 0000000000000..b893d8970b4d6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql @@ -0,0 +1,57 @@ +create temp view v (c) as values (1), (null); +create temp view v_empty (e) as select 1 where false; + +-- Note: tables and temp views hit different optimization/execution codepaths +create table t(c int) using json; +insert into t values (1), (null); +create table t2(d int) using json; +insert into t2 values (2); +create table t_empty(e int) using json; + + + +set spark.sql.legacy.nullInEmptyListBehavior = false; + +-- null IN (empty subquery) +-- Correct results: c in (emptylist) should always be false + +select c, c in (select e from t_empty) from t; +select c, c in (select e from v_empty) from v; +select c, c not in (select e from t_empty) from t; +select c, c not in (select e from v_empty) from v; + +-- constant null IN (empty subquery) - rewritten by NullPropagation rule + +select null in (select e from t_empty); +select null in (select e from v_empty); +select null not in (select e from t_empty); +select null not in (select e from v_empty); + +-- IN subquery which is not rewritten to join - here we use IN in the ON condition because that is a case that doesn't get rewritten to join in RewritePredicateSubquery, so we can observe the execution behavior of InSubquery directly +-- Correct results: column t2.d should be NULL because the ON condition is always false +-- This will be fixed by the execution fixes. +select * from t left join t2 on (t.c in (select e from t_empty)) is null; +select * from t left join t2 on (t.c not in (select e from t_empty)) is null; + + + +-- Test legacy behavior flag +set spark.sql.legacy.nullInEmptyListBehavior = true; + +-- constant null IN (empty subquery) - rewritten by NullPropagation rule + +select null in (select e from t_empty); +select null in (select e from v_empty); +select null not in (select e from t_empty); +select null not in (select e from v_empty); + +-- IN subquery which is not rewritten to join - here we use IN in the ON condition because that is a case that doesn't get rewritten to join in RewritePredicateSubquery, so we can observe the execution behavior of InSubquery directly +-- Correct results: column t2.d should be NULL because the ON condition is always false +select * from t left join t2 on (t.c in (select e from t_empty)) is null; +select * from t left join t2 on (t.c not in (select e from t_empty)) is null; + +reset spark.sql.legacy.nullInEmptyListBehavior; + +drop table t; +drop table t2; +drop table t_empty; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out new file mode 100644 index 0000000000000..39b03576baaf0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out @@ -0,0 +1,239 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +create temp view v (c) as values (1), (null) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temp view v_empty (e) as select 1 where false +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t(c int) using json +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t values (1), (null) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t2(d int) using json +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into t2 values (2) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table t_empty(e int) using json +-- !query schema +struct<> +-- !query output + + + +-- !query +set spark.sql.legacy.nullInEmptyListBehavior = false +-- !query schema +struct +-- !query output +spark.sql.legacy.nullInEmptyListBehavior false + + +-- !query +select c, c in (select e from t_empty) from t +-- !query schema +struct +-- !query output +1 false +NULL false + + +-- !query +select c, c in (select e from v_empty) from v +-- !query schema +struct +-- !query output +1 false +NULL false + + +-- !query +select c, c not in (select e from t_empty) from t +-- !query schema +struct +-- !query output +1 true +NULL true + + +-- !query +select c, c not in (select e from v_empty) from v +-- !query schema +struct +-- !query output +1 true +NULL true + + +-- !query +select null in (select e from t_empty) +-- !query schema +struct<(NULL IN (listquery())):boolean> +-- !query output +false + + +-- !query +select null in (select e from v_empty) +-- !query schema +struct<(NULL IN (listquery())):boolean> +-- !query output +false + + +-- !query +select null not in (select e from t_empty) +-- !query schema +struct<(NOT (NULL IN (listquery()))):boolean> +-- !query output +true + + +-- !query +select null not in (select e from v_empty) +-- !query schema +struct<(NOT (NULL IN (listquery()))):boolean> +-- !query output +true + + +-- !query +select * from t left join t2 on (t.c in (select e from t_empty)) is null +-- !query schema +struct +-- !query output +1 NULL +NULL 2 + + +-- !query +select * from t left join t2 on (t.c not in (select e from t_empty)) is null +-- !query schema +struct +-- !query output +1 NULL +NULL 2 + + +-- !query +set spark.sql.legacy.nullInEmptyListBehavior = true +-- !query schema +struct +-- !query output +spark.sql.legacy.nullInEmptyListBehavior true + + +-- !query +select null in (select e from t_empty) +-- !query schema +struct<(NULL IN (listquery())):boolean> +-- !query output +NULL + + +-- !query +select null in (select e from v_empty) +-- !query schema +struct<(NULL IN (listquery())):boolean> +-- !query output +NULL + + +-- !query +select null not in (select e from t_empty) +-- !query schema +struct<(NOT (NULL IN (listquery()))):boolean> +-- !query output +NULL + + +-- !query +select null not in (select e from v_empty) +-- !query schema +struct<(NOT (NULL IN (listquery()))):boolean> +-- !query output +NULL + + +-- !query +select * from t left join t2 on (t.c in (select e from t_empty)) is null +-- !query schema +struct +-- !query output +1 NULL +NULL 2 + + +-- !query +select * from t left join t2 on (t.c not in (select e from t_empty)) is null +-- !query schema +struct +-- !query output +1 NULL +NULL 2 + + +-- !query +reset spark.sql.legacy.nullInEmptyListBehavior +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t2 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t_empty +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala new file mode 100644 index 0000000000000..c9e016c891e77 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class EmptyInSuite extends QueryTest +with SharedSparkSession { + import testImplicits._ + + val row = identity[(java.lang.Integer, java.lang.Double)](_) + + lazy val t = Seq( + row((1, 1.0)), + row((null, 2.0))).toDF("a", "b") + + test("IN with empty list") { + // This test has to be written in scala to construct a literal empty IN list, since that + // isn't valid syntax in SQL. + val emptylist = Seq.empty[Literal] + + Seq(true, false).foreach { legacyNullInBehavior => + // To observe execution behavior, disable the OptimizeIn rule which optimizes away empty lists + Seq(true, false).foreach { disableOptimizeIn => + // Disable ConvertToLocalRelation since it would collapse the evaluation of the IN + // expression over the LocalRelation + var excludedRules = "org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation" + if (disableOptimizeIn) { + excludedRules += ",org.apache.spark.sql.catalyst.optimizer.OptimizeIn" + } + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules, + SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> legacyNullInBehavior.toString) { + // We still get legacy behavior with disableOptimizeIn until execution is also fixed + val expectedResultForNullInEmpty = + if (legacyNullInBehavior || disableOptimizeIn) null else false + val df = t.select(col("a"), col("a").isin(emptylist: _*)) + checkAnswer( + df, + Row(1, false) :: Row(null, expectedResultForNullInEmpty) :: Nil) + } + } + } + } +} From 630888727451d2ec7ba620a303d12bbda05d3801 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 20 Jul 2023 13:51:39 +0800 Subject: [PATCH 028/986] [SPARK-44367][SQL][UI] Show error message on UI for each failed query ### What changes were proposed in this pull request? This PR adds an 'error message' col to the failed query execution table on the SQL/DataFrame tab of UI. ### Why are the changes needed? The SQL tab of UI is not helping to detect SQL errors. This PR will provide users with a clear understanding of why their queries have failed. ### Does this PR introduce _any_ user-facing change? SQL tab of UI shows errors for failed queries ### How was this patch tested? built and tested locally ![image](https://github.com/apache/spark/assets/8326978/cf25e347-bb99-47f8-accd-aabaf8d3a9d8) Closes #41951 from yaooqinn/SPARK-44367. Authored-by: Kent Yao Signed-off-by: Kent Yao (cherry picked from commit 399705512a417de460529843eb047d5c2e8f9e22) Signed-off-by: Kent Yao --- .../scala/org/apache/spark/ui/UIUtils.scala | 28 +++++++++++ .../org/apache/spark/ui/jobs/StagePage.scala | 15 +----- .../org/apache/spark/ui/jobs/StageTable.scala | 18 +------ .../org/apache/spark/ui/UIUtilsSuite.scala | 24 ++++++++- .../spark/sql/execution/SQLExecution.scala | 2 +- .../sql/execution/ui/AllExecutionsPage.scala | 50 ++++++++++++++++--- .../thriftserver/ui/ThriftServerPage.scala | 17 ------- 7 files changed, 97 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 0ce647d12c569..f0f8cf1310f00 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -31,6 +31,7 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} +import org.apache.commons.text.StringEscapeUtils import org.glassfish.jersey.internal.util.collection.MultivaluedStringMap import org.apache.spark.internal.Logging @@ -708,4 +709,31 @@ private[spark] object UIUtils extends Logging { Seq.empty[Node] } } + + private final val ERROR_CLASS_REGEX = """\[(?[A-Z][A-Z_.]+[A-Z])]""".r + + private def errorSummary(errorMessage: String): (String, Boolean) = { + var isMultiline = true + val maybeErrorClass = + ERROR_CLASS_REGEX.findFirstMatchIn(errorMessage).map(_.group("errorClass")) + val errorClassOrBrief = if (maybeErrorClass.nonEmpty && maybeErrorClass.get.nonEmpty) { + maybeErrorClass.get + } else if (errorMessage.indexOf('\n') >= 0) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else if (errorMessage.indexOf(":") >= 0) { + errorMessage.substring(0, errorMessage.indexOf(":")) + } else { + isMultiline = false + errorMessage + } + + val errorSummary = StringEscapeUtils.escapeHtml4(errorClassOrBrief) + (errorSummary, isMultiline) + } + + def errorMessageCell(errorMessage: String): Seq[Node] = { + val (summary, isMultiline) = errorSummary(errorMessage) + val details = detailsUINode(isMultiline, errorMessage) + {summary}{details} + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1934e9e58e6b2..02aece6e50a8d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -696,7 +696,7 @@ private[ui] class TaskPagedTable( {formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))} {formatBytes(task.taskMetrics.map(_.diskBytesSpilled))} }} - {errorMessageCell(task.errorMessage.getOrElse(""))} + {UIUtils.errorMessageCell(task.errorMessage.getOrElse(""))} } @@ -713,19 +713,6 @@ private[ui] class TaskPagedTable( private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = { task.taskMetrics.map(fn).getOrElse(Nil) } - - private def errorMessageCell(error: String): Seq[Node] = { - val isMultiline = error.indexOf('\n') >= 0 - // Display the first line by default - val errorSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - error.substring(0, error.indexOf('\n')) - } else { - error - }) - val details = UIUtils.detailsUINode(isMultiline, error) - {errorSummary}{details} - } } private[spark] object ApiHelper { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 9e6eb418fe134..9e78f29e92e5d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -24,8 +24,6 @@ import javax.servlet.http.HttpServletRequest import scala.xml._ -import org.apache.commons.text.StringEscapeUtils - import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1 import org.apache.spark.ui._ @@ -217,7 +215,7 @@ private[ui] class StagePagedTable( {data.shuffleWriteWithUnit} ++ { if (isFailedStage) { - failureReasonHtml(info) + UIUtils.errorMessageCell(info.failureReason.getOrElse("")) } else { Seq.empty } @@ -225,20 +223,6 @@ private[ui] class StagePagedTable( } } - private def failureReasonHtml(s: v1.StageData): Seq[Node] = { - val failureReason = s.failureReason.getOrElse("") - val isMultiline = failureReason.indexOf('\n') >= 0 - // Display the first line by default - val failureReasonSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - failureReason.substring(0, failureReason.indexOf('\n')) - } else { - failureReason - }) - val details = UIUtils.detailsUINode(isMultiline, failureReason) - {failureReasonSummary}{details} - } - private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { val basePathUri = UIUtils.prependBaseUri(request, basePath) diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 9d040bb4e1ec7..aecd25f6c8dea 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui import scala.xml.{Node, Text} import scala.xml.Utility.trim -import org.apache.spark.SparkFunSuite +import org.apache.spark.{ErrorMessageFormat, SparkException, SparkFunSuite, SparkThrowableHelper} class UIUtilsSuite extends SparkFunSuite { import UIUtils._ @@ -189,4 +189,26 @@ class UIUtilsSuite extends SparkFunSuite { assert(generated.sameElements(expected), s"\n$errorMsg\n\nExpected:\n$expected\nGenerated:\n$generated") } + + // scalastyle:off line.size.limit + test("SPARK-44367: Extract errorClass from errorMsg with errorMessageCell") { + val e1 = "Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 1) (10.221.98.22 executor driver): org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.\n== SQL(line 1, position 8) ==\nselect a/b from src\n ^^^\n\n\tat org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:226)\n\tat org.apache.spark.sql.errors.QueryExecutionErrors.divideByZeroError(QueryExecutionErrors.scala)\n\tat org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:54)\n\tat org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)\n\tat org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)\n\tat org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388)\n\tat org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:890)\n\tat org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:890)\n\tat org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)\n\tat org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)\n\tat org.apache.spark.rdd.RDD.iterator(RDD.scala:328)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)\n\tat org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:141)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:592)\n\tat org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1474)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:595)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:750)\n\nDriver stacktrace:" + val cell1 = UIUtils.errorMessageCell(e1) + assert(cell1 === {"DIVIDE_BY_ZERO"}{UIUtils.detailsUINode(isMultiline = true, e1)}) + + val e2 = SparkException.internalError("test") + val cell2 = UIUtils.errorMessageCell(e2.getMessage) + assert(cell2 === {"INTERNAL_ERROR"}{UIUtils.detailsUINode(isMultiline = true, e2.getMessage)}) + + val e3 = new SparkException( + errorClass = "CANNOT_CAST_DATATYPE", + messageParameters = Map("sourceType" -> "long", "targetType" -> "int"), cause = null) + val cell3 = UIUtils.errorMessageCell(SparkThrowableHelper.getMessage(e3, ErrorMessageFormat.PRETTY)) + assert(cell3 === {"CANNOT_CAST_DATATYPE"}{UIUtils.detailsUINode(isMultiline = true, e3.getMessage)}) + + val e4 = "java.lang.RuntimeException: random text" + val cell4 = UIUtils.errorMessageCell(e4) + assert(cell4 === {"java.lang.RuntimeException"}{UIUtils.detailsUINode(isMultiline = true, e4)}) + } + // scalastyle:on line.size.limit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index eeca1669e746a..68b29e9e216f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -124,7 +124,7 @@ object SQLExecution { val endTime = System.nanoTime() val errorMessage = ex.map { case e: SparkThrowable => - SparkThrowableHelper.getMessage(e, ErrorMessageFormat.MINIMAL) + SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) case e => // unexpected behavior SparkThrowableHelper.getMessage(e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index a69ca1bbc80d4..2e088ec8e4bc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -75,8 +75,16 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L if (running.nonEmpty) { val runningPageTable = - executionsTable(request, "running", running.toSeq, - executionIdToSubExecutions.mapValues(_.toSeq).toMap, currentTime, true, true, true) + executionsTable( + request, + "running", + running.toSeq, + executionIdToSubExecutions.mapValues(_.toSeq).toMap, + currentTime, + showErrorMessage = false, + showRunningJobs = true, + showSucceededJobs = true, + showFailedJobs = true) _content ++= }} + {if (showErrorMessage) { + UIUtils.errorMessageCell(executionUIData.errorMessage.getOrElse("")) + }} {if (showSubExecutions) { {executionLinks(executionTableRow.subExecutionData.map(_.executionUIData.executionId))} @@ -536,6 +571,7 @@ private[ui] class ExecutionDataSource( case "Job IDs" | "Succeeded Job IDs" => Ordering by (_.completedJobData.headOption) case "Running Job IDs" => Ordering.by(_.runningJobData.headOption) case "Failed Job IDs" => Ordering.by(_.failedJobData.headOption) + case "Error Message" => Ordering.by(_.executionUIData.errorMessage) case unknownColumn => throw QueryExecutionErrors.unknownColumnError(unknownColumn) } if (desc) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index d0378efd646e3..d47a99466a543 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -23,8 +23,6 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.commons.text.StringEscapeUtils - import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.thriftserver.ui.ToolTips._ import org.apache.spark.ui._ @@ -274,21 +272,6 @@ private[ui] class SqlStatsPagedTable( } - - private def errorMessageCell(errorMessage: String): Seq[Node] = { - val isMultiline = errorMessage.indexOf('\n') >= 0 - val errorSummary = StringEscapeUtils.escapeHtml4( - if (isMultiline) { - errorMessage.substring(0, errorMessage.indexOf('\n')) - } else { - errorMessage - }) - val details = detailsUINode(isMultiline, errorMessage) - - {errorSummary}{details} - - } - private def jobURL(request: HttpServletRequest, jobId: String): String = "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) } From 5406982703cd9fcabda69b1a8078354be3ee739c Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 20 Jul 2023 15:25:31 +0800 Subject: [PATCH 029/986] [SPARK-44475][SQL][CONNECT] Relocate DataType and Parser to sql/api ### What changes were proposed in this pull request? This PR relocates Parser and DataType family to sql/api. ### Why are the changes needed? To extract DataType and Parser as an API into a shared module. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #41928 from amaliujia/datatype_1. Authored-by: Rui Wang Signed-off-by: Wenchen Fan (cherry picked from commit 8ff6b7a04cbaef9c552789ad5550ceab760cb078) Signed-off-by: Wenchen Fan --- .../org/apache/spark/SparkException.scala | 9 + .../apache/spark/util/SparkClassUtils.scala | 7 + .../spark/util/SparkCollectionUtils.scala | 37 + .../CheckConnectJvmClientCompatibility.scala | 4 + .../planner/SparkConnectPlannerSuite.scala | 2 +- .../planner/SparkConnectProtoSuite.scala | 3 +- .../scala/org/apache/spark/util/Utils.scala | 5 - .../apache/spark/util/collection/Utils.scala | 18 +- dev/.rat-excludes | 1 + project/MimaExcludes.scala | 65 ++ project/SparkBuild.scala | 4 +- sql/api/pom.xml | 35 + .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 0 .../sql/catalyst/parser/SqlBaseLexer.tokens | 665 ++++++++++++++++++ .../sql/catalyst/parser/SqlBaseParser.g4 | 0 .../org/apache/spark/sql/SqlApiConf.scala | 6 + .../catalyst/analysis/SqlApiAnalysis.scala | 25 + .../catalyst/parser/DataTypeAstBuilder.scala | 9 +- .../parser/DataTypeParserInterface.scala | 0 .../parser/LegacyTypeStringParser.scala | 4 +- .../parser/SparkParserErrorStrategy.scala | 0 .../spark/sql/catalyst/parser/parsers.scala | 7 +- .../catalyst/util/AttributeNameParser.scala | 68 ++ .../sql/catalyst/util/DataTypeJsonUtils.scala | 0 .../sql/catalyst/util/QuotingUtils.scala | 19 + .../util/ResolveDefaultColumnsUtils.scala | 46 ++ .../sql/catalyst/util/SparkParserUtils.scala | 156 ++++ .../spark/sql/errors/DataTypeErrors.scala | 123 +++- .../spark/sql/errors/DataTypeErrorsBase.scala | 49 +- .../spark/sql/errors/QueryParsingErrors.scala | 7 +- .../spark/sql/types/AbstractDataType.scala | 0 .../apache/spark/sql/types/ArrayType.scala | 0 .../apache/spark/sql/types/BinaryType.scala | 0 .../apache/spark/sql/types/BooleanType.scala | 0 .../org/apache/spark/sql/types/ByteType.scala | 0 .../sql/types/CalendarIntervalType.scala | 0 .../org/apache/spark/sql/types/CharType.scala | 0 .../org/apache/spark/sql/types/DataType.scala | 8 +- .../org/apache/spark/sql/types/DateType.scala | 0 .../spark/sql/types/DayTimeIntervalType.scala | 0 .../org/apache/spark/sql/types/Decimal.scala | 29 +- .../apache/spark/sql/types/DecimalType.scala | 4 +- .../apache/spark/sql/types/DoubleType.scala | 0 .../apache/spark/sql/types/FloatType.scala | 0 .../apache/spark/sql/types/IntegerType.scala | 0 .../org/apache/spark/sql/types/LongType.scala | 0 .../org/apache/spark/sql/types/MapType.scala | 0 .../org/apache/spark/sql/types/Metadata.scala | 0 .../org/apache/spark/sql/types/NullType.scala | 0 .../apache/spark/sql/types/ObjectType.scala | 0 .../apache/spark/sql/types/ShortType.scala | 0 .../apache/spark/sql/types/StringType.scala | 0 .../apache/spark/sql/types/StructField.scala | 11 +- .../apache/spark/sql/types/StructType.scala | 31 +- .../spark/sql/types/TimestampNTZType.scala | 0 .../spark/sql/types/TimestampType.scala | 0 .../spark/sql/types/UDTRegistration.scala | 6 +- .../spark/sql/types/UserDefinedType.scala | 0 .../apache/spark/sql/types/VarcharType.scala | 0 .../sql/types/YearMonthIntervalType.scala | 0 sql/catalyst/pom.xml | 8 - .../analysis/RewriteRowLevelCommand.scala | 6 +- .../spark/sql/catalyst/analysis/package.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 47 +- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../aggregate/HyperLogLogPlusPlus.scala | 3 +- .../expressions/aggregate/PivotFirst.scala | 3 +- .../expressions/aggregate/interfaces.scala | 5 +- .../expressions/bitmapExpressions.scala | 5 +- .../sql/catalyst/optimizer/objects.scala | 2 +- .../sql/catalyst/parser/ParserUtils.scala | 135 +--- .../spark/sql/catalyst/plans/QueryPlan.scala | 3 +- .../plans/logical/LocalRelation.scala | 4 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 7 +- .../sql/catalyst/types/DataTypeUtils.scala | 5 +- .../util/ResolveDefaultColumnsUtil.scala | 24 +- .../spark/sql/catalyst/util/package.scala | 15 +- .../sql/errors/QueryCompilationErrors.scala | 4 +- .../spark/sql/errors/QueryErrorsBase.scala | 72 +- .../sql/errors/QueryExecutionErrors.scala | 29 +- .../spark/sql/catalyst/SQLKeywordSuite.scala | 4 +- .../PropagateEmptyRelationSuite.scala | 7 +- .../spark/sql/types/StructTypeSuite.scala | 24 +- .../sql/execution/CollectMetricsExec.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../aggregate/HashAggregateExec.scala | 7 +- .../aggregate/ObjectAggregationIterator.scala | 5 +- .../aggregate/ObjectAggregationMap.scala | 6 +- .../TungstenAggregationIterator.scala | 10 +- .../execution/columnar/InMemoryRelation.scala | 3 +- .../datasources/DataSourceStrategy.scala | 4 +- .../datasources/PartitioningUtils.scala | 3 +- .../datasources/v2/V2CommandExec.scala | 4 +- .../execution/datasources/v2/V2Writes.scala | 3 +- .../exchange/ShuffleExchangeExec.scala | 4 +- .../apache/spark/sql/execution/objects.scala | 3 +- .../python/FlatMapCoGroupsInPandasExec.scala | 6 +- .../python/FlatMapGroupsInPandasExec.scala | 4 +- .../FlatMapGroupsInPandasWithStateExec.scala | 3 +- .../StreamingAggregationStateManager.scala | 5 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../spark/sql/connector/AlterTableTests.scala | 2 +- .../errors/QueryCompilationErrorsSuite.scala | 4 +- .../CoalesceBucketsInJoinSuite.scala | 5 +- 105 files changed, 1504 insertions(+), 467 deletions(-) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala rename sql/{catalyst => api}/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 (100%) create mode 100644 sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens rename sql/{catalyst => api}/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 (100%) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala (96%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala (95%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala (98%) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala (100%) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtils.scala create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala (98%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/ArrayType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/BinaryType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/BooleanType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/ByteType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/CharType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/DataType.scala (98%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/DateType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/Decimal.scala (96%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/DecimalType.scala (98%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/DoubleType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/FloatType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/IntegerType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/LongType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/MapType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/Metadata.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/NullType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/ObjectType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/ShortType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/StringType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/StructField.scala (91%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/StructType.scala (94%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/TimestampType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala (95%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/VarcharType.scala (100%) rename sql/{catalyst => api}/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala (100%) diff --git a/common/utils/src/main/scala/org/apache/spark/SparkException.scala b/common/utils/src/main/scala/org/apache/spark/SparkException.scala index feb7bf5b66eda..4dafaba685e71 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkException.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkException.scala @@ -58,6 +58,15 @@ class SparkException( errorClass = Some(errorClass), messageParameters = messageParameters) + def this(errorClass: String, messageParameters: Map[String, String], cause: Throwable, + context: Array[QueryContext]) = + this( + message = SparkThrowableHelper.getMessage(errorClass, messageParameters), + cause = cause, + errorClass = Some(errorClass), + messageParameters = messageParameters, + context = context) + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass.orNull diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala index 011f74de1febe..7401e3762417f 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.util +import scala.util.Try + trait SparkClassUtils { def getSparkClassLoader: ClassLoader = getClass.getClassLoader @@ -39,6 +41,11 @@ trait SparkClassUtils { } // scalastyle:on classforname } + + /** Determines whether the provided class is loadable in the current thread. */ + def classIsLoadable(clazz: String): Boolean = { + Try { classForName(clazz, initialize = false) }.isSuccess + } } object SparkClassUtils extends SparkClassUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala new file mode 100644 index 0000000000000..7fecc9ccb664d --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import scala.collection.immutable + +trait SparkCollectionUtils { + /** + * Same function as `keys.zipWithIndex.toMap`, but has perf gain. + */ + def toMapWithIndex[K](keys: Iterable[K]): Map[K, Int] = { + val builder = immutable.Map.newBuilder[K, Int] + val keyIter = keys.iterator + var idx = 0 + while (keyIter.hasNext) { + builder += (keyIter.next(), idx).asInstanceOf[(K, Int)] + idx = idx + 1 + } + builder.result() + } +} + +object SparkCollectionUtils extends SparkCollectionUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 70db03df7bdd8..e7f01d6140dec 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -311,6 +311,10 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.SQLImplicits._sqlContext" // protected ), + // Catalyst Refactoring + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils$"), + // New public APIs added in the client // ScalarUserDefinedFunction ProblemFilters diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 595f9d65c269b..40d83b07b756d 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -105,7 +105,7 @@ trait SparkConnectPlanTest extends SharedSparkSession { val bytes = ArrowConverters .toBatchWithSchemaIterator( data.iterator, - StructType.fromAttributes(attrs.map(_.toAttribute)), + DataTypeUtils.fromAttributes(attrs.map(_.toAttribute)), Long.MaxValue, Long.MaxValue, null, diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 82941d8d72e50..63b6f775d7b14 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.dsl.MockRemoteSession @@ -1050,7 +1051,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val buffer = ArrowConverters .toBatchWithSchemaIterator( Iterator.empty, - StructType.fromAttributes(attributes), + DataTypeUtils.fromAttributes(attributes), Long.MaxValue, Long.MaxValue, null, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 048c6ebfa7489..5c9eea7a15176 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -190,11 +190,6 @@ private[spark] object Utils extends Logging with SparkClassUtils { weakStringInterner.intern(s) } - /** Determines whether the provided class is loadable in the current thread. */ - def classIsLoadable(clazz: String): Boolean = { - Try { classForName(clazz, initialize = false) }.isSuccess - } - /** * Run a segment of code using a different context class loader in the current thread */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index 1c3699058e462..b2ced00e8d6c5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -24,10 +24,12 @@ import scala.collection.immutable import com.google.common.collect.{Iterators => GuavaIterators, Ordering => GuavaOrdering} +import org.apache.spark.sql.catalyst.util.SparkCollectionUtils + /** * Utility functions for collections. */ -private[spark] object Utils { +private[spark] object Utils extends SparkCollectionUtils { /** * Returns the first K elements from the input as defined by the specified implicit Ordering[T] @@ -79,20 +81,6 @@ private[spark] object Utils { builder.result() } - /** - * Same function as `keys.zipWithIndex.toMap`, but has perf gain. - */ - def toMapWithIndex[K](keys: Iterable[K]): Map[K, Int] = { - val builder = immutable.Map.newBuilder[K, Int] - val keyIter = keys.iterator - var idx = 0 - while (keyIter.hasNext) { - builder += (keyIter.next(), idx).asInstanceOf[(K, Int)] - idx = idx + 1 - } - builder.result() - } - /** * Same function as `keys.zip(values).toMap.asJava`, but has perf gain. */ diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 58c60e4485e8e..15ddfddd0e895 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -139,6 +139,7 @@ exported_table/* ansible-for-test-node/* node_modules spark-events-broken/* +SqlBaseLexer.tokens # Spark Connect related files with custom licence any.proto empty.proto diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2e70fd9225c66..d727c8c6917cd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -105,6 +105,71 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageLevel"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageLevel$"), + // SPARK-44475: Relocate DataType and Parser to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ArrayType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ArrayType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.BinaryType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.BinaryType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.BooleanType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.BooleanType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ByteType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ByteType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.CalendarIntervalType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.CalendarIntervalType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.CharType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.CharType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DataType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DataType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DateType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DateType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DayTimeIntervalType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DayTimeIntervalType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.Decimal"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.Decimal$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ShortType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ShortType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.StringType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.StringType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.StructField"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.StructField$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.StructType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.StructType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.TimestampNTZType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.TimestampNTZType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.TimestampType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.TimestampType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.UDTRegistration"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.UDTRegistration$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.VarcharType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.VarcharType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.YearMonthIntervalType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.YearMonthIntervalType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DecimalType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DecimalType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DoubleType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DoubleType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DoubleType$DoubleAsIfIntegral"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DoubleType$DoubleAsIfIntegral$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.DoubleType$DoubleIsConflicted"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.FloatType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.FloatType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.FloatType$FloatAsIfIntegral"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.FloatType$FloatAsIfIntegral$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.FloatType$FloatIsConflicted"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.IntegerType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.IntegerType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.LongType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.LongType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.MapType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.MapType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.Metadata"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.Metadata$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.MetadataBuilder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.NullType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.NullType$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ObjectType"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.types.ObjectType$"), + (problem: Problem) => problem match { case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") && !cls.fullName.startsWith("org.sparkproject.dmg.pmml") diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e14886a916509..cc27686b6b335 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -448,8 +448,8 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) - /* Catalyst ANTLR generation settings */ - enable(Catalyst.settings)(catalyst) + /* Sql-api ANTLR generation settings */ + enable(Catalyst.settings)(sqlApi) /* Spark SQL Core console settings */ enable(SQL.settings)(sql) diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 6add5679ce7c4..f51b34d83314b 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -35,6 +35,10 @@ + + org.scala-lang.modules + scala-parser-combinators_${scala.binary.version} + org.apache.spark spark-common-utils_${scala.binary.version} @@ -45,6 +49,21 @@ spark-unsafe_${scala.binary.version} ${project.version} + + org.json4s + json4s-jackson_${scala.binary.version} + 3.7.0-M11 + + + com.fasterxml.jackson.core + * + + + + + org.antlr + antlr4-runtime + target/scala-${scala.binary.version}/classes @@ -68,6 +87,22 @@ + + org.antlr + antlr4-maven-plugin + + + + antlr4 + + + + + true + ../api/src/main/antlr4 + true + + \ No newline at end of file diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 similarity index 100% rename from sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 rename to sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens new file mode 100644 index 0000000000000..459749d8ffe6f --- /dev/null +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens @@ -0,0 +1,665 @@ +SEMICOLON=1 +LEFT_PAREN=2 +RIGHT_PAREN=3 +COMMA=4 +DOT=5 +LEFT_BRACKET=6 +RIGHT_BRACKET=7 +ADD=8 +AFTER=9 +ALL=10 +ALTER=11 +ANALYZE=12 +AND=13 +ANTI=14 +ANY=15 +ANY_VALUE=16 +ARCHIVE=17 +ARRAY=18 +AS=19 +ASC=20 +AT=21 +AUTHORIZATION=22 +BETWEEN=23 +BOTH=24 +BUCKET=25 +BUCKETS=26 +BY=27 +CACHE=28 +CASCADE=29 +CASE=30 +CAST=31 +CATALOG=32 +CATALOGS=33 +CHANGE=34 +CHECK=35 +CLEAR=36 +CLUSTER=37 +CLUSTERED=38 +CODEGEN=39 +COLLATE=40 +COLLECTION=41 +COLUMN=42 +COLUMNS=43 +COMMENT=44 +COMMIT=45 +COMPACT=46 +COMPACTIONS=47 +COMPUTE=48 +CONCATENATE=49 +CONSTRAINT=50 +COST=51 +CREATE=52 +CROSS=53 +CUBE=54 +CURRENT=55 +CURRENT_DATE=56 +CURRENT_TIME=57 +CURRENT_TIMESTAMP=58 +CURRENT_USER=59 +DAY=60 +DAYS=61 +DAYOFYEAR=62 +DATA=63 +DATABASE=64 +DATABASES=65 +DATEADD=66 +DATEDIFF=67 +DBPROPERTIES=68 +DEFAULT=69 +DEFINED=70 +DELETE=71 +DELIMITED=72 +DESC=73 +DESCRIBE=74 +DFS=75 +DIRECTORIES=76 +DIRECTORY=77 +DISTINCT=78 +DISTRIBUTE=79 +DIV=80 +DROP=81 +ELSE=82 +END=83 +ESCAPE=84 +ESCAPED=85 +EXCEPT=86 +EXCHANGE=87 +EXCLUDE=88 +EXISTS=89 +EXPLAIN=90 +EXPORT=91 +EXTENDED=92 +EXTERNAL=93 +EXTRACT=94 +FALSE=95 +FETCH=96 +FIELDS=97 +FILTER=98 +FILEFORMAT=99 +FIRST=100 +FOLLOWING=101 +FOR=102 +FOREIGN=103 +FORMAT=104 +FORMATTED=105 +FROM=106 +FULL=107 +FUNCTION=108 +FUNCTIONS=109 +GLOBAL=110 +GRANT=111 +GROUP=112 +GROUPING=113 +HAVING=114 +HOUR=115 +HOURS=116 +IF=117 +IGNORE=118 +IMPORT=119 +IN=120 +INCLUDE=121 +INDEX=122 +INDEXES=123 +INNER=124 +INPATH=125 +INPUTFORMAT=126 +INSERT=127 +INTERSECT=128 +INTERVAL=129 +INTO=130 +IS=131 +ITEMS=132 +JOIN=133 +KEYS=134 +LAST=135 +LATERAL=136 +LAZY=137 +LEADING=138 +LEFT=139 +LIKE=140 +ILIKE=141 +LIMIT=142 +LINES=143 +LIST=144 +LOAD=145 +LOCAL=146 +LOCATION=147 +LOCK=148 +LOCKS=149 +LOGICAL=150 +MACRO=151 +MAP=152 +MATCHED=153 +MERGE=154 +MICROSECOND=155 +MICROSECONDS=156 +MILLISECOND=157 +MILLISECONDS=158 +MINUTE=159 +MINUTES=160 +MONTH=161 +MONTHS=162 +MSCK=163 +NAMESPACE=164 +NAMESPACES=165 +NANOSECOND=166 +NANOSECONDS=167 +NATURAL=168 +NO=169 +NOT=170 +NULL=171 +NULLS=172 +OF=173 +OFFSET=174 +ON=175 +ONLY=176 +OPTION=177 +OPTIONS=178 +OR=179 +ORDER=180 +OUT=181 +OUTER=182 +OUTPUTFORMAT=183 +OVER=184 +OVERLAPS=185 +OVERLAY=186 +OVERWRITE=187 +PARTITION=188 +PARTITIONED=189 +PARTITIONS=190 +PERCENTILE_CONT=191 +PERCENTILE_DISC=192 +PERCENTLIT=193 +PIVOT=194 +PLACING=195 +POSITION=196 +PRECEDING=197 +PRIMARY=198 +PRINCIPALS=199 +PROPERTIES=200 +PURGE=201 +QUARTER=202 +QUERY=203 +RANGE=204 +RECORDREADER=205 +RECORDWRITER=206 +RECOVER=207 +REDUCE=208 +REFERENCES=209 +REFRESH=210 +RENAME=211 +REPAIR=212 +REPEATABLE=213 +REPLACE=214 +RESET=215 +RESPECT=216 +RESTRICT=217 +REVOKE=218 +RIGHT=219 +RLIKE=220 +ROLE=221 +ROLES=222 +ROLLBACK=223 +ROLLUP=224 +ROW=225 +ROWS=226 +SECOND=227 +SECONDS=228 +SCHEMA=229 +SCHEMAS=230 +SELECT=231 +SEMI=232 +SEPARATED=233 +SERDE=234 +SERDEPROPERTIES=235 +SESSION_USER=236 +SET=237 +SETMINUS=238 +SETS=239 +SHOW=240 +SKEWED=241 +SOME=242 +SORT=243 +SORTED=244 +SOURCE=245 +START=246 +STATISTICS=247 +STORED=248 +STRATIFY=249 +STRUCT=250 +SUBSTR=251 +SUBSTRING=252 +SYNC=253 +SYSTEM_TIME=254 +SYSTEM_VERSION=255 +TABLE=256 +TABLES=257 +TABLESAMPLE=258 +TARGET=259 +TBLPROPERTIES=260 +TEMPORARY=261 +TERMINATED=262 +THEN=263 +TIME=264 +TIMESTAMP=265 +TIMESTAMPADD=266 +TIMESTAMPDIFF=267 +TO=268 +TOUCH=269 +TRAILING=270 +TRANSACTION=271 +TRANSACTIONS=272 +TRANSFORM=273 +TRIM=274 +TRUE=275 +TRUNCATE=276 +TRY_CAST=277 +TYPE=278 +UNARCHIVE=279 +UNBOUNDED=280 +UNCACHE=281 +UNION=282 +UNIQUE=283 +UNKNOWN=284 +UNLOCK=285 +UNPIVOT=286 +UNSET=287 +UPDATE=288 +USE=289 +USER=290 +USING=291 +VALUES=292 +VERSION=293 +VIEW=294 +VIEWS=295 +WEEK=296 +WEEKS=297 +WHEN=298 +WHERE=299 +WINDOW=300 +WITH=301 +WITHIN=302 +YEAR=303 +YEARS=304 +ZONE=305 +EQ=306 +NSEQ=307 +NEQ=308 +NEQJ=309 +LT=310 +LTE=311 +GT=312 +GTE=313 +PLUS=314 +MINUS=315 +ASTERISK=316 +SLASH=317 +PERCENT=318 +TILDE=319 +AMPERSAND=320 +PIPE=321 +CONCAT_PIPE=322 +HAT=323 +COLON=324 +ARROW=325 +HENT_START=326 +HENT_END=327 +STRING=328 +DOUBLEQUOTED_STRING=329 +BIGINT_LITERAL=330 +SMALLINT_LITERAL=331 +TINYINT_LITERAL=332 +INTEGER_VALUE=333 +EXPONENT_VALUE=334 +DECIMAL_VALUE=335 +FLOAT_LITERAL=336 +DOUBLE_LITERAL=337 +BIGDECIMAL_LITERAL=338 +IDENTIFIER=339 +BACKQUOTED_IDENTIFIER=340 +SIMPLE_COMMENT=341 +BRACKETED_COMMENT=342 +WS=343 +UNRECOGNIZED=344 +';'=1 +'('=2 +')'=3 +','=4 +'.'=5 +'['=6 +']'=7 +'ADD'=8 +'AFTER'=9 +'ALL'=10 +'ALTER'=11 +'ANALYZE'=12 +'AND'=13 +'ANTI'=14 +'ANY'=15 +'ANY_VALUE'=16 +'ARCHIVE'=17 +'ARRAY'=18 +'AS'=19 +'ASC'=20 +'AT'=21 +'AUTHORIZATION'=22 +'BETWEEN'=23 +'BOTH'=24 +'BUCKET'=25 +'BUCKETS'=26 +'BY'=27 +'CACHE'=28 +'CASCADE'=29 +'CASE'=30 +'CAST'=31 +'CATALOG'=32 +'CATALOGS'=33 +'CHANGE'=34 +'CHECK'=35 +'CLEAR'=36 +'CLUSTER'=37 +'CLUSTERED'=38 +'CODEGEN'=39 +'COLLATE'=40 +'COLLECTION'=41 +'COLUMN'=42 +'COLUMNS'=43 +'COMMENT'=44 +'COMMIT'=45 +'COMPACT'=46 +'COMPACTIONS'=47 +'COMPUTE'=48 +'CONCATENATE'=49 +'CONSTRAINT'=50 +'COST'=51 +'CREATE'=52 +'CROSS'=53 +'CUBE'=54 +'CURRENT'=55 +'CURRENT_DATE'=56 +'CURRENT_TIME'=57 +'CURRENT_TIMESTAMP'=58 +'CURRENT_USER'=59 +'DAY'=60 +'DAYS'=61 +'DAYOFYEAR'=62 +'DATA'=63 +'DATABASE'=64 +'DATABASES'=65 +'DATEADD'=66 +'DATEDIFF'=67 +'DBPROPERTIES'=68 +'DEFAULT'=69 +'DEFINED'=70 +'DELETE'=71 +'DELIMITED'=72 +'DESC'=73 +'DESCRIBE'=74 +'DFS'=75 +'DIRECTORIES'=76 +'DIRECTORY'=77 +'DISTINCT'=78 +'DISTRIBUTE'=79 +'DIV'=80 +'DROP'=81 +'ELSE'=82 +'END'=83 +'ESCAPE'=84 +'ESCAPED'=85 +'EXCEPT'=86 +'EXCHANGE'=87 +'EXCLUDE'=88 +'EXISTS'=89 +'EXPLAIN'=90 +'EXPORT'=91 +'EXTENDED'=92 +'EXTERNAL'=93 +'EXTRACT'=94 +'FALSE'=95 +'FETCH'=96 +'FIELDS'=97 +'FILTER'=98 +'FILEFORMAT'=99 +'FIRST'=100 +'FOLLOWING'=101 +'FOR'=102 +'FOREIGN'=103 +'FORMAT'=104 +'FORMATTED'=105 +'FROM'=106 +'FULL'=107 +'FUNCTION'=108 +'FUNCTIONS'=109 +'GLOBAL'=110 +'GRANT'=111 +'GROUP'=112 +'GROUPING'=113 +'HAVING'=114 +'HOUR'=115 +'HOURS'=116 +'IF'=117 +'IGNORE'=118 +'IMPORT'=119 +'IN'=120 +'INCLUDE'=121 +'INDEX'=122 +'INDEXES'=123 +'INNER'=124 +'INPATH'=125 +'INPUTFORMAT'=126 +'INSERT'=127 +'INTERSECT'=128 +'INTERVAL'=129 +'INTO'=130 +'IS'=131 +'ITEMS'=132 +'JOIN'=133 +'KEYS'=134 +'LAST'=135 +'LATERAL'=136 +'LAZY'=137 +'LEADING'=138 +'LEFT'=139 +'LIKE'=140 +'ILIKE'=141 +'LIMIT'=142 +'LINES'=143 +'LIST'=144 +'LOAD'=145 +'LOCAL'=146 +'LOCATION'=147 +'LOCK'=148 +'LOCKS'=149 +'LOGICAL'=150 +'MACRO'=151 +'MAP'=152 +'MATCHED'=153 +'MERGE'=154 +'MICROSECOND'=155 +'MICROSECONDS'=156 +'MILLISECOND'=157 +'MILLISECONDS'=158 +'MINUTE'=159 +'MINUTES'=160 +'MONTH'=161 +'MONTHS'=162 +'MSCK'=163 +'NAMESPACE'=164 +'NAMESPACES'=165 +'NANOSECOND'=166 +'NANOSECONDS'=167 +'NATURAL'=168 +'NO'=169 +'NULL'=171 +'NULLS'=172 +'OF'=173 +'OFFSET'=174 +'ON'=175 +'ONLY'=176 +'OPTION'=177 +'OPTIONS'=178 +'OR'=179 +'ORDER'=180 +'OUT'=181 +'OUTER'=182 +'OUTPUTFORMAT'=183 +'OVER'=184 +'OVERLAPS'=185 +'OVERLAY'=186 +'OVERWRITE'=187 +'PARTITION'=188 +'PARTITIONED'=189 +'PARTITIONS'=190 +'PERCENTILE_CONT'=191 +'PERCENTILE_DISC'=192 +'PERCENT'=193 +'PIVOT'=194 +'PLACING'=195 +'POSITION'=196 +'PRECEDING'=197 +'PRIMARY'=198 +'PRINCIPALS'=199 +'PROPERTIES'=200 +'PURGE'=201 +'QUARTER'=202 +'QUERY'=203 +'RANGE'=204 +'RECORDREADER'=205 +'RECORDWRITER'=206 +'RECOVER'=207 +'REDUCE'=208 +'REFERENCES'=209 +'REFRESH'=210 +'RENAME'=211 +'REPAIR'=212 +'REPEATABLE'=213 +'REPLACE'=214 +'RESET'=215 +'RESPECT'=216 +'RESTRICT'=217 +'REVOKE'=218 +'RIGHT'=219 +'ROLE'=221 +'ROLES'=222 +'ROLLBACK'=223 +'ROLLUP'=224 +'ROW'=225 +'ROWS'=226 +'SECOND'=227 +'SECONDS'=228 +'SCHEMA'=229 +'SCHEMAS'=230 +'SELECT'=231 +'SEMI'=232 +'SEPARATED'=233 +'SERDE'=234 +'SERDEPROPERTIES'=235 +'SESSION_USER'=236 +'SET'=237 +'MINUS'=238 +'SETS'=239 +'SHOW'=240 +'SKEWED'=241 +'SOME'=242 +'SORT'=243 +'SORTED'=244 +'SOURCE'=245 +'START'=246 +'STATISTICS'=247 +'STORED'=248 +'STRATIFY'=249 +'STRUCT'=250 +'SUBSTR'=251 +'SUBSTRING'=252 +'SYNC'=253 +'SYSTEM_TIME'=254 +'SYSTEM_VERSION'=255 +'TABLE'=256 +'TABLES'=257 +'TABLESAMPLE'=258 +'TARGET'=259 +'TBLPROPERTIES'=260 +'TERMINATED'=262 +'THEN'=263 +'TIME'=264 +'TIMESTAMP'=265 +'TIMESTAMPADD'=266 +'TIMESTAMPDIFF'=267 +'TO'=268 +'TOUCH'=269 +'TRAILING'=270 +'TRANSACTION'=271 +'TRANSACTIONS'=272 +'TRANSFORM'=273 +'TRIM'=274 +'TRUE'=275 +'TRUNCATE'=276 +'TRY_CAST'=277 +'TYPE'=278 +'UNARCHIVE'=279 +'UNBOUNDED'=280 +'UNCACHE'=281 +'UNION'=282 +'UNIQUE'=283 +'UNKNOWN'=284 +'UNLOCK'=285 +'UNPIVOT'=286 +'UNSET'=287 +'UPDATE'=288 +'USE'=289 +'USER'=290 +'USING'=291 +'VALUES'=292 +'VERSION'=293 +'VIEW'=294 +'VIEWS'=295 +'WEEK'=296 +'WEEKS'=297 +'WHEN'=298 +'WHERE'=299 +'WINDOW'=300 +'WITH'=301 +'WITHIN'=302 +'YEAR'=303 +'YEARS'=304 +'ZONE'=305 +'<=>'=307 +'<>'=308 +'!='=309 +'<'=310 +'>'=312 +'+'=314 +'-'=315 +'*'=316 +'/'=317 +'%'=318 +'~'=319 +'&'=320 +'|'=321 +'||'=322 +'^'=323 +':'=324 +'->'=325 +'/*+'=326 +'*/'=327 diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 similarity index 100% rename from sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 rename to sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala index 3a074c666149a..c297a2e067a3d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SqlApiConf.scala @@ -16,10 +16,12 @@ */ package org.apache.spark.sql + import java.util.concurrent.atomic.AtomicReference import scala.util.Try +import org.apache.spark.sql.types.{AtomicType, TimestampType} import org.apache.spark.util.SparkClassUtils /** @@ -36,6 +38,8 @@ private[sql] trait SqlApiConf { def exponentLiteralAsDecimalEnabled: Boolean def enforceReservedKeywords: Boolean def doubleQuotedIdentifiers: Boolean + def timestampType: AtomicType + def allowNegativeScaleOfDecimalEnabled: Boolean } private[sql] object SqlApiConf { @@ -68,4 +72,6 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf { override def exponentLiteralAsDecimalEnabled: Boolean = false override def enforceReservedKeywords: Boolean = false override def doubleQuotedIdentifiers: Boolean = false + override def timestampType: AtomicType = TimestampType + override def allowNegativeScaleOfDecimalEnabled: Boolean = false } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala new file mode 100644 index 0000000000000..9f5a5b8875b33 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +object SqlApiAnalysis { + /** + * Resolver should return true if the first string refers to the same entity as the second string. + * For example, by using case insensitive equality. + */ + type Resolver = (String, String) => Boolean +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index 84a8bc71b3fb0..f42137e2d3f6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -23,14 +23,13 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.Token import org.antlr.v4.runtime.tree.ParseTree -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.parser.ParserUtils.{string, withOrigin} +import org.apache.spark.sql.SqlApiConf import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType} -class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper { +class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -69,7 +68,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHe case (FLOAT | REAL, Nil) => FloatType case (DOUBLE, Nil) => DoubleType case (DATE, Nil) => DateType - case (TIMESTAMP, Nil) => SQLConf.get.timestampType + case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType case (TIMESTAMP_NTZ, Nil) => TimestampNTZType case (TIMESTAMP_LTZ, Nil) => TimestampType case (STRING, Nil) => StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala similarity index 95% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala index f36fcade382bf..8ac5939bca944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types._ /** @@ -88,6 +88,6 @@ object LegacyTypeStringParser extends RegexParsers { def parseString(asString: String): DataType = parseAll(dataType, asString) match { case Success(result, _) => result case failure: NoSuccess => - throw QueryExecutionErrors.dataTypeUnsupportedError(asString, failure.toString) + throw DataTypeErrors.dataTypeUnsupportedError(asString, failure.toString) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index d4c1c1d9db6c5..ac285d54c1fe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -27,6 +27,7 @@ import org.apache.spark.{QueryContext, SparkException, SparkThrowable, SparkThro import org.apache.spark.internal.Logging import org.apache.spark.sql.SqlApiConf import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.types.{DataType, StructType} @@ -197,10 +198,10 @@ class ParseException( queryContext) { def this(errorClass: String, messageParameters: Map[String, String], ctx: ParserRuleContext) = - this(Option(ParserUtils.command(ctx)), + this(Option(SparkParserUtils.command(ctx)), SparkThrowableHelper.getMessage(errorClass, messageParameters), - ParserUtils.position(ctx.getStart), - ParserUtils.position(ctx.getStop), + SparkParserUtils.position(ctx.getStart), + SparkParserUtils.position(ctx.getStop), Some(errorClass), messageParameters) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala new file mode 100644 index 0000000000000..e47ab1978d0ed --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.errors.DataTypeErrors + +trait AttributeNameParser { + /** + * Used to split attribute name by dot with backticks rule. + * Backticks must appear in pairs, and the quoted string must be a complete name part, + * which means `ab..c`e.f is not allowed. + * We can use backtick only inside quoted name parts. + */ + def parseAttributeName(name: String): Seq[String] = { + def e = DataTypeErrors.attributeNameSyntaxError(name) + val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] + val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] + var inBacktick = false + var i = 0 + while (i < name.length) { + val char = name(i) + if (inBacktick) { + if (char == '`') { + if (i + 1 < name.length && name(i + 1) == '`') { + tmp += '`' + i += 1 + } else { + inBacktick = false + if (i + 1 < name.length && name(i + 1) != '.') throw e + } + } else { + tmp += char + } + } else { + if (char == '`') { + if (tmp.nonEmpty) throw e + inBacktick = true + } else if (char == '.') { + if (name(i - 1) == '.' || i == name.length - 1) throw e + nameParts += tmp.mkString + tmp.clear() + } else { + tmp += char + } + } + i += 1 + } + if (inBacktick) throw e + nameParts += tmp.mkString + nameParts.toSeq + } +} + +object AttributeNameParser extends AttributeNameParser diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala index 7d5b6946244be..62015fe206a30 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/QuotingUtils.scala @@ -28,4 +28,23 @@ object QuotingUtils { def toSQLSchema(schema: String): String = { quoteByDefault(schema) } + + def quoteIfNeeded(part: String): String = { + if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) { + part + } else { + s"`${part.replace("`", "``")}`" + } + } + + def escapeSingleQuotedString(str: String): String = { + val builder = new StringBuilder + + str.foreach { + case '\'' => builder ++= s"\\\'" + case ch => builder += ch + } + + builder.toString() + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtils.scala new file mode 100644 index 0000000000000..4314f85740461 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +trait ResolveDefaultColumnsUtils { + // This column metadata indicates the default value associated with a particular table column that + // is in effect at any given time. Its value begins at the time of the initial CREATE/REPLACE + // TABLE statement with DEFAULT column definition(s), if any. It then changes whenever an ALTER + // TABLE statement SETs the DEFAULT. The intent is for this "current default" to be used by + // UPDATE, INSERT and MERGE, which evaluate each default expression for each row. + val CURRENT_DEFAULT_COLUMN_METADATA_KEY = "CURRENT_DEFAULT" + + // This column metadata represents the default value for all existing rows in a table after a + // column has been added. This value is determined at time of CREATE TABLE, REPLACE TABLE, or + // ALTER TABLE ADD COLUMN, and never changes thereafter. The intent is for this "exist default" to + // be used by any scan when the columns in the source row are missing data. For example, consider + // the following sequence: + // CREATE TABLE t (c1 INT) + // INSERT INTO t VALUES (42) + // ALTER TABLE t ADD COLUMNS (c2 INT DEFAULT 43) + // SELECT c1, c2 FROM t + // In this case, the final query is expected to return 42, 43. The ALTER TABLE ADD COLUMNS command + // executed after there was already data in the table, so in order to enforce this invariant, we + // need either (1) an expensive backfill of value 43 at column c2 into all previous rows, or (2) + // indicate to each data source that selected columns missing data are to generate the + // corresponding DEFAULT value instead. We choose option (2) for efficiency, and represent this + // value as the text representation of a folded constant in the "EXISTS_DEFAULT" column metadata. + val EXISTS_DEFAULT_COLUMN_METADATA_KEY = "EXISTS_DEFAULT" +} + +object ResolveDefaultColumnsUtils extends ResolveDefaultColumnsUtils diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala new file mode 100644 index 0000000000000..c318f208255f1 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import java.lang.{Long => JLong} +import java.nio.CharBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode + +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} + +trait SparkParserUtils { + val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r + val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r + val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r + val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r + + /** Unescape backslash-escaped string enclosed by quotes. */ + def unescapeSQLString(b: String): String = { + val sb = new StringBuilder(b.length()) + + def appendEscapedChar(n: Char): Unit = { + n match { + case '0' => sb.append('\u0000') + case '\'' => sb.append('\'') + case '"' => sb.append('\"') + case 'b' => sb.append('\b') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'Z' => sb.append('\u001A') + case '\\' => sb.append('\\') + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%' => sb.append("\\%") + case '_' => sb.append("\\_") + case _ => sb.append(n) + } + } + + if (b.startsWith("r") || b.startsWith("R")) { + b.substring(2, b.length - 1) + } else { + // Skip the first and last quotations enclosing the string literal. + val charBuffer = CharBuffer.wrap(b, 1, b.length - 1) + + while (charBuffer.remaining() > 0) { + charBuffer match { + case U16_CHAR_PATTERN(cp) => + // \u0000 style 16-bit unicode character literals. + sb.append(Integer.parseInt(cp, 16).toChar) + charBuffer.position(charBuffer.position() + 6) + case U32_CHAR_PATTERN(cp) => + // \U00000000 style 32-bit unicode character literals. + // Use Long to treat codePoint as unsigned in the range of 32-bit. + val codePoint = JLong.parseLong(cp, 16) + if (codePoint < 0x10000) { + sb.append((codePoint & 0xFFFF).toChar) + } else { + val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xD800 + val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xDC00 + sb.append(highSurrogate.toChar) + sb.append(lowSurrogate.toChar) + } + charBuffer.position(charBuffer.position() + 10) + case OCTAL_CHAR_PATTERN(cp) => + // \000 style character literals. + sb.append(Integer.parseInt(cp, 8).toChar) + charBuffer.position(charBuffer.position() + 4) + case ESCAPED_CHAR_PATTERN(c) => + // escaped character literals. + appendEscapedChar(c.charAt(0)) + charBuffer.position(charBuffer.position() + 2) + case _ => + // non-escaped character literals. + sb.append(charBuffer.get()) + } + } + sb.toString() + } + } + + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) + + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) + + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + val opt = Option(token) + Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) + } + + /** + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. + */ + def withOrigin[T](ctx: ParserRuleContext, sqlText: Option[String] = None)(f: => T): T = { + val current = CurrentOrigin.get + val text = sqlText.orElse(current.sqlText) + if (text.isEmpty) { + CurrentOrigin.set(position(ctx.getStart)) + } else { + CurrentOrigin.set(positionAndText(ctx.getStart, ctx.getStop, text.get, + current.objectType, current.objectName)) + } + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + def positionAndText( + startToken: Token, + stopToken: Token, + sqlText: String, + objectType: Option[String], + objectName: Option[String]): Origin = { + val startOpt = Option(startToken) + val stopOpt = Option(stopToken) + Origin( + line = startOpt.map(_.getLine), + startPosition = startOpt.map(_.getCharPositionInLine), + startIndex = startOpt.map(_.getStartIndex), + stopIndex = stopOpt.map(_.getStopIndex), + sqlText = Some(sqlText), + objectType = objectType, + objectName = objectName) + } + + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(0, stream.size() - 1)) + } +} + +object SparkParserUtils extends SparkParserUtils diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index f39b5b13456ad..69156e1165877 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -16,8 +16,10 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.{SparkArithmeticException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.catalyst.util.QuotingUtils +import org.apache.spark.sql.types.{DataType, Decimal, StringType} import org.apache.spark.unsafe.types.UTF8String /** @@ -146,4 +148,123 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { s" Set the config ${sqlConf}" + " to \"true\" to allow it.") } + + def attributeNameSyntaxError(name: String): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_1049", + messageParameters = Map("name" -> name), + cause = null) + } + + def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { + new SparkException( + errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", + messageParameters = Map( + "left" -> toSQLType(left), + "right" -> toSQLType(right)), + cause = null) + } + + def cannotMergeDecimalTypesWithIncompatibleScaleError( + leftScale: Int, rightScale: Int): Throwable = { + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2124", + messageParameters = Map( + "leftScale" -> leftScale.toString(), + "rightScale" -> rightScale.toString()), + cause = null) + } + + def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = { + new SparkIllegalArgumentException( + errorClass = "UNSUPPORTED_DATATYPE", + messageParameters = Map("typeName" -> (dataType + failure))) + } + + def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = { + new SparkException( + errorClass = "INVALID_FIELD_NAME", + messageParameters = Map( + "fieldName" -> toSQLId(fieldName), + "path" -> toSQLId(path)), + cause = null, + context = context.getQueryContext) + } + + def unscaledValueTooLargeForPrecisionError( + value: Decimal, + decimalPrecision: Int, + decimalScale: Int, + context: SQLQueryContext = null): ArithmeticException = { + new SparkArithmeticException( + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + messageParameters = Map( + "value" -> value.toPlainString, + "precision" -> decimalPrecision.toString, + "scale" -> decimalScale.toString, + "config" -> toSQLConf("spark.sql.ansi.enabled")), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def cannotChangeDecimalPrecisionError( + value: Decimal, + decimalPrecision: Int, + decimalScale: Int, + context: SQLQueryContext = null): ArithmeticException = { + new SparkArithmeticException( + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + messageParameters = Map( + "value" -> value.toPlainString, + "precision" -> decimalPrecision.toString, + "scale" -> decimalScale.toString, + "config" -> toSQLConf("spark.sql.ansi.enabled")), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def invalidInputInCastToNumberError( + to: DataType, + s: UTF8String, + context: SQLQueryContext): SparkNumberFormatException = { + val convertedValueStr = "'" + s.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" + new SparkNumberFormatException( + errorClass = "CAST_INVALID_INPUT", + messageParameters = Map( + "expression" -> convertedValueStr, + "sourceType" -> toSQLType(StringType), + "targetType" -> toSQLType(to), + "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")), + context = getQueryContext(context), + summary = getSummary(context)) + } + + def ambiguousColumnOrFieldError( + name: Seq[String], numMatches: Int, context: Origin): Throwable = { + new SparkException( + errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", + messageParameters = Map( + "name" -> toSQLId(name), + "n" -> numMatches.toString), + cause = null, + context = context.getQueryContext) + } + + def castingCauseOverflowError(t: String, from: DataType, to: DataType): ArithmeticException = { + new SparkArithmeticException( + errorClass = "CAST_OVERFLOW", + messageParameters = Map( + "value" -> t, + "sourceType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")), + context = Array.empty, + summary = "") + } + + def failedParsingStructTypeError(raw: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "FAILED_PARSE_STRUCT_TYPE", + messageParameters = Map("raw" -> s"'$raw'")) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index dc95e76a820bb..4a8847959c289 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -16,9 +16,18 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.sql.catalyst.util.SparkStringUtils +import java.util.Locale + +import org.apache.spark.QueryContext +import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils, SparkStringUtils} +import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} private[sql] trait DataTypeErrorsBase { + def toSQLId(parts: String): String = { + toSQLId(AttributeNameParser.parseAttributeName(parts)) + } + def toSQLId(parts: Seq[String]): String = { val cleaned = parts match { case Seq("__auto_generated_subquery_name", rest @ _*) if rest != Nil => rest @@ -26,4 +35,42 @@ private[sql] trait DataTypeErrorsBase { } cleaned.map(SparkStringUtils.quoteIdentifier).mkString(".") } + + def toSQLStmt(text: String): String = { + text.toUpperCase(Locale.ROOT) + } + + def toSQLConf(conf: String): String = { + QuotingUtils.toSQLConf(conf) + } + + def toSQLType(text: String): String = { + quoteByDefault(text.toUpperCase(Locale.ROOT)) + } + + def toSQLType(t: AbstractDataType): String = t match { + case TypeCollection(types) => types.map(toSQLType).mkString("(", " or ", ")") + case dt: DataType => quoteByDefault(dt.sql) + case at => quoteByDefault(at.simpleString.toUpperCase(Locale.ROOT)) + } + + def dataTypeToSQLValue(value: String): String = { + if (value == null) { + "NULL" + } else { + "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'" + } + } + + protected def quoteByDefault(elem: String): String = { + "\"" + elem + "\"" + } + + def getSummary(sqlContext: SQLQueryContext): String = { + if (sqlContext == null) "" else sqlContext.summary + } + + def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { + if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala rename to sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index ced27626be777..2bd385787dd6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -25,13 +25,12 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.StringType /** * Object for grouping all error messages of the query parsing. * Currently it includes all ParseException. */ -private[sql] object QueryParsingErrors extends QueryErrorsBase { +private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def invalidInsertIntoError(ctx: InsertIntoContext): Throwable = { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0001", ctx) @@ -185,7 +184,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { def invalidEscapeStringError(invalidEscape: String, ctx: PredicateContext): Throwable = { new ParseException( errorClass = "INVALID_ESC", - messageParameters = Map("invalidEscape" -> toSQLValue(invalidEscape, StringType)), + messageParameters = Map("invalidEscape" -> dataTypeToSQLValue(invalidEscape)), ctx) } @@ -209,7 +208,7 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { errorClass = "INVALID_TYPED_LITERAL", messageParameters = Map( "valueType" -> toSQLType(valueType), - "value" -> toSQLValue(value, StringType) + "value" -> dataTypeToSQLValue(value) ), ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/CharType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index c9346b9bc1431..632091a9a8c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -30,14 +30,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkThrowable import org.apache.spark.annotation.Stable import org.apache.spark.sql.SqlApiConf -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis import org.apache.spark.sql.catalyst.parser.DataTypeParser import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} import org.apache.spark.sql.catalyst.util.StringConcat import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.YearMonthIntervalType._ -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkClassUtils /** * The base type of all Spark SQL data types. @@ -230,7 +230,7 @@ object DataType { ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => - Utils.classForName[UserDefinedType[_]](udtClass).getConstructor().newInstance() + SparkClassUtils.classForName[UserDefinedType[_]](udtClass).getConstructor().newInstance() // Python UDT case JSortedObject( @@ -381,7 +381,7 @@ object DataType { def equalsStructurallyByName( from: DataType, to: DataType, - resolver: Resolver): Boolean = { + resolver: SqlApiAnalysis.Resolver): Boolean = { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurallyByName(left.elementType, right.elementType, resolver) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f1529285294e2..40186dab822f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,10 +22,9 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.SqlApiConf import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.DataTypeErrors -import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String /** @@ -83,7 +82,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def set(unscaled: Long, precision: Int, scale: Int): Decimal = { if (setOrNull(unscaled, precision, scale) == null) { - throw QueryExecutionErrors.unscaledValueTooLargeForPrecisionError(this, precision, scale) + throw DataTypeErrors.unscaledValueTooLargeForPrecisionError(this, precision, scale) } this } @@ -143,7 +142,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { // the scale is 2. The expected precision should be 2. this._precision = decimal.scale this._scale = decimal.scale - } else if (decimal.scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) { + } else if (decimal.scale < 0 && !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { this._precision = decimal.precision - decimal.scale this._scale = 0 // set scale to 0 to correct unscaled value @@ -278,6 +277,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { private[sql] def roundToInt(): Int = roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue) (_.toInt) (_.toInt) + private def toSqlValue: String = this + "BD" + private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { @@ -285,16 +286,16 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == numericVal) { numericVal } else { - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), integralType) + throw DataTypeErrors.castingCauseOverflowError( + toSqlValue, DecimalType(this.precision, this.scale), integralType) } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= maxValue && Math.ceil(doubleVal) >= minValue) { f2(doubleVal) } else { - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), integralType) + throw DataTypeErrors.castingCauseOverflowError( + toSqlValue, DecimalType(this.precision, this.scale), integralType) } } } @@ -314,8 +315,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal.bigDecimal.toBigInteger.longValueExact() } catch { case _: ArithmeticException => - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), LongType) + throw DataTypeErrors.castingCauseOverflowError( + toSqlValue, DecimalType(this.precision, this.scale), LongType) } } } @@ -348,7 +349,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (nullOnOverflow) { null } else { - throw QueryExecutionErrors.cannotChangeDecimalPrecisionError( + throw DataTypeErrors.cannotChangeDecimalPrecisionError( this, precision, scale, context) } } @@ -602,7 +603,7 @@ object Decimal { // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. // For example: Decimal("6.0790316E+25569151") if (numDigitsInIntegralPart(bigDecimal) > DecimalType.MAX_PRECISION && - !SQLConf.get.allowNegativeScaleOfDecimalEnabled) { + !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { null } else { Decimal(bigDecimal) @@ -622,14 +623,14 @@ object Decimal { // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. // For example: Decimal("6.0790316E+25569151") if (numDigitsInIntegralPart(bigDecimal) > DecimalType.MAX_PRECISION && - !SQLConf.get.allowNegativeScaleOfDecimalEnabled) { + !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { throw DataTypeErrors.outOfDecimalTypeRangeError(str) } else { Decimal(bigDecimal) } } catch { case _: NumberFormatException => - throw QueryExecutionErrors.invalidInputInCastToNumberError(to, str, context) + throw DataTypeErrors.invalidInputInCastToNumberError(to, str, context) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index e460746239201..bc06888e4cf45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -22,8 +22,8 @@ import java.util.Locale import scala.annotation.tailrec import org.apache.spark.annotation.Stable +import org.apache.spark.sql.SqlApiConf import org.apache.spark.sql.errors.DataTypeErrors -import org.apache.spark.sql.internal.SQLConf /** * The data type representing `java.math.BigDecimal` values. @@ -147,7 +147,7 @@ object DecimalType extends AbstractDataType { } private[sql] def checkNegativeScale(scale: Int): Unit = { - if (scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) { + if (scale < 0 && !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) { throw DataTypeErrors.negativeScaleNotAllowedError(scale) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ObjectType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/ObjectType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala similarity index 91% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala index dd267ed763e70..ca15d23b601ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -21,9 +21,8 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded} -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.catalyst.util.StringConcat +import org.apache.spark.sql.catalyst.util.{QuotingUtils, StringConcat} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY} import org.apache.spark.util.SparkSchemaUtils /** @@ -145,7 +144,7 @@ case class StructField( .getOrElse("") private def getDDLComment = getComment() - .map(escapeSingleQuotedString) + .map(QuotingUtils.escapeSingleQuotedString) .map(" COMMENT '" + _ + "'") .getOrElse("") @@ -153,7 +152,7 @@ case class StructField( * Returns a string containing a schema in SQL format. For example the following value: * `StructField("eventId", IntegerType)` will be converted to `eventId`: INT. */ - private[sql] def sql = s"${quoteIfNeeded(name)}: ${dataType.sql}$getDDLComment" + private[sql] def sql = s"${QuotingUtils.quoteIfNeeded(name)}: ${dataType.sql}$getDDLComment" /** * Returns a string containing a schema in DDL format. For example, the following value: @@ -163,6 +162,6 @@ case class StructField( */ def toDDL: String = { val nullString = if (nullable) "" else " NOT NULL" - s"${quoteIfNeeded(name)} ${dataType.sql}${nullString}$getDDLDefault$getDDLComment" + s"${QuotingUtils.quoteIfNeeded(name)} ${dataType.sql}${nullString}$getDDLDefault$getDDLComment" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala similarity index 94% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index df903c57887fa..9eabc6a70bc8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -25,13 +25,11 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.SqlApiConf -import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis import org.apache.spark.sql.catalyst.parser.{DataTypeParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.catalyst.util.{SparkStringUtils, StringConcat} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.util.collection.Utils +import org.apache.spark.sql.catalyst.util.{SparkCollectionUtils, SparkStringUtils, StringConcat} +import org.apache.spark.sql.errors.DataTypeErrors /** * A [[StructType]] object can be constructed by @@ -116,7 +114,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap - private lazy val nameToIndex: Map[String, Int] = Utils.toMapWithIndex(fieldNames) + private lazy val nameToIndex: Map[String, Int] = SparkCollectionUtils.toMapWithIndex(fieldNames) override def equals(that: Any): Boolean = { that match { @@ -322,7 +320,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def findNestedField( fieldNames: Seq[String], includeCollections: Boolean = false, - resolver: Resolver = _ == _, + resolver: SqlApiAnalysis.Resolver = _ == _, context: Origin = Origin()): Option[(Seq[String], StructField)] = { def findFieldInStruct( @@ -333,7 +331,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val searchName = searchPath.head val found = struct.fields.filter(f => resolver(searchName, f.name)) if (found.length > 1) { - throw QueryCompilationErrors.ambiguousColumnOrFieldError(fieldNames, found.length, context) + throw DataTypeErrors.ambiguousColumnOrFieldError(fieldNames, found.length, context) } else if (found.isEmpty) { None } else { @@ -358,7 +356,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru findFieldInStruct(s, searchPath, currentPath) case _ if !includeCollections => - throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context) + throw DataTypeErrors.invalidFieldName(fieldNames, currentPath, context) case (Seq("key", rest @ _*), MapType(keyType, _, _)) => findField(StructField("key", keyType, nullable = false), rest, currentPath) @@ -370,7 +368,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru findField(StructField("element", elementType, isNullable), rest, currentPath) case _ => - throw QueryCompilationErrors.invalidFieldName(fieldNames, currentPath, context) + throw DataTypeErrors.invalidFieldName(fieldNames, currentPath, context) } } } @@ -512,7 +510,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match { case t: StructType => t - case _ => throw QueryExecutionErrors.failedParsingStructTypeError(raw) + case _ => throw DataTypeErrors.failedParsingStructTypeError(raw) } } @@ -531,9 +529,6 @@ object StructType extends AbstractDataType { StructType(fields.asScala.toArray) } - private[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) - private[sql] def removeMetadata(key: String, dt: DataType): DataType = dt match { case StructType(fields) => @@ -583,7 +578,7 @@ object StructType extends AbstractDataType { nullable = leftNullable || rightNullable) } catch { case NonFatal(e) => - throw QueryExecutionErrors.cannotMergeIncompatibleDataTypesError( + throw DataTypeErrors.cannotMergeIncompatibleDataTypesError( leftType, rightType) } } @@ -628,7 +623,7 @@ object StructType extends AbstractDataType { if (leftScale == rightScale) { DecimalType(leftPrecision.max(rightPrecision), leftScale) } else { - throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatibleScaleError( + throw DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError( leftScale, rightScale) } @@ -645,7 +640,7 @@ object StructType extends AbstractDataType { leftType case _ => - throw QueryExecutionErrors.cannotMergeIncompatibleDataTypesError(left, right) + throw DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right) } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { @@ -663,7 +658,7 @@ object StructType extends AbstractDataType { def findMissingFields( source: StructType, target: StructType, - resolver: Resolver): Option[StructType] = { + resolver: SqlApiAnalysis.Resolver): Option[StructType] = { def bothStructType(dt1: DataType, dt2: DataType): Boolean = dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala similarity index 95% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala index 293687a4d61db..42c8c783e54c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.DataTypeErrors -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkClassUtils /** * This object keeps the mappings between user classes and their User Defined Types (UDTs). @@ -73,8 +73,8 @@ object UDTRegistration extends Serializable with Logging { */ def getUDTFor(userClass: String): Option[Class[_]] = { udtMap.get(userClass).map { udtClassName => - if (Utils.classIsLoadable(udtClassName)) { - val udtClass = Utils.classForName(udtClassName) + if (SparkClassUtils.classIsLoadable(udtClassName)) { + val udtClass = SparkClassUtils.classForName(udtClassName) if (classOf[UserDefinedType[_]].isAssignableFrom(udtClass)) { udtClass } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VarcharType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/VarcharType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala rename to sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9dbc8d625d079..12b6da94c8d52 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -39,10 +39,6 @@ org.scala-lang scala-reflect - - org.scala-lang.modules - scala-parser-combinators_${scala.binary.version} - org.apache.spark @@ -111,10 +107,6 @@ org.codehaus.janino commons-compiler - - org.antlr - antlr4-runtime - commons-codec commons-codec diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 9d120e6fc2b06..916c88c376610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.ProjectingInternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, ExprId, Literal, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.RowDeltaUtils._ import org.apache.spark.sql.catalyst.util.WriteDeltaProjections import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations @@ -31,7 +32,6 @@ import org.apache.spark.sql.connector.write.{RowLevelOperation, RowLevelOperatio import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap trait RewriteRowLevelCommand extends Rule[LogicalPlan] { @@ -162,7 +162,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { attrs: Seq[Attribute]): ProjectingInternalRow = { val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name)) - val schema = StructType.fromAttributes(attrs) + val schema = DataTypeUtils.fromAttributes(attrs) ProjectingInternalRow(schema, colOrdinals) } @@ -176,7 +176,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name) if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name) } - val schema = StructType.fromAttributes(rowIdAttrs) + val schema = DataTypeUtils.fromAttributes(rowIdAttrs) ProjectingInternalRow(schema, colOrdinals) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index fc683e98b9795..c0689eb121679 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -35,7 +35,7 @@ package object analysis { * Resolver should return true if the first string refers to the same entity as the second string. * For example, by using case insensitive equality. */ - type Resolver = (String, String) => Boolean + type Resolver = SqlApiAnalysis.Resolver val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b) val caseSensitiveResolution = (a: String, b: String) => a == b diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 0141e66c47990..1c72ec0d69980 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -237,7 +237,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un } } -object UnresolvedAttribute { +object UnresolvedAttribute extends AttributeNameParser { /** * Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.'). */ @@ -257,51 +257,6 @@ object UnresolvedAttribute { */ def quotedString(name: String): UnresolvedAttribute = new UnresolvedAttribute(parseAttributeName(name)) - - /** - * Used to split attribute name by dot with backticks rule. - * Backticks must appear in pairs, and the quoted string must be a complete name part, - * which means `ab..c`e.f is not allowed. - * We can use backtick only inside quoted name parts. - */ - def parseAttributeName(name: String): Seq[String] = { - def e = QueryCompilationErrors.attributeNameSyntaxError(name) - val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] - val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] - var inBacktick = false - var i = 0 - while (i < name.length) { - val char = name(i) - if (inBacktick) { - if (char == '`') { - if (i + 1 < name.length && name(i + 1) == '`') { - tmp += '`' - i += 1 - } else { - inBacktick = false - if (i + 1 < name.length && name(i + 1) != '.') throw e - } - } else { - tmp += char - } - } else { - if (char == '`') { - if (tmp.nonEmpty) throw e - inBacktick = true - } else if (char == '.') { - if (name(i - 1) == '.' || i == name.length - 1) throw e - nameParts += tmp.mkString - tmp.clear() - } else { - tmp += char - } - } - i += 1 - } - if (inBacktick) throw e - nameParts += tmp.mkString - nameParts.toSeq - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 27d05f3bac756..08fda36390593 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -353,7 +353,7 @@ package object dsl { def struct(structType: StructType): AttributeReference = attrRef(structType) def struct(attrs: AttributeReference*): AttributeReference = - struct(StructType.fromAttributes(attrs)) + struct(DataTypeUtils.fromAttributes(attrs)) /** Creates a new AttributeReference of object type */ def obj(cls: Class[_]): AttributeReference = attrRef(ObjectType(cls)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 9b0493f3e68a4..bffec270b6f21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.HyperLogLogPlusPlusHelper import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -87,7 +88,7 @@ case class HyperLogLogPlusPlus( override def dataType: DataType = LongType - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) override def defaultResult: Option[Literal] = Option(Literal.create(0L, dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index b90e46e1545d8..001add6d48207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ @@ -149,7 +150,7 @@ case class PivotFirst( AttributeReference(Option(kv._1).getOrElse("null").toString, valueDataType)() } - override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override val aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index e29714d3a9a1d..bb78aa7dad2f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, TreePattern} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -419,7 +420,7 @@ abstract class DeclarativeAggregate val evaluateExpression: Expression /** An expression-based aggregate's bufferSchema is derived from bufferAttributes. */ - final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + final override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) lazy val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) @@ -606,7 +607,7 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { final override lazy val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) - final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + final override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) /** * In-place replaces the aggregation buffer object stored at buffer's index diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala index bd8f4efa059bf..2adfddb9383e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, LongType, StructType} @@ -175,7 +176,7 @@ case class BitmapConstructAgg(child: Expression, override def nullable: Boolean = false - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) // The aggregation buffer is a fixed size binary. private val bitmapAttr = AttributeReference("bitmap", BinaryType, nullable = false)() @@ -268,7 +269,7 @@ case class BitmapOrAgg(child: Expression, override def nullable: Boolean = false - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes) // The aggregation buffer is a fixed size binary. private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 80ec5e40bec1c..6655a09402d09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -225,7 +225,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { if (conf.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) { // Prunes nested fields in serializers. val prunedSchema = SchemaPruning.pruneSchema( - StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields) + DataTypeUtils.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields) val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) => pruneSerializer(serializer, prunedSchema(idx).dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index bb40aca9129d5..9d00edfe7c41c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,34 +16,20 @@ */ package org.apache.spark.sql.catalyst.parser -import java.lang.{Long => JLong} -import java.nio.CharBuffer import java.util import java.util.Locale import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval -import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode, TerminalNodeImpl} +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl} -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors /** * A collection of utility methods for use during the parsing process. */ -object ParserUtils { - - val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r - val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r - val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r - val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r - - /** Get the command which created the token. */ - def command(ctx: ParserRuleContext): String = { - val stream = ctx.getStart.getInputStream - stream.getText(Interval.of(0, stream.size() - 1)) - } - +object ParserUtils extends SparkParserUtils { def operationNotAllowed(message: String, ctx: ParserRuleContext): Nothing = { throw QueryParsingErrors.operationNotAllowedError(message, ctx) } @@ -89,12 +75,6 @@ object ParserUtils { start.getInputStream.getText(interval) } - /** Convert a string token into a string. */ - def string(token: Token): String = unescapeSQLString(token.getText) - - /** Convert a string node into a string. */ - def string(node: TerminalNode): String = unescapeSQLString(node.getText) - /** Convert a string node into a string without unescaping. */ def stringWithoutUnescape(node: Token): String = { // STRING parser rule forces that the input always has quotes at the starting and ending. @@ -106,30 +86,6 @@ object ParserUtils { Option(value).toSeq.map(x => key -> string(x)) } - /** Get the origin (line and position) of the token. */ - def position(token: Token): Origin = { - val opt = Option(token) - Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) - } - - def positionAndText( - startToken: Token, - stopToken: Token, - sqlText: String, - objectType: Option[String], - objectName: Option[String]): Origin = { - val startOpt = Option(startToken) - val stopOpt = Option(stopToken) - Origin( - line = startOpt.map(_.getLine), - startPosition = startOpt.map(_.getCharPositionInLine), - startIndex = startOpt.map(_.getStartIndex), - stopIndex = stopOpt.map(_.getStopIndex), - sqlText = Some(sqlText), - objectType = objectType, - objectName = objectName) - } - /** Validate the condition. If it doesn't throw a parse exception. */ def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { if (!f) { @@ -140,91 +96,6 @@ object ParserUtils { } } - /** - * Register the origin of the context. Any TreeNode created in the closure will be assigned the - * registered origin. This method restores the previously set origin after completion of the - * closure. - */ - def withOrigin[T](ctx: ParserRuleContext, sqlText: Option[String] = None)(f: => T): T = { - val current = CurrentOrigin.get - val text = sqlText.orElse(current.sqlText) - if (text.isEmpty) { - CurrentOrigin.set(position(ctx.getStart)) - } else { - CurrentOrigin.set(positionAndText(ctx.getStart, ctx.getStop, text.get, - current.objectType, current.objectName)) - } - try { - f - } finally { - CurrentOrigin.set(current) - } - } - - /** Unescape backslash-escaped string enclosed by quotes. */ - def unescapeSQLString(b: String): String = { - val sb = new StringBuilder(b.length()) - - def appendEscapedChar(n: Char): Unit = { - n match { - case '0' => sb.append('\u0000') - case '\'' => sb.append('\'') - case '"' => sb.append('\"') - case 'b' => sb.append('\b') - case 'n' => sb.append('\n') - case 'r' => sb.append('\r') - case 't' => sb.append('\t') - case 'Z' => sb.append('\u001A') - case '\\' => sb.append('\\') - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%' => sb.append("\\%") - case '_' => sb.append("\\_") - case _ => sb.append(n) - } - } - - if (b.startsWith("r") || b.startsWith("R")) { - b.substring(2, b.length - 1) - } else { - // Skip the first and last quotations enclosing the string literal. - val charBuffer = CharBuffer.wrap(b, 1, b.length - 1) - - while (charBuffer.remaining() > 0) { - charBuffer match { - case U16_CHAR_PATTERN(cp) => - // \u0000 style 16-bit unicode character literals. - sb.append(Integer.parseInt(cp, 16).toChar) - charBuffer.position(charBuffer.position() + 6) - case U32_CHAR_PATTERN(cp) => - // \U00000000 style 32-bit unicode character literals. - // Use Long to treat codePoint as unsigned in the range of 32-bit. - val codePoint = JLong.parseLong(cp, 16) - if (codePoint < 0x10000) { - sb.append((codePoint & 0xFFFF).toChar) - } else { - val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xD800 - val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xDC00 - sb.append(highSurrogate.toChar) - sb.append(lowSurrogate.toChar) - } - charBuffer.position(charBuffer.position() + 10) - case OCTAL_CHAR_PATTERN(cp) => - // \000 style character literals. - sb.append(Integer.parseInt(cp, 8).toChar) - charBuffer.position(charBuffer.position() + 4) - case ESCAPED_CHAR_PATTERN(c) => - // escaped character literals. - appendEscapedChar(c.charAt(0)) - charBuffer.position(charBuffer.position() + 2) - case _ => - // non-escaped character literals. - sb.append(charBuffer.get()) - } - } - sb.toString() - } - } - /** the column name pattern in quoted regex without qualifier */ val escapedIdentifier = "`((?s).+)`".r diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 59ce4e8558018..aee4790eb42aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.rules.UnknownRuleId import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin, TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.trees.TreePattern.{OUTER_REFERENCE, PLAN_EXPRESSION} import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.collection.BitSet @@ -406,7 +407,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] } } - lazy val schema: StructType = StructType.fromAttributes(output) + lazy val schema: StructType = DataTypeUtils.fromAttributes(output) /** Returns the output schema in the tree format. */ def schemaString: String = schema.treeString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 78c1087a1b069..8b9d8c91815fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -39,13 +39,13 @@ object LocalRelation { } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { - val schema = StructType.fromAttributes(output) + val schema = DataTypeUtils.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { - val schema = StructType.fromAttributes(output) + val schema = DataTypeUtils.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f8ba042009b2e..4bb830662a33f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1183,7 +1183,7 @@ object Aggregate { } def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { - val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + val aggregationBufferSchema = DataTypeUtils.fromAttributes(aggregateBufferAttributes) isAggregateBufferMutable(aggregationBufferSchema) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 75dab55dccfda..d79d55bc9646d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftAnti, LeftSemi, ReferenceAllColumns} import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -193,7 +194,7 @@ case class MapPartitionsInRWithArrow( override lazy val references: AttributeSet = child.outputSet override protected def stringArgs: Iterator[Any] = Iterator( - inputSchema, StructType.fromAttributes(output), child) + inputSchema, DataTypeUtils.fromAttributes(output), child) override val producedAttributes = AttributeSet(output) @@ -658,7 +659,7 @@ case class FlatMapGroupsInRWithArrow( override lazy val references: AttributeSet = child.outputSet override protected def stringArgs: Iterator[Any] = Iterator( - inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child) + inputSchema, DataTypeUtils.fromAttributes(output), keyDeserializer, groupingAttributes, child) override val producedAttributes = AttributeSet(output) @@ -678,7 +679,7 @@ object CoGroup { rightOrder: Seq[SortOrder], left: LogicalPlan, right: LogicalPlan): LogicalPlan = { - require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + require(DataTypeUtils.fromAttributes(leftGroup) == DataTypeUtils.fromAttributes(rightGroup)) val cogrouped = CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala index feff781a4038d..1fbe0a41678ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, Literal} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy @@ -200,6 +200,9 @@ object DataTypeUtils { schema.map(toAttribute) } + def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + /** * Convert a literal to a DecimalType. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 6489fb9aaafd5..50ff3eeab0c16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -41,29 +41,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * This object contains fields to help process DEFAULT columns. */ -object ResolveDefaultColumns extends QueryErrorsBase { - // This column metadata indicates the default value associated with a particular table column that - // is in effect at any given time. Its value begins at the time of the initial CREATE/REPLACE - // TABLE statement with DEFAULT column definition(s), if any. It then changes whenever an ALTER - // TABLE statement SETs the DEFAULT. The intent is for this "current default" to be used by - // UPDATE, INSERT and MERGE, which evaluate each default expression for each row. - val CURRENT_DEFAULT_COLUMN_METADATA_KEY = "CURRENT_DEFAULT" - // This column metadata represents the default value for all existing rows in a table after a - // column has been added. This value is determined at time of CREATE TABLE, REPLACE TABLE, or - // ALTER TABLE ADD COLUMN, and never changes thereafter. The intent is for this "exist default" to - // be used by any scan when the columns in the source row are missing data. For example, consider - // the following sequence: - // CREATE TABLE t (c1 INT) - // INSERT INTO t VALUES (42) - // ALTER TABLE t ADD COLUMNS (c2 INT DEFAULT 43) - // SELECT c1, c2 FROM t - // In this case, the final query is expected to return 42, 43. The ALTER TABLE ADD COLUMNS command - // executed after there was already data in the table, so in order to enforce this invariant, we - // need either (1) an expensive backfill of value 43 at column c2 into all previous rows, or (2) - // indicate to each data source that selected columns missing data are to generate the - // corresponding DEFAULT value instead. We choose option (2) for efficiency, and represent this - // value as the text representation of a folded constant in the "EXISTS_DEFAULT" column metadata. - val EXISTS_DEFAULT_COLUMN_METADATA_KEY = "EXISTS_DEFAULT" +object ResolveDefaultColumns extends QueryErrorsBase with ResolveDefaultColumnsUtils { // Name of attributes representing explicit references to the value stored in the above // CURRENT_DEFAULT_COLUMN_METADATA. val CURRENT_DEFAULT_COLUMN_NAME = "DEFAULT" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 35d9b256800f6..8721bb809fc19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -131,24 +131,13 @@ package object util extends Logging { } def quoteIfNeeded(part: String): String = { - if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) { - part - } else { - s"`${part.replace("`", "``")}`" - } + QuotingUtils.quoteIfNeeded(part) } def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql def escapeSingleQuotedString(str: String): String = { - val builder = new StringBuilder - - str.foreach { - case '\'' => builder ++= s"\\\'" - case ch => builder += ch - } - - builder.toString() + QuotingUtils.escapeSingleQuotedString(str) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 346f25580aaeb..6fd20b7c34a16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -794,9 +794,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { } def attributeNameSyntaxError(name: String): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1049", - messageParameters = Map("name" -> name)) + DataTypeErrors.attributeNameSyntaxError(name) } def starExpandDataTypeNotSupportedError(attributes: Seq[String]): Throwable = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala index 0a44ededebd4d..db256fbee8785 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryErrorsBase.scala @@ -17,14 +17,9 @@ package org.apache.spark.sql.errors -import java.util.Locale - -import org.apache.spark.QueryContext -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{toPrettySQL, QuotingUtils} -import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType, FloatType, TypeCollection} +import org.apache.spark.sql.types.{DataType, DoubleType, FloatType} /** * The trait exposes util methods for preparing error messages such as quoting of error elements. @@ -47,43 +42,6 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType, Float * For example: "earnings + 1". */ private[sql] trait QueryErrorsBase extends DataTypeErrorsBase { - // Converts an error class parameter to its SQL representation - def toSQLValue(v: Any, t: DataType): String = Literal.create(v, t) match { - case Literal(null, _) => "NULL" - case Literal(v: Float, FloatType) => - if (v.isNaN) "NaN" - else if (v.isPosInfinity) "Infinity" - else if (v.isNegInfinity) "-Infinity" - else v.toString - case l @ Literal(v: Double, DoubleType) => - if (v.isNaN) "NaN" - else if (v.isPosInfinity) "Infinity" - else if (v.isNegInfinity) "-Infinity" - else l.sql - case l => l.sql - } - - private def quoteByDefault(elem: String): String = { - "\"" + elem + "\"" - } - - def toSQLStmt(text: String): String = { - text.toUpperCase(Locale.ROOT) - } - - def toSQLId(parts: String): String = { - toSQLId(UnresolvedAttribute.parseAttributeName(parts)) - } - - def toSQLType(t: AbstractDataType): String = t match { - case TypeCollection(types) => types.map(toSQLType).mkString("(", " or ", ")") - case dt: DataType => quoteByDefault(dt.sql) - case at => quoteByDefault(at.simpleString.toUpperCase(Locale.ROOT)) - } - - def toSQLType(text: String): String = { - quoteByDefault(text.toUpperCase(Locale.ROOT)) - } def toSQLConfVal(conf: String): String = { quoteByDefault(conf) @@ -93,24 +51,28 @@ private[sql] trait QueryErrorsBase extends DataTypeErrorsBase { quoteByDefault(option) } - def toSQLConf(conf: String): String = { - QuotingUtils.toSQLConf(conf) - } - def toSQLExpr(e: Expression): String = { quoteByDefault(toPrettySQL(e)) } - def getSummary(sqlContext: SQLQueryContext): String = { - if (sqlContext == null) "" else sqlContext.summary - } - - def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { - if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) - } - def toSQLSchema(schema: String): String = { QuotingUtils.toSQLSchema(schema) } + + // Converts an error class parameter to its SQL representation + def toSQLValue(v: Any, t: DataType): String = Literal.create(v, t) match { + case Literal(null, _) => "NULL" + case Literal(v: Float, FloatType) => + if (v.isNaN) "NaN" + else if (v.isPosInfinity) "Infinity" + else if (v.isNegInfinity) "-Infinity" + else v.toString + case l @ Literal(v: Double, DoubleType) => + if (v.isNaN) "NaN" + else if (v.isPosInfinity) "Infinity" + else if (v.isNegInfinity) "-Infinity" + else l.sql + case l => l.sql + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 08955b3216894..433d23eaf3f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -202,9 +202,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { } def dataTypeUnsupportedError(dataType: String, failure: String): Throwable = { - new SparkIllegalArgumentException( - errorClass = "UNSUPPORTED_DATATYPE", - messageParameters = Map("typeName" -> (dataType + failure))) + DataTypeErrors.dataTypeUnsupportedError(dataType, failure) } def failedExecuteUserDefinedFunctionError(functionName: String, inputTypes: String, @@ -1278,15 +1276,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { decimalPrecision: Int, decimalScale: Int, context: SQLQueryContext = null): ArithmeticException = { - new SparkArithmeticException( - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", - messageParameters = Map( - "value" -> value.toPlainString, - "precision" -> decimalPrecision.toString, - "scale" -> decimalScale.toString, - "config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = getQueryContext(context), - summary = getSummary(context)) + DataTypeErrors.unscaledValueTooLargeForPrecisionError( + value, decimalPrecision, decimalScale, context) } def decimalPrecisionExceedsMaxPrecisionError( @@ -1319,21 +1310,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { def cannotMergeDecimalTypesWithIncompatibleScaleError( leftScale: Int, rightScale: Int): Throwable = { - new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2124", - messageParameters = Map( - "leftScale" -> leftScale.toString(), - "rightScale" -> rightScale.toString()), - cause = null) + DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(leftScale, rightScale) } def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { - new SparkException( - errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE", - messageParameters = Map( - "left" -> toSQLType(left), - "right" -> toSQLType(right)), - cause = null) + DataTypeErrors.cannotMergeIncompatibleDataTypesError(left, right) } def exceedMapSizeLimitError(size: Int): SparkRuntimeException = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala index 5957adcd75e44..2c8bb8a6ac92c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -31,11 +31,11 @@ trait SQLKeywordUtils extends SparkFunSuite with SQLHelper { val sqlSyntaxDefs = { val sqlBaseParserPath = - getWorkspaceFilePath("sql", "catalyst", "src", "main", "antlr4", "org", + getWorkspaceFilePath("sql", "api", "src", "main", "antlr4", "org", "apache", "spark", "sql", "catalyst", "parser", "SqlBaseParser.g4").toFile val sqlBaseLexerPath = - getWorkspaceFilePath("sql", "catalyst", "src", "main", "antlr4", "org", + getWorkspaceFilePath("sql", "api", "src", "main", "antlr4", "org", "apache", "spark", "sql", "catalyst", "parser", "SqlBaseLexer.g4").toFile (fileToString(sqlBaseParserPath) + fileToString(sqlBaseLexerPath)).split("\n") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index fe45e02c67fac..e8d2ca1ff75de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Expand, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StructType} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -201,7 +202,7 @@ class PropagateEmptyRelationSuite extends PlanTest { test("propagate empty streaming relation through multiple UnaryNode") { val output = Seq($"a".int) val data = Seq(Row(1)) - val schema = StructType.fromAttributes(output) + val schema = DataTypeUtils.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) val relation = LocalRelation( output, @@ -224,7 +225,7 @@ class PropagateEmptyRelationSuite extends PlanTest { test("don't propagate empty streaming relation through agg") { val output = Seq($"a".int) val data = Seq(Row(1)) - val schema = StructType.fromAttributes(output) + val schema = DataTypeUtils.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) val relation = LocalRelation( output, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 6eedd7f9b6f1b..623c6c69165d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -324,7 +324,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { check(Seq("s1", "S12"), Some(Seq("s1") -> StructField("s12", IntegerType))) caseSensitiveCheck(Seq("s1", "S12"), None) check(Seq("S1.non_exist"), None) - var e = intercept[AnalysisException] { + var e = intercept[SparkException] { check(Seq("S1", "S12", "S123"), None) } checkError( @@ -335,7 +335,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { "path" -> "`s1`.`s12`")) // ambiguous name - var e2 = intercept[AnalysisException] { + var e2 = intercept[SparkException] { check(Seq("S2", "x"), None) } checkError( @@ -345,7 +345,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { caseSensitiveCheck(Seq("s2", "x"), Some(Seq("s2") -> StructField("x", IntegerType))) // simple map type - e = intercept[AnalysisException] { + e = intercept[SparkException] { check(Seq("m1", "key"), None) } checkError( @@ -356,7 +356,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { "path" -> "`m1`")) checkCollection(Seq("m1", "key"), Some(Seq("m1") -> StructField("key", IntegerType, false))) checkCollection(Seq("M1", "value"), Some(Seq("m1") -> StructField("value", IntegerType))) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("M1", "key", "name"), None) } checkError( @@ -365,7 +365,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { parameters = Map( "fieldName" -> "`M1`.`key`.`name`", "path" -> "`m1`.`key`")) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("M1", "value", "name"), None) } checkError( @@ -382,7 +382,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { checkCollection(Seq("M2", "value", "b"), Some(Seq("m2", "value") -> StructField("b", IntegerType))) checkCollection(Seq("M2", "value", "non_exist"), None) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("m2", "key", "A", "name"), None) } checkError( @@ -391,7 +391,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { parameters = Map( "fieldName" -> "`m2`.`key`.`A`.`name`", "path" -> "`m2`.`key`.`a`")) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("M2", "value", "b", "name"), None) } checkError( @@ -401,7 +401,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { "fieldName" -> "`M2`.`value`.`b`.`name`", "path" -> "`m2`.`value`.`b`")) // simple array type - e = intercept[AnalysisException] { + e = intercept[SparkException] { check(Seq("A1", "element"), None) } checkError( @@ -411,7 +411,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { "fieldName" -> "`A1`.`element`", "path" -> "`a1`")) checkCollection(Seq("A1", "element"), Some(Seq("a1") -> StructField("element", IntegerType))) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("A1", "element", "name"), None) } checkError( @@ -425,7 +425,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { checkCollection(Seq("A2", "element", "C"), Some(Seq("a2", "element") -> StructField("c", IntegerType))) checkCollection(Seq("A2", "element", "non_exist"), None) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("a2", "element", "C", "name"), None) } checkError( @@ -439,7 +439,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { checkCollection(Seq("M3", "value", "value", "MA"), Some(Seq("m3", "value", "value") -> StructField("ma", IntegerType))) checkCollection(Seq("M3", "value", "value", "non_exist"), None) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("M3", "value", "value", "MA", "name"), None) } checkError( @@ -453,7 +453,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { checkCollection(Seq("A3", "element", "element", "D"), Some(Seq("a3", "element", "element") -> StructField("d", IntegerType))) checkCollection(Seq("A3", "element", "element", "non_exist"), None) - e = intercept[AnalysisException] { + e = intercept[SparkException] { checkCollection(Seq("A3", "element", "element", "D", "name"), None) } checkError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index 6f0b0bb2231a3..dc918e51d0550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.types.StructType @@ -42,7 +43,7 @@ case class CollectMetricsExec( } val metricsSchema: StructType = { - StructType.fromAttributes(metricExpressions.map(_.toAttribute)) + DataTypeUtils.fromAttributes(metricExpressions.map(_.toAttribute)) } // This is not used very frequently (once a query); it is not useful to use code generation here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0eca4a1b04a15..5c93c72e36d62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -41,7 +42,6 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -776,7 +776,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => - val encoder = RowEncoder(StructType.fromAttributes(output)) + val encoder = RowEncoder(DataTypeUtils.fromAttributes(output)) val toRow = encoder.createSerializer() LocalTableScanExec(output, sink.allData.map(r => toRow(r).copy())) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c81ef40398090..b942907b6752f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.truncatedString @@ -38,7 +39,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -133,11 +134,11 @@ case class HashAggregateExec( } private val groupingAttributes = groupingExpressions.map(_.toAttribute) - private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val groupingKeySchema = DataTypeUtils.fromAttributes(groupingAttributes) private val declFunctions = aggregateExpressions.map(_.aggregateFunction) .filter(_.isInstanceOf[DeclarativeAggregate]) .map(_.asInstanceOf[DeclarativeAggregate]) - private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + private val bufferSchema = DataTypeUtils.fromAttributes(aggregateBufferAttributes) // The name for Fast HashMap private var fastHashMapTerm: String = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 1d89e56eebde0..fac3f7d6d8ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf @@ -189,8 +190,8 @@ class ObjectAggregationIterator( .sortedIterator() sortBasedAggregationStore = new SortBasedAggregator( sortIteratorFromHashMap, - StructType.fromAttributes(originalInputAttributes), - StructType.fromAttributes(groupingAttributes), + DataTypeUtils.fromAttributes(originalInputAttributes), + DataTypeUtils.fromAttributes(groupingAttributes), processRow, mergeAggregationBuffers, createNewAggregationBuffer()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index 6aede04b069ef..9b68e6f02a859 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -24,8 +24,8 @@ import org.apache.spark.internal.config import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.UnsafeKVExternalSorter -import org.apache.spark.sql.types.StructType /** * An aggregation map that supports using safe `SpecificInternalRow`s aggregation buffers, so that @@ -73,8 +73,8 @@ class ObjectAggregationMap() { aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = { val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) val sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(groupingAttributes), - StructType.fromAttributes(aggBufferAttributes), + DataTypeUtils.fromAttributes(groupingAttributes), + DataTypeUtils.fromAttributes(aggBufferAttributes), SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 36405fe927273..7de2215037b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator /** @@ -140,8 +140,8 @@ class TungstenAggregationIterator( // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection val groupingAttributes = groupingExpressions.map(_.toAttribute) val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) + val groupingKeySchema = DataTypeUtils.fromAttributes(groupingAttributes) + val bufferSchema = DataTypeUtils.fromAttributes(bufferAttributes) val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { @@ -165,8 +165,8 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), - StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), + DataTypeUtils.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), + DataTypeUtils.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index b48493e4ff027..45d006b58e879 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} import org.apache.spark.sql.execution.{ColumnarToRowTransition, InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec} @@ -144,7 +145,7 @@ class DefaultCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { selectedAttributes: Seq[Attribute], conf: SQLConf): RDD[ColumnarBatch] = { val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled - val outputSchema = StructType.fromAttributes(selectedAttributes) + val outputSchema = DataTypeUtils.fromAttributes(selectedAttributes) val columnIndices = selectedAttributes.map(a => cacheAttributes.map(o => o.exprId).indexOf(a.exprId)).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 454cc0b5f56e7..3ddb897f7082c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -740,7 +741,8 @@ object DataSourceStrategy output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.needConversion) { - val toRow = RowEncoder(StructType.fromAttributes(output), lenient = true).createSerializer() + val toRow = + RowEncoder(DataTypeUtils.fromAttributes(output), lenient = true).createSerializer() rdd.mapPartitions { iterator => iterator.map(toRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index b3d68003be33c..8011cc0074389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.getPartitionValueString import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types._ @@ -364,7 +365,7 @@ object PartitioningUtils extends SQLConfHelper { } def getPathFragment(spec: TablePartitionSpec, partitionColumns: Seq[Attribute]): String = { - getPathFragment(spec, StructType.fromAttributes(partitionColumns)) + getPathFragment(spec, DataTypeUtils.fromAttributes(partitionColumns)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala index 31e4a772dc1a6..42ad83c9821bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2CommandExec.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema} import org.apache.spark.sql.catalyst.trees.LeafLike +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.StructType /** * A physical operator that executes run() and saves the result to prevent multiple executions. @@ -65,7 +65,7 @@ abstract class V2CommandExec extends SparkPlan { } private lazy val rowSerializer = { - RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer() + RowEncoder(DataTypeUtils.fromAttributes(output)).resolveAndBind().createSerializer() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index de43d67e621c5..7e534b92dd92e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.WriteDeltaProjections import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -96,7 +97,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) => - val rowSchema = StructType.fromAttributes(rd.dataInput) + val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput) val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema) val write = writeBuilder.build() val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 8d967458ad7ce..91f2099ce2d53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} import org.apache.spark.util.random.XORShiftRandom @@ -364,7 +364,7 @@ object ShuffleExchangeExec { val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = UnsafeExternalRowSorter.createWithRecordComparator( - StructType.fromAttributes(outputAttributes), + DataTypeUtils.fromAttributes(outputAttributes), recordComparatorSupplier, prefixComparator, prefixComputer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c8d575016fc20..0ae699240ca08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.python.BatchIterator import org.apache.spark.sql.execution.r.ArrowRRunner import org.apache.spark.sql.execution.streaming.GroupStateImpl @@ -593,7 +594,7 @@ case class FlatMapGroupsInRWithArrowExec( // binary in a batch due to the limitation of R API. See also ARROW-4512. val columnarBatchIter = runner.compute(groupedByRKey, -1) val outputProject = UnsafeProjection.create(output, output) - val outputTypes = StructType.fromAttributes(output).map(_.dataType) + val outputTypes = DataTypeUtils.fromAttributes(output).map(_.dataType) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index 4da9fb8a60bbe..9ef133c6bea74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} import org.apache.spark.sql.execution.python.PandasGroupUtils._ -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -95,8 +95,8 @@ case class FlatMapCoGroupsInPandasExec( chainedFunc, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, Array(leftArgOffsets ++ rightArgOffsets), - StructType.fromAttributes(leftDedup), - StructType.fromAttributes(rightDedup), + DataTypeUtils.fromAttributes(leftDedup), + DataTypeUtils.fromAttributes(rightDedup), sessionLocalTimeZone, pythonRunnerConf, pythonMetrics, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 77385e9a7d316..0ae5a99894337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -23,9 +23,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.python.PandasGroupUtils._ -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -90,7 +90,7 @@ case class FlatMapGroupsInPandasExec( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, Array(argOffsets), - StructType.fromAttributes(dedupAttributes), + DataTypeUtils.fromAttributes(dedupAttributes), sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 8366c0c25ae4f..d80320404b0f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets import org.apache.spark.sql.execution.streaming._ @@ -170,7 +171,7 @@ case class FlatMapGroupsInPandasWithStateExec( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, Array(argOffsets), - StructType.fromAttributes(dedupAttributesWithNull), + DataTypeUtils.fromAttributes(dedupAttributesWithNull), sessionLocalTimeZone, pythonRunnerConf, stateEncoder.asInstanceOf[ExpressionEncoder[Row]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala index 36138f11b76cf..97feb9b579af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.StructType /** @@ -160,8 +161,8 @@ class StreamingAggregationStateManagerImplV2( GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) @transient private lazy val joiner = - GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), - StructType.fromAttributes(valueExpressions)) + GenerateUnsafeRowJoiner.create(DataTypeUtils.fromAttributes(keyExpressions), + DataTypeUtils.fromAttributes(valueExpressions)) @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( inputRowAttributes, keyValueJoinedExpressions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9f210ae437108..39db165148cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1721,7 +1721,7 @@ class DataFrameSuite extends QueryTest def checkSyntaxError(name: String): Unit = { checkError( - exception = intercept[org.apache.spark.sql.AnalysisException] { + exception = intercept[SparkException] { df(name) }, errorClass = "_LEGACY_ERROR_TEMP_1049", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index a18c767570f01..248783dd6c6f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -838,7 +838,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase { assert(e2.getMessage.contains("Missing field point.non_exist")) // `AlterTable.resolved` checks column existence. - intercept[AnalysisException]( + intercept[SparkException]( sql(s"ALTER TABLE $t ALTER COLUMN a.y AFTER x")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 7f938deaaa645..04f9da5312c38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.errors -import org.apache.spark.SPARK_DOC_ROOT +import org.apache.spark.{SPARK_DOC_ROOT, SparkException} import org.apache.spark.sql.{AnalysisException, ClassData, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.api.java.{UDF1, UDF2, UDF23Test} import org.apache.spark.sql.catalyst.parser.ParseException @@ -532,7 +532,7 @@ class QueryCompilationErrorsSuite val query = "ALTER TABLE t CHANGE COLUMN c.X COMMENT 'new comment'" checkError( - exception = intercept[AnalysisException] { + exception = intercept[SparkException] { sql(query) }, errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala index 63964665fc81c..b065c9a27a459 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoinSuite.scala @@ -21,13 +21,14 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{BinaryExecNode, FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, PartitionSpec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.IntegerType class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { private val SORT_MERGE_JOIN = "sortMergeJoin" @@ -68,7 +69,7 @@ class CoalesceBucketsInJoinSuite extends SQLTestUtils with SharedSparkSession { val relation = HadoopFsRelation( location = new InMemoryFileIndex(spark, Nil, Map.empty, None), partitionSchema = PartitionSpec.emptySpec.partitionColumns, - dataSchema = StructType.fromAttributes(setting.cols), + dataSchema = DataTypeUtils.fromAttributes(setting.cols), bucketSpec = Some(BucketSpec(setting.numBuckets, setting.cols.map(_.name), Nil)), fileFormat = new ParquetFileFormat(), options = Map.empty)(spark) From 39c732f89cdbd89c3e8bc39cf233eba338ad705a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 20 Jul 2023 17:44:14 +0800 Subject: [PATCH 030/986] [SPARK-43839][SQL][FOLLOWUP] Convert _LEGACY_ERROR_TEMP_1337 to UNSUPPORTED_FEATURE.TIME_TRAVEL ### What changes were proposed in this pull request? - The pr is following up https://github.com/apache/spark/pull/41349. - The pr aims to simplify code logic after merge `_LEGACY_ERROR_TEMP_1337` to `UNSUPPORTED_FEATURE.TIME_TRAVEL`. ### Why are the changes needed? Simplify code logic. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. Closes #42082 from panbingkun/SPARK-43839_FOLLOWUP. Authored-by: panbingkun Signed-off-by: Wenchen Fan (cherry picked from commit 6700f3ce8b1020186a2f0871caecb74354650922) Signed-off-by: Wenchen Fan --- .../datasources/v2/V2SessionCatalog.scala | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index f311ccbb6309d..a7062a9a596c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -88,19 +88,11 @@ class V2SessionCatalog(catalog: SessionCatalog) } private def failTimeTravel(ident: Identifier, t: Table): Table = { - t match { - case V1Table(catalogTable) => - if (catalogTable.tableType == CatalogTableType.VIEW) { - throw QueryCompilationErrors.timeTravelUnsupportedError( - toSQLId(catalogTable.identifier.nameParts)) - } else { - throw QueryCompilationErrors.timeTravelUnsupportedError( - toSQLId(catalogTable.identifier.nameParts)) - } - - case _ => throw QueryCompilationErrors.timeTravelUnsupportedError( - toSQLId(ident.asTableIdentifier.nameParts)) + val nameParts = t match { + case V1Table(catalogTable) => catalogTable.identifier.nameParts + case _ => ident.asTableIdentifier.nameParts } + throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(nameParts)) } override def invalidateTable(ident: Identifier): Unit = { From 337248d201250af7b244406abf2888b0a7e08ea2 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 20 Jul 2023 03:22:34 -0700 Subject: [PATCH 031/986] [SPARK-44494][INFRA] Use `minikube` v1.30.1 for `k8s-integration-tests` ### What changes were proposed in this pull request? This pr change to use `minikube` v1.30.1 for `k8s-integration-tests` on GitHub Action, this is a temporary solution. This PR also leaves a TODO: - SPARK-44495: Resume to use the latest minikube for `k8s-integration-tests` on GitHub Action ### Why are the changes needed? Restore `k8s-integration-tests` GA testing ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - `k8s-integration-tests` test pass on GitHub Action Closes #42091 from LuciferYang/SPARK-44494. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun (cherry picked from commit 57b985c5d57eea614024a4e125eba63663978de3) Signed-off-by: Dongjoon Hyun --- .github/workflows/build_and_test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f5a109d95de5f..54fe9f38dddf2 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -998,7 +998,9 @@ jobs: - name: start minikube run: | # See more in "Installation" https://minikube.sigs.k8s.io/docs/start/ - curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64 + # curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64 + # TODO(SPARK-44495): Resume to use the latest minikube for k8s-integration-tests. + curl -LO https://storage.googleapis.com/minikube/releases/v1.30.1/minikube-linux-amd64 sudo install minikube-linux-amd64 /usr/local/bin/minikube # Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic minikube start --cpus 2 --memory 6144 From c13b0ffc4377a0d7ee9446deef3ed032532d3be3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 20 Jul 2023 13:22:59 -0700 Subject: [PATCH 032/986] [SPARK-44501][K8S] Ignore checksum files in KubernetesLocalDiskShuffleExecutorComponents ### What changes were proposed in this pull request? This PR aims to improve `KubernetesLocalDiskShuffleExecutorComponents` by ignoring checksum files. ### Why are the changes needed? To reduce the overhead of `BlockManager.TempFileBasedBlockStoreUpdater` API call. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. Closes #42094 from dongjoon-hyun/SPARK-44501. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun (cherry picked from commit ca822d5f25d5e0bfa4594aea76a26d53bd01b109) Signed-off-by: Dongjoon Hyun --- ...esLocalDiskShuffleExecutorComponents.scala | 11 +++++++++-- ...ubernetesLocalDiskShuffleDataIOSuite.scala | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleExecutorComponents.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleExecutorComponents.scala index 3d6379b871388..8f0729067b9c3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleExecutorComponents.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleExecutorComponents.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import org.apache.commons.io.FileExistsException import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.k8s.Config.KUBERNETES_DRIVER_REUSE_PVC import org.apache.spark.internal.Logging import org.apache.spark.shuffle.api.{ShuffleExecutorComponents, ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents @@ -41,12 +42,15 @@ class KubernetesLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) appId: String, execId: String, extraConfigs: java.util.Map[String, String]): Unit = { delegate.initializeExecutor(appId, execId, extraConfigs) blockManager = SparkEnv.get.blockManager - if (sparkConf.getBoolean("spark.kubernetes.driver.reusePersistentVolumeClaim", false)) { + if (sparkConf.getBoolean(KUBERNETES_DRIVER_REUSE_PVC.key, false)) { + logInfo("Try to recover shuffle data.") // Turn off the deletion of the shuffle data in order to reuse blockManager.diskBlockManager.deleteFilesOnStop = false Utils.tryLogNonFatalError { KubernetesLocalDiskShuffleExecutorComponents.recoverDiskStore(sparkConf, blockManager) } + } else { + logInfo(s"Skip recovery because ${KUBERNETES_DRIVER_REUSE_PVC.key} is disabled.") } } @@ -80,7 +84,7 @@ object KubernetesLocalDiskShuffleExecutorComponents extends Logging { .flatMap(_.listFiles).filter(_.isDirectory) // executor-xxx .flatMap(_.listFiles).filter(_.isDirectory) // blockmgr-xxx .flatMap(_.listFiles).filter(_.isDirectory) // 00 - .flatMap(_.listFiles) + .flatMap(_.listFiles).filterNot(_.getName.contains(".checksum")) if (files != null) files.toSeq else Seq.empty } @@ -91,14 +95,17 @@ object KubernetesLocalDiskShuffleExecutorComponents extends Logging { val level = StorageLevel.DISK_ONLY val (indexFiles, dataFiles) = files.partition(_.getName.endsWith(".index")) (dataFiles ++ indexFiles).foreach { f => + logInfo(s"Try to recover ${f.getAbsolutePath}") try { val id = BlockId(f.getName) val decryptedSize = f.length() bm.TempFileBasedBlockStoreUpdater(id, level, classTag, f, decryptedSize).save() } catch { case _: UnrecognizedBlockId => + logInfo("Skip due to UnrecognizedBlockId.") case _: FileExistsException => // This may happen due to recompute, but we continue to recover next files + logInfo("Ignore due to FileExistsException.") } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleDataIOSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleDataIOSuite.scala index f3d45ced1bb65..d105ac0417182 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleDataIOSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/shuffle/KubernetesLocalDiskShuffleDataIOSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.shuffle +import java.io.File +import java.nio.file.Files + import scala.concurrent.duration._ +import org.mockito.Mockito.{mock, when} import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} import org.apache.spark.{LocalRootDirsTest, MapOutputTrackerMaster, SparkContext, SparkFunSuite, TestUtils} @@ -26,6 +30,7 @@ import org.apache.spark.LocalSparkContext.withSpark import org.apache.spark.deploy.k8s.Config.KUBERNETES_DRIVER_REUSE_PVC import org.apache.spark.internal.config._ import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend +import org.apache.spark.storage.BlockManager class KubernetesLocalDiskShuffleDataIOSuite extends SparkFunSuite with LocalRootDirsTest { @@ -219,4 +224,18 @@ class KubernetesLocalDiskShuffleDataIOSuite extends SparkFunSuite with LocalRoot } } } + + test("SPARK-44501: Ignore checksum files") { + val sparkConf = conf.clone.set("spark.local.dir", + conf.get("spark.local.dir") + "/spark-x/executor-y") + val dir = sparkConf.get("spark.local.dir") + "/blockmgr-z/00" + Files.createDirectories(new File(dir).toPath()) + Seq("ADLER32", "CRC32").foreach { algorithm => + new File(dir, s"1.checksum.$algorithm").createNewFile() + } + + val bm = mock(classOf[BlockManager]) + when(bm.TempFileBasedBlockStoreUpdater).thenAnswer(_ => throw new Exception()) + KubernetesLocalDiskShuffleExecutorComponents.recoverDiskStore(sparkConf, bm) + } } From b2d9b62ea5a2da47e299e0dad4e6cf7e975cc661 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 21 Jul 2023 08:23:45 +0800 Subject: [PATCH 033/986] [SPARK-44292][SQL][FOLLOWUP] Make TYPE_CHECK_FAILURE_WITH_HINT use correct name ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/41850 uses `TYPE_CHECK_FAILURE_WITH_HINT`, it should be `DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT`. The first commit come from https://github.com/apache/spark/pull/34747. ### Why are the changes needed? Fix a bug. ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? N/A Closes #42084 from beliefer/SPARK-44292_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan (cherry picked from commit 325888bc521e80abde2ed88d2c82ba2337e8cc6f) Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7085a040d66c9..af9ea814a51da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -279,7 +279,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.setTagValue(DATA_TYPE_MISMATCH_ERROR, true) val extraHint = extraHintForAnsiTypeCoercionExpression(operator) e.failAnalysis( - errorClass = "TYPE_CHECK_FAILURE_WITH_HINT", + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", messageParameters = Map( "expr" -> toSQLExpr(e), "msg" -> message, From 4c9e3eb5c640f861d385d1e5079d338cbd2d1c14 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 21 Jul 2023 09:44:45 +0900 Subject: [PATCH 034/986] [SPARK-44477][SQL] Treat TYPE_CHECK_FAILURE_WITH_HINT as an error subclass ### What changes were proposed in this pull request? In `CheckAnalysis#checkAnalysis0`, qualify the error subclass `TYPE_CHECK_FAILURE_WITH_HINT` with the error class `DATATYPE_MISMATCH`. ### Why are the changes needed? `CheckAnalysis` treats `TYPE_CHECK_FAILURE_WITH_HINT` as an error class, but it is actually an error subclass of `DATATYPE_MISMATCH`. ``` spark-sql (default)> select bitmap_count(12); [INTERNAL_ERROR] Cannot find main error class 'TYPE_CHECK_FAILURE_WITH_HINT' org.apache.spark.SparkException: [INTERNAL_ERROR] Cannot find main error class 'TYPE_CHECK_FAILURE_WITH_HINT' at org.apache.spark.SparkException$.internalError(SparkException.scala:83) at org.apache.spark.SparkException$.internalError(SparkException.scala:87) at org.apache.spark.ErrorClassesJsonReader.$anonfun$getMessageTemplate$1(ErrorClassesJSONReader.scala:68) at scala.collection.immutable.HashMap$HashMap1.getOrElse0(HashMap.scala:361) at scala.collection.immutable.HashMap$HashTrieMap.getOrElse0(HashMap.scala:594) at scala.collection.immutable.HashMap$HashTrieMap.getOrElse0(HashMap.scala:589) at scala.collection.immutable.HashMap.getOrElse(HashMap.scala:73) ``` This issue only occurs when an expression uses `TypeCheckResult.TypeCheckFailure` to indicate input type check failure. `TypeCheckResult.TypeCheckFailure` appears to be deprecated in favor of `TypeCheckResult.DataTypeMismatch`, but recently two expressions were added that use `TypeCheckResult.TypeCheckFailure`: `BitmapCount` and `BitmapOrAgg`. `BitmapCount` and `BitmapOrAgg` should probably be fixed to use `TypeCheckResult.DataTypeMismatch`. Regardless, the code in `CheckAnalysis` that handles `TypeCheckResult.TypeCheckFailure` should either be fixed or removed. In this PR, I chose to fix it. ### Does this PR introduce _any_ user-facing change? No, except for the user seeing the correct error message. ### How was this patch tested? New unit test. Closes #42064 from bersprockets/type_check_issue. Authored-by: Bruce Robbins Signed-off-by: Hyukjin Kwon (cherry picked from commit 9619ada35059faf601ebefd5c225ae1ebf86f5ef) Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../analysis/AnalysisErrorSuite.scala | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index af9ea814a51da..d933ea26d5d99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -281,7 +281,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.failAnalysis( errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", messageParameters = Map( - "expr" -> toSQLExpr(e), + "sqlExpr" -> toSQLExpr(e), "msg" -> message, "hint" -> extraHint)) case checkRes: TypeCheckResult.InvalidFormat => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 6c43f84e8d0b7..e2e980073307d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -94,6 +94,28 @@ case class TestFunction( copy(children = newChildren) } +case class TestFunctionWithTypeCheckFailure( + children: Seq[Expression], + inputTypes: Seq[AbstractDataType]) + extends Expression with Unevaluable { + + override def checkInputDataTypes(): TypeCheckResult = { + for ((child, idx) <- children.zipWithIndex) { + val expectedDataType = inputTypes(idx) + if (child.dataType != expectedDataType) { + return TypeCheckResult.TypeCheckFailure( + s"Expression must be a ${expectedDataType.simpleString}") + } + } + TypeCheckResult.TypeCheckSuccess + } + + override def nullable: Boolean = true + override def dataType: DataType = StringType + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(children = newChildren) +} + case class UnresolvedTestPlan() extends UnresolvedLeafNode class AnalysisErrorSuite extends AnalysisTest { @@ -168,6 +190,16 @@ class AnalysisErrorSuite extends AnalysisTest { "inputType" -> "\"DATE\"", "requiredType" -> "\"INT\"")) + errorClassTest( + "SPARK-44477: type check failure", + testRelation.select( + TestFunctionWithTypeCheckFailure(dateLit :: Nil, BinaryType :: Nil).as("a")), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + messageParameters = Map( + "sqlExpr" -> "\"testfunctionwithtypecheckfailure(NULL)\"", + "msg" -> "Expression must be a binary", + "hint" -> "")) + errorClassTest( "invalid window function", testRelation2.select( From dae7a576ec7a103f2486fa4121731de0b1347a52 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 21 Jul 2023 10:41:24 +0800 Subject: [PATCH 035/986] [SPARK-43952][SQL][FOLLOWUP] Correct AQE cancel broadcast job tag ### What changes were proposed in this pull request? This pr changes `cancelJobGroup` to `cancelJobsWithTag ` in AQE, so that broadcast exchange can be cancelled correctly. Since we do not set job id when executing broadcast job and use job tag to cancel it, this pr adds `jobTag` to `BroadcastExchangeLike`. ### Why are the changes needed? fix regression ### Does this PR introduce _any_ user-facing change? no, not released yet ### How was this patch tested? test manully ```sql select * from t1 join (select c1, java_method('java.lang.Thread', 'sleep', 5000l) from t2)t2 on t1.c1 = t2.c1 join (select c1, raise_error('force_fail') from t3)t3 on t1.c1 = t3.c1 ``` before: image after: image Closes #41979 from ulysses-you/jobtag-followup. Authored-by: ulysses-you Signed-off-by: Xiduo You (cherry picked from commit 99f9df564ef3f3223b4789d111426d5be5854c4a) Signed-off-by: Xiduo You --- .../sql/execution/adaptive/QueryStageExec.scala | 2 +- .../execution/exchange/BroadcastExchangeExec.scala | 14 +++++++------- .../spark/sql/SparkSessionExtensionSuite.scala | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index d48b4fe175174..c6234a4072604 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -250,7 +250,7 @@ case class BroadcastQueryStageExec( override def cancel(): Unit = { if (!broadcast.relationFuture.isDone) { - sparkContext.cancelJobGroup(broadcast.runId.toString) + sparkContext.cancelJobsWithTag(broadcast.jobTag) broadcast.relationFuture.cancel(true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 15141b09b6c07..866a62a3a0776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -45,9 +45,14 @@ import org.apache.spark.util.{SparkFatalException, ThreadUtils} trait BroadcastExchangeLike extends Exchange { /** - * The broadcast job group ID + * The broadcast run ID in job tag */ - def runId: UUID = UUID.randomUUID + val runId: UUID = UUID.randomUUID + + /** + * The broadcast job tag + */ + def jobTag: String = s"broadcast exchange (runId ${runId.toString})" /** * The asynchronous job that prepares the broadcast relation. @@ -80,8 +85,6 @@ case class BroadcastExchangeExec( child: SparkPlan) extends BroadcastExchangeLike { import BroadcastExchangeExec._ - override val runId: UUID = UUID.randomUUID - override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -124,9 +127,6 @@ case class BroadcastExchangeExec( case _ => 512000000 } - @transient - private lazy val jobTag = s"broadcast exchange (runId ${runId.toString})" - @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 87237f467fcba..043a3b1a7e58f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -1003,7 +1003,7 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE * whether AQE is enabled. */ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike { - override def runId: UUID = delegate.runId + override val runId: UUID = delegate.runId override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] = delegate.relationFuture override def completionFuture: Future[Broadcast[Any]] = delegate.submitBroadcastJob From 57dcc373a1bc02d63db2dc87f99bee45569735f7 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 21 Jul 2023 12:09:59 +0900 Subject: [PATCH 036/986] [SPARK-44502][DOC][SS][PYTHON] Add missing versionchanged field to streaming functions Add missing doc field Better doc No Don't need Closes #42097 from WweiL/SPARK-44502-doc. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon (cherry picked from commit f1e310b84d5f1d2bde1bbafbb737a9cc612d573a) Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/dataframe.py | 3 + python/pyspark/sql/streaming/query.py | 51 +++++++++++++++++ python/pyspark/sql/streaming/readwriter.py | 66 ++++++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 99eb2a48bb21c..8e655dc3a88e5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1052,6 +1052,9 @@ def withWatermark(self, eventTime: str, delayThreshold: str) -> "DataFrame": .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- eventTime : str diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 6d18c93f0191d..443e7dbee39b7 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -58,6 +58,9 @@ def id(self) -> str: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- str @@ -85,6 +88,9 @@ def runId(self) -> str: .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- str @@ -114,6 +120,9 @@ def name(self) -> str: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- str @@ -140,6 +149,9 @@ def isActive(self) -> bool: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- bool @@ -171,6 +183,9 @@ def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timeout : int, optional @@ -212,6 +227,9 @@ def status(self) -> Dict[str, Any]: .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- dict @@ -240,6 +258,9 @@ def recentProgress(self) -> List[Dict[str, Any]]: .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- list @@ -268,6 +289,9 @@ def lastProgress(self) -> Optional[Dict[str, Any]]: .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- dict, optional @@ -297,6 +321,9 @@ def processAllAvailable(self) -> None: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- In the case of continually arriving data, this method may block forever. @@ -325,6 +352,9 @@ def stop(self) -> None: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Examples -------- >>> sdf = spark.readStream.format("rate").load() @@ -347,6 +377,9 @@ def explain(self, extended: bool = False) -> None: .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- extended : bool, optional @@ -387,6 +420,9 @@ def exception(self) -> Optional[StreamingQueryException]: """ .. versionadded:: 2.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`StreamingQueryException` @@ -406,6 +442,9 @@ class StreamingQueryManager: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -421,6 +460,9 @@ def active(self) -> List[StreamingQuery]: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- list @@ -451,6 +493,9 @@ def get(self, id: str) -> Optional[StreamingQuery]: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- id : str @@ -513,6 +558,9 @@ def awaitAnyTermination(self, timeout: Optional[int] = None) -> Optional[bool]: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timeout : int, optional @@ -554,6 +602,9 @@ def resetTerminated(self) -> None: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Examples -------- >>> spark.streams.resetTerminated() diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 3547dd060f63d..08d7396ba864a 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -87,6 +87,9 @@ def format(self, source: str) -> "DataStreamReader": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- source : str @@ -128,6 +131,9 @@ def schema(self, schema: Union[StructType, str]) -> "DataStreamReader": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- schema : :class:`pyspark.sql.types.StructType` or str @@ -177,6 +183,9 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamReader" .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -203,6 +212,9 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -238,6 +250,9 @@ def load( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- path : str, optional @@ -324,6 +339,9 @@ def json( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- path : str @@ -407,6 +425,9 @@ def orc( .. versionadded:: 2.3.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Other Parameters ---------------- Extra options @@ -458,6 +479,9 @@ def parquet( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- path : str @@ -521,6 +545,9 @@ def text( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- path : str or list @@ -619,6 +646,9 @@ def csv( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Other Parameters ---------------- Extra options @@ -695,6 +725,9 @@ def table(self, tableName: str) -> "DataFrame": .. versionadded:: 3.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- tableName : str @@ -745,6 +778,9 @@ class DataStreamWriter: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -776,6 +812,9 @@ def outputMode(self, outputMode: str) -> "DataStreamWriter": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Options include: * `append`: Only the new rows in the streaming DataFrame/Dataset will be written to @@ -818,6 +857,9 @@ def format(self, source: str) -> "DataStreamWriter": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- source : str @@ -857,6 +899,9 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamWriter" .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -885,6 +930,9 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamWriter": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -925,6 +973,9 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- cols : str or list @@ -968,6 +1019,9 @@ def queryName(self, queryName: str) -> "DataStreamWriter": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- queryName : str @@ -1023,6 +1077,9 @@ def trigger( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- processingTime : str, optional @@ -1280,6 +1337,9 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt .. versionadded:: 2.4.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1388,6 +1448,9 @@ def start( .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- path : str, optional @@ -1473,6 +1536,9 @@ def toTable( .. versionadded:: 3.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- tableName : str From 40f30dd036c7df949ce11c59a009bd8ebafe1f0d Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 21 Jul 2023 12:21:14 +0800 Subject: [PATCH 037/986] [SPARK-43966][SQL][PYTHON] Support non-deterministic table-valued functions ### What changes were proposed in this pull request? This PR supports non-deterministic table-valued functions. More specifically, it supports running non-deterministic Python UDTFs and built-in table-valued generator functions with non-deterministic input values. ### Why are the changes needed? To make table-valued functions more versatile. ### Does this PR introduce _any_ user-facing change? Yes. Before this PR, Spark will throw an exception when running a non-deterministic Python UDTF: ``` select * from random_udtf(1) AnalysisException: [INVALID_NON_DETERMINISTIC_EXPRESSIONS] The operator expects a deterministic expression, ``` After this PR, it is supported. ### How was this patch tested? Existing and new unit tests. Closes #42075 from allisonwang-db/spark-43966-non-det-udtf. Authored-by: allisonwang-db Signed-off-by: Wenchen Fan (cherry picked from commit 1fb3e16a48d826aed1ca9688a661281f750bbf5a) Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests/test_udtf.py | 18 ++++++------------ .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../table-valued-functions.sql.out | 6 ++++++ .../inputs/table-valued-functions.sql | 3 +++ .../results/table-valued-functions.sql.out | 8 ++++++++ 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index ec3379accca60..2c76d2f7e152e 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -355,14 +355,12 @@ def test_nondeterministic_udtf(self): class RandomUDTF: def eval(self, a: int): - yield a * int(random.random() * 100), + yield a + int(random.random()), random_udtf = udtf(RandomUDTF, returnType="x: int").asNondeterministic() - # TODO(SPARK-43966): support non-deterministic UDTFs - with self.assertRaisesRegex( - AnalysisException, "The operator expects a deterministic expression" - ): - random_udtf(lit(1)).collect() + assertDataFrameEqual(random_udtf(lit(1)), [Row(x=1)]) + self.spark.udtf.register("random_udtf", random_udtf) + assertDataFrameEqual(self.spark.sql("select * from random_udtf(1)"), [Row(x=1)]) def test_udtf_with_nondeterministic_input(self): from pyspark.sql.functions import rand @@ -370,13 +368,9 @@ def test_udtf_with_nondeterministic_input(self): @udtf(returnType="x: int") class TestUDTF: def eval(self, a: int): - yield a + 1, + yield 1 if a > 100 else 0, - # TODO(SPARK-43966): support non-deterministic UDTFs - with self.assertRaisesRegex( - AnalysisException, " The operator expects a deterministic expression" - ): - TestUDTF(rand(0) * 100).collect() + assertDataFrameEqual(TestUDTF(rand(0) * 100), [Row(x=0)]) def test_udtf_with_invalid_return_type(self): @udtf(returnType="int") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d933ea26d5d99..e198fd58953dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -752,6 +752,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] && !o.isInstanceOf[Expand] && + !o.isInstanceOf[Generate] && // Lateral join is checked in checkSubqueryExpression. !o.isInstanceOf[LateralJoin] => // The rule above is used to check Aggregate operator. diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out index 49ad4bf19f76b..6c29a0ec1db77 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/table-valued-functions.sql.out @@ -205,6 +205,12 @@ Project [k#x, v#x] +- OneRowRelation +-- !query +select * from explode(array(rand(0))) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + -- !query select * from explode(null) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 2b809f9a7c8a2..79d427bc2099d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -43,6 +43,9 @@ select * from explode(map()); select * from explode(array(1, 2)) t(c1); select * from explode(map('a', 1, 'b', 2)) t(k, v); +-- explode with non-deterministic values +select * from explode(array(rand(0))); + -- explode with erroneous input select * from explode(null); select * from explode(null) t(c1); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index 578461d164a1b..1348110a83a37 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -242,6 +242,14 @@ a 1 b 2 +-- !query +select * from explode(array(rand(0))) +-- !query schema +struct +-- !query output +0.7604953758285915 + + -- !query select * from explode(null) -- !query schema From 86237bba2625ad0cf5325c85e342e6230d7a0699 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Fri, 21 Jul 2023 13:46:01 +0900 Subject: [PATCH 038/986] [SPARK-44504][SS] Unload provider thereby forcing DB instance close and releasing resources on maintenance task error ### What changes were proposed in this pull request? Unload provider thereby forcing DB instance close and releasing resources on maintenance task error ### Why are the changes needed? If we don't do the close, the DB instance and corresponding resources (memory, file descriptors etc) are always left open and the pointer to these objects is lost since loadedProviders is cleared. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit tests ``` ), ForkJoinPool.commonPool-worker-5 (daemon=true), ForkJoinPool.commonPool-worker-17 (daemon=true), shuffle-boss-6-1 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), ForkJoinPool.commonPool-worker-31 (daemon=true), ForkJoinPool.commonPool-worker-23 (daemon=true), state-store-maintenance-task (daemon=true), ForkJoinPool.commonPool-worker-9 (daemon=true) ===== [info] Run completed in 2 minutes, 49 seconds. [info] Total number of tests run: 32 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 32, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` Closes #42098 from anishshri-db/task/SPARK-44504. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../apache/spark/sql/execution/streaming/state/RocksDB.scala | 4 ++++ .../spark/sql/execution/streaming/state/StateStore.scala | 3 +++ 2 files changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 65299ea37eff2..386df61a9e08f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -462,6 +462,8 @@ class RocksDB( /** Release all resources */ def close(): Unit = { try { + // Acquire DB instance lock and release at the end to allow for synchronized access + acquire() closeDB() readOptions.close() @@ -476,6 +478,8 @@ class RocksDB( } catch { case e: Exception => logWarning("Error closing RocksDB", e) + } finally { + release() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 96c7b61f2057f..8a09b226a0cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -611,6 +611,9 @@ object StateStore extends Logging { onError = { loadedProviders.synchronized { logInfo("Stopping maintenance task since an error was encountered.") stopMaintenanceTask() + // SPARK-44504 - Unload explicitly to force closing underlying DB instance + // and releasing allocated resources, especially for RocksDBStateStoreProvider. + loadedProviders.keySet.foreach { key => unload(key) } loadedProviders.clear() } } From e8dd144abcab58870aa730517b6cea5121b5868e Mon Sep 17 00:00:00 2001 From: Venki Korukanti Date: Fri, 21 Jul 2023 13:16:23 +0800 Subject: [PATCH 039/986] [SPARK-39634][SQL] Allow file splitting in combination with row index generation ### What changes were proposed in this pull request? - Parquet version `1.13.1` has a fix for [PARQUET-2161](https://issues.apache.org/jira/browse/PARQUET-2161) which allows splitting the parquet files when row index metadata column is selected. Currently the file splitting is disabled. Enable file splitting with row index column. ### Why are the changes needed? Splitting parquet files allows better parallelization when row index metadata column is selected. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Uncomment the existing unittests. Closes #40728 from vkorukanti/SPARK-39634. Authored-by: Venki Korukanti Signed-off-by: Wenchen Fan (cherry picked from commit 679ea56dc8be7566c32a606639aa052421136afc) Signed-off-by: Wenchen Fan --- .../apache/spark/sql/execution/DataSourceScanExec.scala | 8 ++------ .../datasources/parquet/ParquetRowIndexUtil.scala | 4 ---- .../execution/datasources/v2/parquet/ParquetScan.scala | 7 ++----- .../datasources/parquet/ParquetRowIndexSuite.scala | 5 +---- 4 files changed, 5 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 243aaabc0cbdd..6375cdacaa07e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.catalyst.util.{truncatedString, CaseInsensitiveMap} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource, ParquetRowIndexUtil} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.ConstantColumnVector @@ -691,11 +691,7 @@ case class FileSourceScanExec( partition.files.flatMap { file => if (shouldProcess(file.getPath)) { val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, relation.options, file.getPath) && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - (!relation.fileFormat.isInstanceOf[ParquetSource] - || !ParquetRowIndexUtil.isNeededForSchema(requiredSchema)) + relation.sparkSession, relation.options, file.getPath) PartitionedFileUtil.splitFiles( sparkSession = relation.sparkSession, file = file, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexUtil.scala index b1cd2ebee4265..a5d8494cfa77c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexUtil.scala @@ -126,10 +126,6 @@ object ParquetRowIndexUtil { } } - def isNeededForSchema(sparkSchema: StructType): Boolean = { - findRowIndexColumnIndexInSchema(sparkSchema) >= 0 - } - def isRowIndexColumn(column: ParquetColumn): Boolean = { column.path.length == 1 && column.path.last == ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index c7819d39cbcd5..0e77b419ff524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} -import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetRowIndexUtil, ParquetWriteSupport} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter @@ -50,10 +50,7 @@ case class ParquetScan( override def isSplitable(path: Path): Boolean = { // If aggregate is pushed down, only the file footer will be read once, // so file should not be split across multiple tasks. - pushedAggregate.isEmpty && - // SPARK-39634: Allow file splitting in combination with row index generation once - // the fix for PARQUET-2161 is available. - !ParquetRowIndexUtil.isNeededForSchema(readSchema) + pushedAggregate.isEmpty } override def readSchema(): StructType = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala index dd350ffd31510..27c2a2148fd3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala @@ -249,10 +249,7 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { assert(numOutputRows > 0) if (conf.useSmallSplits) { - // SPARK-39634: Until the fix the fix for PARQUET-2161 is available is available, - // it is not possible to split Parquet files into multiple partitions while generating - // row indexes. - // assert(numPartitions >= 2 * conf.numFiles) + assert(numPartitions >= 2 * conf.numFiles) } // Assert that every rowIdx value matches the value in `expectedRowIdx`. From cc58fe3efe97417cc8c03e926895f48c2abb669d Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 21 Jul 2023 14:23:47 +0900 Subject: [PATCH 040/986] [SPARK-44422][CONNECT] Spark Connect fine grained interrupt ### What changes were proposed in this pull request? Currently, Spark Connect only allows to cancel all operations in a session by using SparkSession.interruptAll(). In this PR we are adding a mechanism to interrupt by tag (similar to SparkContext.cancelJobsWithTag), and to interrupt individual operations. Also, add the new tags to SparkListenerConnectOperationStarted. ### Why are the changes needed? Better control of query cancelation in Spark Connect ### Does this PR introduce _any_ user-facing change? Yes. New Apis in Spark Connect scala client: ``` SparkSession.addTag SparkSession.removeTag SparkSession.getTags SparkSession.clearTags SparkSession.interruptTag SparkSession.interruptOperation ``` and also `SparkResult.operationId`, to be able to get the id for `SparkSession.interruptOperation`. Python client APIs will be added in a followup PR. ### How was this patch tested? Added tests in SparkSessionE2ESuite. Closes #42009 from juliuszsompolski/sc-fine-grained-cancel. Authored-by: Juliusz Sompolski Signed-off-by: Hyukjin Kwon (cherry picked from commit dda37841189a753d7a31d22e091b51903f6cd624) Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 18 ++ .../org/apache/spark/sql/SparkSession.scala | 80 ++++++++- .../connect/client/SparkConnectClient.scala | 58 +++++++ .../sql/connect/client/SparkResult.scala | 27 +++ .../spark/sql/SparkSessionE2ESuite.scala | 154 +++++++++++++++++- .../CheckConnectJvmClientCompatibility.scala | 12 ++ .../main/protobuf/spark/connect/base.proto | 37 ++++- .../spark/sql/connect/common/ProtoUtils.scala | 24 +++ .../execution/ExecuteResponseObserver.scala | 25 ++- .../execution/ExecuteThreadRunner.scala | 25 ++- .../service/ExecuteEventsManager.scala | 8 +- .../sql/connect/service/ExecuteHolder.scala | 36 +++- .../sql/connect/service/SessionHolder.scala | 75 +++++++-- .../SparkConnectInterruptHandler.scala | 24 ++- .../service/ExecuteEventsManagerSuite.scala | 1 + .../scala/org/apache/spark/SparkContext.scala | 1 + ...r-conditions-invalid-handle-error-class.md | 36 ++++ docs/sql-error-conditions.md | 8 + python/pyspark/sql/connect/proto/base_pb2.py | 144 ++++++++-------- python/pyspark/sql/connect/proto/base_pb2.pyi | 93 ++++++++++- 20 files changed, 778 insertions(+), 108 deletions(-) create mode 100644 docs/sql-error-conditions-invalid-handle-error-class.md diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index b136878e6d2c0..d61b17216641f 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1383,6 +1383,24 @@ ], "sqlState" : "22023" }, + "INVALID_HANDLE" : { + "message" : [ + "The handle is invalid." + ], + "subClass" : { + "ALREADY_EXISTS" : { + "message" : [ + "Handle already exists." + ] + }, + "FORMAT" : { + "message" : [ + "Handle has invalid format. Handle must an UUID string of the format '00112233-4455-6677-8899-aabbccddeeff'" + ] + } + }, + "sqlState" : "HY000" + }, "INVALID_HIVE_COLUMN_NAME" : { "message" : [ "Cannot create the table having the nested column whose name contains invalid characters in Hive metastore." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index fb9959c994289..b37e3884038b2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -613,14 +613,40 @@ class SparkSession private[sql] ( /** * Interrupt all operations of this session currently running on the connected server. * - * TODO/WIP: Currently it will interrupt the Spark Jobs running on the server, triggered from - * ExecutePlan requests. If an operation is not running a Spark Job, it becomes an noop and the - * operation will continue afterwards, possibly with more Spark Jobs. + * @return + * sequence of operationIds of interrupted operations. Note: there is still a possiblility of + * operation finishing just as it is interrupted. * * @since 3.5.0 */ - def interruptAll(): Unit = { - client.interruptAll() + def interruptAll(): Seq[String] = { + client.interruptAll().getInterruptedIdsList.asScala.toSeq + } + + /** + * Interrupt all operations of this session with the given operation tag. + * + * @return + * sequence of operationIds of interrupted operations. Note: there is still a possiblility of + * operation finishing just as it is interrupted. + * + * @since 3.5.0 + */ + def interruptTag(tag: String): Seq[String] = { + client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq + } + + /** + * Interrupt an operation of this session with the given operationId. + * + * @return + * sequence of operationIds of interrupted operations. Note: there is still a possiblility of + * operation finishing just as it is interrupted. + * + * @since 3.5.0 + */ + def interruptOperation(operationId: String): Seq[String] = { + client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } /** @@ -641,6 +667,50 @@ class SparkSession private[sql] ( allocator.close() SparkSession.onSessionClose(this) } + + /** + * Add a tag to be assigned to all the operations started by this thread in this session. + * + * @param tag + * The tag to be added. Cannot contain ',' (comma) character or be an empty string. + * + * @since 3.5.0 + */ + def addTag(tag: String): Unit = { + client.addTag(tag) + } + + /** + * Remove a tag previously added to be assigned to all the operations started by this thread in + * this session. Noop if such a tag was not added earlier. + * + * @param tag + * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + * + * @since 3.5.0 + */ + def removeTag(tag: String): Unit = { + client.removeTag(tag) + } + + /** + * Get the tags that are currently set to be assigned to all the operations started by this + * thread. + * + * @since 3.5.0 + */ + def getTags(): Set[String] = { + client.getTags() + } + + /** + * Clear the current thread's operation tags. + * + * @since 3.5.0 + */ + def clearTags(): Unit = { + client.clearTags() + } } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index b41ae5555bf3a..d03d27a6f53d5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -21,11 +21,15 @@ import java.net.URI import java.util.UUID import java.util.concurrent.Executor +import scala.collection.JavaConverters._ +import scala.collection.mutable + import com.google.protobuf.ByteString import io.grpc._ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.UserContext +import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.common.config.ConnectCommon /** @@ -76,6 +80,7 @@ private[sql] class SparkConnectClient( .setUserContext(userContext) .setSessionId(sessionId) .setClientType(userAgent) + .addAllTags(tags.get.toSeq.asJava) .build() bstub.executePlan(request) } @@ -195,6 +200,59 @@ private[sql] class SparkConnectClient( bstub.interrupt(request) } + private[sql] def interruptTag(tag: String): proto.InterruptResponse = { + val builder = proto.InterruptRequest.newBuilder() + val request = builder + .setUserContext(userContext) + .setSessionId(sessionId) + .setClientType(userAgent) + .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG) + .setOperationTag(tag) + .build() + bstub.interrupt(request) + } + + private[sql] def interruptOperation(id: String): proto.InterruptResponse = { + val builder = proto.InterruptRequest.newBuilder() + val request = builder + .setUserContext(userContext) + .setSessionId(sessionId) + .setClientType(userAgent) + .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID) + .setOperationId(id) + .build() + bstub.interrupt(request) + } + + private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] { + override def childValue(parent: mutable.Set[String]): mutable.Set[String] = { + // Note: make a clone such that changes in the parent tags aren't reflected in + // those of the children threads. + parent.clone() + } + override protected def initialValue(): mutable.Set[String] = new mutable.HashSet[String]() + } + + private[sql] def addTag(tag: String): Unit = { + // validation is also done server side, but this will give error earlier. + ProtoUtils.throwIfInvalidTag(tag) + tags.get += tag + } + + private[sql] def removeTag(tag: String): Unit = { + // validation is also done server side, but this will give error earlier. + ProtoUtils.throwIfInvalidTag(tag) + tags.get.remove(tag) + } + + private[sql] def getTags(): Set[String] = { + tags.get.toSet + } + + private[sql] def clearTags(): Unit = { + tags.get.clear() + } + def copy(): SparkConnectClient = configuration.toSparkConnectClient /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 1cdc2035de60b..eed8bd3f37d90 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -40,6 +40,7 @@ private[sql] class SparkResult[T]( extends AutoCloseable with Cleanable { self => + private[this] var opId: String = _ private[this] var numRecords: Int = 0 private[this] var structType: StructType = _ private[this] var arrowSchema: pojo.Schema = _ @@ -72,6 +73,7 @@ private[sql] class SparkResult[T]( } private def processResponses( + stopOnOperationId: Boolean = false, stopOnSchema: Boolean = false, stopOnArrowSchema: Boolean = false, stopOnFirstNonEmptyResponse: Boolean = false): Boolean = { @@ -79,6 +81,20 @@ private[sql] class SparkResult[T]( var stop = false while (!stop && responses.hasNext) { val response = responses.next() + + // Save and validate operationId + if (opId == null) { + opId = response.getOperationId + } + if (opId != response.getOperationId) { + // backwards compatibility: + // response from an old server without operationId field would have getOperationId == "". + throw new IllegalStateException( + "Received response with wrong operationId. " + + s"Expected '$opId' but received '${response.getOperationId}'.") + } + stop |= stopOnOperationId + if (response.hasSchema) { // The original schema should arrive before ArrowBatches. structType = @@ -148,6 +164,17 @@ private[sql] class SparkResult[T]( structType } + /** + * @return + * the operationId of the result. + */ + def operationId: String = { + if (opId == null) { + processResponses(stopOnOperationId = true) + } + opId + } + /** * Create an Array with the contents of the result. */ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 70eeb6c2c41df..5afafaaa6b92f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql +import scala.collection.mutable import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration._ import scala.util.{Failure, Success} @@ -64,13 +65,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession { } // 20 seconds is < 30 seconds the queries should be running, // because it should be interrupted sooner + val interrupted = mutable.ListBuffer[String]() eventually(timeout(20.seconds), interval(1.seconds)) { // keep interrupting every second, until both queries get interrupted. - spark.interruptAll() + val ids = spark.interruptAll() + interrupted ++= ids assert(error.isEmpty, s"Error not empty: $error") assert(q1Interrupted) assert(q2Interrupted) } + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } test("interrupt all - foreground queries, background interrupt") { @@ -79,9 +83,12 @@ class SparkSessionE2ESuite extends RemoteSparkSession { implicit val ec: ExecutionContextExecutor = ExecutionContext.global @volatile var finished = false + val interrupted = mutable.ListBuffer[String]() + val interruptor = Future { eventually(timeout(20.seconds), interval(1.seconds)) { - spark.interruptAll() + val ids = spark.interruptAll() + interrupted ++= ids assert(finished) } finished @@ -96,5 +103,148 @@ class SparkSessionE2ESuite extends RemoteSparkSession { assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2") finished = true assert(ThreadUtils.awaitResult(interruptor, 10.seconds)) + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + + test("interrupt tag") { + val session = spark + import session.implicits._ + + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 4 + val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + val q1 = Future { + assert(spark.getTags() == Set()) + spark.addTag("two") + assert(spark.getTags() == Set("two")) + spark.clearTags() // check that clearing all tags works + assert(spark.getTags() == Set()) + spark.addTag("one") + assert(spark.getTags() == Set("one")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val q2 = Future { + assert(spark.getTags() == Set()) + spark.addTag("one") + spark.addTag("two") + spark.addTag("one") + spark.addTag("two") // duplicates shouldn't matter + assert(spark.getTags() == Set("one", "two")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val q3 = Future { + assert(spark.getTags() == Set()) + spark.addTag("foo") + spark.removeTag("foo") + assert(spark.getTags() == Set()) // check that remove works removing the last tag + spark.addTag("two") + assert(spark.getTags() == Set("two")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val q4 = Future { + assert(spark.getTags() == Set()) + spark.addTag("one") + spark.addTag("two") + spark.addTag("two") + assert(spark.getTags() == Set("one", "two")) + spark.removeTag("two") // check that remove works, despite duplicate add + assert(spark.getTags() == Set("one")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val interrupted = mutable.ListBuffer[String]() + + // q2 and q3 should be cancelled + interrupted.clear() + eventually(timeout(20.seconds), interval(1.seconds)) { + val ids = spark.interruptTag("two") + interrupted ++= ids + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + val e2 = intercept[SparkException] { + ThreadUtils.awaitResult(q2, 1.minute) + } + assert(e2.getCause.getMessage contains "OPERATION_CANCELED") + val e3 = intercept[SparkException] { + ThreadUtils.awaitResult(q3, 1.minute) + } + assert(e3.getCause.getMessage contains "OPERATION_CANCELED") + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + + // q1 and q4 should be cancelled + interrupted.clear() + eventually(timeout(20.seconds), interval(1.seconds)) { + val ids = spark.interruptTag("one") + interrupted ++= ids + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + val e1 = intercept[SparkException] { + ThreadUtils.awaitResult(q1, 1.minute) + } + assert(e1.getCause.getMessage contains "OPERATION_CANCELED") + val e4 = intercept[SparkException] { + ThreadUtils.awaitResult(q4, 1.minute) + } + assert(e4.getCause.getMessage contains "OPERATION_CANCELED") + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + + test("interrupt operation") { + val session = spark + import session.implicits._ + + val result = spark + .range(10) + .map(n => { + Thread.sleep(5000); n + }) + .collectResult() + // cancel + val operationId = result.operationId + val canceledId = spark.interruptOperation(operationId) + assert(canceledId == Seq(operationId)) + // and check that it got canceled + val e = intercept[SparkException] { + result.toArray + } + assert(e.getMessage contains "OPERATION_CANCELED") } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index e7f01d6140dec..deb2ff631fdf2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -365,6 +365,18 @@ object CheckConnectJvmClientCompatibility { // public ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.interruptAll"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.interruptTag"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.interruptOperation"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.addTag"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.removeTag"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.getTags"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.clearTags"), // SparkSession#Builder ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.remote"), diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index e869712858a3d..d935ae65328d4 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -281,6 +281,12 @@ message ExecutePlanRequest { // server side. UserContext user_context = 2; + // (Optional) + // Provide an id for this request. If not provided, it will be generated by the server. + // It is returned in every ExecutePlanResponse.operation_id of the ExecutePlan response stream. + // The id must be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` + optional string operation_id = 6; + // (Required) The logical plan to be executed / analyzed. Plan plan = 3; @@ -299,6 +305,11 @@ message ExecutePlanRequest { google.protobuf.Any extension = 999; } } + + // Tags to tag the given execution with. + // Tags cannot contain ',' character and cannot be empty strings. + // Used by Interrupt with interrupt.tag. + repeated string tags = 7; } // The response of a query, can be one or more for each request. Responses belonging to the @@ -306,6 +317,12 @@ message ExecutePlanRequest { message ExecutePlanResponse { string session_id = 1; + // Identifies the ExecutePlan execution. + // If set by the client in ExecutePlanRequest.operationId, that value is returned. + // Otherwise generated by the server. + // It is an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` + string operation_id = 12; + // Union type for the different response messages. oneof response_type { ArrowBatch arrow_batch = 2; @@ -616,13 +633,31 @@ message InterruptRequest { enum InterruptType { INTERRUPT_TYPE_UNSPECIFIED = 0; - // Interrupt all running executions within session with provided session_id. + // Interrupt all running executions within the session with the provided session_id. INTERRUPT_TYPE_ALL = 1; + + // Interrupt all running executions within the session with the provided operation_tag. + INTERRUPT_TYPE_TAG = 2; + + // Interrupt the running execution within the session with the provided operation_id. + INTERRUPT_TYPE_OPERATION_ID = 3; + } + + oneof interrupt { + // if interrupt_tag == INTERRUPT_TYPE_TAG, interrupt operation with this tag. + string operation_tag = 5; + + // if interrupt_tag == INTERRUPT_TYPE_OPERATION_ID, interrupt operation with this operation_id. + string operation_id = 6; } } message InterruptResponse { + // Session id in which the interrupt was running. string session_id = 1; + + // Operation ids of the executions which were interrupted. + repeated string interrupted_ids = 2; } // Main interface for the SparkConnect service. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index e0c7d267c604e..e2934b5674495 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -81,4 +81,28 @@ private[connect] object ProtoUtils { private def createString(prefix: String, size: Int): String = { s"$prefix[truncated(size=${format.format(size)})]" } + + // Because Spark Connect operation tags are also set as SparkContext Job tags, they cannot contain + // SparkContext.SPARK_JOB_TAGS_SEP + private var SPARK_JOB_TAGS_SEP = ',' // SparkContext.SPARK_JOB_TAGS_SEP + + /** + * Validate if a tag for ExecutePlanRequest.tags is valid. Throw IllegalArgumentException if + * not. + */ + def throwIfInvalidTag(tag: String): Unit = { + // Same format rules apply to Spark Connect execution tags as to SparkContext job tags, + // because the Spark Connect job tag is also used as part of SparkContext job tag. + // See SparkContext.throwIfInvalidTag and ExecuteHolder.tagToSparkJobTag + if (tag == null) { + throw new IllegalArgumentException("Spark Connect tag cannot be null.") + } + if (tag.contains(SPARK_JOB_TAGS_SEP)) { + throw new IllegalArgumentException( + s"Spark Connect tag cannot contain '$SPARK_JOB_TAGS_SEP'.") + } + if (tag.isEmpty) { + throw new IllegalArgumentException("Spark Connect tag cannot be an empty string.") + } + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala index 5aecbdfce163e..ae89c150a68f2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -21,7 +21,9 @@ import scala.collection.mutable import io.grpc.stub.StreamObserver +import org.apache.spark.connect.proto import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.service.ExecuteHolder /** * This StreamObserver is running on the execution thread. Execution pushes responses to it, it @@ -40,7 +42,9 @@ import org.apache.spark.internal.Logging * @see * attachConsumer */ -private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] with Logging { +private[connect] class ExecuteResponseObserver[T](val executeHolder: ExecuteHolder) + extends StreamObserver[T] + with Logging { /** * Cached responses produced by the execution. Map from response index -> response. Response @@ -77,7 +81,9 @@ private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] wi throw new IllegalStateException("Stream onNext can't be called after stream completed") } lastProducedIndex += 1 - responses += ((lastProducedIndex, CachedStreamResponse[T](r, lastProducedIndex))) + val processedResponse = setCommonResponseFields(r) + responses += + ((lastProducedIndex, CachedStreamResponse[T](processedResponse, lastProducedIndex))) logDebug(s"Saved response with index=$lastProducedIndex") notifyAll() } @@ -158,4 +164,19 @@ private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] wi i -= 1 } } + + /** + * Populate response fields that are common and should be set in every response. + */ + private def setCommonResponseFields(response: T): T = { + response match { + case executePlanResponse: proto.ExecutePlanResponse => + executePlanResponse + .toBuilder() + .setSessionId(executeHolder.sessionHolder.sessionId) + .setOperationId(executeHolder.operationId) + .build() + .asInstanceOf[T] + } + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 6c2ffa4654747..6758df0d7e6d7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -54,11 +54,20 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends executionThread.join() } - /** Interrupt the executing thread. */ - def interrupt(): Unit = { + /** + * Interrupt the executing thread. + * @return + * true if it was not interrupted before, false if it was already interrupted. + */ + def interrupt(): Boolean = { synchronized { - interrupted = true - executionThread.interrupt() + if (!interrupted) { + interrupted = true + executionThread.interrupt() + true + } else { + false + } } } @@ -85,6 +94,10 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends } } finally { executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag) + executeHolder.sparkSessionTags.foreach { tag => + executeHolder.sessionHolder.session.sparkContext + .removeJobTag(executeHolder.tagToSparkJobTag(tag)) + } } } catch { ErrorUtils.handleError( @@ -113,6 +126,10 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends // Set tag for query cancellation session.sparkContext.addJobTag(executeHolder.jobTag) + // Also set all user defined tags as Spark Job tags. + executeHolder.sparkSessionTags.foreach { tag => + session.sparkContext.addJobTag(executeHolder.tagToSparkJobTag(tag)) + } session.sparkContext.setJobDescription( s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") session.sparkContext.setInterruptOnCancel(true) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala index 0af54f034a254..5e831aaa98f2f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala @@ -59,6 +59,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { private def jobTag = executeHolder.jobTag + private def sparkSessionTags = executeHolder.sparkSessionTags + private def listenerBus = sessionHolder.session.sparkContext.listenerBus private def sessionHolder = executeHolder.sessionHolder @@ -119,7 +121,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { Utils.redact( sessionHolder.session.sessionState.conf.stringRedactionPattern, ProtoUtils.abbreviate(plan, ExecuteEventsManager.MAX_STATEMENT_TEXT_SIZE).toString), - Some(request))) + Some(request), + sparkSessionTags)) } /** @@ -270,6 +273,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) { * The connect request plan converted to text. * @param planRequest: * The Connect request. None if the operation is not of type @link proto.ExecutePlanRequest + * @param sparkSessionTags: + * Extra tags set by the user (via SparkSession.addTag). * @param extraTags: * Additional metadata during the request. */ @@ -282,6 +287,7 @@ case class SparkListenerConnectOperationStarted( userName: String, statementText: String, planRequest: Option[proto.ExecutePlanRequest], + sparkSessionTags: Set[String], extraTags: Map[String, String] = Map.empty) extends SparkListenerEvent diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 1f70973b60e0d..74530ad032f13 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.connect.service +import scala.collection.JavaConverters._ + import org.apache.spark.connect.proto import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner} import org.apache.spark.util.SystemClock @@ -31,16 +34,34 @@ private[connect] class ExecuteHolder( val sessionHolder: SessionHolder) extends Logging { + /** + * Tag that is set for this execution on SparkContext, via SparkContext.addJobTag. Used + * (internally) for cancallation of the Spark Jobs ran by this execution. + */ val jobTag = s"SparkConnect_Execute_" + s"User_${sessionHolder.userId}_" + s"Session_${sessionHolder.sessionId}_" + s"Request_${operationId}" + /** + * Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group + * executions, and for user cancellation using SparkSession.interruptTag. + */ + val sparkSessionTags: Set[String] = request + .getTagsList() + .asScala + .toSeq + .map { tag => + ProtoUtils.throwIfInvalidTag(tag) + tag + } + .toSet + val session = sessionHolder.session val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] = - new ExecuteResponseObserver[proto.ExecutePlanResponse]() + new ExecuteResponseObserver[proto.ExecutePlanResponse](this) val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new SystemClock()) @@ -85,8 +106,19 @@ private[connect] class ExecuteHolder( /** * Interrupt the execution. Interrupts the running thread, which cancels all running Spark Jobs * and makes the execution throw an OPERATION_CANCELED error. + * @return + * true if it was not interrupted before, false if it was already interrupted. */ - def interrupt(): Unit = { + def interrupt(): Boolean = { runner.interrupt() } + + /** + * Spark Connect tags are also added as SparkContext job tags, but to make the tag unique, they + * need to be combined with userId and sessionId. + */ + def tagToSparkJobTag(tag: String): String = { + "SparkConnect_Tag_" + + s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}" + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5ac4f6db82aa3..ae53d1d171f0c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -22,10 +22,9 @@ import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal +import scala.collection.mutable -import org.apache.spark.JobArtifactSet -import org.apache.spark.SparkException +import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame @@ -57,9 +56,25 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio new ConcurrentHashMap() private[connect] def createExecuteHolder(request: proto.ExecutePlanRequest): ExecuteHolder = { - val operationId = UUID.randomUUID().toString + val operationId = if (request.hasOperationId) { + try { + UUID.fromString(request.getOperationId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> request.getOperationId)) + } + } else { + UUID.randomUUID().toString + } val executePlanHolder = new ExecuteHolder(request, operationId, this) - assert(executions.putIfAbsent(operationId, executePlanHolder) == null) + val oldExecute = executions.putIfAbsent(operationId, executePlanHolder) + if (oldExecute != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.ALREADY_EXISTS", + messageParameters = Map("handle" -> operationId)) + } executePlanHolder } @@ -71,17 +86,51 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio executions.remove(operationId) } - private[connect] def interruptAll(): Unit = { + /** + * Interrupt all executions in the session. + * @return + * list of operationIds of interrupted executions + */ + private[connect] def interruptAll(): Seq[String] = { + val interruptedIds = new mutable.ArrayBuffer[String]() executions.asScala.values.foreach { execute => - // Eat exception while trying to interrupt a given execution and move forward. - try { - logDebug(s"Interrupting execution ${execute.operationId}") - execute.interrupt() - } catch { - case NonFatal(e) => - logWarning(s"Exception $e while trying to interrupt execution ${execute.operationId}") + if (execute.interrupt()) { + interruptedIds += execute.operationId + } + } + interruptedIds.toSeq + } + + /** + * Interrupt executions in the session with a given tag. + * @return + * list of operationIds of interrupted executions + */ + private[connect] def interruptTag(tag: String): Seq[String] = { + val interruptedIds = new mutable.ArrayBuffer[String]() + executions.asScala.values.foreach { execute => + if (execute.sparkSessionTags.contains(tag)) { + if (execute.interrupt()) { + interruptedIds += execute.operationId + } + } + } + interruptedIds.toSeq + } + + /** + * Interrupt the execution with the given operation_id + * @return + * list of operationIds of interrupted executions (one element or empty) + */ + private[connect] def interruptOperation(operationId: String): Seq[String] = { + val interruptedIds = new mutable.ArrayBuffer[String]() + Option(executions.get(operationId)).foreach { execute => + if (execute.interrupt()) { + interruptedIds += execute.operationId } } + interruptedIds.toSeq } private[connect] lazy val artifactManager = new SparkConnectArtifactManager(this) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala index b0923e277e426..a9ed391460ca9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.service +import scala.collection.JavaConverters._ + import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto @@ -30,16 +32,32 @@ class SparkConnectInterruptHandler(responseObserver: StreamObserver[proto.Interr SparkConnectService .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) - v.getInterruptType match { + val interruptedIds = v.getInterruptType match { case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL => sessionHolder.interruptAll() + case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG => + if (!v.hasOperationTag) { + throw new IllegalArgumentException( + s"INTERRUPT_TYPE_TAG requested, but no operation_tag provided.") + } + sessionHolder.interruptTag(v.getOperationTag) + case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID => + if (!v.hasOperationId) { + throw new IllegalArgumentException( + s"INTERRUPT_TYPE_OPERATION_ID requested, but no operation_id provided.") + } + sessionHolder.interruptOperation(v.getOperationId) case other => throw new UnsupportedOperationException(s"Unknown InterruptType $other!") } - val builder = proto.InterruptResponse.newBuilder().setSessionId(v.getSessionId) + val response = proto.InterruptResponse + .newBuilder() + .setSessionId(v.getSessionId) + .addAllInterruptedIds(interruptedIds.asJava) + .build() - responseObserver.onNext(builder.build()) + responseObserver.onNext(response) responseObserver.onCompleted() } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index 365b17632a742..27c57e0d759f0 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -64,6 +64,7 @@ class ExecuteEventsManagerSuite DEFAULT_USER_NAME, DEFAULT_TEXT, Some(events.executeHolder.request), + Set.empty, Map.empty)) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 80f7eaf00ed22..26fdb86d29903 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2975,6 +2975,7 @@ object SparkContext extends Logging { /** Separator of tags in SPARK_JOB_TAGS property */ private[spark] val SPARK_JOB_TAGS_SEP = "," + // Same rules apply to Spark Connect execution tags, see ExecuteHolder.throwIfInvalidTag private[spark] def throwIfInvalidTag(tag: String) = { if (tag == null) { throw new IllegalArgumentException("Spark job tag cannot be null.") diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md b/docs/sql-error-conditions-invalid-handle-error-class.md new file mode 100644 index 0000000000000..7c083bc5f50c8 --- /dev/null +++ b/docs/sql-error-conditions-invalid-handle-error-class.md @@ -0,0 +1,36 @@ +--- +layout: global +title: INVALID_HANDLE error class +displayTitle: INVALID_HANDLE error class +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +[SQLSTATE: HY000](sql-error-conditions-sqlstates.html#class-HY-cli-specific-condition) + +The handle `` is invalid. + +This error class has the following derived error classes: + +## ALREADY_EXISTS + +Handle already exists. + +## FORMAT + +Handle has invalid format. Handle must an UUID string of the format '00112233-4455-6677-8899-aabbccddeeff' + + diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 5686324a0558b..6dbbf7bf05cd2 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -910,6 +910,14 @@ For more details see [INVALID_FORMAT](sql-error-conditions-invalid-format-error- The fraction of sec must be zero. Valid range is [0, 60]. If necessary set `` to "false" to bypass this error. +### [INVALID_HANDLE](sql-error-conditions-invalid-handle-error-class.html) + +[SQLSTATE: HY000](sql-error-conditions-sqlstates.html#class-HY-cli-specific-condition) + +The handle `` is invalid. + +For more details see [INVALID_HANDLE](sql-error-conditions-invalid-handle-error-class.html) + ### INVALID_HIVE_COLUMN_NAME SQLSTATE: none assigned diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 7bf93ed58fa86..04044d4cdcf3f 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\x85\x03\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x1aX\n\rRequestOption\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0e\n\x0c_client_type"\xe5\r\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a`\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06valuesB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xc5\x02\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType"G\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x42\x0e\n\x0c_client_type"2\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId2\xa4\x04\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xd2\x03\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1aX\n\rRequestOption\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\x88\x0e\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a`\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06valuesB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds2\xa4\x04\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -115,75 +115,75 @@ _ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_start = 4482 _ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_end = 4565 _EXECUTEPLANREQUEST._serialized_start = 4578 - _EXECUTEPLANREQUEST._serialized_end = 4967 - _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4863 - _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 4951 - _EXECUTEPLANRESPONSE._serialized_start = 4970 - _EXECUTEPLANRESPONSE._serialized_end = 6735 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 5966 - _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6037 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6039 - _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 6100 - _EXECUTEPLANRESPONSE_METRICS._serialized_start = 6103 - _EXECUTEPLANRESPONSE_METRICS._serialized_end = 6620 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 6198 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 6530 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 6407 - _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 6530 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6532 - _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6620 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6622 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 6718 - _KEYVALUE._serialized_start = 6737 - _KEYVALUE._serialized_end = 6802 - _CONFIGREQUEST._serialized_start = 6805 - _CONFIGREQUEST._serialized_end = 7833 - _CONFIGREQUEST_OPERATION._serialized_start = 7025 - _CONFIGREQUEST_OPERATION._serialized_end = 7523 - _CONFIGREQUEST_SET._serialized_start = 7525 - _CONFIGREQUEST_SET._serialized_end = 7577 - _CONFIGREQUEST_GET._serialized_start = 7579 - _CONFIGREQUEST_GET._serialized_end = 7604 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 7606 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 7669 - _CONFIGREQUEST_GETOPTION._serialized_start = 7671 - _CONFIGREQUEST_GETOPTION._serialized_end = 7702 - _CONFIGREQUEST_GETALL._serialized_start = 7704 - _CONFIGREQUEST_GETALL._serialized_end = 7752 - _CONFIGREQUEST_UNSET._serialized_start = 7754 - _CONFIGREQUEST_UNSET._serialized_end = 7781 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 7783 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 7817 - _CONFIGRESPONSE._serialized_start = 7835 - _CONFIGRESPONSE._serialized_end = 7957 - _ADDARTIFACTSREQUEST._serialized_start = 7960 - _ADDARTIFACTSREQUEST._serialized_end = 8831 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8347 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8400 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8402 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8513 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8515 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 8608 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 8611 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 8804 - _ADDARTIFACTSRESPONSE._serialized_start = 8834 - _ADDARTIFACTSRESPONSE._serialized_end = 9022 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 8941 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9022 - _ARTIFACTSTATUSESREQUEST._serialized_start = 9025 - _ARTIFACTSTATUSESREQUEST._serialized_end = 9220 - _ARTIFACTSTATUSESRESPONSE._serialized_start = 9223 - _ARTIFACTSTATUSESRESPONSE._serialized_end = 9491 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9334 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9374 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9376 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9491 - _INTERRUPTREQUEST._serialized_start = 9494 - _INTERRUPTREQUEST._serialized_end = 9819 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 9732 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 9803 - _INTERRUPTRESPONSE._serialized_start = 9821 - _INTERRUPTRESPONSE._serialized_end = 9871 - _SPARKCONNECTSERVICE._serialized_start = 9874 - _SPARKCONNECTSERVICE._serialized_end = 10422 + _EXECUTEPLANREQUEST._serialized_end = 5044 + _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4923 + _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5011 + _EXECUTEPLANRESPONSE._serialized_start = 5047 + _EXECUTEPLANRESPONSE._serialized_end = 6847 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6078 + _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6149 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6151 + _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 6212 + _EXECUTEPLANRESPONSE_METRICS._serialized_start = 6215 + _EXECUTEPLANRESPONSE_METRICS._serialized_end = 6732 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 6310 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 6642 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 6519 + _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 6642 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6644 + _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6732 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6734 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 6830 + _KEYVALUE._serialized_start = 6849 + _KEYVALUE._serialized_end = 6914 + _CONFIGREQUEST._serialized_start = 6917 + _CONFIGREQUEST._serialized_end = 7945 + _CONFIGREQUEST_OPERATION._serialized_start = 7137 + _CONFIGREQUEST_OPERATION._serialized_end = 7635 + _CONFIGREQUEST_SET._serialized_start = 7637 + _CONFIGREQUEST_SET._serialized_end = 7689 + _CONFIGREQUEST_GET._serialized_start = 7691 + _CONFIGREQUEST_GET._serialized_end = 7716 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 7718 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 7781 + _CONFIGREQUEST_GETOPTION._serialized_start = 7783 + _CONFIGREQUEST_GETOPTION._serialized_end = 7814 + _CONFIGREQUEST_GETALL._serialized_start = 7816 + _CONFIGREQUEST_GETALL._serialized_end = 7864 + _CONFIGREQUEST_UNSET._serialized_start = 7866 + _CONFIGREQUEST_UNSET._serialized_end = 7893 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 7895 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 7929 + _CONFIGRESPONSE._serialized_start = 7947 + _CONFIGRESPONSE._serialized_end = 8069 + _ADDARTIFACTSREQUEST._serialized_start = 8072 + _ADDARTIFACTSREQUEST._serialized_end = 8943 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8459 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8512 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8514 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8625 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8627 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 8720 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 8723 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 8916 + _ADDARTIFACTSRESPONSE._serialized_start = 8946 + _ADDARTIFACTSRESPONSE._serialized_end = 9134 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9053 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9134 + _ARTIFACTSTATUSESREQUEST._serialized_start = 9137 + _ARTIFACTSTATUSESREQUEST._serialized_end = 9332 + _ARTIFACTSTATUSESRESPONSE._serialized_start = 9335 + _ARTIFACTSTATUSESRESPONSE._serialized_end = 9603 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9446 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9486 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9488 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9603 + _INTERRUPTREQUEST._serialized_start = 9606 + _INTERRUPTREQUEST._serialized_end = 10078 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 9921 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10049 + _INTERRUPTRESPONSE._serialized_start = 10080 + _INTERRUPTRESPONSE._serialized_end = 10171 + _SPARKCONNECTSERVICE._serialized_start = 10174 + _SPARKCONNECTSERVICE._serialized_end = 10722 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 633058f33ed4e..651438ea4385d 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -1031,9 +1031,11 @@ class ExecutePlanRequest(google.protobuf.message.Message): SESSION_ID_FIELD_NUMBER: builtins.int USER_CONTEXT_FIELD_NUMBER: builtins.int + OPERATION_ID_FIELD_NUMBER: builtins.int PLAN_FIELD_NUMBER: builtins.int CLIENT_TYPE_FIELD_NUMBER: builtins.int REQUEST_OPTIONS_FIELD_NUMBER: builtins.int + TAGS_FIELD_NUMBER: builtins.int session_id: builtins.str """(Required) @@ -1048,6 +1050,12 @@ class ExecutePlanRequest(google.protobuf.message.Message): user_context.user_id and session+id both identify a unique remote spark session on the server side. """ + operation_id: builtins.str + """(Optional) + Provide an id for this request. If not provided, it will be generated by the server. + It is returned in every ExecutePlanResponse.operation_id of the ExecutePlan response stream. + The id must be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` + """ @property def plan(self) -> global___Plan: """(Required) The logical plan to be executed / analyzed.""" @@ -1065,23 +1073,37 @@ class ExecutePlanRequest(google.protobuf.message.Message): """Repeated element for options that can be passed to the request. This element is currently unused but allows to pass in an extension value used for arbitrary options. """ + @property + def tags( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Tags to tag the given execution with. + Tags cannot contain ',' character and cannot be empty strings. + Used by Interrupt with interrupt.tag. + """ def __init__( self, *, session_id: builtins.str = ..., user_context: global___UserContext | None = ..., + operation_id: builtins.str | None = ..., plan: global___Plan | None = ..., client_type: builtins.str | None = ..., request_options: collections.abc.Iterable[global___ExecutePlanRequest.RequestOption] | None = ..., + tags: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ "_client_type", b"_client_type", + "_operation_id", + b"_operation_id", "client_type", b"client_type", + "operation_id", + b"operation_id", "plan", b"plan", "user_context", @@ -1093,21 +1115,32 @@ class ExecutePlanRequest(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "_client_type", b"_client_type", + "_operation_id", + b"_operation_id", "client_type", b"client_type", + "operation_id", + b"operation_id", "plan", b"plan", "request_options", b"request_options", "session_id", b"session_id", + "tags", + b"tags", "user_context", b"user_context", ], ) -> None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] ) -> typing_extensions.Literal["client_type"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_operation_id", b"_operation_id"] + ) -> typing_extensions.Literal["operation_id"] | None: ... global___ExecutePlanRequest = ExecutePlanRequest @@ -1290,6 +1323,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): ) -> None: ... SESSION_ID_FIELD_NUMBER: builtins.int + OPERATION_ID_FIELD_NUMBER: builtins.int ARROW_BATCH_FIELD_NUMBER: builtins.int SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int WRITE_STREAM_OPERATION_START_RESULT_FIELD_NUMBER: builtins.int @@ -1301,6 +1335,12 @@ class ExecutePlanResponse(google.protobuf.message.Message): OBSERVED_METRICS_FIELD_NUMBER: builtins.int SCHEMA_FIELD_NUMBER: builtins.int session_id: builtins.str + operation_id: builtins.str + """Identifies the ExecutePlan execution. + If set by the client in ExecutePlanRequest.operationId, that value is returned. + Otherwise generated by the server. + It is an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff` + """ @property def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ... @property @@ -1348,6 +1388,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): self, *, session_id: builtins.str = ..., + operation_id: builtins.str = ..., arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ..., sql_command_result: global___ExecutePlanResponse.SqlCommandResult | None = ..., write_stream_operation_start_result: pyspark.sql.connect.proto.commands_pb2.WriteStreamOperationStartResult @@ -1402,6 +1443,8 @@ class ExecutePlanResponse(google.protobuf.message.Message): b"metrics", "observed_metrics", b"observed_metrics", + "operation_id", + b"operation_id", "response_type", b"response_type", "schema", @@ -2208,17 +2251,27 @@ class InterruptRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest._InterruptType.ValueType # 0 INTERRUPT_TYPE_ALL: InterruptRequest._InterruptType.ValueType # 1 - """Interrupt all running executions within session with provided session_id.""" + """Interrupt all running executions within the session with the provided session_id.""" + INTERRUPT_TYPE_TAG: InterruptRequest._InterruptType.ValueType # 2 + """Interrupt all running executions within the session with the provided operation_tag.""" + INTERRUPT_TYPE_OPERATION_ID: InterruptRequest._InterruptType.ValueType # 3 + """Interrupt the running execution within the session with the provided operation_id.""" class InterruptType(_InterruptType, metaclass=_InterruptTypeEnumTypeWrapper): ... INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest.InterruptType.ValueType # 0 INTERRUPT_TYPE_ALL: InterruptRequest.InterruptType.ValueType # 1 - """Interrupt all running executions within session with provided session_id.""" + """Interrupt all running executions within the session with the provided session_id.""" + INTERRUPT_TYPE_TAG: InterruptRequest.InterruptType.ValueType # 2 + """Interrupt all running executions within the session with the provided operation_tag.""" + INTERRUPT_TYPE_OPERATION_ID: InterruptRequest.InterruptType.ValueType # 3 + """Interrupt the running execution within the session with the provided operation_id.""" SESSION_ID_FIELD_NUMBER: builtins.int USER_CONTEXT_FIELD_NUMBER: builtins.int CLIENT_TYPE_FIELD_NUMBER: builtins.int INTERRUPT_TYPE_FIELD_NUMBER: builtins.int + OPERATION_TAG_FIELD_NUMBER: builtins.int + OPERATION_ID_FIELD_NUMBER: builtins.int session_id: builtins.str """(Required) @@ -2236,6 +2289,10 @@ class InterruptRequest(google.protobuf.message.Message): """ interrupt_type: global___InterruptRequest.InterruptType.ValueType """(Required) The type of interrupt to execute.""" + operation_tag: builtins.str + """if interrupt_tag == INTERRUPT_TYPE_TAG, interrupt operation with this tag.""" + operation_id: builtins.str + """if interrupt_tag == INTERRUPT_TYPE_OPERATION_ID, interrupt operation with this operation_id.""" def __init__( self, *, @@ -2243,6 +2300,8 @@ class InterruptRequest(google.protobuf.message.Message): user_context: global___UserContext | None = ..., client_type: builtins.str | None = ..., interrupt_type: global___InterruptRequest.InterruptType.ValueType = ..., + operation_tag: builtins.str = ..., + operation_id: builtins.str = ..., ) -> None: ... def HasField( self, @@ -2251,6 +2310,12 @@ class InterruptRequest(google.protobuf.message.Message): b"_client_type", "client_type", b"client_type", + "interrupt", + b"interrupt", + "operation_id", + b"operation_id", + "operation_tag", + b"operation_tag", "user_context", b"user_context", ], @@ -2262,17 +2327,28 @@ class InterruptRequest(google.protobuf.message.Message): b"_client_type", "client_type", b"client_type", + "interrupt", + b"interrupt", "interrupt_type", b"interrupt_type", + "operation_id", + b"operation_id", + "operation_tag", + b"operation_tag", "session_id", b"session_id", "user_context", b"user_context", ], ) -> None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] ) -> typing_extensions.Literal["client_type"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["interrupt", b"interrupt"] + ) -> typing_extensions.Literal["operation_tag", "operation_id"] | None: ... global___InterruptRequest = InterruptRequest @@ -2280,14 +2356,25 @@ class InterruptResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor SESSION_ID_FIELD_NUMBER: builtins.int + INTERRUPTED_IDS_FIELD_NUMBER: builtins.int session_id: builtins.str + """Session id in which the interrupt was running.""" + @property + def interrupted_ids( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Operation ids of the executions which were interrupted.""" def __init__( self, *, session_id: builtins.str = ..., + interrupted_ids: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def ClearField( - self, field_name: typing_extensions.Literal["session_id", b"session_id"] + self, + field_name: typing_extensions.Literal[ + "interrupted_ids", b"interrupted_ids", "session_id", b"session_id" + ], ) -> None: ... global___InterruptResponse = InterruptResponse From a8e2977fd9b6e3224b014e0b0572a4d7b83c1106 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 22 Jul 2023 13:11:04 +0800 Subject: [PATCH 041/986] [SPARK-44510][UI] Update dataTables to 1.13.5 and remove some unreached png files ### What changes were proposed in this pull request? This PR updates datatables from 1.13.2 to 1.13.5, related license files, and removes some pictures for sorting orientation but unused. FYI, https://cdn.datatables.net/releases.html ### Why are the changes needed? updating web resources and cleanup ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? built and tested locally ![image](https://github.com/apache/spark/assets/8326978/87a899d2-a45a-4aba-8c65-ab694919923d) Closes #42108 from yaooqinn/SPARK-44510. Authored-by: Kent Yao Signed-off-by: Kent Yao (cherry picked from commit 6409284963d6c5edecc374db4027dee7f6f490c1) Signed-off-by: Kent Yao --- .../dataTables.bootstrap4.1.13.2.min.css | 1 - .../dataTables.bootstrap4.1.13.2.min.js | 4 --- .../dataTables.bootstrap4.1.13.5.min.css | 1 + .../dataTables.bootstrap4.1.13.5.min.js | 4 +++ .../spark/ui/static/images/sort_asc.png | Bin 160 -> 0 bytes .../ui/static/images/sort_asc_disabled.png | Bin 148 -> 0 bytes .../spark/ui/static/images/sort_both.png | Bin 201 -> 0 bytes .../spark/ui/static/images/sort_desc.png | Bin 158 -> 0 bytes .../ui/static/images/sort_desc_disabled.png | Bin 146 -> 0 bytes .../ui/static/jquery.dataTables.1.13.2.min.js | 4 --- .../ui/static/jquery.dataTables.1.13.5.min.js | 4 +++ .../scala/org/apache/spark/ui/UIUtils.scala | 6 ++--- dev/.rat-excludes | 6 ++--- licenses-binary/LICENSE-datatables.txt | 25 +++++++++++++++--- licenses/LICENSE-datatables.txt | 25 +++++++++++++++--- 15 files changed, 57 insertions(+), 23 deletions(-) delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_asc.png delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_asc_disabled.png delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_both.png delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_desc.png delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/images/sort_desc_disabled.png delete mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.13.2.min.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.13.5.min.js diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css deleted file mode 100644 index b9c16ca78a01c..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css +++ /dev/null @@ -1 +0,0 @@ -:root{--dt-row-selected: 2, 117, 216;--dt-row-selected-text: 255, 255, 255;--dt-row-selected-link: 9, 10, 11}table.dataTable td.dt-control{text-align:center;cursor:pointer}table.dataTable td.dt-control:before{height:1em;width:1em;margin-top:-9px;display:inline-block;color:white;border:.15em solid white;border-radius:1em;box-shadow:0 0 .2em #444;box-sizing:content-box;text-align:center;text-indent:0 !important;font-family:"Courier New",Courier,monospace;line-height:1em;content:"+";background-color:#31b131}table.dataTable tr.dt-hasChild td.dt-control:before{content:"-";background-color:#d33333}table.dataTable thead>tr>th.sorting,table.dataTable thead>tr>th.sorting_asc,table.dataTable thead>tr>th.sorting_desc,table.dataTable thead>tr>th.sorting_asc_disabled,table.dataTable thead>tr>th.sorting_desc_disabled,table.dataTable thead>tr>td.sorting,table.dataTable thead>tr>td.sorting_asc,table.dataTable thead>tr>td.sorting_desc,table.dataTable thead>tr>td.sorting_asc_disabled,table.dataTable thead>tr>td.sorting_desc_disabled{cursor:pointer;position:relative;padding-right:26px}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after{position:absolute;display:block;opacity:.125;right:10px;line-height:9px;font-size:.8em}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:before{bottom:50%;content:"▲"}table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:after{top:50%;content:"▼"}table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:after{opacity:.6}table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting_asc_disabled:before{display:none}table.dataTable thead>tr>th:active,table.dataTable thead>tr>td:active{outline:none}div.dataTables_scrollBody table.dataTable thead>tr>th:before,div.dataTables_scrollBody table.dataTable thead>tr>th:after,div.dataTables_scrollBody table.dataTable thead>tr>td:before,div.dataTables_scrollBody table.dataTable thead>tr>td:after{display:none}div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:2px}div.dataTables_processing>div:last-child{position:relative;width:80px;height:15px;margin:1em auto}div.dataTables_processing>div:last-child>div{position:absolute;top:0;width:13px;height:13px;border-radius:50%;background:2 117 216;animation-timing-function:cubic-bezier(0, 1, 1, 0)}div.dataTables_processing>div:last-child>div:nth-child(1){left:8px;animation:datatables-loader-1 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(2){left:8px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(3){left:32px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(4){left:56px;animation:datatables-loader-3 .6s infinite}@keyframes datatables-loader-1{0%{transform:scale(0)}100%{transform:scale(1)}}@keyframes datatables-loader-3{0%{transform:scale(1)}100%{transform:scale(0)}}@keyframes datatables-loader-2{0%{transform:translate(0, 0)}100%{transform:translate(24px, 0)}}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th,table.dataTable thead td,table.dataTable tfoot th,table.dataTable tfoot td{text-align:left}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable{clear:both;margin-top:6px !important;margin-bottom:6px !important;max-width:none !important;border-collapse:separate !important;border-spacing:0}table.dataTable td,table.dataTable th{-webkit-box-sizing:content-box;box-sizing:content-box}table.dataTable td.dataTables_empty,table.dataTable th.dataTables_empty{text-align:center}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.table-striped>tbody>tr:nth-of-type(2n+1){background-color:transparent}table.dataTable>tbody>tr{background-color:transparent}table.dataTable>tbody>tr.selected>*{box-shadow:inset 0 0 0 9999px rgb(2, 117, 216);box-shadow:inset 0 0 0 9999px rgb(var(--dt-row-selected));color:rgb(255, 255, 255);color:rgb(var(--dt-row-selected-text))}table.dataTable>tbody>tr.selected a{color:rgb(9, 10, 11);color:rgb(var(--dt-row-selected-link))}table.dataTable.table-striped>tbody>tr.odd>*{box-shadow:inset 0 0 0 9999px rgba(0, 0, 0, 0.05)}table.dataTable.table-striped>tbody>tr.odd.selected>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.95);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.95)}table.dataTable.table-hover>tbody>tr:hover>*{box-shadow:inset 0 0 0 9999px rgba(0, 0, 0, 0.075)}table.dataTable.table-hover>tbody>tr.selected:hover>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.975);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.975)}div.dataTables_wrapper div.dataTables_length label{font-weight:normal;text-align:left;white-space:nowrap}div.dataTables_wrapper div.dataTables_length select{width:auto;display:inline-block}div.dataTables_wrapper div.dataTables_filter{text-align:right}div.dataTables_wrapper div.dataTables_filter label{font-weight:normal;white-space:nowrap;text-align:left}div.dataTables_wrapper div.dataTables_filter input{margin-left:.5em;display:inline-block;width:auto}div.dataTables_wrapper div.dataTables_info{padding-top:.85em}div.dataTables_wrapper div.dataTables_paginate{margin:0;white-space:nowrap;text-align:right}div.dataTables_wrapper div.dataTables_paginate ul.pagination{margin:2px 0;white-space:nowrap;justify-content:flex-end}div.dataTables_wrapper div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:1em 0}div.dataTables_scrollHead table.dataTable{margin-bottom:0 !important}div.dataTables_scrollBody>table{border-top:none;margin-top:0 !important;margin-bottom:0 !important}div.dataTables_scrollBody>table>thead .sorting:before,div.dataTables_scrollBody>table>thead .sorting_asc:before,div.dataTables_scrollBody>table>thead .sorting_desc:before,div.dataTables_scrollBody>table>thead .sorting:after,div.dataTables_scrollBody>table>thead .sorting_asc:after,div.dataTables_scrollBody>table>thead .sorting_desc:after{display:none}div.dataTables_scrollBody>table>tbody tr:first-child th,div.dataTables_scrollBody>table>tbody tr:first-child td{border-top:none}div.dataTables_scrollFoot>.dataTables_scrollFootInner{box-sizing:content-box}div.dataTables_scrollFoot>.dataTables_scrollFootInner>table{margin-top:0 !important;border-top:none}@media screen and (max-width: 767px){div.dataTables_wrapper div.dataTables_length,div.dataTables_wrapper div.dataTables_filter,div.dataTables_wrapper div.dataTables_info,div.dataTables_wrapper div.dataTables_paginate{text-align:center}div.dataTables_wrapper div.dataTables_paginate ul.pagination{justify-content:center !important}}table.dataTable.table-sm>thead>tr>th:not(.sorting_disabled){padding-right:20px}table.table-bordered.dataTable{border-right-width:0}table.table-bordered.dataTable th,table.table-bordered.dataTable td{border-left-width:0}table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable td:last-child,table.table-bordered.dataTable td:last-child{border-right-width:1px}table.table-bordered.dataTable tbody th,table.table-bordered.dataTable tbody td{border-bottom-width:0}div.dataTables_scrollHead table.table-bordered{border-bottom-width:0}div.table-responsive>div.dataTables_wrapper>div.row{margin:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:first-child{padding-left:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:last-child{padding-right:0} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js deleted file mode 100644 index 2937bc3c90c2c..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js +++ /dev/null @@ -1,4 +0,0 @@ -/*! DataTables Bootstrap 4 integration - * ©2011-2017 SpryMedia Ltd - datatables.net/license - */ -!function(t){"function"==typeof define&&define.amd?define(["jquery","datatables.net"],function(e){return t(e,window,document)}):"object"==typeof exports?module.exports=function(e,a){return e=e||window,(a=a||("undefined"!=typeof window?require("jquery"):require("jquery")(e))).fn.dataTable||require("datatables.net")(e,a),t(a,0,e.document)}:t(jQuery,window,document)}(function(x,e,n,r){"use strict";var s=x.fn.dataTable;return x.extend(!0,s.defaults,{dom:"<'row'<'col-sm-12 col-md-6'l><'col-sm-12 col-md-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-12 col-md-5'i><'col-sm-12 col-md-7'p>>",renderer:"bootstrap"}),x.extend(s.ext.classes,{sWrapper:"dataTables_wrapper dt-bootstrap4",sFilterInput:"form-control form-control-sm",sLengthSelect:"custom-select custom-select-sm form-control form-control-sm",sProcessing:"dataTables_processing card",sPageButton:"paginate_button page-item"}),s.ext.renderer.pageButton.bootstrap=function(i,e,d,a,l,c){function u(e,a){for(var t,n,r=function(e){e.preventDefault(),x(e.currentTarget).hasClass("disabled")||b.page()==e.data.action||b.page(e.data.action).draw("page")},s=0,o=a.length;s",{class:m.sPageButton+" "+f,id:0===d&&"string"==typeof t?i.sTableId+"_"+t:null}).append(x("",{href:n?null:"#","aria-controls":i.sTableId,"aria-disabled":n?"true":null,"aria-label":w[t],"aria-role":"link","aria-current":"active"===f?"page":null,"data-dt-idx":t,tabindex:i.iTabIndex,class:"page-link"}).html(p)).appendTo(e),i.oApi._fnBindAction(n,{action:t},r))}}var p,f,t,b=new s.Api(i),m=i.oClasses,g=i.oLanguage.oPaginate,w=i.oLanguage.oAria.paginate||{};try{t=x(e).find(n.activeElement).data("dt-idx")}catch(e){}u(x(e).empty().html('