This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 47ac097f61c [SPARK-43207][CONNECT] Add helper functions to extract 
value from literal expression
47ac097f61c is described below

commit 47ac097f61c0185cd5a0674528020a65917b7b90
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Apr 20 16:06:28 2023 +0800

    [SPARK-43207][CONNECT] Add helper functions to extract value from literal 
expression
    
    ### What changes were proposed in this pull request?
    Add helper functions for extract value from literal expression
    
    ### Why are the changes needed?
    some logic should be reused
    
    ### Does this PR introduce _any_ user-facing change?
    no, dev-only
    
    ### How was this patch tested?
    existing UTs
    
    Closes #40863 from zhengruifeng/connect_helper.
    
    Lead-authored-by: Ruifeng Zheng <ruife...@apache.org>
    Co-authored-by: Ruifeng Zheng <ruife...@foxmail.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 123 +++++++--------------
 1 file changed, 37 insertions(+), 86 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 5f39fcd17f7..e4522cea747 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1240,83 +1240,49 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformUnregisteredFunction(
       fun: proto.Expression.UnresolvedFunction): Option[Expression] = {
     fun.getFunctionName match {
-      case "product" =>
-        if (fun.getArgumentsCount != 1) {
-          throw InvalidPlanInput("Product requires single child expression")
-        }
+      case "product" if fun.getArgumentsCount == 1 =>
         Some(
           aggregate
             .Product(transformExpression(fun.getArgumentsList.asScala.head))
             .toAggregateExpression())
 
-      case "when" =>
-        if (fun.getArgumentsCount == 0) {
-          throw InvalidPlanInput("CaseWhen requires at least one child 
expression")
-        }
+      case "when" if fun.getArgumentsCount > 0 =>
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
         Some(CaseWhen.createFromParser(children))
 
-      case "in" =>
-        if (fun.getArgumentsCount == 0) {
-          throw InvalidPlanInput("In requires at least one child expression")
-        }
+      case "in" if fun.getArgumentsCount > 0 =>
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
         Some(In(children.head, children.tail))
 
       case "nth_value" if fun.getArgumentsCount == 3 =>
         // NthValue does not have a constructor which accepts Expression typed 
'ignoreNulls'
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
-        val ignoreNulls = children.last match {
-          case Literal(bool: Boolean, BooleanType) => bool
-          case other =>
-            throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, 
but got $other")
-        }
+        val ignoreNulls = extractBoolean(children(2), "ignoreNulls")
         Some(NthValue(children(0), children(1), ignoreNulls))
 
       case "lag" if fun.getArgumentsCount == 4 =>
         // Lag does not have a constructor which accepts Expression typed 
'ignoreNulls'
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
-        val ignoreNulls = children.last match {
-          case Literal(bool: Boolean, BooleanType) => bool
-          case other =>
-            throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, 
but got $other")
-        }
+        val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
         Some(Lag(children.head, children(1), children(2), ignoreNulls))
 
       case "lead" if fun.getArgumentsCount == 4 =>
         // Lead does not have a constructor which accepts Expression typed 
'ignoreNulls'
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
-        val ignoreNulls = children.last match {
-          case Literal(bool: Boolean, BooleanType) => bool
-          case other =>
-            throw InvalidPlanInput(s"ignoreNulls should be a literal boolean, 
but got $other")
-        }
+        val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
         Some(Lead(children.head, children(1), children(2), ignoreNulls))
 
-      case "window" if 2 <= fun.getArgumentsCount && fun.getArgumentsCount <= 
4 =>
+      case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
         val timeCol = children.head
-        val args = children.tail.map {
-          case Literal(s, StringType) if s != null => s.toString
-          case other =>
-            throw InvalidPlanInput(
-              s"windowDuration,slideDuration,startTime should be literal 
strings, but got $other")
+        val windowDuration = extractString(children(1), "windowDuration")
+        var slideDuration = windowDuration
+        if (fun.getArgumentsCount >= 3) {
+          slideDuration = extractString(children(2), "slideDuration")
         }
-        var windowDuration: String = null
-        var slideDuration: String = null
-        var startTime: String = null
-        if (args.length == 3) {
-          windowDuration = args(0)
-          slideDuration = args(1)
-          startTime = args(2)
-        } else if (args.length == 2) {
-          windowDuration = args(0)
-          slideDuration = args(1)
-          startTime = "0 second"
-        } else {
-          windowDuration = args(0)
-          slideDuration = args(0)
-          startTime = "0 second"
+        var startTime = "0 second"
+        if (fun.getArgumentsCount == 4) {
+          startTime = extractString(children(3), "startTime")
         }
         Some(
           Alias(TimeWindow(timeCol, windowDuration, slideDuration, startTime), 
"window")(
@@ -1373,20 +1339,10 @@ class SparkConnectPlanner(val session: SparkSession) {
         }
 
         if (schema != null) {
-          val options = if (children.length == 3) {
-            // ExprUtils.convertToMapData requires the options to be resolved 
CreateMap,
-            // but the options here is not resolved yet: 
UnresolvedFunction("map", ...)
-            children(2) match {
-              case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
-                ExprUtils.convertToMapData(CreateMap(arguments))
-              case other =>
-                throw InvalidPlanInput(
-                  s"Options in from_json should be created by map, but got 
$other")
-            }
-          } else {
-            Map.empty[String, String]
+          var options = Map.empty[String, String]
+          if (children.length == 3) {
+            options = extractMapData(children(2), "Options")
           }
-
           Some(
             JsonToStructs(
               schema = CharVarcharUtils.failIfHasCharVarchar(schema),
@@ -1399,21 +1355,10 @@ class SparkConnectPlanner(val session: SparkSession) {
       // Avro-specific functions
       case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
-        val jsonFormatSchema = children(1) match {
-          case Literal(s, StringType) if s != null => s.toString
-          case other =>
-            throw InvalidPlanInput(
-              s"jsonFormatSchema in from_avro should be a literal string, but 
got $other")
-        }
+        val jsonFormatSchema = extractString(children(1), "jsonFormatSchema")
         var options = Map.empty[String, String]
         if (fun.getArgumentsCount == 3) {
-          children(2) match {
-            case UnresolvedFunction(Seq("map"), arguments, _, _, _) =>
-              options = ExprUtils.convertToMapData(CreateMap(arguments))
-            case other =>
-              throw InvalidPlanInput(
-                s"Options in from_json should be created by map, but got 
$other")
-          }
+          options = extractMapData(children(2), "Options")
         }
         Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options))
 
@@ -1421,12 +1366,7 @@ class SparkConnectPlanner(val session: SparkSession) {
         val children = 
fun.getArgumentsList.asScala.toSeq.map(transformExpression)
         var jsonFormatSchema = Option.empty[String]
         if (fun.getArgumentsCount == 2) {
-          children(1) match {
-            case Literal(s, StringType) if s != null => jsonFormatSchema = 
Some(s.toString)
-            case other =>
-              throw InvalidPlanInput(
-                s"jsonFormatSchema in to_avro should be a literal string, but 
got $other")
-          }
+          jsonFormatSchema = Some(extractString(children(1), 
"jsonFormatSchema"))
         }
         Some(CatalystDataToAvro(children.head, jsonFormatSchema))
 
@@ -1437,12 +1377,7 @@ class SparkConnectPlanner(val session: SparkSession) {
       // ML-specific functions
       case "vector_to_array" if fun.getArgumentsCount == 2 =>
         val expr = transformExpression(fun.getArguments(0))
-        val dtype = transformExpression(fun.getArguments(1)) match {
-          case Literal(s, StringType) if s != null => s.toString
-          case other =>
-            throw InvalidPlanInput(
-              s"dtype in vector_to_array should be a literal string, but got 
$other")
-        }
+        val dtype = extractString(transformExpression(fun.getArguments(1)), 
"dtype")
         dtype match {
           case "float64" =>
             Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, 
Seq(expr)))
@@ -1479,6 +1414,22 @@ class SparkConnectPlanner(val session: SparkSession) {
       udfDeterministic = f.deterministic)
   }
 
+  private def extractBoolean(expr: Expression, field: String): Boolean = expr 
match {
+    case Literal(bool: Boolean, BooleanType) => bool
+    case other => throw InvalidPlanInput(s"$field should be a literal boolean, 
but got $other")
+  }
+
+  private def extractString(expr: Expression, field: String): String = expr 
match {
+    case Literal(s, StringType) if s != null => s.toString
+    case other => throw InvalidPlanInput(s"$field should be a literal string, 
but got $other")
+  }
+
+  private def extractMapData(expr: Expression, field: String): Map[String, 
String] = expr match {
+    case map: CreateMap => ExprUtils.convertToMapData(map)
+    case UnresolvedFunction(Seq("map"), args, _, _, _) => 
extractMapData(CreateMap(args), field)
+    case other => throw InvalidPlanInput(s"$field should be created by map, 
but got $other")
+  }
+
   private def transformAlias(alias: proto.Expression.Alias): NamedExpression = 
{
     if (alias.getNameCount == 1) {
       val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to