beliefer commented on a change in pull request #30981:
URL: https://github.com/apache/spark/pull/30981#discussion_r551131846



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
##########
@@ -751,267 +751,348 @@ case class FindInSet(left: Expression, right: 
Expression) extends BinaryExpressi
   override def prettyName: String = "find_in_set"
 }
 
-trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
+trait TrimExpression extends Expression with ImplicitCastInputTypes {
 
-  protected def srcStr: Expression
-  protected def trimStr: Option[Expression]
+  protected def srcExpr: Expression
+  protected def trimExprOpt: Option[Expression]
   protected def direction: String
 
-  override def children: Seq[Expression] = srcStr +: trimStr.toSeq
-  override def dataType: DataType = StringType
-  override def inputTypes: Seq[AbstractDataType] = 
Seq.fill(children.size)(StringType)
+  override def children: Seq[Expression] = srcExpr +: trimExprOpt.toSeq
+  override def dataType: DataType = srcExpr.dataType
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq.fill(children.size)(TypeCollection(StringType, BinaryType))
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val inputTypeCheck = super.checkInputDataTypes()
+    if (inputTypeCheck.isSuccess) {
+      TypeUtils.checkForSameTypeInputExpr(
+        children.map(_.dataType), s"function $prettyName")
+    } else {
+      inputTypeCheck
+    }
+  }
 
   override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
 
   protected def doEval(srcString: UTF8String): UTF8String
+  protected def doEval(srcBytes: Array[Byte]): Array[Byte]
   protected def doEval(srcString: UTF8String, trimString: UTF8String): 
UTF8String
+  protected def doEval(srcBytes: Array[Byte], trimBytes: Array[Byte]): 
Array[Byte]
+
+  private lazy val evalFunc = srcExpr.dataType match {
+    case StringType =>
+      (input: InternalRow) => {
+        val srcString = srcExpr.eval(input).asInstanceOf[UTF8String]
+        if (srcString == null) {
+          null
+        } else if (trimExprOpt.isDefined) {
+          doEval(srcString, 
trimExprOpt.get.eval(input).asInstanceOf[UTF8String])
+        } else {
+          doEval(srcString)
+        }
+      }
+    case BinaryType =>
+      (input: InternalRow) => {
+        val srcBytes = srcExpr.eval (input).asInstanceOf[Array[Byte]]
+        if (srcBytes == null) {
+          null
+        } else if (trimExprOpt.isDefined) {
+          doEval(srcBytes, 
trimExprOpt.get.eval(input).asInstanceOf[Array[Byte]])
+        } else {
+          doEval(srcBytes)
+        }
+      }
+  }
 
   override def eval(input: InternalRow): Any = {
-    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
-    if (srcString == null) {
-      null
-    } else if (trimStr.isDefined) {
-      doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String])
-    } else {
-      doEval(srcString)
-    }
+    evalFunc(input)
   }
 
   protected val trimMethod: String
 
+  private lazy val resultType = srcExpr.dataType match {

Review comment:
       OK




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to