Repository: spark Updated Branches: refs/heads/master 678e1aca6 -> bdf32847b
[SPARK-26021][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/23043 There are 4 places we need to deal with NaN and -0.0: 1. comparison expressions. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. 2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. 3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. Different NaNs should be assigned to the same group. 4. window partition keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same. The case 1 is OK. Our comparison already handles NaN and -0.0, and for struct/array/map, we will recursively compare the fields/elements. Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. To fix it, a simple solution is: normalize float/double when building unsafe data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to worry about it anymore. Following this direction, this PR moves the handling of NaN and -0.0 from `Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not handle them, which reduces the perf overhead. It's also easier to add comments explaining why we do it in `UnsafeWriter`. ## How was this patch tested? existing tests Closes #23239 from cloud-fan/minor. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bdf32847 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bdf32847 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bdf32847 Branch: refs/heads/master Commit: bdf32847b1ffcb3aa4d0bef058f86e65656e99fb Parents: 678e1ac Author: Wenchen Fan <wenc...@databricks.com> Authored: Sat Dec 8 11:18:09 2018 -0800 Committer: Dongjoon Hyun <dongj...@apache.org> Committed: Sat Dec 8 11:18:09 2018 -0800 ---------------------------------------------------------------------- .../java/org/apache/spark/unsafe/Platform.java | 10 ------ .../apache/spark/unsafe/PlatformUtilSuite.java | 18 ---------- .../expressions/codegen/UnsafeWriter.java | 35 ++++++++++++++++++++ .../codegen/UnsafeRowWriterSuite.scala | 20 +++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 12 +++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 14 ++++++++ 6 files changed, 81 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 4563efc..076b693 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -174,11 +174,6 @@ public final class Platform { } public static void putFloat(Object object, long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } _UNSAFE.putFloat(object, offset, value); } @@ -187,11 +182,6 @@ public final class Platform { } public static void putDouble(Object object, long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } _UNSAFE.putDouble(object, offset, value); } http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 2474081..3ad9ac7 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,22 +157,4 @@ public class PlatformUtilSuite { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } - - @Test - // SPARK-26021 - public void writeMinusZeroIsReplacedWithZero() { - byte[] doubleBytes = new byte[Double.BYTES]; - byte[] floatBytes = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - - byte[] doubleBytes2 = new byte[Double.BYTES]; - byte[] floatBytes2 = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f); - - // Make sure the bytes we write from 0.0 and -0.0 are same. - Assert.assertArrayEquals(doubleBytes, doubleBytes2); - Assert.assertArrayEquals(floatBytes, floatBytes2); - } } http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 95263a0..7553ab8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -198,11 +198,46 @@ public abstract class UnsafeWriter { Platform.putLong(getBuffer(), offset, value); } + // We need to take care of NaN and -0.0 in several places: + // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be + // treated as same. + // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong + // to the same group. + // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be + // treated as same. + // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` + // should be treated as same. + // + // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we + // recursively compare the fields/elements, so it's also fine. + // + // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different + // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. + // + // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing + // float/double columns and nested fields to `UnsafeRow`. + // + // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract + // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex + // types, so nested float/double may not be normalized. We need to make sure that all the unsafe + // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during + // creation. protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } Platform.putFloat(getBuffer(), offset, value); } + // See comments for `writeFloat`. protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } Platform.putDouble(getBuffer(), offset, value); } } http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index fb651b7..22e1fa6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite { assert(res1 == res2) } + test("SPARK-26021: normalize float/double NaN and -0.0") { + val unsafeRowWriter1 = new UnsafeRowWriter(4) + unsafeRowWriter1.resetRowWriter() + unsafeRowWriter1.write(0, Float.NaN) + unsafeRowWriter1.write(1, Double.NaN) + unsafeRowWriter1.write(2, 0.0f) + unsafeRowWriter1.write(3, 0.0) + val res1 = unsafeRowWriter1.getRow + + val unsafeRowWriter2 = new UnsafeRowWriter(4) + unsafeRowWriter2.resetRowWriter() + unsafeRowWriter2.write(0, 0.0f/0.0f) + unsafeRowWriter2.write(1, 0.0/0.0) + unsafeRowWriter2.write(2, -0.0f) + unsafeRowWriter2.write(3, -0.0) + val res2 = unsafeRowWriter2.getRow + + // The two rows should be the equal + assert(res1 == res2) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e6b30f9..c9f41ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } + + test("NaN and -0.0 in join keys") { + val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") + val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") + val joined = df1.join(df2, Seq("f", "d")) + checkAnswer(joined, Seq( + Row(Float.NaN, Double.NaN), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(0.0f, 0.0))) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78277d7..9a5d5a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -681,4 +681,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("S2", "P2", 300, 300, 500))) } + + test("NaN and -0.0 in window partition keys") { + val df = Seq( + (Float.NaN, Double.NaN, 1), + (0.0f/0.0f, 0.0/0.0, 1), + (0.0f, 0.0, 1), + (-0.0f, -0.0, 1)).toDF("f", "d", "i") + val result = df.select($"f", count("i").over(Window.partitionBy("f", "d"))) + checkAnswer(result, Seq( + Row(Float.NaN, 2), + Row(Float.NaN, 2), + Row(0.0f, 2), + Row(0.0f, 2))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org