Repository: spark
Updated Branches:
  refs/heads/master d27daa54b -> bb823ca4b


[SPARK-19959][SQL] Fix to throw NullPointerException in  
df[java.lang.Long].collect

## What changes were proposed in this pull request?

This PR fixes `NullPointerException` in the generated code by Catalyst. When we 
run the following code, we get the following `NullPointerException`. This is 
because there is no null checks for `inputadapter_value`  while `java.lang.Long 
inputadapter_value` at Line 30 may have `null`.

This happen when a type of DataFrame is nullable primitive type such as 
`java.lang.Long` and the wholestage codegen is used. While the physical plan 
keeps `nullable=true` in `input[0, java.lang.Long, true].longValue`, 
`BoundReference.doGenCode` ignores `nullable=true`. Thus, nullcheck code will 
not be generated and `NullPointerException` will occur.

This PR checks the nullability and correctly generates nullcheck if needed.
```java
sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF.collect
```

```java
Caused by: java.lang.NullPointerException
        at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(generated.java:37)
        at 
org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
        at 
org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:393)
...
```

Generated code without this PR
```java
/* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow serializefromobject_result;
/* 010 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder 
serializefromobject_holder;
/* 011 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
serializefromobject_rowWriter;
/* 012 */
/* 013 */   public GeneratedIterator(Object[] references) {
/* 014 */     this.references = references;
/* 015 */   }
/* 016 */
/* 017 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */     partitionIndex = index;
/* 019 */     this.inputs = inputs;
/* 020 */     inputadapter_input = inputs[0];
/* 021 */     serializefromobject_result = new UnsafeRow(1);
/* 022 */     this.serializefromobject_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result,
 0);
/* 023 */     this.serializefromobject_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder,
 1);
/* 024 */
/* 025 */   }
/* 026 */
/* 027 */   protected void processNext() throws java.io.IOException {
/* 028 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 029 */       InternalRow inputadapter_row = (InternalRow) 
inputadapter_input.next();
/* 030 */       java.lang.Long inputadapter_value = 
(java.lang.Long)inputadapter_row.get(0, null);
/* 031 */
/* 032 */       boolean serializefromobject_isNull = true;
/* 033 */       long serializefromobject_value = -1L;
/* 034 */       if (!false) {
/* 035 */         serializefromobject_isNull = false;
/* 036 */         if (!serializefromobject_isNull) {
/* 037 */           serializefromobject_value = inputadapter_value.longValue();
/* 038 */         }
/* 039 */
/* 040 */       }
/* 041 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 042 */
/* 043 */       if (serializefromobject_isNull) {
/* 044 */         serializefromobject_rowWriter.setNullAt(0);
/* 045 */       } else {
/* 046 */         serializefromobject_rowWriter.write(0, 
serializefromobject_value);
/* 047 */       }
/* 048 */       append(serializefromobject_result);
/* 049 */       if (shouldStop()) return;
/* 050 */     }
/* 051 */   }
/* 052 */ }
```

Generated code with this PR

```java
/* 005 */ final class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow serializefromobject_result;
/* 010 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder 
serializefromobject_holder;
/* 011 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter 
serializefromobject_rowWriter;
/* 012 */
/* 013 */   public GeneratedIterator(Object[] references) {
/* 014 */     this.references = references;
/* 015 */   }
/* 016 */
/* 017 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */     partitionIndex = index;
/* 019 */     this.inputs = inputs;
/* 020 */     inputadapter_input = inputs[0];
/* 021 */     serializefromobject_result = new UnsafeRow(1);
/* 022 */     this.serializefromobject_holder = new 
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result,
 0);
/* 023 */     this.serializefromobject_rowWriter = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder,
 1);
/* 024 */
/* 025 */   }
/* 026 */
/* 027 */   protected void processNext() throws java.io.IOException {
/* 028 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 029 */       InternalRow inputadapter_row = (InternalRow) 
inputadapter_input.next();
/* 030 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 031 */       java.lang.Long inputadapter_value = inputadapter_isNull ? null 
: ((java.lang.Long)inputadapter_row.get(0, null));
/* 032 */
/* 033 */       boolean serializefromobject_isNull = true;
/* 034 */       long serializefromobject_value = -1L;
/* 035 */       if (!inputadapter_isNull) {
/* 036 */         serializefromobject_isNull = false;
/* 037 */         if (!serializefromobject_isNull) {
/* 038 */           serializefromobject_value = inputadapter_value.longValue();
/* 039 */         }
/* 040 */
/* 041 */       }
/* 042 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 043 */
/* 044 */       if (serializefromobject_isNull) {
/* 045 */         serializefromobject_rowWriter.setNullAt(0);
/* 046 */       } else {
/* 047 */         serializefromobject_rowWriter.write(0, 
serializefromobject_value);
/* 048 */       }
/* 049 */       append(serializefromobject_result);
/* 050 */       if (shouldStop()) return;
/* 051 */     }
/* 052 */   }
/* 053 */ }
```

## How was this patch tested?

Added new test suites in `DataFrameSuites`

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #17302 from kiszk/SPARK-19959.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bb823ca4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bb823ca4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bb823ca4

Branch: refs/heads/master
Commit: bb823ca4b479a00030c4919c2d857d254b2a44d8
Parents: d27daa5
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Fri Mar 24 12:57:56 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Mar 24 12:57:56 2017 +0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/plans/logical/object.scala |  5 ++++-
 .../org/apache/spark/sql/DataFrameImplicitsSuite.scala   | 11 +++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bb823ca4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 6225b3f..bfb70c2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -41,7 +41,10 @@ object CatalystSerde {
   }
 
   def generateObjAttr[T : Encoder]: Attribute = {
-    AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = 
false)()
+    val enc = encoderFor[T]
+    val dataType = enc.deserializer.dataType
+    val nullable = !enc.clsTag.runtimeClass.isPrimitive
+    AttributeReference("obj", dataType, nullable)()
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bb823ca4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index 094efba..63094d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -51,4 +51,15 @@ class DataFrameImplicitsSuite extends QueryTest with 
SharedSQLContext {
       sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
       (1 to 10).map(i => Row(i.toString)))
   }
+
+  test("SPARK-19959: df[java.lang.Long].collect includes null throws 
NullPointerException") {
+    checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 
1).toDF,
+      Seq(Row(0), Row(null), Row(2)))
+    checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 
1).toDF,
+      Seq(Row(0L), Row(null), Row(2L)))
+    checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 
2.0F), 1).toDF,
+      Seq(Row(0.0F), Row(null), Row(2.0F)))
+    checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 
2.0D), 1).toDF,
+      Seq(Row(0.0D), Row(null), Row(2.0D)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to