beliefer commented on a change in pull request #30981: URL: https://github.com/apache/spark/pull/30981#discussion_r551131846
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala ########## @@ -751,267 +751,348 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } -trait String2TrimExpression extends Expression with ImplicitCastInputTypes { +trait TrimExpression extends Expression with ImplicitCastInputTypes { - protected def srcStr: Expression - protected def trimStr: Option[Expression] + protected def srcExpr: Expression + protected def trimExprOpt: Option[Expression] protected def direction: String - override def children: Seq[Expression] = srcStr +: trimStr.toSeq - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + override def children: Seq[Expression] = srcExpr +: trimExprOpt.toSeq + override def dataType: DataType = srcExpr.dataType + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(TypeCollection(StringType, BinaryType)) + + override def checkInputDataTypes(): TypeCheckResult = { + val inputTypeCheck = super.checkInputDataTypes() + if (inputTypeCheck.isSuccess) { + TypeUtils.checkForSameTypeInputExpr( + children.map(_.dataType), s"function $prettyName") + } else { + inputTypeCheck + } + } override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) protected def doEval(srcString: UTF8String): UTF8String + protected def doEval(srcBytes: Array[Byte]): Array[Byte] protected def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String + protected def doEval(srcBytes: Array[Byte], trimBytes: Array[Byte]): Array[Byte] + + private lazy val evalFunc = srcExpr.dataType match { + case StringType => + (input: InternalRow) => { + val srcString = srcExpr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else if (trimExprOpt.isDefined) { + doEval(srcString, trimExprOpt.get.eval(input).asInstanceOf[UTF8String]) + } else { + doEval(srcString) + } + } + case BinaryType => + (input: InternalRow) => { + val srcBytes = srcExpr.eval (input).asInstanceOf[Array[Byte]] + if (srcBytes == null) { + null + } else if (trimExprOpt.isDefined) { + doEval(srcBytes, trimExprOpt.get.eval(input).asInstanceOf[Array[Byte]]) + } else { + doEval(srcBytes) + } + } + } override def eval(input: InternalRow): Any = { - val srcString = srcStr.eval(input).asInstanceOf[UTF8String] - if (srcString == null) { - null - } else if (trimStr.isDefined) { - doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String]) - } else { - doEval(srcString) - } + evalFunc(input) } protected val trimMethod: String + private lazy val resultType = srcExpr.dataType match { Review comment: OK ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org