Repository: spark Updated Branches: refs/heads/master aec502d91 -> 127a6678d
[SPARK-15489][SQL] Dataset kryo encoder won't load custom user settings ## What changes were proposed in this pull request? Serializer instantiation will consider existing SparkConf ## How was this patch tested? manual test with `ImmutableList` (Guava) and `kryo-serializers`'s `Immutable*Serializer` implementations. Added Test Suite. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Sela <ans...@paypal.com> Closes #13424 from amitsela/SPARK-15489. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/127a6678 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/127a6678 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/127a6678 Branch: refs/heads/master Commit: 127a6678d7af6b5164a115be7c64525bb80001fe Parents: aec502d Author: Sela <ans...@paypal.com> Authored: Fri Jun 10 14:36:51 2016 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Fri Jun 10 14:36:51 2016 -0700 ---------------------------------------------------------------------- .../catalyst/expressions/objects/objects.scala | 30 ++++++--- .../sql/DatasetSerializerRegistratorSuite.scala | 68 ++++++++++++++++++++ 2 files changed, 89 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/127a6678/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 87c8a2e..c597a2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Modifier import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -547,11 +547,17 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) } } + // try conf from env, otherwise create a new one + val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addMutableState( - serializerInstanceClass, - serializer, - s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + val serializerInit = s""" + if ($env == null) { + $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + } else { + $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + } + """ + ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) // Code to serialize. val input = child.genCode(ctx) @@ -587,11 +593,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) } } + // try conf from env, otherwise create a new one + val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addMutableState( - serializerInstanceClass, - serializer, - s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + val serializerInit = s""" + if ($env == null) { + $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + } else { + $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + } + """ + ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) // Code to deserialize. val input = child.genCode(ctx) http://git-wip-us.apache.org/repos/asf/spark/blob/127a6678/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala new file mode 100644 index 0000000..0f3d0ce --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.serializer.KryoRegistrator +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.TestSparkSession + +/** + * Test suite to test Kryo custom registrators. + */ +class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + /** + * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. + */ + protected override def beforeAll(): Unit = { + sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) + super.beforeAll() + } + + test("Kryo registrator") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + assert(ds.collect().toSet == Set(KryoData(0), KryoData(0))) + } + +} + +/** Used to test user provided registrator. */ +class TestRegistrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo): Unit = + kryo.register(classOf[KryoData], new ZeroKryoDataSerializer()) +} + +object TestRegistrator { + def apply(): TestRegistrator = new TestRegistrator() +} + +/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */ +class ZeroKryoDataSerializer extends Serializer[KryoData] { + override def write(kryo: Kryo, output: Output, t: KryoData): Unit = { + output.writeInt(0) + } + + override def read(kryo: Kryo, input: Input, aClass: Class[KryoData]): KryoData = { + KryoData(input.readInt()) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org