This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5533c81e3453 [SPARK-48355][SQL] Support for CASE statement
5533c81e3453 is described below

commit 5533c81e34534d43ae90fc2ce5ac1d174d4e8289
Author: Dušan Tišma <dusan.ti...@databricks.com>
AuthorDate: Fri Sep 13 15:01:09 2024 +0200

    [SPARK-48355][SQL] Support for CASE statement
    
    ### What changes were proposed in this pull request?
    Add support for [case 
statements](https://docs.google.com/document/d/1cpSuR3KxRuTSJ4ZMQ73FJ4_-hjouNNU2zfI4vri6yhs/edit#heading=h.ofijhkunigv)
 to sql scripting. There are 2 types of case statement - simple and searched 
(EXAMPLES BELOW). Proposed changes are:
    
    - Add `caseStatement` grammar rule to SqlBaseParser.g4
    - Add visit case statement methods to `AstBuilder`
    - Add `SearchedCaseStatement` and `SearchedCaseStatementExec` classes, to 
enable them to be run in sql scripts.
    
    The reason only searched case nodes are added is that, in the current 
implementation, a simple case is parsed into a searched case, by creating 
internal `EqualTo` expressions to compare the main case expression to the 
expressions in the when clauses. This approach is similar to the existing case 
**expressions**, which are parsed in the same way. The problem with this 
approach is that the main expression is unnecessarily evaluated N times, where 
N is the number of when clauses, which c [...]
    
    Simple case compares one expression (case variable) to others, until an 
equal one is found. Else clause is optional.
    ```
    BEGIN
      CASE 1
        WHEN 1 THEN
          SELECT 1;
        WHEN 2 THEN
          SELECT 2;
        ELSE
          SELECT 3;
      END CASE;
    END
    ```
    
    Searched case evaluates boolean expressions. Else clause is optional.
    ```
    BEGIN
      CASE
        WHEN 1 = 1 THEN
          SELECT 1;
        WHEN 2 IN (1,2,3) THEN
          SELECT 2;
        ELSE
          SELECT 3;
      END CASE;
    END
    ```
    
    ### Why are the changes needed?
    Case statements are currently not implemented in sql scripting.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users will now be able to use case statements in their sql scripts.
    
    ### How was this patch tested?
    Tests for both simple and searched case statements are added to 
SqlScriptingParserSuite, SqlScriptingExecutionNodeSuite and 
SqlScriptingInterpreterSuite.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47672 from dusantism-db/sql-scripting-case-statement.
    
    Authored-by: Dušan Tišma <dusan.ti...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   8 +
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  48 ++-
 .../parser/SqlScriptingLogicalOperators.scala      |  14 +
 .../catalyst/parser/SqlScriptingParserSuite.scala  | 297 +++++++++++++++-
 .../sql/scripting/SqlScriptingExecutionNode.scala  |  72 ++++
 .../sql/scripting/SqlScriptingInterpreter.scala    |  13 +-
 .../scripting/SqlScriptingExecutionNodeSuite.scala |  93 +++++
 .../scripting/SqlScriptingInterpreterSuite.scala   | 379 ++++++++++++++++++++-
 8 files changed, 920 insertions(+), 4 deletions(-)

diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 42f0094de351..73d5cb55295a 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -64,6 +64,7 @@ compoundStatement
     | setStatementWithOptionalVarKeyword
     | beginEndCompoundBlock
     | ifElseStatement
+    | caseStatement
     | whileStatement
     | repeatStatement
     | leaveStatement
@@ -98,6 +99,13 @@ iterateStatement
     : ITERATE multipartIdentifier
     ;
 
+caseStatement
+    : CASE (WHEN conditions+=booleanExpression THEN 
conditionalBodies+=compoundBody)+
+        (ELSE elseBody=compoundBody)? END CASE                
#searchedCaseStatement
+    | CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN 
conditionalBodies+=compoundBody)+
+        (ELSE elseBody=compoundBody)? END CASE                
#simpleCaseStatement
+    ;
+
 singleStatement
     : (statement|setResetStatement) SEMICOLON* EOF
     ;
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 924b5c2cfeb1..9620ce13d92e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -261,6 +261,52 @@ class AstBuilder extends DataTypeAstBuilder
     WhileStatement(condition, body, Some(labelText))
   }
 
