Repository: spark
Updated Branches:
  refs/heads/branch-1.6 54685fa36 -> 80641c4fa


[SPARK-11618][ML] Minor refactoring of basic ML import/export

Refactoring
* separated overwrite and param save logic in DefaultParamsWriter
* added sparkVersion to DefaultParamsWriter

CC: mengxr

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #9587 from jkbradley/logreg-io.

(cherry picked from commit 18350a57004eb87cafa9504ff73affab4b818e06)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80641c4f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80641c4f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80641c4f

Branch: refs/heads/branch-1.6
Commit: 80641c4faf9b208728f22c7ecac5b0c298ee0c6d
Parents: 54685fa
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Tue Nov 10 11:36:43 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Nov 10 11:36:50 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/util/ReadWrite.scala    | 57 ++++++++++----------
 1 file changed, 30 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/80641c4f/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ea790e0..cbdf913 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite {
   protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
     SQLContext.getOrCreate(SparkContext.getOrCreate())
   }
+
+  /** Returns the [[SparkContext]] underlying [[sqlContext]] */
+  protected final def sc: SparkContext = sqlContext.sparkContext
 }
 
 /**
@@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite {
  */
 @Experimental
 @Since("1.6.0")
-abstract class Writer extends BaseReadWrite {
+abstract class Writer extends BaseReadWrite with Logging {
 
   protected var shouldOverwrite: Boolean = false
 
@@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite {
    */
   @Since("1.6.0")
   @throws[IOException]("If the input path already exists but overwrite is not 
enabled.")
-  def save(path: String): Unit
+  def save(path: String): Unit = {
+    val hadoopConf = sc.hadoopConfiguration
+    val fs = FileSystem.get(hadoopConf)
+    val p = new Path(path)
+    if (fs.exists(p)) {
+      if (shouldOverwrite) {
+        logInfo(s"Path $path already exists. It will be overwritten.")
+        // TODO: Revert back to the original content if save is not successful.
+        fs.delete(p, true)
+      } else {
+        throw new IOException(
+          s"Path $path already exists. Please use write.overwrite().save(path) 
to overwrite it.")
+      }
+    }
+    saveImpl(path)
+  }
+
+  /**
+   * [[save()]] handles overwriting and then calls this method.  Subclasses 
should override this
+   * method to implement the actual saving of the instance.
+   */
+  @Since("1.6.0")
+  protected def saveImpl(path: String): Unit
 
   /**
    * Overwrites if the output path already exists.
@@ -147,28 +172,9 @@ trait Readable[T] {
  * data (e.g., models with coefficients).
  * @param instance object to save
  */
-private[ml] class DefaultParamsWriter(instance: Params) extends Writer with 
Logging {
-
-  /**
-   * Saves the ML component to the input path.
-   */
-  override def save(path: String): Unit = {
-    val sc = sqlContext.sparkContext
-
-    val hadoopConf = sc.hadoopConfiguration
-    val fs = FileSystem.get(hadoopConf)
-    val p = new Path(path)
-    if (fs.exists(p)) {
-      if (shouldOverwrite) {
-        logInfo(s"Path $path already exists. It will be overwritten.")
-        // TODO: Revert back to the original content if save is not successful.
-        fs.delete(p, true)
-      } else {
-        throw new IOException(
-          s"Path $path already exists. Please use write.overwrite().save(path) 
to overwrite it.")
-      }
-    }
+private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
 
+  override protected def saveImpl(path: String): Unit = {
     val uid = instance.uid
     val cls = instance.getClass.getName
     val params = 
instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) 
extends Writer with Logg
     }.toList
     val metadata = ("class" -> cls) ~
       ("timestamp" -> System.currentTimeMillis()) ~
+      ("sparkVersion" -> sc.version) ~
       ("uid" -> uid) ~
       ("paramMap" -> jsonParams)
     val metadataPath = new Path(path, "metadata").toString
@@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) 
extends Writer with Logg
  */
 private[ml] class DefaultParamsReader[T] extends Reader[T] {
 
-  /**
-   * Loads the ML component from the input path.
-   */
   override def load(path: String): T = {
     implicit val format = DefaultFormats
-    val sc = sqlContext.sparkContext
     val metadataPath = new Path(path, "metadata").toString
     val metadataStr = sc.textFile(metadataPath, 1).first()
     val metadata = parse(metadataStr)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to