Repository: spark Updated Branches: refs/heads/master cc567b663 -> 21fac5434
[SPARK-11766][MLLIB] add toJson/fromJson to Vector/Vectors This is to support JSON serialization of Param[Vector] in the pipeline API. It could be used for other purposes too. The schema is the same as `VectorUDT`. jkbradley Author: Xiangrui Meng <m...@databricks.com> Closes #9751 from mengxr/SPARK-11766. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21fac543 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21fac543 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21fac543 Branch: refs/heads/master Commit: 21fac5434174389e8b83a2f11341fa7c9e360bfd Parents: cc567b6 Author: Xiangrui Meng <m...@databricks.com> Authored: Tue Nov 17 10:17:16 2015 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Nov 17 10:17:16 2015 -0800 ---------------------------------------------------------------------- .../org/apache/spark/mllib/linalg/Vectors.scala | 45 ++++++++++++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 17 ++++++++ project/MimaExcludes.scala | 4 ++ 3 files changed, 66 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/21fac543/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index bd9badc..4dcf351 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -24,6 +24,9 @@ import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson} import org.apache.spark.SparkException import org.apache.spark.annotation.{AlphaComponent, Since} @@ -171,6 +174,12 @@ sealed trait Vector extends Serializable { */ @Since("1.5.0") def argmax: Int + + /** + * Converts the vector to a JSON string. + */ + @Since("1.6.0") + def toJson: String } /** @@ -339,6 +348,27 @@ object Vectors { parseNumeric(NumericParser.parse(s)) } + /** + * Parses the JSON representation of a vector into a [[Vector]]. + */ + @Since("1.6.0") + def fromJson(json: String): Vector = { + implicit val formats = DefaultFormats + val jValue = parseJson(json) + (jValue \ "type").extract[Int] match { + case 0 => // sparse + val size = (jValue \ "size").extract[Int] + val indices = (jValue \ "indices").extract[Seq[Int]].toArray + val values = (jValue \ "values").extract[Seq[Double]].toArray + sparse(size, indices, values) + case 1 => // dense + val values = (jValue \ "values").extract[Seq[Double]].toArray + dense(values) + case _ => + throw new IllegalArgumentException(s"Cannot parse $json into a vector.") + } + } + private[mllib] def parseNumeric(any: Any): Vector = { any match { case values: Array[Double] => @@ -650,6 +680,12 @@ class DenseVector @Since("1.0.0") ( maxIdx } } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 1) ~ ("values" -> values.toSeq) + compact(render(jValue)) + } } @Since("1.3.0") @@ -837,6 +873,15 @@ class SparseVector @Since("1.0.0") ( }.unzip new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray) } + + @Since("1.6.0") + override def toJson: String = { + val jValue = ("type" -> 0) ~ + ("size" -> size) ~ + ("indices" -> indices.toSeq) ~ + ("values" -> values.toSeq) + compact(render(jValue)) + } } @Since("1.3.0") http://git-wip-us.apache.org/repos/asf/spark/blob/21fac543/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 6508dde..f895e2a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} +import org.json4s.jackson.JsonMethods.{parse => parseJson} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ @@ -374,4 +375,20 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("toJson/fromJson") { + val sv0 = Vectors.sparse(0, Array.empty, Array.empty) + val sv1 = Vectors.sparse(1, Array.empty, Array.empty) + val sv2 = Vectors.sparse(2, Array(1), Array(2.0)) + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0) + val dv2 = Vectors.dense(0.0, 2.0) + for (v <- Seq(sv0, sv1, sv2, dv0, dv1, dv2)) { + val json = v.toJson + parseJson(json) // `json` should be a valid JSON string + val u = Vectors.fromJson(json) + assert(u.getClass === v.getClass, "toJson/fromJson should preserve vector types.") + assert(u === v, "toJson/fromJson should preserve vector values.") + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/21fac543/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5022079..8159518 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -137,6 +137,10 @@ object MimaExcludes { ) ++ Seq ( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.status.api.v1.ApplicationInfo.this") + ) ++ Seq( + // SPARK-11766 add toJson to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toJson") ) case v if v.startsWith("1.5") => Seq( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org