Repository: spark
Updated Branches:
  refs/heads/master c6f01cade -> a36b78b0e


[SPARK-22450][CORE][MLLIB][FOLLOWUP] safely register class for mllib - 
LabeledPoint/VectorWithNorm/TreePoint

## What changes were proposed in this pull request?
register following classes in Kryo:
`org.apache.spark.mllib.regression.LabeledPoint`
`org.apache.spark.mllib.clustering.VectorWithNorm`
`org.apache.spark.ml.feature.LabeledPoint`
`org.apache.spark.ml.tree.impl.TreePoint`

`org.apache.spark.ml.tree.impl.BaggedPoint` seems also need to be registered, 
but I don't know how to do it in this safe way.
WeichenXu123 cloud-fan

## How was this patch tested?
added tests

Author: Zheng RuiFeng <ruife...@foxmail.com>

Closes #19950 from zhengruifeng/labeled_kryo.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a36b78b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a36b78b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a36b78b0

Branch: refs/heads/master
Commit: a36b78b0e420b909bde0cec4349cdc2103853b91
Parents: c6f01ca
Author: Zheng RuiFeng <ruife...@foxmail.com>
Authored: Thu Dec 21 20:20:04 2017 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Dec 21 20:20:04 2017 -0600

----------------------------------------------------------------------
 .../spark/serializer/KryoSerializer.scala       | 27 ++++++-----
 .../apache/spark/ml/feature/InstanceSuit.scala  | 47 --------------------
 .../apache/spark/ml/feature/InstanceSuite.scala | 45 +++++++++++++++++++
 .../spark/ml/feature/LabeledPointSuite.scala    | 39 ++++++++++++++++
 .../spark/ml/tree/impl/TreePointSuite.scala     | 35 +++++++++++++++
 .../spark/mllib/clustering/KMeansSuite.scala    | 18 +++++++-
 .../mllib/regression/LabeledPointSuite.scala    | 18 +++++++-
 7 files changed, 169 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 2259d1a..538ae05 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -181,20 +181,25 @@ class KryoSerializer(conf: SparkConf)
 
     // We can't load those class directly in order to avoid unnecessary jar 
dependencies.
     // We load them safely, ignore it if the class not found.
