cloud-fan commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1544183331


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson 
=
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object 
key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | 
index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. 
Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON 
path.
+ * @param targetType The target data type to cast into. Any non-nullable 
annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an 
exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by 
timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> 
toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = 
Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else 
"try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId 
= Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar 
types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | 
TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => 
v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => 
v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant 
represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an 
illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, 
cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` 
is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: 
Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, 
dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else 
invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This 
case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. 
However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and 
we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, 
EvalMode.TRY).eval()

Review Comment:
   It's risky to evaluate Cast on the fly, as we do not apply any analysis 
checks. Can we define the allowed type mapping before creating `Cast`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -53,3 +66,311 @@ case class ParseJson(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): ParseJson 
=
     copy(child = newChild)
 }
+
+object VariantPathParser extends RegexParsers {
+  // A path segment in the `VariantGet` expression represents either an object 
key access or an
+  // array index access.
+  type PathSegment = Either[String, Int]
+
+  private def root: Parser[Char] = '$'
+
+  // Parse index segment like `[123]`.
+  private def index: Parser[PathSegment] =
+    for {
+      index <- '[' ~> "\\d+".r <~ ']'
+    } yield {
+      scala.util.Right(index.toInt)
+    }
+
+  // Parse key segment like `.name`, `['name']`, or `["name"]`.
+  private def key: Parser[PathSegment] =
+    for {
+      key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+        "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+    } yield {
+      scala.util.Left(key)
+    }
+
+  private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | 
index))
+
+  def parse(str: String): Option[Array[PathSegment]] = {
+    this.parseAll(parser, str) match {
+      case Success(result, _) => Some(result.toArray)
+      case _ => None
+    }
+  }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions. 
Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON 
path.
+ * @param targetType The target data type to cast into. Any non-nullable 
annotations are ignored.
+ * @param failOnError Controls whether the expression should throw an 
exception or return null if
+ *                    the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by 
timestamp-related casts.
+ */
+case class VariantGet(
+    child: Expression,
+    path: Expression,
+    targetType: DataType,
+    failOnError: Boolean,
+    timeZoneId: Option[String] = None)
+    extends BinaryExpression
+    with TimeZoneAwareExpression
+    with NullIntolerant
+    with ExpectsInputTypes
+    with CodegenFallback
+    with QueryErrorsBase {
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else if (!path.foldable) {
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> toSQLId("path"),
+          "inputType" -> toSQLType(path.dataType),
+          "inputExpr" -> toSQLExpr(path)))
+    } else if (!VariantGet.checkDataType(targetType)) {
+      DataTypeMismatch(
+        errorSubClass = "CAST_WITHOUT_SUGGESTION",
+        messageParameters =
+          Map("srcType" -> toSQLType(VariantType), "targetType" -> 
toSQLType(targetType)))
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override lazy val dataType: DataType = targetType.asNullable
+
+  @transient private lazy val parsedPath = {
+    val pathValue = path.eval().toString
+    VariantPathParser.parse(pathValue).getOrElse {
+      throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+    }
+  }
+
+  final override def nodePatternsInternal(): Seq[TreePattern] = 
Seq(VARIANT_GET)
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+  override def prettyName: String = if (failOnError) "variant_get" else 
"try_variant_get"
+
+  override def nullable: Boolean = true
+
+  protected override def nullSafeEval(input: Any, path: Any): Any = {
+    VariantGet.variantGet(
+      input.asInstanceOf[VariantVal],
+      parsedPath,
+      dataType,
+      failOnError,
+      timeZoneId)
+  }
+
+  override def left: Expression = child
+
+  override def right: Expression = path
+
+  override protected def withNewChildrenInternal(
+      newChild: Expression,
+      newPath: Expression): VariantGet = copy(child = newChild, path = newPath)
+
+  override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId 
= Option(timeZoneId))
+}
+
+case object VariantGet {
+  /**
+   * Returns whether a data type can be cast into/from variant. For scalar 
types, we allow a subset
+   * of them. For nested types, we reject map types with a non-string key type.
+   */
+  def checkDataType(dataType: DataType): Boolean = dataType match {
+    case _: NumericType | BooleanType | StringType | BinaryType | 
TimestampType | DateType |
+        VariantType =>
+      true
+    case ArrayType(elementType, _) => checkDataType(elementType)
+    case MapType(StringType, valueType, _) => checkDataType(valueType)
+    case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+    case _ => false
+  }
+
+  /** The actual implementation of the `VariantGet` expression. */
+  def variantGet(
+      input: VariantVal,
+      parsedPath: Array[VariantPathParser.PathSegment],
+      dataType: DataType,
+      failOnError: Boolean,
+      zoneId: Option[String]): Any = {
+    var v = new Variant(input.getValue, input.getMetadata)
+    for (path <- parsedPath) {
+      v = path match {
+        case scala.util.Left(key) if v.getType == Type.OBJECT => 
v.getFieldByKey(key)
+        case scala.util.Right(index) if v.getType == Type.ARRAY => 
v.getElementAtIndex(index)
+        case _ => null
+      }
+      if (v == null) return null
+    }
+    VariantGet.cast(v, dataType, failOnError, zoneId)
+  }
+
+  /**
+   * Cast a variant `v` into a target data type `dataType`. If the variant 
represents a variant
+   * null, the result is always a SQL NULL. The cast may fail due to an 
illegal type combination
+   * (e.g., cast a variant int to binary), or an invalid input valid (e.g, 
cast a variant string
+   * "hello" to int). If the cast fails, throw an exception when `failOnError` 
is true, or return a
+   * SQL NULL when it is false.
+   */
+  def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId: 
Option[String]): Any = {
+    def invalidCast(): Any =
+      if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson, 
dataType) else null
+
+    val variantType = v.getType
+    if (variantType == Type.NULL) return null
+    dataType match {
+      case VariantType => new VariantVal(v.getValue, v.getMetadata)
+      case _: AtomicType =>
+        variantType match {
+          case Type.OBJECT | Type.ARRAY =>
+            if (dataType == StringType) UTF8String.fromString(v.toJson) else 
invalidCast()
+          case _ =>
+            val input = variantType match {
+              case Type.BOOLEAN => v.getBoolean
+              case Type.LONG => v.getLong
+              case Type.STRING => UTF8String.fromString(v.getString)
+              case Type.DOUBLE => v.getDouble
+              case Type.DECIMAL => Decimal(v.getDecimal)
+              // We have handled other cases and should never reach here. This 
case is only intended
+              // to by pass the compiler exhaustiveness check.
+              case _ => throw QueryExecutionErrors.unreachableError()
+            }
+            // We mostly use the `Cast` expression to implement the cast. 
However, `Cast` silently
+            // ignores the overflow in the long/decimal -> timestamp cast, and 
we want to enforce
+            // strict overflow checks.
+            input match {
+              case l: Long if dataType == TimestampType =>
+                try Math.multiplyExact(l, MICROS_PER_SECOND)
+                catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case d: Decimal if dataType == TimestampType =>
+                try {
+                  d.toJavaBigDecimal
+                    .multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
+                    .toBigInteger
+                    .longValueExact()
+                } catch {
+                  case _: ArithmeticException => invalidCast()
+                }
+              case _ =>
+                val result = Cast(Literal(input), dataType, zoneId, 
EvalMode.TRY).eval()

Review Comment:
   It's risky to evaluate Cast on the fly, as we do not apply any analysis 
checks for this Cast. Can we define the allowed type mapping before creating 
`Cast`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to