Repository: spark
Updated Branches:
  refs/heads/master 7d2ed8cc0 -> f9d578eaa


[SPARK-13783][ML] Model export/import for spark.ml: GBTs

## What changes were proposed in this pull request?
* Added save/load for 
```GBTClassifier/GBTClassificationModel/GBTRegressor/GBTRegressionModel```.
* Meanwhile, I modified ```EnsembleModelReadWrite.saveImpl/loadImpl``` to 
support save/load ```treeWeights```.

## How was this patch tested?
Adds standard unit tests for GBT save/load.

cc jkbradley GayathriMurali

Author: Yanbo Liang <yblia...@gmail.com>

Closes #12230 from yanboliang/spark-13783.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9d578ea
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9d578ea
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9d578ea

Branch: refs/heads/master
Commit: f9d578eaa107d8e8503c1563a2b3990c85104298
Parents: 7d2ed8c
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Apr 13 11:31:10 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Apr 13 11:31:10 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/GBTClassifier.scala | 110 +++++++++++-------
 .../classification/RandomForestClassifier.scala |   2 +-
 .../spark/ml/regression/GBTRegressor.scala      | 114 ++++++++++++-------
 .../ml/regression/RandomForestRegressor.scala   |   2 +-
 .../org/apache/spark/ml/tree/treeModels.scala   |  25 ++--
 .../org/apache/spark/ml/tree/treeParams.scala   |  73 +++++++++++-
 .../ml/classification/GBTClassifierSuite.scala  |  37 +++---
 .../spark/ml/regression/GBTRegressorSuite.scala |  36 +++---
 8 files changed, 262 insertions(+), 137 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 46e8b89..39a698a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -18,19 +18,21 @@
 package org.apache.spark.ml.classification
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, 
TreeEnsembleModel}
+import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => 
OldLoss}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => 
OldGBTModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -58,7 +60,7 @@ import org.apache.spark.sql.functions._
 final class GBTClassifier @Since("1.4.0") (
     @Since("1.4.0") override val uid: String)
   extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
