Repository: spark Updated Branches: refs/heads/master 41ee7c57a -> 6eba65525
[SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes the object passed is not serializable. For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly makes `StaticInvoke`. ``` case class TimestampContainer(timestamp: java.sql.Timestamp) val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis)) val df = rdd.toDF val ds = df.as[TimestampContainer] val rdd2 = ds.rdd <----------------- invokes extractorsFor indirectory ``` I'll add test cases. Author: Kousuke Saruta <saru...@oss.nttdata.co.jp> Author: Michael Armbrust <mich...@databricks.com> Closes #10357 from sarutak/SPARK-12404. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6eba6552 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6eba6552 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6eba6552 Branch: refs/heads/master Commit: 6eba655259d2bcea27d0147b37d5d1e476e85422 Parents: 41ee7c5 Author: Kousuke Saruta <saru...@oss.nttdata.co.jp> Authored: Fri Dec 18 14:05:06 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Fri Dec 18 14:05:06 2015 -0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/JavaTypeInference.scala | 12 ++--- .../spark/sql/catalyst/ScalaReflection.scala | 16 +++--- .../sql/catalyst/encoders/RowEncoder.scala | 14 +++--- .../sql/catalyst/expressions/objects.scala | 8 ++- .../org/apache/spark/sql/JavaDatasetSuite.java | 52 ++++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++++ 6 files changed, 88 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index c8ee87e..f566d1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -194,7 +194,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaDate", getPath :: Nil, @@ -202,7 +202,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", getPath :: Nil, @@ -276,7 +276,7 @@ object JavaTypeInference { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", keyData :: valueData :: Nil) @@ -341,21 +341,21 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ecff860..c1b1d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -223,7 +223,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", getPath :: Nil, @@ -231,7 +231,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", getPath :: Nil, @@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) @@ -548,28 +548,28 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d34ec94..63bdf05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -61,21 +61,21 @@ object RowEncoder { case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case _: DecimalType => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) @@ -172,14 +172,14 @@ object RowEncoder { case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", input :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", input :: Nil) @@ -197,7 +197,7 @@ object RowEncoder { "array", ObjectType(classOf[Array[_]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -210,7 +210,7 @@ object RowEncoder { val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 10ec75e..492cc9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -42,16 +42,14 @@ import org.apache.spark.sql.types._ * of calling the function. */ case class StaticInvoke( - staticObject: Any, + staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, propagateNull: Boolean = true) extends Expression { - val objectName = staticObject match { - case c: Class[_] => c.getName - case other => other.getClass.getName.stripSuffix("$") - } + val objectName = staticObject.getName.stripSuffix("$") + override def nullable: Boolean = true override def children: Seq[Expression] = arguments http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 383a2d0..0dbaeb8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -39,6 +39,7 @@ import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; @@ -608,6 +609,44 @@ public class JavaDatasetSuite implements Serializable { } } + public class SimpleJavaBean2 implements Serializable { + private Timestamp a; + private Date b; + private java.math.BigDecimal c; + + public Timestamp getA() { return a; } + + public void setA(Timestamp a) { this.a = a; } + + public Date getB() { return b; } + + public void setB(Date b) { this.b = b; } + + public java.math.BigDecimal getC() { return c; } + + public void setC(java.math.BigDecimal c) { this.c = c; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (!a.equals(that.a)) return false; + if (!b.equals(that.b)) return false; + return c.equals(that.c); + } + + @Override + public int hashCode() { + int result = a.hashCode(); + result = 31 * result + b.hashCode(); + result = 31 * result + c.hashCode(); + return result; + } + } + public class NestedJavaBean implements Serializable { private SimpleJavaBean a; @@ -689,4 +728,17 @@ public class JavaDatasetSuite implements Serializable { .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } + + @Test + public void testJavaBeanEncoder2() { + // This is a regression test of SPARK-12404 + OuterScopes.addOuterScope(this); + SimpleJavaBean2 obj = new SimpleJavaBean2(); + obj.setA(new Timestamp(0)); + obj.setB(new Date(0)); + obj.setC(java.math.BigDecimal.valueOf(1)); + Dataset<SimpleJavaBean2> ds = + context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + ds.collect(); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/6eba6552/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f1b6b98..de012a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.sql.{Date, Timestamp} import scala.language.postfixOps @@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + + test("SPARK-12404: Datatype Helper Serializablity") { + val ds = sparkContext.parallelize(( + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() + + ds.collect() + } + test("collect, first, and take should use encoders for serialization") { val item = NonSerializableCaseClass("abcd") val ds = Seq(item).toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org