vinodkc commented on code in PR #38419:
URL: https://github.com/apache/spark/pull/38419#discussion_r1080715806


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,268 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+case class TruncNumber(child: Expression, scale: Expression)
+  extends BaseBinaryExpression with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = 
newRight)
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this 
expression. The default
+   * behavior is to call the eval method of the expression. Concrete 
expression implementations
+   * should override this to do actual code generation.
+   *
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given 
expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        dataType match {
+          case ByteType if (_scale <= 0) =>
+            
s"""(byte)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |(long)$input, ${_scale}))""".stripMargin
+          case ShortType if (_scale <= 0) =>
+            
s"""(short)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |(long)$input, ${_scale}))""".stripMargin
+          case IntegerType if (_scale <= 0) =>
+            
s"""(int)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |(long)$input, ${_scale}))""".stripMargin
+          case LongType if (_scale <= 0) =>
+            s"""(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |$input, ${_scale}))""".stripMargin
+          case FloatType if (_scale <= 0) =>
+            s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |$input, ${_scale}).floatValue()""".stripMargin
+          case DoubleType if (_scale <= 0) =>
+            s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |$input, ${_scale}).doubleValue()""".stripMargin
+          case DecimalType.Fixed(_, _) =>
+            s"""Decimal.apply(
+             |org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |${input}.toJavaBigDecimal(), ${_scale}))""".stripMargin
+          case _ => s"$input"
+        }
+      })
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = {
+    child.dataType match {
+      case DecimalType.Fixed(p, s) =>
+        val newPosition =
+          if (_scale > 0) {
+            if (_scale >= s) {
+              s
+            } else {
+              _scale
+            }
+          } else {
+            0
+          }
+        DecimalType(p - s + newPosition, newPosition)
+      case t => t
+    }
+  }
+
+  /**
+   * Called by default [[eval]] implementation. If subclass of 
BinaryExpression keep the default
+   * nullability, they can override this method to save null-check code. If we 
need full control
+   * of evaluation process, we should override [[eval]].
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    dataType match {
+      case ByteType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Byte].toLong, _scale).toByte
+      case ShortType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Short].toLong, _scale).shortValue
+      case IntegerType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Int].toLong, _scale).intValue
+      case LongType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Long], _scale).longValue
+      case FloatType =>
+        TruncNumber.trunc(input1.asInstanceOf[Float], _scale).floatValue
+      case DoubleType =>
+        TruncNumber.trunc(input1.asInstanceOf[Double], _scale).doubleValue
+      case DecimalType.Fixed(p, s) =>
+        
Decimal(TruncNumber.trunc(input1.asInstanceOf[Decimal].toJavaBigDecimal, 
_scale))
+      case _ => input1
+    }
+  }
+}
+
+object TruncNumber {
+  /**
+   * To truncate whole numbers ; byte, short, int, long types
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // position is -ve, truncate the number by absolute value of position
+      // eg: input 123 , scale -2 , result 100
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) {
+      input
+    } else {
+      val wholePart = input.toBigInteger
+      if (position > 0) {
+        // position is +ve , truncate only the decimal part by value of 
position

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to