Repository: spark Updated Branches: refs/heads/master 207067ead -> 2a0bc867a
[SPARK-17495][SQL] Support Decimal type in Hive-hash ## What changes were proposed in this pull request? Hive hash to support Decimal datatype. [Hive internally normalises decimals](https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/storage-api/src/java/org/apache/hadoop/hive/common/type/HiveDecimalV1.java#L307) and I have ported that logic as-is to HiveHash. ## How was this patch tested? Added unit tests Author: Tejas Patil <tej...@fb.com> Closes #17056 from tejasapatil/SPARK-17495_decimal. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a0bc867 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a0bc867 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a0bc867 Branch: refs/heads/master Commit: 2a0bc867a4a1dad4ecac47701199e540d345ff4f Parents: 207067e Author: Tejas Patil <tej...@fb.com> Authored: Mon Mar 6 10:16:20 2017 -0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Mon Mar 6 10:16:20 2017 -0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/hash.scala | 56 +++++++++++++++++++- .../expressions/HashExpressionsSuite.scala | 46 +++++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2a0bc867/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2d9c2e4..03101b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal, RoundingMode} import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 @@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction { * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution * - * TODO: Support Decimal and date related types + * TODO: Support date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashBytes(b: String, result: String): String = s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + s""" + $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( + $input.toJavaBigDecimal()).hashCode();""" + } + override protected def genHashCalendarInterval(input: String, result: String): String = { s""" $result = (31 * $hasherClassName.hashInt($input.months)) + @@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } + private val HIVE_DECIMAL_MAX_PRECISION = 38 + private val HIVE_DECIMAL_MAX_SCALE = 38 + + // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() + def normalizeDecimal(input: BigDecimal): BigDecimal = { + if (input == null) return null + + def trimDecimal(input: BigDecimal) = { + var result = input + if (result.compareTo(BigDecimal.ZERO) == 0) { + // Special case for 0, because java doesn't strip zeros correctly on that number. + result = BigDecimal.ZERO + } else { + result = result.stripTrailingZeros + if (result.scale < 0) { + // no negative scale decimals + result = result.setScale(0) + } + } + result + } + + var result = trimDecimal(input) + val intDigits = result.precision - result.scale + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { + return null + } + + val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) + if (result.scale > maxScale) { + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) + } + result + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction { } result - case _ => super.hash(value, dataType, 0) + case d: Decimal => + normalizeDecimal(d.toJavaBigDecimal).hashCode() + + case _ => super.hash(value, dataType, seed) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/2a0bc867/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0cb3a79..0c77dc2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 val actual = HiveHashFunction.hash(input, dataType, seed = 0) @@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("hive-hash for decimal") { + def checkHiveHashForDecimal( + input: String, + precision: Int, + scale: Int, + expected: Long): Unit = { + val decimalType = DataTypes.createDecimalType(precision, scale) + val decimal = { + val value = Decimal.apply(new java.math.BigDecimal(input)) + if (value.changePrecision(precision, scale)) value else null + } + + checkHiveHash(decimal, decimalType, expected) + } + + checkHiveHashForDecimal("18", 38, 0, 558) + checkHiveHashForDecimal("-18", 38, 0, -558) + checkHiveHashForDecimal("-18", 38, 12, -558) + checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) + checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) + checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) + checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) + checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) + checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) + } + test("SPARK-18207: Compute hash for a lot of expressions") { val N = 1000 val wideRow = new GenericInternalRow( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org