Repository: spark
Updated Branches:
  refs/heads/master dbbe14907 -> b0adb9f54


[SPARK-10386][MLLIB] PrefixSpanModel supports save/load

```PrefixSpanModel``` supports ```save/load```. It's similar with #9267.

cc jkbradley

Author: Yanbo Liang <yblia...@gmail.com>

Closes #10664 from yanboliang/spark-10386.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b0adb9f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b0adb9f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b0adb9f5

Branch: refs/heads/master
Commit: b0adb9f543fbac16ea14c64eef6ba032a9919039
Parents: dbbe149
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Apr 13 13:18:02 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Apr 13 13:18:02 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/mllib/fpm/PrefixSpan.scala | 96 +++++++++++++++++++-
 .../spark/mllib/fpm/JavaPrefixSpanSuite.java    | 37 ++++++++
 .../spark/mllib/fpm/PrefixSpanSuite.scala       | 31 +++++++
 3 files changed, 163 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b0adb9f5/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 4455681..4344ab1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
 
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
 import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -566,4 +576,88 @@ object PrefixSpan extends Logging {
 @Since("1.5.0")
 class PrefixSpanModel[Item] @Since("1.5.0") (
     @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
-  extends Serializable
+  extends Saveable with Serializable {
+
+  /**
+   * Save this model to the given path.
+   * It only works for Item datatypes supported by DataFrames.
+   *
+   * This saves:
+   *  - human-readable (JSON) model metadata to path/metadata/
+   *  - Parquet formatted data to path/data/
+   *
+   * The model may be loaded using [[PrefixSpanModel.load]].
+   *
+   * @param sc  Spark context used to save model data.
+   * @param path  Path specifying the directory in which to save this model.
+   *              If the directory already exists, this method throws an 
exception.
+   */
+  @Since("2.0.0")
+  override def save(sc: SparkContext, path: String): Unit = {
+    PrefixSpanModel.SaveLoadV1_0.save(this, path)
+  }
+
+  override protected val formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
+
+  @Since("2.0.0")
+  override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+    PrefixSpanModel.SaveLoadV1_0.load(sc, path)
+  }
+
+  private[fpm] object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel"
+
+    def save(model: PrefixSpanModel[_], path: String): Unit = {
+      val sc = model.freqSequences.sparkContext
+      val sqlContext = SQLContext.getOrCreate(sc)
+
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
+
+      // Get the type of item class
+      val sample = model.freqSequences.first().sequence(0)(0)
+      val className = sample.getClass.getCanonicalName
+      val classSymbol = 
runtimeMirror(getClass.getClassLoader).staticClass(className)
+      val tpe = classSymbol.selfType
+
+      val itemType = ScalaReflection.schemaFor(tpe).dataType
+      val fields = Array(StructField("sequence", 
ArrayType(ArrayType(itemType))),
+        StructField("freq", LongType))
+      val schema = StructType(fields)
+      val rowDataRDD = model.freqSequences.map { x =>
+        Row(x.sequence, x.freq)
+      }
+      sqlContext.createDataFrame(rowDataRDD, 
schema).write.parquet(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+      implicit val formats = DefaultFormats
+      val sqlContext = SQLContext.getOrCreate(sc)
+
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+
+      val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
+      val sample = freqSequences.select("sequence").head().get(0)
+      loadImpl(freqSequences, sample)
+    }
+
+    def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): 
PrefixSpanModel[Item] = {
+      val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map 
{ x =>
+        val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray
+        val freq = x.getLong(1)
+        new PrefixSpan.FreqSequence(sequence, freq)
+      }
+      new PrefixSpanModel(freqSequencesRDD)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b0adb9f5/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 34daf5f..8a67793 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.fpm;
 
+import java.io.File;
 import java.util.Arrays;
 import java.util.List;
 
@@ -28,6 +29,7 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.util.Utils;
 
 public class JavaPrefixSpanSuite {
   private transient JavaSparkContext sc;
@@ -64,4 +66,39 @@ public class JavaPrefixSpanSuite {
       long freq = freqSeq.freq();
     }
   }
+
+  @Test
+  public void runPrefixSpanSaveLoad() {
+    JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+      Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
+      Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 
2)),
+      Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
+      Arrays.asList(Arrays.asList(6))
+    ), 2);
+    PrefixSpan prefixSpan = new PrefixSpan()
+      .setMinSupport(0.5)
+      .setMaxPatternLength(5);
+    PrefixSpanModel<Integer> model = prefixSpan.run(sequences);
+
+    File tempDir = Utils.createTempDir(
+      System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite");
+    String outputPath = tempDir.getPath();
+
+    try {
+      model.save(sc.sc(), outputPath);
+      PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+      JavaRDD<FreqSequence<Integer>> freqSeqs = 
newModel.freqSequences().toJavaRDD();
+      List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
+      Assert.assertEquals(5, localFreqSeqs.size());
+      // Check that each frequent sequence could be materialized.
+      for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+        List<List<Integer>> seq = freqSeq.javaSequence();
+        long freq = freqSeq.freq();
+      }
+    } finally {
+      Utils.deleteRecursively(tempDir);
+    }
+
+
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b0adb9f5/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index a83e543..6d8c7b4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
 
 class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     compareResults(expected, model.freqSequences.collect())
   }
 
+  test("model save/load") {
+    val sequences = Seq(
+      Array(Array(1, 2), Array(3)),
+      Array(Array(1), Array(3, 2), Array(1, 2)),
+      Array(Array(1, 2), Array(5)),
+      Array(Array(6)))
+    val rdd = sc.parallelize(sequences, 2).cache()
+
+    val prefixSpan = new PrefixSpan()
+      .setMinSupport(0.5)
+      .setMaxPatternLength(5)
+    val model = prefixSpan.run(rdd)
+
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    try {
+      model.save(sc, path)
+      val newModel = PrefixSpanModel.load(sc, path)
+      val originalSet = model.freqSequences.collect().map { x =>
+        (x.sequence.map(_.toSet).toSeq, x.freq)
+      }.toSet
+      val newSet = newModel.freqSequences.collect().map { x =>
+        (x.sequence.map(_.toSet).toSeq, x.freq)
+      }.toSet
+      assert(originalSet === newSet)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
   private def compareResults[Item](
       expectedValue: Array[(Array[Array[Item]], Long)],
       actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to