+  override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): 
CaseStatement = {
+    val conditions = ctx.conditions.asScala.toList.map(boolExpr => 
withOrigin(boolExpr) {
+      SingleStatement(
+        Project(
+          Seq(Alias(expression(boolExpr), "condition")()),
+          OneRowRelation()))
+    })
+    val conditionalBodies =
+      ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
+
+    if (conditions.length != conditionalBodies.length) {
+      throw SparkException.internalError(
+        s"Mismatched number of conditions ${conditions.length} and condition 
bodies" +
+          s" ${conditionalBodies.length} in case statement")
+    }
+
+    CaseStatement(
+      conditions = conditions,
+      conditionalBodies = conditionalBodies,
+      elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
+  }
+
+  override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): 
CaseStatement = {
+    // uses EqualTo to compare the case variable(the main case expression)
+    // to the WHEN clause expressions
+    val conditions = ctx.conditionExpressions.asScala.toList.map(expr => 
withOrigin(expr) {
+      SingleStatement(
+        Project(
+          Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), 
"condition")()),
+          OneRowRelation()))
+    })
+    val conditionalBodies =
+      ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
+
+    if (conditions.length != conditionalBodies.length) {
+      throw SparkException.internalError(
+        s"Mismatched number of conditions ${conditions.length} and condition 
bodies" +
+          s" ${conditionalBodies.length} in case statement")
+    }
+
+    CaseStatement(
+      conditions = conditions,
+      conditionalBodies = conditionalBodies,
+      elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
+  }
+
   override def visitRepeatStatement(ctx: RepeatStatementContext): 
