Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/12646#discussion_r131534214 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala --- @@ -502,69 +503,311 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } +trait String2TrimExpression extends Expression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def sql: String = { + if (children.size == 1) { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } else { + val trimSQL = children(0).map(_.sql).mkString(", ") + val tarSQL = children(1).map(_.sql).mkString(", ") + s"$prettyName($trimSQL, $tarSQL)" + } + } +} + +object StringTrim { + def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr)) + def apply(str: Expression) : StringTrim = StringTrim(str, None) +} + /** - * A function that trim the spaces from both ends for the specified string. - */ + * A function that takes a character string, removes the leading and trailing characters matching with the characters + * in the trim string, returns the new string. + * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim + * function will have one argument, which contains the source string. + * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have + * two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function + * searches for each character in the source string, removes the characters from the source string until it + * encounters the first non-match character. + * BOTH: removes any characters from both ends of the source string that matches characters in the trim string. + */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` + """, extended = """ + Arguments: + str - a string expression + trimString - the trim string + BOTH, FROM - these are keyword to specify for trim string from both ends of the string Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); + parkSQ """) -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrim( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { + + def this (trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) - def convert(v: UTF8String): UTF8String = v.trim() + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "trim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trim() + } + } + null + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trim()") + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(); + } + """.stripMargin) + } else { + val trimString = evals(1) + val getTrimFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimFunction + } + """.stripMargin) + } } } +object StringTrimLeft { + def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr)) + def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None) +} + /** - * A function that trim the spaces from left end for given string. + * A function that trims the characters from left end for a given string. + * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim + * function will have one argument, which contains the source string. + * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any characters from the left end of the source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading space characters from `str`. + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string + """, extended = """ + Arguments: + str - a string expression + trimStr - the trim string Examples: - > SELECT _FUNC_(' SparkSQL'); + > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_('Sp', 'SSparkSQLS'); + arkSQLS """) -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimLeft( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { + + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) - def convert(v: UTF8String): UTF8String = v.trimLeft() + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trimLeft() + } + } + null + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(); + }""".stripMargin) + } else { + val trimString = evals(1) + val getTrimLeftFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimLeftFunction + } + """.stripMargin ) + } + } +} + +object StringTrimRight { + def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr)) + def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None) } /** - * A function that trim the spaces from right end for given string. + * A function that trims the characters from right end for a given string. + * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The + * rtrim function will have one argument, which contains the source string. + * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any characters from the right end of source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the trailing space characters from `str`. + _FUNC_(trimStr, str) - Removes the trailing string which contains the character from the trim string from the `str` + """, extended = """ --- End diff -- Recently, we have the update in `ExpressionDescription`. We need an update here.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org