chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541513308


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends 
UnaryExpression
   override protected def withNewChildInternal(newChild: Expression): ParseJson 
=
     copy(child = newChild)
 }
+
+// A path segment in the `VariantGet` expression. It represents either an 
object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      PathSegment(null, index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      PathSegment(key, 0)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | 
index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. 
Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON 
path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an 
exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by 
timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    schema: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)
+        )
+      )
+    } else if (!VariantGet.checkDataType(schema)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters = Map(
+          "srcType" -> toSQLType(VariantType),
+          "targetType" -> toSQLType(schema)
+        )
+      )
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = schema.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = 
Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else 
"try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {

Review Comment:
   I didn't mean writing everything by hand. Essentially, we create a method 
that implements `VariantGet`, and the class only needs some boilerplate code to 
call this method (similar to the code in `StaticInvoke` itself).
   
   There is still another reason why I don't like `StaticInvoke`. In the 
future, I will write some optimizer rules on `VariantGet` (e.g., to push it 
down a join). This is why I added a new `TreePattern` ``VARIANT_GET` in this 
PR. The optimizer rule will run after `RuntimeReplaceable` expression is 
replaced, so it will become `StaticInvoke` and no longer has this tree pattern, 
and the optimizer rule can no longer prune expressions. Plus, matching against 
`StaticInvoke` is also more complex than matching against `VariantGet`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to