RepeatStatement = {
     val labelText = generateLabelText(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
     val boolExpr = ctx.booleanExpression()
@@ -292,7 +338,7 @@ class AstBuilder extends DataTypeAstBuilder
       case c: RepeatStatementContext
         if Option(c.beginLabel()).isDefined &&
           
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
-      => true
+        => true
       case _ => false
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
index 5e7e8b0b4fc9..ed40a5fd734b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
@@ -124,3 +124,17 @@ case class LeaveStatement(label: String) extends 
CompoundPlanStatement
  * @param label Label of the loop to iterate.
  */
 case class IterateStatement(label: String) extends CompoundPlanStatement
+
+/**
+ * Logical operator for CASE statement.
+ * @param conditions Collection of conditions which correspond to WHEN clauses.
+ * @param conditionalBodies Collection of bodies that have a corresponding 
condition,
+ *                          in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, 
i.e. ELSE branch.
+ */
+case class CaseStatement(
+    conditions: Seq[SingleStatement],
+    conditionalBodies: Seq[CompoundBody],
+    elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
+  assert(conditions.length == conditionalBodies.length)
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index bf527b9c3bd7..24ad32c5300b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -18,8 +18,9 @@
 package org.apache.spark.sql.catalyst.parser
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, 
In, Literal, ScalarSubquery}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.plans.logical.CreateVariable
+import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, Project}
 import org.apache.spark.sql.exceptions.SqlScriptingException
 
 class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
@@ -1111,6 +1112,287 @@ class SqlScriptingParserSuite extends SparkFunSuite 
with SQLHelper {
 
   }
 
+  test("searched case statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 = 1 THEN
+        |     SELECT 42;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.conditions.length == 1)
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    assert(caseStmt.conditions.head.getText == "1 = 1")
+  }
+
+  test("searched case statement - multi when") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 IN (1,2,3) THEN
+        |     SELECT 1;
+        |   WHEN (SELECT * FROM t) THEN
+        |     SELECT * FROM b;
+        |   WHEN 1 = 1 THEN
+        |     SELECT 42;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.conditions.length == 3)
+    assert(caseStmt.conditionalBodies.length == 3)
+    assert(caseStmt.elseBody.isEmpty)
+
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    assert(caseStmt.conditions.head.getText == "1 IN (1,2,3)")
+
+    
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+    
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 1")
+
+    assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
+    assert(caseStmt.conditions(1).getText == "(SELECT * FROM t)")
+
+    
assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
+    
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT * FROM b")
+
+    assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
+    assert(caseStmt.conditions(2).getText == "1 = 1")
+
+    
assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
+    
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 42")
+  }
+
+  test("searched case statement with else") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 = 1 THEN
+        |     SELECT 42;
+        |   ELSE
+        |     SELECT 43;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.elseBody.isDefined)
+    assert(caseStmt.conditions.length == 1)
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    assert(caseStmt.conditions.head.getText == "1 = 1")
+
+    assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+    assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 43")
+  }
+
+  test("searched case statement nested") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 = 1 THEN
+        |     CASE
+        |       WHEN 2 = 1 THEN
+        |         SELECT 41;
+        |       ELSE
+        |         SELECT 42;
+        |     END CASE;
+        |  END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.conditions.length == 1)
+    assert(caseStmt.conditionalBodies.length == 1)
+    assert(caseStmt.elseBody.isEmpty)
+
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    assert(caseStmt.conditions.head.getText == "1 = 1")
+
+    
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
+    val nestedCaseStmt =
+      
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
+
+    assert(nestedCaseStmt.conditions.length == 1)
+    assert(nestedCaseStmt.conditionalBodies.length == 1)
+    assert(nestedCaseStmt.elseBody.isDefined)
+
+    assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
+    assert(nestedCaseStmt.conditions.head.getText == "2 = 1")
+
+    
assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+    
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 41")
+
+    
assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+    
assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 42")
+  }
+
+  test("simple case statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 1 THEN
+        |     SELECT 1;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.conditions.length == 1)
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == 
Literal(1), _ == Literal(1))
+  }
+
+
+  test("simple case statement - multi when") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 1 THEN
+        |     SELECT 1;
+        |   WHEN (SELECT 2) THEN
+        |     SELECT * FROM b;
+        |   WHEN 3 IN (1,2,3) THEN
+        |     SELECT 42;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.conditions.length == 3)
+    assert(caseStmt.conditionalBodies.length == 3)
+    assert(caseStmt.elseBody.isEmpty)
+
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == 
Literal(1), _ == Literal(1))
+
+    
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+    
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 1")
+
+    assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(
+    caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery])
+
+    
assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
+    
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT * FROM b")
+
+    assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(
+      caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In])
+
+    
assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
+    
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 42")
+  }
+
+  test("simple case statement with else") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 1 THEN
+        |     SELECT 42;
+        |   ELSE
+        |     SELECT 43;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.elseBody.isDefined)
+    assert(caseStmt.conditions.length == 1)
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == 
Literal(1), _ == Literal(1))
+
+    assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+    assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 43")
+  }
+
+  test("simple case statement nested") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | CASE (SELECT 1)
+        |   WHEN 1 THEN
+        |     CASE 2
+        |       WHEN 2 THEN
+        |         SELECT 41;
+        |       ELSE
+        |         SELECT 42;
+        |     END CASE;
+        |  END CASE;
+        |END
+        |""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+    val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+    assert(caseStmt.conditions.length == 1)
+    assert(caseStmt.conditionalBodies.length == 1)
+    assert(caseStmt.elseBody.isEmpty)
+
+    assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(
+      caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == 
Literal(1))
+
+    
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
+    val nestedCaseStmt =
+      
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
+
+    assert(nestedCaseStmt.conditions.length == 1)
+    assert(nestedCaseStmt.conditionalBodies.length == 1)
+    assert(nestedCaseStmt.elseBody.isDefined)
+
+    assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
+    checkSimpleCaseStatementCondition(
+      nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2))
+
+    
assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+    
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 41")
+
+    
assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+    
assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 42")
+  }
+
   // Helper methods
   def cleanupStatementString(statementStr: String): String = {
     statementStr
@@ -1119,4 +1401,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with 
SQLHelper {
       .replace("END", "")
       .trim
   }
+
+  private def checkSimpleCaseStatementCondition(
+      conditionStatement: SingleStatement,
+      predicateLeft: Expression => Boolean,
+      predicateRight: Expression => Boolean): Unit = {
+    assert(conditionStatement.parsedPlan.isInstanceOf[Project])
+    val project = conditionStatement.parsedPlan.asInstanceOf[Project]
+    assert(project.projectList.head.isInstanceOf[Alias])
+    
assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo])
+    val equalTo = 
project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo]
+    assert(predicateLeft(equalTo.left))
+    assert(predicateRight(equalTo.right))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index cae797614314..af9fd5464277 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -405,6 +405,78 @@ class WhileStatementExec(
   }
 }
 
