Github user chenghao-intel commented on a diff in the pull request:

    https://github.com/apache/spark/pull/6762#discussion_r33738979
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
 ---
    @@ -220,6 +222,404 @@ case class EndsWith(left: Expression, right: 
Expression)
     }
     
     /**
    + * A function that trim the spaces from both ends for the specified string.
    + */
    +case class StringTrim(child: Expression)
    +  extends UnaryExpression with String2StringExpression {
    +
    +  def convert(v: UTF8String): UTF8String = v.trim()
    +
    +  override def toString: String = s"TRIM($child)"
    +  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    +    defineCodeGen(ctx, ev, c => s"($c).trim()")
    +  }
    +}
    +
    +/**
    + * A function that trim the spaces from left end for given string.
    + */
    +case class StringTrimLeft(child: Expression)
    +  extends UnaryExpression with String2StringExpression {
    +
    +  def convert(v: UTF8String): UTF8String = v.trimLeft()
    +
    +  override def toString: String = s"LTRIM($child)"
    +  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    +    defineCodeGen(ctx, ev, c => s"($c).trimLeft()")
    +  }
    +}
    +
    +/**
    + * A function that trim the spaces from right end for given string.
    + */
    +case class StringTrimRight(child: Expression)
    +  extends UnaryExpression with String2StringExpression {
    +
    +  def convert(v: UTF8String): UTF8String = v.trimRight()
    +
    +  override def toString: String = s"RTRIM($child)"
    +  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    +    defineCodeGen(ctx, ev, c => s"($c).trimRight()")
    +  }
    +}
    +
    +/**
    + * A function that Returns the position of the first occurrence of substr 
in the given string.
    + * Returns null if either of the arguments are null and
    + * returns 0 if substr could not be found in str.
    + *
    + * Be aware that this is not zero based. The first character in str has 
index 1
    + */
    +case class StringInstr(str: Expression, substr: Expression)
    +  extends Expression with ExpectsInputTypes {
    +
    +  override def children: Seq[Expression] = str :: substr :: Nil
    +  override def foldable: Boolean = str.foldable && substr.foldable
    +  override def nullable: Boolean = str.nullable || substr.nullable
    +  override def dataType: DataType = IntegerType
    +  override def expectedChildTypes: Seq[DataType] = Seq(StringType, 
StringType)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val l = str.eval(input)
    +    if (l == null) {
    +      null
    +    } else {
    +      val r = substr.eval(input)
    +      if (r == null) {
    +        null
    +      } else {
    +        l.asInstanceOf[UTF8String].instr(r.asInstanceOf[UTF8String], 0) + 1
    +      }
    +    }
    +  }
    +
    +  override def toString: String = s"INSTR($str, $substr)"
    +}
    +
    +/**
    + * A function that returns the position of the first occurrence of substr
    + * in given string after position pos.
    + */
    +case class StringLocate(substr: Expression, str: Expression, start: 
Expression)
    +  extends Expression with ExpectsInputTypes {
    +
    +  def this(substr: Expression, str: Expression) = {
    +    this(substr, str, Literal(0))
    +  }
    +
    +  override def children: Seq[Expression] = substr :: str :: start :: Nil
    +  override def foldable: Boolean = children.forall(_.foldable)
    +  override def nullable: Boolean = substr.nullable || str.nullable
    +  override def dataType: DataType = IntegerType
    +  override def expectedChildTypes: Seq[DataType] = Seq(StringType, 
StringType, IntegerType)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val s = start.eval(input)
    +    if (s == null) {
    +      // if the start position is null, we need to return 0, (keep it 
conform to Hive)
    +      0
    +    } else {
    +      val r = substr.eval(input)
    +      if (r == null) {
    +        null
    +      } else {
    +        val l = str.eval(input)
    +        if (l == null) {
    +          null
    +        } else {
    +          l.asInstanceOf[UTF8String].instr(
    +            r.asInstanceOf[UTF8String],
    +            s.asInstanceOf[Int]) + 1
    +        }
    +      }
    +    }
    +  }
    +
    +  override def toString: String = s"LOCATE($substr, $str[, $start])"
    +}
    +
    +/**
    + * Returns str, left-padded with pad to a length of len
    + */
    +case class StringLPad(str: Expression, len: Expression, pad: Expression)
    +  extends Expression with ExpectsInputTypes {
    +
    +  override def children: Seq[Expression] = str :: len :: pad :: Nil
    +  override def foldable: Boolean = children.forall(_.foldable)
    +  override def nullable: Boolean = children.exists(_.nullable)
    +  override def dataType: DataType = StringType
    +  override def expectedChildTypes: Seq[DataType] = Seq(StringType, 
IntegerType, StringType)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val s = str.eval(input)
    +    if (s == null) {
    +      null
    +    } else {
    +      val l = len.eval(input)
    +      if (l == null) {
    +        null
    +      } else {
    +        val p = pad.eval(input)
    +        if (p == null) {
    +          null
    +        } else {
    +          val len = l.asInstanceOf[Int]
    +          val str = s.asInstanceOf[UTF8String]
    +          val pad = p.asInstanceOf[UTF8String]
    +          val bytes = new Array[Byte](len)
    +
    +          performOp(bytes, str.getBytes, pad.getBytes, len, str, pad)
    +        }
    +      }
    +    }
    +  }
    +
    +  // Copied from org.apache.hadoop.hive.ql.udf.generic.GenericUDFLpad
    +  // ('hi', 5, '??') => '???hi'
    +  // ('hi', 1, '??') => 'h'
    +  protected def performOp(
    +    data: Array[Byte],
    +    txt: Array[Byte],
    +    padTxt: Array[Byte],
    +    len: Int,
    +    str: UTF8String,
    +    pad: UTF8String): UTF8String = {
    +
    +    val pos: Int = Math.max(len - str.length, 0)
    +    // Copy the padding
    +    var i = 0
    +    while (i < pos) {
    +      var j = 0
    +      while (j < pad.length()) {
    +        data(i + j) = padTxt(j)
    +        j += 1
    +      }
    +      i += pad.length()
    +    }
    +
    +    // Copy the text
    +    i = 0
    +    while (pos + i < len && i < str.length()) {
    +      data(pos + i) = txt(i)
    +      i += 1
    +    }
    +
    +    UTF8String.fromBytes(data)
    +  }
    +
    +  override def toString: String = s"LPAD($str, $len, $pad)"
    +}
    +
    +/**
    + * Returns str, right-padded with pad to a length of len.
    + */
    +case class StringRPad(str: Expression, len: Expression, pad: Expression)
    +  extends Expression with ExpectsInputTypes {
    +
    +  override def children: Seq[Expression] = str :: len :: pad :: Nil
    +  override def foldable: Boolean = children.forall(_.foldable)
    +  override def nullable: Boolean = children.exists(_.nullable)
    +  override def dataType: DataType = StringType
    +  override def expectedChildTypes: Seq[DataType] = Seq(StringType, 
IntegerType, StringType)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val s = str.eval(input)
    +    if (s == null) {
    +      null
    +    } else {
    +      val l = len.eval(input)
    +      if (l == null) {
    +        null
    +      } else {
    +        val p = pad.eval(input)
    +        if (p == null) {
    +          null
    +        } else {
    +          val len = l.asInstanceOf[Int]
    +          val str = s.asInstanceOf[UTF8String]
    +          val pad = p.asInstanceOf[UTF8String]
    +          val bytes = new Array[Byte](len)
    +
    +          performOp(bytes, str.getBytes, pad.getBytes, len, str, pad)
    +        }
    +      }
    +    }
    +  }
    +
    +  // Copied from org.apache.hadoop.hive.ql.udf.generic.GenericUDFRpad
    +  // ('hi', 5, '??') => 'hi???'
    +  // ('hi', 1, '??') => 'h'
    +  protected def performOp(
    +    data: Array[Byte],
    +    txt: Array[Byte],
    +    padTxt: Array[Byte],
    +    len: Int,
    +    str: UTF8String,
    +    pad: UTF8String): UTF8String = {
    +
    +    // Copy the text
    +    var pos = 0
    +    while (pos < str.length() && pos < len) {
    +      data(pos) = txt(pos)
    +      pos += 1
    +    }
    +
    +    // Copy the padding
    +    while (pos < len) {
    +      var i = 0
    +      while (i < pad.length() && i < len - pos) {
    +        data(pos + i) = padTxt(i)
    +        i += 1
    +      }
    +
    +      pos += pad.length()
    +    }
    +
    +    UTF8String.fromBytes(data)
    +  }
    +
    +  override def toString: String = s"RPAD($str, $len, $pad)"
    +}
    +
    +/**
    + * Returns the input formatted according do printf-style format strings
    + */
    +case class StringFormat(children: Expression*) extends Expression {
    +
    +  require(children.length >=1, "printf() should take at least 1 argument")
    +
    +  override def foldable: Boolean = children.forall(_.foldable)
    +  override def nullable: Boolean = children(0).nullable
    +  override def dataType: DataType = StringType
    +  private def format: Expression = children(0)
    +  private def args: Seq[Expression] = children.tail
    +
    +  override def eval(input: InternalRow): Any = {
    +    val pattern = format.eval(input)
    +    if (pattern == null) {
    +      null
    +    } else {
    +      val sb = new StringBuffer()
    +      val formatter = new java.util.Formatter(sb, Locale.US)
    +
    +      val arglist = args.map(_.eval(input).asInstanceOf[AnyRef])
    +      formatter.format(pattern.asInstanceOf[UTF8String].toString(), 
arglist: _*)
    +
    +      UTF8String.fromString(sb.toString)
    +    }
    +  }
    +
    +  override def toString: String = s"printf($format, $args)"
    +}
    +
    +/**
    + * Returns the string which repeat the given string value n times.
    + */
    +case class StringRepeat(str: Expression, times: Expression)
    +  extends Expression with ExpectsInputTypes {
    +
    +  override def children: Seq[Expression] = str :: times :: Nil
    +  override def foldable: Boolean = str.foldable && times.foldable
    +  override def nullable: Boolean = str.nullable || times.nullable
    +  override def dataType: DataType = StringType
    +  override def expectedChildTypes: Seq[DataType] = Seq(StringType, 
IntegerType)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val s = str.eval(input)
    +    if (s == null) {
    +      null
    +    } else {
    +      val t = times.eval(input)
    +      if (t == null) {
    +        null
    +      } else {
    +        UTF8String.fromString(s.asInstanceOf[UTF8String].toString() * 
t.asInstanceOf[Integer])
    +      }
    +    }
    +  }
    +
    +  override def toString: String = s"repeat($str, $times)"
    +}
    +
    +/**
    + * Returns the reversed given string.
    + */
    +case class StringReverse(child: Expression) extends UnaryExpression with 
