diff --git a/.devcontainer.json b/.devcontainer.json new file mode 100644 index 0000000000000..ab44853849448 --- /dev/null +++ b/.devcontainer.json @@ -0,0 +1 @@ +{"image":"mcr.microsoft.com/devcontainers/java"} \ No newline at end of file diff --git a/.devcontainer/.gitignore b/.devcontainer/.gitignore new file mode 100644 index 0000000000000..044efa3bcc5ae --- /dev/null +++ b/.devcontainer/.gitignore @@ -0,0 +1 @@ +.devpod-internal \ No newline at end of file diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000000..e68f61b5aa801 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,2 @@ +# Explicitly force the platform here to ensure the build happens on amd64 +FROM --platform=linux/amd64 mcr.microsoft.com/devcontainers/base:ubuntu-22.04 \ No newline at end of file diff --git a/.devcontainer/aws-config b/.devcontainer/aws-config new file mode 100644 index 0000000000000..ad3dd3c9788ce --- /dev/null +++ b/.devcontainer/aws-config @@ -0,0 +1,13 @@ +[profile default-engineering] +sso_start_url = https://6si-mgmt.awsapps.com/start +sso_region = us-east-1 +sso_account_id = 242358675102 +sso_role_name = BigData +region = us-east-1 +output = json +sso_session = default + +[sso-session default] +sso_start_url = https://6si-mgmt.awsapps.com/start +sso_region = us-east-1 +sso_registration_scopes = sso:account:access \ No newline at end of file diff --git a/.devcontainer/bash_additional.sh b/.devcontainer/bash_additional.sh new file mode 100644 index 0000000000000..1dd40707abf35 --- /dev/null +++ b/.devcontainer/bash_additional.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +__spark_git_branch() { + command -v git >/dev/null 2>&1 || return + git rev-parse --is-inside-work-tree >/dev/null 2>&1 || return + git branch --show-current 2>/dev/null +} + +__spark_prompt_git() { + local b + b="$(__spark_git_branch)" + if [ -n "$b" ]; then + printf '(%s)' "$b" + fi +} + +if [ -n "$DEVPOD_WORKSPACE_ID" ]; then + export PS1="\[\e[32m\]\u@\h\[\e[33m\]($DEVPOD_WORKSPACE_ID)\[\e[34m\]:\w \[\e[33m\]$(__spark_prompt_git)\[\e[0m\]\n\$ " +else + export PS1="\[\e[32m\]\u@\h:\[\e[34m\]\w \[\e[33m\]$(__spark_prompt_git)\[\e[0m\]\n\$ " +fi diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000..d0b3ad6c76bf6 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,52 @@ +{ + // Devcontainer config for x86_64 (amd64) development + "name": "Apache Spark Development (Multi-JDK, x86_64)", + // CHANGE: Use 'build' instead of 'image' + "build": { + "dockerfile": "Dockerfile" + }, + // Force running the container as linux/amd64 via Docker run args + "runArgs": ["--platform=linux/amd64"], + + "features": { + // Installs SDKMAN and multiple Java versions + // 1. Install multiple JDKs (8, 11, 17, 21) + "ghcr.io/devcontainers/features/java:1": { + "version": "17", + "installMaven": "true", + "installGradle": "false", + "jdkDistro": "amzn" + }, + // Common utilities for Spark building + "ghcr.io/devcontainers/features/common-utils:1": { + "installZsh": "true", + "configureZshAsDefaultShell": "true" + } + }, + + "customizations": { + "vscode": { + "extensions": [ + "GitHub.vscode-pull-request-github", + "vscjava.vscode-java-pack", + "vscjava.vscode-maven", + "vscjava.vscode-java-dependency", + "rebornix.project-config", + "mtxr.sqltools" + ], + "settings": { + "java.configuration.runtimes": [ + { "name": "JavaSE-1.8", "path": "/usr/local/sdkman/candidates/java/8.0.402-amzn" }, + { "name": "JavaSE-11", "path": "/usr/local/sdkman/candidates/java/11.0.22-amzn" }, + { "name": "JavaSE-17", "path": "/usr/local/sdkman/candidates/java/17.0.10-amzn", "default": true } + ], + "files.watcherExclude": { + "**/target/**": true + } + } + } + }, + + "postCreateCommand": "bash .devcontainer/setup-sparkdev.sh", + "remoteUser": "vscode" +} \ No newline at end of file diff --git a/.devcontainer/setup-sparkdev.sh b/.devcontainer/setup-sparkdev.sh new file mode 100644 index 0000000000000..68d4dde621384 --- /dev/null +++ b/.devcontainer/setup-sparkdev.sh @@ -0,0 +1,611 @@ +#!/bin/bash +set +e # Don't exit on errors - we want to track them + +#================================================================ +# Logging and Tracking Setup +#================================================================ + +# Log file paths +SETUP_LOG_DIR="/tmp/spark-setup" +SETUP_LOG_FILE="${SETUP_LOG_DIR}/setup.log" +mkdir -p "${SETUP_LOG_DIR}" + +# Arrays to track successes and failures +declare -a FAILED_COMMANDS=() +declare -a SUCCESSFUL_COMMANDS=() +declare -a WARNING_MESSAGES=() +declare -a INFO_MESSAGES=() + +# Initialize log file with timestamp +{ + echo "========================================================================" + echo "Spark Development Container Setup Log" + echo "Started: $(date '+%Y-%m-%d %H:%M:%S')" + echo "User: $(whoami)" + echo "Platform: $(uname -a)" + echo "========================================================================" + echo "" +} | tee "${SETUP_LOG_FILE}" + +# Function to log messages +log_message() { + local level="$1" + local message="$2" + local timestamp=$(date '+%H:%M:%S') + echo "[${timestamp}] [${level}] ${message}" | tee -a "${SETUP_LOG_FILE}" +} + +# Function to track failures +track_failure() { + local component="$1" + local reason="$2" + FAILED_COMMANDS+=("$component: $reason") + log_message "FAILED" "$component - $reason" +} + +# Function to track successes +track_success() { + local component="$1" + SUCCESSFUL_COMMANDS+=("$component") + log_message "SUCCESS" "$component installed successfully" +} + +# Function to track warnings +track_warning() { + local message="$1" + WARNING_MESSAGES+=("$message") + log_message "WARNING" "$message" +} + +# Function to track info +track_info() { + local message="$1" + INFO_MESSAGES+=("$message") + log_message "INFO" "$message" +} + +# Enable extended error reporting for line numbers +trap 'log_message "ERROR" "Script failed on line $LINENO"; show_final_report; exit 1' ERR + +# Ensure SDKMAN is loaded +export SDKMAN_DIR="/usr/local/sdkman" +if [[ -s "${SDKMAN_DIR}/bin/sdkman-init.sh" ]]; then + source "${SDKMAN_DIR}/bin/sdkman-init.sh" + track_info "SDKMAN loaded successfully" +else + track_warning "SDKMAN initialization file not found at ${SDKMAN_DIR}/bin/sdkman-init.sh" +fi + +#================================================================ +# Helper Functions for Version Detection +#================================================================ + +# Function to get available versions for a given SDK +get_available_versions() { + local sdk_name="$1" + sdk list "$sdk_name" 2>/dev/null | grep -oE '[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9]+)?' | sort -V | uniq || echo "" +} + +# Function to find the latest version matching a pattern +find_latest_version() { + local sdk_name="$1" + local pattern="$2" # e.g., "17\.|21\." for Java versions + + get_available_versions "$sdk_name" | grep -E "$pattern" | tail -1 +} + +# Function to install an SDK with error handling and logging +install_sdk_with_retry() { + local sdk_name="$1" + local version="$2" + local default_choice="${3:-y}" # Default to 'y' (set as default) + + if [ -z "$version" ]; then + track_warning "No version found matching criteria for $sdk_name" + return 1 + fi + + log_message "ACTION" "Installing $sdk_name $version..." + + # Suppress SDKMAN post-installation hook errors, but capture the real output + local sdk_output + local sdk_exit_code + + # Run installation, temporarily ignoring hook errors + sdk_output=$(echo "$default_choice" | sdk install "$sdk_name" "$version" 2>&1) + sdk_exit_code=$? + + # Log the full output + echo "$sdk_output" >> "${SETUP_LOG_FILE}" + + # Check if the output contains the post-installation hook error + # This error is non-critical and often occurs even when installation succeeds + if echo "$sdk_output" | grep -q "__sdkman_post_installation_hook: command not found"; then + track_warning "$sdk_name ($version): SDKMAN post-installation hook error (non-critical)" + # Continue to verify installation + sdk_exit_code=0 + fi + + # Verify installation by checking if binary exists + if [ $sdk_exit_code -eq 0 ] || grep -q "Successfully installed" <<< "$sdk_output"; then + # Reload SDKMAN to ensure the new version is available + if [[ -s "${SDKMAN_DIR}/bin/sdkman-init.sh" ]]; then + source "${SDKMAN_DIR}/bin/sdkman-init.sh" 2>/dev/null || true + fi + + # Double-check by trying to get version + local version_check=$(sdk current "$sdk_name" 2>/dev/null || echo "") + if [ -n "$version_check" ]; then + track_success "$sdk_name ($version)" + return 0 + else + # If we can't verify immediately, consider it success if SDK said so + if grep -q "Successfully installed" <<< "$sdk_output"; then + track_success "$sdk_name ($version)" + return 0 + fi + fi + fi + + track_failure "$sdk_name ($version)" "Installation failed (exit code: $sdk_exit_code)" + return 1 +} + +#================================================================ +# Java Installation +#================================================================ +echo "" +echo "=== Installing Java versions ===" +log_message "STAGE" "Starting Java version installation" + +echo "Detecting available Java versions..." +track_info "Detecting available Java versions on this platform" + +# Find available Java versions +java11=$(find_latest_version "java" "11\.") +java17=$(find_latest_version "java" "17\.") +java21=$(find_latest_version "java" "21\.") + +# Log detected versions +[ -n "$java11" ] && track_info "Detected Java 11: $java11" || track_warning "Java 11 not available" +[ -n "$java17" ] && track_info "Detected Java 17: $java17" || track_warning "Java 17 not available" +[ -n "$java21" ] && track_info "Detected Java 21: $java21" || track_warning "Java 21 not available" + +# Prefer Java 21, then 17, then 11 +preferred_java="${java21:-${java17:-${java11}}}" + +if [ -n "$java11" ]; then + install_sdk_with_retry "java" "$java11" "y" || true +fi + +if [ -n "$java17" ]; then + install_sdk_with_retry "java" "$java17" "n" || true +fi + +if [ -n "$java21" ]; then + install_sdk_with_retry "java" "$java21" "n" || true +fi + +if [ -n "$preferred_java" ]; then + track_info "Default Java version set to: $preferred_java" +else + track_failure "Java Installation" "No Java versions could be installed from SDKMAN" +fi + +#================================================================ +# Maven Installation +#================================================================ +echo "" +echo "=== Installing Maven ===" +log_message "STAGE" "Starting Maven installation" + +echo "Detecting available Maven versions..." +track_info "Detecting available Maven versions" + +# Find latest Maven 3.9 or 3.8 +maven39=$(find_latest_version "maven" "3\.9\.") +maven38=$(find_latest_version "maven" "3\.8\.") + +# Log detected versions +[ -n "$maven39" ] && track_info "Detected Maven 3.9: $maven39" || track_warning "Maven 3.9 not available" +[ -n "$maven38" ] && track_info "Detected Maven 3.8: $maven38" || track_warning "Maven 3.8 not available" + +# Prefer Maven 3.9, then 3.8 +preferred_maven="${maven39:-${maven38}}" + +if [ -n "$preferred_maven" ]; then + install_sdk_with_retry "maven" "$preferred_maven" "y" || { + track_failure "Maven Installation" "Failed to install Maven $preferred_maven" + } +else + track_failure "Maven Installation" "No Maven 3.8+ versions found - SDKMAN may be unavailable or versions not released yet" +fi + + +#================================================================ +# Just (Task Runner) Installation +#================================================================ +echo "" +echo "=== Installing Just (task runner) ===" +log_message "STAGE" "Starting Just installation" + +install_just() { + echo "Fetching latest Just release..." + track_info "Fetching latest Just release information" + + JUST_VERSION=$(curl -s \ + -H "Accept: application/vnd.github+json" \ + -H "User-Agent: curl" \ + https://api.github.com/repos/casey/just/releases/latest \ + | jq -r '.tag_name // empty' \ + | sed 's/^v//') + + if [ -z "$JUST_VERSION" ]; then + # Fallback version + JUST_VERSION="1.46.0" + track_warning "Could not detect latest Just version, using fallback: $JUST_VERSION" + echo " Using fallback version: $JUST_VERSION" + else + track_info "Found Just version: $JUST_VERSION" + echo " Found Just version: $JUST_VERSION" + fi + + JUST_URL="https://github.com/casey/just/releases/download/${JUST_VERSION}/just-${JUST_VERSION}-x86_64-unknown-linux-musl.tar.gz" + + if wget -q -O /tmp/just.tar.gz "$JUST_URL" 2>/dev/null; then + if tar -xzf /tmp/just.tar.gz -C /tmp >> "${SETUP_LOG_FILE}" 2>&1 && sudo mv /tmp/just /usr/local/bin/ >> "${SETUP_LOG_FILE}" 2>&1; then + track_success "Just ($JUST_VERSION)" + echo " ✓ Successfully installed Just $JUST_VERSION" + rm -f /tmp/just.tar.gz + return 0 + else + track_failure "Just ($JUST_VERSION)" "Failed to extract or move Just binary" + fi + else + track_failure "Just ($JUST_VERSION)" "Failed to download from GitHub" + fi + + echo " ✗ Warning: Could not install Just - continuing without it" + rm -f /tmp/just.tar.gz + return 1 +} + +install_just || true + +#================================================================ +# AWS Tools Installation +#================================================================ +echo "" +echo "=== Installing AWS tools ===" +log_message "STAGE" "Starting AWS CLI installation" + +install_aws_cli() { + echo "Installing AWS CLI v2..." + track_info "Installing AWS CLI v2" + + # Update package manager + echo "Updating system packages..." + if sudo apt-get update -qq >> "${SETUP_LOG_FILE}" 2>&1 && \ + sudo apt-get install -y -qq jq curl git unzip >> "${SETUP_LOG_FILE}" 2>&1; then + track_info "System packages updated successfully" + else + track_warning "Some system packages could not be installed" + fi + + # Download and install AWS CLI + echo "Downloading AWS CLI..." + if curl -sSL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "/tmp/awscliv2.zip" >> "${SETUP_LOG_FILE}" 2>&1; then + track_info "AWS CLI archive downloaded" + + if unzip -q /tmp/awscliv2.zip -d /tmp >> "${SETUP_LOG_FILE}" 2>&1; then + track_info "AWS CLI archive extracted" + + if sudo /tmp/aws/install --update >> "${SETUP_LOG_FILE}" 2>&1; then + track_success "AWS CLI v2" + echo " ✓ AWS CLI v2 installed successfully" + rm -rf /tmp/awscliv2.zip /tmp/aws + return 0 + else + track_failure "AWS CLI v2" "Installation of AWS CLI binaries failed" + fi + else + track_failure "AWS CLI v2" "Failed to extract AWS CLI archive" + fi + rm -f /tmp/awscliv2.zip /tmp/aws + else + track_failure "AWS CLI v2" "Failed to download from awscli.amazonaws.com" + fi + + echo " ✗ Warning: Could not install AWS CLI - continuing without it" + return 1 +} + +install_aws_cli || true + +# Configure AWS if config file exists +echo "Configuring AWS..." +if [ -f "/workspaces/spark/.devcontainer/aws-config" ]; then + mkdir -p ~/.aws + if cp /workspaces/spark/.devcontainer/aws-config ~/.aws/config >> "${SETUP_LOG_FILE}" 2>&1; then + track_success "AWS Configuration" + echo " ✓ AWS configuration copied" + else + track_failure "AWS Configuration" "Failed to copy AWS config file" + fi +else + track_info "AWS config file not found - skipping AWS configuration" + echo " ℹ AWS config file not found - skipping AWS configuration" +fi + + +#================================================================ +# Shell Configuration +#================================================================ +echo "" +echo "=== Configuring shell environment ===" + +# Set bash as default shell +if sudo chsh -s /bin/bash vscode 2>/dev/null; then + echo " ✓ Set bash as default shell" +else + echo " ℹ Could not set bash as default (already set or insufficient permissions)" +fi + +# Configure bash prompt and environment +PROMPT_FILE_SRC="/workspaces/spark/.devcontainer/bash_additional.sh" +PROMPT_FILE_DST="/home/vscode/.bash_additional.sh" + +if [ -f "${PROMPT_FILE_SRC}" ]; then + cp "${PROMPT_FILE_SRC}" "${PROMPT_FILE_DST}" + chown vscode:vscode "${PROMPT_FILE_DST}" 2>/dev/null || true + chmod 0644 "${PROMPT_FILE_DST}" 2>/dev/null || true + echo " ✓ Bash additional config installed" +else + echo " ℹ Bash additional config not found - using defaults" +fi + +# Add source line to .bashrc if not already present +if ! grep -q 'bash_additional\.sh' /home/vscode/.bashrc 2>/dev/null; then + cat >> /home/vscode/.bashrc << 'EOF' + +if [ -f "$HOME/.bash_additional.sh" ]; then + source "$HOME/.bash_additional.sh" +fi +EOF + echo " ✓ Updated .bashrc with bash_additional.sh sourcing" +fi + +#================================================================ +# Spark Dependencies Pre-fetching +#================================================================ +echo "" +echo "=== Pre-fetching Spark build dependencies ===" +echo "This may take several minutes on first run..." + +if [ -d "/workspaces/spark" ] && [ -f "/workspaces/spark/pom.xml" ]; then + cd /workspaces/spark + + # Pre-fetch Maven dependencies (non-critical, so we don't fail on error) + if command -v mvn &> /dev/null; then + if mvn dependency:go-offline -DskipTests -q 2>/dev/null; then + echo " ✓ Successfully pre-fetched dependencies" + else + echo " ⚠️ Partial dependency fetch completed (some mirrors may be unavailable)" + fi + else + echo " ℹ Maven not available - skipping dependency pre-fetch" + fi +else + echo " ℹ Workspace or pom.xml not yet available - skipping dependency pre-fetch" +fi + +#================================================================ +# Shell Configuration +#================================================================ +echo "" +echo "=== Configuring shell environment ===" +log_message "STAGE" "Configuring shell environment" + +# Set bash as default shell +if sudo chsh -s /bin/bash vscode 2>/dev/null; then + track_success "Bash Default Shell" + echo " ✓ Set bash as default shell" +else + track_warning "Could not set bash as default (already set or insufficient permissions)" + echo " ℹ Could not set bash as default (already set or insufficient permissions)" +fi + +# Configure bash prompt and environment +PROMPT_FILE_SRC="/workspaces/spark/.devcontainer/bash_additional.sh" +PROMPT_FILE_DST="/home/vscode/.bash_additional.sh" + +if [ -f "${PROMPT_FILE_SRC}" ]; then + if cp "${PROMPT_FILE_SRC}" "${PROMPT_FILE_DST}" >> "${SETUP_LOG_FILE}" 2>&1; then + chown vscode:vscode "${PROMPT_FILE_DST}" 2>/dev/null || true + chmod 0644 "${PROMPT_FILE_DST}" 2>/dev/null || true + track_success "Bash Additional Config" + echo " ✓ Bash additional config installed" + else + track_failure "Bash Additional Config" "Failed to copy bash_additional.sh" + fi +else + track_info "Bash additional config file not found - using defaults" + echo " ℹ Bash additional config not found - using defaults" +fi + +# Add source line to .bashrc if not already present +if ! grep -q 'bash_additional\.sh' /home/vscode/.bashrc 2>/dev/null; then + if cat >> /home/vscode/.bashrc << 'EOF' >> "${SETUP_LOG_FILE}" 2>&1 +if [ -f "$HOME/.bash_additional.sh" ]; then + source "$HOME/.bash_additional.sh" +fi +EOF + then + track_success "Bashrc Update" + echo " ✓ Updated .bashrc with bash_additional.sh sourcing" + else + track_failure "Bashrc Update" "Failed to update .bashrc" + fi +fi + +#================================================================ +# Spark Dependencies Pre-fetching +#================================================================ +echo "" +echo "=== Pre-fetching Spark build dependencies ===" +log_message "STAGE" "Pre-fetching Spark dependencies" + +echo "This may take several minutes on first run..." +track_info "Starting dependency pre-fetch (this may take several minutes)" + +if [ -d "/workspaces/spark" ] && [ -f "/workspaces/spark/pom.xml" ]; then + cd /workspaces/spark + + # Pre-fetch Maven dependencies (non-critical, so we don't fail on error) + if command -v mvn &> /dev/null; then + echo "Running: mvn dependency:go-offline..." + if mvn dependency:go-offline -DskipTests -q >> "${SETUP_LOG_FILE}" 2>&1; then + track_success "Spark Dependencies Pre-fetch" + echo " ✓ Successfully pre-fetched dependencies" + else + track_warning "Partial dependency fetch completed (some mirrors may be unavailable)" + echo " ⚠️ Partial dependency fetch completed (some mirrors may be unavailable)" + fi + else + track_failure "Spark Dependencies Pre-fetch" "Maven not available - skipping dependency pre-fetch" + echo " ℹ Maven not available - skipping dependency pre-fetch" + fi +else + track_info "Workspace or pom.xml not yet available - skipping dependency pre-fetch" + echo " ℹ Workspace or pom.xml not yet available - skipping dependency pre-fetch" +fi + +#================================================================ +# Final Report Function +#================================================================ + +show_final_report() { + local success_count=${#SUCCESSFUL_COMMANDS[@]} + local failure_count=${#FAILED_COMMANDS[@]} + local warning_count=${#WARNING_MESSAGES[@]} + + echo "" + log_message "REPORT" "Setup process completed" + log_message "STATS" "Successes: $success_count, Failures: $failure_count, Warnings: $warning_count" + + echo "" + echo "╔════════════════════════════════════════════════════════════════╗" + echo "║ Spark Development Container Setup Report ║" + echo "╚════════════════════════════════════════════════════════════════╝" + echo "" + + # Summary stats + echo "📊 Setup Summary:" + echo " ✓ Successful: $success_count" + echo " ✗ Failed: $failure_count" + echo " ⚠️ Warnings: $warning_count" + echo "" + + # Successful installations + if [ $success_count -gt 0 ]; then + echo "✅ Successfully Installed:" + for cmd in "${SUCCESSFUL_COMMANDS[@]}"; do + echo " • $cmd" + done + echo "" + fi + + # Failed installations - IMPORTANT FOR DEVELOPERS + if [ $failure_count -gt 0 ]; then + echo "❌ Failed Installations (Developer Action Required):" + for failed in "${FAILED_COMMANDS[@]}"; do + echo " • $failed" + done + echo "" + echo "💡 Troubleshooting Steps:" + echo " 1. Check the full log at: ${SETUP_LOG_FILE}" + echo " 2. For SDKMAN issues:" + echo " - Check SDKMAN is installed: ls ${SDKMAN_DIR}/bin/sdkman-init.sh" + echo " - Reload SDKMAN: source ${SDKMAN_DIR}/bin/sdkman-init.sh" + echo " - List available versions: sdk list java (or maven)" + echo " 3. For network issues, verify internet connectivity" + echo " 4. Try running the setup script again" + echo "" + fi + + # Warnings + if [ $warning_count -gt 0 ]; then + echo "⚠️ Warnings (Review Recommended):" + for warning in "${WARNING_MESSAGES[@]}"; do + echo " • $warning" + done + echo "" + fi + + # Installed versions + echo "🔍 Installed Component Versions:" + if command -v java &> /dev/null; then + echo " • Java: $(java -version 2>&1 | head -1)" + else + echo " • Java: ⚠️ Not available" + fi + + if command -v mvn &> /dev/null; then + echo " • Maven: $(mvn -v 2>&1 | head -1)" + else + echo " • Maven: ⚠️ Not available" + fi + + if command -v just &> /dev/null; then + echo " • Just: $(just --version 2>/dev/null)" + else + echo " • Just: ⚠️ Not available" + fi + + if command -v aws &> /dev/null; then + echo " • AWS CLI: $(aws --version 2>&1)" + else + echo " • AWS CLI: ⚠️ Not available" + fi + echo "" + + # Next steps + echo "📝 Next Steps:" + echo " 1. Review the full log: cat ${SETUP_LOG_FILE}" + echo " 2. Source your shell config: source ~/.bashrc" + echo " 3. Verify SDK versions: sdk current" + echo " 4. Start building: mvn clean install" + echo "" + + # Log details + echo "📂 Log File: ${SETUP_LOG_FILE}" + echo "" + + { + echo "" + echo "========================================================================" + echo "Setup completed at: $(date '+%Y-%m-%d %H:%M:%S')" + echo "Total successes: $success_count" + echo "Total failures: $failure_count" + echo "Total warnings: $warning_count" + echo "========================================================================" + } >> "${SETUP_LOG_FILE}" +} + +#================================================================ +# Setup Summary +#================================================================ +echo "" +echo "=== Setup Process Complete ===" +log_message "STAGE" "Setup process finished, generating report" + +show_final_report + +# Exit with appropriate code +if [ ${#FAILED_COMMANDS[@]} -gt 0 ]; then + echo "⚠️ Setup completed with some failures. See report above." + exit 0 # Don't fail - allow container to start even with some failures +else + echo "✅ Setup completed successfully!" + exit 0 +fi \ No newline at end of file diff --git a/.devcontainer/setup.md b/.devcontainer/setup.md new file mode 100644 index 0000000000000..38a090cf1c814 --- /dev/null +++ b/.devcontainer/setup.md @@ -0,0 +1,3 @@ +# Local setup + +> checkout /tmp/setup.log to understand setup was successful or not and take appropriate action. \ No newline at end of file diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 9e5b0f424838a..23888ed8ac64e 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -46,7 +46,7 @@ on: jobs: matrix-gen: name: Generate matrix for job splits - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} env: @@ -60,7 +60,7 @@ jobs: tpcds-1g-gen: name: "Generate an input dataset for TPCDSQueryBenchmark with SF=1" if: contains(github.event.inputs.class, 'TPCDSQueryBenchmark') || contains(github.event.inputs.class, '*') - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 env: SPARK_LOCAL_IP: localhost steps: @@ -98,7 +98,7 @@ jobs: uses: actions/checkout@v3 with: repository: databricks/tpcds-kit - ref: 2a5078a782192ddb6efbcead8de9973d6ab4f069 + ref: 1b7fb7529edae091684201fab142d956d6afd881 path: ./tpcds-kit - name: Build tpcds-kit if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' @@ -117,8 +117,7 @@ jobs: name: "Run benchmarks: ${{ github.event.inputs.class }} (JDK ${{ github.event.inputs.jdk }}, Scala ${{ github.event.inputs.scala }}, ${{ matrix.split }} out of ${{ github.event.inputs.num-splits }} splits)" if: always() needs: [matrix-gen, tpcds-1g-gen] - # Ubuntu 20.04 is the latest LTS. The next LTS is 22.04. - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: @@ -188,7 +187,7 @@ jobs: echo "Preparing the benchmark results:" tar -cvf benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar `git diff --name-only` `git ls-files --others --exclude=tpcds-sf-1 --exclude-standard` - name: Upload benchmark results - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}-${{ matrix.split }} path: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 98da34f7cdef1..c21e137dd9bfa 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -30,8 +30,7 @@ on: description: Branch to run the build against required: false type: string - # Change 'master' to 'branch-3.5' in branch-3.5 branch after cutting it. - default: master + default: branch-3.5 hadoop: description: Hadoop version to run with. HADOOP_PROFILE environment variable should accept it. required: false @@ -80,25 +79,34 @@ jobs: id: set-outputs run: | if [ -z "${{ inputs.jobs }}" ]; then - pyspark=true; sparkr=true; tpcds=true; docker=true; pyspark_modules=`cd dev && python -c "import sparktestsupport.modules as m; print(','.join(m.name for m in m.all_modules if m.name.startswith('pyspark')))"` pyspark=`./dev/is-changed.py -m $pyspark_modules` - sparkr=`./dev/is-changed.py -m sparkr` - tpcds=`./dev/is-changed.py -m sql` - docker=`./dev/is-changed.py -m docker-integration-tests` - # 'build', 'scala-213', and 'java-11-17' are always true for now. - # It does not save significant time and most of PRs trigger the build. + if [[ "${{ github.repository }}" != 'apache/spark' ]]; then + pandas=$pyspark + kubernetes=`./dev/is-changed.py -m kubernetes` + sparkr=`./dev/is-changed.py -m sparkr` + tpcds=`./dev/is-changed.py -m sql` + docker=`./dev/is-changed.py -m docker-integration-tests` + else + pandas=false + kubernetes=false + sparkr=false + tpcds=false + docker=false + fi + build=`./dev/is-changed.py -m "core,unsafe,kvstore,avro,utils,network-common,network-shuffle,repl,launcher,examples,sketch,graphx,catalyst,hive-thriftserver,streaming,sql-kafka-0-10,streaming-kafka-0-10,mllib-local,mllib,yarn,mesos,kubernetes,hadoop-cloud,spark-ganglia-lgpl,sql,hive,connect,protobuf,api"` precondition=" { - \"build\": \"true\", + \"build\": \"$build\", \"pyspark\": \"$pyspark\", + \"pyspark-pandas\": \"$pandas\", \"sparkr\": \"$sparkr\", \"tpcds-1g\": \"$tpcds\", \"docker-integration-tests\": \"$docker\", - \"scala-213\": \"true\", - \"java-11-17\": \"true\", + \"scala-213\": \"$build\", + \"java-11-17\": \"$build\", \"lint\" : \"true\", - \"k8s-integration-tests\" : \"true\", + \"k8s-integration-tests\" : \"$kubernetes\", \"breaking-changes-buf\" : \"true\", }" echo $precondition # For debugging @@ -205,6 +213,9 @@ jobs: HIVE_PROFILE: ${{ matrix.hive }} GITHUB_PREV_SHA: ${{ github.event.before }} SPARK_LOCAL_IP: localhost + SKIP_UNIDOC: true + SKIP_MIMA: true + SKIP_PACKAGING: true steps: - name: Checkout Spark repository uses: actions/checkout@v3 @@ -256,7 +267,7 @@ jobs: - name: Install Python packages (Python 3.8) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) run: | - python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3' + python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3' python3.8 -m pip list # Run the tests. - name: Run tests @@ -271,13 +282,13 @@ jobs: ./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST" --included-tags "$INCLUDED_TAGS" --excluded-tags "$EXCLUDED_TAGS" - name: Upload test results to report if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: test-results-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/unit-tests.log" @@ -344,6 +355,8 @@ jobs: java: - ${{ inputs.java }} modules: + - >- + pyspark-errors - >- pyspark-sql, pyspark-mllib, pyspark-resource, pyspark-testing - >- @@ -353,11 +366,19 @@ jobs: - >- pyspark-pandas-slow - >- - pyspark-connect, pyspark-errors + pyspark-connect - >- pyspark-pandas-connect - >- pyspark-pandas-slow-connect + exclude: + # Always run if pyspark-pandas == 'true', even infra-image is skip (such as non-master job) + # In practice, the build will run in individual PR, but not against the individual commit + # in Apache Spark repository. + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas' }} + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas-slow' }} + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas-connect' }} + - modules: ${{ fromJson(needs.precondition.outputs.required).pyspark-pandas != 'true' && 'pyspark-pandas-slow-connect' }} env: MODULES_TO_TEST: ${{ matrix.modules }} HADOOP_PROFILE: ${{ inputs.hadoop }} @@ -366,6 +387,7 @@ jobs: SPARK_LOCAL_IP: localhost SKIP_UNIDOC: true SKIP_MIMA: true + SKIP_PACKAGING: true METASPACE_SIZE: 1g steps: - name: Checkout Spark repository @@ -404,6 +426,8 @@ jobs: key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | pyspark-coursier- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java ${{ matrix.java }} uses: actions/setup-java@v3 with: @@ -414,14 +438,20 @@ jobs: python3.9 -m pip list pypy3 -m pip list - name: Install Conda for pip packaging test + if: ${{ matrix.modules == 'pyspark-errors' }} run: | curl -s https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh > miniconda.sh bash miniconda.sh -b -p $HOME/miniconda # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} + shell: 'script -q -e -c "bash {0}"' run: | - export PATH=$PATH:$HOME/miniconda/bin + if [[ "$MODULES_TO_TEST" == "pyspark-errors" ]]; then + export PATH=$PATH:$HOME/miniconda/bin + export SKIP_PACKAGING=false + echo "Python Packaging Tests Enabled!" + fi ./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST" - name: Upload coverage to Codecov if: fromJSON(inputs.envs).PYSPARK_CODECOV == 'true' @@ -432,13 +462,13 @@ jobs: name: PySpark - name: Upload test results to report if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: test-results-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: unit-tests-log-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -457,6 +487,7 @@ jobs: GITHUB_PREV_SHA: ${{ github.event.before }} SPARK_LOCAL_IP: localhost SKIP_MIMA: true + SKIP_PACKAGING: true steps: - name: Checkout Spark repository uses: actions/checkout@v3 @@ -494,6 +525,8 @@ jobs: key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | sparkr-coursier- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java ${{ inputs.java }} uses: actions/setup-java@v3 with: @@ -509,7 +542,7 @@ jobs: ./dev/run-tests --parallelism 1 --modules sparkr - name: Upload test results to report if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: test-results-sparkr--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" @@ -602,6 +635,8 @@ jobs: key: docs-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | docs-maven- + - name: Free up disk space + run: ./dev/free_disk_space_container - name: Install Java 8 uses: actions/setup-java@v3 with: @@ -611,6 +646,8 @@ jobs: run: ./dev/check-license - name: Dependencies test run: ./dev/test-dependencies.sh + - name: MIMA test + run: ./dev/mima - name: Scala linter run: ./dev/lint-scala - name: Java linter @@ -662,16 +699,16 @@ jobs: # See also https://issues.apache.org/jira/browse/SPARK-35375. # Pin the MarkupSafe to 2.0.1 to resolve the CI error. # See also https://issues.apache.org/jira/browse/SPARK-38279. - python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0' + python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme 'sphinx-copybutton==0.5.2' 'nbsphinx==0.9.3' numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0' 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' 'nest-asyncio==1.5.8' 'rpds-py==0.16.2' 'alabaster==0.7.13' python3.9 -m pip install ipython_genutils # See SPARK-38517 - python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' + python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas 'plotly>=4.8' python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 apt-get update -y apt-get install -y ruby ruby-dev Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'markdown', 'e1071', 'roxygen2', 'ggplot2', 'mvtnorm', 'statmod'), repos='https://cloud.r-project.org/')" Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" - gem install bundler + gem install bundler -v 2.4.22 cd docs bundle install - name: R linter @@ -794,8 +831,7 @@ jobs: needs: precondition if: fromJson(needs.precondition.outputs.required).tpcds-1g == 'true' name: Run TPC-DS queries with SF=1 - # Pin to 'Ubuntu 20.04' due to 'databricks/tpcds-kit' compilation - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 env: SPARK_LOCAL_IP: localhost steps: @@ -845,7 +881,7 @@ jobs: uses: actions/checkout@v3 with: repository: databricks/tpcds-kit - ref: 2a5078a782192ddb6efbcead8de9973d6ab4f069 + ref: 1b7fb7529edae091684201fab142d956d6afd881 path: ./tpcds-kit - name: Build tpcds-kit if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' @@ -878,13 +914,13 @@ jobs: spark.sql.join.forceApplyShuffledHashJoin=true - name: Upload test results to report if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: test-results-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: unit-tests-log-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -901,6 +937,7 @@ jobs: SPARK_LOCAL_IP: localhost ORACLE_DOCKER_IMAGE_NAME: gvenzl/oracle-xe:21.3.0 SKIP_MIMA: true + SKIP_PACKAGING: true steps: - name: Checkout Spark repository uses: actions/checkout@v3 @@ -943,13 +980,13 @@ jobs: ./dev/run-tests --parallelism 1 --modules docker-integration-tests --included-tags org.apache.spark.tags.DockerTest - name: Upload test results to report if: always() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: test-results-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: unit-tests-log-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -1017,10 +1054,10 @@ jobs: kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.7.0/installer/volcano-development.yaml || true eval $(minikube docker-env) - build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" + build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: spark-on-kubernetes-it-log path: "**/target/integration-tests.log" diff --git a/.github/workflows/notify_test_workflow.yml b/.github/workflows/notify_test_workflow.yml index 6fb776d708346..3079cacb60c6c 100644 --- a/.github/workflows/notify_test_workflow.yml +++ b/.github/workflows/notify_test_workflow.yml @@ -30,7 +30,7 @@ on: jobs: notify: name: Notify test workflow - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 permissions: actions: read checks: write diff --git a/.github/workflows/update_build_status.yml b/.github/workflows/update_build_status.yml index 05cf4914a25ca..aee9e9c9a601a 100644 --- a/.github/workflows/update_build_status.yml +++ b/.github/workflows/update_build_status.yml @@ -26,7 +26,7 @@ on: jobs: update: name: Update build status - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 permissions: actions: read checks: write diff --git a/.gitignore b/.gitignore index 11141961bf805..06c6660900d66 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ .scala_dependencies .settings .vscode +artifacts/ /lib/ R-unit-tests.log R/unit-tests.out @@ -50,6 +51,7 @@ dev/create-release/*final dev/create-release/*txt dev/pr-deps/ dist/ +docs/_generated/ docs/_site/ docs/api docs/.local_ruby_bundle @@ -117,6 +119,6 @@ spark-warehouse/ node_modules # For Antlr -sql/catalyst/gen/ -sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens -sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ +sql/api/gen/ +sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.tokens +sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/gen/ diff --git a/LICENSE b/LICENSE index 012fdbca4c90d..2eda694eaae9f 100644 --- a/LICENSE +++ b/LICENSE @@ -213,17 +213,10 @@ Apache Software Foundation License 2.0 common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java core/src/main/java/org/apache/spark/util/collection/TimSort.java core/src/main/resources/org/apache/spark/ui/static/bootstrap* -core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* core/src/main/resources/org/apache/spark/ui/static/vis* -docs/js/vendor/bootstrap.js connector/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java -Python Software Foundation License ----------------------------------- - -python/docs/source/_static/copybutton.js - BSD 3-Clause ------------ @@ -245,9 +238,11 @@ core/src/main/resources/org/apache/spark/ui/static/*dataTables* core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js core/src/main/resources/org/apache/spark/ui/static/jquery* core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/bootstrap* docs/js/vendor/anchor.min.js docs/js/vendor/jquery* docs/js/vendor/modernizer* +docs/js/vendor/docsearch.min.js Creative Commons CC0 1.0 Universal Public Domain Dedication diff --git a/LICENSE-binary b/LICENSE-binary index 9472d28e509ac..05645977a0ba5 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -412,7 +412,6 @@ com.google.cloud.bigdataoss:gcs-connector core/src/main/java/org/apache/spark/util/collection/TimSort.java core/src/main/resources/org/apache/spark/ui/static/bootstrap* -core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* core/src/main/resources/org/apache/spark/ui/static/vis* docs/js/vendor/bootstrap.js diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 1c093a4a98046..5eca59375425e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 3.5.0 +Version: 3.5.5 Title: R Front End for 'Apache Spark' Description: Provides an R Front end for 'Apache Spark' . Authors@R: diff --git a/assembly/pom.xml b/assembly/pom.xml index 09d6bd8a33f79..2066bbeb7e4d5 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../pom.xml @@ -159,6 +159,12 @@ org.apache.spark spark-connect_${scala.binary.version} ${project.version} + + + org.apache.spark + spark-connect-common_${scala.binary.version} + + org.apache.spark @@ -166,6 +172,12 @@ ${project.version} provided + + org.apache.spark + spark-protobuf_${scala.binary.version} + ${project.version} + provided + @@ -248,6 +260,7 @@ hive-provided provided + provided diff --git a/bin/pyspark b/bin/pyspark index 1ae28b1f507cd..a0972ccee3301 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +export HADOOP_USER_HOME=$(whoami) if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home diff --git a/bin/spark-class b/bin/spark-class index fc343ca29fddd..01f810ecdb31c 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +export HADOOP_USER_HOME=$(whoami) if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi diff --git a/bin/spark-sql b/bin/spark-sql index b08b944ebd319..3a46b41a003e2 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +export HADOOP_USER_HOME=$(whoami) if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home diff --git a/bin/spark-submit b/bin/spark-submit index 4e9d3614e6370..3a398928c6c11 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -17,6 +17,8 @@ # limitations under the License. # +export HADOOP_USER_HOME=$(whoami) + if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi diff --git a/bin/sparkR b/bin/sparkR index 8ecc755839fe3..136419dc2982f 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +export HADOOP_USER_HOME=$(whoami) if [ -z "${SPARK_HOME}" ]; then source "$(dirname "$0")"/find-spark-home fi diff --git a/binder/Dockerfile b/binder/Dockerfile new file mode 100644 index 0000000000000..6e3dd9155fb7a --- /dev/null +++ b/binder/Dockerfile @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM python:3.10-slim +# install the notebook package +RUN pip install --no-cache notebook jupyterlab + +# create user with a home directory +ARG NB_USER +ARG NB_UID +ENV USER ${NB_USER} +ENV HOME /home/${NB_USER} + +RUN adduser --disabled-password \ + --gecos "Default user" \ + --uid ${NB_UID} \ + ${NB_USER} +WORKDIR ${HOME} +USER ${USER} + +# Make sure the contents of our repo are in ${HOME} +COPY . ${HOME} +USER root +RUN chown -R ${NB_UID} ${HOME} +RUN apt-get update && apt-get install -y openjdk-17-jre git coreutils +USER ${NB_USER} + +RUN binder/postBuild + diff --git a/binder/apt.txt b/binder/apt.txt deleted file mode 100644 index 3d86667d4b910..0000000000000 --- a/binder/apt.txt +++ /dev/null @@ -1,2 +0,0 @@ -openjdk-8-jre -git diff --git a/binder/postBuild b/binder/postBuild old mode 100644 new mode 100755 index 70ae23b393707..c17816d4a5009 --- a/binder/postBuild +++ b/binder/postBuild @@ -20,8 +20,13 @@ # This file is used for Binder integration to install PySpark available in # Jupyter notebook. +# SPARK-45706: Should fail fast. Otherwise, the Binder image is successfully +# built, and it cannot be rebuilt. +set -o pipefail +set -e + VERSION=$(python -c "exec(open('python/pyspark/version.py').read()); print(__version__)") -TAG=$(git describe --tags --exact-match 2>/dev/null) +TAG=$(git describe --tags --exact-match 2> /dev/null || true) # If a commit is tagged, exactly specified version of pyspark should be installed to avoid # a kind of accident that an old version of pyspark is installed in the live notebook environment. @@ -33,9 +38,9 @@ else fi if [[ ! $VERSION < "3.4.0" ]]; then - pip install plotly "pandas<2.0.0" "pyspark[sql,ml,mllib,pandas_on_spark,connect]$SPECIFIER$VERSION" + pip install plotly "pandas<2.0.0" "numpy>=1.15,<2" "pyspark[sql,ml,mllib,pandas_on_spark,connect]$SPECIFIER$VERSION" else - pip install plotly "pandas<2.0.0" "pyspark[sql,ml,mllib,pandas_on_spark]$SPECIFIER$VERSION" + pip install plotly "pandas<2.0.0" "numpy>=1.15,<2" "pyspark[sql,ml,mllib,pandas_on_spark]$SPECIFIER$VERSION" fi # Set 'PYARROW_IGNORE_TIMEZONE' to surpress warnings from PyArrow. diff --git a/build/mvn b/build/mvn index 3179099304c7a..2c778fd6c71a7 100755 --- a/build/mvn +++ b/build/mvn @@ -56,7 +56,7 @@ install_app() { local binary="${_DIR}/$6" local remote_tarball="${mirror_host}/${url_path}${url_query}" local local_checksum="${local_tarball}.${checksum_suffix}" - local remote_checksum="https://archive.apache.org/dist/${url_path}.${checksum_suffix}" + local remote_checksum="${mirror_host}/${url_path}.${checksum_suffix}${url_query}" local curl_opts="--silent --show-error -L" local wget_opts="--no-verbose" diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index bef8303874b20..a1ec2748329b9 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml @@ -66,6 +66,11 @@ commons-io test + + org.apache.commons + commons-lang3 + test + org.apache.logging.log4j diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8a63e999c53cd..6ae7863161b1e 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml @@ -150,7 +150,8 @@ org.apache.spark - spark-tags_${scala.binary.version} + spark-common-utils_${scala.binary.version} + ${project.version} + ${project.build.directory}/extra-resources + true + + + + org.apache.maven.plugins + maven-antrun-plugin + + + choose-shell-and-script + validate + + run + + + true + + + + + + + + + + + + Shell to use for generating spark-version-info.properties file = + ${shell} + + Script to use for generating spark-version-info.properties file = + ${spark-build-info-script} + + + + + + generate-spark-build-info + generate-resources + + + + + + + + + + + + run + + + + diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/FilterFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/FilterFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/common/utils/src/main/java/org/apache/spark/api/java/function/Function.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/Function.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/Function.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/common/utils/src/main/java/org/apache/spark/api/java/function/Function0.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/Function0.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function2.java b/common/utils/src/main/java/org/apache/spark/api/java/function/Function2.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/Function2.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/Function2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function3.java b/common/utils/src/main/java/org/apache/spark/api/java/function/Function3.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/Function3.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/Function3.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/common/utils/src/main/java/org/apache/spark/api/java/function/Function4.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/Function4.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/Function4.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/MapFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/MapFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/MapFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/PairFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/PairFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/PairFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/common/utils/src/main/java/org/apache/spark/api/java/function/VoidFunction.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/VoidFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/common/utils/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/package-info.java b/common/utils/src/main/java/org/apache/spark/api/java/function/package-info.java similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/package-info.java rename to common/utils/src/main/java/org/apache/spark/api/java/function/package-info.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/common/utils/src/main/java/org/apache/spark/api/java/function/package.scala similarity index 100% rename from core/src/main/java/org/apache/spark/api/java/function/package.scala rename to common/utils/src/main/java/org/apache/spark/api/java/function/package.scala diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/utils/src/main/java/org/apache/spark/network/util/ByteUnit.java similarity index 100% rename from common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java rename to common/utils/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 89% rename from common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java rename to common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7e410e9eab223..d6603dcbee1ae 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -23,15 +23,11 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.attribute.BasicFileAttributes; -import java.util.Locale; -import java.util.UUID; +import java.util.*; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; -import io.netty.buffer.Unpooled; import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,7 +68,7 @@ public static int nonNegativeHash(Object obj) { * converted back to the same string through {@link #bytesToString(ByteBuffer)}. */ public static ByteBuffer stringToBytes(String s) { - return Unpooled.wrappedBuffer(s.getBytes(StandardCharsets.UTF_8)).nioBuffer(); + return ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8)); } /** @@ -80,7 +76,7 @@ public static ByteBuffer stringToBytes(String s) { * converted back to the same byte buffer through {@link #stringToBytes(String)}. */ public static String bytesToString(ByteBuffer b) { - return Unpooled.wrappedBuffer(b).toString(StandardCharsets.UTF_8); + return StandardCharsets.UTF_8.decode(b.slice()).toString(); } /** @@ -124,6 +120,7 @@ public static void deleteRecursively(File file, FilenameFilter filter) throws IO private static void deleteRecursivelyUsingJavaIO( File file, FilenameFilter filter) throws IOException { + if (!file.exists()) return; BasicFileAttributes fileAttributes = Files.readAttributes(file.toPath(), BasicFileAttributes.class); if (fileAttributes.isDirectory() && !isSymlink(file)) { @@ -191,7 +188,7 @@ private static File[] listFilesSafely(File file, FilenameFilter filter) throws I } private static boolean isSymlink(File file) throws IOException { - Preconditions.checkNotNull(file); + Objects.requireNonNull(file); File fileInCanonicalDir = null; if (file.getParent() == null) { fileInCanonicalDir = file; @@ -201,31 +198,35 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap timeSuffixes = - ImmutableMap.builder() - .put("us", TimeUnit.MICROSECONDS) - .put("ms", TimeUnit.MILLISECONDS) - .put("s", TimeUnit.SECONDS) - .put("m", TimeUnit.MINUTES) - .put("min", TimeUnit.MINUTES) - .put("h", TimeUnit.HOURS) - .put("d", TimeUnit.DAYS) - .build(); - - private static final ImmutableMap byteSuffixes = - ImmutableMap.builder() - .put("b", ByteUnit.BYTE) - .put("k", ByteUnit.KiB) - .put("kb", ByteUnit.KiB) - .put("m", ByteUnit.MiB) - .put("mb", ByteUnit.MiB) - .put("g", ByteUnit.GiB) - .put("gb", ByteUnit.GiB) - .put("t", ByteUnit.TiB) - .put("tb", ByteUnit.TiB) - .put("p", ByteUnit.PiB) - .put("pb", ByteUnit.PiB) - .build(); + private static final Map timeSuffixes; + + private static final Map byteSuffixes; + + static { + final Map timeSuffixesBuilder = new HashMap<>(); + timeSuffixesBuilder.put("us", TimeUnit.MICROSECONDS); + timeSuffixesBuilder.put("ms", TimeUnit.MILLISECONDS); + timeSuffixesBuilder.put("s", TimeUnit.SECONDS); + timeSuffixesBuilder.put("m", TimeUnit.MINUTES); + timeSuffixesBuilder.put("min", TimeUnit.MINUTES); + timeSuffixesBuilder.put("h", TimeUnit.HOURS); + timeSuffixesBuilder.put("d", TimeUnit.DAYS); + timeSuffixes = Collections.unmodifiableMap(timeSuffixesBuilder); + + final Map byteSuffixesBuilder = new HashMap<>(); + byteSuffixesBuilder.put("b", ByteUnit.BYTE); + byteSuffixesBuilder.put("k", ByteUnit.KiB); + byteSuffixesBuilder.put("kb", ByteUnit.KiB); + byteSuffixesBuilder.put("m", ByteUnit.MiB); + byteSuffixesBuilder.put("mb", ByteUnit.MiB); + byteSuffixesBuilder.put("g", ByteUnit.GiB); + byteSuffixesBuilder.put("gb", ByteUnit.GiB); + byteSuffixesBuilder.put("t", ByteUnit.TiB); + byteSuffixesBuilder.put("tb", ByteUnit.TiB); + byteSuffixesBuilder.put("p", ByteUnit.PiB); + byteSuffixesBuilder.put("pb", ByteUnit.PiB); + byteSuffixes = Collections.unmodifiableMap(byteSuffixesBuilder); + } /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count in the given unit. diff --git a/common/utils/src/main/scala/org/apache/spark/unsafe/array/ByteArrayUtils.java b/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java similarity index 100% rename from common/utils/src/main/scala/org/apache/spark/unsafe/array/ByteArrayUtils.java rename to common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java diff --git a/common/utils/src/main/resources/error/README.md b/common/utils/src/main/resources/error/README.md index dfcb42d49e79a..aed2c0becd311 100644 --- a/common/utils/src/main/resources/error/README.md +++ b/common/utils/src/main/resources/error/README.md @@ -666,6 +666,7 @@ The following SQLSTATEs are collated from: |4274C |42 |Syntax Error or Access Rule Violation |74C |The specified attribute was not found in the trusted context.|DB2 |N |DB2 | |4274D |42 |Syntax Error or Access Rule Violation |74D |The specified attribute already exists in the trusted context.|DB2 |N |DB2 | |4274E |42 |Syntax Error or Access Rule Violation |74E |The specified attribute is not supported in the trusted context.|DB2 |N |DB2 | +|4274K |42 |Syntax Error or Access Rule Violation |74K |Invalid use of a named argument when invoking a routine.|DB2 |N |DB2 | |4274M |42 |Syntax Error or Access Rule Violation |74M |An undefined period name was detected. |DB2 |N |DB2 | |42801 |42 |Syntax Error or Access Rule Violation |801 |Isolation level UR is invalid, because the result table is not read-only.|DB2 |N |DB2 | |42802 |42 |Syntax Error or Access Rule Violation |802 |The number of target values is not the same as the number of source values.|DB2 |N |DB2 | diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index e8cdaa6c63b3f..f1943a8ff3e04 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -28,6 +28,15 @@ ], "sqlState" : "42702" }, + "AMBIGUOUS_COLUMN_REFERENCE" : { + "message" : [ + "Column is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same.", + "This column points to one of the DataFrame but Spark is unable to figure out which one.", + "Please alias the DataFrames with different names via `DataFrame.alias` before joining them,", + "and specify the column using qualified name, e.g. `df.alias(\"a\").join(df.alias(\"b\"), col(\"a.id\") > col(\"b.id\"))`." + ], + "sqlState" : "42702" + }, "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { "message" : [ "Lateral column alias is ambiguous and has matches." @@ -69,14 +78,9 @@ } } }, - "AVRO_INCORRECT_TYPE" : { - "message" : [ - "Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which would lead to an incorrect answer. To allow reading this field, enable the SQL configuration: ." - ] - }, - "AVRO_LOWER_PRECISION" : { + "AVRO_INCOMPATIBLE_READ_TYPE" : { "message" : [ - "Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which leads to data being read as null. Please provide a wider decimal type to get the correct result. To allow reading null to this field, enable the SQL configuration: ." + "Cannot convert Avro to SQL because the original encoded data type is , however you're trying to read the field as , which would lead to an incorrect answer. To allow reading this field, enable the SQL configuration: \"spark.sql.legacy.avro.allowIncompatibleSchema\"." ] }, "BATCH_METADATA_NOT_FOUND" : { @@ -350,6 +354,11 @@ "message" : [ "Error instantiating Spark Connect plugin: " ] + }, + "SESSION_NOT_SAME" : { + "message" : [ + "Both Datasets must belong to the same SparkSession." + ] } } }, @@ -738,6 +747,24 @@ ], "sqlState" : "23505" }, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT" : { + "message" : [ + "Call to function is invalid because it includes multiple argument assignments to the same parameter name ." + ], + "subClass" : { + "BOTH_POSITIONAL_AND_NAMED" : { + "message" : [ + "A positional argument and named argument both referred to the same parameter. Please remove the named argument referring to this parameter." + ] + }, + "DOUBLE_NAMED_ARGUMENT_REFERENCE" : { + "message" : [ + "More than one named argument referred to the same parameter. Please assign a value only once." + ] + } + }, + "sqlState" : "4274K" + }, "EMPTY_JSON_FIELD_VALUE" : { "message" : [ "Failed to parse an empty string for data type ." @@ -924,6 +951,11 @@ "Cannot safely cast to ." ] }, + "EXTRA_COLUMNS" : { + "message" : [ + "Cannot write extra columns ." + ] + }, "EXTRA_STRUCT_FIELDS" : { "message" : [ "Cannot write extra fields to the struct ." @@ -1222,6 +1254,34 @@ ], "sqlState" : "42000" }, + "INVALID_CURSOR" : { + "message" : [ + "The cursor is invalid." + ], + "subClass" : { + "DISCONNECTED" : { + "message" : [ + "The cursor has been disconnected by the server." + ] + }, + "NOT_REATTACHABLE" : { + "message" : [ + "The cursor is not reattachable." + ] + }, + "POSITION_NOT_AVAILABLE" : { + "message" : [ + "The cursor position id is no longer available at index ." + ] + }, + "POSITION_NOT_FOUND" : { + "message" : [ + "The cursor position id is not found." + ] + } + }, + "sqlState" : "HY109" + }, "INVALID_DEFAULT_VALUE" : { "message" : [ "Failed to execute command because the destination table column has a DEFAULT value ," @@ -1232,6 +1292,11 @@ "which requires type, but the statement provided a value of incompatible type." ] }, + "NOT_CONSTANT" : { + "message" : [ + "which is not a constant expression whose equivalent value is known at query planning time." + ] + }, "SUBQUERY_EXPRESSION" : { "message" : [ "which contains subquery expressions." @@ -1365,6 +1430,44 @@ ], "sqlState" : "22023" }, + "INVALID_HANDLE" : { + "message" : [ + "The handle is invalid." + ], + "subClass" : { + "FORMAT" : { + "message" : [ + "Handle must be an UUID string of the format '00112233-4455-6677-8899-aabbccddeeff'" + ] + }, + "OPERATION_ABANDONED" : { + "message" : [ + "Operation was considered abandoned because of inactivity and removed." + ] + }, + "OPERATION_ALREADY_EXISTS" : { + "message" : [ + "Operation already exists." + ] + }, + "OPERATION_NOT_FOUND" : { + "message" : [ + "Operation not found." + ] + }, + "SESSION_ALREADY_EXISTS" : { + "message" : [ + "Session already exists." + ] + }, + "SESSION_NOT_FOUND" : { + "message" : [ + "Session not found." + ] + } + }, + "sqlState" : "HY000" + }, "INVALID_HIVE_COLUMN_NAME" : { "message" : [ "Cannot create the table having the nested column whose name contains invalid characters in Hive metastore." @@ -1956,7 +2059,13 @@ "Not allowed to implement multiple UDF interfaces, UDF class ." ] }, - "NAMED_ARGUMENTS_SUPPORT_DISABLED" : { + "NAMED_PARAMETERS_NOT_SUPPORTED" : { + "message" : [ + "Named parameters are not supported for function ; please retry the query with positional arguments to the function call instead." + ], + "sqlState" : "4274K" + }, + "NAMED_PARAMETER_SUPPORT_DISABLED" : { "message" : [ "Cannot call function because named argument references are not enabled here. In this case, the named argument reference was . Set \"spark.sql.allowNamedFunctionArguments\" to \"true\" to turn on feature." ] @@ -2295,6 +2404,12 @@ ], "sqlState" : "42614" }, + "REQUIRED_PARAMETER_NOT_FOUND" : { + "message" : [ + "Cannot invoke function because the parameter named is required, but the function call did not supply a value. Please update the function call to supply an argument value (either positionally at index or by name) and retry the query again." + ], + "sqlState" : "4274K" + }, "REQUIRES_SINGLE_PART_NAMESPACE" : { "message" : [ " requires a single-part namespace, but got ." @@ -2316,6 +2431,12 @@ ], "sqlState" : "42883" }, + "RULE_ID_NOT_FOUND" : { + "message" : [ + "Not found an id for the rule name \"\". Please modify RuleIdCollection.scala if you are adding a new rule." + ], + "sqlState" : "22023" + }, "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION" : { "message" : [ "The correlated scalar subquery '' is neither present in GROUP BY, nor in an aggregate function. Add it to GROUP BY using ordinal position or wrap it in `first()` (or `first_value`) if you don't care which value you get." @@ -2485,6 +2606,12 @@ ], "sqlState" : "42K09" }, + "UNEXPECTED_POSITIONAL_ARGUMENT" : { + "message" : [ + "Cannot invoke function because it contains positional argument(s) following the named argument assigned to ; please rearrange them so the positional arguments come first and then retry the query again." + ], + "sqlState" : "4274K" + }, "UNKNOWN_PROTOBUF_MESSAGE_TYPE" : { "message" : [ "Attempting to treat as a Message, but it was ." @@ -2514,6 +2641,12 @@ ], "sqlState" : "428C4" }, + "UNRECOGNIZED_PARAMETER_NAME" : { + "message" : [ + "Cannot invoke function because the function call included a named argument reference for the argument named , but this function does not include any signature containing an argument with this name. Did you mean one of the following? []." + ], + "sqlState" : "4274K" + }, "UNRECOGNIZED_SQL_TYPE" : { "message" : [ "Unrecognized SQL type - name: , id: ." @@ -2641,11 +2774,6 @@ ], "sqlState" : "0A000" }, - "UNSUPPORTED_DATA_SOURCE_FOR_DIRECT_QUERY" : { - "message" : [ - "The direct query on files does not support the data source type: . Please try a different data source type or consider using a different query method." - ] - }, "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE" : { "message" : [ "The datasource doesn't support the column of the type ." @@ -2853,6 +2981,16 @@ "Pivoting by the value '' of the column data type ." ] }, + "PURGE_PARTITION" : { + "message" : [ + "Partition purge." + ] + }, + "PURGE_TABLE" : { + "message" : [ + "Purge table." + ] + }, "PYTHON_UDF_IN_ON_CLAUSE" : { "message" : [ "Python UDF in the ON clause of a JOIN. In case of an INNNER JOIN consider rewriting to a CROSS JOIN with a WHERE clause." @@ -2923,7 +3061,7 @@ "subClass" : { "MULTI_GENERATOR" : { "message" : [ - "only one generator allowed per clause but found : ." + "only one generator allowed per SELECT clause but found : ." ] }, "NESTED_IN_EXPRESSIONS" : { @@ -5312,11 +5450,6 @@ "." ] }, - "_LEGACY_ERROR_TEMP_2175" : { - "message" : [ - "Rule id not found for . Please modify RuleIdCollection.scala if you are adding a new rule." - ] - }, "_LEGACY_ERROR_TEMP_2176" : { "message" : [ "Cannot create array with elements of data due to exceeding the limit elements for ArrayData. " diff --git a/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala b/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala new file mode 100644 index 0000000000000..ebc62460d2318 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import java.util.Properties + +private[spark] object SparkBuildInfo { + + val ( + spark_version: String, + spark_branch: String, + spark_revision: String, + spark_build_user: String, + spark_repo_url: String, + spark_build_date: String, + spark_doc_root: String) = { + + val resourceStream = Thread.currentThread().getContextClassLoader. + getResourceAsStream("spark-version-info.properties") + if (resourceStream == null) { + throw new SparkException("Could not find spark-version-info.properties") + } + + try { + val unknownProp = "" + val props = new Properties() + props.load(resourceStream) + ( + props.getProperty("version", unknownProp), + props.getProperty("branch", unknownProp), + props.getProperty("revision", unknownProp), + props.getProperty("user", unknownProp), + props.getProperty("url", unknownProp), + props.getProperty("date", unknownProp), + props.getProperty("docroot", unknownProp) + ) + } catch { + case e: Exception => + throw new SparkException("Error loading properties from spark-version-info.properties", e) + } finally { + if (resourceStream != null) { + try { + resourceStream.close() + } catch { + case e: Exception => + throw new SparkException("Error closing spark build info resource stream", e) + } + } + } + } +} diff --git a/common/utils/src/main/scala/org/apache/spark/SparkException.scala b/common/utils/src/main/scala/org/apache/spark/SparkException.scala index feb7bf5b66eda..5c5bf17c942d6 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkException.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkException.scala @@ -58,6 +58,15 @@ class SparkException( errorClass = Some(errorClass), messageParameters = messageParameters) + def this(errorClass: String, messageParameters: Map[String, String], cause: Throwable, + context: Array[QueryContext]) = + this( + message = SparkThrowableHelper.getMessage(errorClass, messageParameters), + cause = cause, + errorClass = Some(errorClass), + messageParameters = messageParameters, + context = context) + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass.orNull @@ -124,50 +133,107 @@ private[spark] case class ExecutorDeadException(message: String) /** * Exception thrown when Spark returns different result after upgrading to a new version. */ -private[spark] class SparkUpgradeException( +private[spark] class SparkUpgradeException private( + message: String, + cause: Option[Throwable], + errorClass: Option[String], + messageParameters: Map[String, String]) + extends RuntimeException(message, cause.orNull) with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], - cause: Throwable) - extends RuntimeException( - SparkThrowableHelper.getMessage(errorClass, messageParameters), cause) - with SparkThrowable { + cause: Throwable) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters), + Option(cause), + Option(errorClass), + messageParameters + ) + } + + def this(message: String, cause: Option[Throwable]) = { + this( + message, + cause = cause, + errorClass = None, + messageParameters = Map.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull } /** * Arithmetic exception thrown from Spark with an error class. */ -private[spark] class SparkArithmeticException( +private[spark] class SparkArithmeticException private( + message: String, + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends ArithmeticException(message) with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], context: Array[QueryContext], - summary: String) - extends ArithmeticException( - SparkThrowableHelper.getMessage(errorClass, messageParameters, summary)) - with SparkThrowable { + summary: String) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(errorClass), + messageParameters, + context + ) + } + + def this(message: String) = { + this( + message, + errorClass = None, + messageParameters = Map.empty, + context = Array.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } /** * Unsupported operation exception thrown from Spark with an error class. */ -private[spark] class SparkUnsupportedOperationException( +private[spark] class SparkUnsupportedOperationException private( + message: String, + errorClass: Option[String], + messageParameters: Map[String, String]) + extends UnsupportedOperationException(message) with SparkThrowable { + + def this( errorClass: String, - messageParameters: Map[String, String]) - extends UnsupportedOperationException( - SparkThrowableHelper.getMessage(errorClass, messageParameters)) - with SparkThrowable { + messageParameters: Map[String, String]) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters), + Option(errorClass), + messageParameters + ) + } + + def this(message: String) = { + this( + message, + errorClass = None, + messageParameters = Map.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull } /** @@ -205,18 +271,38 @@ private[spark] class SparkConcurrentModificationException( /** * Datetime exception thrown from Spark with an error class. */ -private[spark] class SparkDateTimeException( +private[spark] class SparkDateTimeException private( + message: String, + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends DateTimeException(message) with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], context: Array[QueryContext], - summary: String) - extends DateTimeException( - SparkThrowableHelper.getMessage(errorClass, messageParameters, summary)) - with SparkThrowable { + summary: String) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(errorClass), + messageParameters, + context + ) + } + + def this(message: String) = { + this( + message, + errorClass = None, + messageParameters = Map.empty, + context = Array.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -238,54 +324,122 @@ private[spark] class SparkFileNotFoundException( /** * Number format exception thrown from Spark with an error class. */ -private[spark] class SparkNumberFormatException( +private[spark] class SparkNumberFormatException private( + message: String, + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends NumberFormatException(message) + with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], context: Array[QueryContext], - summary: String) - extends NumberFormatException( - SparkThrowableHelper.getMessage(errorClass, messageParameters, summary)) - with SparkThrowable { + summary: String) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(errorClass), + messageParameters, + context + ) + } + + def this(message: String) = { + this( + message, + errorClass = None, + messageParameters = Map.empty, + context = Array.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } /** * Illegal argument exception thrown from Spark with an error class. */ -private[spark] class SparkIllegalArgumentException( +private[spark] class SparkIllegalArgumentException private( + message: String, + cause: Option[Throwable], + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends IllegalArgumentException(message, cause.orNull) + with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], context: Array[QueryContext] = Array.empty, summary: String = "", - cause: Throwable = null) - extends IllegalArgumentException( - SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), cause) - with SparkThrowable { + cause: Throwable = null) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(cause), + Option(errorClass), + messageParameters, + context + ) + } + + def this(message: String, cause: Option[Throwable]) = { + this( + message, + cause = cause, + errorClass = None, + messageParameters = Map.empty, + context = Array.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } -private[spark] class SparkRuntimeException( +private[spark] class SparkRuntimeException private( + message: String, + cause: Option[Throwable], + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends RuntimeException(message, cause.orNull) + with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], cause: Throwable = null, context: Array[QueryContext] = Array.empty, - summary: String = "") - extends RuntimeException( - SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), - cause) - with SparkThrowable { + summary: String = "") = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(cause), + Option(errorClass), + messageParameters, + context + ) + } + + def this(message: String, cause: Option[Throwable]) = { + this( + message, + cause = cause, + errorClass = None, + messageParameters = Map.empty, + context = Array.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -326,18 +480,39 @@ private[spark] class SparkSecurityException( /** * Array index out of bounds exception thrown from Spark with an error class. */ -private[spark] class SparkArrayIndexOutOfBoundsException( +private[spark] class SparkArrayIndexOutOfBoundsException private( + message: String, + errorClass: Option[String], + messageParameters: Map[String, String], + context: Array[QueryContext]) + extends ArrayIndexOutOfBoundsException(message) + with SparkThrowable { + + def this( errorClass: String, messageParameters: Map[String, String], context: Array[QueryContext], - summary: String) - extends ArrayIndexOutOfBoundsException( - SparkThrowableHelper.getMessage(errorClass, messageParameters, summary)) - with SparkThrowable { + summary: String) = { + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters, summary), + Option(errorClass), + messageParameters, + context + ) + } + + def this(message: String) = { + this( + message, + errorClass = None, + messageParameters = Map.empty, + context = Array.empty + ) + } override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getErrorClass: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 5106460d145c7..2331a8e67b28e 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -19,8 +19,6 @@ package org.apache.spark import scala.collection.JavaConverters._ -import com.fasterxml.jackson.core.util.MinimalPrettyPrinter - import org.apache.spark.util.JsonUtils.toJsonString import org.apache.spark.util.SparkClassUtils @@ -63,7 +61,7 @@ private[spark] object SparkThrowableHelper { } def isInternalError(errorClass: String): Boolean = { - errorClass.startsWith("INTERNAL_ERROR") + errorClass != null && errorClass.startsWith("INTERNAL_ERROR") } def getMessage(e: SparkThrowable with Throwable, format: ErrorMessageFormat.Value): String = { @@ -121,16 +119,4 @@ private[spark] object SparkThrowableHelper { } } } - - def getMessage(throwable: Throwable): String = { - toJsonString { generator => - val g = generator.setPrettyPrinter(new MinimalPrettyPrinter) - g.writeStartObject() - g.writeStringField("errorClass", throwable.getClass.getCanonicalName) - g.writeObjectFieldStart("messageParameters") - g.writeStringField("message", throwable.getMessage) - g.writeEndObject() - g.writeEndObject() - } - } } diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 83e01330ce3f6..bd82ce962b8d0 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -196,7 +196,7 @@ private[spark] object Logging { val initLock = new Object() try { // We use reflection here to handle the case where users remove the - // slf4j-to-jul bridge order to route their logs to JUL. + // jul-to-slf4j bridge order to route their logs to JUL. val bridgeClass = SparkClassUtils.classForName("org.slf4j.bridge.SLF4JBridgeHandler") bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala index 011f74de1febe..5984eaee42e73 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala @@ -16,7 +16,13 @@ */ package org.apache.spark.util -trait SparkClassUtils { +import java.util.Random + +import scala.util.Try + +private[spark] trait SparkClassUtils { + val random = new Random() + def getSparkClassLoader: ClassLoader = getClass.getClassLoader def getContextOrSparkClassLoader: ClassLoader = @@ -39,6 +45,39 @@ trait SparkClassUtils { } // scalastyle:on classforname } + + /** Determines whether the provided class is loadable in the current thread. */ + def classIsLoadable(clazz: String): Boolean = { + Try { classForName(clazz, initialize = false) }.isSuccess + } + + /** + * Returns true if and only if the underlying class is a member class. + * + * Note: jdk8u throws a "Malformed class name" error if a given class is a deeply-nested + * inner class (See SPARK-34607 for details). This issue has already been fixed in jdk9+, so + * we can remove this helper method safely if we drop the support of jdk8u. + */ + def isMemberClass(cls: Class[_]): Boolean = { + try { + cls.isMemberClass + } catch { + case _: InternalError => + // We emulate jdk8u `Class.isMemberClass` below: + // public boolean isMemberClass() { + // return getSimpleBinaryName() != null && !isLocalOrAnonymousClass(); + // } + // `getSimpleBinaryName()` returns null if a given class is a top-level class, + // so we replace it with `cls.getEnclosingClass != null`. The second condition checks + // if a given class is not a local or an anonymous class, so we replace it with + // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()` return a value + // only in either case (JVM Spec 4.8.6). + // + // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first, + // we reorder the conditions to follow it. + cls.getEnclosingMethod == null && cls.getEnclosingClass != null + } + } } -object SparkClassUtils extends SparkClassUtils +private[spark] object SparkClassUtils extends SparkClassUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala new file mode 100644 index 0000000000000..be8282db31bee --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkCollectionUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import scala.collection.immutable + +private[spark] trait SparkCollectionUtils { + /** + * Same function as `keys.zipWithIndex.toMap`, but has perf gain. + */ + def toMapWithIndex[K](keys: Iterable[K]): Map[K, Int] = { + val builder = immutable.Map.newBuilder[K, Int] + val keyIter = keys.iterator + var idx = 0 + while (keyIter.hasNext) { + builder += (keyIter.next(), idx).asInstanceOf[(K, Int)] + idx = idx + 1 + } + builder.result() + } +} + +private[spark] object SparkCollectionUtils extends SparkCollectionUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala index 8e4de01885e02..8194d1e424173 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala @@ -16,13 +16,14 @@ */ package org.apache.spark.util -import java.io.IOException +import java.io.{Closeable, IOException, PrintWriter} +import java.nio.charset.StandardCharsets.UTF_8 import scala.util.control.NonFatal import org.apache.spark.internal.Logging -object SparkErrorUtils extends Logging { +private[spark] trait SparkErrorUtils extends Logging { /** * Execute a block of code that returns a value, re-throwing any non-fatal uncaught * exceptions as IOException. This is used when implementing Externalizable and Serializable's @@ -41,4 +42,52 @@ object SparkErrorUtils extends Logging { throw new IOException(e) } } + + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } + + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable + } + } + } + + def stackTraceToString(t: Throwable): String = { + val out = new java.io.ByteArrayOutputStream + SparkErrorUtils.tryWithResource(new PrintWriter(out)) { writer => + t.printStackTrace(writer) + writer.flush() + } + new String(out.toByteArray, UTF_8) + } } + +private[spark] object SparkErrorUtils extends SparkErrorUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala index 63d1ab4799ab2..e12f8acdadd3c 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala @@ -18,8 +18,12 @@ package org.apache.spark.util import java.io.File import java.net.{URI, URISyntaxException} +import java.nio.file.Files -private[spark] object SparkFileUtils { +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils + +private[spark] trait SparkFileUtils extends Logging { /** * Return a well-formed URI for the file described by a user input string. * @@ -44,4 +48,78 @@ private[spark] object SparkFileUtils { } new File(path).getCanonicalFile().toURI() } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val result = f.listFiles.toBuffer + val dirList = result.filter(_.isDirectory) + while (dirList.nonEmpty) { + val curDir = dirList.remove(0) + val files = curDir.listFiles() + result ++= files + dirList ++= files.filter(_.isDirectory) + } + result.toArray + } + + /** + * Create a directory given the abstract pathname + * @return true, if the directory is successfully created; otherwise, return false. + */ + def createDirectory(dir: File): Boolean = { + try { + // SPARK-35907: The check was required by File.mkdirs() because it could sporadically + // fail silently. After switching to Files.createDirectories(), ideally, there should + // no longer be silent fails. But the check is kept for the safety concern. We can + // remove the check when we're sure that Files.createDirectories() would never fail silently. + Files.createDirectories(dir.toPath) + if ( !dir.exists() || !dir.isDirectory) { + logError(s"Failed to create directory " + dir) + } + dir.isDirectory + } catch { + case e: Exception => + logError(s"Failed to create directory " + dir, e) + false + } + } + + /** + * Create a directory inside the given parent directory. The directory is guaranteed to be + * newly created, and is not marked for automatic deletion. + */ + def createDirectory(root: String, namePrefix: String = "spark"): File = { + JavaUtils.createDirectory(root, namePrefix) + } + + /** + * Create a temporary directory inside the `java.io.tmpdir` prefixed with `spark`. + * The directory will be automatically deleted when the VM shuts down. + */ + def createTempDir(): File = + createTempDir(System.getProperty("java.io.tmpdir"), "spark") + + /** + * Create a temporary directory inside the given parent directory. The directory will be + * automatically deleted when the VM shuts down. + */ + def createTempDir( + root: String = System.getProperty("java.io.tmpdir"), + namePrefix: String = "spark"): File = { + createDirectory(root, namePrefix) + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + def deleteRecursively(file: File): Unit = { + JavaUtils.deleteRecursively(file) + } } + +private[spark] object SparkFileUtils extends SparkFileUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala index 88d1d6bdba8dd..2cc14fea5f307 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.util -import java.io.{ByteArrayOutputStream, ObjectOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream, ObjectStreamClass} -object SparkSerDeUtils { +private[spark] trait SparkSerDeUtils { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -27,4 +27,28 @@ object SparkSerDeUtils { oos.close() bos.toByteArray } + + /** Deserialize an object using Java serialization */ + def deserialize[T](bytes: Array[Byte]): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) + ois.readObject.asInstanceOf[T] + } + + /** + * Deserialize an object using Java serialization and the given ClassLoader + */ + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + // scalastyle:off classforname + Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } + } + ois.readObject.asInstanceOf[T] + } } + +private[spark] object SparkSerDeUtils extends SparkSerDeUtils diff --git a/connector/avro/pom.xml b/connector/avro/pom.xml index 597e3c2235f7a..11811ed080bca 100644 --- a/connector/avro/pom.xml +++ b/connector/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 59f2999bdd395..2c2a45fc3f14f 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -61,7 +61,8 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val reader = new GenericDatumReader[Any](actualSchema, expectedSchema) @transient private lazy val deserializer = - new AvroDeserializer(expectedSchema, dataType, avroOptions.datetimeRebaseModeInRead) + new AvroDeserializer(expectedSchema, dataType, + avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 092a8f80b771b..ec34d10a5ffe8 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -50,18 +49,21 @@ private[sql] class AvroDeserializer( rootCatalystType: DataType, positionalFieldMatch: Boolean, datetimeRebaseSpec: RebaseSpec, - filters: StructFilters) { + filters: StructFilters, + useStableIdForUnionType: Boolean) { def this( rootAvroType: Schema, rootCatalystType: DataType, - datetimeRebaseMode: String) = { + datetimeRebaseMode: String, + useStableIdForUnionType: Boolean) = { this( rootAvroType, rootCatalystType, positionalFieldMatch = false, RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), - new NoopFilters) + new NoopFilters, + useStableIdForUnionType) } private lazy val decimalConversions = new DecimalConversion() @@ -103,6 +105,9 @@ private[sql] class AvroDeserializer( s"Cannot convert Avro type $rootAvroType to SQL type ${rootCatalystType.sql}.", ise) } + private lazy val preventReadingIncorrectType = !SQLConf.get + .getConf(SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA) + def deserialize(data: Any): Option[Any] = converter(data) /** @@ -119,268 +124,203 @@ private[sql] class AvroDeserializer( val incompatibleMsg = errorPrefix + s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" - val confKey = SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA - val preventReadingIncorrectType = !SQLConf.get.getConf(confKey) + val realDataType = SchemaConverters.toSqlType(avroType, useStableIdForUnionType).dataType + + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) - val logicalDataType = SchemaConverters.toSqlType(avroType).dataType - avroType.getType match { - case NULL => - (logicalDataType, catalystType) match { - case (_, NullType) => (updater, ordinal, _) => - updater.setNullAt(ordinal) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } // TODO: we can avoid boxing if future version of avro provide primitive accessors. - case BOOLEAN => - (logicalDataType, catalystType) match { - case (_, BooleanType) => (updater, ordinal, value) => - updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, dt: DatetimeType) + if preventReadingIncorrectType && realDataType.isInstanceOf[YearMonthIntervalType] => + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) + + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) + + case (LONG, dt: DatetimeType) + if preventReadingIncorrectType && realDataType.isInstanceOf[DayTimeIntervalType] => + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => avroType.getLogicalType match { + // For backward compatibility, if the Avro type is Long and it is not logical type + // (the `null` case), the value is processed as timestamp type with millisecond precision. + case null | _: TimestampMillis => (updater, ordinal, value) => + val millis = value.asInstanceOf[Long] + val micros = DateTimeUtils.millisToMicros(millis) + updater.setLong(ordinal, timestampRebaseFunc(micros)) + case _: TimestampMicros => (updater, ordinal, value) => + val micros = value.asInstanceOf[Long] + updater.setLong(ordinal, timestampRebaseFunc(micros)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.") + } - case INT => - (logicalDataType, catalystType) match { - case (IntegerType, IntegerType) => (updater, ordinal, value) => - updater.setInt(ordinal, value.asInstanceOf[Int]) - case (IntegerType, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) - case (DateType, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) - case (_: YearMonthIntervalType, _: YearMonthIntervalType) => (updater, ordinal, value) => - updater.setInt(ordinal, value.asInstanceOf[Int]) - case (_: YearMonthIntervalType, _) if preventReadingIncorrectType => - throw QueryCompilationErrors.avroIncorrectTypeError( - toFieldStr(avroPath), toFieldStr(catalystPath), - logicalDataType.catalogString, catalystType.catalogString, confKey.key) - case _ if !preventReadingIncorrectType => (updater, ordinal, value) => - updater.setInt(ordinal, value.asInstanceOf[Int]) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case LONG => - (logicalDataType, catalystType) match { - case (LongType, LongType) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - case (TimestampType, LongType) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - case (TimestampNTZType, LongType) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - case (LongType, TimestampType) - | (TimestampType, TimestampType) - |(TimestampNTZType, TimestampType) => avroType.getLogicalType match { - // For backward compatibility, if the Avro type is Long and it is not logical type - // (the `null` case), the value is processed as timestamp type with - // millisecond precision. - case null | _: TimestampMillis => (updater, ordinal, value) => - val millis = value.asInstanceOf[Long] - val micros = DateTimeUtils.millisToMicros(millis) - updater.setLong(ordinal, timestampRebaseFunc(micros)) - case _: TimestampMicros => (updater, ordinal, value) => - val micros = value.asInstanceOf[Long] - updater.setLong(ordinal, timestampRebaseFunc(micros)) - case other => throw new IncompatibleSchemaException(errorPrefix + - s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.") - } - case (LongType, TimestampNTZType) - | (TimestampNTZType, TimestampNTZType) - | (TimestampType, TimestampNTZType) => avroType.getLogicalType match { - // To keep consistent with TimestampType, if the Avro type is Long and it is not - // logical type (the `null` case), the value is processed as TimestampNTZ - // with millisecond precision. - case null | _: LocalTimestampMillis => (updater, ordinal, value) => - val millis = value.asInstanceOf[Long] - val micros = DateTimeUtils.millisToMicros(millis) - updater.setLong(ordinal, micros) - case _: LocalTimestampMicros => (updater, ordinal, value) => - val micros = value.asInstanceOf[Long] - updater.setLong(ordinal, micros) - case other => throw new IncompatibleSchemaException(errorPrefix + - s"Avro logical type $other cannot be converted to SQL type ${TimestampNTZType.sql}.") - } - // Before we upgrade Avro to 1.8 for logical type support, - // spark-avro converts Long to Date. - // For backward compatibility, we still keep this conversion. - case (LongType, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt) - case (DateType, DateType) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - case (_: DayTimeIntervalType, _: DayTimeIntervalType) => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - case (_: DayTimeIntervalType, _) if preventReadingIncorrectType => - throw QueryCompilationErrors.avroIncorrectTypeError( - toFieldStr(avroPath), toFieldStr(catalystPath), - logicalDataType.catalogString, catalystType.catalogString, confKey.key) - case (_: DayTimeIntervalType, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt) - case (_, dt: DecimalType) => (updater, ordinal, value) => - val d = avroType.getLogicalType.asInstanceOf[CustomDecimal] - updater.setDecimal(ordinal, Decimal(value.asInstanceOf[Long], d.precision, d.scale)) - case _ if !preventReadingIncorrectType => (updater, ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[Long]) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case FLOAT => - (logicalDataType, catalystType) match { - case (_, FloatType) => (updater, ordinal, value) => - updater.setFloat(ordinal, value.asInstanceOf[Float]) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case DOUBLE => - (logicalDataType, catalystType) match { - case (_, DoubleType) => (updater, ordinal, value) => - updater.setDouble(ordinal, value.asInstanceOf[Double]) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case STRING => - (logicalDataType, catalystType) match { - case (_, StringType) => (updater, ordinal, value) => - val str = value match { - case s: String => UTF8String.fromString(s) - case s: Utf8 => - val bytes = new Array[Byte](s.getByteLength) - System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) - UTF8String.fromBytes(bytes) - } - updater.set(ordinal, str) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case ENUM => - (logicalDataType, catalystType) match { - case (_, StringType) => (updater, ordinal, value) => - updater.set(ordinal, UTF8String.fromString(value.toString)) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) + case (LONG, TimestampNTZType) => avroType.getLogicalType match { + // To keep consistent with TimestampType, if the Avro type is Long and it is not + // logical type (the `null` case), the value is processed as TimestampNTZ + // with millisecond precision. + case null | _: LocalTimestampMillis => (updater, ordinal, value) => + val millis = value.asInstanceOf[Long] + val micros = DateTimeUtils.millisToMicros(millis) + updater.setLong(ordinal, micros) + case _: LocalTimestampMicros => (updater, ordinal, value) => + val micros = value.asInstanceOf[Long] + updater.setLong(ordinal, micros) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"Avro logical type $other cannot be converted to SQL type ${TimestampNTZType.sql}.") + } + + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) } - case FIXED => - (logicalDataType, catalystType) match { - case (_, BinaryType) => (updater, ordinal, value) => - updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) - case (_, dt: DecimalType) => - val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] - if (preventReadingIncorrectType && - d.getPrecision - d.getScale > dt.precision - dt.scale) { - throw QueryCompilationErrors.avroLowerPrecisionError(toFieldStr(avroPath), - toFieldStr(catalystPath), logicalDataType.catalogString, - dt.catalogString, confKey.key) - } - (updater, ordinal, value) => - val bigDecimal = - decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d) - val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) - updater.setDecimal(ordinal, decimal) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + // Do not forget to reset the position + b.rewind() + bytes + case b: Array[Byte] => b + case other => + throw new RuntimeException(errorPrefix + s"$other is not a valid avro binary.") } - case BYTES => - (logicalDataType, catalystType) match { - case (_, BinaryType) => (updater, ordinal, value) => - val bytes = value match { - case b: ByteBuffer => - val bytes = new Array[Byte](b.remaining) - b.get(bytes) - // Do not forget to reset the position - b.rewind() - bytes - case b: Array[Byte] => b - case other => - throw new RuntimeException(errorPrefix + s"$other is not a valid avro binary.") - } - updater.set(ordinal, bytes) - case (_, dt: DecimalType) => - val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] - if (preventReadingIncorrectType && - d.getPrecision - d.getScale > dt.precision - dt.scale) { - throw QueryCompilationErrors.avroLowerPrecisionError(toFieldStr(avroPath), - toFieldStr(catalystPath), logicalDataType.catalogString, - dt.catalogString, confKey.key) - } - (updater, ordinal, value) => - val bigDecimal = decimalConversions - .fromBytes(value.asInstanceOf[ByteBuffer], avroType, d) - val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) - updater.setDecimal(ordinal, decimal) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) + updater.set(ordinal, bytes) + + case (FIXED, dt: DecimalType) => + val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] + if (preventReadingIncorrectType && + d.getPrecision - d.getScale > dt.precision - dt.scale) { + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) } - case RECORD => - (logicalDataType, catalystType) match { - case (_, st: StructType) => - // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328. - // We can always return `false` from `applyFilters` for nested records. - val writeRecord = - getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = _ => false) - (updater, ordinal, value) => - val row = new SpecificInternalRow(st) - writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) - updater.set(ordinal, row) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) + (updater, ordinal, value) => + val bigDecimal = + decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) + + case (BYTES, dt: DecimalType) => + val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] + if (preventReadingIncorrectType && + d.getPrecision - d.getScale > dt.precision - dt.scale) { + throw QueryCompilationErrors.avroIncompatibleReadError(toFieldStr(avroPath), + toFieldStr(catalystPath), realDataType.catalogString, dt.catalogString) } - case ARRAY => - (logicalDataType, catalystType) match { - case (_, ArrayType(elementType, containsNull)) => - val avroElementPath = avroPath :+ "element" - val elementWriter = newWriter(avroType.getElementType, elementType, - avroElementPath, catalystPath :+ "element") - (updater, ordinal, value) => - val collection = value.asInstanceOf[java.util.Collection[Any]] - val result = createArrayData(elementType, collection.size()) - val elementUpdater = new ArrayDataUpdater(result) - - var i = 0 - val iter = collection.iterator() - while (iter.hasNext) { - val element = iter.next() - if (element == null) { - if (!containsNull) { - throw new RuntimeException( - s"Array value at path ${toFieldStr(avroElementPath)}" + - s" is not allowed to be null") - } else { - elementUpdater.setNullAt(i) - } - } else { - elementWriter(elementUpdater, i, element) - } - i += 1 + (updater, ordinal, value) => + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) + + case (RECORD, st: StructType) => + // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328. + // We can always return `false` from `applyFilters` for nested records. + val writeRecord = + getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = _ => false) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val avroElementPath = avroPath :+ "element" + val elementWriter = newWriter(avroType.getElementType, elementType, + avroElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iter = collection.iterator() + while (iter.hasNext) { + val element = iter.next() + if (element == null) { + if (!containsNull) { + throw new RuntimeException( + s"Array value at path ${toFieldStr(avroElementPath)} is not allowed to be null") + } else { + elementUpdater.setNullAt(i) } - updater.set(ordinal, result) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case MAP => - (logicalDataType, catalystType) match { - case (_, MapType(keyType, valueType, valueContainsNull)) - if keyType == StringType => - val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, - avroPath :+ "key", catalystPath :+ "key") - val valueWriter = newWriter(avroType.getValueType, valueType, - avroPath :+ "value", catalystPath :+ "value") - (updater, ordinal, value) => - val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] - val keyArray = createArrayData(keyType, map.size()) - val keyUpdater = new ArrayDataUpdater(keyArray) - val valueArray = createArrayData(valueType, map.size()) - val valueUpdater = new ArrayDataUpdater(valueArray) - val iter = map.entrySet().iterator() - var i = 0 - while (iter.hasNext) { - val entry = iter.next() - assert(entry.getKey != null) - keyWriter(keyUpdater, i, entry.getKey) - if (entry.getValue == null) { - if (!valueContainsNull) { - throw new RuntimeException( - s"Map value at path ${toFieldStr(avroPath :+ "value")}" + - s" is not allowed to be null") - } else { - valueUpdater.setNullAt(i) - } - } else { - valueWriter(valueUpdater, i, entry.getValue) - } - i += 1 + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, + avroPath :+ "key", catalystPath :+ "key") + val valueWriter = newWriter(avroType.getValueType, valueType, + avroPath :+ "value", catalystPath :+ "value") + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException( + s"Map value at path ${toFieldStr(avroPath :+ "value")} is not allowed to be null") + } else { + valueUpdater.setNullAt(i) } - // The Avro map will never have null or duplicated map keys, it's safe to create a - // ArrayBasedMapData directly here. - updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) - case _ => throw new IncompatibleSchemaException(incompatibleMsg) - } - case UNION => + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + // The Avro map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => val nonNullTypes = nonNullUnionBranches(avroType) val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava) if (nonNullTypes.nonEmpty) { @@ -389,18 +329,20 @@ private[sql] class AvroDeserializer( } else { nonNullTypes.map(_.getType).toSeq match { case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => - (updater, ordinal, value) => value match { - case null => updater.setNullAt(ordinal) - case l: java.lang.Long => updater.setLong(ordinal, l) - case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) - } + (updater, ordinal, value) => + value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => - (updater, ordinal, value) => value match { - case null => updater.setNullAt(ordinal) - case d: java.lang.Double => updater.setDouble(ordinal, d) - case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) - } + (updater, ordinal, value) => + value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } case _ => catalystType match { @@ -424,6 +366,17 @@ private[sql] class AvroDeserializer( } else { (updater, ordinal, _) => updater.setNullAt(ordinal) } + + case (INT, _: YearMonthIntervalType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, _: DayTimeIntervalType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, _: DecimalType) => (updater, ordinal, value) => + val d = avroType.getLogicalType.asInstanceOf[CustomDecimal] + updater.setDecimal(ordinal, Decimal(value.asInstanceOf[Long], d.precision, d.scale)) + case _ => throw new IncompatibleSchemaException(incompatibleMsg) } } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 53562a3afdb5b..7b0292df43c2f 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -141,7 +141,8 @@ private[sql] class AvroFileFormat extends FileFormat requiredSchema, parsedOptions.positionalFieldMatching, datetimeRebaseMode, - avroFilters) + avroFilters, + parsedOptions.useStableIdForUnionType) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index edaaa8835cc01..5fd39393335d4 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -81,14 +81,14 @@ private[sql] class AvroOptions( /** * Top level record name in write result, which is required in Avro spec. - * See https://avro.apache.org/docs/1.11.2/specification/#schema-record . + * See https://avro.apache.org/docs/1.11.4/specification/#schema-record . * Default value is "topLevelRecord" */ val recordName: String = parameters.getOrElse(RECORD_NAME, "topLevelRecord") /** * Record namespace in write result. Default value is "". - * See Avro spec for details: https://avro.apache.org/docs/1.11.2/specification/#schema-record . + * See Avro spec for details: https://avro.apache.org/docs/1.11.4/specification/#schema-record . */ val recordNamespace: String = parameters.getOrElse(RECORD_NAMESPACE, "") diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 44520bb92d14e..23dfe86b6bce1 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -32,8 +32,7 @@ import org.apache.spark.SPARK_VERSION_SHORT import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ // NOTE: This class is instantiated and used on executor side only, no need to be serializable. diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index c95d731f0dedd..34bf47613e7bf 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSourceUtils -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ /** diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 2554106d78e9d..67e4583fe4822 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -29,6 +29,7 @@ import org.apache.avro.mapred.{AvroOutputFormat, FsInput} import org.apache.avro.mapreduce.AvroJob import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.hdfs.BlockMissingException import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkException @@ -140,6 +141,8 @@ private[sql] object AvroUtils extends Logging { try { Some(DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]())) } catch { + case e: BlockMissingException => + throw new SparkException(s"Could not read file: $path", e) case e: IOException => if (ignoreCorruptFiles) { logWarning(s"Skipped the footer in the corrupted file: $path", e) diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CustomDecimal.scala similarity index 95% rename from connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/CustomDecimal.scala index d76f40c7635c4..fab3d4493e344 100644 --- a/connector/avro/src/main/java/org/apache/spark/sql/avro/CustomDecimal.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CustomDecimal.scala @@ -22,14 +22,14 @@ import org.apache.avro.Schema import org.apache.spark.sql.types.DecimalType -object CustomDecimal { +private[spark] object CustomDecimal { val TYPE_NAME = "custom-decimal" } // A customized logical type, which will be registered to Avro. This logical type is similar to // Avro's builtin Decimal type, but is meant to be registered for long type. It indicates that // the long type should be converted to Spark's Decimal type, with provided precision and scale. -private class CustomDecimal(schema: Schema) extends LogicalType(CustomDecimal.TYPE_NAME) { +private[spark] class CustomDecimal(schema: Schema) extends LogicalType(CustomDecimal.TYPE_NAME) { val scale : Int = { val obj = schema.getObjectProp("scale") obj match { diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 6f21639e28d68..af358a8d1c961 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -46,16 +46,24 @@ object SchemaConverters { */ case class SchemaType(dataType: DataType, nullable: Boolean) + /** + * Converts an Avro schema to a corresponding Spark SQL schema. + * + * @since 4.0.0 + */ + def toSqlType(avroSchema: Schema, useStableIdForUnionType: Boolean): SchemaType = { + toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType) + } /** * Converts an Avro schema to a corresponding Spark SQL schema. * * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(Map())) + toSqlType(avroSchema, false) } def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options)) + toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options).useStableIdForUnionType) } // The property specifies Catalyst type of the given field @@ -64,7 +72,7 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, existingRecordNames: Set[String], - avroOptions: AvroOptions): SchemaType = { + useStableIdForUnionType: Boolean): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -117,7 +125,7 @@ object SchemaConverters { } val newRecordNames = existingRecordNames + avroSchema.getFullName val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, avroOptions) + val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, useStableIdForUnionType) StructField(f.name, schemaType.dataType, schemaType.nullable) } @@ -127,13 +135,14 @@ object SchemaConverters { val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, - avroOptions) + useStableIdForUnionType) SchemaType( ArrayType(schemaType.dataType, containsNull = schemaType.nullable), nullable = false) case MAP => - val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames, avroOptions) + val schemaType = toSqlTypeHelper(avroSchema.getValueType, + existingRecordNames, useStableIdForUnionType) SchemaType( MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), nullable = false) @@ -143,17 +152,18 @@ object SchemaConverters { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) if (remainingUnionTypes.size == 1) { - toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, avroOptions) + toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, useStableIdForUnionType) .copy(nullable = true) } else { toSqlTypeHelper( Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames, - avroOptions).copy(nullable = true) + useStableIdForUnionType).copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => - toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames, avroOptions) + toSqlTypeHelper(avroSchema.getTypes.get(0), + existingRecordNames, useStableIdForUnionType) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -167,20 +177,20 @@ object SchemaConverters { val fieldNameSet : mutable.Set[String] = mutable.Set() val fields = avroSchema.getTypes.asScala.zipWithIndex.map { case (s, i) => - val schemaType = toSqlTypeHelper(s, existingRecordNames, avroOptions) + val schemaType = toSqlTypeHelper(s, existingRecordNames, useStableIdForUnionType) - val fieldName = if (avroOptions.useStableIdForUnionType) { + val fieldName = if (useStableIdForUnionType) { // Avro's field name may be case sensitive, so field names for two named type // could be "a" and "A" and we need to distinguish them. In this case, we throw // an exception. - val temp_name = s"member_${s.getName.toLowerCase(Locale.ROOT)}" - if (fieldNameSet.contains(temp_name)) { + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"member_${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { throw new IncompatibleSchemaException( - "Cannot generate stable indentifier for Avro union type due to name " + + "Cannot generate stable identifier for Avro union type due to name " + s"conflict of type name ${s.getName}") } - fieldNameSet.add(temp_name) - temp_name + tempFieldName } else { s"member$i" } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index cc7bd180e8477..2c85c1b067392 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -103,7 +103,8 @@ case class AvroPartitionReaderFactory( readDataSchema, options.positionalFieldMatching, datetimeRebaseMode, - avroFilters) + avroFilters, + options.useStableIdForUnionType) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 5c0d64b4d55eb..250b5e0615ad8 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopF import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.LegacyBehaviorPolicy import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -59,7 +59,7 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite val expected = { val avroSchema = new Schema.Parser().parse(schema) - SchemaConverters.toSqlType(avroSchema).dataType match { + SchemaConverters.toSqlType(avroSchema, false).dataType match { case st: StructType => Row.fromSeq((0 until st.length).map(_ => null)) case _ => null } @@ -281,13 +281,14 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite data: GenericData.Record, expected: Option[Any], filters: StructFilters = new NoopFilters): Unit = { - val dataType = SchemaConverters.toSqlType(schema).dataType + val dataType = SchemaConverters.toSqlType(schema, false).dataType val deserializer = new AvroDeserializer( schema, dataType, false, - RebaseSpec(SQLConf.LegacyBehaviorPolicy.CORRECTED), - filters) + RebaseSpec(LegacyBehaviorPolicy.CORRECTED), + filters, + false) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index cc0e178c617af..965e3a0c1cba6 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters} import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.LegacyBehaviorPolicy._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.v2.avro.AvroScan @@ -75,7 +75,8 @@ class AvroRowReaderSuite StructType(new StructField("value", IntegerType, true) :: Nil), false, RebaseSpec(CORRECTED), - new NoopFilters) + new NoopFilters, + false) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index bfd56613fd64c..a21f3f008fdc7 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -22,7 +22,7 @@ import org.apache.avro.generic.GenericRecordBuilder import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.NoopFilters import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy.CORRECTED +import org.apache.spark.sql.internal.LegacyBehaviorPolicy.CORRECTED import org.apache.spark.sql.types.{IntegerType, StructType} /** @@ -226,7 +226,8 @@ object AvroSerdeSuite { sql, isPositional(matchType), RebaseSpec(CORRECTED), - new NoopFilters) + new NoopFilters, + false) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 35e9f43289c16..01c9dfb57a191 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -46,9 +46,9 @@ import org.apache.spark.sql.execution.{FormattedMode, SparkPlan} import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, FilePartition} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.internal.LegacyBehaviorPolicy._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.v2.avro.AvroScan @@ -370,7 +370,7 @@ abstract class AvroSuite "", Seq()) } - assert(e.getMessage.contains("Cannot generate stable indentifier")) + assert(e.getMessage.contains("Cannot generate stable identifier")) } { val e = intercept[Exception] { @@ -381,7 +381,7 @@ abstract class AvroSuite "", Seq()) } - assert(e.getMessage.contains("Cannot generate stable indentifier")) + assert(e.getMessage.contains("Cannot generate stable identifier")) } // Two array types or two map types are not allowed in union. { @@ -434,6 +434,33 @@ abstract class AvroSuite } } + test("SPARK-47904: Test that field name case is preserved") { + checkUnionStableId( + List( + Schema.createEnum("myENUM", "", null, List[String]("E1", "e2").asJava), + Schema.createRecord("myRecord", "", null, false, + List[Schema.Field](new Schema.Field("f", Schema.createFixed("myField", "", null, 6))) + .asJava), + Schema.createRecord("myRecord2", "", null, false, + List[Schema.Field](new Schema.Field("F", Schema.create(Type.FLOAT))) + .asJava)), + "struct, " + + "member_myRecord2: struct>", + Seq()) + + { + val e = intercept[Exception] { + checkUnionStableId( + List( + Schema.createRecord("myRecord", "", null, false, List[Schema.Field]().asJava), + Schema.createRecord("myrecord", "", null, false, List[Schema.Field]().asJava)), + "", + Seq()) + } + assert(e.getMessage.contains("Cannot generate stable identifier")) + } + } + test("SPARK-27858 Union type: More than one non-null type") { Seq(true, false).foreach { isStableUnionMember => withTempDir { dir => @@ -816,7 +843,7 @@ abstract class AvroSuite } test("SPARK-43380: Fix Avro data type conversion" + - " of decimal type to avoid producing incorrect results") { + " of decimal type to avoid producing incorrect results") { withTempPath { path => val confKey = SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key sql("SELECT 13.1234567890 a").write.format("avro").save(path.toString) @@ -829,12 +856,11 @@ abstract class AvroSuite case ex: AnalysisException => checkError( exception = ex, - errorClass = "AVRO_LOWER_PRECISION", + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "decimal\\(12,10\\)", - "sqlType" -> "\"DECIMAL\\(4,3\\)\"", - "key" -> SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key), + "sqlType" -> "\"DECIMAL\\(4,3\\)\""), matchPVals = true ) case other => @@ -880,12 +906,11 @@ abstract class AvroSuite case ex: AnalysisException => checkError( exception = ex, - errorClass = "AVRO_INCORRECT_TYPE", + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval day to second", - "sqlType" -> s""""$sqlType"""", - "key" -> SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key), + "sqlType" -> s""""$sqlType""""), matchPVals = true ) case other => @@ -923,12 +948,11 @@ abstract class AvroSuite case ex: AnalysisException => checkError( exception = ex, - errorClass = "AVRO_INCORRECT_TYPE", + errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval year to month", - "sqlType" -> s""""$sqlType"""", - "key" -> SQLConf.LEGACY_AVRO_ALLOW_INCOMPATIBLE_SCHEMA.key), + "sqlType" -> s""""$sqlType""""), matchPVals = true ) case other => @@ -2140,7 +2164,7 @@ abstract class AvroSuite private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { val message = intercept[IncompatibleSchemaException] { - SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema)) + SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false) }.getMessage assert(message.contains("Found recursive reference in Avro schema")) diff --git a/connector/connect/bin/spark-connect-build b/connector/connect/bin/spark-connect-build index ca8d4cf6e9005..63c17d8f7aa06 100755 --- a/connector/connect/bin/spark-connect-build +++ b/connector/connect/bin/spark-connect-build @@ -29,7 +29,5 @@ SCALA_BINARY_VER=`grep "scala.binary.version" "${SPARK_HOME}/pom.xml" | head -n1 SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VER} | head -n1 | awk -F '[<>]' '{print $3}'` SCALA_ARG="-Pscala-${SCALA_BINARY_VER}" -# Build the jars needed for spark submit and spark connect -build/sbt "${SCALA_ARG}" -Phive -Pconnect package || exit 1 -# Build the jars needed for spark connect JVM client -build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly" || exit 1 +# Build the jars needed for spark submit and spark connect JVM client +build/sbt "${SCALA_ARG}" -Phive -Pconnect package "connect-client-jvm/package" || exit 1 diff --git a/connector/connect/bin/spark-connect-scala-client b/connector/connect/bin/spark-connect-scala-client index ef394df4e0f25..ffa77f708421f 100755 --- a/connector/connect/bin/spark-connect-scala-client +++ b/connector/connect/bin/spark-connect-scala-client @@ -45,7 +45,7 @@ SCALA_ARG="-Pscala-${SCALA_BINARY_VER}" SCBUILD="${SCBUILD:-1}" if [ "$SCBUILD" -eq "1" ]; then # Build the jars needed for spark connect JVM client - build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly" || exit 1 + build/sbt "${SCALA_ARG}" "connect-client-jvm/package" || exit 1 fi if [ -z "$SCCLASSPATH" ]; then diff --git a/connector/connect/bin/spark-connect-scala-client-classpath b/connector/connect/bin/spark-connect-scala-client-classpath index 99a22f3d5ffeb..9d33e90bf09cb 100755 --- a/connector/connect/bin/spark-connect-scala-client-classpath +++ b/connector/connect/bin/spark-connect-scala-client-classpath @@ -30,6 +30,5 @@ SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VE SCALA_ARG="-Pscala-${SCALA_BINARY_VER}" CONNECT_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export connect-client-jvm/fullClasspath" | grep jar | tail -n1)" -SQL_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export sql/fullClasspath" | grep jar | tail -n1)" -echo "$CONNECT_CLASSPATH:$CLASSPATH" +echo "$CONNECT_CLASSPATH" diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 8a51bf65d6a88..f2630bfb9303f 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../../../pom.xml @@ -39,31 +39,21 @@ org.apache.spark spark-connect-common_${scala.binary.version} ${project.version} - - - com.google.guava - guava - - - org.apache.spark - spark-catalyst_${scala.binary.version} + spark-sql-api_${scala.binary.version} ${project.version} - provided - - - com.google.guava - guava - - org.apache.spark - spark-common-utils_${scala.binary.version} + spark-sketch_${scala.binary.version} ${project.version} + com.google.protobuf protobuf-java @@ -81,54 +71,25 @@ ${guava.failureaccess.version} compile - - io.netty - netty-codec-http2 - ${netty.version} - - - io.netty - netty-handler-proxy - ${netty.version} - - - io.netty - netty-transport-native-unix-common - ${netty.version} - com.lihaoyi ammonite_${scala.version} ${ammonite.version} provided - - - org.scala-lang.modules - scala-xml_${scala.binary.version} - - - org.apache.spark - spark-connect-common_${scala.binary.version} - ${project.version} - test-jar + commons-io + commons-io test - - - com.google.guava - guava - - - org.scalacheck - scalacheck_${scala.binary.version} + org.apache.commons + commons-lang3 test - org.mockito - mockito-core + org.scalacheck + scalacheck_${scala.binary.version} test @@ -140,6 +101,7 @@ + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes @@ -149,60 +111,78 @@ maven-shade-plugin false + true + com.google.guava:* com.google.android:* com.google.api.grpc:* com.google.code.findbugs:* com.google.code.gson:* com.google.errorprone:* - com.google.guava:* com.google.j2objc:* com.google.protobuf:* + com.google.flatbuffers:* io.grpc:* io.netty:* io.perfmark:* + org.apache.arrow:* org.codehaus.mojo:* org.checkerframework:* org.apache.spark:spark-connect-common_${scala.binary.version} - org.apache.spark:spark-common-utils_${scala.binary.version} + org.apache.spark:spark-sql-api_${scala.binary.version} + + com.google.common + ${spark.shade.packageName}.connect.guava + + com.google.common.** + + io.grpc - ${spark.shade.packageName}.connect.client.io.grpc + ${spark.shade.packageName}.io.grpc io.grpc.** com.google - ${spark.shade.packageName}.connect.client.com.google + ${spark.shade.packageName}.com.google + + + com.google.common.** + io.netty - ${spark.shade.packageName}.connect.client.io.netty + ${spark.shade.packageName}.io.netty org.checkerframework - ${spark.shade.packageName}.connect.client.org.checkerframework + ${spark.shade.packageName}.org.checkerframework javax.annotation - ${spark.shade.packageName}.connect.client.javax.annotation + ${spark.shade.packageName}.javax.annotation io.perfmark - ${spark.shade.packageName}.connect.client.io.perfmark + ${spark.shade.packageName}.io.perfmark org.codehaus - ${spark.shade.packageName}.connect.client.org.codehaus + ${spark.shade.packageName}.org.codehaus + + + org.apache.arrow + ${spark.shade.packageName}.org.apache.arrow android.annotation - ${spark.shade.packageName}.connect.client.android.annotation + ${spark.shade.packageName}.android.annotation @@ -224,6 +204,24 @@ + + org.codehaus.mojo + build-helper-maven-plugin + + + add-sources + generate-sources + + add-source + + + + src/main/scala-${scala.binary.version} + + + + + \ No newline at end of file diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java deleted file mode 100644 index 95af157687c85..0000000000000 --- a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql; - -import org.apache.spark.annotation.Stable; - -/** - * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. - * - * @since 3.4.0 - */ -@Stable -public enum SaveMode { - /** - * Append mode means that when saving a DataFrame to a data source, if data/table already exists, - * contents of the DataFrame are expected to be appended to existing data. - * - * @since 3.4.0 - */ - Append, - /** - * Overwrite mode means that when saving a DataFrame to a data source, - * if data/table already exists, existing data is expected to be overwritten by the contents of - * the DataFrame. - * - * @since 3.4.0 - */ - Overwrite, - /** - * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, - * an exception is expected to be thrown. - * - * @since 3.4.0 - */ - ErrorIfExists, - /** - * Ignore mode means that when saving a DataFrame to a data source, if data already exists, - * the save operation is expected to not save the contents of the DataFrame and to not - * change the existing data. - * - * @since 3.4.0 - */ - Ignore -} diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java deleted file mode 100644 index 27ffe67d9909c..0000000000000 --- a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming; - -import java.util.concurrent.TimeUnit; - -import scala.concurrent.duration.Duration; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.AvailableNowTrigger$; -import org.apache.spark.sql.execution.streaming.ContinuousTrigger; -import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; -import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger; - -/** - * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. - * - * @since 3.5.0 - */ -@Evolving -public class Trigger { - // This is a copy of the same class in sql/core/.../streaming/Trigger.java - - /** - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is 0, the query will run as fast as possible. - * - * @since 3.5.0 - */ - public static Trigger ProcessingTime(long intervalMs) { - return ProcessingTimeTrigger.create(intervalMs, TimeUnit.MILLISECONDS); - } - - /** - * (Java-friendly) - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is 0, the query will run as fast as possible. - * - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { - return ProcessingTimeTrigger.create(interval, timeUnit); - } - - /** - * (Scala-friendly) - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `duration` is 0, the query will run as fast as possible. - * - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) - * }}} - * @since 3.5.0 - */ - public static Trigger ProcessingTime(Duration interval) { - return ProcessingTimeTrigger.apply(interval); - } - - /** - * A trigger policy that runs a query periodically based on an interval in processing time. - * If `interval` is effectively 0, the query will run as fast as possible. - * - * {{{ - * df.writeStream.trigger(Trigger.ProcessingTime("10 seconds")) - * }}} - * @since 3.5.0 - */ - public static Trigger ProcessingTime(String interval) { - return ProcessingTimeTrigger.apply(interval); - } - - /** - * A trigger that processes all available data in a single batch then terminates the query. - * - * @since 3.5.0 - * @deprecated This is deprecated as of Spark 3.4.0. Use {@link #AvailableNow()} to leverage - * better guarantee of processing, fine-grained scale of batches, and better gradual - * processing of watermark advancement including no-data batch. - * See the NOTES in {@link #AvailableNow()} for details. - */ - @Deprecated - public static Trigger Once() { - return OneTimeTrigger$.MODULE$; - } - - /** - * A trigger that processes all available data at the start of the query in one or multiple - * batches, then terminates the query. - * - * Users are encouraged to set the source options to control the size of the batch as similar as - * controlling the size of the batch in {@link #ProcessingTime(long)} trigger. - * - * NOTES: - * - This trigger provides a strong guarantee of processing: regardless of how many batches were - * left over in previous run, it ensures all available data at the time of execution gets - * processed before termination. All uncommitted batches will be processed first. - * - Watermark gets advanced per each batch, and no-data batch gets executed before termination - * if the last batch advances the watermark. This helps to maintain smaller and predictable - * state size and smaller latency on the output of stateful operators. - * - * @since 3.5.0 - */ - public static Trigger AvailableNow() { - return AvailableNowTrigger$.MODULE$; - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * @since 3.5.0 - */ - public static Trigger Continuous(long intervalMs) { - return ContinuousTrigger.apply(intervalMs); - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(Trigger.Continuous(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - public static Trigger Continuous(long interval, TimeUnit timeUnit) { - return ContinuousTrigger.create(interval, timeUnit); - } - - /** - * (Scala-friendly) - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * import scala.concurrent.duration._ - * df.writeStream.trigger(Trigger.Continuous(10.seconds)) - * }}} - * @since 3.5.0 - */ - public static Trigger Continuous(Duration interval) { - return ContinuousTrigger.apply(interval); - } - - /** - * A trigger that continuously processes streaming data, asynchronously checkpointing at - * the specified interval. - * - * {{{ - * df.writeStream.trigger(Trigger.Continuous("10 seconds")) - * }}} - * @since 3.5.0 - */ - public static Trigger Continuous(String interval) { - return ContinuousTrigger.apply(interval); - } -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala index 6a660a7482e27..4a527040d80cf 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -24,7 +24,7 @@ import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.parser.DataTypeParser import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -1087,7 +1087,7 @@ class Column private[sql] (@DeveloperApi val expr: proto.Expression) extends Log * @group expr_ops * @since 3.4.0 */ - def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to)) + def cast(to: String): Column = cast(DataTypeParser.parseDataType(to)) /** * Returns a sort expression based on the descending order of the column. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 40f9ac1df2b22..10d2af094a08c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -25,9 +25,9 @@ import org.apache.spark.annotation.Stable import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -58,7 +58,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging */ def schema(schema: StructType): DataFrameReader = { if (schema != null) { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] this.userSpecifiedSchema = Option(replaced) } this @@ -563,7 +563,7 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging */ private def assertNoSpecifiedSchema(operation: String): Unit = { if (userSpecifiedSchema.nonEmpty) { - throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation) + throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 0d4372b8738ee..4d35b4e876795 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} +import java.io.ByteArrayInputStream import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.CountMinSketch +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. @@ -584,6 +586,90 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } CountMinSketch.readFrom(ds.head()) } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.5.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param fpp + * expected false positive probability of the filter. + * @since 3.5.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + buildBloomFilter(col, expectedNumItems, -1L, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName + * name of the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.5.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col + * the column over which the filter is built + * @param expectedNumItems + * expected number of items which will be put into the filter. + * @param numBits + * expected number of bits of the filter. + * @since 3.5.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) + } + + private def buildBloomFilter( + col: Column, + expectedNumItems: Long, + numBits: Long, + fpp: Double): BloomFilter = { + def numBitsValue: Long = if (!fpp.isNaN) { + BloomFilter.optimalNumOfBits(expectedNumItems, fpp) + } else { + numBits + } + + if (fpp <= 0d || fpp >= 1d) { + throw new SparkException("False positive probability must be within range (0.0, 1.0)") + } + val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBitsValue)) + + val ds = sparkSession.newDataset(BinaryEncoder) { builder => + builder.getProjectBuilder + .setInput(root) + .addExpressions(agg.expr) + } + BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) + } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 47e361c96795b..865596a669a09 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,15 +21,17 @@ import java.util.{Collections, Locale} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ -import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.expressions.OrderUtils import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils} import org.apache.spark.sql.expressions.ScalarUserDefinedFunction @@ -37,7 +39,7 @@ import org.apache.spark.sql.functions.{struct, to_json} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkClassUtils /** * A Dataset is a strongly typed collection of domain-specific objects that can be transformed in @@ -128,11 +130,13 @@ import org.apache.spark.util.Utils class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, - val encoder: AgnosticEncoder[T]) + val encoder: Encoder[T]) extends Serializable { // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) + private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder) + override def toString: String = { try { val builder = new mutable.StringBuilder @@ -828,7 +832,7 @@ class Dataset[T] private[sql] ( } private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getSortBuilder .setInput(plan.getRoot) .setIsGlobal(global) @@ -876,10 +880,11 @@ class Dataset[T] private[sql] ( val tupleEncoder = ProductEncoder[(T, U)]( - ClassTag(Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")), + ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")), Seq( - EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty), - EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty))) + EncoderField(s"_1", this.agnosticEncoder, leftNullable, Metadata.empty), + EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)), + None) sparkSession.newDataset(tupleEncoder) { builder => val joinBuilder = builder.getJoinBuilder @@ -889,8 +894,8 @@ class Dataset[T] private[sql] ( .setJoinType(joinTypeValue) .setJoinCondition(condition.expr) .setJoinDataType(joinBuilder.getJoinDataTypeBuilder - .setIsLeftStruct(this.encoder.isStruct) - .setIsRightStruct(other.encoder.isStruct)) + .setIsLeftStruct(this.agnosticEncoder.isStruct) + .setIsRightStruct(other.agnosticEncoder.isStruct)) } } @@ -1010,13 +1015,13 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) { - builder => + def hint(name: String, parameters: Any*): Dataset[T] = + sparkSession.newDataset(agnosticEncoder) { builder => builder.getHintBuilder .setInput(plan.getRoot) .setName(name) .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava) - } + } private def getPlanId: Option[Long] = if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) { @@ -1056,7 +1061,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder => + def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getSubqueryAliasBuilder .setInput(plan.getRoot) .setAlias(alias) @@ -1238,8 +1243,9 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder => - builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) + def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { + builder => + builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) } /** @@ -1285,7 +1291,7 @@ class Dataset[T] private[sql] ( val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) .addAllIds(ids.toSeq.map(_.expr).asJava) - .setValueColumnName(variableColumnName) + .setVariableColumnName(variableColumnName) .setValueColumnName(valueColumnName) valuesOption.foreach { values => unpivot.getValuesBuilder @@ -1355,12 +1361,12 @@ class Dataset[T] private[sql] ( def reduce(func: (T, T) => T): T = { val udf = ScalarUserDefinedFunction( function = func, - inputEncoders = encoder :: encoder :: Nil, - outputEncoder = encoder) + inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil, + outputEncoder = agnosticEncoder) val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr val result = sparkSession - .newDataset(encoder) { builder => + .newDataset(agnosticEncoder) { builder => builder.getAggregateBuilder .setInput(plan.getRoot) .addAggregateExpressions(reduceExpr) @@ -1718,7 +1724,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder => + def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder .setInput(plan.getRoot) .setLimit(n) @@ -1730,7 +1736,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder => + def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getOffsetBuilder .setInput(plan.getRoot) .setOffset(n) @@ -1739,7 +1745,7 @@ class Dataset[T] private[sql] ( private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)( f: proto.SetOperation.Builder => Unit): Dataset[T] = { checkSameSparkSession(right) - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => f( builder.getSetOpBuilder .setSetOpType(setOpType) @@ -1750,7 +1756,10 @@ class Dataset[T] private[sql] ( private def checkSameSparkSession(other: Dataset[_]): Unit = { if (this.sparkSession.sessionId != other.sparkSession.sessionId) { - throw new SparkException("Both Datasets must belong to the same SparkSession") + throw new SparkException( + errorClass = "CONNECT.SESSION_NOT_SAME", + messageParameters = Map.empty, + cause = null) } } @@ -2009,7 +2018,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getSampleBuilder .setInput(plan.getRoot) .setWithReplacement(withReplacement) @@ -2035,7 +2044,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { - sample(withReplacement, fraction, Utils.random.nextLong) + sample(withReplacement, fraction, SparkClassUtils.random.nextLong) } /** @@ -2069,7 +2078,7 @@ class Dataset[T] private[sql] ( // between construction and execution the query might fail or produce wrong results. Another // problem can come from data that arrives between the execution of the returned datasets. val sortOrder = schema.collect { - case f if RowOrdering.isOrderable(f.dataType) => col(f.name).asc + case f if OrderUtils.isOrderable(f.dataType) => col(f.name).asc } val sortedInput = sortWithinPartitions(sortOrder: _*).plan.getRoot val sum = weights.sum @@ -2077,7 +2086,7 @@ class Dataset[T] private[sql] ( normalizedCumWeights .sliding(2) .map { case Array(low, high) => - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getSampleBuilder .setInput(sortedInput) .setWithReplacement(false) @@ -2114,7 +2123,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def randomSplit(weights: Array[Double]): Array[Dataset[T]] = { - randomSplit(weights, Utils.random.nextLong) + randomSplit(weights, SparkClassUtils.random.nextLong) } private def withColumns(names: Seq[String], values: Seq[Column]): DataFrame = { @@ -2396,6 +2405,20 @@ class Dataset[T] private[sql] ( .addAllColumnNames(cols.asJava) } + private def buildDropDuplicates( + columns: Option[Seq[String]], + withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { + builder => + val dropBuilder = builder.getDeduplicateBuilder + .setInput(plan.getRoot) + .setWithinWatermark(withinWaterMark) + if (columns.isDefined) { + dropBuilder.addAllColumnNames(columns.get.asJava) + } else { + dropBuilder.setAllColumnsAsKeys(true) + } + } + /** * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias * for `distinct`. @@ -2403,11 +2426,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder => - builder.getDeduplicateBuilder - .setInput(plan.getRoot) - .setAllColumnsAsKeys(true) - } + def dropDuplicates(): Dataset[T] = buildDropDuplicates(None, withinWaterMark = false) /** * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the @@ -2416,11 +2435,8 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) { - builder => - builder.getDeduplicateBuilder - .setInput(plan.getRoot) - .addAllColumnNames(colNames.asJava) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = { + buildDropDuplicates(Option(colNames), withinWaterMark = false) } /** @@ -2440,16 +2456,14 @@ class Dataset[T] private[sql] ( */ @scala.annotation.varargs def dropDuplicates(col1: String, cols: String*): Dataset[T] = { - val colNames: Seq[String] = col1 +: cols - dropDuplicates(colNames) + dropDuplicates(col1 +: cols) } - def dropDuplicatesWithinWatermark(): Dataset[T] = { - dropDuplicatesWithinWatermark(this.columns) - } + def dropDuplicatesWithinWatermark(): Dataset[T] = + buildDropDuplicates(None, withinWaterMark = true) def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = { - throw new UnsupportedOperationException("dropDuplicatesWithinWatermark is not implemented.") + buildDropDuplicates(Option(colNames), withinWaterMark = true) } def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = { @@ -2458,8 +2472,7 @@ class Dataset[T] private[sql] ( @scala.annotation.varargs def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = { - val colNames: Seq[String] = col1 +: cols - dropDuplicatesWithinWatermark(colNames) + dropDuplicatesWithinWatermark(col1 +: cols) } /** @@ -2624,9 +2637,9 @@ class Dataset[T] private[sql] ( def filter(func: T => Boolean): Dataset[T] = { val udf = ScalarUserDefinedFunction( function = func, - inputEncoders = encoder :: Nil, + inputEncoders = agnosticEncoder :: Nil, outputEncoder = PrimitiveBooleanEncoder) - sparkSession.newDataset[T](encoder) { builder => + sparkSession.newDataset[T](agnosticEncoder) { builder => builder.getFilterBuilder .setInput(plan.getRoot) .setCondition(udf.apply(col("*")).expr) @@ -2677,7 +2690,7 @@ class Dataset[T] private[sql] ( val outputEncoder = encoderFor[U] val udf = ScalarUserDefinedFunction( function = func, - inputEncoders = encoder :: Nil, + inputEncoders = agnosticEncoder :: Nil, outputEncoder = outputEncoder) sparkSession.newDataset(outputEncoder) { builder => builder.getMapPartitionsBuilder @@ -2718,6 +2731,74 @@ class Dataset[T] private[sql] ( flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder) } + /** + * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows + * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the + * input row are implicitly joined with each row that is output by the function. + * + * Given that this is deprecated, as an alternative, you can explode columns either using + * `functions.explode()` or `flatMap()`. The following example uses these alternatives to count + * the number of books that contain a given word: + * + * {{{ + * case class Book(title: String, words: String) + * val ds: Dataset[Book] + * + * val allWords = ds.select($"title", explode(split($"words", " ")).as("word")) + * + * val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title")) + * }}} + * + * Using `flatMap()` this can similarly be exploded as: + * + * {{{ + * ds.flatMap(_.words.split(" ")) + * }}} + * + * @group untypedrel + * @since 3.5.0 + */ + @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") + def explode[A <: Product: TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val generator = ScalarUserDefinedFunction( + UdfUtils.traversableOnceToSeq(f), + UnboundRowEncoder :: Nil, + ScalaReflection.encoderFor[Seq[A]]) + select(col("*"), functions.inline(generator(struct(input: _*)))) + } + + /** + * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or + * more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the + * function. + * + * Given that this is deprecated, as an alternative, you can explode columns either using + * `functions.explode()`: + * + * {{{ + * ds.select(explode(split($"words", " ")).as("word")) + * }}} + * + * or `flatMap()`: + * + * {{{ + * ds.flatMap(_.words.split(" ")) + * }}} + * + * @group untypedrel + * @since 3.5.0 + */ + @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") + def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame = { + val generator = ScalarUserDefinedFunction( + UdfUtils.traversableOnceToSeq(f), + Nil, + ScalaReflection.encoderFor[Seq[B]]) + select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn))) + } + /** * Applies a function `f` to all rows. * @@ -2779,7 +2860,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def tail(n: Int): Array[T] = { - val lastN = sparkSession.newDataset(encoder) { builder => + val lastN = sparkSession.newDataset(agnosticEncoder) { builder => builder.getTailBuilder .setInput(plan.getRoot) .setLimit(n) @@ -2829,7 +2910,7 @@ class Dataset[T] private[sql] ( /** * Returns an iterator that contains all rows in this Dataset. * - * The returned iterator implements [[AutoCloseable]]. For memory management it is better to + * The returned iterator implements [[AutoCloseable]]. For resource management it is better to * close it once you are done. If you don't close it, it and the underlying data will be cleaned * up once the iterator is garbage collected. * @@ -2837,7 +2918,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def toLocalIterator(): java.util.Iterator[T] = { - collectResult().destructiveIterator + collectResult().destructiveIterator.asJava } /** @@ -2850,7 +2931,7 @@ class Dataset[T] private[sql] ( } private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getRepartitionBuilder .setInput(plan.getRoot) .setNumPartitions(numPartitions) @@ -2860,11 +2941,12 @@ class Dataset[T] private[sql] ( private def buildRepartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder => - val repartitionBuilder = builder.getRepartitionByExpressionBuilder - .setInput(plan.getRoot) - .addAllPartitionExprs(partitionExprs.map(_.expr).asJava) - numPartitions.foreach(repartitionBuilder.setNumPartitions) + partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { + builder => + val repartitionBuilder = builder.getRepartitionByExpressionBuilder + .setInput(plan.getRoot) + .addAllPartitionExprs(partitionExprs.map(_.expr).asJava) + numPartitions.foreach(repartitionBuilder.setNumPartitions) } /** @@ -3177,7 +3259,7 @@ class Dataset[T] private[sql] ( * @since 3.5.0 */ def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getWithWatermarkBuilder .setInput(plan.getRoot) .setEventTime(eventTime) @@ -3245,7 +3327,7 @@ class Dataset[T] private[sql] ( sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA) } - def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder) + def collectResult(): SparkResult[T] = sparkSession.execute(plan, agnosticEncoder) private[sql] def withResult[E](f: SparkResult[T] => E): E = { val result = collectResult() @@ -3254,4 +3336,10 @@ class Dataset[T] private[sql] ( result.close() } } + + /** + * We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We + * null out the instance for now. + */ + private def writeReplace(): Any = null } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala new file mode 100644 index 0000000000000..74f0133803137 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => RowEncoderFactory} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.types.StructType + +/** + * Methods for creating an [[Encoder]]. + * + * @since 3.5.0 + */ +object Encoders { + + /** + * An encoder for nullable boolean type. The Scala primitive encoder is available as + * [[scalaBoolean]]. + * @since 3.5.0 + */ + def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder + + /** + * An encoder for nullable byte type. The Scala primitive encoder is available as [[scalaByte]]. + * @since 3.5.0 + */ + def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder + + /** + * An encoder for nullable short type. The Scala primitive encoder is available as + * [[scalaShort]]. + * @since 3.5.0 + */ + def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder + + /** + * An encoder for nullable int type. The Scala primitive encoder is available as [[scalaInt]]. + * @since 3.5.0 + */ + def INT: Encoder[java.lang.Integer] = BoxedIntEncoder + + /** + * An encoder for nullable long type. The Scala primitive encoder is available as [[scalaLong]]. + * @since 3.5.0 + */ + def LONG: Encoder[java.lang.Long] = BoxedLongEncoder + + /** + * An encoder for nullable float type. The Scala primitive encoder is available as + * [[scalaFloat]]. + * @since 3.5.0 + */ + def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder + + /** + * An encoder for nullable double type. The Scala primitive encoder is available as + * [[scalaDouble]]. + * @since 3.5.0 + */ + def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder + + /** + * An encoder for nullable string type. + * + * @since 3.5.0 + */ + def STRING: Encoder[java.lang.String] = StringEncoder + + /** + * An encoder for nullable decimal type. + * + * @since 3.5.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER + + /** + * An encoder for nullable date type. + * + * @since 3.5.0 + */ + def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false) + + /** + * Creates an encoder that serializes instances of the `java.time.LocalDate` class to the + * internal representation of nullable Catalyst's DateType. + * + * @since 3.5.0 + */ + def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER + + /** + * Creates an encoder that serializes instances of the `java.time.LocalDateTime` class to the + * internal representation of nullable Catalyst's TimestampNTZType. + * + * @since 3.5.0 + */ + def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder + + /** + * An encoder for nullable timestamp type. + * + * @since 3.5.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER + + /** + * Creates an encoder that serializes instances of the `java.time.Instant` class to the internal + * representation of nullable Catalyst's TimestampType. + * + * @since 3.5.0 + */ + def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER + + /** + * An encoder for arrays of bytes. + * + * @since 3.5.0 + */ + def BINARY: Encoder[Array[Byte]] = BinaryEncoder + + /** + * Creates an encoder that serializes instances of the `java.time.Duration` class to the + * internal representation of nullable Catalyst's DayTimeIntervalType. + * + * @since 3.5.0 + */ + def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder + + /** + * Creates an encoder that serializes instances of the `java.time.Period` class to the internal + * representation of nullable Catalyst's YearMonthIntervalType. + * + * @since 3.5.0 + */ + def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder + + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal, java.math.BigInteger + * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant + * - collection types: array, java.util.List, and map + * - nested java bean. + * + * @since 3.5.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass) + + /** + * Creates a [[Row]] encoder for schema `schema`. + * + * @since 3.5.0 + */ + def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { + ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] + } + + /** + * An encoder for 2-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = tupleEncoder(e1, e2) + + /** + * An encoder for 3-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = tupleEncoder(e1, e2, e3) + + /** + * An encoder for 4-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = tupleEncoder(e1, e2, e3, e4) + + /** + * An encoder for 5-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = tupleEncoder(e1, e2, e3, e4, e5) + + /** + * An encoder for Scala's product type (tuples, case classes, etc). + * @since 3.5.0 + */ + def product[T <: Product: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] + + /** + * An encoder for Scala's primitive int type. + * @since 3.5.0 + */ + def scalaInt: Encoder[Int] = PrimitiveIntEncoder + + /** + * An encoder for Scala's primitive long type. + * @since 3.5.0 + */ + def scalaLong: Encoder[Long] = PrimitiveLongEncoder + + /** + * An encoder for Scala's primitive double type. + * @since 3.5.0 + */ + def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder + + /** + * An encoder for Scala's primitive float type. + * @since 3.5.0 + */ + def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder + + /** + * An encoder for Scala's primitive byte type. + * @since 3.5.0 + */ + def scalaByte: Encoder[Byte] = PrimitiveByteEncoder + + /** + * An encoder for Scala's primitive short type. + * @since 3.5.0 + */ + def scalaShort: Encoder[Short] = PrimitiveShortEncoder + + /** + * An encoder for Scala's primitive boolean type. + * @since 3.5.0 + */ + def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e67ef1c0fa7e2..88c8b6a4f8bad 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -979,6 +979,12 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( outputEncoder = outputEncoder) udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction } + + /** + * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the + * server side. We null out the instance for now. + */ + private def writeReplace(): Any = null } private object KeyValueGroupedDatasetImpl { @@ -988,15 +994,15 @@ private object KeyValueGroupedDatasetImpl { groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = { val gf = ScalarUserDefinedFunction( function = groupingFunc, - inputEncoders = ds.encoder :: Nil, // Using the original value and key encoders + inputEncoders = ds.agnosticEncoder :: Nil, // Using the original value and key encoders outputEncoder = kEncoder) new KeyValueGroupedDatasetImpl( ds.sparkSession, ds.plan, kEncoder, kEncoder, - ds.encoder, - ds.encoder, + ds.agnosticEncoder, + ds.agnosticEncoder, Arrays.asList(gf.apply(col("*")).expr), UdfUtils.identical(), () => ds.map(groupingFunc)(kEncoder)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index fb9959c994289..421f37b9e8a62 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.connect.client.util.Cleaner import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto -import org.apache.spark.sql.internal.CatalogImpl +import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType @@ -126,17 +126,14 @@ class SparkSession private[sql] ( private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = { newDataset(encoder) { builder => if (data.nonEmpty) { - val timeZoneId = conf.get("spark.sql.session.timeZone") val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId) - if (arrowData.size() <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) { + if (arrowData.size() <= conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) { builder.getLocalRelationBuilder .setSchema(encoder.schema.json) .setData(arrowData) } else { val hash = client.cacheLocalRelation(arrowData, encoder.schema.json) builder.getCachedLocalRelationBuilder - .setUserId(client.userId) - .setSessionId(client.sessionId) .setHash(hash) } } else { @@ -253,9 +250,10 @@ class SparkSession private[sql] ( .setSql(sqlText) .addAllPosArgs(args.map(toLiteralProto).toIterable.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) + // .toBuffer forces that the iterator is consumed and closed + val responseSeq = client.execute(plan.build()).toBuffer.toSeq - val response = responseIter.asScala + val response = responseSeq .find(_.hasSqlCommandResult) .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) @@ -309,9 +307,10 @@ class SparkSession private[sql] ( .setSql(sqlText) .putAllArgs(args.asScala.mapValues(toLiteralProto).toMap.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) - val responseIter = client.execute(plan.build()) + // .toBuffer forces that the iterator is consumed and closed + val responseSeq = client.execute(plan.build()).toBuffer.toSeq - val response = responseIter.asScala + val response = responseSeq .find(_.hasSqlCommandResult) .getOrElse(throw new RuntimeException("SQLCommandResult must be present")) @@ -529,9 +528,11 @@ class SparkSession private[sql] ( client.semanticHash(plan).getSemanticHash.getResult } + private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY) + private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = { val value = client.execute(plan) - val result = new SparkResult(value, allocator, encoder) + val result = new SparkResult(value, allocator, encoder, timeZoneId) cleaner.register(result) result } @@ -541,19 +542,19 @@ class SparkSession private[sql] ( f(builder) builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement()) val plan = proto.Plan.newBuilder().setRoot(builder).build() - client.execute(plan).asScala.foreach(_ => ()) + // .toBuffer forces that the iterator is consumed and closed + client.execute(plan).toBuffer } private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = { val plan = proto.Plan.newBuilder().setCommand(command).build() - client.execute(plan).asScala.toSeq + // .toBuffer forces that the iterator is consumed and closed + client.execute(plan).toBuffer.toSeq } private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { val command = proto.Command.newBuilder().setRegisterFunction(udf).build() - val plan = proto.Plan.newBuilder().setCommand(command).build() - - client.execute(plan) + execute(command) } @DeveloperApi @@ -613,14 +614,40 @@ class SparkSession private[sql] ( /** * Interrupt all operations of this session currently running on the connected server. * - * TODO/WIP: Currently it will interrupt the Spark Jobs running on the server, triggered from - * ExecutePlan requests. If an operation is not running a Spark Job, it becomes an noop and the - * operation will continue afterwards, possibly with more Spark Jobs. + * @return + * sequence of operationIds of interrupted operations. Note: there is still a possibility of + * operation finishing just as it is interrupted. + * + * @since 3.5.0 + */ + def interruptAll(): Seq[String] = { + client.interruptAll().getInterruptedIdsList.asScala.toSeq + } + + /** + * Interrupt all operations of this session with the given operation tag. + * + * @return + * sequence of operationIds of interrupted operations. Note: there is still a possibility of + * operation finishing just as it is interrupted. + * + * @since 3.5.0 + */ + def interruptTag(tag: String): Seq[String] = { + client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq + } + + /** + * Interrupt an operation of this session with the given operationId. + * + * @return + * sequence of operationIds of interrupted operations. Note: there is still a possibility of + * operation finishing just as it is interrupted. * * @since 3.5.0 */ - def interruptAll(): Unit = { - client.interruptAll() + def interruptOperation(operationId: String): Seq[String] = { + client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } /** @@ -641,6 +668,56 @@ class SparkSession private[sql] ( allocator.close() SparkSession.onSessionClose(this) } + + /** + * Add a tag to be assigned to all the operations started by this thread in this session. + * + * @param tag + * The tag to be added. Cannot contain ',' (comma) character or be an empty string. + * + * @since 3.5.0 + */ + def addTag(tag: String): Unit = { + client.addTag(tag) + } + + /** + * Remove a tag previously added to be assigned to all the operations started by this thread in + * this session. Noop if such a tag was not added earlier. + * + * @param tag + * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + * + * @since 3.5.0 + */ + def removeTag(tag: String): Unit = { + client.removeTag(tag) + } + + /** + * Get the tags that are currently set to be assigned to all the operations started by this + * thread. + * + * @since 3.5.0 + */ + def getTags(): Set[String] = { + client.getTags() + } + + /** + * Clear the current thread's operation tags. + * + * @since 3.5.0 + */ + def clearTags(): Unit = { + client.clearTags() + } + + /** + * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. + * We null out the instance for now. + */ + private def writeReplace(): Any = null } // The minimal builder needed to create a spark session. @@ -657,6 +734,23 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet. + */ + private def setDefaultAndActiveSession(session: SparkSession): Unit = { + defaultSession.compareAndSet(null, session) + if (getActiveSession.isEmpty) { + setActiveSession(session) + } + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -669,8 +763,17 @@ object SparkSession extends Logging { */ private[sql] def onSessionClose(session: SparkSession): Unit = { sessions.invalidate(session.client.configuration) + defaultSession.compareAndSet(session, null) + if (getActiveSession.contains(session)) { + clearActiveSession() + } } + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 3.4.0 + */ def builder(): Builder = new Builder() private[sql] lazy val cleaner = { @@ -680,8 +783,12 @@ object SparkSession extends Logging { } class Builder() extends Logging { - private val builder = SparkConnectClient.builder() + // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE + // by default, if it exists. The connection string can be overridden using + // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. + private val builder = SparkConnectClient.builder().loadFromEnvironment() private var client: SparkConnectClient = _ + private[this] val options = new scala.collection.mutable.HashMap[String, String] def remote(connectionString: String): Builder = { builder.connectionString(connectionString) @@ -705,6 +812,84 @@ object SparkSession extends Logging { this } + /** + * Sets a config option. Options set using this method are automatically propagated to the + * Spark Connect session. Only runtime options are supported. + * + * @since 3.5.0 + */ + def config(key: String, value: String): Builder = synchronized { + options += key -> value + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to the + * Spark Connect session. Only runtime options are supported. + * + * @since 3.5.0 + */ + def config(key: String, value: Long): Builder = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to the + * Spark Connect session. Only runtime options are supported. + * + * @since 3.5.0 + */ + def config(key: String, value: Double): Builder = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to the + * Spark Connect session. Only runtime options are supported. + * + * @since 3.5.0 + */ + def config(key: String, value: Boolean): Builder = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config a map of options. Options set using this method are automatically propagated + * to the Spark Connect session. Only runtime options are supported. + * + * @since 3.5.0 + */ + def config(map: Map[String, Any]): Builder = synchronized { + map.foreach { kv: (String, Any) => + { + options += kv._1 -> kv._2.toString + } + } + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @since 3.5.0 + */ + def config(map: java.util.Map[String, Any]): Builder = synchronized { + config(map.asScala.toMap) + } + + @deprecated("enableHiveSupport does not work in Spark Connect") + def enableHiveSupport(): Builder = this + + @deprecated("master does not work in Spark Connect, please use remote instead") + def master(master: String): Builder = this + + @deprecated("appName does not work in Spark Connect") + def appName(name: String): Builder = this + private def tryCreateSessionFromClient(): Option[SparkSession] = { if (client != null) { Option(new SparkSession(client, cleaner, planIdGenerator)) @@ -713,6 +898,12 @@ object SparkSession extends Logging { } } + private def applyOptions(session: SparkSession): Unit = { + options.foreach { case (key, value) => + session.conf.set(key, value) + } + } + /** * Build the [[SparkSession]]. * @@ -726,10 +917,16 @@ object SparkSession extends Logging { * * This will always return a newly created session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def create(): SparkSession = { - tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(SparkSession.this.create(builder.configuration)) + setDefaultAndActiveSession(session) + applyOptions(session) + session } /** @@ -738,30 +935,82 @@ object SparkSession extends Logging { * If a session exist with the same configuration that is returned instead of creating a new * session. * + * This method will update the default and/or active session if they are not set. This method + * will always set the specified configuration options on the session, even when it is not + * newly created. + * * @since 3.5.0 */ def getOrCreate(): SparkSession = { - tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(sessions.get(builder.configuration)) + setDefaultAndActiveSession(session) + applyOptions(session) + session } } - def getActiveSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getActiveSession is not supported") + /** + * Returns the default SparkSession. + * + * @since 3.5.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + + /** + * Sets the default SparkSession. + * + * @since 3.5.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) } - def getDefaultSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getDefaultSession is not supported") + /** + * Clears the default SparkSession. + * + * @since 3.5.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) } + /** + * Returns the active SparkSession for the current thread. + * + * @since 3.5.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * an isolated SparkSession. + * + * @since 3.5.0 + */ def setActiveSession(session: SparkSession): Unit = { - throw new UnsupportedOperationException("setActiveSession is not supported") + activeThreadSession.set(session) } + /** + * Clears the active SparkSession for current thread. + * + * @since 3.5.0 + */ def clearActiveSession(): Unit = { - throw new UnsupportedOperationException("clearActiveSession is not supported") + activeThreadSession.remove() } + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 3.5.0 + */ def active: SparkSession = { - throw new UnsupportedOperationException("active is not supported") + getActiveSession + .orElse(getDefaultSession) + .getOrElse(throw new IllegalStateException("No active or default Spark session found")) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 426709b8f1867..2e8211a0966e7 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -20,7 +20,10 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.internal.Logging +import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} +import org.apache.spark.sql.types.DataType /** * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: @@ -1024,5 +1027,268 @@ class UDFRegistration(session: SparkSession) extends Logging { typeTag[A22]) register(name, udf) } + + // (0 to 22).foreach { i => + // val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + // val version = "3.5.0" + // println(s""" + // |/** + // | * Register a deterministic Java UDF$i instance as user-defined function (UDF). + // | * @since $version + // | */ + // |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { + // | val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + // | register(name, udf) + // |}""".stripMargin) + // } + + /** + * Register a deterministic Java UDF0 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF0[_], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF1 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF2 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF3 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF4 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF5 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF6 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF7 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF8 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF9 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF9[_, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF10 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF10[_, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF11 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF12 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF13 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF14 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF15 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF16 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF17 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF18 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF19 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF20 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF21 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } + + /** + * Register a deterministic Java UDF22 instance as user-defined function (UDF). + * @since 3.5.0 + */ + def register( + name: String, + f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): Unit = { + val udf = ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + register(name, udf) + } // scalastyle:on line.size.limit } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index e6ada566398c7..0360a40578869 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -22,7 +22,8 @@ import java.util.concurrent.Semaphore import scala.util.control.NonFatal import ammonite.compiler.CodeClassWrapper -import ammonite.util.Bind +import ammonite.compiler.iface.CodeWrapper +import ammonite.util.{Bind, Imports, Name, Util} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SparkSession @@ -94,8 +95,8 @@ object ConnectRepl { val main = ammonite.Main( welcomeBanner = Option(splash), predefCode = predefCode, - replCodeWrapper = CodeClassWrapper, - scriptCodeWrapper = CodeClassWrapper, + replCodeWrapper = ExtendedCodeClassWrapper, + scriptCodeWrapper = ExtendedCodeClassWrapper, inputStream = inputStream, outputStream = outputStream, errorStream = errorStream) @@ -107,3 +108,25 @@ object ConnectRepl { } } } + +/** + * [[CodeWrapper]] that makes sure new Helper classes are always registered as an outer scope. + */ +@DeveloperApi +object ExtendedCodeClassWrapper extends CodeWrapper { + override def wrapperPath: Seq[Name] = CodeClassWrapper.wrapperPath + override def apply( + code: String, + source: Util.CodeSource, + imports: Imports, + printCode: String, + indexedWrapper: Name, + extraCode: String): (String, String, Int) = { + val (top, bottom, level) = + CodeClassWrapper(code, source, imports, printCode, indexedWrapper, extraCode) + // Make sure we register the Helper before anything else, so outer scopes work as expected. + val augmentedTop = top + + "\norg.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this)\n" + (augmentedTop, bottom, level) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala new file mode 100644 index 0000000000000..4ebc22202b0b7 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.net.URL +import java.nio.file.Paths + +import ammonite.repl.api.Session +import ammonite.runtime.SpecialClassLoader + +/** + * A special [[ClassFinder]] for the Ammonite REPL to handle in-memory class files. + * + * @param session + */ +class AmmoniteClassFinder(session: Session) extends ClassFinder { + + override def findClasses(): Iterator[Artifact] = { + session.frames.iterator.flatMap { frame => + val classloader = frame.classloader.asInstanceOf[SpecialClassLoader] + val signatures: Seq[(Either[String, URL], Long)] = classloader.classpathSignature + signatures.iterator.collect { case (Left(name), _) => + val parts = name.split('.') + parts(parts.length - 1) += ".class" + val path = Paths.get(parts.head, parts.tail: _*) + val bytes = classloader.newFileDict(name) + Artifact.newClassArtifact(path, new Artifact.InMemory(bytes)) + } + } + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala deleted file mode 100644 index 1a42ec821d84f..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connect.client - -import io.grpc.StatusRuntimeException -import io.grpc.protobuf.StatusProto - -import org.apache.spark.{SparkException, SparkThrowable} - -private[client] object GrpcExceptionConverter { - def convert[T](f: => T): T = { - try { - f - } catch { - case e: StatusRuntimeException => - throw toSparkThrowable(e) - } - } - - def convertIterator[T](iter: java.util.Iterator[T]): java.util.Iterator[T] = { - new java.util.Iterator[T] { - override def hasNext: Boolean = { - convert { - iter.hasNext - } - } - - override def next(): T = { - convert { - iter.next() - } - } - } - } - - private def toSparkThrowable(ex: StatusRuntimeException): SparkThrowable with Throwable = { - val status = StatusProto.fromThrowable(ex) - // TODO: Add finer grained error conversion - new SparkException(status.getMessage, ex.getCause) - } -} - - diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala deleted file mode 100644 index a727c86f70fc6..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connect.client - -import java.util.Collections - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.FieldVector -import org.apache.arrow.vector.ipc.ArrowStreamReader - -import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable} -import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} - -private[sql] class SparkResult[T]( - responses: java.util.Iterator[proto.ExecutePlanResponse], - allocator: BufferAllocator, - encoder: AgnosticEncoder[T]) - extends AutoCloseable - with Cleanable { - - private[this] var numRecords: Int = 0 - private[this] var structType: StructType = _ - private[this] var boundEncoder: ExpressionEncoder[T] = _ - private[this] var nextBatchIndex: Int = 0 - private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch] - - private def createEncoder(schema: StructType): ExpressionEncoder[T] = { - val agnosticEncoder = createEncoder(encoder, schema).asInstanceOf[AgnosticEncoder[T]] - ExpressionEncoder(agnosticEncoder) - } - - /** - * Update RowEncoder and recursively update the fields of the ProductEncoder if found. - */ - private def createEncoder(enc: AgnosticEncoder[_], dataType: DataType): AgnosticEncoder[_] = { - enc match { - case UnboundRowEncoder => - // Replace the row encoder with the encoder inferred from the schema. - RowEncoder.encoderFor(dataType.asInstanceOf[StructType]) - case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) => - // Recursively continue updating the tuple product encoder - val schema = dataType.asInstanceOf[StructType] - assert(fields.length <= schema.fields.length) - val updatedFields = fields.zipWithIndex.map { case (f, id) => - f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType)) - } - ProductEncoder(clsTag, updatedFields) - case _ => - enc - } - } - - private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = { - while (responses.hasNext) { - val response = responses.next() - if (response.hasSchema) { - // The original schema should arrive before ArrowBatches. - structType = - DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType] - } else if (response.hasArrowBatch) { - val ipcStreamBytes = response.getArrowBatch.getData - val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator) - try { - val root = reader.getVectorSchemaRoot - if (structType == null) { - // If the schema is not available yet, fallback to the schema from Arrow. - structType = ArrowUtils.fromArrowSchema(root.getSchema) - } - // TODO: create encoders that directly operate on arrow vectors. - if (boundEncoder == null) { - boundEncoder = createEncoder(structType) - .resolveAndBind(DataTypeUtils.toAttributes(structType)) - } - while (reader.loadNextBatch()) { - val rowCount = root.getRowCount - if (rowCount > 0) { - val vectors = root.getFieldVectors.asScala - .map(v => new ArrowColumnVector(transferToNewVector(v))) - .toArray[ColumnVector] - idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount)) - nextBatchIndex += 1 - numRecords += rowCount - if (stopOnFirstNonEmptyResponse) { - return true - } - } - } - } finally { - reader.close() - } - } - } - false - } - - private def transferToNewVector(in: FieldVector): FieldVector = { - val pair = in.getTransferPair(allocator) - pair.transfer() - pair.getTo.asInstanceOf[FieldVector] - } - - /** - * Returns the number of elements in the result. - */ - def length: Int = { - // We need to process all responses to make sure numRecords is correct. - processResponses(stopOnFirstNonEmptyResponse = false) - numRecords - } - - /** - * @return - * the schema of the result. - */ - def schema: StructType = { - processResponses(stopOnFirstNonEmptyResponse = true) - structType - } - - /** - * Create an Array with the contents of the result. - */ - def toArray: Array[T] = { - val result = encoder.clsTag.newArray(length) - val rows = iterator - var i = 0 - while (rows.hasNext) { - result(i) = rows.next() - assert(i < numRecords) - i += 1 - } - result - } - - /** - * Returns an iterator over the contents of the result. - */ - def iterator: java.util.Iterator[T] with AutoCloseable = - buildIterator(destructive = false) - - /** - * Returns an destructive iterator over the contents of the result. - */ - def destructiveIterator: java.util.Iterator[T] with AutoCloseable = - buildIterator(destructive = true) - - private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = { - new java.util.Iterator[T] with AutoCloseable { - private[this] var batchIndex: Int = -1 - private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator() - private[this] var deserializer: Deserializer[T] = _ - - override def hasNext: Boolean = { - if (iterator.hasNext) { - return true - } - - val nextBatchIndex = batchIndex + 1 - if (destructive) { - idxToBatches.remove(batchIndex).foreach(_.close()) - } - - val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) { - processResponses(stopOnFirstNonEmptyResponse = true) - } else { - true - } - if (hasNextBatch) { - batchIndex = nextBatchIndex - iterator = idxToBatches(nextBatchIndex).rowIterator() - if (deserializer == null) { - deserializer = boundEncoder.createDeserializer() - } - } - hasNextBatch - } - - override def next(): T = { - if (!hasNext) { - throw new NoSuchElementException - } - deserializer(iterator.next()) - } - - override def close(): Unit = SparkResult.this.close() - } - } - - /** - * Close this result, freeing any underlying resources. - */ - override def close(): Unit = { - idxToBatches.values.foreach(_.close()) - } - - override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq) -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala new file mode 100644 index 0000000000000..9a5fda1189d2d --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +/** + * Class used to test stubbing. This needs to be in the main source tree, because this is not + * synced with the connect server during tests. + */ +case class ToStub(value: Long) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 18aef8a2e4cfd..e060dba0b7e42 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -18,15 +18,18 @@ package org.apache.spark.sql.expressions import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import com.google.protobuf.ByteString +import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket} -import org.apache.spark.util.SparkSerDeUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils} /** * A user-defined function. To create one, use the `udf` functions in `functions`. @@ -92,7 +95,7 @@ sealed abstract class UserDefinedFunction { /** * Holder class for a scalar user-defined function and it's input/output encoder(s). */ -case class ScalarUserDefinedFunction private ( +case class ScalarUserDefinedFunction private[sql] ( // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. serializedUdfPacket: Array[Byte], inputTypes: Seq[proto.DataType], @@ -143,6 +146,25 @@ case class ScalarUserDefinedFunction private ( } object ScalarUserDefinedFunction { + private val LAMBDA_DESERIALIZATION_ERR_MSG: String = + "cannot assign instance of java.lang.invoke.SerializedLambda to field" + + private def checkDeserializable(bytes: Array[Byte]): Unit = { + try { + SparkSerDeUtils.deserialize(bytes, SparkClassUtils.getContextOrSparkClassLoader) + } catch { + case e: ClassCastException if e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) => + throw new SparkException( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized. " + + "This is very likely to be caused by the lambda function (the UDF) having a " + + "self-reference. This is not supported by java serialization.") + case NonFatal(e) => + throw new SparkException( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized.", + e) + } + } + private[sql] def apply( function: AnyRef, returnType: TypeTag[_], @@ -163,6 +185,7 @@ object ScalarUserDefinedFunction { outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = { val udfPacketBytes = SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) + checkDeserializable(udfPacketBytes) ScalarUserDefinedFunction( serializedUdfPacket = udfPacketBytes, inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType), @@ -171,4 +194,11 @@ object ScalarUserDefinedFunction { nullable = true, deterministic = true) } + + private[sql] def apply(function: AnyRef, returnType: DataType): ScalarUserDefinedFunction = { + ScalarUserDefinedFunction( + function = function, + inputEncoders = Seq.empty[AgnosticEncoder[_]], + outputEncoder = RowEncoder.encoderForDataType(returnType, lenient = false)) + } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 17d1cdca350d5..93762037ece52 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.connect.proto +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder import org.apache.spark.sql.connect.common.LiteralValueProtoConverter._ +import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.types.DataType.parseTypeWithFallback @@ -144,6 +146,7 @@ object functions { * @group normal_funcs * @since 3.5.0 */ + @scala.annotation.varargs def named_struct(cols: Column*): Column = Column.fn("named_struct", cols: _*) ////////////////////////////////////////////////////////////////////////////////////////////// @@ -589,6 +592,7 @@ object functions { * @group agg_funcs * @since 3.4.0 */ + @scala.annotation.varargs def grouping_id(cols: Column*): Column = Column.fn("grouping_id", cols: _*) /** @@ -604,6 +608,7 @@ object functions { * @group agg_funcs * @since 3.4.0 */ + @scala.annotation.varargs def grouping_id(colName: String, colNames: String*): Column = grouping_id((Seq(colName) ++ colNames).map(n => Column(n)): _*) @@ -985,7 +990,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = stddev(e) + def std(e: Column): Column = Column.fn("std", e) /** * Aggregate function: alias for `stddev_samp`. @@ -2335,7 +2340,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = ceil(e, scale) + def ceiling(e: Column, scale: Column): Column = Column.fn("ceiling", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2343,7 +2348,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = ceil(e) + def ceiling(e: Column): Column = Column.fn("ceiling", e) /** * Convert a number in a string column from one base to another. @@ -2622,7 +2627,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ln(e: Column): Column = log(e) + def ln(e: Column): Column = Column.fn("ln", e) /** * Computes the natural logarithm of the given value. @@ -2630,7 +2635,7 @@ object functions { * @group math_funcs * @since 3.4.0 */ - def log(e: Column): Column = Column.fn("log", e) + def log(e: Column): Column = ln(e) /** * Computes the natural logarithm of the given column. @@ -2798,7 +2803,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def power(l: Column, r: Column): Column = pow(l, r) + def power(l: Column, r: Column): Column = Column.fn("power", l, r) /** * Returns the positive value of dividend mod divisor. @@ -2935,7 +2940,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = signum(e) + def sign(e: Column): Column = Column.fn("sign", e) /** * Computes the signum of the given value. @@ -3475,7 +3480,7 @@ object functions { mode: Column, padding: Column, aad: Column): Column = - Column.fn("aes_encrypt", input, key, mode, padding, aad) + Column.fn("aes_decrypt", input, key, mode, padding, aad) /** * Returns a decrypted value of `input`. @@ -3487,7 +3492,7 @@ object functions { * @since 3.5.0 */ def aes_decrypt(input: Column, key: Column, mode: Column, padding: Column): Column = - Column.fn("aes_encrypt", input, key, mode, padding) + Column.fn("aes_decrypt", input, key, mode, padding) /** * Returns a decrypted value of `input`. @@ -3499,7 +3504,7 @@ object functions { * @since 3.5.0 */ def aes_decrypt(input: Column, key: Column, mode: Column): Column = - Column.fn("aes_encrypt", input, key, mode) + Column.fn("aes_decrypt", input, key, mode) /** * Returns a decrypted value of `input`. @@ -3511,7 +3516,7 @@ object functions { * @since 3.5.0 */ def aes_decrypt(input: Column, key: Column): Column = - Column.fn("aes_encrypt", input, key) + Column.fn("aes_decrypt", input, key) /** * This is a special version of `aes_decrypt` that performs the same operation, but returns a @@ -3609,6 +3614,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ + @scala.annotation.varargs def reflect(cols: Column*): Column = Column.fn("reflect", cols: _*) /** @@ -3617,6 +3623,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ + @scala.annotation.varargs def java_method(cols: Column*): Column = Column.fn("java_method", cols: _*) /** @@ -3643,6 +3650,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ + @scala.annotation.varargs def stack(cols: Column*): Column = Column.fn("stack", cols: _*) /** @@ -4417,8 +4425,9 @@ object functions { * @group string_funcs * @since 3.5.0 */ + @scala.annotation.varargs def printf(format: Column, arguments: Column*): Column = - Column.fn("format_string", lit(format) +: arguments: _*) + Column.fn("printf", (format +: arguments): _*) /** * Decodes a `str` in 'application/x-www-form-urlencoded' format using a specific encoding @@ -7903,8 +7912,197 @@ object functions { typeTag[A9], typeTag[A10]) } + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). The caller must specify the + * output data type, and there is no automatic input type coercion. By default the returned UDF + * is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.5.0 + */ + def udf( + f: UDF10[_, _, _, _, _, _, _, _, _, _, _], + returnType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(UdfUtils.wrap(f), returnType) + } // scalastyle:off line.size.limit + /** + * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, + * the caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * Note that, although the Scala closure can have primitive-type function argument, it doesn't + * work well with null values. Because the Scala closure is passed in as Any type, there is no + * type information for the function arguments. Without the type information, Spark may blindly + * pass null to the Scala closure with primitive-type argument, and the closure will see the + * default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`, + * the result is 0 for null input. + * + * @param f + * A closure in Scala + * @param dataType + * The output data type of the UDF + * + * @group udf_funcs + * @since 3.5.0 + */ + @deprecated( + "Scala `udf` method with return type parameter is deprecated. " + + "Please use Scala `udf` method without return type parameter.", + "3.0.0") + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(f, dataType) + } + + /** + * Call an user-defined function. + * + * @group udf_funcs + * @since 3.5.0 + */ + @scala.annotation.varargs + @deprecated("Use call_udf") + def callUDF(udfName: String, cols: Column*): Column = + call_function(udfName, cols: _*) + /** * Call an user-defined function. Example: * {{{ @@ -7923,15 +8121,19 @@ object functions { def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** - * Call a builtin or temp function. + * Call a SQL function. * * @param funcName - * function name + * function name that follows the SQL identifier syntax (can be quoted, can be qualified) * @param cols * the expression parameters of function * @since 3.5.0 */ @scala.annotation.varargs - def call_function(funcName: String, cols: Column*): Column = Column.fn(funcName, cols: _*) + def call_function(funcName: String, cols: Column*): Column = Column { builder => + builder.getCallFunctionBuilder + .setFunctionName(funcName) + .addAllArguments(cols.map(_.expr).asJava) + } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 57ce013065e2b..293490928a278 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -16,19 +16,16 @@ */ package org.apache.spark.sql.protobuf -import java.io.File import java.io.FileNotFoundException -import java.nio.file.NoSuchFileException +import java.nio.file.{Files, NoSuchFileException, Paths} import java.util.Collections import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.CompilationErrors import org.apache.spark.sql.functions.{fnWithOptions, lit} // scalastyle:off: object.name @@ -309,13 +306,13 @@ object functions { // This method is copied from org.apache.spark.sql.protobuf.util.ProtobufUtils private def readDescriptorFileContent(filePath: String): Array[Byte] = { try { - FileUtils.readFileToByteArray(new File(filePath)) + Files.readAllBytes(Paths.get(filePath)) } catch { case ex: FileNotFoundException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) + throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) case ex: NoSuchFileException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case NonFatal(ex) => throw QueryCompilationErrors.descriptorParseError(ex) + throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) + case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex) } } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ad76ab4a1bc66..54eb6e761407c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -25,20 +25,19 @@ import scala.collection.JavaConverters._ import com.google.protobuf.ByteString import org.apache.spark.annotation.Evolving +import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.WriteStreamOperationStart import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.connect.common.ForeachWriterPacket +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, UdfUtils} import org.apache.spark.sql.execution.streaming.AvailableNowTrigger import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.execution.streaming.OneTimeTrigger import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger import org.apache.spark.sql.types.NullType import org.apache.spark.util.SparkSerDeUtils -import org.apache.spark.util.Utils /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -217,7 +216,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { * @since 3.5.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { - val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.encoder)) + val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) val scalaWriterBuilder = proto.ScalarScalaUDF .newBuilder() .setPayload(ByteString.copyFrom(serialized)) @@ -240,7 +239,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { */ @Evolving def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { - val serializedFn = Utils.serialize(function) + val serializedFn = SparkSerDeUtils.serialize(function) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) .setOutputType(DataTypeProtoConverter.toConnectProtoType(NullType)) // Unused. @@ -248,6 +247,24 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The + * batchId can be used to deduplicate and transactionally write the output (that is, the + * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the + * same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 3.5.0 + */ + @Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { + foreachBatch(UdfUtils.foreachBatchFuncToScalaFunc(function)) + } + /** * Starts the execution of the streaming query, which will continually output results to the * given path as new data arrives. The returned [[StreamingQuery]] object can be used to diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 8cef421becd11..404bd1b078ba4 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -25,7 +25,6 @@ import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc} import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.annotation.Evolving -import org.apache.spark.scheduler.SparkListenerEvent /** * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. @@ -76,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable { def onQueryTerminated(event: QueryTerminatedEvent): Unit } -/** - * Py4J allows a pure interface so this proxy is required. - */ -private[spark] trait PythonStreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit - - def onQueryProgress(event: QueryProgressEvent): Unit - - def onQueryIdle(event: QueryIdleEvent): Unit - - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener) - extends StreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) - - def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) - - override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) - - def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) -} - /** * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 3.5.0 @@ -116,7 +87,7 @@ object StreamingQueryListener extends Serializable { * @since 3.5.0 */ @Evolving - trait Event extends SparkListenerEvent + trait Event /** * Event representing the start of a query diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 8f9e768d23f7d..d16638e594599 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -31,7 +31,7 @@ import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.common.{InvalidPlanInput, StreamingListenerPacket} -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkSerDeUtils /** * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. @@ -155,8 +155,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo cacheListenerById(id, listener) executeManagerCmd( _.getAddListenerBuilder - .setListenerPayload(ByteString.copyFrom(Utils - .serialize(StreamingListenerPacket(id, listener))))) + .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils + .serialize(StreamingListenerPacket(id, listener)))) + .setId(id)) } /** @@ -168,8 +169,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo val id = getIdByListener(listener) executeManagerCmd( _.getRemoveListenerBuilder - .setListenerPayload(ByteString.copyFrom(Utils - .serialize(StreamingListenerPacket(id, listener))))) + .setId(id)) removeCachedListener(id) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 123b3306f2a76..593311efb9c93 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDoubleToJValue, safeMapToJValue} import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS -import org.apache.spark.sql.util.ToJsonUtil /** * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. @@ -190,9 +189,7 @@ class StreamingQueryProgress private[spark] ( ("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~ ("sources" -> JArray(sources.map(_.jsonValue).toList)) ~ ("sink" -> sink.jsonValue) ~ - ("observedMetrics" -> safeMapToJValue[Row]( - observedMetrics, - row => ToJsonUtil.jsonValue(row))) + ("observedMetrics" -> safeMapToJValue[Row](observedMetrics, row => row.jsonValue)) } } diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java new file mode 100644 index 0000000000000..d5fdede774f47 --- /dev/null +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; + +import org.junit.*; +import static org.junit.Assert.*; + +import static org.apache.spark.sql.Encoders.*; +import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.RowFactory.create; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.test.SparkConnectServerUtils; +import org.apache.spark.sql.types.StructType; + +/** + * Tests for the encoders class. + */ +public class JavaEncoderSuite implements Serializable { + private static SparkSession spark; + + @BeforeClass + public static void setup() { + spark = SparkConnectServerUtils.createSparkSession(); + } + + @AfterClass + public static void tearDown() { + spark.stop(); + spark = null; + SparkConnectServerUtils.stop(); + } + + private static BigDecimal bigDec(long unscaled, int scale) { + return BigDecimal.valueOf(unscaled, scale); + } + + + private Dataset dataset(Encoder encoder, T... elements) { + return spark.createDataset(Arrays.asList(elements), encoder); + } + + @Test + public void testSimpleEncoders() { + final Column v = col("value"); + assertFalse( + dataset(BOOLEAN(), false, true, false).select(every(v)).as(BOOLEAN()).head()); + assertEquals( + 7L, + dataset(BYTE(), (byte) -120, (byte)127).select(sum(v)).as(LONG()).head().longValue()); + assertEquals( + (short) 16, + dataset(SHORT(), (short)16, (short)2334).select(min(v)).as(SHORT()).head().shortValue()); + assertEquals( + 10L, + dataset(INT(), 1, 2, 3, 4).select(sum(v)).as(LONG()).head().longValue()); + assertEquals( + 96L, + dataset(LONG(), 77L, 19L).select(sum(v)).as(LONG()).head().longValue()); + assertEquals( + 0.12f, + dataset(FLOAT(), 0.12f, 0.3f, 44f).select(min(v)).as(FLOAT()).head(), + 0.0001f); + assertEquals( + 789d, + dataset(DOUBLE(), 789d, 12.213d, 10.01d).select(max(v)).as(DOUBLE()).head(), + 0.0001f); + assertEquals( + bigDec(1002, 2), + dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2)) + .select(sum(v)).as(DECIMAL()).head().setScale(2)); + } + + @Test + public void testRowEncoder() { + final StructType schema = new StructType() + .add("a", "int") + .add("b", "string"); + final Dataset df = spark.range(3) + .map(new MapFunction() { + @Override + public Row call(Long i) { + return create(i.intValue(), "s" + i); + } + }, + Encoders.row(schema)) + .filter(col("a").geq(1)); + final List expected = Arrays.asList(create(1, "s1"), create(2, "s2")); + Assert.assertEquals(expected, df.collectAsList()); + } +} diff --git a/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala new file mode 100644 index 0000000000000..ff1b3deafafdc --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +// To generate a jar from the source file: +// `scalac StubClassDummyUdf.scala -d udf.jar` +// To remove class A from the jar: +// `jar -xvf udf.jar` -> delete A.class and A$.class +// `jar -cvf udf_noA.jar org/` +class StubClassDummyUdf { + val udf: Int => Int = (x: Int) => x + 1 + val dummy = (x: Int) => A(x) +} + +case class A(x: Int) { def get: Int = x + 5 } + +// The code to generate the udf file +object StubClassDummyUdf { + import java.io.{BufferedOutputStream, File, FileOutputStream} + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveIntEncoder + import org.apache.spark.sql.connect.common.UdfPacket + import org.apache.spark.util.Utils + + def packDummyUdf(): String = { + val byteArray = + Utils.serialize[UdfPacket]( + new UdfPacket( + new StubClassDummyUdf().udf, + Seq(PrimitiveIntEncoder), + PrimitiveIntEncoder + ) + ) + val file = new File("src/test/resources/udf") + val target = new BufferedOutputStream(new FileOutputStream(file)) + try { + target.write(byteArray) + file.getAbsolutePath + } finally { + target.close + } + } +} diff --git a/connector/connect/client/jvm/src/test/resources/log4j2.properties b/connector/connect/client/jvm/src/test/resources/log4j2.properties index ab02104c69697..550fd261b6fb5 100644 --- a/connector/connect/client/jvm/src/test/resources/log4j2.properties +++ b/connector/connect/client/jvm/src/test/resources/log4j2.properties @@ -32,7 +32,7 @@ appender.console.type = Console appender.console.name = console appender.console.target = SYSTEM_ERR appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %t: %m%n%ex +appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n # Ignore messages below warning level from Jetty, because it's a bit verbose logger.jetty.name = org.sparkproject.jetty diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12 b/connector/connect/client/jvm/src/test/resources/udf2.12 new file mode 100644 index 0000000000000..1090bc90d9b4b Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.12 differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12.jar b/connector/connect/client/jvm/src/test/resources/udf2.12.jar new file mode 100644 index 0000000000000..6ce6799678f99 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.12.jar differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13 b/connector/connect/client/jvm/src/test/resources/udf2.13 new file mode 100644 index 0000000000000..863ac32a76dc9 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.13 differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13.jar b/connector/connect/client/jvm/src/test/resources/udf2.13.jar new file mode 100644 index 0000000000000..c89830f127c0c Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.13.jar differ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 00a6bcc9b5c45..cefa63ecd353e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -22,7 +22,7 @@ import java.io.{File, FilenameFilter} import org.apache.commons.io.FileUtils import org.apache.spark.SparkException -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.test.{RemoteSparkSession, SQLHelper} import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(databasesWithPattern.length == 0) val database = spark.catalog.getDatabase(db) assert(database.name == db) - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getDatabase("notExists") }.getMessage assert(message.contains("SCHEMA_NOT_FOUND")) @@ -141,7 +141,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.listTables().collect().map(_.name).toSet == Set(parquetTableName)) } } - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getTable(parquetTableName) }.getMessage assert(message.contains("TABLE_OR_VIEW_NOT_FOUND")) @@ -207,7 +207,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.getFunction(absFunctionName).name == absFunctionName) val notExistsFunction = "notExists" assert(!spark.catalog.functionExists(notExistsFunction)) - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { spark.catalog.getFunction(notExistsFunction) }.getMessage assert(message.contains("UNRESOLVED_ROUTINE")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala similarity index 65% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 12683e54d989f..069d8ec502f52 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -21,10 +21,10 @@ import java.util.Random import org.scalatest.matchers.must.Matchers._ -import org.apache.spark.SparkException -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.{SparkException, SparkIllegalArgumentException} +import org.apache.spark.sql.test.RemoteSparkSession -class DataFrameStatSuite extends RemoteSparkSession { +class ClientDataFrameStatSuite extends RemoteSparkSession { private def toLetter(i: Int): String = (i + 97).toChar.toString test("approxQuantile") { @@ -87,7 +87,7 @@ class DataFrameStatSuite extends RemoteSparkSession { val results = df.stat.cov("singles", "doubles") assert(math.abs(results - 55.0 / 3) < 1e-12) - intercept[SparkException] { + intercept[SparkIllegalArgumentException] { df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes } val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b") @@ -176,4 +176,91 @@ class DataFrameStatSuite extends RemoteSparkSession { assert(sketch.relativeError() === 0.001) assert(sketch.confidence() === 0.99 +- 5e-3) } + + test("Bloom filter -- Long Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toLong) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767).map(_.toLong) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- Int Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- Short Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toShort) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767).map(_.toShort) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- Byte Column") { + val session = spark + import session.implicits._ + val data = Seq(-32, -5, 1, 17, 39, 43, 101, 127).map(_.toByte) + val df = data.toDF("id") + val negativeValues = Seq(-101, 55, 113).map(_.toByte) + checkBloomFilter(data, negativeValues, df) + } + + test("Bloom filter -- String Column") { + val session = spark + import session.implicits._ + val data = Seq(-143, -32, -5, 1, 17, 39, 43, 101, 127, 997).map(_.toString) + val df = data.toDF("id") + val negativeValues = Seq(-11, 1021, 32767).map(_.toString) + checkBloomFilter(data, negativeValues, df) + } + + private def checkBloomFilter( + data: Seq[Any], + notContainValues: Seq[Any], + df: DataFrame): Unit = { + val filter1 = df.stat.bloomFilter("id", 1000, 0.03) + assert(filter1.expectedFpp() - 0.03 < 1e-3) + assert(data.forall(filter1.mightContain)) + assert(notContainValues.forall(n => !filter1.mightContain(n))) + val filter2 = df.stat.bloomFilter("id", 1000, 64 * 5) + assert(filter2.bitSize() == 64 * 5) + assert(data.forall(filter2.mightContain)) + assert(notContainValues.forall(n => !filter2.mightContain(n))) + } + + test("Bloom filter -- Wrong dataType Column") { + val session = spark + import session.implicits._ + val data = Range(0, 1000).map(_.toDouble) + val message = intercept[AnalysisException] { + data.toDF("id").stat.bloomFilter("id", 1000, 0.03) + }.getMessage + assert(message.contains("DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE")) + } + + test("Bloom filter test invalid inputs") { + val df = spark.range(1000).toDF("id") + val message1 = intercept[SparkException] { + df.stat.bloomFilter("id", -1000, 100) + }.getMessage + assert(message1.contains("Expected insertions must be positive")) + + val message2 = intercept[SparkException] { + df.stat.bloomFilter("id", 1000, -100) + }.getMessage + assert(message2.contains("Number of bits must be positive")) + + val message3 = intercept[SparkException] { + df.stat.bloomFilter("id", 1000, -1.0) + }.getMessage + assert(message3.contains("False positive probability must be within range (0.0, 1.0)")) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala similarity index 93% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index f70e9b4ce13e5..aab31d97e8c9d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -26,14 +26,15 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.util.SparkSerDeUtils // Add sample tests. // - sample fraction: simple.sample(0.1) // - sample withReplacement_fraction: simple.sample(withReplacement = true, 0.11) // Add tests for exceptions thrown -class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { +class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { private var server: Server = _ private var service: DummySparkConnectService = _ @@ -172,4 +173,11 @@ class DatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { val actualPlan = service.getAndClearLatestInputPlan() assert(actualPlan.equals(expectedPlan)) } + + test("serialize as null") { + val session = newSparkSession() + val ds = session.range(10) + val bytes = SparkSerDeUtils.serialize(ds) + assert(SparkSerDeUtils.deserialize[Dataset[Long]](bytes) == null) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 73c04389c0597..feefd19000d1d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -29,20 +29,114 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester -import org.apache.spark.{SPARK_VERSION, SparkException} +import org.apache.spark.{SparkArithmeticException, SparkException} +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} -import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} -import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper} +import org.apache.spark.sql.test.SparkConnectServerUtils.port import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { + test("throw SparkArithmeticException") { + withSQLConf("spark.sql.ansi.enabled" -> "true") { + intercept[SparkArithmeticException] { + spark.sql("select 1/0").collect() + } + } + } + + test("throw NoSuchDatabaseException") { + intercept[NoSuchDatabaseException] { + spark.sql("use database123") + } + } + + test("throw NoSuchTableException") { + intercept[NoSuchTableException] { + spark.catalog.getTable("test_table") + } + } + + test("throw NamespaceAlreadyExistsException") { + try { + spark.sql("create database test_db") + intercept[NamespaceAlreadyExistsException] { + spark.sql("create database test_db") + } + } finally { + spark.sql("drop database test_db") + } + } + + test("throw TempTableAlreadyExistsException") { + try { + spark.sql("create temporary view test_view as select 1") + intercept[TempTableAlreadyExistsException] { + spark.sql("create temporary view test_view as select 1") + } + } finally { + spark.sql("drop view test_view") + } + } + + test("throw TableAlreadyExistsException") { + withTable("testcat.test_table") { + spark.sql(s"create table testcat.test_table (id int)") + intercept[TableAlreadyExistsException] { + spark.sql(s"create table testcat.test_table (id int)") + } + } + } + + test("throw ParseException") { + intercept[ParseException] { + spark.sql("selet 1").collect() + } + } + + test("spark deep recursion") { + var df = spark.range(1) + for (a <- 1 to 500) { + df = df.union(spark.range(a, a + 1)) + } + assert(df.collect().length == 501) + } + + test("handle unknown exception") { + var df = spark.range(1) + val limit = spark.conf.get("spark.connect.grpc.marshallerRecursionLimit").toInt + 1 + for (a <- 1 to limit) { + df = df.union(spark.range(a, a + 1)) + } + val ex = intercept[SparkException] { + df.collect() + } + assert(ex.getMessage.contains("io.grpc.StatusRuntimeException: UNKNOWN")) + } + + test("many tables") { + withSQLConf("spark.sql.execution.arrow.maxRecordsPerBatch" -> "10") { + val numTables = 20 + try { + for (i <- 0 to numTables) { + spark.sql(s"create table testcat.table${i} (id int)") + } + assert(spark.sql("show tables in testcat").collect().length == numTables + 1) + } finally { + for (i <- 0 to numTables) { + spark.sql(s"drop table if exists testcat.table${i}") + } + } + } + } + // Spark Result test("spark result schema") { val df = spark.sql("select val from (values ('Hello'), ('World')) as t(val)") @@ -64,7 +158,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("test_martin") { // Fails, because table does not exist. - assertThrows[SparkException] { + assertThrows[AnalysisException] { spark.sql("select * from test_martin").collect() } // Execute eager, DML @@ -153,7 +247,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM StructField("job", StringType) :: Nil)) .csv(testDataPath.toString) // Failed because the path cannot be provided both via option and load method (csv). - assertThrows[SparkException] { + assertThrows[AnalysisException] { df.collect() } } @@ -357,7 +451,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val df = spark.range(10) val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath // Failed because the path cannot be provided both via option and save method. - assertThrows[SparkException] { + assertThrows[AnalysisException] { df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString) } } @@ -571,7 +665,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM (col("id") / lit(10.0d)).as("b"), col("id"), lit("world").as("d"), - (col("id") % 2).cast("int").as("a")) + (col("id") % 2).as("a")) private def validateMyTypeResult(result: Array[MyType]): Unit = { result.zipWithIndex.foreach { case (MyType(id, a, b), i) => @@ -673,6 +767,64 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert(joined2.schema.catalogString === "struct") } + test("SPARK-45509: ambiguous column reference") { + val session = spark + import session.implicits._ + val df1 = Seq(1 -> "a").toDF("i", "j") + val df1_filter = df1.filter(df1("i") > 0) + val df2 = Seq(2 -> "b").toDF("i", "y") + + checkSameResult( + Seq(Row(1)), + // df1("i") is not ambiguous, and it's still valid in the filtered df. + df1_filter.select(df1("i"))) + + val e1 = intercept[AnalysisException] { + // df1("i") is not ambiguous, but it's not valid in the projected df. + df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect() + } + assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT")) + + checkSameResult( + Seq(Row(1, "a")), + // All these column references are not ambiguous and are still valid after join. + df1.join(df2, df1("i") + 1 === df2("i")).sort(df1("i").desc).select(df1("i"), df1("j"))) + + val e2 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1, df1("i") === 1).collect() + } + assert(e2.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e3 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides. + df1.join(df1).select(df1("i")).collect() + } + assert(e3.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + val e4 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both join sides (df1_filter contains df1). + df1.join(df1_filter, df1("i") === 1).collect() + } + assert(e4.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter does not exist in the join left side. + df1.join(df1_filter, df1_filter("i") === 1).select(df1_filter("j"))) + + val e5 = intercept[AnalysisException] { + // df1("i") is ambiguous as df1 appears in both sides of the first join. + df1.join(df1_filter, df1_filter("i") === 1).join(df2, df1("i") === 1).collect() + } + assert(e5.getMessage.contains("AMBIGUOUS_COLUMN_REFERENCE")) + + checkSameResult( + Seq(Row("a")), + // df1_filter("i") is not ambiguous as df1_filter only appears once. + df1.join(df1_filter).join(df2, df1_filter("i") === 1).select(df1_filter("j"))) + } + test("broadcast join") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "-1") { val left = spark.range(100).select(col("id"), rand(10).as("a")) @@ -731,7 +883,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM private def checkSameResult[E](expected: scala.collection.Seq[E], dataset: Dataset[E]): Unit = { dataset.withResult { result => - assert(expected === result.iterator.asScala.toBuffer) + assert(expected === result.iterator.toBuffer) } } @@ -893,14 +1045,12 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM test("Dataset result destructive iterator") { // Helper methods for accessing private field `idxToBatches` from SparkResult - val _idxToBatches = - PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches")) - - def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = { - val idxToBatches = result invokePrivate _idxToBatches() + val getResultMap = + PrivateMethod[mutable.Map[Int, Any]](Symbol("resultMap")) - // Sort by key to get stable results. - idxToBatches.toSeq.sortBy(_._1).map(_._2) + def assertResultsMapEmpty(result: SparkResult[_]): Unit = { + val resultMap = result invokePrivate getResultMap() + assert(resultMap.isEmpty) } val df = spark @@ -911,25 +1061,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM try { // build and verify the destructive iterator val iterator = result.destructiveIterator - // batches is empty before traversing the result iterator - assert(getColumnarBatches(result).isEmpty) - var previousBatch: ColumnarBatch = null - val buffer = mutable.Buffer.empty[Long] + // resultMap Map is empty before traversing the result iterator + assertResultsMapEmpty(result) + val buffer = mutable.Set.empty[Long] while (iterator.hasNext) { - // always having 1 batch, since a columnar batch will be removed and closed after - // its data got consumed. - val batches = getColumnarBatches(result) - assert(batches.size === 1) - assert(batches.head != previousBatch) - previousBatch = batches.head - - buffer.append(iterator.next()) + // resultMap is empty during iteration because results get removed immediately on access. + assertResultsMapEmpty(result) + buffer += iterator.next() } - // Batches should be closed and removed after traversing all the records. - assert(getColumnarBatches(result).isEmpty) + // resultMap Map is empty afterward because all results have been removed. + assertResultsMapEmpty(result) - val expectedResult = Seq(6L, 7L, 8L) - assert(buffer.size === 3 && expectedResult.forall(buffer.contains)) + val expectedResult = Set(6L, 7L, 8L) + assert(buffer.size === 3 && expectedResult == buffer) } finally { result.close() } @@ -938,7 +1082,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM test("SparkSession.createDataFrame - large data set") { val threshold = 1024 * 1024 - withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) { + withSQLConf(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY -> threshold.toString) { val count = 2 val suffix = "abcdef" val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString + suffix @@ -1161,6 +1305,26 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined) } + + test("dropDuplicatesWithinWatermark not supported in batch DataFrame") { + def testAndVerify(df: Dataset[_]): Unit = { + val exc = intercept[AnalysisException] { + df.write.format("noop").mode(SaveMode.Append).save() + } + + assert(exc.getMessage.contains("dropDuplicatesWithinWatermark is not supported")) + assert(exc.getMessage.contains("batch DataFrames/DataSets")) + } + + val result = spark.range(10).dropDuplicatesWithinWatermark() + testAndVerify(result) + + val result2 = spark + .range(10) + .withColumn("newcol", col("id")) + .dropDuplicatesWithinWatermark("newcol") + testAndVerify(result2) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index 0d361fe1007f7..a88d6ec116a42 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream import scala.collection.JavaConverters._ import org.apache.spark.sql.{functions => fn} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types._ /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index c44b515bdedf4..393fa19fa70b4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ -import org.apache.spark.SparkException -import org.apache.spark.sql.connect.client.util.QueryTest -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { @@ -279,7 +278,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { test("drop with col(*)") { val df = createDF() - val ex = intercept[SparkException] { + val ex = intercept[AnalysisException] { df.na.drop("any", Seq("*")).collect() } assert(ex.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) @@ -388,7 +387,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { } test("replace float with nan") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + withSQLConf(SqlApiConf.ANSI_ENABLED_KEY -> false.toString) { checkAnswer( createNaNDF().na.replace("*", Map(1.0f -> Float.NaN)), Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: @@ -397,7 +396,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { } test("replace double with nan") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + withSQLConf(SqlApiConf.ANSI_ENABLED_KEY -> false.toString) { checkAnswer( createNaNDF().na.replace("*", Map(1.0 -> Double.NaN)), Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 32004b6bcc11d..78cc26d627c7c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -21,9 +21,9 @@ import java.util.Collections import scala.collection.JavaConverters._ import org.apache.spark.sql.avro.{functions => avroFn} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.functions._ import org.apache.spark.sql.protobuf.{functions => pbFn} +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.{DataType, StructType} /** @@ -249,6 +249,8 @@ class FunctionTestSuite extends ConnectFunSuite { pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes(), Map.empty[String, String].asJava), pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes())) + testEquals("call_udf", callUDF("bob", lit(1)), call_udf("bob", lit(1))) + test("assert_true no message") { val e = assert_true(a).expr assert(e.hasUnresolvedFunction) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index e15069f2d9e96..98a947826e3de 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import java.sql.Timestamp import java.util.Arrays -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append -import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types._ +import org.apache.spark.util.SparkSerDeUtils case class ClickEvent(id: String, timestamp: Timestamp) @@ -68,10 +68,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("keyAs - keys") { + // TODO SPARK-44449 make this long again when upcasting is in. // It is okay to cast from Long to Double, but not Long to Int. val values = spark .range(10) - .groupByKey(v => v % 2) + .groupByKey(v => (v % 2).toDouble) .keyAs[Double] .keys .collectAsList() @@ -178,7 +179,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { assert(values == Arrays.asList[String]("0", "8,6,4,2,0", "1", "9,7,5,3,1")) // Star is not allowed as group sort column - val message = intercept[SparkException] { + val message = intercept[AnalysisException] { grouped .flatMapSortedGroups(col("*")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) @@ -232,9 +233,10 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("agg, keyAs") { + // TODO SPARK-44449 make this long again when upcasting is in. val ds = spark .range(10) - .groupByKey(v => v % 2) + .groupByKey(v => (v % 2).toDouble) .keyAs[Double] .agg(count("*")) @@ -244,7 +246,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { test("typed aggregation: expr") { val session: SparkSession = spark import session.implicits._ - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long]), @@ -254,7 +257,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), @@ -264,7 +268,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), @@ -274,7 +279,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -289,7 +295,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -305,7 +312,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -322,7 +330,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -340,7 +349,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -474,7 +484,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { .toDF("key", "seq", "value") val grouped = ds.groupBy($"value").as[String, (String, Int, Int)] val keys = grouped.keyAs[String].keys.sort($"value") - checkDataset(keys, "1", "2", "10", "20") } @@ -622,6 +631,12 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { 30, 3) } + + test("serialize as null") { + val kvgds = session.range(10).groupByKey(_ % 2) + val bytes = SparkSerDeUtils.serialize(kvgds) + assert(SparkSerDeUtils.deserialize[KeyValueGroupedDataset[Long, Long]](bytes) == null) + } } case class K1(a: Long) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 7e4e0f24f4fea..97fa5d5fe53b4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -37,14 +37,13 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.connect.client.SparkConnectClient -import org.apache.spark.sql.connect.client.util.ConnectFunSuite -import org.apache.spark.sql.connect.client.util.IntegrationTestUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.{functions => pbFn} +import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkFileUtils // scalastyle:off /** @@ -131,7 +130,7 @@ class PlanGenerationTestSuite private def cleanOrphanedGoldenFile(): Unit = { val allTestNames = testNames.map(_.replace(' ', '_')) - val orphans = Utils + val orphans = SparkFileUtils .recursiveList(queryFilePath.toFile) .filter(g => g.getAbsolutePath.endsWith(".proto.bin") || @@ -139,7 +138,7 @@ class PlanGenerationTestSuite .filter(g => !allTestNames.contains(g.getName.stripSuffix(".proto.bin")) && !allTestNames.contains(g.getName.stripSuffix(".json"))) - orphans.foreach(Utils.deleteRecursively) + orphans.foreach(SparkFileUtils.deleteRecursively) } private def test(name: String)(f: => Dataset[_]): Unit = super.test(name) { @@ -3232,11 +3231,15 @@ class PlanGenerationTestSuite private val testDescFilePath: String = s"${IntegrationTestUtils.sparkHome}/connector/" + "connect/common/src/test/resources/protobuf-tests/common.desc" - test("from_protobuf messageClassName") { + // TODO(SPARK-45030): Re-enable this test when all Maven test scenarios succeed and there + // are no other negative impacts. For the problem description, please refer to SPARK-45029 + ignore("from_protobuf messageClassName") { binary.select(pbFn.from_protobuf(fn.col("bytes"), classOf[StorageLevel].getName)) } - test("from_protobuf messageClassName options") { + // TODO(SPARK-45030): Re-enable this test when all Maven test scenarios succeed and there + // are no other negative impacts. For the problem description, please refer to SPARK-45029 + ignore("from_protobuf messageClassName options") { binary.select( pbFn.from_protobuf( fn.col("bytes"), diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 3b5d7dae1b383..680380c91a0c2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -22,12 +22,13 @@ import java.time.temporal.ChronoUnit import java.util.concurrent.atomic.AtomicLong import io.grpc.inprocess.InProcessChannelBuilder +import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder} import org.apache.spark.sql.connect.client.SparkConnectClient -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} +import org.apache.spark.sql.test.ConnectFunSuite /** * Test suite for SQL implicits. @@ -54,12 +55,28 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { val spark = session import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { - val encoder = implicitly[Encoder[T]].asInstanceOf[AgnosticEncoder[T]] - val expressionEncoder = ExpressionEncoder(encoder).resolveAndBind() - val serializer = expressionEncoder.createSerializer() - val deserializer = expressionEncoder.createDeserializer() - val actual = deserializer(serializer(expected)) - assert(actual === expected) + val encoder = encoderFor[T] + val allocator = new RootAllocator() + try { + val batch = ArrowSerializer.serialize( + input = Iterator.single(expected), + enc = encoder, + allocator = allocator, + timeZoneId = "UTC") + val fromArrow = ArrowDeserializers.deserializeFromArrow( + input = Iterator.single(batch.toByteArray), + encoder = encoder, + allocator = allocator, + timeZoneId = "UTC") + try { + assert(fromArrow.next() === expected) + assert(!fromArrow.hasNext) + } finally { + fromArrow.close() + } + } finally { + allocator.close() + } } val booleans = Array(false, true, false, false) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 70eeb6c2c41df..e9c2f0c457508 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -16,6 +16,9 @@ */ package org.apache.spark.sql +import java.util.concurrent.ForkJoinPool + +import scala.collection.mutable import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration._ import scala.util.{Failure, Success} @@ -23,8 +26,8 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkException -import org.apache.spark.sql.connect.client.util.RemoteSparkSession -import org.apache.spark.util.ThreadUtils +import org.apache.spark.sql.test.RemoteSparkSession +import org.apache.spark.util.SparkThreadUtils.awaitResult /** * NOTE: Do not import classes that only exist in `spark-connect-client-jvm.jar` into the this @@ -64,13 +67,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession { } // 20 seconds is < 30 seconds the queries should be running, // because it should be interrupted sooner + val interrupted = mutable.ListBuffer[String]() eventually(timeout(20.seconds), interval(1.seconds)) { // keep interrupting every second, until both queries get interrupted. - spark.interruptAll() + val ids = spark.interruptAll() + interrupted ++= ids assert(error.isEmpty, s"Error not empty: $error") assert(q1Interrupted) assert(q2Interrupted) } + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } test("interrupt all - foreground queries, background interrupt") { @@ -79,9 +85,12 @@ class SparkSessionE2ESuite extends RemoteSparkSession { implicit val ec: ExecutionContextExecutor = ExecutionContext.global @volatile var finished = false + val interrupted = mutable.ListBuffer[String]() + val interruptor = Future { eventually(timeout(20.seconds), interval(1.seconds)) { - spark.interruptAll() + val ids = spark.interruptAll() + interrupted ++= ids assert(finished) } finished @@ -95,6 +104,196 @@ class SparkSessionE2ESuite extends RemoteSparkSession { } assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2") finished = true - assert(ThreadUtils.awaitResult(interruptor, 10.seconds)) + assert(awaitResult(interruptor, 10.seconds)) + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + + // TODO(SPARK-48139): Re-enable `SparkSessionE2ESuite.interrupt tag` + ignore("interrupt tag") { + val session = spark + import session.implicits._ + + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 4 + val fpool = new ForkJoinPool(numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + val q1 = Future { + assert(spark.getTags() == Set()) + spark.addTag("two") + assert(spark.getTags() == Set("two")) + spark.clearTags() // check that clearing all tags works + assert(spark.getTags() == Set()) + spark.addTag("one") + assert(spark.getTags() == Set("one")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val q2 = Future { + assert(spark.getTags() == Set()) + spark.addTag("one") + spark.addTag("two") + spark.addTag("one") + spark.addTag("two") // duplicates shouldn't matter + assert(spark.getTags() == Set("one", "two")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val q3 = Future { + assert(spark.getTags() == Set()) + spark.addTag("foo") + spark.removeTag("foo") + assert(spark.getTags() == Set()) // check that remove works removing the last tag + spark.addTag("two") + assert(spark.getTags() == Set("two")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val q4 = Future { + assert(spark.getTags() == Set()) + spark.addTag("one") + spark.addTag("two") + spark.addTag("two") + assert(spark.getTags() == Set("one", "two")) + spark.removeTag("two") // check that remove works, despite duplicate add + assert(spark.getTags() == Set("one")) + try { + spark + .range(10) + .map(n => { + Thread.sleep(30000); n + }) + .collect() + } finally { + spark.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val interrupted = mutable.ListBuffer[String]() + + // q2 and q3 should be cancelled + interrupted.clear() + eventually(timeout(20.seconds), interval(1.seconds)) { + val ids = spark.interruptTag("two") + interrupted ++= ids + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + val e2 = intercept[SparkException] { + awaitResult(q2, 1.minute) + } + assert(e2.getCause.getMessage contains "OPERATION_CANCELED") + val e3 = intercept[SparkException] { + awaitResult(q3, 1.minute) + } + assert(e3.getCause.getMessage contains "OPERATION_CANCELED") + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + + // q1 and q4 should be cancelled + interrupted.clear() + eventually(timeout(20.seconds), interval(1.seconds)) { + val ids = spark.interruptTag("one") + interrupted ++= ids + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + val e1 = intercept[SparkException] { + awaitResult(q1, 1.minute) + } + assert(e1.getCause.getMessage contains "OPERATION_CANCELED") + val e4 = intercept[SparkException] { + awaitResult(q4, 1.minute) + } + assert(e4.getCause.getMessage contains "OPERATION_CANCELED") + assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") + } + + test("interrupt operation") { + val session = spark + import session.implicits._ + + val result = spark + .range(10) + .map(n => { + Thread.sleep(5000); n + }) + .collectResult() + // cancel + val operationId = result.operationId + val canceledId = spark.interruptOperation(operationId) + assert(canceledId == Seq(operationId)) + // and check that it got canceled + val e = intercept[SparkException] { + result.toArray + } + assert(e.getMessage contains "OPERATION_CANCELED") + } + + test("option propagation") { + val remote = s"sc://localhost:$serverPort" + val session1 = SparkSession + .builder() + .remote(remote) + .config("foo", 12L) + .config("bar", value = true) + .config("bob", 12.0) + .config("heading", "north") + .getOrCreate() + assert(session1.conf.get("foo") == "12") + assert(session1.conf.get("bar") == "true") + assert(session1.conf.get("bob") == String.valueOf(12.0)) + assert(session1.conf.get("heading") == "north") + + // Check if new options are applied to an existing session. + val session2 = SparkSession + .builder() + .remote(remote) + .config("heading", "south") + .getOrCreate() + assert(session2 == session1) + assert(session2.conf.get("heading") == "south") + + // Create a completely different session, confs are not support to leak. + val session3 = SparkSession + .builder() + .remote(remote) + .config(Map("foo" -> "13", "baar" -> "false", "heading" -> "east")) + .create() + assert(session3 != session1) + assert(session3.conf.get("foo") == "13") + assert(session3.conf.get("baar") == "false") + assert(session3.conf.getOption("bob").isEmpty) + assert(session3.conf.get("heading") == "east") + + // Try to set a static conf. + intercept[Exception] { + SparkSession + .builder() + .remote(remote) + .config("spark.sql.globalTempDatabase", "not_gonna_happen") + .create() + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 97fb46bf48af4..4c858262c6ef5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,14 +16,23 @@ */ package org.apache.spark.sql +import java.util.concurrent.{Executors, Phaser} + +import scala.util.control.NonFatal + import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.util.SparkSerDeUtils /** * Tests for non-dataframe related SparkSession operations. */ class SparkSessionSuite extends ConnectFunSuite { + private val connectionString1: String = "sc://test.it:17845" + private val connectionString2: String = "sc://test.me:14099" + private val connectionString3: String = "sc://doit:16845" + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") @@ -32,16 +41,15 @@ class SparkSessionSuite extends ConnectFunSuite { } test("remote") { - val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate() + val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) session.close() } test("getOrCreate") { - val connectionString = "sc://test.it:17865" - val session1 = SparkSession.builder().remote(connectionString).getOrCreate() - val session2 = SparkSession.builder().remote(connectionString).getOrCreate() + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + val session2 = SparkSession.builder().remote(connectionString1).getOrCreate() try { assert(session1 eq session2) } finally { @@ -51,9 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite { } test("create") { - val connectionString = "sc://test.it:17845" - val session1 = SparkSession.builder().remote(connectionString).create() - val session2 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() try { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) @@ -64,8 +71,7 @@ class SparkSessionSuite extends ConnectFunSuite { } test("newSession") { - val connectionString = "sc://doit:16845" - val session1 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString3).create() val session2 = session1.newSession() try { assert(session1 ne session2) @@ -92,5 +98,174 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } + session.close() + } + + test("Default/Active session") { + // Make sure we start with a clean slate. + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + intercept[IllegalStateException](SparkSession.active) + + // Create a session + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + assert(SparkSession.active == session1) + + // Create another session... + val session2 = SparkSession.builder().remote(connectionString2).create() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Clear sessions + SparkSession.clearDefaultSession() + assert(SparkSession.getDefaultSession.isEmpty) + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + + // Flip sessions + SparkSession.setActiveSession(session1) + SparkSession.setDefaultSession(session2) + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.contains(session1)) + + // Close session1 + session1.close() + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.isEmpty) + + // Close session2 + session2.close() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + + test("active session in multiple threads") { + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + val phaser = new Phaser(2) + val executor = Executors.newFixedThreadPool(2) + def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = { + executor.submit[Boolean] { () => + try { + block(phaser) + true + } catch { + case NonFatal(e) => + phaser.forceTermination() + throw e + } + } + } + + try { + val script1 = execute { phaser => + // Step 0 - check initial state + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Step 1 - new active session in script 2 + phaser.arriveAndAwaitAdvance() + + // Step2 - script 1 is unchanged, script 2 has new active session + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Step 3 - close session 1, no more default session in both scripts + phaser.arriveAndAwaitAdvance() + session1.close() + + // Step 4 - no default session, same active session. + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(session2)) + + // Step 5 - clear active session in script 1 + phaser.arriveAndAwaitAdvance() + SparkSession.clearActiveSession() + + // Step 6 - no default/no active session in script 1, script2 unchanged. + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + + // Step 7 - close active session in script2 + phaser.arriveAndAwaitAdvance() + } + val script2 = execute { phaser => + // Step 0 - check initial state + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Step 1 - new active session in script 2 + phaser.arriveAndAwaitAdvance() + SparkSession.clearActiveSession() + val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate() + + // Step2 - script 1 is unchanged, script 2 has new active session + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(internalSession)) + + // Step 3 - close session 1, no more default session in both scripts + phaser.arriveAndAwaitAdvance() + + // Step 4 - no default session, same active session. + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + // Step 5 - clear active session in script 1 + phaser.arriveAndAwaitAdvance() + + // Step 6 - no default/no active session in script 1, script2 unchanged. + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + // Step 7 - close active session in script2 + phaser.arriveAndAwaitAdvance() + internalSession.close() + assert(SparkSession.getActiveSession.isEmpty) + } + assert(script1.get()) + assert(script2.get()) + assert(SparkSession.getActiveSession.contains(session2)) + session2.close() + assert(SparkSession.getActiveSession.isEmpty) + } finally { + executor.shutdown() + } + } + + test("deprecated methods") { + SparkSession + .builder() + .master("yayay") + .appName("bob") + .enableHiveSupport() + .create() + .close() + } + + test("serialize as null") { + val session = SparkSession.builder().create() + val bytes = SparkSerDeUtils.serialize(session) + assert(SparkSerDeUtils.deserialize[SparkSession](bytes) == null) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala new file mode 100644 index 0000000000000..b9c5888e5cb77 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.connect.client.ToStub +import org.apache.spark.sql.test.RemoteSparkSession + +class StubbingTestSuite extends RemoteSparkSession { + private def eval[T](f: => T): T = f + + test("capture of to-be stubbed class") { + val session = spark + import session.implicits._ + val result = spark + .range(0, 10, 1, 1) + .map(n => n + 1) + .as[ToStub] + .head() + eval { + assert(result.value == 1) + } + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala new file mode 100644 index 0000000000000..a76e046db2e3a --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.File +import java.nio.file.{Files, Paths} + +import scala.util.Properties + +import org.apache.spark.sql.connect.common.ProtoDataTypes +import org.apache.spark.sql.expressions.ScalarUserDefinedFunction +import org.apache.spark.sql.test.RemoteSparkSession + +class UDFClassLoadingE2ESuite extends RemoteSparkSession { + + private val scalaVersion = Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + + // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. + private val udfByteArray: Array[Byte] = + Files.readAllBytes(Paths.get(s"src/test/resources/udf$scalaVersion")) + private val udfJar = + new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL + + private def registerUdf(session: SparkSession): Unit = { + val udf = ScalarUserDefinedFunction( + serializedUdfPacket = udfByteArray, + inputTypes = Seq(ProtoDataTypes.IntegerType), + outputType = ProtoDataTypes.IntegerType, + name = Some("dummyUdf"), + nullable = true, + deterministic = true) + session.registerUdf(udf.toProto) + } + + test("update class loader after stubbing: new session") { + // Session1 should stub the missing class, but fail to call methods on it + val session1 = spark.newSession() + + assert( + intercept[Exception] { + registerUdf(session1) + }.getMessage.contains( + "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf")) + + // Session2 uses the real class + val session2 = spark.newSession() + session2.addArtifact(udfJar.toURI) + registerUdf(session2) + } + + test("update class loader after stubbing: same session") { + // Session should stub the missing class, but fail to call methods on it + val session = spark.newSession() + + assert( + intercept[Exception] { + registerUdf(session) + }.getMessage.contains( + "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf")) + + // Session uses the real class + session.addArtifact(udfJar.toURI) + registerUdf(session) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index a4f1a61cf3997..fbc2c1c266262 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -23,11 +23,12 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ -import org.apache.spark.TaskContext import org.apache.spark.api.java.function._ +import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} -import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions.{col, struct, udf} +import org.apache.spark.sql.test.QueryTest +import org.apache.spark.sql.types.IntegerType /** * All tests in this class requires client UDF defined in this test class synced with the server. @@ -94,6 +95,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { rows.forEach(x => assert(x == 42)) } + test("(deprecated) Dataset explode") { + val session: SparkSession = spark + import session.implicits._ + val result1 = spark + .range(3) + .filter(col("id") =!= 1L) + .explode(col("id") + 41, col("id") + 10) { case Row(x: Long, y: Long) => + Iterator((x, x - 1), (y, y + 1)) + } + .as[(Long, Long, Long)] + .collect() + .toSeq + assert(result1 === Seq((0L, 41L, 40L), (0L, 10L, 11L), (2L, 43L, 42L), (2L, 12L, 13L))) + + val result2 = Seq((1, "a b c"), (2, "a b"), (3, "a")) + .toDF("number", "letters") + .explode('letters) { case Row(letters: String) => + letters.split(' ').map(Tuple1.apply).toSeq + } + .as[(Int, String, String)] + .collect() + .toSeq + assert( + result2 === Seq( + (1, "a b c", "a"), + (1, "a b c", "b"), + (1, "a b c", "c"), + (2, "a b", "a"), + (2, "a b", "b"), + (3, "a", "a"))) + + val result3 = Seq("a b c", "d e") + .toDF("words") + .explode("words", "word") { word: String => + word.split(' ').toSeq + } + .select(col("word")) + .as[String] + .collect() + .toSeq + assert(result3 === Seq("a", "b", "c", "d", "e")) + + val result4 = Seq("a b c", "d e") + .toDF("words") + .explode("words", "word") { word: String => + word.split(' ').map(s => s -> s.head.toInt).toSeq + } + .select(col("word"), col("words")) + .as[((String, Int), String)] + .collect() + .toSeq + assert( + result4 === Seq( + (("a", 97), "a b c"), + (("b", 98), "a b c"), + (("c", 99), "a b c"), + (("d", 100), "d e"), + (("e", 101), "d e"))) + } + test("Dataset typed flat map - java") { val rows = spark .range(5) @@ -154,39 +215,31 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { } test("Dataset foreachPartition") { - val sum = new AtomicLong() val func: Iterator[JLong] => Unit = f => { + val sum = new AtomicLong() f.foreach(v => sum.addAndGet(v)) - TaskContext - .get() - .addTaskCompletionListener(_ => - // The value should be 45 - assert(sum.get() == -1)) + throw new Exception("Success, processed records: " + sum.get()) } val exception = intercept[Exception] { spark.range(10).repartition(1).foreachPartition(func) } - assert(exception.getMessage.contains("45 did not equal -1")) + assert(exception.getMessage.contains("Success, processed records: 45")) } test("Dataset foreachPartition - java") { val sum = new AtomicLong() val exception = intercept[Exception] { spark - .range(10) + .range(11) .repartition(1) .foreachPartition(new ForeachPartitionFunction[JLong] { override def call(t: JIterator[JLong]): Unit = { t.asScala.foreach(v => sum.addAndGet(v)) - TaskContext - .get() - .addTaskCompletionListener(_ => - // The value should be 45 - assert(sum.get() == -1)) + throw new Exception("Success, processed records: " + sum.get()) } }) } - assert(exception.getMessage.contains("45 did not equal -1")) + assert(exception.getMessage.contains("Success, processed records: 55")) } test("Dataset foreach: change not visible to client") { @@ -257,4 +310,37 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { "b", "c") } + + test("(deprecated) scala UDF with dataType") { + val session: SparkSession = spark + import session.implicits._ + val fn = udf(((i: Long) => (i + 1).toInt), IntegerType) + checkDataset(session.range(2).select(fn($"id")).as[Int], 1, 2) + } + + test("java UDF") { + val session: SparkSession = spark + import session.implicits._ + val fn = udf( + new UDF2[Long, Long, Int] { + override def call(t1: Long, t2: Long): Int = (t1 + t2 + 1).toInt + }, + IntegerType) + checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5) + } + + test("nullified SparkSession/Dataset/KeyValueGroupedDataset in UDF") { + val session: SparkSession = spark + import session.implicits._ + val df = session.range(0, 10, 1, 1) + val kvgds = df.groupByKey(_ / 2) + val f = udf { (i: Long) => + assert(session == null) + assert(df == null) + assert(kvgds == null) + i + 1 + } + val result = df.select(f($"id")).as[Long].head + assert(result == 1L) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala index 1c4ee21773749..923aa5af75ba8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql import scala.reflect.runtime.universe.typeTag -import org.scalatest.BeforeAndAfterEach - +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.UdfPacket import org.apache.spark.sql.functions.udf -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.util.SparkSerDeUtils -class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach { +class UserDefinedFunctionSuite extends ConnectFunSuite { test("udf and encoder serialization") { def func(x: Int): Int = x + 1 @@ -42,10 +41,49 @@ class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(udfObj.getNullable) - val deSer = Utils.deserialize[UdfPacket](udfObj.getPayload.toByteArray) + val deSer = SparkSerDeUtils.deserialize[UdfPacket](udfObj.getPayload.toByteArray) assert(deSer.function.asInstanceOf[Int => Int](5) == func(5)) assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int])) assert(deSer.inputEncoders == Seq(ScalaReflection.encoderFor(typeTag[Int]))) } + + private def testNonDeserializable(f: Int => Int): Unit = { + val e = intercept[SparkException](udf(f)) + assert( + e.getMessage.contains( + "UDF cannot be executed on a Spark cluster: it cannot be deserialized.")) + assert(e.getMessage.contains("This is not supported by java serialization.")) + } + + test("non deserializable UDFs") { + testNonDeserializable(Command2(Command1()).indirect) + testNonDeserializable(MultipleLambdas().indirect) + testNonDeserializable(SelfRef(22).method) + } + + test("serializable UDFs") { + val direct = (i: Int) => i + 1 + val indirect = (i: Int) => direct(i) + udf(indirect) + udf(Command1().direct) + udf(MultipleLambdas().direct) + } +} + +case class Command1() extends Serializable { + val direct: Int => Int = (i: Int) => i + 1 +} + +case class Command2(prev: Command1) extends Serializable { + val indirect: Int => Int = (i: Int) => prev.direct(i) +} + +case class SelfRef(start: Int) extends Serializable { + val method: Int => Int = (i: Int) => i + start +} + +case class MultipleLambdas() extends Serializable { + val direct: Int => Int = (i: Int) => i + 1 + val indirect: Int => Int = (i: Int) => direct(i) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 58758a1384031..9d61b4d56e1ed 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -23,10 +23,9 @@ import java.util.concurrent.{Executors, Semaphore, TimeUnit} import scala.util.Properties import org.apache.commons.io.output.ByteArrayOutputStream -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} +import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession} class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { @@ -51,29 +50,26 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } override def beforeAll(): Unit = { - // TODO(SPARK-44121) Remove this check condition - if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) { - super.beforeAll() - ammoniteOut = new ByteArrayOutputStream() - testSuiteOut = new PipedOutputStream() - // Connect the `testSuiteOut` and `ammoniteIn` pipes - ammoniteIn = new PipedInputStream(testSuiteOut) - errorStream = new ByteArrayOutputStream() - - val args = Array("--port", serverPort.toString) - val task = new Runnable { - override def run(): Unit = { - ConnectRepl.doMain( - args = args, - semaphore = Some(semaphore), - inputStream = ammoniteIn, - outputStream = ammoniteOut, - errorStream = errorStream) - } - } + super.beforeAll() + ammoniteOut = new ByteArrayOutputStream() + testSuiteOut = new PipedOutputStream() + // Connect the `testSuiteOut` and `ammoniteIn` pipes + ammoniteIn = new PipedInputStream(testSuiteOut) + errorStream = new ByteArrayOutputStream() - executorService.submit(task) + val args = Array("--port", serverPort.toString) + val task = new Runnable { + override def run(): Unit = { + ConnectRepl.doMain( + args = args, + semaphore = Some(semaphore), + inputStream = ammoniteIn, + outputStream = ammoniteOut, + errorStream = errorStream) + } } + + executorService.submit(task) } override def afterAll(): Unit = { @@ -86,6 +82,7 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } def runCommandsInShell(input: String): String = { + ammoniteOut.reset() require(input.nonEmpty) // Pad the input with a semaphore release so that we know when the execution of the provided // input is complete. @@ -106,6 +103,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { getCleanString(ammoniteOut) } + def runCommandsUsingSingleCellInShell(input: String): String = { + runCommandsInShell("{\n" + input + "\n}") + } + def assertContains(message: String, output: String): Unit = { val isContain = output.contains(message) assert( @@ -134,20 +135,6 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("Array[Int] = Array(19, 24, 29, 34, 39)", output) } - // SPARK-43198: Switching REPL to CodeClass generation mode causes UDFs defined through lambda - // expressions to hit deserialization issues. - // TODO(SPARK-43227): Enable test after fixing deserialization issue. - ignore("UDF containing lambda expression") { - val input = """ - |class A(x: Int) { def get = x * 20 + 5 } - |val dummyUdf = (x: Int) => new A(x).get - |val myUdf = udf(dummyUdf) - |spark.range(5).select(myUdf(col("id"))).as[Int].collect() - """.stripMargin - val output = runCommandsInShell(input) - assertContains("Array[Int] = Array(5, 25, 45, 65, 85)", output) - } - test("UDF containing in-place lambda") { val input = """ |class A(x: Int) { def get = x * 42 + 5 } @@ -207,6 +194,36 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { // scalastyle:on classforname line.size.limit } + test("Java UDF") { + val input = + """ + |import org.apache.spark.sql.api.java._ + |import org.apache.spark.sql.types.LongType + | + |val javaUdf = udf(new UDF1[Long, Long] { + | override def call(num: Long): Long = num * num + 25L + |}, LongType).asNondeterministic() + |spark.range(5).select(javaUdf(col("id"))).as[Long].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Long] = Array(25L, 26L, 29L, 34L, 41L)", output) + } + + test("Java UDF Registration") { + val input = + """ + |import org.apache.spark.sql.api.java._ + |import org.apache.spark.sql.types.LongType + | + |spark.udf.register("javaUdf", new UDF1[Long, Long] { + | override def call(num: Long): Long = num * num * num + 250L + |}, LongType) + |spark.sql("select javaUdf(id) from range(5)").as[Long].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Long] = Array(250L, 251L, 258L, 277L, 314L)", output) + } + test("UDF Registration") { val input = """ |class A(x: Int) { def get = x * 100 } @@ -237,4 +254,128 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { val output = runCommandsInShell(input) assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) } + + test("call_function") { + val input = """ + |val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + |spark.udf.register("simpleUDF", (v: Int) => v * v) + |df.select($"id", call_function("simpleUDF", $"value")).collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) + } + + test("Single Cell Compilation") { + val input = + """ + |case class C1(value: Int) + |case class C2(value: Int) + |val h1 = classOf[C1].getDeclaringClass + |val h2 = classOf[C2].getDeclaringClass + |val same = h1 == h2 + |""".stripMargin + assertContains("same: Boolean = false", runCommandsInShell(input)) + assertContains("same: Boolean = true", runCommandsUsingSingleCellInShell(input)) + } + + test("Local relation containing REPL generated class") { + val input = + """ + |case class MyTestClass(value: Int) + |val data = (0 to 10).map(MyTestClass) + |spark.createDataset(data).map(mtc => mtc.value).select(sum($"value")).as[Long].head + |""".stripMargin + val expected = "Long = 55L" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) + } + + test("Collect REPL generated class") { + val input = + """ + |case class MyTestClass(value: Int) + |spark.range(4). + | filter($"id" % 2 === 1). + | select($"id".cast("int").as("value")). + | as[MyTestClass]. + | collect(). + | map(mtc => s"MyTestClass(${mtc.value})"). + | mkString("[", ", ", "]") + """.stripMargin + val expected = """String = "[MyTestClass(1), MyTestClass(3)]"""" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) + } + + test("REPL class in encoder") { + val input = """ + |case class MyTestClass(value: Int) + |spark.range(3). + | select(col("id").cast("int").as("value")). + | as[MyTestClass]. + | map(mtc => mtc.value). + | collect() + """.stripMargin + val expected = "Array[Int] = Array(0, 1, 2)" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) + } + + test("REPL class in UDF") { + val input = """ + |case class MyTestClass(value: Int) + |spark.range(2). + | map(i => MyTestClass(i.toInt)). + | collect(). + | map(mtc => s"MyTestClass(${mtc.value})"). + | mkString("[", ", ", "]") + """.stripMargin + val expected = """String = "[MyTestClass(0), MyTestClass(1)]"""" + assertContains(expected, runCommandsInShell(input)) + assertContains(expected, runCommandsUsingSingleCellInShell(input)) + } + + test("streaming works with REPL generated code") { + val input = + """ + |val add1 = udf((i: Long) => i + 1) + |val query = { + | spark.readStream + | .format("rate") + | .option("rowsPerSecond", "10") + | .option("numPartitions", "1") + | .load() + | .withColumn("value", add1($"value")) + | .writeStream + | .format("memory") + | .queryName("my_sink") + | .start() + |} + |var progress = query.lastProgress + |while (query.isActive && (progress == null || progress.numInputRows == 0)) { + | query.awaitTermination(100) + | progress = query.lastProgress + |} + |val noException = query.exception.isEmpty + |query.stop() + |""".stripMargin + val output = runCommandsInShell(input) + assertContains("noException: Boolean = true", output) + } + + test("broadcast works with REPL generated code") { + val input = + """ + |val add1 = udf((i: Long) => i + 1) + |val tableA = spark.range(2).alias("a") + |val tableB = broadcast(spark.range(2).select(add1(col("id")).alias("id"))).alias("b") + |tableA.join(tableB). + | where(col("a.id")===col("b.id")). + | select(col("a.id").alias("a_id"), col("b.id").alias("b_id")). + | collect(). + | mkString("[", ", ", "]") + |""".stripMargin + val output = runCommandsInShell(input) + assertContains("""String = "[[1,1]]"""", output) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index a15d7562f19e1..770143f2e9b4e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -28,9 +28,9 @@ import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.scalatest.BeforeAndAfterEach -import org.apache.spark.connect.proto import org.apache.spark.connect.proto.AddArtifactsRequest -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration +import org.apache.spark.sql.test.ConnectFunSuite class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach { @@ -57,7 +57,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach { retryPolicy = GrpcRetryHandler.RetryPolicy() bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy) stub = new CustomSparkConnectStub(channel, retryPolicy) - artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), "", bstub, stub) + artifactManager = new ArtifactManager(Configuration(), "", bstub, stub) } override def beforeEach(): Unit = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 70db03df7bdd8..0cc1a44b27327 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -24,7 +24,8 @@ import java.util.regex.Pattern import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.lib.MiMaLib -import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._ +import org.apache.spark.SparkBuildInfo.spark_version +import org.apache.spark.sql.test.IntegrationTestUtils._ /** * A tool for checking the binary compatibility of the connect client API against the spark SQL @@ -46,18 +47,38 @@ object CheckConnectJvmClientCompatibility { sys.env("SPARK_HOME") } + private val sqlJar = { + val path = Paths.get( + sparkHome, + "sql", + "core", + "target", + "scala-" + scalaVersion, + "spark-sql_" + scalaVersion + "-" + spark_version + ".jar") + assert(Files.exists(path), s"$path does not exist") + path.toFile + } + + private val clientJar = { + val path = Paths.get( + sparkHome, + "connector", + "connect", + "client", + "jvm", + "target", + "scala-" + scalaVersion, + "spark-connect-client-jvm_" + scalaVersion + "-" + spark_version + ".jar") + assert(Files.exists(path), s"$path does not exist") + path.toFile + } + def main(args: Array[String]): Unit = { var resultWriter: Writer = null try { resultWriter = Files.newBufferedWriter( Paths.get(s"$sparkHome/.connect-mima-check-result"), StandardCharsets.UTF_8) - val clientJar: File = - findJar( - "connector/connect/client/jvm", - "spark-connect-client-jvm-assembly", - "spark-connect-client-jvm") - val sqlJar: File = findJar("sql/core", "spark-sql", "spark-sql") val problemsWithSqlModule = checkMiMaCompatibilityWithSqlModule(clientJar, sqlJar) appendMimaCheckErrorMessageIfNeeded( resultWriter, @@ -163,9 +184,6 @@ object CheckConnectJvmClientCompatibility { // DataFrameNaFunctions ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameNaFunctions.fillValue"), - // DataFrameStatFunctions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameStatFunctions.bloomFilter"), - // Dataset ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.Dataset$" // private[sql] @@ -181,19 +199,14 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), // deprecated ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"), // functions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.call_udf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), @@ -208,24 +221,18 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sqlContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udf"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"), ProblemFilters.exclude[Problem]( "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), - // TODO(SPARK-44068): Support positional parameters in Scala connect client - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"), // SparkSession#implicits @@ -233,14 +240,8 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.SparkSession#implicits._sqlContext"), // SparkSession#Builder - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.appName"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.config"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.master"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession#Builder.enableHiveSupport"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.withExtensions"), @@ -250,9 +251,6 @@ object CheckConnectJvmClientCompatibility { // DataStreamWriter ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.DataStreamWriter$"), - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.streaming.DataStreamWriter.foreachBatch" // TODO(SPARK-42944) - ), ProblemFilters.exclude[Problem]( "org.apache.spark.sql.streaming.DataStreamWriter.SOURCE*" // These are constant vals. ), @@ -270,8 +268,6 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.StreamingQueryException.time"), // Classes missing from streaming API - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.TestGroupState"), ProblemFilters.exclude[MissingClassProblem]( @@ -280,6 +276,24 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.PythonStreamingQueryListener"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.PythonStreamingQueryListenerWrapper"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener$Event"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryIdleEvent"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener#QueryIdleEvent.logEvent"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgressEvent.logEvent"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.logEvent"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminatedEvent.logEvent"), // SQLImplicits ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"), @@ -311,6 +325,10 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.SQLImplicits._sqlContext" // protected ), + // Catalyst Refactoring + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils$"), + // New public APIs added in the client // ScalarUserDefinedFunction ProblemFilters @@ -323,8 +341,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.plan" ), // developer API - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.sql.Dataset.encoder"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.collectResult"), @@ -342,6 +358,12 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.application.ConnectRepl$" // developer API ), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.application.ExtendedCodeClassWrapper" // developer API + ), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.application.ExtendedCodeClassWrapper$" // developer API + ), // SparkSession // developer API @@ -361,6 +383,18 @@ object CheckConnectJvmClientCompatibility { // public ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.interruptAll"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.interruptTag"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.interruptOperation"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.addTag"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.removeTag"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.getTags"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.clearTags"), // SparkSession#Builder ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.remote"), @@ -384,7 +418,11 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.RemoteStreamingQuery"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.RemoteStreamingQuery$")) + "org.apache.spark.sql.streaming.RemoteStreamingQuery$"), + + // Encoders are in the wrong JAR + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala index c9066615bb572..ca23436675f87 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala @@ -20,15 +20,15 @@ import java.nio.file.Paths import org.apache.commons.io.FileUtils -import org.apache.spark.sql.connect.client.util.ConnectFunSuite -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.util.SparkFileUtils class ClassFinderSuite extends ConnectFunSuite { private val classResourcePath = commonResourcePath.resolve("artifact-tests") test("REPLClassDirMonitor functionality test") { - val copyDir = Utils.createTempDir().toPath + val copyDir = SparkFileUtils.createTempDir().toPath FileUtils.copyDirectory(classResourcePath.toFile, copyDir.toFile) val monitor = new REPLClassDirMonitor(copyDir.toAbsolutePath.toString) @@ -47,7 +47,7 @@ class ClassFinderSuite extends ConnectFunSuite { checkClasses(monitor) // Add new class file into directory - val subDir = Utils.createTempDir(copyDir.toAbsolutePath.toString) + val subDir = SparkFileUtils.createTempDir(copyDir.toAbsolutePath.toString) val classToCopy = copyDir.resolve("Hello.class") val copyLocation = subDir.toPath.resolve("HelloDup.class") FileUtils.copyFile(classToCopy.toFile, copyLocation.toFile) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala index 2c6886d0386c5..e1d4a18d0ff60 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.connect.client -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import java.util.UUID + +import org.apache.spark.sql.test.ConnectFunSuite /** * Test suite for [[SparkConnectClient.Builder]] parsing and configuration. @@ -46,6 +48,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite { argumentTest("user_id", "U1238", _.userId.get) argumentTest("user_name", "alice", _.userName.get) argumentTest("user_agent", "MY APP", _.userAgent) + argumentTest("session_id", UUID.randomUUID().toString, _.sessionId.get) test("Argument - remote") { val builder = @@ -55,6 +58,7 @@ class SparkConnectClientBuilderParseTestSuite extends ConnectFunSuite { assert(builder.token.contains("nahnah")) assert(builder.userId.contains("x127")) assert(builder.options === Map(("user_name", "Q"), ("param1", "x"))) + assert(builder.sessionId.isEmpty) } test("Argument - use_ssl") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 50182f1517816..89acc2c60ac21 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connect.client +import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -30,8 +31,8 @@ import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.test.ConnectFunSuite class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { @@ -85,6 +86,24 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(response.getSessionId === "abc123") } + private def withEnvs(pairs: (String, String)*)(f: => Unit): Unit = { + val readonlyEnv = System.getenv() + val field = readonlyEnv.getClass.getDeclaredField("m") + field.setAccessible(true) + val modifiableEnv = field.get(readonlyEnv).asInstanceOf[java.util.Map[String, String]] + try { + for ((k, v) <- pairs) { + assert(!modifiableEnv.containsKey(k)) + modifiableEnv.put(k, v) + } + f + } finally { + for ((k, _) <- pairs) { + modifiableEnv.remove(k) + } + } + } + test("Test connection") { testClientConnection() { testPort => SparkConnectClient.builder().port(testPort).build() } } @@ -111,6 +130,49 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } + test("SparkSession create with SPARK_REMOTE") { + startDummyServer(0) + + withEnvs("SPARK_REMOTE" -> s"sc://localhost:${server.getPort}") { + val session = SparkSession.builder().create() + val df = session.range(10) + df.analyze // Trigger RPC + assert(df.plan === service.getAndClearLatestInputPlan()) + + val session2 = SparkSession.builder().create() + assert(session != session2) + } + } + + test("SparkSession getOrCreate with SPARK_REMOTE") { + startDummyServer(0) + + withEnvs("SPARK_REMOTE" -> s"sc://localhost:${server.getPort}") { + val session = SparkSession.builder().getOrCreate() + + val df = session.range(10) + df.analyze // Trigger RPC + assert(df.plan === service.getAndClearLatestInputPlan()) + + val session2 = SparkSession.builder().getOrCreate() + assert(session === session2) + } + } + + test("Builder.remote takes precedence over SPARK_REMOTE") { + startDummyServer(0) + val incorrectUrl = s"sc://localhost:${server.getPort + 1}" + + withEnvs("SPARK_REMOTE" -> incorrectUrl) { + val session = + SparkSession.builder().remote(s"sc://localhost:${server.getPort}").getOrCreate() + + val df = session.range(10) + df.analyze // Trigger RPC + assert(df.plan === service.getAndClearLatestInputPlan()) + } + } + test("SparkSession initialisation with connection string") { startDummyServer(0) client = SparkConnectClient @@ -163,6 +225,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { client => { assert(client.configuration.host == "localhost") assert(client.configuration.port == 1234) + assert(client.sessionId != null) + // Must be able to parse the UUID + assert(UUID.fromString(client.sessionId) != null) }), TestPackURI( "sc://localhost/;", @@ -192,6 +257,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { TestPackURI("sc://host:123/;use_ssl=true", isCorrect = true), TestPackURI("sc://host:123/;token=mySecretToken", isCorrect = true), TestPackURI("sc://host:123/;token=", isCorrect = false), + TestPackURI("sc://host:123/;session_id=", isCorrect = false), + TestPackURI("sc://host:123/;session_id=abcdefgh", isCorrect = false), + TestPackURI(s"sc://host:123/;session_id=${UUID.randomUUID().toString}", isCorrect = true), TestPackURI("sc://host:123/;use_ssl=true;token=mySecretToken", isCorrect = true), TestPackURI("sc://host:123/;token=mySecretToken;use_ssl=true", isCorrect = true), TestPackURI("sc://host:123/;use_ssl=false;token=mySecretToken", isCorrect = false), @@ -213,10 +281,10 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } - private class DummyFn(val e: Throwable) { + private class DummyFn(val e: Throwable, numFails: Int = 3) { var counter = 0 def fn(): Int = { - if (counter < 3) { + if (counter < numFails) { counter += 1 throw e } else { @@ -225,6 +293,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } + test("SPARK-44721: Retries run for a minimum period") { + // repeat test few times to avoid random flakes + for (_ <- 1 to 10) { + var totalSleepMs: Long = 0 + + def sleep(t: Long): Unit = { + totalSleepMs += t + } + + val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100) + val retryHandler = new GrpcRetryHandler(GrpcRetryHandler.RetryPolicy(), sleep) + + assertThrows[StatusRuntimeException] { + retryHandler.retry { + dummyFn.fn() + } + } + + assert(totalSleepMs >= 10 * 60 * 1000) // waited at least 10 minutes + } + } + test("SPARK-44275: retry actually retries") { val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) val retryPolicy = GrpcRetryHandler.RetryPolicy() @@ -292,12 +382,30 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { // Reply with a dummy response using the same client ID val requestSessionId = request.getSessionId + val operationId = if (request.hasOperationId) { + request.getOperationId + } else { + UUID.randomUUID().toString + } inputPlan = request.getPlan val response = ExecutePlanResponse .newBuilder() .setSessionId(requestSessionId) + .setOperationId(operationId) .build() responseObserver.onNext(response) + // Reattachable execute must end with ResultComplete + if (request.getRequestOptionsList.asScala.exists { option => + option.hasReattachOptions && option.getReattachOptions.getReattachable == true + }) { + val resultComplete = ExecutePlanResponse + .newBuilder() + .setSessionId(requestSessionId) + .setOperationId(operationId) + .setResultComplete(proto.ExecutePlanResponse.ResultComplete.newBuilder().build()) + .build() + responseObserver.onNext(resultComplete) + } responseObserver.onCompleted() } @@ -362,4 +470,37 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver.onNext(builder.build()) responseObserver.onCompleted() } + + override def interrupt( + request: proto.InterruptRequest, + responseObserver: StreamObserver[proto.InterruptResponse]): Unit = { + val response = proto.InterruptResponse.newBuilder().setSessionId(request.getSessionId).build() + responseObserver.onNext(response) + responseObserver.onCompleted() + } + + override def reattachExecute( + request: proto.ReattachExecuteRequest, + responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { + // Reply with a dummy response using the same client ID + val requestSessionId = request.getSessionId + val response = ExecutePlanResponse + .newBuilder() + .setSessionId(requestSessionId) + .build() + responseObserver.onNext(response) + responseObserver.onCompleted() + } + + override def releaseExecute( + request: proto.ReleaseExecuteRequest, + responseObserver: StreamObserver[proto.ReleaseExecuteResponse]): Unit = { + val response = proto.ReleaseExecuteResponse + .newBuilder() + .setSessionId(request.getSessionId) + .setOperationId(request.getOperationId) + .build() + responseObserver.onNext(response) + responseObserver.onCompleted() + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 0c327484e477d..b6ad27d3e5287 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -17,31 +17,34 @@ package org.apache.spark.sql.connect.client.arrow import java.math.BigInteger +import java.time.{Duration, Period, ZoneOffset} +import java.time.temporal.ChronoUnit import java.util import java.util.{Collections, Objects} import scala.beans.BeanProperty -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.classTag -import scala.util.control.NonFatal -import com.google.protobuf.ByteString import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.connect.proto -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} -import org.apache.spark.sql.connect.client.SparkResult +import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND +import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ +import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ +import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum -import org.apache.spark.sql.connect.client.util.ConnectFunSuite -import org.apache.spark.sql.types.{ArrayType, DataType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType} +import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} /** * Tests for encoding external data to and from arrow. @@ -73,13 +76,31 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { maxBatchSize: Long = 16 * 1024, batchSizeCheckInterval: Int = 128, inspectBatch: Array[Byte] => Unit = null): CloseableIterator[T] = { + roundTripWithDifferentIOEncoders( + encoder, + encoder, + iterator, + maxRecordsPerBatch, + maxBatchSize, + batchSizeCheckInterval, + inspectBatch) + } + + private def roundTripWithDifferentIOEncoders[I, O]( + inputEncoder: AgnosticEncoder[I], + outputEncoder: AgnosticEncoder[O], + iterator: Iterator[I], + maxRecordsPerBatch: Int = 4 * 1024, + maxBatchSize: Long = 16 * 1024, + batchSizeCheckInterval: Int = 128, + inspectBatch: Array[Byte] => Unit = null): CloseableIterator[O] = { // Use different allocators so we can pinpoint memory leaks better. val serializerAllocator = newAllocator("serialization") val deserializerAllocator = newAllocator("deserialization") val arrowIterator = ArrowSerializer.serialize( input = iterator, - enc = encoder, + enc = inputEncoder, allocator = serializerAllocator, maxRecordsPerBatch = maxRecordsPerBatch, maxBatchSize = maxBatchSize, @@ -96,16 +117,12 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } val resultIterator = - try { - deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator) - } catch { - case NonFatal(e) => - arrowIterator.close() - serializerAllocator.close() - deserializerAllocator.close() - throw e - } - new CloseableIterator[T] { + ArrowDeserializers.deserializeFromArrow( + inspectedIterator, + outputEncoder, + deserializerAllocator, + timeZoneId = "UTC") + new CloseableIterator[O] { override def close(): Unit = { arrowIterator.close() resultIterator.close() @@ -113,26 +130,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { deserializerAllocator.close() } override def hasNext: Boolean = resultIterator.hasNext - override def next(): T = resultIterator.next() - } - } - - // Temporary hack until we merge the deserializer. - private def deserializeFromArrow[E]( - batches: Iterator[Array[Byte]], - encoder: AgnosticEncoder[E], - allocator: BufferAllocator): CloseableIterator[E] = { - val responses = batches.map { batch => - val builder = proto.ExecutePlanResponse.newBuilder() - builder.getArrowBatchBuilder.setData(ByteString.copyFrom(batch)) - builder.build() - } - val result = new SparkResult[E](responses.asJava, allocator, encoder) - new CloseableIterator[E] { - private val itr = result.iterator - override def close(): Unit = itr.close() - override def hasNext: Boolean = itr.hasNext - override def next(): E = itr.next() + override def next(): O = resultIterator.next() } } @@ -188,11 +186,11 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } private def compareIterators[T](expected: Iterator[T], actual: Iterator[T]): Unit = { - expected.zipAll(actual, null, null).foreach { case (expected, actual) => - assert(expected != null) - assert(actual != null) - assert(actual == expected) + while (expected.hasNext && actual.hasNext) { + assert(expected.next() == actual.next()) } + assert(!expected.hasNext, "Less results produced than expected.") + assert(!actual.hasNext, "More results produced than expected.") } private class CountingBatchInspector extends (Array[Byte] => Unit) { @@ -246,6 +244,18 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { assert(inspector.sizeInBytes > 0) } + test("deserializing empty iterator") { + withAllocator { allocator => + val iterator = ArrowDeserializers.deserializeFromArrow( + Iterator.empty, + singleIntEncoder, + allocator, + timeZoneId = "UTC") + assert(iterator.isEmpty) + assert(allocator.getAllocatedMemory == 0) + } + } + test("single batch") { val inspector = new CountingBatchInspector roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () => @@ -353,8 +363,10 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { test("nullable fields") { val encoder = ScalaReflection.encoderFor[NullableData] - val instant = java.time.Instant.now() - val now = java.time.LocalDateTime.now() + // SPARK-44457: Similar to SPARK-42770, calling `truncatedTo(ChronoUnit.MICROS)` + // on `Instant.now()` and `LocalDateTime.now()` to ensure microsecond accuracy is used. + val instant = java.time.Instant.now().truncatedTo(ChronoUnit.MICROS) + val now = java.time.LocalDateTime.now().truncatedTo(ChronoUnit.MICROS) val today = java.time.LocalDate.now() roundTripAndCheckIdentical(encoder) { () => val maybeNull = MaybeNull(3) @@ -533,15 +545,22 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { val maybeNull = MaybeNull(11) Iterator.tabulate(100) { i => val bean = new JavaMapData - bean.setDummyToDoubleListMap(maybeNull { - val map = new util.HashMap[DummyBean, java.util.List[java.lang.Double]] - (0 until (i % 5)).foreach { j => - val dummy = new DummyBean - dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + bean.setMetricMap(maybeNull { + val map = new util.HashMap[String, util.List[java.lang.Double]] + (0 until (i % 20)).foreach { i => val values = Array.tabulate(i % 40) { j => Double.box(j.toDouble) } - map.put(dummy, maybeNull(util.Arrays.asList(values: _*))) + map.put("k" + i, maybeNull(util.Arrays.asList(values: _*))) + } + map + }) + bean.setDummyToStringMap(maybeNull { + val map = new util.HashMap[DummyBean, String] + (0 until (i % 5)).foreach { j => + val dummy = new DummyBean + dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + map.put(dummy, maybeNull("s" + i + "v" + j)) } map }) @@ -587,7 +606,9 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } test("lenient field serialization - timestamp/instant") { - val base = java.time.Instant.now() + // SPARK-44457: Similar to SPARK-42770, calling `truncatedTo(ChronoUnit.MICROS)` + // on `Instant.now()` to ensure microsecond accuracy is used. + val base = java.time.Instant.now().truncatedTo(ChronoUnit.MICROS) val instants = () => Iterator.tabulate(10)(i => base.plusSeconds(i * i * 60)) val timestamps = () => instants().map(java.sql.Timestamp.from) val combo = () => instants() ++ timestamps() @@ -675,6 +696,223 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { .add("Ca", "array") .add("Cb", "binary"))) + test("bind to schema") { + // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly + // different field order, and uses different cased names in a couple of places. + withAllocator { allocator => + val input = Row( + 887, + "foo", + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f), + Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false))) + val expected = Row( + "foo", + Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) + val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) + val result = + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + narrowSchemaEncoder, + allocator, + timeZoneId = "UTC") + val actual = result.next() + assert(result.isEmpty) + assert(expected === actual) + result.close() + arrowBatches.close() + } + } + + test("unknown field") { + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + wideSchemaEncoder, + allocator, + timeZoneId = "UTC") + } + arrowBatches.close() + } + } + + test("duplicate fields") { + val duplicateSchemaEncoder = toRowEncoder( + new StructType() + .add("foO", "string") + .add("Foo", "string")) + val fooSchemaEncoder = toRowEncoder( + new StructType() + .add("foo", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + fooSchemaEncoder, + allocator, + timeZoneId = "UTC") + } + arrowBatches.close() + } + } + + case class MyTestClass(value: Int) + OuterScopes.addOuterScope(this) + + test("REPL generated classes") { + val encoder = ScalaReflection.encoderFor[MyTestClass] + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(MyTestClass) + } + } + + /* ******************************************************************** * + * Arrow deserialization upcasting + * ******************************************************************** */ + // Not supported: UDT, CalendarInterval + // Not tested: Char/Varchar. + private case class UpCastTestCase[I](input: AgnosticEncoder[I], generator: Int => I) { + def test[O](output: AgnosticEncoder[O], convert: I => O): this.type = { + val name = "upcast " + input.dataType.catalogString + " to " + output.dataType.catalogString + ArrowEncoderSuite.this.test(name) { + def data: Iterator[I] = Iterator.tabulate(5)(generator) + val result = roundTripWithDifferentIOEncoders(input, output, data) + try { + compareIterators(data.map(convert), result) + } finally { + result.close() + } + } + this + } + + def nullTest[O](e: AgnosticEncoder[O]): this.type = { + test(e, _.asInstanceOf[O]) + } + } + + private val timestampFormatter = TimestampFormatter.getFractionFormatter(ZoneOffset.UTC) + private val dateFormatter = DateFormatter() + + private def scalaDecimalEncoder(precision: Int, scale: Int = 0): ScalaDecimalEncoder = { + ScalaDecimalEncoder(DecimalType(precision, scale)) + } + + UpCastTestCase(NullEncoder, _ => null) + .nullTest(BoxedBooleanEncoder) + .nullTest(BoxedByteEncoder) + .nullTest(BoxedShortEncoder) + .nullTest(BoxedIntEncoder) + .nullTest(BoxedLongEncoder) + .nullTest(BoxedFloatEncoder) + .nullTest(BoxedDoubleEncoder) + .nullTest(StringEncoder) + .nullTest(DateEncoder(false)) + .nullTest(TimestampEncoder(false)) + UpCastTestCase(PrimitiveBooleanEncoder, _ % 2 == 0) + .test(StringEncoder, _.toString) + UpCastTestCase(PrimitiveByteEncoder, i => i.toByte) + .test(PrimitiveShortEncoder, _.toShort) + .test(PrimitiveIntEncoder, _.toInt) + .test(PrimitiveLongEncoder, _.toLong) + .test(PrimitiveFloatEncoder, _.toFloat) + .test(PrimitiveDoubleEncoder, _.toDouble) + .test(scalaDecimalEncoder(3), BigDecimal(_)) + .test(scalaDecimalEncoder(5, 2), BigDecimal(_)) + .test(StringEncoder, _.toString) + UpCastTestCase(PrimitiveShortEncoder, i => i.toShort) + .test(PrimitiveIntEncoder, _.toInt) + .test(PrimitiveLongEncoder, _.toLong) + .test(PrimitiveFloatEncoder, _.toFloat) + .test(PrimitiveDoubleEncoder, _.toDouble) + .test(scalaDecimalEncoder(5), BigDecimal(_)) + .test(scalaDecimalEncoder(10, 5), BigDecimal(_)) + .test(StringEncoder, _.toString) + UpCastTestCase(PrimitiveIntEncoder, i => i) + .test(PrimitiveLongEncoder, _.toLong) + .test(PrimitiveFloatEncoder, _.toFloat) + .test(PrimitiveDoubleEncoder, _.toDouble) + .test(scalaDecimalEncoder(10), BigDecimal(_)) + .test(scalaDecimalEncoder(13, 3), BigDecimal(_)) + .test(StringEncoder, _.toString) + UpCastTestCase(PrimitiveLongEncoder, i => i.toLong) + .test(PrimitiveFloatEncoder, _.toFloat) + .test(PrimitiveDoubleEncoder, _.toDouble) + .test(scalaDecimalEncoder(20), BigDecimal(_)) + .test(scalaDecimalEncoder(25, 5), BigDecimal(_)) + .test(TimestampEncoder(false), s => toJavaTimestamp(s * MICROS_PER_SECOND)) + .test(StringEncoder, _.toString) + UpCastTestCase(PrimitiveFloatEncoder, i => i.toFloat) + .test(PrimitiveDoubleEncoder, _.toDouble) + .test(StringEncoder, _.toString) + UpCastTestCase(PrimitiveDoubleEncoder, i => i.toDouble) + .test(StringEncoder, _.toString) + UpCastTestCase(scalaDecimalEncoder(2), BigDecimal(_)) + .test(PrimitiveByteEncoder, _.toByte) + .test(PrimitiveShortEncoder, _.toShort) + .test(PrimitiveIntEncoder, _.toInt) + .test(PrimitiveLongEncoder, _.toLong) + .test(scalaDecimalEncoder(7, 5), identity) + .test(StringEncoder, _.toString()) + UpCastTestCase(scalaDecimalEncoder(4), BigDecimal(_)) + .test(PrimitiveShortEncoder, _.toShort) + .test(PrimitiveIntEncoder, _.toInt) + .test(PrimitiveLongEncoder, _.toLong) + .test(scalaDecimalEncoder(10, 1), identity) + .test(StringEncoder, _.toString()) + UpCastTestCase(scalaDecimalEncoder(9), BigDecimal(_)) + .test(PrimitiveIntEncoder, _.toInt) + .test(PrimitiveLongEncoder, _.toLong) + .test(scalaDecimalEncoder(13, 4), identity) + .test(StringEncoder, _.toString()) + UpCastTestCase(scalaDecimalEncoder(19), BigDecimal(_)) + .test(PrimitiveLongEncoder, _.toLong) + .test(scalaDecimalEncoder(23, 1), identity) + .test(StringEncoder, _.toString()) + UpCastTestCase(scalaDecimalEncoder(7, 3), BigDecimal(_)) + .test(scalaDecimalEncoder(9, 5), identity) + .test(scalaDecimalEncoder(23, 3), identity) + UpCastTestCase(DateEncoder(false), i => toJavaDate(i)) + .test( + TimestampEncoder(false), + date => toJavaTimestamp(daysToMicros(fromJavaDate(date), ZoneOffset.UTC))) + .test( + LocalDateTimeEncoder, + date => microsToLocalDateTime(daysToMicros(fromJavaDate(date), ZoneOffset.UTC))) + .test(StringEncoder, date => dateFormatter.format(date)) + UpCastTestCase(TimestampEncoder(false), i => toJavaTimestamp(i)) + .test(PrimitiveLongEncoder, ts => Math.floorDiv(fromJavaTimestamp(ts), MICROS_PER_SECOND)) + .test(LocalDateTimeEncoder, ts => microsToLocalDateTime(fromJavaTimestamp(ts))) + .test(StringEncoder, ts => timestampFormatter.format(ts)) + UpCastTestCase(LocalDateTimeEncoder, i => microsToLocalDateTime(i)) + .test(TimestampEncoder(false), ldt => toJavaTimestamp(localDateTimeToMicros(ldt))) + .test(StringEncoder, ldt => timestampFormatter.format(ldt)) + UpCastTestCase(DayTimeIntervalEncoder, i => Duration.ofDays(i)) + .test( + StringEncoder, + { i => + toDayTimeIntervalString( + durationToMicros(i), + ANSI_STYLE, + DayTimeIntervalType.DEFAULT.startField, + DayTimeIntervalType.DEFAULT.endField) + }) + UpCastTestCase(YearMonthIntervalEncoder, i => Period.ofMonths(i)) + .test( + StringEncoder, + { i => + toYearMonthIntervalString( + periodToMonths(i), + ANSI_STYLE, + YearMonthIntervalType.DEFAULT.startField, + YearMonthIntervalType.DEFAULT.endField) + }) + UpCastTestCase(BinaryEncoder, i => Array.tabulate(10)(j => (64 + j + i).toByte)) + .test(StringEncoder, bytes => SparkStringUtils.getHexString(bytes)) + /* ******************************************************************** * * Arrow serialization/deserialization specific errors * ******************************************************************** */ @@ -833,17 +1071,23 @@ case class MapData(intStringMap: Map[Int, String], metricMap: Map[String, Array[ class JavaMapData { @scala.beans.BeanProperty - var dummyToDoubleListMap: java.util.Map[DummyBean, java.util.List[java.lang.Double]] = _ + var dummyToStringMap: java.util.Map[DummyBean, String] = _ + + @scala.beans.BeanProperty + var metricMap: java.util.HashMap[String, java.util.List[java.lang.Double]] = _ def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData] override def equals(other: Any): Boolean = other match { case that: JavaMapData if that canEqual this => - dummyToDoubleListMap == that.dummyToDoubleListMap + dummyToStringMap == that.dummyToStringMap && + metricMap == that.metricMap case _ => false } - override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap) + override def hashCode(): Int = { + java.util.Arrays.deepHashCode(Array(dummyToStringMap, metricMap)) + } } class DummyBean { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala deleted file mode 100644 index 8d84dffc9d5bd..0000000000000 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connect.client.util - -import java.io.{BufferedOutputStream, File} -import java.util.concurrent.TimeUnit - -import scala.io.Source - -import org.apache.commons.lang3.{JavaVersion, SystemUtils} -import org.scalactic.source.Position -import org.scalatest.{BeforeAndAfterAll, Tag} -import sys.process._ - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.client.SparkConnectClient -import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._ -import org.apache.spark.sql.connect.common.config.ConnectCommon - -/** - * An util class to start a local spark connect server in a different process for local E2E tests. - * Pre-running the tests, the spark connect artifact needs to be built using e.g. `build/sbt - * package`. It is designed to start the server once but shared by all tests. It is equivalent to - * use the following command to start the connect server via command line: - * - * {{{ - * bin/spark-shell \ - * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \ - * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin - * }}} - * - * Set system property `spark.test.home` or env variable `SPARK_HOME` if the test is not executed - * from the Spark project top folder. Set system property `spark.debug.sc.jvm.client=true` to - * print the server process output in the console to debug server start stop problems. - */ -object SparkConnectServerUtils { - - // Server port - private[spark] val port: Int = - ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) - - @volatile private var stopped = false - - private var consoleOut: BufferedOutputStream = _ - private val serverStopCommand = "q" - - private lazy val sparkConnect: Process = { - debug("Starting the Spark Connect Server...") - val connectJar = findJar( - "connector/connect/server", - "spark-connect-assembly", - "spark-connect").getCanonicalPath - - val builder = Process( - Seq( - "bin/spark-submit", - "--driver-class-path", - connectJar, - "--conf", - s"spark.connect.grpc.binding.port=$port") ++ testConfigs ++ debugConfigs ++ Seq( - "--class", - "org.apache.spark.sql.connect.SimpleSparkConnectService", - connectJar), - new File(sparkHome)) - - val io = new ProcessIO( - in => consoleOut = new BufferedOutputStream(in), - out => Source.fromInputStream(out).getLines.foreach(debug), - err => Source.fromInputStream(err).getLines.foreach(debug)) - val process = builder.run(io) - - // Adding JVM shutdown hook - sys.addShutdownHook(stop()) - process - } - - /** - * As one shared spark will be started for all E2E tests, for tests that needs some special - * configs, we add them here - */ - private def testConfigs: Seq[String] = { - // Use InMemoryTableCatalog for V2 writer tests - val writerV2Configs = { - val catalystTestJar = findJar( // To find InMemoryTableCatalog for V2 writer tests - "sql/catalyst", - "spark-catalyst", - "spark-catalyst", - test = true).getCanonicalPath - Seq( - "--jars", - catalystTestJar, - "--conf", - "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog") - } - - // Run tests using hive - val hiveTestConfigs = { - val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { - "hive" - } else { - // scalastyle:off println - println( - "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + - "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + - "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + - "2. Test with sbt: run test with `-Phive` profile") - // scalastyle:on println - // SPARK-43647: Proactively cleaning the `classes` and `test-classes` dir of hive - // module to avoid unexpected loading of `DataSourceRegister` in hive module during - // testing without `-Phive` profile. - IntegrationTestUtils.cleanUpHiveClassesDirIfNeeded() - "in-memory" - } - Seq("--conf", s"spark.sql.catalogImplementation=$catalogImplementation") - } - - // For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests. - val udfTestConfigs = tryFindJar( - "connector/connect/client/jvm", - // SBT passes the client & test jars to the server process automatically. - // So we skip building or finding this jar for SBT. - "sbt-tests-do-not-need-this-jar", - "spark-connect-client-jvm", - test = true) - .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath)) - .getOrElse(Seq.empty) - - writerV2Configs ++ hiveTestConfigs ++ udfTestConfigs - } - - def start(): Unit = { - assert(!stopped) - sparkConnect - } - - def stop(): Int = { - stopped = true - debug("Stopping the Spark Connect Server...") - try { - consoleOut.write(serverStopCommand.getBytes) - consoleOut.flush() - consoleOut.close() - } catch { - case e: Throwable => - debug(e) - sparkConnect.destroy() - } - - val code = sparkConnect.exitValue() - debug(s"Spark Connect Server is stopped with exit code: $code") - code - } -} - -trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { - import SparkConnectServerUtils._ - var spark: SparkSession = _ - protected lazy val serverPort: Int = port - - override def beforeAll(): Unit = { - // TODO(SPARK-44121) Remove this check condition - if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) { - super.beforeAll() - SparkConnectServerUtils.start() - spark = SparkSession - .builder() - .client(SparkConnectClient.builder().port(serverPort).build()) - .create() - - // Retry and wait for the server to start - val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min - var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff - var success = false - val error = new RuntimeException(s"Failed to start the test server on port $serverPort.") - - while (!success && System.nanoTime() < stop) { - try { - // Run a simple query to verify the server is really up and ready - val result = spark - .sql("select val from (values ('Hello'), ('World')) as t(val)") - .collect() - assert(result.length == 2) - success = true - debug("Spark Connect Server is up.") - } catch { - // ignored the error - case e: Throwable => - error.addSuppressed(e) - Thread.sleep(sleepInternalMs) - sleepInternalMs *= 2 - } - } - - // Throw error if failed - if (!success) { - debug(error) - throw error - } - } - } - - override def afterAll(): Unit = { - try { - if (spark != null) spark.stop() - } catch { - case e: Throwable => debug(e) - } - spark = null - super.afterAll() - } - - /** - * SPARK-44259: override test function to skip `RemoteSparkSession-based` tests as default, we - * should delete this function after SPARK-44121 is completed. - */ - override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit - pos: Position): Unit = { - super.test(testName, testTags: _*) { - // TODO(SPARK-44121) Re-enable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) - testFun - } - } -} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala similarity index 94% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 1287176d76e88..dc4d441ec3015 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -27,15 +27,16 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, SQLHelper} -import org.apache.spark.sql.connect.client.util.QueryTest +import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession} import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.window import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent} -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.{QueryTest, SQLHelper} +import org.apache.spark.util.SparkFileUtils -class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { +class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { test("Streaming API with windowed aggregate query") { // This verifies standard streaming API by starting a streaming query with windowed count. @@ -268,6 +269,8 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { q.stop() assert(!q1.isActive) + + assert(spark.streams.get(q.id) == null) } test("streaming query listener") { @@ -294,7 +297,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { spark.sql("DROP TABLE IF EXISTS my_listener_table") } - // List listeners after adding a new listener, length should be 2. + // List listeners after adding a new listener, length should be 1. val listeners = spark.streams.listListeners() assert(listeners.length == 1) @@ -335,7 +338,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { .start() eventually(timeout(30.seconds)) { // Wait for first progress. - assert(q.lastProgress != null) + assert(q.lastProgress != null, "Failed to make progress") assert(q.lastProgress.numInputRows > 0) } @@ -346,7 +349,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { .collect() .toSeq assert(rows.size > 0) - log.info(s"Rows in $tableName: $rows") + logInfo(s"Rows in $tableName: $rows") } q.stop() @@ -359,7 +362,7 @@ class TestForeachWriter[T] extends ForeachWriter[T] { var path: File = _ def open(partitionId: Long, version: Long): Boolean = { - path = Utils.createTempDir() + path = SparkFileUtils.createTempDir() fileWriter = new FileWriter(path, true) true } @@ -371,7 +374,7 @@ class TestForeachWriter[T] extends ForeachWriter[T] { def close(errorOrNull: Throwable): Unit = { fileWriter.close() - Utils.deleteRecursively(path) + SparkFileUtils.deleteRecursively(path) } } @@ -410,11 +413,13 @@ class EventCollector extends StreamingQueryListener { } } -class ForeachBatchFn(val viewName: String) extends ((DataFrame, Long) => Unit) with Serializable { - override def apply(df: DataFrame, batchId: Long): Unit = { +class ForeachBatchFn(val viewName: String) + extends VoidFunction2[DataFrame, java.lang.Long] + with Serializable { + override def call(df: DataFrame, batchId: java.lang.Long): Unit = { val count = df.count() df.sparkSession - .createDataFrame(Seq((batchId, count))) + .createDataFrame(Seq((batchId.toLong, count))) .createOrReplaceGlobalTempView(viewName) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala index cdb6b9a2e9c10..2fab6e8e3c843 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.{SparkSession, SQLHelper} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append -import org.apache.spark.sql.connect.client.util.QueryTest +import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} case class ClickEvent(id: String, timestamp: Timestamp) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala index a6a44c1bd7164..1a72252a41795 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.sql.types.StructType class StreamingQueryProgressSuite extends ConnectFunSuite { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala index 0a1e794c8e72e..8d69d91a34f7d 100755 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.client.util +package org.apache.spark.sql.test import java.nio.file.Path diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala similarity index 79% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala index 0eaca7577b922..3ae9b9fc73b48 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.client.util +package org.apache.spark.sql.test import java.io.File import java.nio.file.{Files, Paths} @@ -23,14 +23,19 @@ import scala.util.Properties.versionNumberString import org.scalatest.Assertions.fail -import org.apache.spark.util.Utils +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.util.SparkFileUtils object IntegrationTestUtils { // System properties used for testing and debugging private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" - // Enable this flag to print all client debug log + server logs to the console - private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean + private val DEBUG_SC_JVM_CLIENT_ENV = "SPARK_DEBUG_SC_JVM_CLIENT" + // Enable this flag to print all server logs to the console + private[sql] val isDebug = { + System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean || + Option(System.getenv(DEBUG_SC_JVM_CLIENT_ENV)).exists(_.toBoolean) + } private[sql] lazy val scalaVersion = { versionNumberString.split('.') match { @@ -48,8 +53,14 @@ object IntegrationTestUtils { sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) } - private[connect] def debugConfigs: Seq[String] = { - val log4j2 = s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties" + private[sql] lazy val connectClientHomeDir = s"$sparkHome/connector/connect/client/jvm" + + private[sql] lazy val connectClientTestClassDir = { + s"$connectClientHomeDir/target/$scalaDir/test-classes" + } + + private[sql] def debugConfigs: Seq[String] = { + val log4j2 = s"$connectClientHomeDir/src/test/resources/log4j2.properties" if (isDebug) { Seq( // Enable to see the server plan change log @@ -63,26 +74,26 @@ object IntegrationTestUtils { // Redirect server log into console "--conf", - s"spark.driver.extraJavaOptions=-Dlog4j.configuration=$log4j2") + s"spark.driver.extraJavaOptions=-Dlog4j.configurationFile=$log4j2") } else Seq.empty } // Log server start stop debug info into console // scalastyle:off println - private[connect] def debug(msg: String): Unit = if (isDebug) println(msg) + private[sql] def debug(msg: String): Unit = if (isDebug) println(msg) // scalastyle:on println - private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() + private[sql] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() private[sql] lazy val isSparkHiveJarAvailable: Boolean = { val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/" + - s"spark-hive_$scalaVersion-${org.apache.spark.SPARK_VERSION}.jar" + s"spark-hive_$scalaVersion-$SPARK_VERSION.jar" Files.exists(Paths.get(filePath)) } private[sql] def cleanUpHiveClassesDirIfNeeded(): Unit = { def delete(f: File): Unit = { if (f.exists()) { - Utils.deleteRecursively(f) + SparkFileUtils.deleteRecursively(f) } } delete(new File(s"$sparkHome/sql/hive/target/$scalaDir/classes")) @@ -105,7 +116,7 @@ object IntegrationTestUtils { val jar = tryFindJar(path, sbtName, mvnName, test).getOrElse({ val suffix = if (test) "-tests.jar" else ".jar" val sbtFileName = s"$sbtName(.*)$suffix" - val mvnFileName = s"$mvnName(.*)${org.apache.spark.SPARK_VERSION}$suffix" + val mvnFileName = s"$mvnName(.*)$SPARK_VERSION$suffix" throw new RuntimeException(s"Failed to find the jar: $sbtFileName or $mvnFileName " + s"inside folder: ${getTargetFilePath(path)}. This file can be generated by similar to " + s"the following command: build/sbt package|assembly") @@ -136,7 +147,7 @@ object IntegrationTestUtils { // Maven Jar (f.getParent.endsWith("target") && f.getName.startsWith(mvnName) && - f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix")) + f.getName.endsWith(s"$SPARK_VERSION$suffix")) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala similarity index 97% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala index fdbb3edbf84a1..adbd8286090d9 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.connect.client.util +package org.apache.spark.sql.test import java.util.TimeZone import org.scalatest.Assertions import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide abstract class QueryTest extends RemoteSparkSession { @@ -122,7 +122,7 @@ object QueryTest extends Assertions { |${df.analyze} |== Exception == |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + |${org.apache.spark.util.SparkErrorUtils.stackTraceToString(e)} """.stripMargin return Some(errorMessage) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala new file mode 100644 index 0000000000000..8a8f739a7c502 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.test + +import java.io.{File, IOException, OutputStream} +import java.lang.ProcessBuilder +import java.lang.ProcessBuilder.Redirect +import java.nio.file.Paths +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.FiniteDuration + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkBuildInfo +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryPolicy +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.test.IntegrationTestUtils._ + +/** + * An util class to start a local spark connect server in a different process for local E2E tests. + * Pre-running the tests, the spark connect artifact needs to be built using e.g. `build/sbt + * package`. It is designed to start the server once but shared by all tests. It is equivalent to + * use the following command to start the connect server via command line: + * + * {{{ + * bin/spark-shell \ + * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \ + * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin + * }}} + * + * Set system property `spark.test.home` or env variable `SPARK_HOME` if the test is not executed + * from the Spark project top folder. Set system property `spark.debug.sc.jvm.client=true` or + * environment variable `SPARK_DEBUG_SC_JVM_CLIENT=true` to print the server process output in the + * console to debug server start stop problems. + */ +object SparkConnectServerUtils { + + // Server port + val port: Int = + ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) + + @volatile private var stopped = false + + private var consoleOut: OutputStream = _ + private val serverStopCommand = "q" + + private lazy val sparkConnect: java.lang.Process = { + debug("Starting the Spark Connect Server...") + val connectJar = findJar( + "connector/connect/server", + "spark-connect-assembly", + "spark-connect").getCanonicalPath + + val command = Seq.newBuilder[String] + command += "bin/spark-submit" + command += "--driver-class-path" += connectJar + command += "--class" += "org.apache.spark.sql.connect.SimpleSparkConnectService" + command += "--conf" += s"spark.connect.grpc.binding.port=$port" + command ++= testConfigs + command ++= debugConfigs + command += connectJar + val builder = new ProcessBuilder(command.result(): _*) + builder.directory(new File(sparkHome)) + val environment = builder.environment() + environment.remove("SPARK_DIST_CLASSPATH") + if (isDebug) { + builder.redirectError(Redirect.INHERIT) + builder.redirectOutput(Redirect.INHERIT) + } + + val process = builder.start() + consoleOut = process.getOutputStream + + // Adding JVM shutdown hook + sys.addShutdownHook(stop()) + process + } + + /** + * As one shared spark will be started for all E2E tests, for tests that needs some special + * configs, we add them here + */ + private def testConfigs: Seq[String] = { + // To find InMemoryTableCatalog for V2 writer tests + val catalystTestJar = + findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true).getCanonicalPath + + val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { + "hive" + } else { + // scalastyle:off println + println( + "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + + "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + + "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + + "2. Test with sbt: run test with `-Phive` profile") + // scalastyle:on println + // SPARK-43647: Proactively cleaning the `classes` and `test-classes` dir of hive + // module to avoid unexpected loading of `DataSourceRegister` in hive module during + // testing without `-Phive` profile. + IntegrationTestUtils.cleanUpHiveClassesDirIfNeeded() + "in-memory" + } + val confs = Seq( + // Use InMemoryTableCatalog for V2 writer tests + "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog", + // Try to use the hive catalog, fallback to in-memory if it is not there. + "spark.sql.catalogImplementation=" + catalogImplementation, + // Make the server terminate reattachable streams every 1 second and 123 bytes, + // to make the tests exercise reattach. + "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", + "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Disable UI + "spark.ui.enabled=false") + Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) + } + + def start(): Unit = { + assert(!stopped) + sparkConnect + } + + def stop(): Int = { + stopped = true + debug("Stopping the Spark Connect Server...") + try { + consoleOut.write(serverStopCommand.getBytes) + consoleOut.flush() + consoleOut.close() + if (!sparkConnect.waitFor(2, TimeUnit.SECONDS)) { + sparkConnect.destroyForcibly() + } + val code = sparkConnect.exitValue() + debug(s"Spark Connect Server is stopped with exit code: $code") + code + } catch { + case e: IOException if e.getMessage.contains("Stream closed") => + -1 + case e: Throwable => + debug(e) + sparkConnect.destroyForcibly() + throw e + } + } + + def syncTestDependencies(spark: SparkSession): Unit = { + // Both SBT & Maven pass the test-classes as a directory instead of a jar. + val testClassesPath = Paths.get(IntegrationTestUtils.connectClientTestClassDir) + spark.client.artifactManager.addClassDir(testClassesPath) + + // We need scalatest & scalactic on the session's classpath to make the tests work. + val jars = System + .getProperty("java.class.path") + .split(File.pathSeparatorChar) + .filter { e: String => + val fileName = e.substring(e.lastIndexOf(File.separatorChar) + 1) + fileName.endsWith(".jar") && + (fileName.startsWith("scalatest") || fileName.startsWith("scalactic")) + } + .map(e => Paths.get(e).toUri) + spark.client.artifactManager.addArtifacts(jars) + } + + def createSparkSession(): SparkSession = { + SparkConnectServerUtils.start() + + val spark = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .retryPolicy(RetryPolicy(maxRetries = 7, maxBackoff = FiniteDuration(10, "s"))) + .build()) + .create() + + // Execute an RPC which will get retried until the server is up. + assert(spark.version == SparkBuildInfo.spark_version) + + // Auto-sync dependencies. + SparkConnectServerUtils.syncTestDependencies(spark) + + spark + } +} + +trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll { + import SparkConnectServerUtils._ + var spark: SparkSession = _ + protected lazy val serverPort: Int = port + + override def beforeAll(): Unit = { + super.beforeAll() + spark = createSparkSession() + } + + override def afterAll(): Unit = { + try { + if (spark != null) spark.stop() + } catch { + case e: Throwable => debug(e) + } + spark = null + super.afterAll() + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala similarity index 89% rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala index 5603099fd4975..12212492e370b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala @@ -14,15 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.test import java.io.File import java.util.UUID import org.scalatest.Assertions.fail -import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE -import org.apache.spark.util.Utils +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils} trait SQLHelper { @@ -77,7 +77,7 @@ trait SQLHelper { try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE $DEFAULT_DATABASE") + spark.sql(s"USE default") } spark.sql(s"DROP DATABASE $dbName CASCADE") } @@ -88,17 +88,17 @@ trait SQLHelper { * If a file/directory is created there by `f`, it will be delete after `f` returns. */ protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() + val path = SparkFileUtils.createTempDir() path.delete() try f(path) - finally Utils.deleteRecursively(path) + finally SparkFileUtils.deleteRecursively(path) } /** * Drops table `tableName` after calling `f`. */ protected def withTable(tableNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f) { + SparkErrorUtils.tryWithSafeFinally(f) { tableNames.foreach { name => spark.sql(s"DROP TABLE IF EXISTS $name").collect() } diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml index 1890384b51db5..a4f010f7076d4 100644 --- a/connector/connect/common/pom.xml +++ b/connector/connect/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../../pom.xml @@ -36,15 +36,8 @@ org.apache.spark - spark-catalyst_${scala.binary.version} + spark-sql-api_${scala.binary.version} ${project.version} - provided - - - com.google.guava - guava - - org.scala-lang @@ -53,7 +46,6 @@ com.google.protobuf protobuf-java - compile io.grpc @@ -79,25 +71,21 @@ io.netty netty-codec-http2 ${netty.version} - provided io.netty netty-handler-proxy ${netty.version} - provided io.netty netty-transport-native-unix-common ${netty.version} - provided org.apache.tomcat annotations-api ${tomcat.annotations.api.version} - provided + + {formatDurationVerbose(executionTime)} + + + {formatDurationVerbose(duration)} + + + + {info.statement} + + + + {if (info.isExecutionActive) "RUNNING" else info.state} + + + {info.operationId} + + + {info.jobTag} + + + {sqlStatsTableRow.sparkSessionTags.mkString(", ")} + + {errorMessageCell(Option(info.detail))} + + } + + private def errorMessageCell(errorMessageOption: Option[String]): Seq[Node] = { + val errorMessage = errorMessageOption.getOrElse("") + val isMultiline = errorMessage.indexOf('\n') >= 0 + val errorSummary = StringEscapeUtils.escapeHtml4(if (isMultiline) { + errorMessage.substring(0, errorMessage.indexOf('\n')) + } else { + errorMessage + }) + val details = detailsUINode(isMultiline, errorMessage) + + {errorSummary}{details} + + } + + private def jobURL(request: HttpServletRequest, jobId: String): String = + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + + private def sqlURL(request: HttpServletRequest, sqlExecId: String): String = + "%s/SQL/execution/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), sqlExecId) +} + +private[ui] class SessionStatsPagedTable( + request: HttpServletRequest, + parent: SparkConnectServerTab, + data: Seq[SessionInfo], + subPath: String, + basePath: String, + sessionStatsTableTag: String) + extends PagedTable[SessionInfo] { + + private val (sortColumn, desc, pageSize) = + getTableParameters(request, sessionStatsTableTag, "Start Time") + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + + private val parameterPath = + s"$basePath/$subPath/?${getParameterOtherTable(request, sessionStatsTableTag)}" + + override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) + + override def tableId: String = sessionStatsTableTag + + override def tableCssClass: String = + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" + + override def pageLink(page: Int): String = { + parameterPath + + s"&$pageNumberFormField=$page" + + s"&$sessionStatsTableTag.sort=$encodedSortColumn" + + s"&$sessionStatsTableTag.desc=$desc" + + s"&$pageSizeFormField=$pageSize" + + s"#$sessionStatsTableTag" + } + + override def pageSizeFormField: String = s"$sessionStatsTableTag.pageSize" + + override def pageNumberFormField: String = s"$sessionStatsTableTag.page" + + override def goButtonFormPath: String = + s"$parameterPath&$sessionStatsTableTag.sort=$encodedSortColumn" + + s"&$sessionStatsTableTag.desc=$desc#$sessionStatsTableTag" + + override def headers: Seq[Node] = { + val sessionTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = + Seq( + ("User", true, None), + ("Session ID", true, None), + ("Start Time", true, None), + ("Finish Time", true, None), + ("Duration", true, Some(SPARK_CONNECT_SESSION_DURATION)), + ("Total Execute", true, Some(SPARK_CONNECT_SESSION_TOTAL_EXECUTE))) + + isSortColumnValid(sessionTableHeadersAndTooltips, sortColumn) + + headerRow( + sessionTableHeadersAndTooltips, + desc, + pageSize, + sortColumn, + parameterPath, + sessionStatsTableTag, + sessionStatsTableTag) + } + + override def row(session: SessionInfo): Seq[Node] = { + val sessionLink = "%s/%s/session/?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), + parent.prefix, + session.sessionId) + + {session.userId} + {session.sessionId} + {formatDate(session.startTimestamp)} + {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} + {formatDurationVerbose(session.totalTime)} + {session.totalExecution.toString} + + } +} + +private[ui] class SqlStatsTableRow( + val jobTag: String, + val jobId: Seq[String], + val sqlExecId: Seq[String], + val duration: Long, + val executionTime: Long, + val sparkSessionTags: Seq[String], + val executionInfo: ExecutionInfo) + +private[ui] class SqlStatsTableDataSource( + info: Seq[ExecutionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) + extends PagedDataSource[SqlStatsTableRow](pageSize) { + + // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in + // the table so that we can avoid creating duplicate contents during sorting the data + private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = data.slice(from, to) + + private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { + val duration = executionInfo.totalTime(executionInfo.closeTimestamp) + val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) + val jobId = executionInfo.jobId.toSeq.sorted + val sqlExecId = executionInfo.sqlExecId.toSeq.sorted + val sparkSessionTags = executionInfo.sparkSessionTags.toSeq.sorted + + new SqlStatsTableRow( + executionInfo.jobTag, + jobId, + sqlExecId, + duration, + executionTime, + sparkSessionTags, + executionInfo) + } + + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { + val ordering: Ordering[SqlStatsTableRow] = sortColumn match { + case "User" => Ordering.by(_.executionInfo.userId) + case "Operation ID" => Ordering.by(_.executionInfo.operationId) + case "Job ID" => Ordering.by(_.jobId.headOption) + case "SQL Query ID" => Ordering.by(_.sqlExecId.headOption) + case "Session ID" => Ordering.by(_.executionInfo.sessionId) + case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) + case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) + case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) + case "Execution Time" => Ordering.by(_.executionTime) + case "Duration" => Ordering.by(_.duration) + case "Statement" => Ordering.by(_.executionInfo.statement) + case "State" => Ordering.by(_.executionInfo.state) + case "Detail" => Ordering.by(_.executionInfo.detail) + case "Job Tag" => Ordering.by(_.executionInfo.jobTag) + case "Spark Session Tags" => Ordering.by(_.sparkSessionTags.headOption) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } +} + +private[ui] class SessionStatsTableDataSource( + info: Seq[SessionInfo], + pageSize: Int, + sortColumn: String, + desc: Boolean) + extends PagedDataSource[SessionInfo](pageSize) { + + // Sorting SessionInfo data + private val data = info.sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[SessionInfo] = data.slice(from, to) + + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { + val ordering: Ordering[SessionInfo] = sortColumn match { + case "User" => Ordering.by(_.userId) + case "Session ID" => Ordering.by(_.sessionId) + case "Start Time" => Ordering.by(_.startTimestamp) + case "Finish Time" => Ordering.by(_.finishTimestamp) + case "Duration" => Ordering.by(_.totalTime) + case "Total Execute" => Ordering.by(_.totalExecution) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala new file mode 100644 index 0000000000000..fde6e8da8b63f --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerSessionPage.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.internal.Logging +import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.util.Utils + +/** Page for Spark UI that contains information pertaining to a single Spark Connect session */ +private[ui] class SparkConnectServerSessionPage(parent: SparkConnectServerTab) + extends WebUIPage("session") + with Logging { + + val store = parent.store + private val startTime = parent.startTime + + /** Render the page */ + def render(request: HttpServletRequest): Seq[Node] = { + val sessionId = request.getParameter("id") + require(sessionId != null && sessionId.nonEmpty, "Missing id parameter") + + val content = store.synchronized { // make sure all parts in this page are consistent + store + .getSession(sessionId) + .map { sessionStat => + generateBasicStats() ++ +
++ +

+ User + {sessionStat.userId} + , + Session created at + {formatDate(sessionStat.startTimestamp)} + , + Total run + {sessionStat.totalExecution} + Request(s) +

++ + generateSQLStatsTable(request, sessionStat.sessionId) + } + .getOrElse(
No information to display for session {sessionId}
) + } + UIUtils.headerSparkPage(request, "Spark Connect Session", content, parent) + } + + /** Generate basic stats of the Spark Connect Server */ + private def generateBasicStats(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime.getTime +
    +
  • + Started at: {formatDate(startTime)} +
  • +
  • + Time since start: {formatDurationVerbose(timeSinceStart)} +
  • +
+ } + + /** Generate stats of batch statements of the Spark Connect server */ + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { + val executionList = store.getExecutionList + .filter(_.sessionId == sessionID) + val numStatement = executionList.size + val table = if (numStatement > 0) { + + val sqlTableTag = "sqlsessionstat" + + val sqlTablePage = + Option(request.getParameter(s"$sqlTableTag.page")).map(_.toInt).getOrElse(1) + + try { + Some( + new SqlStatsPagedTable( + request, + parent, + executionList, + "connect/session", + UIUtils.prependBaseUri(request, parent.basePath), + sqlTableTag, + showSessionLink = false).table(sqlTablePage)) + } catch { + case e @ (_: IllegalArgumentException | _: IndexOutOfBoundsException) => + Some(
+

Error while rendering job table:

+
+              {Utils.exceptionString(e)}
+            
+
) + } + } else { + None + } + val content = + +

+ + Request Statistics +

+
++ +
+ {table.getOrElse("No statistics have been generated yet.")} +
+ + content + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerTab.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerTab.scala new file mode 100644 index 0000000000000..c5ea0bf618b52 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerTab.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.ui + +import java.util.Date + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.ui.{SparkUI, SparkUITab} + +private[connect] class SparkConnectServerTab( + val store: SparkConnectServerAppStatusStore, + sparkUI: SparkUI) + extends SparkUITab(sparkUI, "connect") + with Logging { + + override val name = "Connect" + + val parent = sparkUI + val startTime = + try { + sparkUI.store.applicationInfo().attempts.head.startTime + } catch { + case _: NoSuchElementException => new Date(System.currentTimeMillis()) + } + + attachPage(new SparkConnectServerPage(this)) + attachPage(new SparkConnectServerSessionPage(this)) + parent.attachTab(this) + def detach(): Unit = { + parent.detachTab(this) + } + + override def displayOrder: Int = 3 +} + +private[connect] object SparkConnectServerTab { + def getSparkUI(sparkContext: SparkContext): SparkUI = { + sparkContext.ui.getOrElse { + throw QueryExecutionErrors.parentSparkUIToAttachTabNotFoundError() + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/ToolTips.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/ToolTips.scala new file mode 100644 index 0000000000000..9b51ace83c6c1 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/ToolTips.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.ui + +private[ui] object ToolTips { + val SPARK_CONNECT_SERVER_FINISH_TIME = + "Execution finish time, before fetching the results" + + val SPARK_CONNECT_SERVER_CLOSE_TIME = + "Operation close time after fetching the results" + + val SPARK_CONNECT_SERVER_EXECUTION = + "Difference between start time and finish time" + + val SPARK_CONNECT_SERVER_DURATION = + "Difference between start time and close time" + + val SPARK_CONNECT_SESSION_TOTAL_EXECUTE = + "Number of operations submitted in this session" + + val SPARK_CONNECT_SESSION_DURATION = + "Elapsed time since session start, or until closed if the session was closed" + +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index d0f754827dad8..2050ebc01aa01 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} import org.apache.spark.api.python.PythonException import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.ExecuteEventsManager import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.internal.SQLConf @@ -103,32 +104,59 @@ private[connect] object ErrorUtils extends Logging { opType: String, observer: StreamObserver[V], userId: String, - sessionId: String): PartialFunction[Throwable, Unit] = { + sessionId: String, + events: Option[ExecuteEventsManager] = None, + isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { val session = SparkConnectService .getOrCreateIsolatedSession(userId, sessionId) .session val stackTraceEnabled = session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) - { + val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", se) - observer.onError( + ( + se, StatusProto.toStatusRuntimeException( buildStatusFromThrowable(se.getCause, stackTraceEnabled))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( - StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) case e: Throwable => - logError(s"Error during: $opType. UserId: $userId. SessionId: $sessionId.", e) - observer.onError( + ( + e, Status.UNKNOWN .withCause(e) .withDescription(StringUtils.abbreviate(e.getMessage, 2048)) .asRuntimeException()) } + partial + .andThen { case (original, wrapped) => + if (events.isDefined) { + // Errors thrown inside execution are user query errors, return then as INFO. + logInfo( + s"Spark Connect error " + + s"during: $opType. UserId: $userId. SessionId: $sessionId.", + original) + } else { + // Other errors are server RPC errors, return them as ERROR. + logError( + s"Spark Connect RPC error " + + s"during: $opType. UserId: $userId. SessionId: $sessionId.", + original) + } + + // If ExecuteEventsManager is present, this this is an execution error that needs to be + // posted to it. + events.foreach { executeEventsManager => + if (isInterrupted) { + executeEventsManager.postCanceled() + } else { + executeEventsManager.postFailed(wrapped.getMessage) + } + } + observer.onError(wrapped) + } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala index 88120e616efdb..6395fb588ab84 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/MetricGenerator.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper /** * Helper object for generating responses with metrics from queries. @@ -47,12 +47,6 @@ private[connect] object MetricGenerator extends AdaptiveSparkPlanHelper { allChildren(p).flatMap(c => transformPlan(c, p.id)) } - private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { - case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) - case s: QueryStageExec => Seq(s.plan) - case _ => p.children - } - private def transformPlan( p: SparkPlan, parentId: Int): Seq[ExecutePlanResponse.Metrics.MetricObject] = { diff --git a/connector/connect/server/src/test/resources/log4j2.properties b/connector/connect/server/src/test/resources/log4j2.properties index ab02104c69697..f782d7f3aaa48 100644 --- a/connector/connect/server/src/test/resources/log4j2.properties +++ b/connector/connect/server/src/test/resources/log4j2.properties @@ -32,8 +32,13 @@ appender.console.type = Console appender.console.name = console appender.console.target = SYSTEM_ERR appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %t: %m%n%ex +appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n # Ignore messages below warning level from Jetty, because it's a bit verbose logger.jetty.name = org.sparkproject.jetty logger.jetty.level = warn + +# SPARK-44922: Disable o.a.p.h.InternalParquetRecordWriter logs for tests +logger.parquet_recordwriter.name = org.apache.parquet.hadoop.InternalParquetRecordWriter +logger.parquet_recordwriter.additivity = false +logger.parquet_recordwriter.level = off diff --git a/connector/connect/server/src/test/resources/udf b/connector/connect/server/src/test/resources/udf new file mode 100644 index 0000000000000..55a3264a017fd Binary files /dev/null and b/connector/connect/server/src/test/resources/udf differ diff --git a/connector/connect/server/src/test/resources/udf_noA.jar b/connector/connect/server/src/test/resources/udf_noA.jar new file mode 100644 index 0000000000000..4d8c423ab6dfb Binary files /dev/null and b/connector/connect/server/src/test/resources/udf_noA.jar differ diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala new file mode 100644 index 0000000000000..234ee526d438a --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.scalatest.concurrent.{Eventually, TimeLimits} +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, GrpcRetryHandler, SparkConnectClient, WrappedCloseableIterator} +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.dsl.MockRemoteSession +import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.connect.service.{ExecuteHolder, SparkConnectService} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Base class and utilities for a test suite that starts and tests the real SparkConnectService + * with a real SparkConnectClient, communicating over RPC, but both in-process. + */ +trait SparkConnectServerTest extends SharedSparkSession { + + // Server port + val serverPort: Int = + ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) + + val eventuallyTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + // Other suites using mocks leave a mess in the global executionManager, + // shut it down so that it's cleared before starting server. + SparkConnectService.executionManager.shutdown() + // Start the real service. + withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { + SparkConnectService.start(spark.sparkContext) + } + } + + override def afterAll(): Unit = { + SparkConnectService.stop() + super.afterAll() + } + + override def beforeEach(): Unit = { + super.beforeEach() + clearAllExecutions() + } + + override def afterEach(): Unit = { + clearAllExecutions() + super.afterEach() + } + + protected def clearAllExecutions(): Unit = { + SparkConnectService.executionManager.listExecuteHolders.foreach(_.close()) + SparkConnectService.executionManager.periodicMaintenance(0) + assertNoActiveExecutions() + } + + protected val defaultSessionId = UUID.randomUUID.toString() + protected val defaultUserId = UUID.randomUUID.toString() + + // We don't have the real SparkSession/Dataset api available, + // so use mock for generating simple query plans. + protected val dsl = new MockRemoteSession() + + protected val userContext = proto.UserContext + .newBuilder() + .setUserId(defaultUserId) + .build() + + protected def buildExecutePlanRequest( + plan: proto.Plan, + sessionId: String = defaultSessionId, + operationId: String = UUID.randomUUID.toString) = { + proto.ExecutePlanRequest + .newBuilder() + .setUserContext(userContext) + .setSessionId(sessionId) + .setOperationId(operationId) + .setPlan(plan) + .addRequestOptions( + proto.ExecutePlanRequest.RequestOption + .newBuilder() + .setReattachOptions(proto.ReattachOptions.newBuilder().setReattachable(true).build()) + .build()) + .build() + } + + protected def buildReattachExecuteRequest(operationId: String, responseId: Option[String]) = { + val req = proto.ReattachExecuteRequest + .newBuilder() + .setUserContext(userContext) + .setSessionId(defaultSessionId) + .setOperationId(operationId) + + if (responseId.isDefined) { + req.setLastResponseId(responseId.get) + } + + req.build() + } + + protected def buildPlan(query: String) = { + proto.Plan.newBuilder().setRoot(dsl.sql(query)).build() + } + + protected def getReattachableIterator( + stubIterator: CloseableIterator[proto.ExecutePlanResponse]) = { + // This depends on the wrapping in CustomSparkConnectBlockingStub.executePlanReattachable: + // GrpcExceptionConverter.convertIterator + stubIterator + .asInstanceOf[WrappedCloseableIterator[proto.ExecutePlanResponse]] + // ExecutePlanResponseReattachableIterator + .innerIterator + .asInstanceOf[ExecutePlanResponseReattachableIterator] + } + + protected def assertNoActiveRpcs(): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // nothing running, good + case Right(executions) => + // all rpc detached. + assert( + executions.forall(_.lastAttachedRpcTime.isDefined), + s"Expected no RPCs, but got $executions") + } + } + + protected def assertEventuallyNoActiveRpcs(): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertNoActiveRpcs() + } + } + + protected def assertNoActiveExecutions(): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // cleaned up + case Right(executions) => fail(s"Expected empty, but got $executions") + } + } + + protected def assertEventuallyNoActiveExecutions(): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertNoActiveExecutions() + } + } + + protected def assertExecutionReleased(operationId: String): Unit = { + SparkConnectService.executionManager.listActiveExecutions match { + case Left(_) => // cleaned up + case Right(executions) => assert(!executions.exists(_.operationId == operationId)) + } + } + + protected def assertEventuallyExecutionReleased(operationId: String): Unit = { + Eventually.eventually(timeout(eventuallyTimeout)) { + assertExecutionReleased(operationId) + } + } + + // Get ExecutionHolder, assuming that only one execution is active + protected def getExecutionHolder: ExecuteHolder = { + val executions = SparkConnectService.executionManager.listExecuteHolders + assert(executions.length == 1) + executions.head + } + + protected def withClient(f: SparkConnectClient => Unit): Unit = { + val client = SparkConnectClient + .builder() + .port(serverPort) + .sessionId(defaultSessionId) + .userId(defaultUserId) + .enableReattachableExecute() + .build() + try f(client) + finally { + client.shutdown() + } + } + + protected def withRawBlockingStub( + f: proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub => Unit): Unit = { + val conf = SparkConnectClient.Configuration(port = serverPort) + val channel = conf.createChannel() + val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + try f(bstub) + finally { + channel.shutdownNow() + } + } + + protected def withCustomBlockingStub( + retryPolicy: GrpcRetryHandler.RetryPolicy = GrpcRetryHandler.RetryPolicy())( + f: CustomSparkConnectBlockingStub => Unit): Unit = { + val conf = SparkConnectClient.Configuration(port = serverPort) + val channel = conf.createChannel() + val bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy) + try f(bstub) + finally { + channel.shutdownNow() + } + } + + protected def runQuery(plan: proto.Plan, queryTimeout: Span, iterSleep: Long): Unit = { + withClient { client => + TimeLimits.failAfter(queryTimeout) { + val iter = client.execute(plan) + var operationId: Option[String] = None + var r: proto.ExecutePlanResponse = null + val reattachableIter = getReattachableIterator(iter) + while (iter.hasNext) { + r = iter.next() + operationId match { + case None => operationId = Some(r.getOperationId) + case Some(id) => assert(r.getOperationId == id) + } + if (iterSleep > 0) { + Thread.sleep(iterSleep) + } + } + // Check that last response had ResultComplete indicator + assert(r != null) + assert(r.hasResultComplete) + // ... that client sent ReleaseExecute based on it + assert(reattachableIter.resultComplete) + // ... and that the server released the execution. + assert(operationId.isDefined) + assertEventuallyExecutionReleased(operationId.get) + } + } + } + + protected def runQuery(query: String, queryTimeout: Span, iterSleep: Long = 0): Unit = { + val plan = buildPlan(query) + runQuery(plan, queryTimeout, iterSleep) + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 199290327cf89..fa3b7d52379c9 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.artifact import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} +import java.util.UUID import org.apache.commons.io.FileUtils @@ -96,7 +97,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString()) sessionHolder.addArtifact(remotePath, stagingPath, None) val movedClassFile = SparkConnectArtifactManager.artifactRootPath @@ -208,9 +210,11 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } test("Classloaders for spark sessions are isolated") { - val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1") - val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2") - val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", "session3") + // use same sessionId - different users should still make it isolated. + val sessionId = UUID.randomUUID.toString() + val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId) + val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", sessionId) + val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", sessionId) def addHelloClass(holder: SessionHolder): Unit = { val copyDir = Utils.createTempDir().toPath @@ -267,7 +271,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") - val sessionHolder = SparkConnectService.getOrCreateIsolatedSession("c1", "session") + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString) sessionHolder.addArtifact(remotePath, stagingPath, None) val sessionDirectory = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala new file mode 100644 index 0000000000000..bde9a71fa17e6 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.artifact + +import java.io.File + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader} + +class StubClassLoaderSuite extends SparkFunSuite { + + // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. + private val udfNoAJar = new File("src/test/resources/udf_noA.jar").toURI.toURL + private val classDummyUdf = "org.apache.spark.sql.connect.client.StubClassDummyUdf" + private val classA = "org.apache.spark.sql.connect.client.A" + + test("find class with stub class") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + val cls = cl.findClass("my.name.HelloWorld") + assert(cls.getName === "my.name.HelloWorld") + assert(cl.lastStubbed === "my.name.HelloWorld") + } + + test("class for name with stub class") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + // scalastyle:off classforname + val cls = Class.forName("my.name.HelloWorld", false, cl) + // scalastyle:on classforname + assert(cls.getName === "my.name.HelloWorld") + assert(cl.lastStubbed === "my.name.HelloWorld") + } + + test("filter class to stub") { + val list = "my.name" :: Nil + val cl = StubClassLoader(getClass().getClassLoader(), list) + val cls = cl.findClass("my.name.HelloWorld") + assert(cls.getName === "my.name.HelloWorld") + + intercept[ClassNotFoundException] { + cl.findClass("name.my.GoodDay") + } + } + + test("call stub class default constructor") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + // scalastyle:off classforname + val cls = Class.forName("my.name.HelloWorld", false, cl) + // scalastyle:on classforname + assert(cl.lastStubbed === "my.name.HelloWorld") + val error = intercept[java.lang.reflect.InvocationTargetException] { + cls.getDeclaredConstructor().newInstance() + } + assert( + error.getCause != null && error.getCause.getMessage.contains( + "Fail to initiate the class my.name.HelloWorld because it is stubbed"), + error) + } + + test("stub missing class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + // Install artifact without class A. + val sessionClassLoader = + new ChildFirstURLClassLoader(Array(udfNoAJar), stubClassLoader, sysClassLoader) + // Load udf with A used in the same class. + loadDummyUdf(sessionClassLoader) + // Class A should be stubbed. + assert(stubClassLoader.lastStubbed === classA) + } + + test("unload stub class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + val cl1 = new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader) + + // Failed to load DummyUdf + intercept[Exception] { + loadDummyUdf(cl1) + } + // Successfully stubbed the missing class. + assert(stubClassLoader.lastStubbed === classDummyUdf) + + // Creating a new class loader will unpack the udf correctly. + val cl2 = new ChildFirstURLClassLoader( + Array(udfNoAJar), + stubClassLoader, // even with the same stub class loader. + sysClassLoader) + // Should be able to load after the artifact is added + loadDummyUdf(cl2) + } + + test("throw no such method if trying to access methods on stub class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + val sessionClassLoader = + new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader) + + // Failed to load DummyUdf because of missing methods + assert(intercept[NoSuchMethodException] { + loadDummyUdf(sessionClassLoader) + }.getMessage.contains(classDummyUdf)) + // Successfully stubbed the missing class. + assert(stubClassLoader.lastStubbed === classDummyUdf) + } + + private def loadDummyUdf(sessionClassLoader: ClassLoader): Unit = { + // Load DummyUdf and call a method on it. + // scalastyle:off classforname + val cls = Class.forName(classDummyUdf, false, sessionClassLoader) + // scalastyle:on classforname + cls.getDeclaredMethod("dummy") + + // Load class A used inside DummyUdf + // scalastyle:off classforname + Class.forName(classA, false, sessionClassLoader) + // scalastyle:on classforname + } +} + +class RecordedStubClassLoader(parent: ClassLoader, shouldStub: String => Boolean) + extends StubClassLoader(parent, shouldStub) { + var lastStubbed: String = _ + + override def findClass(name: String): Class[_] = { + if (shouldStub(name)) { + lastStubbed = name + } + super.findClass(name) + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala new file mode 100644 index 0000000000000..00de9fb6fd260 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.execution + +import java.util.UUID + +import io.grpc.StatusRuntimeException +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.sql.connect.SparkConnectServerTest +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +class ReattachableExecuteSuite extends SparkConnectServerTest { + + // Tests assume that this query will result in at least a couple ExecutePlanResponses on the + // stream. If this is no longer the case because of changes in how much is returned in a single + // ExecutePlanResponse, it may need to be adjusted. + val MEDIUM_RESULTS_QUERY = "select * from range(10000000)" + + test("reattach after initial RPC ends") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + + // open the iterator + iter.next() + // expire all RPCs on server + SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + assertEventuallyNoActiveRpcs() + // iterator should reattach + // (but not necessarily at first next, as there might have been messages buffered client side) + while (iter.hasNext && (reattachableIter.innerIterator eq initialInnerIter)) { + iter.next() + } + assert( + reattachableIter.innerIterator ne initialInnerIter + ) // reattach changed the inner iter + } + } + + test("reattach after connection expired") { + withClient { client => + withRawBlockingStub { stub => + // emulate session expiration + SparkConnectService.invalidateSession(defaultUserId, defaultSessionId) + + // session closed, bound to fail immediately + val operationId = UUID.randomUUID().toString + val iter = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + val e = intercept[StatusRuntimeException] { + iter.next() + } + assert(e.getMessage.contains("INVALID_HANDLE.SESSION_NOT_FOUND")) + } + } + } + + test("raw interrupted RPC results in INVALID_CURSOR.DISCONNECTED error") { + withRawBlockingStub { stub => + val iter = stub.executePlan(buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY))) + iter.next() // open the iterator + // interrupt all RPCs on server + SparkConnectService.executionManager.interruptAllRPCs() + assertEventuallyNoActiveRpcs() + val e = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + } + } + + test("raw new RPC interrupts previous RPC with INVALID_CURSOR.DISCONNECTED error") { + // Raw stub does not have retries, auto reattach etc. + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + val iter = stub.executePlan( + buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) + iter.next() // open the iterator + + // send reattach + val iter2 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + iter2.next() // open the iterator + + // should result in INVALID_CURSOR.DISCONNECTED error on the original iterator + val e = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + + // send another reattach + val iter3 = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + assert(iter3.hasNext) + iter3.next() // open the iterator + + // should result in INVALID_CURSOR.DISCONNECTED error on the previous reattach iterator + val e2 = intercept[StatusRuntimeException] { + while (iter2.hasNext) iter2.next() + } + assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + } + } + + test("client INVALID_CURSOR.DISCONNECTED error is retried when rpc sender gets interrupted") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + // open the iterator + iter.next() + + // interrupt all RPCs on server + SparkConnectService.executionManager.interruptAllRPCs() + assertEventuallyNoActiveRpcs() + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + + test("client INVALID_CURSOR.DISCONNECTED error is retried when other RPC preempts this one") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + // open the iterator + val response = iter.next() + + // Send another Reattach request, it should preempt this request with an + // INVALID_CURSOR.DISCONNECTED error. + withRawBlockingStub { stub => + val reattachIter = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response.getResponseId))) + assert(reattachIter.hasNext) + } + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + + test("abandoned query gets INVALID_HANDLE.OPERATION_ABANDONED error") { + withClient { client => + val plan = buildPlan("select * from range(100000)") + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val operationId = getReattachableIterator(iter).operationId + // open the iterator + iter.next() + // disconnect and remove on server + SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + assertEventuallyNoActiveRpcs() + SparkConnectService.executionManager.periodicMaintenance(0) + assertNoActiveExecutions() + // check that it throws abandoned error + val e = intercept[SparkException] { + while (iter.hasNext) iter.next() + } + assert(e.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + // check that afterwards, new operation can't be created with the same operationId. + withCustomBlockingStub() { stub => + val executePlanReq = buildExecutePlanRequest(plan, operationId = operationId) + + val iterNonReattachable = stub.executePlan(executePlanReq) + val eNonReattachable = intercept[SparkException] { + iterNonReattachable.hasNext + } + assert(eNonReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + val iterReattachable = stub.executePlanReattachable(executePlanReq) + val eReattachable = intercept[SparkException] { + iterReattachable.hasNext + } + assert(eReattachable.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + } + } + } + + test("client releases responses directly after consuming them") { + withClient { client => + val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + val reattachableIter = getReattachableIterator(iter) + val initialInnerIter = reattachableIter.innerIterator + val operationId = getReattachableIterator(iter).operationId + + assert(iter.hasNext) // open iterator + val execution = getExecutionHolder + assert(execution.responseObserver.releasedUntilIndex == 0) + + // get two responses, check on the server that ReleaseExecute releases them afterwards + val response1 = iter.next() + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex == 1) + } + + val response2 = iter.next() + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex == 2) + } + + withRawBlockingStub { stub => + // Reattach after response1 should fail with INVALID_CURSOR.POSITION_NOT_AVAILABLE + val reattach1 = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response1.getResponseId))) + val e = intercept[StatusRuntimeException] { + reattach1.hasNext() + } + assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE")) + + // Reattach after response2 should work + val reattach2 = stub.reattachExecute( + buildReattachExecuteRequest(operationId, Some(response2.getResponseId))) + val response3 = reattach2.next() + val response4 = reattach2.next() + val response5 = reattach2.next() + + // The original client iterator will handle the INVALID_CURSOR.DISCONNECTED error, + // and reconnect back. Since the raw iterator was not releasing responses, client iterator + // should be able to continue where it left off (server shouldn't have released yet) + assert(execution.responseObserver.releasedUntilIndex == 2) + assert(iter.hasNext) + + val r3 = iter.next() + assert(r3.getResponseId == response3.getResponseId) + val r4 = iter.next() + assert(r4.getResponseId == response4.getResponseId) + val r5 = iter.next() + assert(r5.getResponseId == response5.getResponseId) + // inner iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) + } + } + } + + test("server releases responses automatically when client moves ahead") { + withRawBlockingStub { stub => + val operationId = UUID.randomUUID().toString + val iter = stub.executePlan( + buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) + var lastSeenResponse: String = null + val serverRetryBuffer = SparkEnv.get.conf + .get(Connect.CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE) + .toLong + + iter.hasNext // open iterator + val execution = getExecutionHolder + + // after consuming enough from the iterator, server should automatically start releasing + var lastSeenIndex = 0 + var totalSizeSeen = 0 + while (iter.hasNext && totalSizeSeen <= 1.1 * serverRetryBuffer) { + val r = iter.next() + lastSeenResponse = r.getResponseId() + totalSizeSeen += r.getSerializedSize + lastSeenIndex += 1 + } + assert(iter.hasNext) + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex > 0) + } + + // Reattach from the beginning is not available. + val reattach = stub.reattachExecute(buildReattachExecuteRequest(operationId, None)) + val e = intercept[StatusRuntimeException] { + reattach.hasNext() + } + assert(e.getMessage.contains("INVALID_CURSOR.POSITION_NOT_AVAILABLE")) + + // Original iterator got disconnected by the reattach and gets INVALID_CURSOR.DISCONNECTED + val e2 = intercept[StatusRuntimeException] { + while (iter.hasNext) iter.next() + } + assert(e2.getMessage.contains("INVALID_CURSOR.DISCONNECTED")) + + Eventually.eventually(timeout(eventuallyTimeout)) { + // Even though we didn't consume more from the iterator, the server thinks that + // it sent more, because GRPC stream onNext() can push into internal GRPC buffer without + // client picking it up. + assert(execution.responseObserver.highestConsumedIndex > lastSeenIndex) + } + // but CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE is big enough that the last + // response we've seen is still in range + assert(execution.responseObserver.releasedUntilIndex < lastSeenIndex) + + // and a new reattach can continue after what there. + val reattach2 = + stub.reattachExecute(buildReattachExecuteRequest(operationId, Some(lastSeenResponse))) + assert(reattach2.hasNext) + while (reattach2.hasNext) reattach2.next() + } + } + + // A few integration tests with large results. + // They should run significantly faster than the LARGE_QUERY_TIMEOUT + // - big query (4 seconds, 871 milliseconds) + // - big query and slow client (7 seconds, 288 milliseconds) + // - big query with frequent reattach (1 second, 527 milliseconds) + // - big query with frequent reattach and slow client (7 seconds, 365 milliseconds) + // - long sleeping query (10 seconds, 805 milliseconds) + + // intentionally smaller than CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, + // so that reattach deadline doesn't "unstuck" if something got stuck. + val LARGE_QUERY_TIMEOUT = 100.seconds + + val LARGE_RESULTS_QUERY = s"select id, " + + (1 to 20).map(i => s"cast(id as string) c$i").mkString(", ") + + s" from range(1000000)" + + test("big query") { + // regular query with large results + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + + test("big query and slow client") { + // regular query with large results, but client is slow so sender will need to control flow + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + + test("big query with frequent reattach") { + // will reattach every 100kB + withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key, "100k")) { + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("big query with frequent reattach and slow client") { + // will reattach every 100kB, and in addition the client is slow, + // so sender will need to control flow + withSparkEnvConfs((Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE.key, "100k")) { + runQuery(LARGE_RESULTS_QUERY, LARGE_QUERY_TIMEOUT, iterSleep = 50) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("long sleeping query") { + // register udf directly on the server, we're not testing client UDFs here... + val serverSession = + SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session + serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) + // query will be sleeping and not returning results, while having multiple reattach + withSparkEnvConfs( + (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, "1s")) { + runQuery("select sleep(10000) as s", 30.seconds) + // Check that execution is released on the server. + assertEventuallyNoActiveExecutions() + } + } + + test("Async cleanup callback gets called after the execution is closed") { + withClient { client => + val query1 = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + // just creating the iterator is lazy, trigger query1 to be sent. + query1.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 1) + } + val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head + // Close execution + SparkConnectService.executionManager.removeExecuteHolder(executeHolder1.key) + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 0) + } + // Check the async execute cleanup get called + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(executeHolder1.completionCallbackCalled) + } + } + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index a10540676b04e..0caa02a0b6112 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto @@ -35,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionHolder, SessionStatus, SparkConnectService} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -58,8 +57,9 @@ trait SparkConnectPlanTest extends SharedSparkSession { } def transform(cmd: proto.Command): Unit = { - new SparkConnectPlanner(SessionHolder.forTesting(spark)) - .process(cmd, "clientId", "sessionId", new MockObserver()) + val executeHolder = buildExecutePlanHolder(cmd) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(cmd, new MockObserver(), executeHolder) } def readRel: proto.Relation = @@ -104,7 +104,7 @@ trait SparkConnectPlanTest extends SharedSparkSession { val bytes = ArrowConverters .toBatchWithSchemaIterator( data.iterator, - StructType.fromAttributes(attrs.map(_.toAttribute)), + DataTypeUtils.fromAttributes(attrs.map(_.toAttribute)), Long.MaxValue, Long.MaxValue, null, @@ -114,6 +114,29 @@ trait SparkConnectPlanTest extends SharedSparkSession { localRelationBuilder.setData(ByteString.copyFrom(bytes)) proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } + + def buildExecutePlanHolder(command: proto.Command): ExecuteHolder = { + val sessionHolder = SessionHolder.forTesting(spark) + sessionHolder.eventManager.status_(SessionStatus.Started) + + val context = proto.UserContext + .newBuilder() + .setUserId(sessionHolder.userId) + .build() + val plan = proto.Plan + .newBuilder() + .setCommand(command) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setSessionId(sessionHolder.sessionId) + .setUserContext(context) + .build() + val executeHolder = SparkConnectService.executionManager.createExecuteHolder(request) + executeHolder.eventsManager.status_(ExecuteStatus.Started) + executeHolder + } } /** @@ -455,8 +478,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform LocalRelation") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val rows = (0 until 10).map { i => InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i)) } @@ -558,8 +579,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform UnresolvedStar and ExpressionString") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val sql = "SELECT * FROM VALUES (1,'spark',1), (2,'hadoop',2), (3,'kafka',3) AS tab(id, name, value)" val input = proto.Relation @@ -596,8 +615,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform UnresolvedStar with target field") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val rows = (0 until 10).map { i => InternalRow(InternalRow(InternalRow(i, i + 1))) } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 82941d8d72e50..8bc4de8351248 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -21,7 +21,6 @@ import java.nio.file.{Files, Paths} import scala.collection.JavaConverters._ import com.google.protobuf.ByteString -import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.apache.spark.{SparkClassNotFoundException, SparkIllegalArgumentException} import org.apache.spark.connect.proto @@ -31,7 +30,8 @@ import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Observation, import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, Distinct, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.dsl.MockRemoteSession @@ -694,8 +694,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with create") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -723,8 +721,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with create and using") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName()) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -762,8 +758,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with append") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -795,8 +789,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with overwrite") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -850,8 +842,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("WriteTo with overwritePartitions") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("testcat.table_name") { spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) @@ -1050,7 +1040,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val buffer = ArrowConverters .toBatchWithSchemaIterator( Iterator.empty, - StructType.fromAttributes(attributes), + DataTypeUtils.fromAttributes(attributes), Long.MaxValue, Long.MaxValue, null, @@ -1077,7 +1067,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { // Compares proto plan with LogicalPlan. private def comparePlans(connectPlan: proto.Relation, sparkPlan: LogicalPlan): Unit = { + def normalizeDataframeId(plan: LogicalPlan): LogicalPlan = plan transform { + case cm: CollectMetrics => cm.copy(dataframeId = 0) + } val connectAnalyzed = analyzePlan(transform(connectPlan)) - comparePlans(connectAnalyzed, sparkPlan, false) + comparePlans(normalizeDataframeId(connectAnalyzed), normalizeDataframeId(sparkPlan), false) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index bceaada9051e3..06508bfc6a7c2 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -16,30 +16,56 @@ */ package org.apache.spark.sql.connect.planner +import java.util.UUID +import java.util.concurrent.Semaphore + import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.protobuf +import com.google.protobuf.ByteString import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.commons.lang3.{JavaVersion, SystemUtils} +import org.mockito.Mockito.when +import org.scalatest.Tag +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime +import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.CreateDataFrameViewCommand +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ -import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService} -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionHolder, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted} +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog +import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * Testing Connect Service implementation. */ -class SparkConnectServiceSuite extends SharedSparkSession { +class SparkConnectServiceSuite + extends SharedSparkSession + with MockitoSugar + with Logging + with SparkConnectPlanTest { private def sparkSessionHolder = SessionHolder.forTesting(spark) + private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093") test("Test schema in analyze response") { withTable("test") { @@ -131,126 +157,530 @@ class SparkConnectServiceSuite extends SharedSparkSession { } test("SPARK-41224: collect data using arrow") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) - val instance = new SparkConnectService(false) - val connect = new MockRemoteSession() - val context = proto.UserContext - .newBuilder() - .setUserId("c1") - .build() - val plan = proto.Plan - .newBuilder() - .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) - .build() - val request = proto.ExecutePlanRequest - .newBuilder() - .setPlan(plan) - .setUserContext(context) - .build() - - // Execute plan. - @volatile var done = false - val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] - instance.executePlan( - request, - new StreamObserver[proto.ExecutePlanResponse] { - override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v - - override def onError(throwable: Throwable): Unit = throw throwable - - override def onCompleted(): Unit = done = true - }) - - // The current implementation is expected to be blocking. This is here to make sure it is. - assert(done) - - // 4 Partitions + Metrics - assert(responses.size == 6) - - // Make sure the first response is schema only - val head = responses.head - assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) - - // Make sure the last response is metrics only - val last = responses.last - assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) - - val allocator = new RootAllocator() - - // Check the 'data' batches - var expectedId = 0L - var previousEId = 0.0d - responses.tail.dropRight(1).foreach { response => - assert(response.hasArrowBatch) - val batch = response.getArrowBatch - assert(batch.getData != null) - assert(batch.getRowCount == 25) - - val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) - while (reader.loadNextBatch()) { - val root = reader.getVectorSchemaRoot - val idVector = root.getVector(0).asInstanceOf[BigIntVector] - val eidVector = root.getVector(1).asInstanceOf[Float8Vector] - val numRows = root.getRowCount - var i = 0 - while (i < numRows) { - assert(idVector.get(i) == expectedId) - expectedId += 1 - val eid = eidVector.get(i) - assert(eid > previousEId) - previousEId = eid - i += 1 + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted(Some(100)) + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 4 Partitions + Metrics + assert(responses.size == 6) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) + + val allocator = new RootAllocator() + + // Check the 'data' batches + var expectedId = 0L + var previousEId = 0.0d + responses.tail.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + assert(batch.getRowCount == 25) + + val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) + while (reader.loadNextBatch()) { + val root = reader.getVectorSchemaRoot + val idVector = root.getVector(0).asInstanceOf[BigIntVector] + val eidVector = root.getVector(1).asInstanceOf[Float8Vector] + val numRows = root.getRowCount + var i = 0 + while (i < numRows) { + assert(idVector.get(i) == expectedId) + expectedId += 1 + val eid = eidVector.get(i) + assert(eid > previousEId) + previousEId = eid + i += 1 + } } + reader.close() } - reader.close() + allocator.close() } - allocator.close() } - test("SPARK-41165: failures in the arrow collect path should not cause hangs") { - val instance = new SparkConnectService(false) + test("SPARK-44776: LocalTableScanExec") { + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + + val rows = (0L to 5L).map { i => + new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar))) + } + + val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType))) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + + val localRelation = createLocalRelationProto(schema, inputRows) + val plan = proto.Plan + .newBuilder() + .setRoot(localRelation) + .build() + + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) + .build() - // Add an always crashing UDF - val session = SparkConnectService.getOrCreateIsolatedSession("c1", "session").session - val instaKill: Long => Long = { _ => - throw new Exception("Kaboom") + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted(Some(6)) + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 1 Partitions + Metrics + assert(responses.size == 3) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) } - session.udf.register("insta_kill", instaKill) - - val connect = new MockRemoteSession() - val context = proto.UserContext - .newBuilder() - .setUserId("c1") - .build() - val plan = proto.Plan - .newBuilder() - .setRoot(connect.sql("select insta_kill(id) from range(10)")) - .build() - val request = proto.ExecutePlanRequest - .newBuilder() - .setPlan(plan) - .setUserContext(context) - .setSessionId("session") - .build() - - // The observer is executed inside this thread. So - // we can perform the checks inside the observer. - instance.executePlan( - request, - new StreamObserver[proto.ExecutePlanResponse] { - override def onNext(v: proto.ExecutePlanResponse): Unit = { - fail("this should not receive responses") - } + } - override def onError(throwable: Throwable): Unit = { - assert(throwable.isInstanceOf[StatusRuntimeException]) - } + test("SPARK-44657: Arrow batches respect max batch size limit") { + // Set 10 KiB as the batch size limit + val batchSize = 10 * 1024 + withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> batchSize.toString) { + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select * from range(0, 15000, 1, 1)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + } + + override def onError(throwable: Throwable): Unit = { + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 1 schema + 1 metric + at least 2 data batches + assert(responses.size > 3) + + val allocator = new RootAllocator() + + // Check the 'data' batches + responses.tail.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + // Batch size must be <= 70% since we intentionally use this multiplier for the size + // estimator. + assert(batch.getData.size() <= batchSize * 0.7) + } + } + } + + gridTest("SPARK-43923: commands send events")( + Seq( + ( + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()), + Some(0L)), + ( + proto.Command + .newBuilder() + .setSqlCommand(proto.SqlCommand.newBuilder().setSql("show databases").build()), + Some(1L)), + ( + proto.Command + .newBuilder() + .setWriteOperation( + proto.WriteOperation + .newBuilder() + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1"))) + .setPath(Utils.createTempDir().getAbsolutePath) + .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)), + None), + ( + proto.Command + .newBuilder() + .setWriteOperationV2( + proto.WriteOperationV2 + .newBuilder() + .setInput(proto.Relation.newBuilder.setRange( + proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L))) + .setTableName("testcat.testtable") + .setMode(proto.WriteOperationV2.Mode.MODE_CREATE)), + None), + ( + proto.Command + .newBuilder() + .setCreateDataframeView( + CreateDataFrameViewCommand + .newBuilder() + .setName("testview") + .setInput( + proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))), + None), + ( + proto.Command + .newBuilder() + .setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()), + None), + ( + proto.Command + .newBuilder() + .setExtension( + protobuf.Any.pack( + proto.ExamplePluginCommand + .newBuilder() + .setCustomField("SPARK-43923") + .build())), + None), + ( + proto.Command + .newBuilder() + .setWriteStreamOperationStart( + proto.WriteStreamOperationStart + .newBuilder() + .setInput( + proto.Relation + .newBuilder() + .setRead(proto.Read + .newBuilder() + .setIsStreaming(true) + .setDataSource(proto.Read.DataSource.newBuilder().setFormat("rate").build()) + .build()) + .build()) + .setOutputMode("Append") + .setAvailableNow(true) + .setQueryName("test") + .setFormat("memory") + .putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath) + .setPath("test-path") + .build()), + None), + ( + proto.Command + .newBuilder() + .setStreamingQueryCommand( + proto.StreamingQueryCommand + .newBuilder() + .setQueryId( + proto.StreamingQueryInstanceId + .newBuilder() + .setId(DEFAULT_UUID.toString) + .setRunId(DEFAULT_UUID.toString) + .build()) + .setStop(true)), + None), + ( + proto.Command + .newBuilder() + .setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand + .newBuilder() + .setListListeners(true)), + None), + ( + proto.Command + .newBuilder() + .setRegisterFunction( + proto.CommonInlineUserDefinedFunction + .newBuilder() + .setFunctionName("function") + .setPythonUdf( + proto.PythonUDF + .newBuilder() + .setEvalType(100) + .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType)) + .setCommand(ByteString.copyFrom("command".getBytes())) + .setPythonVer("3.10") + .build())), + None))) { case (command, producedNumRows) => + val sessionId = UUID.randomUUID.toString() + withCommandTest(sessionId) { verifyEvents => + val instance = new SparkConnectService(false) + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setCommand(command) + .build() + + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setSessionId(sessionId) + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + verifyEvents.onNext(v) + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onError(throwable) + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + verifyEvents.onCompleted(producedNumRows) + // The current implementation is expected to be blocking. + // This is here to make sure it is. + assert(done) + + // Result + Metrics + if (responses.size > 1) { + assert(responses.size == 2) + + // Make sure the first response result only + val head = responses.head + assert(head.hasSqlCommandResult && !head.hasMetrics) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasSqlCommandResult) + } + } + } + + test("SPARK-43923: canceled request send events") { + val sessionId = UUID.randomUUID.toString + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId).session + val sleep: Long => Long = { time => + Thread.sleep(time) + time + } + session.udf.register("sleep", sleep) + + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select sleep(10000)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(sessionId) + .build() - override def onCompleted(): Unit = { - fail("this should not complete") + val thread = new Thread { + override def run: Unit = { + verifyEvents.listener.semaphoreStarted.acquire() + instance.interrupt( + proto.InterruptRequest + .newBuilder() + .setSessionId(sessionId) + .setUserContext(context) + .setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL) + .build(), + new StreamObserver[proto.InterruptResponse] { + override def onNext(v: proto.InterruptResponse): Unit = {} + + override def onError(throwable: Throwable): Unit = {} + + override def onCompleted(): Unit = {} + }) } - }) + } + thread.start() + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + logInfo(s"$v") + } + + override def onError(throwable: Throwable): Unit = { + verifyEvents.onCanceled + } + + override def onCompleted(): Unit = { + fail("this should not complete") + } + }) + thread.join() + verifyEvents.onCompleted() + } + } + + test("SPARK-41165: failures in the arrow collect path should not cause hangs") { + val sessionId = UUID.randomUUID.toString + withEvents { verifyEvents => + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId).session + val instaKill: Long => Long = { _ => + throw new Exception("Kaboom") + } + session.udf.register("insta_kill", instaKill) + + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select insta_kill(id) from range(10)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(sessionId) + .build() + + // Even though the observer is executed inside this thread, this thread is also executing + // the SparkConnectService. If we throw an exception inside it, it will be caught by + // the ErrorUtils.handleError wrapping instance.executePlan and turned into an onError + // call with StatusRuntimeException, which will be eaten here. + var failures: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer[String]() + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + // The query receives some pre-execution responses such as schema, but should + // never proceed to execution and get query results. + if (v.hasArrowBatch) { + failures += s"this should not receive query results but got $v" + } + } + + override def onError(throwable: Throwable): Unit = { + try { + assert(throwable.isInstanceOf[StatusRuntimeException]) + verifyEvents.onError(throwable) + } catch { + case t: Throwable => + failures += s"assertion $t validating processing onError($throwable)." + } + } + + override def onCompleted(): Unit = { + failures += "this should not complete" + } + }) + assert(failures.isEmpty, s"this should have no failures but got $failures") + verifyEvents.onCompleted() + } } test("Test explain mode in analyze response") { @@ -307,8 +737,6 @@ class SparkConnectServiceSuite extends SharedSparkSession { } test("Test observe response") { - // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 - assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) withTable("test") { spark.sql(""" | CREATE TABLE test (col1 INT, col2 STRING) @@ -341,6 +769,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .newBuilder() .setPlan(plan) .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) .build() // Execute plan. @@ -378,4 +807,112 @@ class SparkConnectServiceSuite extends SharedSparkSession { assert(valuesList.last.hasLong && valuesList.last.getLong == 99) } } + + protected def withCommandTest(sessionId: String)(f: VerifyEvents => Unit): Unit = { + withView("testview") { + withTable("testcat.testtable") { + withSparkConf( + "spark.sql.catalog.testcat" -> classOf[InMemoryPartitionTableCatalog].getName, + Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES.key -> + "org.apache.spark.sql.connect.plugin.ExampleCommandPlugin") { + withEvents { verifyEvents => + val restartedQuery = mock[StreamingQuery] + when(restartedQuery.id).thenReturn(DEFAULT_UUID) + when(restartedQuery.runId).thenReturn(DEFAULT_UUID) + SparkConnectService.streamingSessionManager.registerNewStreamingQuery( + SparkConnectService.getOrCreateIsolatedSession("c1", sessionId), + restartedQuery) + f(verifyEvents) + } + } + } + } + } + + protected def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + pairs.foreach { kv => conf.remove(kv._1) } + } + } + + protected def withEvents(f: VerifyEvents => Unit): Unit = { + val verifyEvents = new VerifyEvents(spark.sparkContext) + spark.sparkContext.addSparkListener(verifyEvents.listener) + Utils.tryWithSafeFinally({ + f(verifyEvents) + SparkConnectService.invalidateAllSessions() + verifyEvents.onSessionClosed() + }) { + verifyEvents.waitUntilEmpty() + spark.sparkContext.removeSparkListener(verifyEvents.listener) + SparkConnectService.invalidateAllSessions() + SparkConnectPluginRegistry.reset() + } + } + + protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])( + testFun: A => Unit): Unit = { + for (param <- params) { + test(testNamePrefix + s" ($param)", testTags: _*)(testFun(param)) + } + } + + class VerifyEvents(val sparkContext: SparkContext) { + val listener: MockSparkListener = new MockSparkListener() + val listenerBus = sparkContext.listenerBus + val LISTENER_BUS_TIMEOUT = 30000 + def executeHolder: ExecuteHolder = { + assert(listener.executeHolder.isDefined) + listener.executeHolder.get + } + def onNext(v: proto.ExecutePlanResponse): Unit = { + if (v.hasSchema) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Analyzed) + } + if (v.hasMetrics) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Finished) + } + } + def onError(throwable: Throwable): Unit = { + assert(executeHolder.eventsManager.hasCanceled.isEmpty) + assert(executeHolder.eventsManager.hasError.isDefined) + } + def onCompleted(producedRowCount: Option[Long] = None): Unit = { + assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) + // The eventsManager is closed asynchronously + Eventually.eventually(timeout(1.seconds)) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } + } + def onCanceled(): Unit = { + assert(executeHolder.eventsManager.hasCanceled.contains(true)) + assert(executeHolder.eventsManager.hasError.isEmpty) + } + def onSessionClosed(): Unit = { + assert(executeHolder.sessionHolder.eventManager.status == SessionStatus.Closed) + } + def onSessionStarted(): Unit = { + assert(executeHolder.sessionHolder.eventManager.status == SessionStatus.Started) + } + def waitUntilEmpty(): Unit = { + listenerBus.waitUntilEmpty(LISTENER_BUS_TIMEOUT) + } + } + class MockSparkListener() extends SparkListener { + val semaphoreStarted = new Semaphore(0) + var executeHolder = Option.empty[ExecuteHolder] + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: SparkListenerConnectOperationStarted => + semaphoreStarted.release() + val sessionHolder = + SparkConnectService.getOrCreateIsolatedSession(e.userId, e.sessionId) + executeHolder = sessionHolder.executeHolder(e.operationId) + case _ => + } + } + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelperSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelperSuite.scala new file mode 100644 index 0000000000000..820a1b047957b --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelperSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.planner + +import java.util.UUID + +import org.mockito.Mockito.times +import org.mockito.Mockito.verify +import org.mockito.Mockito.when +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.sql.test.SharedSparkSession + +class StreamingForeachBatchHelperSuite extends SharedSparkSession with MockitoSugar { + + private def mockQuery(): StreamingQuery = { + val query = mock[StreamingQuery] + val (queryId, runId) = (UUID.randomUUID(), UUID.randomUUID()) + when(query.id).thenReturn(queryId) + when(query.runId).thenReturn(runId) + query + } + + test("CleanerCache functionality: register queries, terminate, full cleanup") { + + val cleaner1 = mock[AutoCloseable] + val cleaner2 = mock[AutoCloseable] + + val query1 = mockQuery() + val query2 = mockQuery() + + val cache = new StreamingForeachBatchHelper.CleanerCache(SessionHolder.forTesting(spark)) + + cache.registerCleanerForQuery(query1, cleaner1) + + // Verify listener is registered. + assert(spark.streams.listListeners().contains(cache.listenerForTesting)) + + cache.registerCleanerForQuery(query2, cleaner2) + + assert(cache.listEntriesForTesting().size == 2) + + // No calls to close yet. + verify(cleaner1, times(0)).close() + + // Terminate query1 + val terminatedEvent = + new StreamingQueryListener.QueryTerminatedEvent(id = query1.id, runId = query1.runId, None) + cache.listenerForTesting.onQueryTerminated(terminatedEvent) + + // This should close 'cleaner1' and remove it from the cache. + verify(cleaner1, times(1)).close() + assert(cache.listEntriesForTesting().size == 1) + + // Clean up remaining entries + verify(cleaner2, times(0)).close() // cleaner2 is not closed yet. + cache.cleanUpAll() // It should be closed now. + verify(cleaner2, times(1)).close() + + // No more entries left in it now. + assert(cache.listEntriesForTesting().isEmpty) + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index 2bdabc7ccc214..fdb9032379419 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.{SparkConnectPlanner, SparkConnectPlanTest} -import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.test.SharedSparkSession class DummyPlugin extends RelationPlugin { @@ -196,8 +195,9 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne .build())) .build() - new SparkConnectPlanner(SessionHolder.forTesting(spark)) - .process(plan, "clientId", "sessionId", new MockObserver()) + val executeHolder = buildExecutePlanHolder(plan) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(plan, new MockObserver(), executeHolder) assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin")) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala index f11c9b2969e52..2e199bff5e7dc 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.io.InputStream import java.nio.file.{Files, Path} +import java.util.UUID import scala.collection.JavaConverters._ import scala.collection.mutable @@ -37,6 +38,8 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { private val CHUNK_SIZE: Int = 32 * 1024 + private val sessionId = UUID.randomUUID.toString() + class DummyStreamObserver(p: Promise[AddArtifactsResponse]) extends StreamObserver[AddArtifactsResponse] { override def onNext(v: AddArtifactsResponse): Unit = p.success(v) @@ -125,7 +128,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -168,7 +171,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val requestBuilder = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBeginChunk(beginChunkedArtifact) @@ -295,7 +298,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -336,7 +339,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val singleChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBatch( proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build()) @@ -353,7 +356,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper { val beginChunkArtifactRequest = AddArtifactsRequest .newBuilder() - .setSessionId("abc") + .setSessionId(sessionId) .setUserContext(context) .setBeginChunk(beginChunkedArtifact) .build() diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala new file mode 100644 index 0000000000000..12e67f2c59c60 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.UUID + +import scala.util.matching.Regex + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{ExecutePlanRequest, Plan, UserContext} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connect.planner.SparkConnectPlanTest +import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.util.{JsonProtocol, ManualClock} + +class ExecuteEventsManagerSuite + extends SparkFunSuite + with MockitoSugar + with SparkConnectPlanTest { + + val DEFAULT_ERROR = "error" + val DEFAULT_CLOCK = new ManualClock() + val DEFAULT_NODE_NAME = "nodeName" + val DEFAULT_TEXT = """limit { + limit: 10 +} +""" + val DEFAULT_USER_ID = "1" + val DEFAULT_USER_NAME = "userName" + val DEFAULT_SESSION_ID = UUID.randomUUID.toString + val DEFAULT_QUERY_ID = UUID.randomUUID.toString + val DEFAULT_CLIENT_TYPE = "clientType" + + test("SPARK-43923: post started") { + val events = setupEvents(ExecuteStatus.Pending) + events.postStarted() + val expectedEvent = SparkListenerConnectOperationStarted( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_USER_NAME, + DEFAULT_TEXT, + Set.empty, + Map.empty) + expectedEvent.planRequest = Some(events.executeHolder.request) + + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationStarted]) + } + + test("SPARK-43923: post analyzed with plan") { + val events = setupEvents(ExecuteStatus.Started) + + val mockPlan = mock[LogicalPlan] + events.postAnalyzed(Some(mockPlan)) + val expectedEvent = SparkListenerConnectOperationAnalyzed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + expectedEvent.analyzedPlan = Some(mockPlan) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationAnalyzed]) + } + + test("SPARK-43923: post analyzed with empty plan") { + val events = setupEvents(ExecuteStatus.Started) + events.postAnalyzed() + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationAnalyzed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis())) + } + + test("SPARK-43923: post readyForExecution") { + val events = setupEvents(ExecuteStatus.Analyzed) + events.postReadyForExecution() + val expectedEvent = SparkListenerConnectOperationReadyForExecution( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationReadyForExecution]) + } + + test("SPARK-43923: post canceled") { + val events = setupEvents(ExecuteStatus.Started) + events.postCanceled() + val expectedEvent = SparkListenerConnectOperationCanceled( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationCanceled]) + } + + test("SPARK-43923: post failed") { + val events = setupEvents(ExecuteStatus.Started) + events.postFailed(DEFAULT_ERROR) + val expectedEvent = SparkListenerConnectOperationFailed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + DEFAULT_ERROR, + Map.empty[String, String]) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationFailed]) + } + + test("SPARK-43923: post finished") { + val events = setupEvents(ExecuteStatus.Started) + events.postFinished() + val expectedEvent = SparkListenerConnectOperationFinished( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationFinished]) + } + + test("SPARK-44776: post finished with row number") { + val events = setupEvents(ExecuteStatus.Started) + events.postFinished(Some(100)) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectOperationFinished( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis(), + Some(100))) + } + + test("SPARK-43923: post closed") { + val events = setupEvents(ExecuteStatus.Finished) + events.postClosed() + val expectedEvent = SparkListenerConnectOperationClosed( + events.executeHolder.jobTag, + DEFAULT_QUERY_ID, + DEFAULT_CLOCK.getTimeMillis()) + verify(events.executeHolder.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post(expectedEvent) + + assert( + JsonProtocol + .sparkEventFromJson(JsonProtocol.sparkEventToJsonString(expectedEvent)) + .isInstanceOf[SparkListenerConnectOperationClosed]) + } + + test("SPARK-43923: Closed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Closed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postCanceled() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Finished wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Finished) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + } + + test("SPARK-43923: Failed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Finished) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + } + + test("SPARK-43923: Canceled wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Canceled) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postCanceled() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postFailed(DEFAULT_ERROR) + } + } + + test("SPARK-43923: ReadyForExecution wrong order throws exception") { + val events = setupEvents(ExecuteStatus.ReadyForExecution) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postAnalyzed() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Analyzed wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Analyzed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postFinished() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Started wrong order throws exception") { + val events = setupEvents(ExecuteStatus.Started) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postReadyForExecution() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + test("SPARK-43923: Started wrong session status") { + val events = setupEvents(ExecuteStatus.Started, SessionStatus.Pending) + assertThrows[IllegalStateException] { + events.postStarted() + } + } + + def setupEvents( + executeStatus: ExecuteStatus, + sessionStatus: SessionStatus = SessionStatus.Started): ExecuteEventsManager = { + val mockSession = mock[SparkSession] + val sessionHolder = SessionHolder(DEFAULT_USER_ID, DEFAULT_SESSION_ID, mockSession) + sessionHolder.eventManager.status_(sessionStatus) + val mockContext = mock[SparkContext] + val mockListenerBus = mock[LiveListenerBus] + val mockSessionState = mock[SessionState] + val mockConf = mock[SQLConf] + when(mockSession.sessionState).thenReturn(mockSessionState) + when(mockSessionState.conf).thenReturn(mockConf) + when(mockConf.stringRedactionPattern).thenReturn(Option.empty[Regex]) + when(mockContext.listenerBus).thenReturn(mockListenerBus) + when(mockSession.sparkContext).thenReturn(mockContext) + + val relation = proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(10)) + .build() + + val executePlanRequest = ExecutePlanRequest + .newBuilder() + .setPlan(Plan.newBuilder().setRoot(relation)) + .setUserContext( + UserContext + .newBuilder() + .setUserId(DEFAULT_USER_ID) + .setUserName(DEFAULT_USER_NAME)) + .setSessionId(DEFAULT_SESSION_ID) + .setOperationId(DEFAULT_QUERY_ID) + .setClientType(DEFAULT_CLIENT_TYPE) + .build() + + val executeHolder = new ExecuteHolder(executePlanRequest, sessionHolder) + + val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK) + eventsManager.status_(executeStatus) + eventsManager + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala index 7f85966f0a7b6..33f6627ee0e10 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala @@ -184,4 +184,14 @@ class InterceptorRegistrySuite extends SharedSparkSession { assert(interceptors.head.isInstanceOf[LoggingInterceptor]) } } + + test("LocalPropertiesCleanupInterceptor initializes when configured in spark conf") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.LocalPropertiesCleanupInterceptor") { + val interceptors = SparkConnectInterceptorRegistry.createConfiguredInterceptors() + assert(interceptors.size == 1) + assert(interceptors.head.isInstanceOf[LocalPropertiesCleanupInterceptor]) + } + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala new file mode 100644 index 0000000000000..7025146b0295b --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SessionEventsManagerSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import org.mockito.Mockito._ +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.planner.SparkConnectPlanTest +import org.apache.spark.util.ManualClock + +class SessionEventsManagerSuite + extends SparkFunSuite + with MockitoSugar + with SparkConnectPlanTest { + + val DEFAULT_ERROR = "error" + val DEFAULT_CLOCK = new ManualClock() + val DEFAULT_NODE_NAME = "nodeName" + val DEFAULT_TEXT = """limit { + limit: 10 +} +""" + val DEFAULT_USER_ID = "1" + val DEFAULT_USER_NAME = "userName" + val DEFAULT_SESSION_ID = "2" + val DEFAULT_QUERY_ID = "3" + val DEFAULT_CLIENT_TYPE = "clientType" + + test("SPARK-43923: post started") { + val events = setupEvents(SessionStatus.Pending) + events.postStarted() + + verify(events.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectSessionStarted( + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_CLOCK.getTimeMillis(), + Map.empty)) + } + + test("SPARK-43923: post closed") { + val events = setupEvents(SessionStatus.Started) + events.postClosed() + + verify(events.sessionHolder.session.sparkContext.listenerBus, times(1)) + .post( + SparkListenerConnectSessionClosed( + DEFAULT_SESSION_ID, + DEFAULT_USER_ID, + DEFAULT_CLOCK.getTimeMillis(), + Map.empty)) + } + + test("SPARK-43923: Started wrong order throws exception") { + val events = setupEvents(SessionStatus.Started) + assertThrows[IllegalStateException] { + events.postStarted() + } + } + + test("SPARK-43923: Closed wrong order throws exception") { + val events = setupEvents(SessionStatus.Closed) + assertThrows[IllegalStateException] { + events.postStarted() + } + assertThrows[IllegalStateException] { + events.postClosed() + } + } + + def setupEvents(status: SessionStatus): SessionEventsManager = { + val mockSession = mock[SparkSession] + val sessionHolder = SessionHolder(DEFAULT_USER_ID, DEFAULT_SESSION_ID, mockSession) + val mockContext = mock[SparkContext] + val mockListenerBus = mock[LiveListenerBus] + when(mockContext.listenerBus).thenReturn(mockListenerBus) + when(mockSession.sparkContext).thenReturn(mockContext) + + val eventsManager = SessionEventsManager(sessionHolder, DEFAULT_CLOCK) + eventsManager.status_(status) + eventsManager + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala similarity index 100% rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala rename to connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index 36f284ec3ca2e..ed3da2c0f7156 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.service import java.util.UUID -import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.duration.DurationInt @@ -36,9 +35,8 @@ import org.apache.spark.util.ManualClock class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSugar { // Creates a manager with short durations for periodic check and expiry. - private def createSessionManager(keepAliveFn: (String, String) => Unit) = { + private def createSessionManager() = { new SparkConnectStreamingQueryCache( - keepAliveFn, clock = new ManualClock(), stoppedQueryInactivityTimeout = 1.minute, // This is on manual clock. sessionPollingPeriod = 20.milliseconds // This is real clock. Used for periodic task. @@ -48,8 +46,6 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug test("Session cache functionality with a streaming query") { // Verifies common happy path for the query cache. Runs a query through its life cycle. - val numKeepAliveCalls = new AtomicInteger(0) - val queryId = UUID.randomUUID().toString val runId = UUID.randomUUID().toString val mockSession = mock[SparkSession] @@ -59,11 +55,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug val sessionHolder = SessionHolder(userId = "test_user_1", sessionId = "test_session_1", session = mockSession) - val sessionMgr = createSessionManager(keepAliveFn = { case (userId, sessionId) => - assert(userId == sessionHolder.userId) - assert(sessionId == sessionHolder.sessionId) - numKeepAliveCalls.incrementAndGet() - }) + val sessionMgr = createSessionManager() val clock = sessionMgr.clock.asInstanceOf[ManualClock] @@ -77,11 +69,6 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery) - eventually(timeout(1.minute)) { - // Verify keep alive function is called a few times. - assert(numKeepAliveCalls.get() >= 5) - } - sessionMgr.getCachedValue(queryId, runId) match { case Some(v) => assert(v.sessionId == sessionHolder.sessionId) @@ -95,7 +82,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug // Query is returned when correct session is used assert(sessionMgr.getCachedQuery(queryId, runId, mockSession).contains(mockQuery)) - // Stop the query. + // Cleanup the query and verify if stop() method has been called. when(mockQuery.isActive).thenReturn(false) val expectedExpiryTimeMs = sessionMgr.clock.getTimeMillis() + 1.minute.toMillis diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala new file mode 100644 index 0000000000000..c9c110dd1e626 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.ui + +import java.util.Properties + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SharedSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config.Status.{ASYNC_TRACKING_ENABLED, LIVE_ENTITY_UPDATE_PERIOD} +import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.sql.connect.config.Connect.{CONNECT_UI_SESSION_LIMIT, CONNECT_UI_STATEMENT_LIMIT} +import org.apache.spark.sql.connect.service._ +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart +import org.apache.spark.status.ElementTrackingStore +import org.apache.spark.util.kvstore.InMemoryStore + +class SparkConnectServerListenerSuite + extends SparkFunSuite + with BeforeAndAfter + with SharedSparkContext { + + private var kvstore: ElementTrackingStore = _ + + private val jobTag = ExecuteJobTag("userId", "sessionId", "operationId") + + after { + if (kvstore != null) { + kvstore.close() + kvstore = null + } + } + + Seq(true, false).foreach { live => + test(s"listener events should store successfully (live = $live)") { + val (statusStore: SparkConnectServerAppStatusStore, listener: SparkConnectServerListener) = + createAppStatusStore(live) + listener.onOtherEvent( + SparkListenerConnectSessionStarted("sessionId", "user", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationStarted( + jobTag, + "operationId", + System.currentTimeMillis(), + "sessionId", + "userId", + "userName", + "dummy query", + Set())) + listener.onOtherEvent( + SparkListenerConnectOperationAnalyzed(jobTag, "operationId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerSQLExecutionStart( + 0, + None, + null, + null, + null, + null, + System.currentTimeMillis(), + null, + Set(jobTag))) + listener.onJobStart( + SparkListenerJobStart(0, System.currentTimeMillis(), Nil, createProperties)) + listener.onOtherEvent( + SparkListenerConnectOperationFinished(jobTag, "sessionId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationClosed(jobTag, "sessionId", System.currentTimeMillis())) + + if (live) { + assert(statusStore.getOnlineSessionNum === 1) + } + + listener.onOtherEvent( + SparkListenerConnectSessionClosed("sessionId", "userId", System.currentTimeMillis())) + + if (!live) { + // To update history store + kvstore.close(false) + } + assert(statusStore.getOnlineSessionNum === 0) + assert(statusStore.getExecutionList.size === 1) + + val storeExecData = statusStore.getExecutionList.head + + assert(storeExecData.jobTag === jobTag) + assert(storeExecData.sessionId === "sessionId") + assert(storeExecData.statement === "dummy query") + assert(storeExecData.jobId === Seq("0")) + assert(storeExecData.sqlExecId === Set("0")) + assert(listener.noLiveData()) + } + } + + Seq(true, false).foreach { live => + test(s"cleanup session if exceeds the threshold (live = $live)") { + val (statusStore: SparkConnectServerAppStatusStore, listener: SparkConnectServerListener) = + createAppStatusStore(live) + var time = 0 + listener.onOtherEvent( + SparkListenerConnectSessionStarted("sessionId1", "user", System.currentTimeMillis())) + time += 1 + listener.onOtherEvent( + SparkListenerConnectSessionStarted("sessionId2", "user", System.currentTimeMillis())) + time += 1 + listener.onOtherEvent(SparkListenerConnectSessionClosed("sessionId1", "userId", time)) + time += 1 + listener.onOtherEvent(SparkListenerConnectSessionClosed("sessionId2", "userId", time)) + listener.onOtherEvent( + SparkListenerConnectSessionStarted("sessionId3", "user", System.currentTimeMillis())) + time += 1 + listener.onOtherEvent(SparkListenerConnectSessionClosed("sessionId3", "userId", time)) + + if (!live) { + kvstore.close(false) + } + assert(statusStore.getOnlineSessionNum === 0) + assert(statusStore.getSessionCount === 1) + assert(statusStore.getSession("sessionId1") === None) + assert(listener.noLiveData()) + } + } + + test( + "update execution info when event reordering causes job and sql" + + " start to come after operation closed") { + val (statusStore: SparkConnectServerAppStatusStore, listener: SparkConnectServerListener) = + createAppStatusStore(true) + listener.onOtherEvent( + SparkListenerConnectSessionStarted("sessionId", "userId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationStarted( + jobTag, + "operationId", + System.currentTimeMillis(), + "sessionId", + "userId", + "userName", + "dummy query", + Set())) + listener.onOtherEvent( + SparkListenerConnectOperationAnalyzed(jobTag, "operationId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationFinished(jobTag, "operationId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationClosed(jobTag, "operationId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerSQLExecutionStart( + 0, + None, + null, + null, + null, + null, + System.currentTimeMillis(), + null, + Set(jobTag))) + listener.onJobStart( + SparkListenerJobStart(0, System.currentTimeMillis(), Nil, createProperties)) + listener.onOtherEvent( + SparkListenerConnectSessionClosed("sessionId", "userId", System.currentTimeMillis())) + val exec = statusStore.getExecution(ExecuteJobTag("userId", "sessionId", "operationId")) + assert(exec.isDefined) + assert(exec.get.jobId === Seq("0")) + assert(exec.get.sqlExecId === Set("0")) + assert(listener.noLiveData()) + } + + test("SPARK-31387 - listener update methods should not throw exception with unknown input") { + val (statusStore: SparkConnectServerAppStatusStore, listener: SparkConnectServerListener) = + createAppStatusStore(true) + + val unknownSession = "unknown_session" + val unknownJob = "unknown_job_tag" + listener.onOtherEvent(SparkListenerConnectSessionClosed(unknownSession, "userId", 0)) + listener.onOtherEvent( + SparkListenerConnectOperationStarted( + ExecuteJobTag("userId", "sessionId", "operationId"), + "operationId", + System.currentTimeMillis(), + unknownSession, + "userId", + "userName", + "dummy query", + Set())) + listener.onOtherEvent( + SparkListenerConnectOperationAnalyzed( + unknownJob, + "operationId", + System.currentTimeMillis())) + listener.onOtherEvent(SparkListenerConnectOperationCanceled(unknownJob, "userId", 0)) + listener.onOtherEvent( + SparkListenerConnectOperationFailed(unknownJob, "operationId", 0, "msg")) + listener.onOtherEvent(SparkListenerConnectOperationFinished(unknownJob, "operationId", 0)) + listener.onOtherEvent(SparkListenerConnectOperationClosed(unknownJob, "operationId", 0)) + } + + private def createProperties: Properties = { + val properties = new Properties() + properties.setProperty(SparkContext.SPARK_JOB_TAGS, jobTag) + properties + } + + private def createAppStatusStore(live: Boolean) = { + val sparkConf = new SparkConf() + sparkConf + .set(ASYNC_TRACKING_ENABLED, false) + .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + SparkEnv.get.conf + .set(CONNECT_UI_SESSION_LIMIT, 1) + .set(CONNECT_UI_STATEMENT_LIMIT, 10) + kvstore = new ElementTrackingStore(new InMemoryStore, sparkConf) + if (live) { + val listener = new SparkConnectServerListener(kvstore, sparkConf) + (new SparkConnectServerAppStatusStore(kvstore), listener) + } else { + ( + new SparkConnectServerAppStatusStore(kvstore), + new SparkConnectServerListener(kvstore, sparkConf, false)) + } + } +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerPageSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerPageSuite.scala new file mode 100644 index 0000000000000..d352f4e2c32f7 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerPageSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.ui + +import java.util.{Calendar, Locale} +import javax.servlet.http.HttpServletRequest + +import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.sql.connect.service._ +import org.apache.spark.status.ElementTrackingStore +import org.apache.spark.util.kvstore.InMemoryStore + +class SparkConnectServerPageSuite + extends SparkFunSuite + with BeforeAndAfter + with SharedSparkContext { + + private var kvstore: ElementTrackingStore = _ + + after { + if (kvstore != null) { + kvstore.close() + kvstore = null + } + } + + /** + * Run a dummy session and return the store + */ + private def getStatusStore: SparkConnectServerAppStatusStore = { + kvstore = new ElementTrackingStore(new InMemoryStore, new SparkConf()) + // val server = mock(classOf[SparkConnectServer], RETURNS_SMART_NULLS) + val sparkConf = new SparkConf + + val listener = new SparkConnectServerListener(kvstore, sparkConf) + val statusStore = new SparkConnectServerAppStatusStore(kvstore) + + listener.onOtherEvent( + SparkListenerConnectSessionStarted("sessionId", "userId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationStarted( + "jobTag", + "operationId", + System.currentTimeMillis(), + "sessionId", + "userId", + "userName", + "dummy query", + Set())) + listener.onOtherEvent( + SparkListenerConnectOperationAnalyzed("jobTag", "dummy plan", System.currentTimeMillis())) + listener.onOtherEvent(SparkListenerJobStart(0, System.currentTimeMillis(), Seq())) + listener.onOtherEvent( + SparkListenerConnectOperationFinished("jobTag", "operationId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectOperationClosed("jobTag", "operationId", System.currentTimeMillis())) + listener.onOtherEvent( + SparkListenerConnectSessionClosed("sessionId", "userId", System.currentTimeMillis())) + + statusStore + } + + test("Spark Connect Server page should load successfully") { + val store = getStatusStore + + val request = mock(classOf[HttpServletRequest]) + val tab = mock(classOf[SparkConnectServerTab], RETURNS_SMART_NULLS) + when(tab.startTime).thenReturn(Calendar.getInstance().getTime) + when(tab.store).thenReturn(store) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + val page = new SparkConnectServerPage(tab) + val html = page.render(request).toString().toLowerCase(Locale.ROOT) + + // session statistics and sql statistics tables should load successfully + assert(html.contains("session statistics (1)")) + assert(html.contains("request statistics (1)")) + assert(html.contains("dummy query")) + + // Pagination support + assert(html.contains("")) + + // Hiding table support + assert( + html.contains("class=\"collapse-aggregated-sessionstat" + + " collapse-table\" onclick=\"collapsetable")) + } + + test("Spark Connect Server session page should load successfully") { + val store = getStatusStore + + val request = mock(classOf[HttpServletRequest]) + when(request.getParameter("id")).thenReturn("sessionId") + val tab = mock(classOf[SparkConnectServerTab], RETURNS_SMART_NULLS) + when(tab.startTime).thenReturn(Calendar.getInstance().getTime) + when(tab.store).thenReturn(store) + when(tab.appName).thenReturn("testing") + when(tab.headerTabs).thenReturn(Seq.empty) + val page = new SparkConnectServerSessionPage(tab) + val html = page.render(request).toString().toLowerCase(Locale.ROOT) + + // session sql statistics table should load successfully + assert(html.contains("request statistics")) + assert(html.contains("userid")) + assert(html.contains("jobtag")) + + // Pagination support + assert(html.contains("")) + + // Hiding table support + assert( + html.contains("collapse-aggregated-sqlsessionstat collapse-table\"" + + " onclick=\"collapsetable")) + } +} diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index cc549487a8b57..d79b232ebeff1 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml @@ -46,22 +46,6 @@ - - com.spotify - docker-client - test - shaded - - - org.apache.httpcomponents - httpclient - test - - - org.apache.httpcomponents - httpcore - test - com.google.guava @@ -112,14 +96,6 @@ hadoop-minikdc test - - - org.glassfish.jersey.bundles.repackaged - jersey-guava - 2.25.1 - test - org.mariadb.jdbc mariadb-java-client @@ -167,5 +143,15 @@ mysql-connector-j test + + com.github.docker-java + docker-java + test + + + com.github.docker-java + docker-java-transport-zerodep + test + diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala index 9b518d61d252f..66e2afbb6effd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2KrbIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.security.PrivilegedExceptionAction import java.sql.Connection import javax.security.auth.login.Configuration -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} +import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS import org.scalatest.time.SpanSugar._ @@ -66,14 +66,15 @@ class DB2KrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { } override def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = { + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = { copyExecutableResource("db2_krb_setup.sh", initDbDir, replaceIp) - hostConfigBuilder.appendBinds( - HostConfig.Bind.from(initDbDir.getAbsolutePath) - .to("/var/custom").readOnly(true).build() - ) + val newBind = new Bind( + initDbDir.getAbsolutePath, + new Volume("/var/custom"), + AccessMode.ro) + hostConfigBuilder.withBinds(hostConfigBuilder.getBinds :+ newBind: _*) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index 40e8cbb6546b5..837382239514a 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -20,14 +20,18 @@ package org.apache.spark.sql.jdbc import java.net.ServerSocket import java.sql.{Connection, DriverManager} import java.util.Properties +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import com.spotify.docker.client._ -import com.spotify.docker.client.DockerClient.{ListContainersParam, LogsParam} -import com.spotify.docker.client.exceptions.ImageNotFoundException -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import com.github.dockerjava.api.DockerClient +import com.github.dockerjava.api.async.{ResultCallback, ResultCallbackTemplate} +import com.github.dockerjava.api.command.CreateContainerResponse +import com.github.dockerjava.api.exception.NotFoundException +import com.github.dockerjava.api.model._ +import com.github.dockerjava.core.{DefaultDockerClientConfig, DockerClientImpl} +import com.github.dockerjava.zerodep.ZerodepDockerHttpClient import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ @@ -88,8 +92,8 @@ abstract class DatabaseOnDocker { * Optional step before container starts */ def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = {} + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = {} } abstract class DockerJDBCIntegrationSuite @@ -97,7 +101,7 @@ abstract class DockerJDBCIntegrationSuite protected val dockerIp = DockerUtils.getDockerIp() val db: DatabaseOnDocker - val connectionTimeout = timeout(5.minutes) + val connectionTimeout = timeout(10.minutes) val keepContainer = sys.props.getOrElse("spark.test.docker.keepContainer", "false").toBoolean val removePulledImage = @@ -111,56 +115,75 @@ abstract class DockerJDBCIntegrationSuite sock.close() port } - private var containerId: String = _ + private var container: CreateContainerResponse = _ private var pulled: Boolean = false protected var jdbcUrl: String = _ override def beforeAll(): Unit = runIfTestsEnabled(s"Prepare for ${this.getClass.getName}") { super.beforeAll() try { - docker = DefaultDockerClient.fromEnv.build() + val config = DefaultDockerClientConfig.createDefaultConfigBuilder.build + val httpClient = new ZerodepDockerHttpClient.Builder() + .dockerHost(config.getDockerHost) + .sslConfig(config.getSSLConfig) + .build() + docker = DockerClientImpl.getInstance(config, httpClient) // Check that Docker is actually up try { - docker.ping() + docker.pingCmd().exec() } catch { case NonFatal(e) => log.error("Exception while connecting to Docker. Check whether Docker is running.") throw e } - // Ensure that the Docker image is installed: try { - docker.inspectImage(db.imageName) + // Ensure that the Docker image is installed: + docker.inspectImageCmd(db.imageName).exec() } catch { - case e: ImageNotFoundException => + case e: NotFoundException => log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") - docker.pull(db.imageName) + docker.pullImageCmd(db.imageName) + .start() + .awaitCompletion(connectionTimeout.value.toSeconds, TimeUnit.SECONDS) pulled = true } - val hostConfigBuilder = HostConfig.builder() - .privileged(db.privileged) - .networkMode("bridge") - .ipcMode(if (db.usesIpc) "host" else "") - .portBindings( - Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) - // Create the database container: - val containerConfigBuilder = ContainerConfig.builder() - .image(db.imageName) - .networkDisabled(false) - .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) - .exposedPorts(s"${db.jdbcPort}/tcp") - if (db.getEntryPoint.isDefined) { - containerConfigBuilder.entrypoint(db.getEntryPoint.get) - } - if (db.getStartupProcessName.isDefined) { - containerConfigBuilder.cmd(db.getStartupProcessName.get) + + docker.pullImageCmd(db.imageName) + .start() + .awaitCompletion(connectionTimeout.value.toSeconds, TimeUnit.SECONDS) + + val hostConfig = HostConfig + .newHostConfig() + .withNetworkMode("bridge") + .withPrivileged(db.privileged) + .withPortBindings(PortBinding.parse(s"$externalPort:${db.jdbcPort}")) + + if (db.usesIpc) { + hostConfig.withIpcMode("host") } - db.beforeContainerStart(hostConfigBuilder, containerConfigBuilder) - containerConfigBuilder.hostConfig(hostConfigBuilder.build()) - val config = containerConfigBuilder.build() + + val containerConfig = new ContainerConfig() + + db.beforeContainerStart(hostConfig, containerConfig) + // Create the database container: - containerId = docker.createContainer(config).id + val createContainerCmd = docker.createContainerCmd(db.imageName) + .withHostConfig(hostConfig) + .withExposedPorts(ExposedPort.tcp(db.jdbcPort)) + .withEnv(db.env.map { case (k, v) => s"$k=$v" }.toList.asJava) + .withNetworkDisabled(false) + + + db.getEntryPoint.foreach(ep => createContainerCmd.withEntrypoint(ep)) + db.getStartupProcessName.foreach(n => createContainerCmd.withCmd(n)) + + container = createContainerCmd.exec() // Start the container and wait until the database can accept JDBC connections: - docker.startContainer(containerId) + docker.startContainerCmd(container.getId).exec() + eventually(connectionTimeout, interval(1.second)) { + val response = docker.inspectContainerCmd(container.getId).exec() + assert(response.getState.getRunning) + } jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) var conn: Connection = null eventually(connectionTimeout, interval(1.second)) { @@ -174,6 +197,7 @@ abstract class DockerJDBCIntegrationSuite } } catch { case NonFatal(e) => + logError(s"Failed to initialize Docker container for ${this.getClass.getName}", e) try { afterAll() } finally { @@ -206,36 +230,35 @@ abstract class DockerJDBCIntegrationSuite def dataPreparation(connection: Connection): Unit private def cleanupContainer(): Unit = { - if (docker != null && containerId != null && !keepContainer) { + if (docker != null && container != null && !keepContainer) { try { - docker.killContainer(containerId) + docker.killContainerCmd(container.getId).exec() } catch { case NonFatal(e) => - val exitContainerIds = - docker.listContainers(ListContainersParam.withStatusExited()).asScala.map(_.id()) - if (exitContainerIds.contains(containerId)) { - logWarning(s"Container $containerId already stopped") - } else { - logWarning(s"Could not stop container $containerId", e) - } + val response = docker.inspectContainerCmd(container.getId).exec() + logWarning(s"Container $container already stopped") + val status = Option(response).map(_.getState.getStatus).getOrElse("unknown") + logWarning(s"Could not stop container $container at stage '$status'", e) } finally { logContainerOutput() - docker.removeContainer(containerId) + docker.removeContainerCmd(container.getId).exec() if (removePulledImage && pulled) { - docker.removeImage(db.imageName) + docker.removeImageCmd(db.imageName).exec() } } } } private def logContainerOutput(): Unit = { - val logStream = docker.logs(containerId, LogsParam.stdout(), LogsParam.stderr()) - try { - logInfo("\n\n===== CONTAINER LOGS FOR container Id: " + containerId + " =====") - logInfo(logStream.readFully()) - logInfo("\n\n===== END OF CONTAINER LOGS FOR container Id: " + containerId + " =====") - } finally { - logStream.close() - } + logInfo("\n\n===== CONTAINER LOGS FOR container Id: " + container + " =====") + docker.logContainerCmd(container.getId) + .withStdOut(true) + .withStdErr(true) + .withFollowStream(true) + .withSince(0).exec( + new ResultCallbackTemplate[ResultCallback[Frame], Frame] { + override def onNext(f: Frame): Unit = logInfo(f.toString) + }) + logInfo("\n\n===== END OF CONTAINER LOGS FOR container Id: " + container + " =====") } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala index 873d5ad1ee43b..49c9e3dba0d7f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MariaDBKrbIntegrationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc import javax.security.auth.login.Configuration -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} +import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider import org.apache.spark.tags.DockerTest @@ -52,17 +52,17 @@ class MariaDBKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { Some("/docker-entrypoint/mariadb_docker_entrypoint.sh") override def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = { + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = { copyExecutableResource("mariadb_docker_entrypoint.sh", entryPointDir, replaceIp) copyExecutableResource("mariadb_krb_setup.sh", initDbDir, replaceIp) - hostConfigBuilder.appendBinds( - HostConfig.Bind.from(entryPointDir.getAbsolutePath) - .to("/docker-entrypoint").readOnly(true).build(), - HostConfig.Bind.from(initDbDir.getAbsolutePath) - .to("/docker-entrypoint-initdb.d").readOnly(true).build() - ) + val binds = + Seq(entryPointDir -> "/docker-entrypoint", initDbDir -> "/docker-entrypoint-initdb.d") + .map { case (from, to) => + new Bind(from.getAbsolutePath, new Volume(to), AccessMode.ro) + } + hostConfigBuilder.withBinds(hostConfigBuilder.getBinds ++ binds: _*) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala new file mode 100644 index 0000000000000..b351b2ad1ec7d --- /dev/null +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +class MsSQLServerDatabaseOnDocker extends DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", + "mcr.microsoft.com/mssql/server:2022-CU12-GDR1-ubuntu-22.04") + override val env = Map( + "SA_PASSWORD" -> "Sapass123", + "ACCEPT_EULA" -> "Y" + ) + override val usesIpc = false + override val jdbcPort: Int = 1433 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index f2614f46bc3f6..443000050a476 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -38,19 +38,7 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") - override val env = Map( - "SA_PASSWORD" -> "Sapass123", - "ACCEPT_EULA" -> "Y" - ) - override val usesIpc = false - override val jdbcPort: Int = 1433 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - } + override val db = new MsSQLServerDatabaseOnDocker override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE TABLE tbl (x INT, y VARCHAR (50))").executeUpdate() diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index dc3acb66ff1f4..cefbe41b64bd3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -43,7 +43,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val usesIpc = false override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&disableMariaDbDriver" } override def dataPreparation(conn: Connection): Unit = { @@ -56,10 +56,14 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " - + "dbl DOUBLE)").executeUpdate() + + "dbl DOUBLE, tiny TINYINT, u_tiny TINYINT UNSIGNED)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " - + "42.75, 1.0000000000000002)").executeUpdate() + + "42.75, 1.0000000000000002, -128, 255)").executeUpdate() + + conn.prepareStatement("INSERT INTO numbers VALUES (null, null, " + + "null, null, null, null, null, null, null, null, null)").executeUpdate() conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + "yr YEAR)").executeUpdate() @@ -74,6 +78,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { "'jumps', 'over', 'the', 'lazy', 'dog', '{\"status\": \"merrily\"}')").executeUpdate() } + def testConnection(): Unit = { + val conn = getConnection() + try { + assert(conn.getClass.getName === "com.mysql.cj.jdbc.ConnectionImpl") + } finally { + conn.close() + } + } + + test("SPARK-47537: ensure use the right jdbc driver") { + testConnection() + } + test("Basic test") { val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) val rows = df.collect() @@ -87,9 +104,9 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { test("Numeric types") { val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) val rows = df.collect() - assert(rows.length == 1) + assert(rows.length == 2) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 9) + assert(types.length == 11) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Long")) assert(types(2).equals("class java.lang.Integer")) @@ -99,6 +116,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(6).equals("class java.math.BigDecimal")) assert(types(7).equals("class java.lang.Double")) assert(types(8).equals("class java.lang.Double")) + assert(types(9).equals("class java.lang.Byte")) + assert(types(10).equals("class java.lang.Short")) assert(rows(0).getBoolean(0) == false) assert(rows(0).getLong(1) == 0x225) assert(rows(0).getInt(2) == 17) @@ -109,6 +128,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getAs[BigDecimal](6).equals(bd)) assert(rows(0).getDouble(7) == 42.75) assert(rows(0).getDouble(8) == 1.0000000000000002) + assert(rows(0).getByte(9) == 0x80.toByte) + assert(rows(0).getShort(10) == 0xff.toShort) } test("Date types") { @@ -194,4 +215,50 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { """.stripMargin.replaceAll("\n", " ")) assert(sql("select x, y from queryOption").collect.toSet == expectedResult) } + + test("SPARK-47666: Check nulls for result set getters") { + val nulls = spark.read.jdbc(jdbcUrl, "numbers", new Properties).tail(1).head + assert(nulls === Row(null, null, null, null, null, null, null, null, null, null, null)) + } + + test("SPARK-44638: Char/Varchar in Custom Schema") { + val df = spark.read.option("url", jdbcUrl) + .option("query", "SELECT c, d from strings") + .option("customSchema", "c CHAR(10), d VARCHAR(10)") + .format("jdbc") + .load() + assert(df.head === Row("brown ", "fox")) + } +} + +/** + * To run this test suite for a specific version (e.g., mysql:8.3.0): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.3.0 + * ./build/sbt -Pdocker-integration-tests + * "docker-integration-tests/testOnly *MySQLOverMariaConnectorIntegrationSuite" + * }}} + */ +@DockerTest +class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite { + + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + } + + override def testConnection(): Unit = { + val conn = getConnection() + try { + assert(conn.getClass.getName === "org.mariadb.jdbc.MariaDbConnection") + } finally { + conn.close() + } + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 407da028b7eb8..70afad781ca25 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -49,7 +49,7 @@ import org.apache.spark.tags.DockerTest * 4. Start docker: sudo service docker start * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests - * "testOnly org.apache.spark.sql.jdbc.OracleIntegrationSuite" + * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.OracleIntegrationSuite" * * A sequence of commands to build the Oracle XE database container image: * $ git clone https://github.com/oracle/docker-images.git @@ -173,8 +173,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark } - // SPARK-43049: Use CLOB instead of VARCHAR(255) for StringType for Oracle jdbc-am"" - test("SPARK-12941: String datatypes to be mapped to CLOB in Oracle") { + test("SPARK-12941: String datatypes to be mapped to VARCHAR(255) in Oracle") { // create a sample dataframe with string type val df1 = sparkContext.parallelize(Seq(("foo"))).toDF("x") // write the dataframe to the oracle table tbl @@ -521,4 +520,19 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSpark assert(types(0).equals("class java.lang.String")) assert(!rows(0).getString(0).isEmpty) } + + test("SPARK-44885: query row with ROWID type containing NULL value") { + val rows = spark.read.format("jdbc") + .option("url", jdbcUrl) + // Rename column to `row_id` to prevent the following SQL error: + // ORA-01446: cannot select ROWID from view with DISTINCT, GROUP BY, etc. + // See also https://stackoverflow.com/a/42632686/13300239 + .option("query", "SELECT rowid as row_id from datetime where d = {d '1991-11-09'}\n" + + "union all\n" + + "select null from dual") + .load() + .collect() + assert(rows(0).getString(0).nonEmpty) + assert(rows(1).getString(0) == null) + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index c539452bb9ae0..23fbf39db3be0 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -26,13 +26,13 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} +import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} @@ -40,7 +40,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -148,6 +148,11 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { |('2013-04-05 18:01:02.123'), |('2013-04-05 18:01:02.123456')""".stripMargin).executeUpdate() + conn.prepareStatement("CREATE TABLE infinity_timestamp" + + "(id SERIAL PRIMARY KEY, timestamp_column TIMESTAMP);").executeUpdate() + conn.prepareStatement("INSERT INTO infinity_timestamp (timestamp_column)" + + " VALUES ('infinity'), ('-infinity');").executeUpdate() + conn.prepareStatement("CREATE DOMAIN not_null_text AS TEXT DEFAULT ''").executeUpdate() conn.prepareStatement("create table custom_type(type_array not_null_text[]," + "type not_null_text)").executeUpdate() @@ -432,4 +437,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(row(0).getSeq[String](0) == Seq("1", "fds", "fdsa")) assert(row(0).getString(1) == "fdasfasdf") } + + test("SPARK-44280: infinity timestamp test") { + val df = sqlContext.read.jdbc(jdbcUrl, "infinity_timestamp", new Properties) + val row = df.collect() + + assert(row.length == 2) + val infinity = row(0).getAs[Timestamp]("timestamp_column") + val negativeInfinity = row(1).getAs[Timestamp]("timestamp_column") + val minTimeStamp = -62135596800000L + val maxTimestamp = 253402300799999L + assert(infinity.getTime == maxTimestamp) + assert(negativeInfinity.getTime == minTimeStamp) + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index 4debe24754de3..1dcf101b394a4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.jdbc import javax.security.auth.login.Configuration -import com.spotify.docker.client.messages.{ContainerConfig, HostConfig} +import com.github.dockerjava.api.model.{AccessMode, Bind, ContainerConfig, HostConfig, Volume} import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnectionProvider import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresKrbIntegrationSuite" * }}} */ @@ -37,7 +37,7 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "postgres.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -48,14 +48,14 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { s"jdbc:postgresql://$ip:$port/postgres?user=$principal&gsslib=gssapi" override def beforeContainerStart( - hostConfigBuilder: HostConfig.Builder, - containerConfigBuilder: ContainerConfig.Builder): Unit = { + hostConfigBuilder: HostConfig, + containerConfigBuilder: ContainerConfig): Unit = { copyExecutableResource("postgres_krb_setup.sh", initDbDir, replaceIp) - - hostConfigBuilder.appendBinds( - HostConfig.Bind.from(initDbDir.getAbsolutePath) - .to("/docker-entrypoint-initdb.d").readOnly(true).build() - ) + val newBind = new Bind( + initDbDir.getAbsolutePath, + new Volume("/docker-entrypoint-initdb.d"), + AccessMode.ro) + hostConfigBuilder.withBinds(hostConfigBuilder.getBinds :+ newBind: _*) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 661b1277e9f03..5bcc8afefb1dd 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -80,16 +80,24 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { connection.prepareStatement( "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") .executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE DOUBLE") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", DoubleType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", DoubleType, true, defaultMetadata(DoubleType)) assert(t.schema === expectedSchema) // Update column type from DOUBLE to STRING val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE VARCHAR(10)" @@ -112,7 +120,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { sql(s"CREATE TABLE $tbl (ID INT)" + s" TBLPROPERTIES('CCSID'='UNICODE')") val t = spark.table(tbl) - val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala index 72edfc9f1bf1c..60345257f2dc4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -38,6 +38,25 @@ abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { .executeUpdate() connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") .executeUpdate() + + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_quote''_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_quote_not_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_percent%_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_percent_not_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_underscore_present')") + .executeUpdate() + connection.prepareStatement("INSERT INTO pattern_testing_table " + + "VALUES ('special_character_underscorenot_present')") + .executeUpdate() } def tablePreparation(connection: Connection): Unit diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index fc93f5cba4c03..78fdbe7158bb7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -22,9 +22,13 @@ import java.sql.Connection import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.DatabaseOnDocker +import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -39,6 +43,17 @@ import org.apache.spark.tags.DockerTest @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + def getExternalEngineQuery(executedPlan: SparkPlan): String = { + getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery + } + + def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = { + val queryNode = executedPlan.collect { case r: RowDataSourceScanExec => + r + }.head + queryNode.rdd + } + override def excluded: Seq[String] = Seq( "simple scan with OFFSET", "simple scan with LIMIT and OFFSET", @@ -60,19 +75,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD "scan with aggregate push-down: REGR_SXY without DISTINCT") override val catalogName: String = "mssql" - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") - override val env = Map( - "SA_PASSWORD" -> "Sapass123", - "ACCEPT_EULA" -> "Y" - ) - override val usesIpc = false - override val jdbcPort: Int = 1433 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - } + override val db = new MsSQLServerDatabaseOnDocker override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) @@ -86,6 +89,12 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") .executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def notSupportsTableComment: Boolean = true @@ -93,11 +102,13 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update column type from STRING to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -125,4 +136,84 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD }, errorClass = "_LEGACY_ERROR_TEMP_2271") } + + test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { + val df1 = sql("SELECT name FROM " + + s"$catalogName.employee WHERE ((name LIKE 'am%') = (name LIKE '%y'))") + assert(df1.collect().length == 4) + + val df2 = sql("SELECT name FROM " + + s"$catalogName.employee " + + "WHERE ((name NOT LIKE 'am%') = (name NOT LIKE '%y'))") + assert(df2.collect().length == 4) + + val df3 = sql("SELECT name FROM " + + s"$catalogName.employee " + + "WHERE (dept > 1 AND ((name LIKE 'am%') = (name LIKE '%y')))") + assert(df3.collect().length == 3) + } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END + | ELSE (name = 'Sauron') END + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """ + ) + // scalastyle:on + df.collect() + } + + test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") { + val df = sql( + s"""|SELECT * FROM $catalogName.employee + |WHERE CASE WHEN (name = 'Legolas') THEN + | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END + | ELSE 'Sauron' END = name + |""".stripMargin + ) + + // scalastyle:off + assert(getExternalEngineQuery(df.queryExecution.executedPlan) == + """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """ + ) + // scalastyle:on + df.collect() + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index b0a2d37e465ac..de0ae5d59716b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import scala.collection.JavaConverters._ -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, MsSQLServerDatabaseOnDocker} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest @@ -35,20 +35,7 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { - override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") - override val env = Map( - "SA_PASSWORD" -> "Sapass123", - "ACCEPT_EULA" -> "Y" - ) - override val usesIpc = false - override val jdbcPort: Int = 1433 - - override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" - } - + override val db = new MsSQLServerDatabaseOnDocker val map = new CaseInsensitiveStringMap( Map("url" -> db.getJdbcUrl(dockerIp, externalPort), "driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver").asJava) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 5e340f135c85d..faf9f14b260d4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -68,8 +68,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:mysql://$ip:$port/" + - s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + + "&useSSL=false&disableMariaDbDriver" } override def sparkConf: SparkConf = super.sparkConf @@ -88,16 +88,24 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest connection.prepareStatement( "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + " bonus DOUBLE)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col LONGTEXT + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update column type from STRING to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -145,7 +153,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest sql(s"CREATE TABLE $tbl (ID INT)" + s" TBLPROPERTIES('ENGINE'='InnoDB', 'DEFAULT CHARACTER SET'='utf8')") val t = spark.table(tbl) - val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) } @@ -164,3 +173,32 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest } } } + +/** + * To run this test suite for a specific version (e.g., mysql:8.3.0): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.3.0 + * ./build/sbt -Pdocker-integration-tests + * "docker-integration-tests/testOnly *MySQLOverMariaConnectorIntegrationSuite" + * }}} + */ +@DockerTest +class MySQLOverMariaConnectorIntegrationSuite extends MySQLIntegrationSuite { + override def defaultMetadata(dataType: DataType = StringType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", true) + .build() + + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + + "&useSSL=false" + } +} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index d58146fecdf42..8b889f8509f56 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -45,8 +45,8 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = - s"jdbc:mysql://$ip:$port/" + - s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true" + + "&useSSL=false&disableMariaDbDriver" } val map = new CaseInsensitiveStringMap( diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 5124199328ce2..002091b6a0d80 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -22,8 +22,9 @@ import java.util.Locale import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkRuntimeException} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ @@ -86,6 +87,12 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes s"jdbc:oracle:thin:system/$oracle_password@//$ip:$port/xe" } + override def defaultMetadata(dataType: DataType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType]) + .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, "varchar(255)") + .build() + override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) @@ -99,16 +106,24 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes connection.prepareStatement( "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + " bonus BINARY_DOUBLE)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", DecimalType(10, 0), true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", DecimalType(10, 0), true, super.defaultMetadata(DecimalType(10, 0))) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE LONG") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", DecimalType(19, 0), true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", DecimalType(19, 0), true, super.defaultMetadata(DecimalType(19, 0))) assert(t.schema === expectedSchema) // Update column type from LONG to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -129,12 +144,17 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) - test("SPARK-43049: Use CLOB instead of VARCHAR(255) for StringType for Oracle JDBC") { + test("SPARK-46478: Revert SPARK-43049 to use varchar(255) for string") { val tableName = catalogName + ".t1" withTable(tableName) { sql(s"CREATE TABLE $tableName(c1 string)") - sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") - assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 256) + checkError( + exception = intercept[SparkRuntimeException] { + sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") + }, + errorClass = "EXCEED_LIMIT_LENGTH", + parameters = Map("limit" -> "255") + ) } } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 85e85f8bf3803..b0edac3fcdd1f 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -22,15 +22,16 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2) * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @@ -38,7 +39,7 @@ import org.apache.spark.tags.DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) @@ -59,16 +60,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT connection.prepareStatement( "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + " bonus double precision)").executeUpdate() + connection.prepareStatement( + s"""CREATE TABLE pattern_testing_table ( + |pattern_testing_col VARCHAR(50) + |) + """.stripMargin + ).executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") var t = spark.table(tbl) - var expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $tbl ALTER COLUMN id TYPE STRING") t = spark.table(tbl) - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update column type from STRING to INTEGER val sql1 = s"ALTER TABLE $tbl ALTER COLUMN id TYPE INTEGER" @@ -91,7 +100,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT sql(s"CREATE TABLE $tbl (ID INT)" + s" TBLPROPERTIES('TABLESPACE'='pg_default')") val t = spark.table(tbl) - val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("ID", IntegerType, true, defaultMetadata(IntegerType)) assert(t.schema === expectedSchema) } @@ -114,4 +124,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT ) } } + + test("SPARK-49695: Postgres fix xor push-down") { + val df = spark.sql(s"select dept, name from $catalogName.employee where dept ^ 6 = 0") + val rows = df.collect() + assert(!df.queryExecution.sparkPlan.exists(_.isInstanceOf[FilterExec])) + assert(rows.length == 1) + assert(rows(0).getInt(0) === 6) + assert(rows(0).getString(1) === "jen") + } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index cf7266e67e325..b725fc8967514 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -26,16 +26,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:15.1): + * To run this test suite for a specific version (e.g., postgres:16.2): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.2 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.2-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index b5f5b0e5f20bd..a0f337912c859 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -49,18 +49,21 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def notSupportsTableComment: Boolean = false - val defaultMetadata = new MetadataBuilder().putLong("scale", 0).build() + def defaultMetadata(dataType: DataType = StringType): Metadata = new MetadataBuilder() + .putLong("scale", 0) + .putBoolean("isSigned", dataType.isInstanceOf[NumericType]) + .build() def testUpdateColumnNullability(tbl: String): Unit = { sql(s"CREATE TABLE $catalogName.alt_table (ID STRING NOT NULL)") var t = spark.table(s"$catalogName.alt_table") // nullable is true in the expectedSchema because Spark always sets nullable to true // regardless of the JDBC metadata https://github.com/apache/spark/pull/18445 - var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $catalogName.alt_table ALTER COLUMN ID DROP NOT NULL") t = spark.table(s"$catalogName.alt_table") - expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Update nullability of not existing column val msg = intercept[AnalysisException] { @@ -72,8 +75,9 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def testRenameColumn(tbl: String): Unit = { sql(s"ALTER TABLE $tbl RENAME COLUMN ID TO RENAMED") val t = spark.table(s"$tbl") - val expectedSchema = new StructType().add("RENAMED", StringType, true, defaultMetadata) - .add("ID1", StringType, true, defaultMetadata).add("ID2", StringType, true, defaultMetadata) + val expectedSchema = new StructType().add("RENAMED", StringType, true, defaultMetadata()) + .add("ID1", StringType, true, defaultMetadata()) + .add("ID2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) } @@ -83,16 +87,19 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu withTable(s"$catalogName.alt_table") { sql(s"CREATE TABLE $catalogName.alt_table (ID STRING)") var t = spark.table(s"$catalogName.alt_table") - var expectedSchema = new StructType().add("ID", StringType, true, defaultMetadata) + var expectedSchema = new StructType() + .add("ID", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C1 STRING, C2 STRING)") t = spark.table(s"$catalogName.alt_table") - expectedSchema = expectedSchema.add("C1", StringType, true, defaultMetadata) - .add("C2", StringType, true, defaultMetadata) + expectedSchema = expectedSchema + .add("C1", StringType, true, defaultMetadata()) + .add("C2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 STRING)") t = spark.table(s"$catalogName.alt_table") - expectedSchema = expectedSchema.add("C3", StringType, true, defaultMetadata) + expectedSchema = expectedSchema + .add("C3", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Add already existing column checkError( @@ -125,7 +132,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu sql(s"ALTER TABLE $catalogName.alt_table DROP COLUMN C1") sql(s"ALTER TABLE $catalogName.alt_table DROP COLUMN c3") val t = spark.table(s"$catalogName.alt_table") - val expectedSchema = new StructType().add("C2", StringType, true, defaultMetadata) + val expectedSchema = new StructType() + .add("C2", StringType, true, defaultMetadata()) assert(t.schema === expectedSchema) // Drop not existing column val msg = intercept[AnalysisException] { @@ -350,6 +358,235 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(scan.schema.names.sameElements(Seq(col))) } + test("SPARK-48172: Test CONTAINS") { + val df1 = spark.sql( + s""" + |SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'quote\\'')""".stripMargin) + df1.explain("formatted") + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'percent%')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'underscore_')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE contains(pattern_testing_col, 'character') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test ENDSWITH") { + val df1 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'quote\\'_present')""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'percent%_present')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'underscore_present')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE endswith(pattern_testing_col, 'present') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test STARTSWITH") { + val df1 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character_quote\\'')""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character_percent%')""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character_underscore_')""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE startswith(pattern_testing_col, 'special_character') + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + } + + test("SPARK-48172: Test LIKE") { + // this one should map to contains + val df1 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%quote\\'%'""".stripMargin) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "special_character_quote'_present") + + val df2 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%percent\\%%'""".stripMargin) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "special_character_percent%_present") + + val df3 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%underscore\\_%'""".stripMargin) + val rows3 = df3.collect() + assert(rows3.length === 1) + assert(rows3(0).getString(0) === "special_character_underscore_present") + + val df4 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%character%' + |ORDER BY pattern_testing_col""".stripMargin) + val rows4 = df4.collect() + assert(rows4.length === 6) + assert(rows4(0).getString(0) === "special_character_percent%_present") + assert(rows4(1).getString(0) === "special_character_percent_not_present") + assert(rows4(2).getString(0) === "special_character_quote'_present") + assert(rows4(3).getString(0) === "special_character_quote_not_present") + assert(rows4(4).getString(0) === "special_character_underscore_present") + assert(rows4(5).getString(0) === "special_character_underscorenot_present") + + // map to startsWith + // this one should map to contains + val df5 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character_quote\\'%'""".stripMargin) + val rows5 = df5.collect() + assert(rows5.length === 1) + assert(rows5(0).getString(0) === "special_character_quote'_present") + + val df6 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character_percent\\%%'""".stripMargin) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "special_character_percent%_present") + + val df7 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character_underscore\\_%'""".stripMargin) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "special_character_underscore_present") + + val df8 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE 'special_character%' + |ORDER BY pattern_testing_col""".stripMargin) + val rows8 = df8.collect() + assert(rows8.length === 6) + assert(rows8(0).getString(0) === "special_character_percent%_present") + assert(rows8(1).getString(0) === "special_character_percent_not_present") + assert(rows8(2).getString(0) === "special_character_quote'_present") + assert(rows8(3).getString(0) === "special_character_quote_not_present") + assert(rows8(4).getString(0) === "special_character_underscore_present") + assert(rows8(5).getString(0) === "special_character_underscorenot_present") + // map to endsWith + // this one should map to contains + val df9 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%quote\\'_present'""".stripMargin) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "special_character_quote'_present") + + val df10 = spark.sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%percent\\%_present'""".stripMargin) + val rows10 = df10.collect() + assert(rows10.length === 1) + assert(rows10(0).getString(0) === "special_character_percent%_present") + + val df11 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%underscore\\_present'""".stripMargin) + val rows11 = df11.collect() + assert(rows11.length === 1) + assert(rows11(0).getString(0) === "special_character_underscore_present") + + val df12 = spark. + sql( + s"""SELECT * FROM $catalogAndNamespace.${caseConvert("pattern_testing_table")} + |WHERE pattern_testing_col LIKE '%present' ORDER BY pattern_testing_col""".stripMargin) + val rows12 = df12.collect() + assert(rows12.length === 6) + assert(rows12(0).getString(0) === "special_character_percent%_present") + assert(rows12(1).getString(0) === "special_character_percent_not_present") + assert(rows12(2).getString(0) === "special_character_quote'_present") + assert(rows12(3).getString(0) === "special_character_quote_not_present") + assert(rows12(4).getString(0) === "special_character_underscore_present") + assert(rows12(5).getString(0) === "special_character_underscorenot_present") + } + test("SPARK-37038: Test TABLESAMPLE") { if (supportsTableSample) { withTable(s"$catalogName.new_table") { diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml index 340974cc789bd..f9f6dfc20524b 100644 --- a/connector/kafka-0-10-assembly/pom.xml +++ b/connector/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index fdd1196cd446a..8a2195bde7e0d 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index 3256130c50f3b..b081b67a56ff8 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index 706eb2dd2c399..b59e6401191be 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml index cd5c0393f6f84..b44c6f1f259d7 100644 --- a/connector/kinesis-asl-assembly/pom.xml +++ b/connector/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index c70a073e73407..608671f47a0c3 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 7d12af3256f1f..d388b480e065d 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -275,7 +275,7 @@ private[streaming] object StreamingExamples extends Logging { // We first log something to initialize Spark's default logging, then we override the // logging level. logInfo("Setting log level to [WARN] for streaming example." + - " To override add a custom log4j.properties to the classpath.") + " To override add a custom log4j2.properties to the classpath.") Configurator.setRootLevel(Level.WARN) } } diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 3d6bbea7d41c5..91df2118e6092 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala index 5c4a5ff068968..d2417674837be 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -22,12 +22,12 @@ import scala.util.control.NonFatal import com.google.protobuf.DynamicMessage import com.google.protobuf.TypeRegistry -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.protobuf.utils.{ProtobufOptions, ProtobufUtils, SchemaConverters} -import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType} private[sql] case class ProtobufDataToCatalyst( child: Expression, @@ -39,16 +39,8 @@ private[sql] case class ProtobufDataToCatalyst( override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) - override lazy val dataType: DataType = { - val dt = SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType - parseMode match { - // With PermissiveMode, the output Catalyst row might contain columns of null values for - // corrupt records, even if some of the columns are not nullable in the user-provided schema. - // Therefore we force the schema to be all nullable here. - case PermissiveMode => dt.asNullable - case _ => dt - } - } + override lazy val dataType: DataType = + SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType override def nullable: Boolean = true @@ -87,22 +79,9 @@ private[sql] case class ProtobufDataToCatalyst( mode } - @transient private lazy val nullResultRow: Any = dataType match { - case st: StructType => - val resultRow = new SpecificInternalRow(st.map(_.dataType)) - for (i <- 0 until st.length) { - resultRow.setNullAt(i) - } - resultRow - - case _ => - null - } - private def handleException(e: Throwable): Any = { parseMode match { - case PermissiveMode => - nullResultRow + case PermissiveMode => null case FailFastMode => throw QueryExecutionErrors.malformedProtobufMessageDetectedInMessageParsingError(e) case _ => diff --git a/connector/protobuf/src/test/resources/log4j2.properties b/connector/protobuf/src/test/resources/log4j2.properties index ab02104c69697..550fd261b6fb5 100644 --- a/connector/protobuf/src/test/resources/log4j2.properties +++ b/connector/protobuf/src/test/resources/log4j2.properties @@ -32,7 +32,7 @@ appender.console.type = Console appender.console.name = console appender.console.target = SYSTEM_ERR appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %t: %m%n%ex +appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n # Ignore messages below warning level from Jetty, because it's a bit verbose logger.jetty.name = org.sparkproject.jetty diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index d3e63a11a66bf..62d0efd7459b2 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -79,20 +79,9 @@ class ProtobufCatalystDataConversionSuite .eval() } - val expected = { - val expectedSchema = ProtobufUtils.buildDescriptor(descBytes, badSchema) - SchemaConverters.toSqlType(expectedSchema).dataType match { - case st: StructType => - Row.fromSeq((0 until st.length).map { _ => - null - }) - case _ => null - } - } - checkEvaluation( ProtobufDataToCatalyst(binary, badSchema, Some(descBytes), Map("mode" -> "PERMISSIVE")), - expected) + expected = null) } protected def prepareExpectedResult(expected: Any): Any = expected match { @@ -137,7 +126,8 @@ class ProtobufCatalystDataConversionSuite while ( data != null && (data.get(0) == defaultValue || - (dt == BinaryType && + (dt.fields(0).dataType == BinaryType && + data.get(0) != null && data.get(0).asInstanceOf[Array[Byte]].isEmpty))) data = generator().asInstanceOf[Row] diff --git a/connector/spark-ganglia-lgpl/pom.xml b/connector/spark-ganglia-lgpl/pom.xml index c0dcde1355849..572766941ed93 100644 --- a/connector/spark-ganglia-lgpl/pom.xml +++ b/connector/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 6519b46d96e31..15d318844ce38 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.5.0-SNAPSHOT + 3.5.5 ../pom.xml @@ -243,6 +243,10 @@ org.scala-lang.modules scala-xml_${scala.binary.version}
+ + org.scala-lang.modules + scala-collection-compat_${scala.binary.version} + org.scala-lang scala-library @@ -481,6 +485,14 @@ commons-logging commons-logging + + org.codehaus.jackson + jackson-mapper-asl + + + org.codehaus.jackson + jackson-core-asl + com.fasterxml.jackson.core jackson-core @@ -513,69 +525,7 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - - ${project.basedir}/src/main/resources - - - - ${project.build.directory}/extra-resources - true - - - - org.apache.maven.plugins - maven-antrun-plugin - - - choose-shell-and-script - validate - - run - - - true - - - - - - - - - - - - Shell to use for generating spark-version-info.properties file = - ${shell} - - Script to use for generating spark-version-info.properties file = - ${spark-build-info-script} - - - - - - generate-spark-build-info - generate-resources - - - - - - - - - - - - run - - - - org.apache.maven.plugins maven-dependency-plugin diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 91910b99ac999..2a580e341dc33 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -14,6 +14,7 @@ package org.apache.spark.io; import org.apache.spark.storage.StorageUtils; +import org.apache.spark.unsafe.Platform; import java.io.File; import java.io.IOException; @@ -39,7 +40,7 @@ public final class NioBufferedFileInputStream extends InputStream { private final FileChannel fileChannel; public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException { - byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes); + byteBuffer = Platform.allocateDirectBuffer(bufferSizeInBytes); fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); byteBuffer.flip(); } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 83352611770fd..08c080f5a5a1d 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -18,6 +18,7 @@ package org.apache.spark.memory; import javax.annotation.concurrent.GuardedBy; +import java.io.InterruptedIOException; import java.io.IOException; import java.nio.channels.ClosedByInterruptException; import java.util.Arrays; @@ -242,7 +243,7 @@ private long trySpillAndAcquire( cList.remove(idx); return 0; } - } catch (ClosedByInterruptException e) { + } catch (ClosedByInterruptException | InterruptedIOException e) { // This called by user to kill a task (e.g: speculative task). logger.error("error while calling spill() on " + consumerToSpill, e); throw new RuntimeException(e.getMessage()); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index a82f691d085d4..b097089282ce3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -150,11 +150,21 @@ public long[] getChecksums() { * Sorts the in-memory records and writes the sorted records to an on-disk file. * This method does not free the sort data structures. * - * @param isLastFile if true, this indicates that we're writing the final output file and that the - * bytes written should be counted towards shuffle spill metrics rather than - * shuffle write metrics. + * @param isFinalFile if true, this indicates that we're writing the final output file and that + * the bytes written should be counted towards shuffle write metrics rather + * than shuffle spill metrics. */ - private void writeSortedFile(boolean isLastFile) { + private void writeSortedFile(boolean isFinalFile) { + // Only emit the log if this is an actual spilling. + if (!isFinalFile) { + logger.info( + "Task {} on Thread {} spilling sort data of {} to disk ({} {} so far)", + taskContext.taskAttemptId(), + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() != 1 ? " times" : " time"); + } // This call performs the actual sort. final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = @@ -167,13 +177,14 @@ private void writeSortedFile(boolean isLastFile) { final ShuffleWriteMetricsReporter writeMetricsToUse; - if (isLastFile) { + if (isFinalFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. writeMetricsToUse = writeMetrics; } else { // We're spilling, so bytes written should be counted towards spill rather than write. // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count // them towards shuffle bytes written. + // The actual shuffle bytes written will be counted when we merge the spill files. writeMetricsToUse = new ShuffleWriteMetrics(); } @@ -246,7 +257,7 @@ private void writeSortedFile(boolean isLastFile) { spills.add(spillInfo); } - if (!isLastFile) { // i.e. this is a spill file + if (!isFinalFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter // relies on its `recordWritten()` method being called in order to trigger periodic updates to @@ -281,12 +292,6 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { return 0L; } - logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", - Thread.currentThread().getId(), - Utils.bytesToString(getMemoryUsage()), - spills.size(), - spills.size() > 1 ? " times" : " time"); - writeSortedFile(false); final long spillSize = freeMemory(); inMemSorter.reset(); @@ -440,8 +445,9 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p */ public SpillInfo[] closeAndGetSpills() throws IOException { if (inMemSorter != null) { - // Do not count the final file towards the spill count. - writeSortedFile(true); + // Here we are spilling the remaining data in the buffer. If there is no spill before, this + // final spill file will be the final shuffle output file. + writeSortedFile(/* isFinalFile = */spills.isEmpty()); freeMemory(); inMemSorter.free(); inMemSorter = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9c54184105951..d5b4eb138b1a6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -327,12 +327,6 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep logger.debug("Using slow merge"); mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); } catch (Exception e) { try { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index eb4d9d9abc8e3..38f0a60f8b0dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.sort.io; +import java.util.Collections; import java.util.Map; import java.util.Optional; @@ -56,7 +57,10 @@ public void initializeExecutor(String appId, String execId, Map if (blockManager == null) { throw new IllegalStateException("No blockManager available from the SparkEnv."); } - blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + blockResolver = + new IndexShuffleBlockResolver( + sparkConf, blockManager, Collections.emptyMap() /* Shouldn't be accessed */ + ); } @Override diff --git a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java index 57d96756c8bee..2791209e019b6 100644 --- a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java +++ b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java @@ -40,6 +40,15 @@ public ChildFirstURLClassLoader(URL[] urls, ClassLoader parent) { this.parent = new ParentClassLoader(parent); } + /** + * Specify the grandparent if there is a need to load in the order of + * `grandparent -> urls (child) -> parent`. + */ + public ChildFirstURLClassLoader(URL[] urls, ClassLoader parent, ClassLoader grandparent) { + super(urls, grandparent); + this.parent = new ParentClassLoader(parent); + } + @Override public Class loadClass(String name, boolean resolve) throws ClassNotFoundException { try { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index db79efd008530..cf29835b2ce89 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -28,6 +28,8 @@ import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.*; @@ -36,6 +38,7 @@ * of the file format). */ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { + private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); public static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb private InputStream in; @@ -82,6 +85,15 @@ public UnsafeSorterSpillReader( Closeables.close(bs, /* swallowIOException = */ true); throw e; } + if (taskContext != null) { + taskContext.addTaskCompletionListener(context -> { + try { + close(); + } catch (IOException e) { + logger.info("error while closing UnsafeSorterSpillReader", e); + } + }); + } } @Override diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css deleted file mode 100644 index b9c16ca78a01c..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.css +++ /dev/null @@ -1 +0,0 @@ -:root{--dt-row-selected: 2, 117, 216;--dt-row-selected-text: 255, 255, 255;--dt-row-selected-link: 9, 10, 11}table.dataTable td.dt-control{text-align:center;cursor:pointer}table.dataTable td.dt-control:before{height:1em;width:1em;margin-top:-9px;display:inline-block;color:white;border:.15em solid white;border-radius:1em;box-shadow:0 0 .2em #444;box-sizing:content-box;text-align:center;text-indent:0 !important;font-family:"Courier New",Courier,monospace;line-height:1em;content:"+";background-color:#31b131}table.dataTable tr.dt-hasChild td.dt-control:before{content:"-";background-color:#d33333}table.dataTable thead>tr>th.sorting,table.dataTable thead>tr>th.sorting_asc,table.dataTable thead>tr>th.sorting_desc,table.dataTable thead>tr>th.sorting_asc_disabled,table.dataTable thead>tr>th.sorting_desc_disabled,table.dataTable thead>tr>td.sorting,table.dataTable thead>tr>td.sorting_asc,table.dataTable thead>tr>td.sorting_desc,table.dataTable thead>tr>td.sorting_asc_disabled,table.dataTable thead>tr>td.sorting_desc_disabled{cursor:pointer;position:relative;padding-right:26px}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after{position:absolute;display:block;opacity:.125;right:10px;line-height:9px;font-size:.8em}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:before{bottom:50%;content:"▲"}table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:after{top:50%;content:"▼"}table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:after{opacity:.6}table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting_asc_disabled:before{display:none}table.dataTable thead>tr>th:active,table.dataTable thead>tr>td:active{outline:none}div.dataTables_scrollBody table.dataTable thead>tr>th:before,div.dataTables_scrollBody table.dataTable thead>tr>th:after,div.dataTables_scrollBody table.dataTable thead>tr>td:before,div.dataTables_scrollBody table.dataTable thead>tr>td:after{display:none}div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:2px}div.dataTables_processing>div:last-child{position:relative;width:80px;height:15px;margin:1em auto}div.dataTables_processing>div:last-child>div{position:absolute;top:0;width:13px;height:13px;border-radius:50%;background:2 117 216;animation-timing-function:cubic-bezier(0, 1, 1, 0)}div.dataTables_processing>div:last-child>div:nth-child(1){left:8px;animation:datatables-loader-1 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(2){left:8px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(3){left:32px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(4){left:56px;animation:datatables-loader-3 .6s infinite}@keyframes datatables-loader-1{0%{transform:scale(0)}100%{transform:scale(1)}}@keyframes datatables-loader-3{0%{transform:scale(1)}100%{transform:scale(0)}}@keyframes datatables-loader-2{0%{transform:translate(0, 0)}100%{transform:translate(24px, 0)}}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th,table.dataTable thead td,table.dataTable tfoot th,table.dataTable tfoot td{text-align:left}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable{clear:both;margin-top:6px !important;margin-bottom:6px !important;max-width:none !important;border-collapse:separate !important;border-spacing:0}table.dataTable td,table.dataTable th{-webkit-box-sizing:content-box;box-sizing:content-box}table.dataTable td.dataTables_empty,table.dataTable th.dataTables_empty{text-align:center}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.table-striped>tbody>tr:nth-of-type(2n+1){background-color:transparent}table.dataTable>tbody>tr{background-color:transparent}table.dataTable>tbody>tr.selected>*{box-shadow:inset 0 0 0 9999px rgb(2, 117, 216);box-shadow:inset 0 0 0 9999px rgb(var(--dt-row-selected));color:rgb(255, 255, 255);color:rgb(var(--dt-row-selected-text))}table.dataTable>tbody>tr.selected a{color:rgb(9, 10, 11);color:rgb(var(--dt-row-selected-link))}table.dataTable.table-striped>tbody>tr.odd>*{box-shadow:inset 0 0 0 9999px rgba(0, 0, 0, 0.05)}table.dataTable.table-striped>tbody>tr.odd.selected>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.95);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.95)}table.dataTable.table-hover>tbody>tr:hover>*{box-shadow:inset 0 0 0 9999px rgba(0, 0, 0, 0.075)}table.dataTable.table-hover>tbody>tr.selected:hover>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.975);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.975)}div.dataTables_wrapper div.dataTables_length label{font-weight:normal;text-align:left;white-space:nowrap}div.dataTables_wrapper div.dataTables_length select{width:auto;display:inline-block}div.dataTables_wrapper div.dataTables_filter{text-align:right}div.dataTables_wrapper div.dataTables_filter label{font-weight:normal;white-space:nowrap;text-align:left}div.dataTables_wrapper div.dataTables_filter input{margin-left:.5em;display:inline-block;width:auto}div.dataTables_wrapper div.dataTables_info{padding-top:.85em}div.dataTables_wrapper div.dataTables_paginate{margin:0;white-space:nowrap;text-align:right}div.dataTables_wrapper div.dataTables_paginate ul.pagination{margin:2px 0;white-space:nowrap;justify-content:flex-end}div.dataTables_wrapper div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:1em 0}div.dataTables_scrollHead table.dataTable{margin-bottom:0 !important}div.dataTables_scrollBody>table{border-top:none;margin-top:0 !important;margin-bottom:0 !important}div.dataTables_scrollBody>table>thead .sorting:before,div.dataTables_scrollBody>table>thead .sorting_asc:before,div.dataTables_scrollBody>table>thead .sorting_desc:before,div.dataTables_scrollBody>table>thead .sorting:after,div.dataTables_scrollBody>table>thead .sorting_asc:after,div.dataTables_scrollBody>table>thead .sorting_desc:after{display:none}div.dataTables_scrollBody>table>tbody tr:first-child th,div.dataTables_scrollBody>table>tbody tr:first-child td{border-top:none}div.dataTables_scrollFoot>.dataTables_scrollFootInner{box-sizing:content-box}div.dataTables_scrollFoot>.dataTables_scrollFootInner>table{margin-top:0 !important;border-top:none}@media screen and (max-width: 767px){div.dataTables_wrapper div.dataTables_length,div.dataTables_wrapper div.dataTables_filter,div.dataTables_wrapper div.dataTables_info,div.dataTables_wrapper div.dataTables_paginate{text-align:center}div.dataTables_wrapper div.dataTables_paginate ul.pagination{justify-content:center !important}}table.dataTable.table-sm>thead>tr>th:not(.sorting_disabled){padding-right:20px}table.table-bordered.dataTable{border-right-width:0}table.table-bordered.dataTable th,table.table-bordered.dataTable td{border-left-width:0}table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable td:last-child,table.table-bordered.dataTable td:last-child{border-right-width:1px}table.table-bordered.dataTable tbody th,table.table-bordered.dataTable tbody td{border-bottom-width:0}div.dataTables_scrollHead table.table-bordered{border-bottom-width:0}div.table-responsive>div.dataTables_wrapper>div.row{margin:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:first-child{padding-left:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:last-child{padding-right:0} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js deleted file mode 100644 index 2937bc3c90c2c..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.2.min.js +++ /dev/null @@ -1,4 +0,0 @@ -/*! DataTables Bootstrap 4 integration - * ©2011-2017 SpryMedia Ltd - datatables.net/license - */ -!function(t){"function"==typeof define&&define.amd?define(["jquery","datatables.net"],function(e){return t(e,window,document)}):"object"==typeof exports?module.exports=function(e,a){return e=e||window,(a=a||("undefined"!=typeof window?require("jquery"):require("jquery")(e))).fn.dataTable||require("datatables.net")(e,a),t(a,0,e.document)}:t(jQuery,window,document)}(function(x,e,n,r){"use strict";var s=x.fn.dataTable;return x.extend(!0,s.defaults,{dom:"<'row'<'col-sm-12 col-md-6'l><'col-sm-12 col-md-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-12 col-md-5'i><'col-sm-12 col-md-7'p>>",renderer:"bootstrap"}),x.extend(s.ext.classes,{sWrapper:"dataTables_wrapper dt-bootstrap4",sFilterInput:"form-control form-control-sm",sLengthSelect:"custom-select custom-select-sm form-control form-control-sm",sProcessing:"dataTables_processing card",sPageButton:"paginate_button page-item"}),s.ext.renderer.pageButton.bootstrap=function(i,e,d,a,l,c){function u(e,a){for(var t,n,r=function(e){e.preventDefault(),x(e.currentTarget).hasClass("disabled")||b.page()==e.data.action||b.page(e.data.action).draw("page")},s=0,o=a.length;s",{class:m.sPageButton+" "+f,id:0===d&&"string"==typeof t?i.sTableId+"_"+t:null}).append(x("",{href:n?null:"#","aria-controls":i.sTableId,"aria-disabled":n?"true":null,"aria-label":w[t],"aria-role":"link","aria-current":"active"===f?"page":null,"data-dt-idx":t,tabindex:i.iTabIndex,class:"page-link"}).html(p)).appendTo(e),i.oApi._fnBindAction(n,{action:t},r))}}var p,f,t,b=new s.Api(i),m=i.oClasses,g=i.oLanguage.oPaginate,w=i.oLanguage.oAria.paginate||{};try{t=x(e).find(n.activeElement).data("dt-idx")}catch(e){}u(x(e).empty().html('