+/**
+ * Executable node for CaseStatement.
+ * @param conditions Collection of executable conditions which correspond to 
WHEN clauses.
+ * @param conditionalBodies Collection of executable bodies that have a 
corresponding condition,
+ *                 in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, 
i.e. ELSE branch.
+ * @param session Spark session that SQL script is executed within.
+ */
+class CaseStatementExec(
+    conditions: Seq[SingleStatementExec],
+    conditionalBodies: Seq[CompoundBodyExec],
+    elseBody: Option[CompoundBodyExec],
+    session: SparkSession) extends NonLeafStatementExec {
+  private object CaseState extends Enumeration {
+    val Condition, Body = Value
+  }
+
+  private var state = CaseState.Condition
+  private var curr: Option[CompoundStatementExec] = Some(conditions.head)
+
+  private var clauseIdx: Int = 0
+  private val conditionsCount = conditions.length
+
+  private lazy val treeIterator: Iterator[CompoundStatementExec] =
+    new Iterator[CompoundStatementExec] {
+      override def hasNext: Boolean = curr.nonEmpty
+
+      override def next(): CompoundStatementExec = state match {
+        case CaseState.Condition =>
+          val condition = curr.get.asInstanceOf[SingleStatementExec]
+          if (evaluateBooleanCondition(session, condition)) {
+            state = CaseState.Body
+            curr = Some(conditionalBodies(clauseIdx))
+          } else {
+            clauseIdx += 1
+            if (clauseIdx < conditionsCount) {
+              // There are WHEN clauses remaining.
+              state = CaseState.Condition
+              curr = Some(conditions(clauseIdx))
+            } else if (elseBody.isDefined) {
+              // ELSE clause exists.
+              state = CaseState.Body
+              curr = Some(elseBody.get)
+            } else {
+              // No remaining clauses.
+              curr = None
+            }
+          }
+          condition
+        case CaseState.Body =>
+          assert(curr.get.isInstanceOf[CompoundBodyExec])
+          val currBody = curr.get.asInstanceOf[CompoundBodyExec]
+          val retStmt = currBody.getTreeIterator.next()
+          if (!currBody.getTreeIterator.hasNext) {
+            curr = None
+          }
+          retStmt
+      }
+    }
+
+  override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+  override def reset(): Unit = {
+    state = CaseState.Condition
+    curr = Some(conditions.head)
+    clauseIdx = 0
+    conditions.foreach(c => c.reset())
+    conditionalBodies.foreach(b => b.reset())
+    elseBody.foreach(b => b.reset())
+  }
+}
+
 /**
  * Executable node for RepeatStatement.
  * @param condition Executable node for the condition - evaluates to a row 
with a single boolean
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 865b33999655..917b4d6f45ee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CompoundBody, 
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, 
RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, 
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, 
RepeatStatement, SingleStatement, WhileStatement}
 import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, 
DropVariable, LogicalPlan}
 import org.apache.spark.sql.catalyst.trees.Origin
 
@@ -95,6 +95,17 @@ case class SqlScriptingInterpreter() {
         new IfElseStatementExec(
           conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, 
session)
 
+      case CaseStatement(conditions, conditionalBodies, elseBody) =>
+        val conditionsExec = conditions.map(condition =>
+          // todo: what to put here for isInternal, in case of simple case 
statement
+          new SingleStatementExec(condition.parsedPlan, condition.origin, 
isInternal = false))
+        val conditionalBodiesExec = conditionalBodies.map(body =>
+          transformTreeIntoExecutable(body, 
session).asInstanceOf[CompoundBodyExec])
+        val unconditionalBodiesExec = elseBody.map(body =>
+          transformTreeIntoExecutable(body, 
session).asInstanceOf[CompoundBodyExec])
+        new CaseStatementExec(
+          conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, 
session)
+
       case WhileStatement(condition, body, label) =>
         val conditionExec =
           new SingleStatementExec(condition.parsedPlan, condition.origin, 
isInternal = false)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 4b72ca8ecaa9..83d8191d01ec 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -576,4 +576,97 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite 
with SharedSparkSessi
       "body1", "lbl", "con1",
       "body1", "lbl", "con1"))
   }
+
+  test("searched case - enter first WHEN clause") {
+    val iter = new CompoundBodyExec(Seq(
+      new CaseStatementExec(
+        conditions = Seq(
+          TestIfElseCondition(condVal = true, description = "con1"),
+          TestIfElseCondition(condVal = false, description = "con2")
+        ),
+        conditionalBodies = Seq(
+          new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+          new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+        ),
+        elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
+        session = spark
+      )
+    )).getTreeIterator
+    val statements = iter.map(extractStatementValue).toSeq
+    assert(statements === Seq("con1", "body1"))
+  }
+
+  test("searched case - enter body of the ELSE clause") {
+    val iter = new CompoundBodyExec(Seq(
+      new CaseStatementExec(
+        conditions = Seq(
+          TestIfElseCondition(condVal = false, description = "con1")
+        ),
+        conditionalBodies = Seq(
+          new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+        ),
+        elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))),
+        session = spark
+      )
+    )).getTreeIterator
+    val statements = iter.map(extractStatementValue).toSeq
+    assert(statements === Seq("con1", "body2"))
+  }
+
+  test("searched case - enter second WHEN clause") {
+    val iter = new CompoundBodyExec(Seq(
+      new CaseStatementExec(
+        conditions = Seq(
+          TestIfElseCondition(condVal = false, description = "con1"),
+          TestIfElseCondition(condVal = true, description = "con2")
+        ),
+        conditionalBodies = Seq(
+          new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+          new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+        ),
+        elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
+        session = spark
+      )
+    )).getTreeIterator
+    val statements = iter.map(extractStatementValue).toSeq
+    assert(statements === Seq("con1", "con2", "body2"))
+  }
+
+  test("searched case - without else (successful check)") {
+    val iter = new CompoundBodyExec(Seq(
+      new CaseStatementExec(
+        conditions = Seq(
+          TestIfElseCondition(condVal = false, description = "con1"),
+          TestIfElseCondition(condVal = true, description = "con2")
+        ),
+        conditionalBodies = Seq(
+          new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+          new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+        ),
+        elseBody = None,
+        session = spark
+      )
+    )).getTreeIterator
+    val statements = iter.map(extractStatementValue).toSeq
+    assert(statements === Seq("con1", "con2", "body2"))
+  }
+
+  test("searched case - without else (unsuccessful checks)") {
+    val iter = new CompoundBodyExec(Seq(
+      new CaseStatementExec(
+        conditions = Seq(
+          TestIfElseCondition(condVal = false, description = "con1"),
+          TestIfElseCondition(condVal = false, description = "con2")
+        ),
+        conditionalBodies = Seq(
+          new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+          new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+        ),
+        elseBody = None,
+        session = spark
+      )
+    )).getTreeIterator
+    val statements = iter.map(extractStatementValue).toSeq
+    assert(statements === Seq("con1", "con2"))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 8d9cd1d8c780..4851faf897a0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.scripting
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkNumberFormatException}
 import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, 
Row}
 import org.apache.spark.sql.catalyst.QueryPlanningTracker
 import org.apache.spark.sql.exceptions.SqlScriptingException
@@ -368,6 +368,383 @@ class SqlScriptingInterpreterSuite extends QueryTest with 
SharedSparkSession {
     }
   }
 
+  test("searched case") {
+    val commands =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 = 1 THEN
+        |     SELECT 42;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(42)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("searched case nested") {
+    val commands =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1=1 THEN
+        |   CASE
+        |    WHEN 2=1 THEN
+        |     SELECT 41;
+        |   ELSE
+        |     SELECT 42;
+        |   END CASE;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(42)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("searched case second case") {
+    val commands =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 = (SELECT 2) THEN
+        |     SELECT 1;
+        |   WHEN 2 = 2 THEN
+        |     SELECT 42;
+        |   WHEN (SELECT * FROM t) THEN
+        |     SELECT * FROM b;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(42)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("searched case going in else") {
+    val commands =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 2 = 1 THEN
+        |     SELECT 1;
+        |   WHEN 3 IN (1,2) THEN
+        |     SELECT 2;
+        |   ELSE
+        |     SELECT 43;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(43)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("searched case with count") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+          |INSERT INTO t VALUES (1, 'a', 1.0);
+          |INSERT INTO t VALUES (1, 'a', 1.0);
+          |CASE
+          | WHEN (SELECT COUNT(*) > 2 FROM t) THEN
+          |   SELECT 42;
+          | ELSE
+          |   SELECT 43;
+          | END CASE;
+          |END
+          |""".stripMargin
+
+      val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], 
Seq(Row(43)))
+      verifySqlScriptResult(commands, expected)
+    }
+  }
+
+  test("searched case else with count") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          |  CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+          |  INSERT INTO t VALUES (1, 'a', 1.0);
+          |  INSERT INTO t VALUES (1, 'a', 1.0);
+          |  CASE
+          |  WHEN (SELECT COUNT(*) > 2 FROM t) THEN
+          |   SELECT 42;
+          |  WHEN (SELECT COUNT(*) > 1 FROM t) THEN
+          |   SELECT 43;
+          |  ELSE
+          |    SELECT 44;
+          |  END CASE;
+          |END
+          |""".stripMargin
+
+      val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], 
Seq(Row(43)))
+      verifySqlScriptResult(commands, expected)
+    }
+  }
+
+  test("searched case no cases matched no else") {
+    val commands =
+      """
+        |BEGIN
+        | CASE
+        |   WHEN 1 = 2 THEN
+        |     SELECT 42;
+        |   WHEN 1 = 3 THEN
+        |     SELECT 43;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq()
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("searched case when evaluates to null") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          |  CREATE TABLE t (a BOOLEAN) USING parquet;
+          |  CASE
+          |  WHEN (SELECT * FROM t) THEN
+          |   SELECT 42;
+          |  END CASE;
+          |END
+          |""".stripMargin
+
+      checkError(
+        exception = intercept[SqlScriptingException] (
+          runSqlScript(commands)
+        ),
+        condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
+        parameters = Map("invalidStatement" -> "(SELECT * FROM T)")
+      )
+    }
+  }
+
+  test("searched case with non boolean condition - constant") {
+    val commands =
+      """
+        |BEGIN
+        |  CASE
+        |  WHEN 1 THEN
+        |   SELECT 42;
+        |  END CASE;
+        |END
+        |""".stripMargin
+
+    checkError(
+      exception = intercept[SqlScriptingException] (
+        runSqlScript(commands)
+      ),
+      condition = "INVALID_BOOLEAN_STATEMENT",
+      parameters = Map("invalidStatement" -> "1")
+    )
+  }
+
+  test("searched case with too many rows in subquery condition") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          | CREATE TABLE t (a BOOLEAN) USING parquet;
+          | INSERT INTO t VALUES (true);
+          | INSERT INTO t VALUES (true);
+          | CASE
+          |   WHEN (SELECT * FROM t) THEN
+          |     SELECT 1;
+          | END CASE;
+          |END
+          |""".stripMargin
+
+      checkError(
+        exception = intercept[SparkException] (
+          runSqlScript(commands)
+        ),
+        condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
+        parameters = Map.empty,
+        context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 124, 
stop = 140)
+      )
+    }
+  }
+
+  test("simple case") {
+    val commands =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 1 THEN
+        |     SELECT 42;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(42)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("simple case nested") {
+    val commands =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 1 THEN
+        |   CASE 2
+        |    WHEN (SELECT 3) THEN
+        |     SELECT 41;
+        |   ELSE
+        |     SELECT 42;
+        |   END CASE;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(42)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("simple case second case") {
+    val commands =
+      """
+        |BEGIN
+        | CASE (SELECT 2)
+        |   WHEN 1 THEN
+        |     SELECT 1;
+        |   WHEN 2 THEN
+        |     SELECT 42;
+        |   WHEN (SELECT * FROM t) THEN
+        |     SELECT * FROM b;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(42)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("simple case going in else") {
+    val commands =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 2 THEN
+        |     SELECT 1;
+        |   WHEN 3 THEN
+        |     SELECT 2;
+        |   ELSE
+        |     SELECT 43;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(43)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("simple case with count") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+          |INSERT INTO t VALUES (1, 'a', 1.0);
+          |INSERT INTO t VALUES (1, 'a', 1.0);
+          |CASE (SELECT COUNT(*) FROM t)
+          | WHEN 1 THEN
+          |   SELECT 41;
+          | WHEN 2 THEN
+          |   SELECT 42;
+          | ELSE
+          |   SELECT 43;
+          | END CASE;
+          |END
+          |""".stripMargin
+
+      val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], 
Seq(Row(42)))
+      verifySqlScriptResult(commands, expected)
+    }
+  }
+
+  test("simple case else with count") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          |  CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+          |  INSERT INTO t VALUES (1, 'a', 1.0);
+          |  INSERT INTO t VALUES (2, 'b', 2.0);
+          |  CASE (SELECT COUNT(*) FROM t)
+          |   WHEN 1 THEN
+          |     SELECT 42;
+          |   WHEN 3 THEN
+          |     SELECT 43;
+          |   ELSE
+          |     SELECT 44;
+          |  END CASE;
+          |END
+          |""".stripMargin
+
+      val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], 
Seq(Row(44)))
+      verifySqlScriptResult(commands, expected)
+    }
+  }
+
+  test("simple case no cases matched no else") {
+    val commands =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN 2 THEN
+        |     SELECT 42;
+        |   WHEN 3 THEN
+        |     SELECT 43;
+        | END CASE;
+        |END
+        |""".stripMargin
+    val expected = Seq()
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("simple case mismatched types") {
+    val commands =
+      """
+        |BEGIN
+        | CASE 1
+        |   WHEN "one" THEN
+        |     SELECT 42;
+        | END CASE;
+        |END
+        |""".stripMargin
+
+    checkError(
+      exception = intercept[SparkNumberFormatException] (
+        runSqlScript(commands)
+      ),
+      condition = "CAST_INVALID_INPUT",
+      parameters = Map(
+        "expression" -> "'one'",
+        "sourceType" -> "\"STRING\"",
+        "targetType" -> "\"BIGINT\""),
+      context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27)
+    )
+  }
+
+  test("simple case compare with null") {
+    withTable("t") {
+      val commands =
+        """
+          |BEGIN
+          |  CREATE TABLE t (a INT) USING parquet;
+          |  CASE (SELECT COUNT(*) FROM t)
+          |   WHEN 1 THEN
+          |     SELECT 42;
+          |   ELSE
+          |     SELECT 43;
+          |  END CASE;
+          |END
+          |""".stripMargin
+
+      val expected = Seq(Seq.empty[Row], Seq(Row(43)))
+      verifySqlScriptResult(commands, expected)
+    }
+  }
+
   test("if's condition must be a boolean statement") {
     withTable("t") {
       val commands =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to