Repository: spark Updated Branches: refs/heads/master 73e64f7d5 -> c7270a46f
[SPARK-17139][ML] Add model summary for MultinomialLogisticRegression ## What changes were proposed in this pull request? Add 4 traits, using the following hierarchy: LogisticRegressionSummary LogisticRegressionTrainingSummary: LogisticRegressionSummary BinaryLogisticRegressionSummary: LogisticRegressionSummary BinaryLogisticRegressionTrainingSummary: LogisticRegressionTrainingSummary, BinaryLogisticRegressionSummary and the public method such as `def summary` only return trait type listed above. and then implement 4 concrete classes: LogisticRegressionSummaryImpl (multiclass case) LogisticRegressionTrainingSummaryImpl (multiclass case) BinaryLogisticRegressionSummaryImpl (binary case). BinaryLogisticRegressionTrainingSummaryImpl (binary case). ## How was this patch tested? Existing tests & added tests. Author: WeichenXu <weichenxu...@outlook.com> Closes #15435 from WeichenXu123/mlor_summary. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c7270a46 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c7270a46 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c7270a46 Branch: refs/heads/master Commit: c7270a46fc340db62c87ddfc6568603d0b832845 Parents: 73e64f7 Author: Weichen Xu <weichen...@databricks.com> Authored: Mon Aug 28 13:31:01 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Aug 28 13:31:01 2017 -0700 ---------------------------------------------------------------------- .../ml/classification/LogisticRegression.scala | 340 +++++++++++++++---- .../LogisticRegressionSuite.scala | 160 +++++++-- .../ml/regression/LinearRegressionSuite.scala | 2 +- project/MimaExcludes.scala | 21 +- 4 files changed, 412 insertions(+), 111 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21957d9..ffe4b52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -35,7 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils @@ -882,21 +882,28 @@ class LogisticRegression @Since("1.2.0") ( val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial)) - // TODO: implement summary model for multinomial case - val m = if (!isMultinomial) { - val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() - val logRegSummary = new BinaryLogisticRegressionTrainingSummary( + + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val logRegSummary = if (numClasses <= 2) { + new BinaryLogisticRegressionTrainingSummaryImpl( summaryModel.transform(dataset), probabilityColName, + predictionColName, $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(Some(logRegSummary)) } else { - model + new LogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + objectiveHistory) } - instr.logSuccess(m) - m + model.setSummary(Some(logRegSummary)) + instr.logSuccess(model) + model } @Since("1.4.0") @@ -1010,8 +1017,8 @@ class LogisticRegressionModel private[spark] ( private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None /** - * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * Gets summary of model on training set. An exception is thrown + * if `trainingSummary == None`. */ @Since("1.5.0") def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse { @@ -1019,18 +1026,36 @@ class LogisticRegressionModel private[spark] ( } /** - * If the probability column is set returns the current model and probability column, - * otherwise generates a new column and sets it as the probability column on a new copy - * of the current model. + * Gets summary of model on training set. An exception is thrown + * if `trainingSummary == None` or it is a multiclass model. */ - private[classification] def findSummaryModelAndProbabilityCol(): - (LogisticRegressionModel, String) = { - $(probabilityCol) match { - case "" => - val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString - (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) - case p => (this, p) + @Since("2.3.0") + def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match { + case b: BinaryLogisticRegressionTrainingSummary => b + case _ => + throw new RuntimeException("Cannot create a binary summary for a non-binary model" + + s"(numClasses=${numClasses}), use summary instead.") + } + + /** + * If the probability and prediction columns are set, this method returns the current model, + * otherwise it generates new columns for them and sets them as columns on a new copy of + * the current model + */ + private[classification] def findSummaryModel(): + (LogisticRegressionModel, String, String) = { + val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) { + copy(ParamMap.empty) + .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) + .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else if ($(probabilityCol).isEmpty) { + copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString) + } else if ($(predictionCol).isEmpty) { + copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString) + } else { + this } + (model, model.getProbabilityCol, model.getPredictionCol) } private[classification] @@ -1051,9 +1076,14 @@ class LogisticRegressionModel private[spark] ( @Since("2.0.0") def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = { // Handle possible missing or invalid prediction columns - val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() - new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), - probabilityColName, $(labelCol), $(featuresCol)) + val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() + if (numClasses > 2) { + new LogisticRegressionSummaryImpl(summaryModel.transform(dataset), + probabilityColName, predictionColName, $(labelCol), $(featuresCol)) + } else { + new BinaryLogisticRegressionSummaryImpl(summaryModel.transform(dataset), + probabilityColName, predictionColName, $(labelCol), $(featuresCol)) + } } /** @@ -1324,90 +1354,154 @@ private[ml] class MultiClassSummarizer extends Serializable { } /** - * Abstraction for multinomial Logistic Regression Training results. - * Currently, the training summary ignores the training weights except - * for the objective trace. - */ -sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { - - /** objective function (scaled loss + regularization) at each iteration. */ - def objectiveHistory: Array[Double] - - /** Number of training iterations until termination */ - def totalIterations: Int = objectiveHistory.length - -} - -/** - * Abstraction for Logistic Regression Results for a given model. + * :: Experimental :: + * Abstraction for logistic regression results for a given model. */ +@Experimental sealed trait LogisticRegressionSummary extends Serializable { /** * Dataframe output by the model's `transform` method. */ + @Since("1.5.0") def predictions: DataFrame /** Field in "predictions" which gives the probability of each class as a vector. */ + @Since("1.5.0") def probabilityCol: String + /** Field in "predictions" which gives the prediction of each class. */ + @Since("2.3.0") + def predictionCol: String + /** Field in "predictions" which gives the true label of each instance (if available). */ + @Since("1.5.0") def labelCol: String /** Field in "predictions" which gives the features of each instance as a vector. */ + @Since("1.6.0") def featuresCol: String + @transient private val multiclassMetrics = { + new MulticlassMetrics( + predictions.select( + col(predictionCol), + col(labelCol).cast(DoubleType)) + .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) }) + } + + /** + * Returns the sequence of labels in ascending order. This order matches the order used + * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel. + * + * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the + * training set is missing a label, then all of the arrays over labels + * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the + * expected numClasses. + */ + @Since("2.3.0") + def labels: Array[Double] = multiclassMetrics.labels + + /** Returns true positive rate for each label (category). */ + @Since("2.3.0") + def truePositiveRateByLabel: Array[Double] = recallByLabel + + /** Returns false positive rate for each label (category). */ + @Since("2.3.0") + def falsePositiveRateByLabel: Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label)) + } + + /** Returns precision for each label (category). */ + @Since("2.3.0") + def precisionByLabel: Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.precision(label)) + } + + /** Returns recall for each label (category). */ + @Since("2.3.0") + def recallByLabel: Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.recall(label)) + } + + /** Returns f-measure for each label (category). */ + @Since("2.3.0") + def fMeasureByLabel(beta: Double): Array[Double] = { + multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta)) + } + + /** Returns f1-measure for each label (category). */ + @Since("2.3.0") + def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0) + + /** + * Returns accuracy. + * (equals to the total number of correctly classified instances + * out of the total number of instances.) + */ + @Since("2.3.0") + def accuracy: Double = multiclassMetrics.accuracy + + /** + * Returns weighted true positive rate. + * (equals to precision, recall and f-measure) + */ + @Since("2.3.0") + def weightedTruePositiveRate: Double = weightedRecall + + /** Returns weighted false positive rate. */ + @Since("2.3.0") + def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate + + /** + * Returns weighted averaged recall. + * (equals to precision, recall and f-measure) + */ + @Since("2.3.0") + def weightedRecall: Double = multiclassMetrics.weightedRecall + + /** Returns weighted averaged precision. */ + @Since("2.3.0") + def weightedPrecision: Double = multiclassMetrics.weightedPrecision + + /** Returns weighted averaged f-measure. */ + @Since("2.3.0") + def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta) + + /** Returns weighted averaged f1-measure. */ + @Since("2.3.0") + def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0) } /** * :: Experimental :: - * Logistic regression training results. - * - * @param predictions dataframe output by the model's `transform` method. - * @param probabilityCol field in "predictions" which gives the probability of - * each class as a vector. - * @param labelCol field in "predictions" which gives the true label of each instance. - * @param featuresCol field in "predictions" which gives the features of each instance as a vector. - * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + * Abstraction for multiclass logistic regression training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. */ @Experimental -@Since("1.5.0") -class BinaryLogisticRegressionTrainingSummary private[classification] ( - predictions: DataFrame, - probabilityCol: String, - labelCol: String, - featuresCol: String, - @Since("1.5.0") val objectiveHistory: Array[Double]) - extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) - with LogisticRegressionTrainingSummary { +sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { + + /** objective function (scaled loss + regularization) at each iteration. */ + @Since("1.5.0") + def objectiveHistory: Array[Double] + + /** Number of training iterations. */ + @Since("1.5.0") + def totalIterations: Int = objectiveHistory.length } /** * :: Experimental :: - * Binary Logistic regression results for a given model. - * - * @param predictions dataframe output by the model's `transform` method. - * @param probabilityCol field in "predictions" which gives the probability of - * each class as a vector. - * @param labelCol field in "predictions" which gives the true label of each instance. - * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * Abstraction for binary logistic regression results for a given model. */ @Experimental -@Since("1.5.0") -class BinaryLogisticRegressionSummary private[classification] ( - @Since("1.5.0") @transient override val predictions: DataFrame, - @Since("1.5.0") override val probabilityCol: String, - @Since("1.5.0") override val labelCol: String, - @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary { - +sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary { private val sparkSession = predictions.sparkSession import sparkSession.implicits._ - /** - * Returns a BinaryClassificationMetrics object. - */ // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( @@ -1484,3 +1578,99 @@ class BinaryLogisticRegressionSummary private[classification] ( binaryMetrics.recallByThreshold().toDF("threshold", "recall") } } + +/** + * :: Experimental :: + * Abstraction for binary logistic regression training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. + */ +@Experimental +sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegressionSummary + with LogisticRegressionTrainingSummary + +/** + * Multiclass logistic regression training results. + * + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of + * each class as a vector. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +private class LogisticRegressionTrainingSummaryImpl( + predictions: DataFrame, + probabilityCol: String, + predictionCol: String, + labelCol: String, + featuresCol: String, + override val objectiveHistory: Array[Double]) + extends LogisticRegressionSummaryImpl( + predictions, probabilityCol, predictionCol, labelCol, featuresCol) + with LogisticRegressionTrainingSummary + +/** + * Multiclass logistic regression results for a given model. + * + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of + * each class as a vector. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + */ +private class LogisticRegressionSummaryImpl( + @transient override val predictions: DataFrame, + override val probabilityCol: String, + override val predictionCol: String, + override val labelCol: String, + override val featuresCol: String) + extends LogisticRegressionSummary + +/** + * Binary logistic regression training results. + * + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of + * each class as a vector. + * @param predictionCol field in "predictions" which gives the prediction for a data instance as a + * double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. + */ +private class BinaryLogisticRegressionTrainingSummaryImpl( + predictions: DataFrame, + probabilityCol: String, + predictionCol: String, + labelCol: String, + featuresCol: String, + override val objectiveHistory: Array[Double]) + extends BinaryLogisticRegressionSummaryImpl( + predictions, probabilityCol, predictionCol, labelCol, featuresCol) + with BinaryLogisticRegressionTrainingSummary + +/** + * Binary logistic regression results for a given model. + * + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of + * each class as a vector. + * @param predictionCol field in "predictions" which gives the prediction of + * each class as a double. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + */ +private class BinaryLogisticRegressionSummaryImpl( + predictions: DataFrame, + probabilityCol: String, + predictionCol: String, + labelCol: String, + featuresCol: String) + extends LogisticRegressionSummaryImpl( + predictions, probabilityCol, predictionCol, labelCol, featuresCol) + with BinaryLogisticRegressionSummary http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 542977a..6649fa4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -222,15 +222,58 @@ class LogisticRegressionSuite } } - test("empty probabilityCol") { - val lr = new LogisticRegression().setProbabilityCol("") - val model = lr.fit(smallBinaryDataset) - assert(model.hasSummary) - // Validate that we re-insert a probability column for evaluation - val fieldNames = model.summary.predictions.schema.fieldNames - assert(smallBinaryDataset.schema.fieldNames.toSet.subsetOf( - fieldNames.toSet)) - assert(fieldNames.exists(s => s.startsWith("probability_"))) + test("empty probabilityCol or predictionCol") { + val lr = new LogisticRegression().setMaxIter(1) + val datasetFieldNames = smallBinaryDataset.schema.fieldNames.toSet + def checkSummarySchema(model: LogisticRegressionModel, columns: Seq[String]): Unit = { + val fieldNames = model.summary.predictions.schema.fieldNames + assert(model.hasSummary) + assert(datasetFieldNames.subsetOf(fieldNames.toSet)) + columns.foreach { c => assert(fieldNames.exists(_.startsWith(c))) } + } + // check that the summary model adds the appropriate columns + Seq(("binomial", smallBinaryDataset), ("multinomial", smallMultinomialDataset)).foreach { + case (family, dataset) => + lr.setFamily(family) + lr.setProbabilityCol("").setPredictionCol("prediction") + val modelNoProb = lr.fit(dataset) + checkSummarySchema(modelNoProb, Seq("probability_")) + + lr.setProbabilityCol("probability").setPredictionCol("") + val modelNoPred = lr.fit(dataset) + checkSummarySchema(modelNoPred, Seq("prediction_")) + + lr.setProbabilityCol("").setPredictionCol("") + val modelNoPredNoProb = lr.fit(dataset) + checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_")) + } + } + + test("check summary types for binary and multiclass") { + val lr = new LogisticRegression() + .setFamily("binomial") + .setMaxIter(1) + + val blorModel = lr.fit(smallBinaryDataset) + assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + + val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset) + assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary]) + withClue("cannot get binary summary for multiclass model") { + intercept[RuntimeException] { + mlorModel.binarySummary + } + } + + val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset) + assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + + val blorSummary = blorModel.evaluate(smallBinaryDataset) + val mlorSummary = mlorModel.evaluate(smallMultinomialDataset) + assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary]) + assert(mlorSummary.isInstanceOf[LogisticRegressionSummary]) } test("setThreshold, getThreshold") { @@ -2341,51 +2384,98 @@ class LogisticRegressionSuite } test("evaluate on test set") { - // TODO: add for multiclass when model summary becomes available // Evaluate on test set should be same as that of the transformed training data. val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) .setThreshold(0.6) - val model = lr.fit(smallBinaryDataset) - val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary] - - val sameSummary = - model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] - assert(summary.areaUnderROC === sameSummary.areaUnderROC) - assert(summary.roc.collect() === sameSummary.roc.collect()) - assert(summary.pr.collect === sameSummary.pr.collect()) + .setFamily("binomial") + val blorModel = lr.fit(smallBinaryDataset) + val blorSummary = blorModel.binarySummary + + val sameBlorSummary = + blorModel.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] + assert(blorSummary.areaUnderROC === sameBlorSummary.areaUnderROC) + assert(blorSummary.roc.collect() === sameBlorSummary.roc.collect()) + assert(blorSummary.pr.collect === sameBlorSummary.pr.collect()) + assert( + blorSummary.fMeasureByThreshold.collect() === sameBlorSummary.fMeasureByThreshold.collect()) assert( - summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect()) - assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect()) + blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect()) assert( - summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) + blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect()) + + lr.setFamily("multinomial") + val mlorModel = lr.fit(smallMultinomialDataset) + val mlorSummary = mlorModel.summary + + val mlorSameSummary = mlorModel.evaluate(smallMultinomialDataset) + + assert(mlorSummary.truePositiveRateByLabel === mlorSameSummary.truePositiveRateByLabel) + assert(mlorSummary.falsePositiveRateByLabel === mlorSameSummary.falsePositiveRateByLabel) + assert(mlorSummary.precisionByLabel === mlorSameSummary.precisionByLabel) + assert(mlorSummary.recallByLabel === mlorSameSummary.recallByLabel) + assert(mlorSummary.fMeasureByLabel === mlorSameSummary.fMeasureByLabel) + assert(mlorSummary.accuracy === mlorSameSummary.accuracy) + assert(mlorSummary.weightedTruePositiveRate === mlorSameSummary.weightedTruePositiveRate) + assert(mlorSummary.weightedFalsePositiveRate === mlorSameSummary.weightedFalsePositiveRate) + assert(mlorSummary.weightedPrecision === mlorSameSummary.weightedPrecision) + assert(mlorSummary.weightedRecall === mlorSameSummary.weightedRecall) + assert(mlorSummary.weightedFMeasure === mlorSameSummary.weightedFMeasure) } test("evaluate with labels that are not doubles") { // Evaluate a test set with Label that is a numeric type other than Double - val lr = new LogisticRegression() + val blor = new LogisticRegression() .setMaxIter(1) .setRegParam(1.0) - val model = lr.fit(smallBinaryDataset) - val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary] + .setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + val blorSummary = blorModel.evaluate(smallBinaryDataset) + .asInstanceOf[BinaryLogisticRegressionSummary] + + val blorLongLabelData = smallBinaryDataset.select(col(blorModel.getLabelCol).cast(LongType), + col(blorModel.getFeaturesCol)) + val blorLongSummary = blorModel.evaluate(blorLongLabelData) + .asInstanceOf[BinaryLogisticRegressionSummary] + + assert(blorSummary.areaUnderROC === blorLongSummary.areaUnderROC) - val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType), - col(model.getFeaturesCol)) - val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary] + val mlor = new LogisticRegression() + .setMaxIter(1) + .setRegParam(1.0) + .setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) + val mlorSummary = mlorModel.evaluate(smallMultinomialDataset) + + val mlorLongLabelData = smallMultinomialDataset.select( + col(mlorModel.getLabelCol).cast(LongType), + col(mlorModel.getFeaturesCol)) + val mlorLongSummary = mlorModel.evaluate(mlorLongLabelData) - assert(summary.areaUnderROC === longSummary.areaUnderROC) + assert(mlorSummary.accuracy === mlorLongSummary.accuracy) } test("statistics on training data") { // Test that loss is monotonically decreasing. - val lr = new LogisticRegression() + val blor = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) - .setThreshold(0.6) - val model = lr.fit(smallBinaryDataset) + .setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + assert( + blorModel.summary + .objectiveHistory + .sliding(2) + .forall(x => x(0) >= x(1))) + + val mlor = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + .setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) assert( - model.summary + mlorModel.summary .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) @@ -2470,7 +2560,7 @@ class LogisticRegressionSuite predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) => assert(p1 === p2) } - // TODO: check that it converges in a single iteration when model summary is available + assert(model4.summary.totalIterations === 1) } test("binary logistic regression with all labels the same") { @@ -2531,6 +2621,7 @@ class LogisticRegressionSuite assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0))) assert(pred === 4.0) } + assert(model.summary.totalIterations === 0) // force the model to be trained with only one class val constantZeroData = Seq( @@ -2544,6 +2635,7 @@ class LogisticRegressionSuite assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) } + assert(modelZeroLabel.summary.totalIterations > 0) // ensure that the correct value is predicted when numClasses passed through metadata val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata() @@ -2557,7 +2649,7 @@ class LogisticRegressionSuite assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))) assert(pred === 4.0) } - // TODO: check num iters is zero when it become available in the model + require(modelWithMetadata.summary.totalIterations === 0) } test("compressed storage for constant label") { http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index e7bd4eb..f470dca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -715,7 +715,7 @@ class LinearRegressionSuite assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) // Residuals in [[LinearRegressionResults]] should equal those manually computed - val expectedResiduals = datasetWithDenseFeature.select("features", "label") + datasetWithDenseFeature.select("features", "label") .rdd .map { case Row(features: DenseVector, label: Double) => val prediction = http://git-wip-us.apache.org/repos/asf/spark/blob/c7270a46/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9bda917..eecda26 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,7 +44,26 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this"), // [SPARK-21276] Update lz4-java to the latest (v1.4.0) - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream"), + + // [SPARK-17139] Add model summary for MultinomialLogisticRegression + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictionCol"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=") ) // Exclude rules for 2.2.x --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org