This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5533c81e3453 [SPARK-48355][SQL] Support for CASE statement 5533c81e3453 is described below commit 5533c81e34534d43ae90fc2ce5ac1d174d4e8289 Author: Dušan Tišma <dusan.ti...@databricks.com> AuthorDate: Fri Sep 13 15:01:09 2024 +0200 [SPARK-48355][SQL] Support for CASE statement ### What changes were proposed in this pull request? Add support for [case statements](https://docs.google.com/document/d/1cpSuR3KxRuTSJ4ZMQ73FJ4_-hjouNNU2zfI4vri6yhs/edit#heading=h.ofijhkunigv) to sql scripting. There are 2 types of case statement - simple and searched (EXAMPLES BELOW). Proposed changes are: - Add `caseStatement` grammar rule to SqlBaseParser.g4 - Add visit case statement methods to `AstBuilder` - Add `SearchedCaseStatement` and `SearchedCaseStatementExec` classes, to enable them to be run in sql scripts. The reason only searched case nodes are added is that, in the current implementation, a simple case is parsed into a searched case, by creating internal `EqualTo` expressions to compare the main case expression to the expressions in the when clauses. This approach is similar to the existing case **expressions**, which are parsed in the same way. The problem with this approach is that the main expression is unnecessarily evaluated N times, where N is the number of when clauses, which c [...] Simple case compares one expression (case variable) to others, until an equal one is found. Else clause is optional. ``` BEGIN CASE 1 WHEN 1 THEN SELECT 1; WHEN 2 THEN SELECT 2; ELSE SELECT 3; END CASE; END ``` Searched case evaluates boolean expressions. Else clause is optional. ``` BEGIN CASE WHEN 1 = 1 THEN SELECT 1; WHEN 2 IN (1,2,3) THEN SELECT 2; ELSE SELECT 3; END CASE; END ``` ### Why are the changes needed? Case statements are currently not implemented in sql scripting. ### Does this PR introduce _any_ user-facing change? Yes, users will now be able to use case statements in their sql scripts. ### How was this patch tested? Tests for both simple and searched case statements are added to SqlScriptingParserSuite, SqlScriptingExecutionNodeSuite and SqlScriptingInterpreterSuite. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47672 from dusantism-db/sql-scripting-case-statement. Authored-by: Dušan Tišma <dusan.ti...@databricks.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 8 + .../spark/sql/catalyst/parser/AstBuilder.scala | 48 ++- .../parser/SqlScriptingLogicalOperators.scala | 14 + .../catalyst/parser/SqlScriptingParserSuite.scala | 297 +++++++++++++++- .../sql/scripting/SqlScriptingExecutionNode.scala | 72 ++++ .../sql/scripting/SqlScriptingInterpreter.scala | 13 +- .../scripting/SqlScriptingExecutionNodeSuite.scala | 93 +++++ .../scripting/SqlScriptingInterpreterSuite.scala | 379 ++++++++++++++++++++- 8 files changed, 920 insertions(+), 4 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 42f0094de351..73d5cb55295a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -64,6 +64,7 @@ compoundStatement | setStatementWithOptionalVarKeyword | beginEndCompoundBlock | ifElseStatement + | caseStatement | whileStatement | repeatStatement | leaveStatement @@ -98,6 +99,13 @@ iterateStatement : ITERATE multipartIdentifier ; +caseStatement + : CASE (WHEN conditions+=booleanExpression THEN conditionalBodies+=compoundBody)+ + (ELSE elseBody=compoundBody)? END CASE #searchedCaseStatement + | CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN conditionalBodies+=compoundBody)+ + (ELSE elseBody=compoundBody)? END CASE #simpleCaseStatement + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 924b5c2cfeb1..9620ce13d92e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -261,6 +261,52 @@ class AstBuilder extends DataTypeAstBuilder WhileStatement(condition, body, Some(labelText)) } + override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = { + val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) { + SingleStatement( + Project( + Seq(Alias(expression(boolExpr), "condition")()), + OneRowRelation())) + }) + val conditionalBodies = + ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)) + + if (conditions.length != conditionalBodies.length) { + throw SparkException.internalError( + s"Mismatched number of conditions ${conditions.length} and condition bodies" + + s" ${conditionalBodies.length} in case statement") + } + + CaseStatement( + conditions = conditions, + conditionalBodies = conditionalBodies, + elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))) + } + + override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = { + // uses EqualTo to compare the case variable(the main case expression) + // to the WHEN clause expressions + val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) { + SingleStatement( + Project( + Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()), + OneRowRelation())) + }) + val conditionalBodies = + ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)) + + if (conditions.length != conditionalBodies.length) { + throw SparkException.internalError( + s"Mismatched number of conditions ${conditions.length} and condition bodies" + + s" ${conditionalBodies.length} in case statement") + } + + CaseStatement( + conditions = conditions, + conditionalBodies = conditionalBodies, + elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))) + } + override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = { val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) val boolExpr = ctx.booleanExpression() @@ -292,7 +338,7 @@ class AstBuilder extends DataTypeAstBuilder case c: RepeatStatementContext if Option(c.beginLabel()).isDefined && c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + => true case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index 5e7e8b0b4fc9..ed40a5fd734b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -124,3 +124,17 @@ case class LeaveStatement(label: String) extends CompoundPlanStatement * @param label Label of the loop to iterate. */ case class IterateStatement(label: String) extends CompoundPlanStatement + +/** + * Logical operator for CASE statement. + * @param conditions Collection of conditions which correspond to WHEN clauses. + * @param conditionalBodies Collection of bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + */ +case class CaseStatement( + conditions: Seq[SingleStatement], + conditionalBodies: Seq[CompoundBody], + elseBody: Option[CompoundBody]) extends CompoundPlanStatement { + assert(conditions.length == conditionalBodies.length) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index bf527b9c3bd7..24ad32c5300b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.CreateVariable +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, Project} import org.apache.spark.sql.exceptions.SqlScriptingException class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { @@ -1111,6 +1112,287 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } + test("searched case statement") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 = 1") + } + + test("searched case statement - multi when") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 IN (1,2,3) THEN + | SELECT 1; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 3) + assert(caseStmt.conditionalBodies.length == 3) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 IN (1,2,3)") + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(caseStmt.conditions(1).isInstanceOf[SingleStatement]) + assert(caseStmt.conditions(1).getText == "(SELECT * FROM t)") + + assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT * FROM b") + + assert(caseStmt.conditions(2).isInstanceOf[SingleStatement]) + assert(caseStmt.conditions(2).getText == "1 = 1") + + assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + + test("searched case statement with else") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.elseBody.isDefined) + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 = 1") + + assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 43") + } + + test("searched case statement nested") { + val sqlScriptText = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | CASE + | WHEN 2 = 1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditionalBodies.length == 1) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditions.head.getText == "1 = 1") + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement]) + val nestedCaseStmt = + caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement] + + assert(nestedCaseStmt.conditions.length == 1) + assert(nestedCaseStmt.conditionalBodies.length == 1) + assert(nestedCaseStmt.elseBody.isDefined) + + assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.conditions.head.getText == "2 = 1") + + assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 41") + + assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + + test("simple case statement") { + val sqlScriptText = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 1; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + } + + + test("simple case statement - multi when") { + val sqlScriptText = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 1; + | WHEN (SELECT 2) THEN + | SELECT * FROM b; + | WHEN 3 IN (1,2,3) THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 3) + assert(caseStmt.conditionalBodies.length == 3) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(caseStmt.conditions(1).isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery]) + + assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT * FROM b") + + assert(caseStmt.conditions(2).isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In]) + + assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + + test("simple case statement with else") { + val sqlScriptText = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.elseBody.isDefined) + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + + assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 43") + } + + test("simple case statement nested") { + val sqlScriptText = + """ + |BEGIN + | CASE (SELECT 1) + | WHEN 1 THEN + | CASE 2 + | WHEN 2 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CaseStatement]) + + val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(caseStmt.conditions.length == 1) + assert(caseStmt.conditionalBodies.length == 1) + assert(caseStmt.elseBody.isEmpty) + + assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == Literal(1)) + + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement]) + val nestedCaseStmt = + caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement] + + assert(nestedCaseStmt.conditions.length == 1) + assert(nestedCaseStmt.conditionalBodies.length == 1) + assert(nestedCaseStmt.elseBody.isDefined) + + assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement]) + checkSimpleCaseStatementCondition( + nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2)) + + assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 41") + + assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr @@ -1119,4 +1401,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .replace("END", "") .trim } + + private def checkSimpleCaseStatementCondition( + conditionStatement: SingleStatement, + predicateLeft: Expression => Boolean, + predicateRight: Expression => Boolean): Unit = { + assert(conditionStatement.parsedPlan.isInstanceOf[Project]) + val project = conditionStatement.parsedPlan.asInstanceOf[Project] + assert(project.projectList.head.isInstanceOf[Alias]) + assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo]) + val equalTo = project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo] + assert(predicateLeft(equalTo.left)) + assert(predicateRight(equalTo.right)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index cae797614314..af9fd5464277 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -405,6 +405,78 @@ class WhileStatementExec( } } +/** + * Executable node for CaseStatement. + * @param conditions Collection of executable conditions which correspond to WHEN clauses. + * @param conditionalBodies Collection of executable bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + * @param session Spark session that SQL script is executed within. + */ +class CaseStatementExec( + conditions: Seq[SingleStatementExec], + conditionalBodies: Seq[CompoundBodyExec], + elseBody: Option[CompoundBodyExec], + session: SparkSession) extends NonLeafStatementExec { + private object CaseState extends Enumeration { + val Condition, Body = Value + } + + private var state = CaseState.Condition + private var curr: Option[CompoundStatementExec] = Some(conditions.head) + + private var clauseIdx: Int = 0 + private val conditionsCount = conditions.length + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = curr.nonEmpty + + override def next(): CompoundStatementExec = state match { + case CaseState.Condition => + val condition = curr.get.asInstanceOf[SingleStatementExec] + if (evaluateBooleanCondition(session, condition)) { + state = CaseState.Body + curr = Some(conditionalBodies(clauseIdx)) + } else { + clauseIdx += 1 + if (clauseIdx < conditionsCount) { + // There are WHEN clauses remaining. + state = CaseState.Condition + curr = Some(conditions(clauseIdx)) + } else if (elseBody.isDefined) { + // ELSE clause exists. + state = CaseState.Body + curr = Some(elseBody.get) + } else { + // No remaining clauses. + curr = None + } + } + condition + case CaseState.Body => + assert(curr.get.isInstanceOf[CompoundBodyExec]) + val currBody = curr.get.asInstanceOf[CompoundBodyExec] + val retStmt = currBody.getTreeIterator.next() + if (!currBody.getTreeIterator.hasNext) { + curr = None + } + retStmt + } + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = CaseState.Condition + curr = Some(conditions.head) + clauseIdx = 0 + conditions.foreach(c => c.reset()) + conditionalBodies.foreach(b => b.reset()) + elseBody.foreach(b => b.reset()) + } +} + /** * Executable node for RepeatStatement. * @param condition Executable node for the condition - evaluates to a row with a single boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 865b33999655..917b4d6f45ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -95,6 +95,17 @@ case class SqlScriptingInterpreter() { new IfElseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) + case CaseStatement(conditions, conditionalBodies, elseBody) => + val conditionsExec = conditions.map(condition => + // todo: what to put here for isInternal, in case of simple case statement + new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)) + val conditionalBodiesExec = conditionalBodies.map(body => + transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) + val unconditionalBodiesExec = elseBody.map(body => + transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]) + new CaseStatementExec( + conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) + case WhileStatement(condition, body, label) => val conditionExec = new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 4b72ca8ecaa9..83d8191d01ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -576,4 +576,97 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "body1", "lbl", "con1", "body1", "lbl", "con1")) } + + test("searched case - enter first WHEN clause") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = true, description = "con1"), + TestIfElseCondition(condVal = false, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body1")) + } + + test("searched case - enter body of the ELSE clause") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))) + ), + elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "body2")) + } + + test("searched case - enter second WHEN clause") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1"), + TestIfElseCondition(condVal = true, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2", "body2")) + } + + test("searched case - without else (successful check)") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1"), + TestIfElseCondition(condVal = true, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = None, + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2", "body2")) + } + + test("searched case - without else (unsuccessful checks)") { + val iter = new CompoundBodyExec(Seq( + new CaseStatementExec( + conditions = Seq( + TestIfElseCondition(condVal = false, description = "con1"), + TestIfElseCondition(condVal = false, description = "con2") + ), + conditionalBodies = Seq( + new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + new CompoundBodyExec(Seq(TestLeafStatement("body2"))) + ), + elseBody = None, + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("con1", "con2")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8d9cd1d8c780..4851faf897a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.scripting -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkNumberFormatException} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.exceptions.SqlScriptingException @@ -368,6 +368,383 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("searched case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case nested") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1=1 THEN + | CASE + | WHEN 2=1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case second case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = (SELECT 2) THEN + | SELECT 1; + | WHEN 2 = 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case going in else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 2 = 1 THEN + | SELECT 1; + | WHEN 3 IN (1,2) THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("searched case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (1, 'a', 1.0); + | CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | WHEN (SELECT COUNT(*) > 1 FROM t) THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 2 THEN + | SELECT 42; + | WHEN 1 = 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq() + verifySqlScriptResult(commands, expected) + } + + test("searched case when evaluates to null") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a BOOLEAN) USING parquet; + | CASE + | WHEN (SELECT * FROM t) THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] ( + runSqlScript(commands) + ), + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "(SELECT * FROM T)") + ) + } + } + + test("searched case with non boolean condition - constant") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] ( + runSqlScript(commands) + ), + condition = "INVALID_BOOLEAN_STATEMENT", + parameters = Map("invalidStatement" -> "1") + ) + } + + test("searched case with too many rows in subquery condition") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a BOOLEAN) USING parquet; + | INSERT INTO t VALUES (true); + | INSERT INTO t VALUES (true); + | CASE + | WHEN (SELECT * FROM t) THEN + | SELECT 1; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SparkException] ( + runSqlScript(commands) + ), + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty, + context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 124, stop = 140) + ) + } + } + + test("simple case") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case nested") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | CASE 2 + | WHEN (SELECT 3) THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case second case") { + val commands = + """ + |BEGIN + | CASE (SELECT 2) + | WHEN 1 THEN + | SELECT 1; + | WHEN 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case going in else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 1; + | WHEN 3 THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("simple case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (2, 'b', 2.0); + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row], Seq(Row(44))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq() + verifySqlScriptResult(commands, expected) + } + + test("simple case mismatched types") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN "one" THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + + checkError( + exception = intercept[SparkNumberFormatException] ( + runSqlScript(commands) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'one'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"BIGINT\""), + context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27) + ) + } + + test("simple case compare with null") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT) USING parquet; + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val expected = Seq(Seq.empty[Row], Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + test("if's condition must be a boolean statement") { withTable("t") { val commands = --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org