This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f15ad40640c [SPARK-47994][SQL] Fix bug with CASE WHEN column filter 
push down in SQLServer
3f15ad40640c is described below

commit 3f15ad40640ce71764d1d00b8fae7d88df5e2194
Author: Stefan Bukorovic <stefan.bukoro...@databricks.com>
AuthorDate: Mon Apr 29 19:42:16 2024 +0800

    [SPARK-47994][SQL] Fix bug with CASE WHEN column filter push down in 
SQLServer
    
    ### What changes were proposed in this pull request?
    
    In this PR I propose a change in QueryBuilder for SQLServer. This change 
modifies push down of predicate that contains a column that is generated with a 
CASE WHEN construct, so that we add simple ` = 1` comparison to this query, 
making it work on SQLServer.
    
    ### Why are the changes needed?
    
    SQLServer does not support 0 or 1 as a boolean values. There are certain 
situations where spark optimizer rewrites filters that contain CASE WHEN 
columns in a way that adds 1 or 0 as a boolean values, which fails on SQLServer 
side with an error "An expression of non-boolean type specified in a context 
where a condition is expected". With these changes, we modify pushing this 
filters down, and error is no longer present.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    A new test case is added, which fails without these changes.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46231 from stefanbuk-db/SQLServer_case_when_bugfix.
    
    Authored-by: Stefan Bukorovic <stefan.bukoro...@databricks.com>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala     | 13 +++++++++++++
 .../spark/sql/connector/util/V2ExpressionSQLBuilder.java    |  2 +-
 .../org/apache/spark/sql/jdbc/MsSqlServerDialect.scala      |  1 +
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index f5f5d00d6bda..65f7579de820 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -131,4 +131,17 @@ class MsSqlServerIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JD
       "WHERE (dept > 1 AND ((name LIKE 'am%') = (name LIKE '%y')))")
     assert(df3.collect().length == 3)
   }
+
+  test("SPARK-47994: SQLServer does not support 1 or 0 as boolean type in CASE 
WHEN filter") {
+    val df = sql(
+      s"""
+        |WITH tbl AS (
+        |SELECT CASE
+        |WHEN e.dept = 1 THEN 'first' WHEN e.dept = 2 THEN 'second' ELSE 
'third' END
+        |AS deptString FROM $catalogName.employee as e)
+        |SELECT * FROM tbl
+        |WHERE deptString = 'first'
+        |""".stripMargin)
+    assert(df.collect().length == 2)
+  }
 }
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 61d68d4a3e88..e42d9193ea39 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -356,7 +356,7 @@ public class V2ExpressionSQLBuilder {
     return joiner.toString();
   }
 
-  private String[] expressionsToStringArray(Expression[] expressions) {
+  protected String[] expressionsToStringArray(Expression[] expressions) {
     String[] result = new String[expressions.length];
     for (int i = 0; i < expressions.length; i++) {
       result[i] = build(expressions[i]);
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 5535545efba8..e341bf3720f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -92,6 +92,7 @@ private case class MsSqlServerDialect() extends JdbcDialect {
               case o => inputToSQL(o)
             }
             visitBinaryComparison(e.name(), l, r)
+          case "CASE_WHEN" => 
visitCaseWhen(expressionsToStringArray(e.children())) + " = 1"
           case _ => super.build(expr)
         }
         case _ => super.build(expr)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to