-  with GBTParams with TreeClassifierParams with Logging {
+  with GBTClassifierParams with DefaultParamsWritable with Logging {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("gbtc"))
@@ -115,40 +117,12 @@ final class GBTClassifier @Since("1.4.0") (
   @Since("1.4.0")
   override def setStepSize(value: Double): this.type = super.setStepSize(value)
 
-  // Parameters for GBTClassifier:
-
-  /**
-   * Loss function which GBT tries to minimize. (case-insensitive)
-   * Supported: "logistic"
-   * (default = logistic)
-   * @group param
-   */
-  @Since("1.4.0")
-  val lossType: Param[String] = new Param[String](this, "lossType", "Loss 
function which GBT" +
-    " tries to minimize (case-insensitive). Supported options:" +
-    s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
-    (value: String) => 
GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
-
-  setDefault(lossType -> "logistic")
+  // Parameters from GBTClassifierParams:
 
   /** @group setParam */
   @Since("1.4.0")
   def setLossType(value: String): this.type = set(lossType, value)
 
-  /** @group getParam */
-  @Since("1.4.0")
-  def getLossType: String = $(lossType).toLowerCase
-
-  /** (private[ml]) Convert new loss to old loss. */
-  override private[ml] def getOldLossType: OldLoss = {
-    getLossType match {
-      case "logistic" => OldLogLoss
-      case _ =>
-        // Should never happen because of check in setter method.
-        throw new RuntimeException(s"GBTClassifier was given bad loss type: 
$getLossType")
-    }
-  }
-
   override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -175,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") (
 
 @Since("1.4.0")
 @Experimental
-object GBTClassifier {
-  // The losses below should be lowercase.
+object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
+
   /** Accessor for supported loss settings: logistic */
   @Since("1.4.0")
-  final val supportedLossTypes: Array[String] = 
Array("logistic").map(_.toLowerCase)
+  final val supportedLossTypes: Array[String] = 
GBTClassifierParams.supportedLossTypes
+
+  @Since("2.0.0")
+  override def load(path: String): GBTClassifier = super.load(path)
 }
 
 /**
@@ -199,7 +176,8 @@ final class GBTClassificationModel private[ml](
     private val _treeWeights: Array[Double],
     @Since("1.6.0") override val numFeatures: Int)
   extends PredictionModel[Vector, GBTClassificationModel]
-  with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
+  with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+  with MLWritable with Serializable {
 
   require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
   require(_trees.length == _treeWeights.length, "GBTClassificationModel given 
trees, treeWeights" +
@@ -267,12 +245,62 @@ final class GBTClassificationModel private[ml](
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new 
GBTClassificationModel.GBTClassificationModelWriter(this)
 }
 
-private[ml] object GBTClassificationModel {
+@Since("2.0.0")
+object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[GBTClassificationModel] = new 
GBTClassificationModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): GBTClassificationModel = super.load(path)
+
+  private[GBTClassificationModel]
+  class GBTClassificationModelWriter(instance: GBTClassificationModel) extends 
MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+
+      val extraMetadata: JObject = Map(
+        "numFeatures" -> instance.numFeatures,
+        "numTrees" -> instance.getNumTrees)
+      EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, 
extraMetadata)
+    }
+  }
+
+  private class GBTClassificationModelReader extends 
MLReader[GBTClassificationModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[GBTClassificationModel].getName
+    private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+    override def load(path: String): GBTClassificationModel = {
+      implicit val format = DefaultFormats
+      val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
+        EnsembleModelReadWrite.loadImpl(path, sqlContext, className, 
treeClassName)
+      val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+      val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+      val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+        case (treeMetadata, root) =>
+          val tree =
+            new DecisionTreeRegressionModel(treeMetadata.uid, root, 
numFeatures)
+          DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+          tree
+      }
+      require(numTrees == trees.length, s"GBTClassificationModel.load expected 
$numTrees" +
+        s" trees based on metadata but found ${trees.length} trees.")
+      val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, 
numFeatures)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 
-  /** (private[ml]) Convert a model from the old API */
-  def fromOld(
+  /** Convert a model from the old API */
+  private[ml] def fromOld(
       oldModel: OldGBTModel,
       parent: GBTClassifier,
       categoricalFeatures: Map[Int, Int],

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 9d80b8e..dfa711b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -294,7 +294,7 @@ object RandomForestClassificationModel extends 
MLReadable[RandomForestClassifica
 
     override def load(path: String): RandomForestClassificationModel = {
       implicit val format = DefaultFormats
-      val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+      val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
         EnsembleModelReadWrite.loadImpl(path, sqlContext, className, 
treeClassName)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numClasses = (metadata.metadata \ "numClasses").extract[Int]

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 0b52fe2..741724d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -18,19 +18,20 @@
 package org.apache.spark.ml.regression
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, 
TreeRegressorParams}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, 
Loss => OldLoss,
-  SquaredError => OldSquaredError}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => 
OldGBTModel}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -58,7 +59,7 @@ import org.apache.spark.sql.functions._
 @Experimental
 final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: 
String)
   extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
-  with GBTParams with TreeRegressorParams with Logging {
+  with GBTRegressorParams with DefaultParamsWritable with Logging {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("gbtr"))
@@ -112,41 +113,12 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: Stri
   @Since("1.4.0")
   override def setStepSize(value: Double): this.type = super.setStepSize(value)
 
-  // Parameters for GBTRegressor:
-
-  /**
-   * Loss function which GBT tries to minimize. (case-insensitive)
-   * Supported: "squared" (L2) and "absolute" (L1)
-   * (default = squared)
-   * @group param
-   */
-  @Since("1.4.0")
-  val lossType: Param[String] = new Param[String](this, "lossType", "Loss 
function which GBT" +
-    " tries to minimize (case-insensitive). Supported options:" +
-    s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
-    (value: String) => 
GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
-
-  setDefault(lossType -> "squared")
+  // Parameters from GBTRegressorParams:
 
   /** @group setParam */
   @Since("1.4.0")
   def setLossType(value: String): this.type = set(lossType, value)
 
-  /** @group getParam */
-  @Since("1.4.0")
-  def getLossType: String = $(lossType).toLowerCase
-
-  /** (private[ml]) Convert new loss to old loss. */
-  override private[ml] def getOldLossType: OldLoss = {
-    getLossType match {
-      case "squared" => OldSquaredError
-      case "absolute" => OldAbsoluteError
-      case _ =>
-        // Should never happen because of check in setter method.
-        throw new RuntimeException(s"GBTRegressorParams was given bad loss 
type: $getLossType")
-    }
-  }
-
   override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -164,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: Stri
 
 @Since("1.4.0")
 @Experimental
-object GBTRegressor {
-  // The losses below should be lowercase.
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
   /** Accessor for supported loss settings: squared (L2), absolute (L1) */
   @Since("1.4.0")
-  final val supportedLossTypes: Array[String] = Array("squared", 
"absolute").map(_.toLowerCase)
+  final val supportedLossTypes: Array[String] = 
GBTRegressorParams.supportedLossTypes
+
+  @Since("2.0.0")
+  override def load(path: String): GBTRegressor = super.load(path)
 }
 
 /**
@@ -188,7 +163,8 @@ final class GBTRegressionModel private[ml](
     private val _treeWeights: Array[Double],
     override val numFeatures: Int)
   extends PredictionModel[Vector, GBTRegressionModel]
-  with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
+  with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+  with MLWritable with Serializable {
 
   require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
   require(_trees.length == _treeWeights.length, "GBTRegressionModel given 
trees, treeWeights of" +
@@ -255,12 +231,64 @@ final class GBTRegressionModel private[ml](
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new 
GBTRegressionModel.GBTRegressionModelWriter(this)
 }
 
-private[ml] object GBTRegressionModel {
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[GBTRegressionModel] = new 
GBTRegressionModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): GBTRegressionModel = super.load(path)
+
+  private[GBTRegressionModel]
+  class GBTRegressionModelWriter(instance: GBTRegressionModel) extends 
MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val extraMetadata: JObject = Map(
+        "numFeatures" -> instance.numFeatures,
+        "numTrees" -> instance.getNumTrees)
+      EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, 
extraMetadata)
+    }
+  }
+
+  private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[GBTRegressionModel].getName
+    private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+    override def load(path: String): GBTRegressionModel = {
+      implicit val format = DefaultFormats
+      val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
+        EnsembleModelReadWrite.loadImpl(path, sqlContext, className, 
treeClassName)
+
+      val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+      val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+      val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+        case (treeMetadata, root) =>
+          val tree =
+            new DecisionTreeRegressionModel(treeMetadata.uid, root, 
numFeatures)
+          DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+          tree
+      }
+
+      require(numTrees == trees.length, s"GBTRegressionModel.load expected 
$numTrees" +
+        s" trees based on metadata but found ${trees.length} trees.")
+
+      val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, 
numFeatures)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 
-  /** (private[ml]) Convert a model from the old API */
-  def fromOld(
+  /** Convert a model from the old API */
+  private[ml] def fromOld(
       oldModel: OldGBTModel,
       parent: GBTRegressor,
       categoricalFeatures: Map[Int, Int],

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index bee13c2..4c4ff27 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -249,7 +249,7 @@ object RandomForestRegressionModel extends 
MLReadable[RandomForestRegressionMode
 
     override def load(path: String): RandomForestRegressionModel = {
       implicit val format = DefaultFormats
-      val (metadata: Metadata, treesData: Array[(Metadata, Node)]) =
+      val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
         EnsembleModelReadWrite.loadImpl(path, sqlContext, className, 
treeClassName)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index c4ab673..f38e1ec 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite {
       sql: SQLContext,
       extraMetadata: JObject): Unit = {
     DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, 
Some(extraMetadata))
-    val treesMetadataJson: Array[(Int, String)] = 
instance.trees.zipWithIndex.map {
+    val treesMetadataWeights: Array[(Int, String, Double)] = 
instance.trees.zipWithIndex.map {
       case (tree, treeID) =>
-        treeID -> 
DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], 
sql.sparkContext)
+        (treeID,
+          DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], 
sql.sparkContext),
+          instance.treeWeights(treeID))
     }
     val treesMetadataPath = new Path(path, "treesMetadata").toString
-    sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata")
+    sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", 
"weights")
       .write.parquet(treesMetadataPath)
     val dataPath = new Path(path, "data").toString
     val nodeDataRDD = 
sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
@@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite {
       path: String,
       sql: SQLContext,
       className: String,
-      treeClassName: String): (Metadata, Array[(Metadata, Node)]) = {
+      treeClassName: String): (Metadata, Array[(Metadata, Node)], 
Array[Double]) = {
     import sql.implicits._
     implicit val format = DefaultFormats
     val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, 
className)
@@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite {
     }
 
     val treesMetadataPath = new Path(path, "treesMetadata").toString
-    val treesMetadataRDD: RDD[(Int, Metadata)] = 
sql.read.parquet(treesMetadataPath)
-      .select("treeID", "metadata").as[(Int, String)].rdd.map {
-      case (treeID: Int, json: String) =>
-        treeID -> DefaultParamsReader.parseMetadata(json, treeClassName)
+    val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = 
sql.read.parquet(treesMetadataPath)
+      .select("treeID", "metadata", "weights").as[(Int, String, 
Double)].rdd.map {
+      case (treeID: Int, json: String, weights: Double) =>
+        treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), 
weights)
     }
-    val treesMetadata: Array[Metadata] = 
treesMetadataRDD.sortByKey().values.collect()
+
+    val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
+    val treesMetadata = treesMetadataWeights.map(_._1)
+    val treesWeights = treesMetadataWeights.map(_._2)
 
     val dataPath = new Path(path, "data").toString
     val nodeData: Dataset[EnsembleNodeData] =
@@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite {
           treeID -> 
DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
       }
     val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
-    (metadata, treesMetadata.zip(rootNodes))
+    (metadata, treesMetadata.zip(rootNodes), treesWeights)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 0767dc1..b678391 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, 
BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => 
OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, 
LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
 /**
@@ -462,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams 
with HasMaxIter with HasS
   /** Get old Gradient Boosting Loss type */
   private[ml] def getOldLossType: OldLoss
 }
+
+private[ml] object GBTClassifierParams {
+  // The losses below should be lowercase.
+  /** Accessor for supported loss settings: logistic */
+  final val supportedLossTypes: Array[String] = 
Array("logistic").map(_.toLowerCase)
+}
+
+private[ml] trait GBTClassifierParams extends GBTParams with 
TreeClassifierParams {
+
+  /**
+   * Loss function which GBT tries to minimize. (case-insensitive)
+   * Supported: "logistic"
+   * (default = logistic)
+   * @group param
+   */
+  val lossType: Param[String] = new Param[String](this, "lossType", "Loss 
function which GBT" +
+    " tries to minimize (case-insensitive). Supported options:" +
+    s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+    (value: String) => 
GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase))
+
+  setDefault(lossType -> "logistic")
+
+  /** @group getParam */
+  def getLossType: String = $(lossType).toLowerCase
+
+  /** (private[ml]) Convert new loss to old loss. */
+  override private[ml] def getOldLossType: OldLoss = {
+    getLossType match {
+      case "logistic" => OldLogLoss
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(s"GBTClassifier was given bad loss type: 
$getLossType")
+    }
+  }
+}
+
+private[ml] object GBTRegressorParams {
+  // The losses below should be lowercase.
+  /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+  final val supportedLossTypes: Array[String] = Array("squared", 
"absolute").map(_.toLowerCase)
+}
+
+private[ml] trait GBTRegressorParams extends GBTParams with 
TreeRegressorParams {
+
+  /**
+   * Loss function which GBT tries to minimize. (case-insensitive)
+   * Supported: "squared" (L2) and "absolute" (L1)
+   * (default = squared)
+   * @group param
+   */
+  val lossType: Param[String] = new Param[String](this, "lossType", "Loss 
function which GBT" +
+    " tries to minimize (case-insensitive). Supported options:" +
+    s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+    (value: String) => 
GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase))
+
+  setDefault(lossType -> "squared")
+
+  /** @group getParam */
+  def getLossType: String = $(lossType).toLowerCase
+
+  /** (private[ml]) Convert new loss to old loss. */
+  override private[ml] def getOldLossType: OldLoss = {
+    getLossType match {
+      case "squared" => OldSquaredError
+      case "absolute" => OldAbsoluteError
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(s"GBTRegressorParams was given bad loss 
type: $getLossType")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 76d8c93..7e6aec6 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees 
=> OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -34,7 +34,8 @@ import org.apache.spark.util.Utils
 /**
  * Test suite for [[GBTClassifier]].
  */
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
 
   import GBTClassifierSuite.compareAPIs
 
@@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 
-  // TODO: Reinstate test once save/load are implemented  SPARK-6725
-  /*
   test("model save/load") {
-    val tempDir = Utils.createTempDir()
-    val path = tempDir.toURI.toString
-
-    val trees = Range(0, 3).map(_ => 
OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
-    val treeWeights = Array(0.1, 0.3, 1.1)
-    val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
-    val newModel = GBTClassificationModel.fromOld(oldModel)
-
-    // Save model, load it back, and compare.
-    try {
-      newModel.save(sc, path)
-      val sameNewModel = GBTClassificationModel.load(sc, path)
-      TreeTests.checkEqual(newModel, sameNewModel)
-    } finally {
-      Utils.deleteRecursively(tempDir)
+    def checkModelData(
+        model: GBTClassificationModel,
+        model2: GBTClassificationModel): Unit = {
+      TreeTests.checkEqual(model, model2)
+      assert(model.numFeatures === model2.numFeatures)
     }
+
+    val gbt = new GBTClassifier()
+    val rdd = TreeTests.getTreeReadWriteData(sc)
+
+    val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> 
"logistic")
+
+    val continuousData: DataFrame =
+      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, 
checkModelData)
   }
-  */
 }
 
 private object GBTClassifierSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 3c11631..2163779 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees 
=> OldGBT}
@@ -32,7 +32,8 @@ import org.apache.spark.util.Utils
 /**
  * Test suite for [[GBTRegressor]].
  */
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
 
   import GBTRegressorSuite.compareAPIs
 
@@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 
-  // TODO: Reinstate test once save/load are implemented  SPARK-6725
-  /*
   test("model save/load") {
-    val tempDir = Utils.createTempDir()
-    val path = tempDir.toURI.toString
-
-    val trees = Range(0, 3).map(_ => 
OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
-    val treeWeights = Array(0.1, 0.3, 1.1)
-    val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
-    val newModel = GBTRegressionModel.fromOld(oldModel)
-
-    // Save model, load it back, and compare.
-    try {
-      newModel.save(sc, path)
-      val sameNewModel = GBTRegressionModel.load(sc, path)
-      TreeTests.checkEqual(newModel, sameNewModel)
-    } finally {
-      Utils.deleteRecursively(tempDir)
+    def checkModelData(
+        model: GBTRegressionModel,
+        model2: GBTRegressionModel): Unit = {
+      TreeTests.checkEqual(model, model2)
+      assert(model.numFeatures === model2.numFeatures)
     }
+
+    val gbt = new GBTRegressor()
+    val rdd = TreeTests.getTreeReadWriteData(sc)
+
+    val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> 
"squared")
+    val continuousData: DataFrame =
+      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, 
checkModelData)
   }
-  */
 }
 
 private object GBTRegressorSuite extends SparkFunSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to