Repository: spark
Updated Branches:
  refs/heads/branch-2.4 8705a9dd7 -> d63ab5a4f


[SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float

GROUP BY treats -0.0 and 0.0 as different values which is unlike hive's 
behavior.
In addition current behavior with codegen is unpredictable (see example in JIRA 
ticket).

## What changes were proposed in this pull request?

In Platform.putDouble/Float() checking if the value is -0.0, and if so 
replacing with 0.0.
This is used by UnsafeRow so it won't have -0.0 values.

## How was this patch tested?

Added tests

Closes #23043 from adoron/adoron-spark-26021-replace-minus-zero-with-zero.

Authored-by: Alon Doron <ado...@palantir.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>
(cherry picked from commit 0ec7b99ea2b638453ed38bb092905bee4f907fe5)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d63ab5a4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d63ab5a4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d63ab5a4

Branch: refs/heads/branch-2.4
Commit: d63ab5a4f5aeecfa227edc84aa38e866446f5238
Parents: 8705a9d
Author: Alon Doron <ado...@palantir.com>
Authored: Fri Nov 23 08:55:00 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Nov 23 08:55:31 2018 +0800

----------------------------------------------------------------------
 .../main/java/org/apache/spark/unsafe/Platform.java   | 10 ++++++++++
 .../org/apache/spark/unsafe/PlatformUtilSuite.java    | 14 ++++++++++++++
 .../spark/sql/catalyst/expressions/UnsafeRow.java     |  6 ------
 .../catalyst/expressions/codegen/UnsafeWriter.java    |  6 ------
 .../apache/spark/sql/DataFrameAggregateSuite.scala    | 14 ++++++++++++++
 .../test/scala/org/apache/spark/sql/QueryTest.scala   |  5 ++++-
 6 files changed, 42 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index aca6fca..bc94f21 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -120,6 +120,11 @@ public final class Platform {
   }
 
   public static void putFloat(Object object, long offset, float value) {
+    if (Float.isNaN(value)) {
+      value = Float.NaN;
+    } else if (value == -0.0f) {
+      value = 0.0f;
+    }
     _UNSAFE.putFloat(object, offset, value);
   }
 
@@ -128,6 +133,11 @@ public final class Platform {
   }
 
   public static void putDouble(Object object, long offset, double value) {
+    if (Double.isNaN(value)) {
+      value = Double.NaN;
+    } else if (value == -0.0d) {
+      value = 0.0d;
+    }
     _UNSAFE.putDouble(object, offset, value);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 3ad9ac7..ab34324 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -157,4 +157,18 @@ public class PlatformUtilSuite {
     Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
     Assert.assertEquals(obj3, onheap4.getBaseObject());
   }
+
+  @Test
+  // SPARK-26021
+  public void writeMinusZeroIsReplacedWithZero() {
+    byte[] doubleBytes = new byte[Double.BYTES];
+    byte[] floatBytes = new byte[Float.BYTES];
+    Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
+    Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
+    double doubleFromPlatform = Platform.getDouble(doubleBytes, 
Platform.BYTE_ARRAY_OFFSET);
+    float floatFromPlatform = Platform.getFloat(floatBytes, 
Platform.BYTE_ARRAY_OFFSET);
+
+    Assert.assertEquals(Double.doubleToLongBits(0.0d), 
Double.doubleToLongBits(doubleFromPlatform));
+    Assert.assertEquals(Float.floatToIntBits(0.0f), 
Float.floatToIntBits(floatFromPlatform));
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index a76e6ef..9bf9452 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -224,9 +224,6 @@ public final class UnsafeRow extends InternalRow implements 
Externalizable, Kryo
   public void setDouble(int ordinal, double value) {
     assertIndexIsValid(ordinal);
     setNotNullAt(ordinal);
-    if (Double.isNaN(value)) {
-      value = Double.NaN;
-    }
     Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
   }
 
@@ -255,9 +252,6 @@ public final class UnsafeRow extends InternalRow implements 
Externalizable, Kryo
   public void setFloat(int ordinal, float value) {
     assertIndexIsValid(ordinal);
     setNotNullAt(ordinal);
-    if (Float.isNaN(value)) {
-      value = Float.NaN;
-    }
     Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 2781655..95263a0 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -199,16 +199,10 @@ public abstract class UnsafeWriter {
   }
 
   protected final void writeFloat(long offset, float value) {
-    if (Float.isNaN(value)) {
-      value = Float.NaN;
-    }
     Platform.putFloat(getBuffer(), offset, value);
   }
 
   protected final void writeDouble(long offset, double value) {
-    if (Double.isNaN(value)) {
-      value = Double.NaN;
-    }
     Platform.putDouble(getBuffer(), offset, value);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index d0106c4..41dc72d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -727,4 +727,18 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
       "grouping expressions: [current_date(None)], value: [key: int, value: 
string], " +
         "type: GroupBy]"))
   }
+
+  test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") 
{
+    val colName = "i"
+    val doubles = Seq(0.0d, -0.0d, 
0.0d).toDF(colName).groupBy(colName).count().collect()
+    val floats = Seq(0.0f, -0.0f, 
0.0f).toDF(colName).groupBy(colName).count().collect()
+
+    assert(doubles.length == 1)
+    assert(floats.length == 1)
+    // using compare since 0.0 == -0.0 is true
+    assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
+    assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
+    assert(doubles(0).getLong(1) == 3)
+    assert(floats(0).getLong(1) == 3)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d63ab5a4/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index baca9c1..8ba6723 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -289,7 +289,7 @@ object QueryTest {
   def prepareRow(row: Row): Row = {
     Row.fromSeq(row.toSeq.map {
       case null => null
-      case d: java.math.BigDecimal => BigDecimal(d)
+      case bd: java.math.BigDecimal => BigDecimal(bd)
       // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 
2.12.2+
       case seq: Seq[_] => seq.map {
         case b: java.lang.Byte => b.byteValue
@@ -303,6 +303,9 @@ object QueryTest {
       // Convert array to Seq for easy equality check.
       case b: Array[_] => b.toSeq
       case r: Row => prepareRow(r)
+      // spark treats -0.0 as 0.0
+      case d: Double if d == -0.0d => 0.0d
+      case f: Float if f == -0.0f => 0.0f
       case o => o
     })
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to