Repository: spark Updated Branches: refs/heads/branch-2.4 8705a9dd7 -> d63ab5a4f
[SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float GROUP BY treats -0.0 and 0.0 as different values which is unlike hive's behavior. In addition current behavior with codegen is unpredictable (see example in JIRA ticket). ## What changes were proposed in this pull request? In Platform.putDouble/Float() checking if the value is -0.0, and if so replacing with 0.0. This is used by UnsafeRow so it won't have -0.0 values. ## How was this patch tested? Added tests Closes #23043 from adoron/adoron-spark-26021-replace-minus-zero-with-zero. Authored-by: Alon Doron <ado...@palantir.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 0ec7b99ea2b638453ed38bb092905bee4f907fe5) Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d63ab5a4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d63ab5a4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d63ab5a4 Branch: refs/heads/branch-2.4 Commit: d63ab5a4f5aeecfa227edc84aa38e866446f5238 Parents: 8705a9d Author: Alon Doron <ado...@palantir.com> Authored: Fri Nov 23 08:55:00 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Nov 23 08:55:31 2018 +0800 ---------------------------------------------------------------------- .../main/java/org/apache/spark/unsafe/Platform.java | 10 ++++++++++ .../org/apache/spark/unsafe/PlatformUtilSuite.java | 14 ++++++++++++++ .../spark/sql/catalyst/expressions/UnsafeRow.java | 6 ------ .../catalyst/expressions/codegen/UnsafeWriter.java | 6 ------ .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 ++++++++++++++ .../test/scala/org/apache/spark/sql/QueryTest.scala | 5 ++++- 6 files changed, 42 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca..bc94f21 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -120,6 +120,11 @@ public final class Platform { } public static void putFloat(Object object, long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } _UNSAFE.putFloat(object, offset, value); } @@ -128,6 +133,11 @@ public final class Platform { } public static void putDouble(Object object, long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } _UNSAFE.putDouble(object, offset, value); } http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7..ab34324 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,4 +157,18 @@ public class PlatformUtilSuite { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + // SPARK-26021 + public void writeMinusZeroIsReplacedWithZero() { + byte[] doubleBytes = new byte[Double.BYTES]; + byte[] floatBytes = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); + double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); + float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); + + Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); + Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef..9bf9452 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -224,9 +224,6 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -255,9 +252,6 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 2781655..95263a0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -199,16 +199,10 @@ public abstract class UnsafeWriter { } protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(getBuffer(), offset, value); } } http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c4..41dc72d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -727,4 +727,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { + val colName = "i" + val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() + val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() + + assert(doubles.length == 1) + assert(floats.length == 1) + // using compare since 0.0 == -0.0 is true + assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) + assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) + assert(doubles(0).getLong(1) == 3) + assert(floats(0).getLong(1) == 3) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index baca9c1..8ba6723 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -289,7 +289,7 @@ object QueryTest { def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null - case d: java.math.BigDecimal => BigDecimal(d) + case bd: java.math.BigDecimal => BigDecimal(bd) // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ case seq: Seq[_] => seq.map { case b: java.lang.Byte => b.byteValue @@ -303,6 +303,9 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) + // spark treats -0.0 as 0.0 + case d: Double if d == -0.0d => 0.0d + case f: Float if f == -0.0f => 0.0f case o => o }) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org