-    Seq("org.apache.spark.mllib.linalg.Vector",
-      "org.apache.spark.mllib.linalg.DenseVector",
-      "org.apache.spark.mllib.linalg.SparseVector",
-      "org.apache.spark.mllib.linalg.Matrix",
-      "org.apache.spark.mllib.linalg.DenseMatrix",
-      "org.apache.spark.mllib.linalg.SparseMatrix",
-      "org.apache.spark.ml.linalg.Vector",
+    Seq(
+      "org.apache.spark.ml.feature.Instance",
+      "org.apache.spark.ml.feature.LabeledPoint",
+      "org.apache.spark.ml.feature.OffsetInstance",
+      "org.apache.spark.ml.linalg.DenseMatrix",
       "org.apache.spark.ml.linalg.DenseVector",
-      "org.apache.spark.ml.linalg.SparseVector",
       "org.apache.spark.ml.linalg.Matrix",
-      "org.apache.spark.ml.linalg.DenseMatrix",
       "org.apache.spark.ml.linalg.SparseMatrix",
-      "org.apache.spark.ml.feature.Instance",
-      "org.apache.spark.ml.feature.OffsetInstance"
+      "org.apache.spark.ml.linalg.SparseVector",
+      "org.apache.spark.ml.linalg.Vector",
+      "org.apache.spark.ml.tree.impl.TreePoint",
+      "org.apache.spark.mllib.clustering.VectorWithNorm",
+      "org.apache.spark.mllib.linalg.DenseMatrix",
+      "org.apache.spark.mllib.linalg.DenseVector",
+      "org.apache.spark.mllib.linalg.Matrix",
+      "org.apache.spark.mllib.linalg.SparseMatrix",
+      "org.apache.spark.mllib.linalg.SparseVector",
+      "org.apache.spark.mllib.linalg.Vector",
+      "org.apache.spark.mllib.regression.LabeledPoint"
     ).foreach { name =>
       try {
         val clazz = Utils.classForName(name)

http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala
deleted file mode 100644
index 88c85a9..0000000
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuit.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.feature
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.serializer.KryoSerializer
-
-class InstanceSuit extends SparkFunSuite{
-  test("Kryo class register") {
-    val conf = new SparkConf(false)
-    conf.set("spark.kryo.registrationRequired", "true")
-
-    val ser = new KryoSerializer(conf)
-    val serInstance = new KryoSerializer(conf).newInstance()
-
-    def check[T: ClassTag](t: T) {
-      assert(serInstance.deserialize[T](serInstance.serialize(t)) === t)
-    }
-
-    val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0))
-    val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse)
-    val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0))
-    val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 
5.0).toSparse)
-    check(instance1)
-    check(instance2)
-    check(oInstance1)
-    check(oInstance2)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala
new file mode 100644
index 0000000..cca7399
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.serializer.KryoSerializer
+
+class InstanceSuite extends SparkFunSuite{
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val instance1 = Instance(19.0, 2.0, Vectors.dense(1.0, 7.0))
+    val instance2 = Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse)
+    Seq(instance1, instance2).foreach { i =>
+      val i2 = ser.deserialize[Instance](ser.serialize(i))
+      assert(i === i2)
+    }
+
+    val oInstance1 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0))
+    val oInstance2 = OffsetInstance(0.2, 1.0, 2.0, Vectors.dense(0.0, 
5.0).toSparse)
+    Seq(oInstance1, oInstance2).foreach { o =>
+      val o2 = ser.deserialize[OffsetInstance](ser.serialize(o))
+      assert(o === o2)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala
new file mode 100644
index 0000000..05c7a58
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LabeledPointSuite.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.serializer.KryoSerializer
+
+class LabeledPointSuite extends SparkFunSuite {
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0)))
+    val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), 
Array(1.0, 2.0)))
+
+    Seq(labeled1, labeled2).foreach { l =>
+      val l2 = ser.deserialize[LabeledPoint](ser.serialize(l))
+      assert(l === l2)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala
new file mode 100644
index 0000000..f41abe4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.KryoSerializer
+
+class TreePointSuite extends SparkFunSuite {
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val point = new TreePoint(1.0, Array(1, 2, 3))
+    val point2 = ser.deserialize[TreePoint](ser.serialize(point))
+    assert(point.label === point2.label)
+    assert(point.binnedFeatures === point2.binnedFeatures)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 48bd41d..00d7e2f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.mllib.clustering
 
 import scala.util.Random
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, 
Vectors}
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, 
MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.util.Utils
 
 class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -311,6 +312,21 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1))
   }
 
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val vec1 = new VectorWithNorm(Vectors.dense(Array(1.0, 2.0)))
+    val vec2 = new VectorWithNorm(Vectors.sparse(10, Array(5, 8), Array(1.0, 
2.0)))
+
+    Seq(vec1, vec2).foreach { v =>
+      val v2 = ser.deserialize[VectorWithNorm](ser.serialize(v))
+      assert(v2.norm === v.norm)
+      assert(v2.vector === v.vector)
+    }
+  }
 }
 
 object KMeansSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/a36b78b0/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
index 252a068..c1449ec 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.mllib.regression
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}
 import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.serializer.KryoSerializer
 
 class LabeledPointSuite extends SparkFunSuite {
 
@@ -53,4 +54,19 @@ class LabeledPointSuite extends SparkFunSuite {
       assert(p1 === LabeledPoint.fromML(p2))
     }
   }
+
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val labeled1 = LabeledPoint(1.0, Vectors.dense(Array(1.0, 2.0)))
+    val labeled2 = LabeledPoint(1.0, Vectors.sparse(10, Array(5, 7), 
Array(1.0, 2.0)))
+
+    Seq(labeled1, labeled2).foreach { l =>
+      val l2 = ser.deserialize[LabeledPoint](ser.serialize(l))
+      assert(l === l2)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to