diff --git a/jenkins/databricks/build.sh b/jenkins/databricks/build.sh index 831197c61d9..25bade91968 100755 --- a/jenkins/databricks/build.sh +++ b/jenkins/databricks/build.sh @@ -57,6 +57,7 @@ initialize() if [[ ! -d $HOME/apache-maven-3.6.3 ]]; then wget https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz -P /tmp tar xf /tmp/apache-maven-3.6.3-bin.tar.gz -C $HOME + rm -f /tmp/apache-maven-3.6.3-bin.tar.gz sudo ln -s $HOME/apache-maven-3.6.3/bin/mvn /usr/local/bin/mvn fi diff --git a/jenkins/databricks/test.sh b/jenkins/databricks/test.sh index 5ea4fce625b..38728161d12 100755 --- a/jenkins/databricks/test.sh +++ b/jenkins/databricks/test.sh @@ -49,6 +49,7 @@ source jenkins/databricks/common_vars.sh BASE_SPARK_VERSION=${BASE_SPARK_VERSION:-$(< /databricks/spark/VERSION)} SHUFFLE_SPARK_SHIM=${SHUFFLE_SPARK_SHIM:-spark${BASE_SPARK_VERSION//./}db} SHUFFLE_SPARK_SHIM=${SHUFFLE_SPARK_SHIM//\-SNAPSHOT/} +WITH_DEFAULT_UPSTREAM_SHIM=${WITH_DEFAULT_UPSTREAM_SHIM:-1} IS_SPARK_321_OR_LATER=0 [[ "$(printf '%s\n' "3.2.1" "$BASE_SPARK_VERSION" | sort -V | head -n1)" = "3.2.1" ]] && IS_SPARK_321_OR_LATER=1 @@ -90,6 +91,18 @@ run_pyarrow_tests() { ## Separate the integration tests into "CI_PART1" and "CI_PART2", run each part in parallel on separate Databricks clusters to speed up the testing process. if [[ $TEST_MODE == "DEFAULT" || $TEST_MODE == "CI_PART1" ]]; then + # Run two-shim smoke test with the base Spark build + if [[ "$WITH_DEFAULT_UPSTREAM_SHIM" != "0" ]]; then + if [[ ! -d $HOME/spark-3.2.0-bin-hadoop3.2 ]]; then + wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz -P /tmp + tar xf /tmp/spark-3.2.0-bin-hadoop3.2.tgz -C $HOME + rm -f /tmp/spark-3.2.0-bin-hadoop3.2.tgz + fi + SPARK_HOME=$HOME/spark-3.2.0-bin-hadoop3.2 \ + SPARK_SHELL_SMOKE_TEST=1 \ + PYSP_TEST_spark_shuffle_manager=com.nvidia.spark.rapids.spark320.RapidsShuffleManager \ + bash integration_tests/run_pyspark_from_build.sh + fi bash integration_tests/run_pyspark_from_build.sh --runtime_env="databricks" --test_type=$TEST_TYPE fi diff --git a/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/DatabricksShimServiceProvider.scala b/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/DatabricksShimServiceProvider.scala index e8a27aaecc8..cedaee9fe69 100644 --- a/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/DatabricksShimServiceProvider.scala +++ b/sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/DatabricksShimServiceProvider.scala @@ -21,12 +21,10 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids -import scala.util.Try - object DatabricksShimServiceProvider { val log = org.slf4j.LoggerFactory.getLogger(getClass().getName().stripSuffix("$")) def matchesVersion(dbrVersion: String): Boolean = { - Try { + try { val sparkBuildInfo = org.apache.spark.BuildInfo val databricksBuildInfo = com.databricks.BuildInfo val matchRes = sparkBuildInfo.dbrVersion.startsWith(dbrVersion) @@ -44,10 +42,10 @@ object DatabricksShimServiceProvider { log.debug(logMessage) } matchRes - }.recover { + } catch { case x: Throwable => log.debug("Databricks detection failed: " + x, x) false - }.getOrElse(false) + } } }