String2StringExpression {
    +  override def convert(v: UTF8String): UTF8String = v.reverse()
    +
    +  override def toString: String = s"reverse($child)"
    +
    +  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    +    defineCodeGen(ctx, ev, c => s"($c).reverse()")
    +  }
    +}
    +
    +/**
    + * Returns a n spaces string.
    + */
    +case class StringSpace(child: Expression) extends UnaryExpression with 
ExpectsInputTypes {
    +
    +  override def dataType: DataType = StringType
    +  override def expectedChildTypes: Seq[DataType] = Seq(IntegerType)
    +
    +  override def eval(input: InternalRow): Any = {
    +    val s = child.eval(input)
    +    if (s == null) {
    +      null
    +    } else {
    +      val length = s.asInstanceOf[Integer]
    +
    +      val spaces = new Array[Byte](if (length < 0) 0 else length)
    +      java.util.Arrays.fill(spaces, ' '.asInstanceOf[Byte])
    +      UTF8String.fromBytes(spaces)
    +    }
    +  }
    +
    +  override def toString: String = s"space($child)"
    +}
    +
    +/**
    + * Splits str around pat (pattern is a regular expression).
    + */
    +case class StringSplit(str: Expression, pattern: Expression)
    +  extends Expression with ExpectsInputTypes {
    +
    +  override def foldable: Boolean = str.foldable && pattern.foldable
    +  override def nullable: Boolean = str.nullable || pattern.nullable
    +  override def dataType: DataType = ArrayType(StringType)
    +  override def expectedChildTypes: Seq[DataType] = StringType :: 
StringType :: Nil
    +  override def children: Seq[Expression] = str :: pattern :: Nil
    +
    +  override def eval(input: InternalRow): Any = {
    +    val v = str.eval(input)
    +    if (v == null) {
    +      null
    +    } else {
    +      val p = pattern.eval(input)
    +      if (p == null) {
    +        null
    +      } else {
    +        val splits = split(v.asInstanceOf[UTF8String], 
p.asInstanceOf[UTF8String], -1)
    --- End diff --
    
    Yes, thanks for the suggestion, and also the benchmarking. :)


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to