Repository: spark
Updated Branches:
  refs/heads/master 678e1aca6 -> bdf32847b


[SPARK-26021][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter

## What changes were proposed in this pull request?

A followup of https://github.com/apache/spark/pull/23043

There are 4 places we need to deal with NaN and -0.0:
1. comparison expressions. `-0.0` and `0.0` should be treated as same. 
Different NaNs should be treated as same.
2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should 
be treated as same.
3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. 
Different NaNs should be assigned to the same group.
4. window partition keys. `-0.0` and `0.0` should be treated as same. Different 
NaNs should be treated as same.

The case 1 is OK. Our comparison already handles NaN and -0.0, and for 
struct/array/map, we will recursively compare the fields/elements.

Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, 
and different NaNs have different binary representation, and the same thing 
happens for -0.0 and 0.0.

To fix it, a simple solution is: normalize float/double when building unsafe 
data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to 
worry about it anymore.

Following this direction, this PR moves the handling of NaN and -0.0 from 
`Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not 
handle them, which reduces the perf overhead. It's also easier to add comments 
explaining why we do it in `UnsafeWriter`.

## How was this patch tested?

existing tests

Closes #23239 from cloud-fan/minor.

Authored-by: Wenchen Fan <wenc...@databricks.com>
Signed-off-by: Dongjoon Hyun <dongj...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bdf32847
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bdf32847
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bdf32847

Branch: refs/heads/master
Commit: bdf32847b1ffcb3aa4d0bef058f86e65656e99fb
Parents: 678e1ac
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Sat Dec 8 11:18:09 2018 -0800
Committer: Dongjoon Hyun <dongj...@apache.org>
Committed: Sat Dec 8 11:18:09 2018 -0800

----------------------------------------------------------------------
 .../java/org/apache/spark/unsafe/Platform.java  | 10 ------
 .../apache/spark/unsafe/PlatformUtilSuite.java  | 18 ----------
 .../expressions/codegen/UnsafeWriter.java       | 35 ++++++++++++++++++++
 .../codegen/UnsafeRowWriterSuite.scala          | 20 +++++++++++
 .../apache/spark/sql/DataFrameJoinSuite.scala   | 12 +++++++
 .../sql/DataFrameWindowFunctionsSuite.scala     | 14 ++++++++
 6 files changed, 81 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index 4563efc..076b693 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -174,11 +174,6 @@ public final class Platform {
   }
 
   public static void putFloat(Object object, long offset, float value) {
-    if (Float.isNaN(value)) {
-      value = Float.NaN;
-    } else if (value == -0.0f) {
-      value = 0.0f;
-    }
     _UNSAFE.putFloat(object, offset, value);
   }
 
@@ -187,11 +182,6 @@ public final class Platform {
   }
 
   public static void putDouble(Object object, long offset, double value) {
-    if (Double.isNaN(value)) {
-      value = Double.NaN;
-    } else if (value == -0.0d) {
-      value = 0.0d;
-    }
     _UNSAFE.putDouble(object, offset, value);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 2474081..3ad9ac7 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -157,22 +157,4 @@ public class PlatformUtilSuite {
     Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
     Assert.assertEquals(obj3, onheap4.getBaseObject());
   }
-
-  @Test
-  // SPARK-26021
-  public void writeMinusZeroIsReplacedWithZero() {
-    byte[] doubleBytes = new byte[Double.BYTES];
-    byte[] floatBytes = new byte[Float.BYTES];
-    Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
-    Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
-
-    byte[] doubleBytes2 = new byte[Double.BYTES];
-    byte[] floatBytes2 = new byte[Float.BYTES];
-    Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d);
-    Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f);
-
-    // Make sure the bytes we write from 0.0 and -0.0 are same.
-    Assert.assertArrayEquals(doubleBytes, doubleBytes2);
-    Assert.assertArrayEquals(floatBytes, floatBytes2);
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 95263a0..7553ab8 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -198,11 +198,46 @@ public abstract class UnsafeWriter {
     Platform.putLong(getBuffer(), offset, value);
   }
 
+  // We need to take care of NaN and -0.0 in several places:
+  //   1. When compare values, different NaNs should be treated as same, 
`-0.0` and `0.0` should be
+  //      treated as same.
+  //   2. In GROUP BY, different NaNs should belong to the same group, -0.0 
and 0.0 should belong
+  //      to the same group.
+  //   3. As join keys, different NaNs should be treated as same, `-0.0` and 
`0.0` should be
+  //      treated as same.
+  //   4. As window partition keys, different NaNs should be treated as same, 
`-0.0` and `0.0`
+  //      should be treated as same.
+  //
+  // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For 
complex types, we
+  // recursively compare the fields/elements, so it's also fine.
+  //
+  // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary 
directly, and different
+  // NaNs have different binary representation, and the same thing happens for 
-0.0 and 0.0.
+  //
+  // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize 
them when writing
+  // float/double columns and nested fields to `UnsafeRow`.
+  //
+  // Note that, we must do this for all the `UnsafeProjection`s, not only the 
ones that extract
+  // join/grouping/window partition keys. `UnsafeProjection` copies unsafe 
data directly for complex
+  // types, so nested float/double may not be normalized. We need to make sure 
that all the unsafe
+  // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have 
flat/double normalized during
+  // creation.
   protected final void writeFloat(long offset, float value) {
+    if (Float.isNaN(value)) {
+      value = Float.NaN;
+    } else if (value == -0.0f) {
+      value = 0.0f;
+    }
     Platform.putFloat(getBuffer(), offset, value);
   }
 
+  // See comments for `writeFloat`.
   protected final void writeDouble(long offset, double value) {
+    if (Double.isNaN(value)) {
+      value = Double.NaN;
+    } else if (value == -0.0d) {
+      value = 0.0d;
+    }
     Platform.putDouble(getBuffer(), offset, value);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
index fb651b7..22e1fa6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
     assert(res1 == res2)
   }
 
+  test("SPARK-26021: normalize float/double NaN and -0.0") {
+    val unsafeRowWriter1 = new UnsafeRowWriter(4)
+    unsafeRowWriter1.resetRowWriter()
+    unsafeRowWriter1.write(0, Float.NaN)
+    unsafeRowWriter1.write(1, Double.NaN)
+    unsafeRowWriter1.write(2, 0.0f)
+    unsafeRowWriter1.write(3, 0.0)
+    val res1 = unsafeRowWriter1.getRow
+
+    val unsafeRowWriter2 = new UnsafeRowWriter(4)
+    unsafeRowWriter2.resetRowWriter()
+    unsafeRowWriter2.write(0, 0.0f/0.0f)
+    unsafeRowWriter2.write(1, 0.0/0.0)
+    unsafeRowWriter2.write(2, -0.0f)
+    unsafeRowWriter2.write(3, -0.0)
+    val res2 = unsafeRowWriter2.getRow
+
+    // The two rows should be the equal
+    assert(res1 == res2)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index e6b30f9..c9f41ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with 
SharedSQLContext {
       df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
     }
   }
+
+  test("NaN and -0.0 in join keys") {
+    val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> 
-0.0).toDF("f", "d")
+    val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> 
-0.0).toDF("f", "d")
+    val joined = df1.join(df2, Seq("f", "d"))
+    checkAnswer(joined, Seq(
+      Row(Float.NaN, Double.NaN),
+      Row(0.0f, 0.0),
+      Row(0.0f, 0.0),
+      Row(0.0f, 0.0),
+      Row(0.0f, 0.0)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 78277d7..9a5d5a9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -681,4 +681,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with 
SharedSQLContext {
         Row("S2", "P2", 300, 300, 500)))
 
   }
+
+  test("NaN and -0.0 in window partition keys") {
+    val df = Seq(
+      (Float.NaN, Double.NaN, 1),
+      (0.0f/0.0f, 0.0/0.0, 1),
+      (0.0f, 0.0, 1),
+      (-0.0f, -0.0, 1)).toDF("f", "d", "i")
+    val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
+    checkAnswer(result, Seq(
+      Row(Float.NaN, 2),
+      Row(Float.NaN, 2),
+      Row(0.0f, 2),
+      Row(0.0f, 2)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to