Repository: spark
Updated Branches:
  refs/heads/master 207067ead -> 2a0bc867a


[SPARK-17495][SQL] Support Decimal type in Hive-hash

## What changes were proposed in this pull request?

Hive hash to support Decimal datatype. [Hive internally normalises 
decimals](https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/storage-api/src/java/org/apache/hadoop/hive/common/type/HiveDecimalV1.java#L307)
 and I have ported that logic as-is to HiveHash.

## How was this patch tested?

Added unit tests

Author: Tejas Patil <tej...@fb.com>

Closes #17056 from tejasapatil/SPARK-17495_decimal.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a0bc867
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a0bc867
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a0bc867

Branch: refs/heads/master
Commit: 2a0bc867a4a1dad4ecac47701199e540d345ff4f
Parents: 207067e
Author: Tejas Patil <tej...@fb.com>
Authored: Mon Mar 6 10:16:20 2017 -0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Mon Mar 6 10:16:20 2017 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/hash.scala   | 56 +++++++++++++++++++-
 .../expressions/HashExpressionsSuite.scala      | 46 +++++++++++++++-
 2 files changed, 99 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a0bc867/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 2d9c2e4..03101b4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.math.{BigDecimal, RoundingMode}
 import java.security.{MessageDigest, NoSuchAlgorithmException}
 import java.util.zip.CRC32
 
@@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction {
  * We should use this hash function for both shuffle and bucket of Hive 
tables, so that
  * we can guarantee shuffle and bucketing have same data distribution
  *
- * TODO: Support Decimal and date related types
+ * TODO: Support date related types
  */
 @ExpressionDescription(
   usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.")
@@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
   override protected def genHashBytes(b: String, result: String): String =
     s"$result = $hasherClassName.hashUnsafeBytes($b, 
Platform.BYTE_ARRAY_OFFSET, $b.length);"
 
+  override protected def genHashDecimal(
+      ctx: CodegenContext,
+      d: DecimalType,
+      input: String,
+      result: String): String = {
+    s"""
+      $result = 
${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal(
+        $input.toJavaBigDecimal()).hashCode();"""
+  }
+
   override protected def genHashCalendarInterval(input: String, result: 
String): String = {
     s"""
         $result = (31 * $hasherClassName.hashInt($input.months)) +
@@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction {
     HiveHasher.hashUnsafeBytes(base, offset, len)
   }
 
+  private val HIVE_DECIMAL_MAX_PRECISION = 38
+  private val HIVE_DECIMAL_MAX_SCALE = 38
+
+  // Mimics normalization done for decimals in Hive at 
HiveDecimalV1.normalize()
+  def normalizeDecimal(input: BigDecimal): BigDecimal = {
+    if (input == null) return null
+
+    def trimDecimal(input: BigDecimal) = {
+      var result = input
+      if (result.compareTo(BigDecimal.ZERO) == 0) {
+        // Special case for 0, because java doesn't strip zeros correctly on 
that number.
+        result = BigDecimal.ZERO
+      } else {
+        result = result.stripTrailingZeros
+        if (result.scale < 0) {
+          // no negative scale decimals
+          result = result.setScale(0)
+        }
+      }
+      result
+    }
+
+    var result = trimDecimal(input)
+    val intDigits = result.precision - result.scale
+    if (intDigits > HIVE_DECIMAL_MAX_PRECISION) {
+      return null
+    }
+
+    val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE,
+      Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale))
+    if (result.scale > maxScale) {
+      result = result.setScale(maxScale, RoundingMode.HALF_UP)
+      // Trimming is again necessary, because rounding may introduce new 
trailing 0's.
+      result = trimDecimal(result)
+    }
+    result
+  }
+
   override def hash(value: Any, dataType: DataType, seed: Long): Long = {
     value match {
       case null => 0
@@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction {
         }
         result
 
-      case _ => super.hash(value, dataType, 0)
+      case d: Decimal =>
+        normalizeDecimal(d.toJavaBigDecimal).hashCode()
+
+      case _ => super.hash(value, dataType, seed)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2a0bc867/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 0cb3a79..0c77dc2 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
   }
 
-
   def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
     // Note : All expected hashes need to be computed using Hive 1.2.1
     val actual = HiveHashFunction.hash(input, dataType, seed = 0)
@@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
         new StructType().add("array", arrayOfString).add("map", mapOfString))
       .add("structOfUDT", structOfUDT))
 
+  test("hive-hash for decimal") {
+    def checkHiveHashForDecimal(
+        input: String,
+        precision: Int,
+        scale: Int,
+        expected: Long): Unit = {
+      val decimalType = DataTypes.createDecimalType(precision, scale)
+      val decimal = {
+        val value = Decimal.apply(new java.math.BigDecimal(input))
+        if (value.changePrecision(precision, scale)) value else null
+      }
+
+      checkHiveHash(decimal, decimalType, expected)
+    }
+
+    checkHiveHashForDecimal("18", 38, 0, 558)
+    checkHiveHashForDecimal("-18", 38, 0, -558)
+    checkHiveHashForDecimal("-18", 38, 12, -558)
+    checkHiveHashForDecimal("18446744073709001000", 38, 19, 0)
+    checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0)
+    checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057)
+    checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057)
+    checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656)
+    checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656)
+    checkHiveHashForDecimal("00000.00000000000", 38, 34, 0)
+    checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0)
+    checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974)
+    checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
+    checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252)
+    checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234)
+    checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136)
+    checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136)
+    checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
+    checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234)
+    checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 
3827136)
+    checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, 
-3827136)
+    checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 
1871500252)
+    checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, 
-1871500234)
+    checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 
236317582)
+    checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, 
-236317544)
+    checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 
1728235666)
+    checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, 
-1728235608)
+    checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 
1728235666)
+  }
+
   test("SPARK-18207: Compute hash for a lot of expressions") {
     val N = 1000
     val wideRow = new GenericInternalRow(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to