Repository: spark Updated Branches: refs/heads/master 7d2ed8cc0 -> f9d578eaa
[SPARK-13783][ML] Model export/import for spark.ml: GBTs ## What changes were proposed in this pull request? * Added save/load for ```GBTClassifier/GBTClassificationModel/GBTRegressor/GBTRegressionModel```. * Meanwhile, I modified ```EnsembleModelReadWrite.saveImpl/loadImpl``` to support save/load ```treeWeights```. ## How was this patch tested? Adds standard unit tests for GBT save/load. cc jkbradley GayathriMurali Author: Yanbo Liang <yblia...@gmail.com> Closes #12230 from yanboliang/spark-13783. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9d578ea Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9d578ea Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9d578ea Branch: refs/heads/master Commit: f9d578eaa107d8e8503c1563a2b3990c85104298 Parents: 7d2ed8c Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Apr 13 11:31:10 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Apr 13 11:31:10 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/classification/GBTClassifier.scala | 110 +++++++++++------- .../classification/RandomForestClassifier.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 114 ++++++++++++------- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeModels.scala | 25 ++-- .../org/apache/spark/ml/tree/treeParams.scala | 73 +++++++++++- .../ml/classification/GBTClassifierSuite.scala | 37 +++--- .../spark/ml/regression/GBTRegressorSuite.scala | 36 +++--- 8 files changed, 262 insertions(+), 137 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 46e8b89..39a698a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -18,19 +18,21 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +60,7 @@ import org.apache.spark.sql.functions._ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] - with GBTParams with TreeClassifierParams with Logging { + with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) @@ -115,40 +117,12 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTClassifier: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "logistic" - * (default = logistic) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "logistic") + // Parameters from GBTClassifierParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "logistic" => OldLogLoss - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") - } - } - override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -175,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental -object GBTClassifier { - // The losses below should be lowercase. +object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { + /** Accessor for supported loss settings: logistic */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTClassifier = super.load(path) } /** @@ -199,7 +176,8 @@ final class GBTClassificationModel private[ml]( private val _treeWeights: Array[Double], @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + @@ -267,12 +245,62 @@ final class GBTClassificationModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } -private[ml] object GBTClassificationModel { +@Since("2.0.0") +object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader + + @Since("2.0.0") + override def load(path: String): GBTClassificationModel = super.load(path) + + private[GBTClassificationModel] + class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTClassificationModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 9d80b8e..dfa711b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -294,7 +294,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 0b52fe2..741724d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -18,19 +18,20 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, - SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +59,7 @@ import org.apache.spark.sql.functions._ @Experimental final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] - with GBTParams with TreeRegressorParams with Logging { + with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) @@ -112,41 +113,12 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTRegressor: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "squared") + // Parameters from GBTRegressorParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") - } - } - override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -164,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") @Experimental -object GBTRegressor { - // The losses below should be lowercase. +object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTRegressor = super.load(path) } /** @@ -188,7 +163,8 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + @@ -255,12 +231,64 @@ final class GBTRegressionModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } -private[ml] object GBTRegressionModel { +@Since("2.0.0") +object GBTRegressionModel extends MLReadable[GBTRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GBTRegressionModel = super.load(path) + + private[GBTRegressionModel] + class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + + require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, categoricalFeatures: Map[Int, Int], http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index bee13c2..4c4ff27 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -249,7 +249,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index c4ab673..f38e1ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite { sql: SQLContext, extraMetadata: JObject): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) - val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map { + val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { case (tree, treeID) => - treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext) + (treeID, + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext), + instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString - sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata") + sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights") .write.parquet(treesMetadataPath) val dataPath = new Path(path, "data").toString val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { @@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SQLContext, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)]) = { + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite { } val treesMetadataPath = new Path(path, "treesMetadata").toString - val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath) - .select("treeID", "metadata").as[(Int, String)].rdd.map { - case (treeID: Int, json: String) => - treeID -> DefaultParamsReader.parseMetadata(json, treeClassName) + val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { + case (treeID: Int, json: String, weights: Double) => + treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights) } - val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect() + + val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadata = treesMetadataWeights.map(_._1) + val treesWeights = treesMetadataWeights.map(_._2) val dataPath = new Path(path, "data").toString val nodeData: Dataset[EnsembleNodeData] = @@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite { treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() - (metadata, treesMetadata.zip(rootNodes)) + (metadata, treesMetadata.zip(rootNodes), treesWeights) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 0767dc1..b678391 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -462,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** Get old Gradient Boosting Loss type */ private[ml] def getOldLossType: OldLoss } + +private[ml] object GBTClassifierParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } +} + +private[ml] object GBTRegressorParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 76d8c93..7e6aec6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTClassifierSuite.compareAPIs @@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) - val newModel = GBTClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTClassificationModel, + model2: GBTClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTClassifierSuite extends SparkFunSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/f9d578ea/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 3c11631..2163779 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -32,7 +32,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs @@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) - val newModel = GBTRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTRegressionModel, + model2: GBTRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTRegressorSuite extends SparkFunSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org