Repository: spark
Updated Branches:
  refs/heads/master cc567b663 -> 21fac5434


[SPARK-11766][MLLIB] add toJson/fromJson to Vector/Vectors

This is to support JSON serialization of Param[Vector] in the pipeline API. It 
could be used for other purposes too. The schema is the same as `VectorUDT`. 
jkbradley

Author: Xiangrui Meng <m...@databricks.com>

Closes #9751 from mengxr/SPARK-11766.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21fac543
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21fac543
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21fac543

Branch: refs/heads/master
Commit: 21fac5434174389e8b83a2f11341fa7c9e360bfd
Parents: cc567b6
Author: Xiangrui Meng <m...@databricks.com>
Authored: Tue Nov 17 10:17:16 2015 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Nov 17 10:17:16 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/mllib/linalg/Vectors.scala | 45 ++++++++++++++++++++
 .../spark/mllib/linalg/VectorsSuite.scala       | 17 ++++++++
 project/MimaExcludes.scala                      |  4 ++
 3 files changed, 66 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21fac543/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index bd9badc..4dcf351 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -24,6 +24,9 @@ import scala.annotation.varargs
 import scala.collection.JavaConverters._
 
 import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson}
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{AlphaComponent, Since}
@@ -171,6 +174,12 @@ sealed trait Vector extends Serializable {
    */
   @Since("1.5.0")
   def argmax: Int
+
+  /**
+   * Converts the vector to a JSON string.
+   */
+  @Since("1.6.0")
+  def toJson: String
 }
 
 /**
@@ -339,6 +348,27 @@ object Vectors {
     parseNumeric(NumericParser.parse(s))
   }
 
+  /**
+   * Parses the JSON representation of a vector into a [[Vector]].
+   */
+  @Since("1.6.0")
+  def fromJson(json: String): Vector = {
+    implicit val formats = DefaultFormats
+    val jValue = parseJson(json)
+    (jValue \ "type").extract[Int] match {
+      case 0 => // sparse
+        val size = (jValue \ "size").extract[Int]
+        val indices = (jValue \ "indices").extract[Seq[Int]].toArray
+        val values = (jValue \ "values").extract[Seq[Double]].toArray
+        sparse(size, indices, values)
+      case 1 => // dense
+        val values = (jValue \ "values").extract[Seq[Double]].toArray
+        dense(values)
+      case _ =>
+        throw new IllegalArgumentException(s"Cannot parse $json into a 
vector.")
+    }
+  }
+
   private[mllib] def parseNumeric(any: Any): Vector = {
     any match {
       case values: Array[Double] =>
@@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") (
       maxIdx
     }
   }
+
+  @Since("1.6.0")
+  override def toJson: String = {
+    val jValue = ("type" -> 1) ~ ("values" -> values.toSeq)
+    compact(render(jValue))
+  }
 }
 
 @Since("1.3.0")
@@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") (
     }.unzip
     new SparseVector(selectedIndices.length, sliceInds.toArray, 
sliceVals.toArray)
   }
+
+  @Since("1.6.0")
+  override def toJson: String = {
+    val jValue = ("type" -> 0) ~
+      ("size" -> size) ~
+      ("indices" -> indices.toSeq) ~
+      ("values" -> values.toSeq)
+    compact(render(jValue))
+  }
 }
 
 @Since("1.3.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/21fac543/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 6508dde..f895e2a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg
 import scala.util.Random
 
 import breeze.linalg.{DenseMatrix => BDM, squaredDistance => 
breezeSquaredDistance}
+import org.json4s.jackson.JsonMethods.{parse => parseJson}
 
 import org.apache.spark.{Logging, SparkException, SparkFunSuite}
 import org.apache.spark.mllib.util.TestingUtils._
@@ -374,4 +375,20 @@ class VectorsSuite extends SparkFunSuite with Logging {
     assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
     assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), 
Array(2.2, 4.4)))
   }
+
+  test("toJson/fromJson") {
+    val sv0 = Vectors.sparse(0, Array.empty, Array.empty)
+    val sv1 = Vectors.sparse(1, Array.empty, Array.empty)
+    val sv2 = Vectors.sparse(2, Array(1), Array(2.0))
+    val dv0 = Vectors.dense(Array.empty[Double])
+    val dv1 = Vectors.dense(1.0)
+    val dv2 = Vectors.dense(0.0, 2.0)
+    for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) {
+      val json = v.toJson
+      parseJson(json) // `json` should be a valid JSON string
+      val u = Vectors.fromJson(json)
+      assert(u.getClass === v.getClass, "toJson/fromJson should preserve 
vector types.")
+      assert(u === v, "toJson/fromJson should preserve vector values.")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/21fac543/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 5022079..8159518 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -137,6 +137,10 @@ object MimaExcludes {
       ) ++ Seq (
         ProblemFilters.exclude[MissingMethodProblem](
           "org.apache.spark.status.api.v1.ApplicationInfo.this")
+      ) ++ Seq(
+        // SPARK-11766 add toJson to Vector
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.mllib.linalg.Vector.toJson")
       )
     case v if v.startsWith("1.5") =>
       Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to