[ https://issues.apache.org/jira/browse/SPARK-20351?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16723078#comment-16723078 ]
ASF GitHub Bot commented on SPARK-20351: ---------------------------------------- srowen closed pull request #17654: [SPARK-20351] [ML] Add trait hasTrainingSummary to replace the duplicate code URL: https://github.com/apache/spark/pull/17654 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 27a7db0b2f5d4..f2a5c11a34867 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -934,8 +934,8 @@ class LogisticRegressionModel private[spark] ( @Since("2.1.0") val interceptVector: Vector, @Since("1.3.0") override val numClasses: Int, private val isMultinomial: Boolean) - extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with MLWritable { + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with MLWritable + with LogisticRegressionParams with HasTrainingSummary[LogisticRegressionTrainingSummary] { require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " + s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " + @@ -1018,20 +1018,16 @@ class LogisticRegressionModel private[spark] ( @Since("1.6.0") override val numFeatures: Int = coefficientMatrix.numCols - private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None - /** * Gets summary of model on training set. An exception is thrown - * if `trainingSummary == None`. + * if `hasSummary` is false. */ @Since("1.5.0") - def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException("No training summary available for this LogisticRegressionModel") - } + override def summary: LogisticRegressionTrainingSummary = super.summary /** * Gets summary of model on training set. An exception is thrown - * if `trainingSummary == None` or it is a multiclass model. + * if `hasSummary` is false or it is a multiclass model. */ @Since("2.3.0") def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match { @@ -1062,16 +1058,6 @@ class LogisticRegressionModel private[spark] ( (model, model.getProbabilityCol, model.getPredictionCol) } - private[classification] - def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** Indicates whether a training summary exists for this model instance. */ - @Since("1.5.0") - def hasSummary: Boolean = trainingSummary.isDefined - /** * Evaluates the model on a test dataset. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1a94aefa3f563..49e9f51368131 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -87,8 +87,9 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter @Since("2.0.0") class BisectingKMeansModel private[ml] ( @Since("2.0.0") override val uid: String, - private val parentModel: MLlibBisectingKMeansModel - ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { + private val parentModel: MLlibBisectingKMeansModel) + extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable + with HasTrainingSummary[BisectingKMeansSummary] { @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { @@ -143,28 +144,12 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) - private var trainingSummary: Option[BisectingKMeansSummary] = None - - private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.1.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.1.0") - def summary: BisectingKMeansSummary = trainingSummary.getOrElse { - throw new SparkException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: BisectingKMeansSummary = super.summary } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88abc1605d69f..bb10b3228b93f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -85,7 +85,8 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override val uid: String, @Since("2.0.0") val weights: Array[Double], @Since("2.0.0") val gaussians: Array[MultivariateGaussian]) - extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable + with HasTrainingSummary[GaussianMixtureSummary] { /** @group setParam */ @Since("2.1.0") @@ -160,28 +161,13 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) - private var trainingSummary: Option[GaussianMixtureSummary] = None - - private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.0.0") - def summary: GaussianMixtureSummary = trainingSummary.getOrElse { - throw new RuntimeException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: GaussianMixtureSummary = super.summary + } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 2eed84d51782a..319747d4a1930 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -107,7 +107,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, private[clustering] val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable + with HasTrainingSummary[KMeansSummary] { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -153,28 +154,12 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: GeneralMLWriter = new GeneralMLWriter(this) - private var trainingSummary: Option[KMeansSummary] = None - - private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** - * Return true if there exists summary of model. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - /** * Gets summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("2.0.0") - def summary: KMeansSummary = trainingSummary.getOrElse { - throw new SparkException( - s"No training summary available for the ${this.getClass.getSimpleName}") - } + override def summary: KMeansSummary = super.summary } /** Helper class for storing model data */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index abb60ea205751..885b13bf8dac3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1001,7 +1001,8 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("2.0.0") val intercept: Double) extends RegressionModel[Vector, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase with MLWritable { + with GeneralizedLinearRegressionBase with MLWritable + with HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary] { /** * Sets the link prediction (linear predictor) column name. @@ -1054,29 +1055,12 @@ class GeneralizedLinearRegressionModel private[ml] ( output.toDF() } - private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None - /** * Gets R-like summary of model on training set. An exception is * thrown if there is no summary available. */ @Since("2.0.0") - def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException( - "No training summary available for this GeneralizedLinearRegressionModel") - } - - /** - * Indicates if [[summary]] is available. - */ - @Since("2.0.0") - def hasSummary: Boolean = trainingSummary.nonEmpty - - private[regression] - def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } + override def summary: GeneralizedLinearRegressionTrainingSummary = super.summary /** * Evaluate the model on the given dataset, returning a summary of the results. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ce6c12cc368dd..197828762d160 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -647,33 +647,20 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with GeneralMLWritable { + with LinearRegressionParams with GeneralMLWritable + with HasTrainingSummary[LinearRegressionTrainingSummary] { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) - private var trainingSummary: Option[LinearRegressionTrainingSummary] = None - override val numFeatures: Int = coefficients.size /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if `hasSummary` is false. */ @Since("1.5.0") - def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse { - throw new SparkException("No training summary available for this LinearRegressionModel") - } - - private[regression] - def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { - this.trainingSummary = summary - this - } - - /** Indicates whether a training summary exists for this model instance. */ - @Since("1.5.0") - def hasSummary: Boolean = trainingSummary.isDefined + override def summary: LinearRegressionTrainingSummary = super.summary /** * Evaluates the model on a test dataset. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala new file mode 100644 index 0000000000000..edb0208144e10 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since + + +/** + * Trait for models that provides Training summary. + * + * @tparam T Summary instance type + */ +@Since("3.0.0") +private[ml] trait HasTrainingSummary[T] { + + private[ml] final var trainingSummary: Option[T] = None + + /** Indicates whether a training summary exists for this model instance. */ + @Since("3.0.0") + def hasSummary: Boolean = trainingSummary.isDefined + + /** + * Gets summary of model on training set. An exception is + * thrown if if `hasSummary` is false. + */ + @Since("3.0.0") + def summary: T = trainingSummary.getOrElse { + throw new SparkException( + s"No training summary available for this ${this.getClass.getSimpleName}") + } + + private[ml] def setSummary(summary: Option[T]): this.type = { + this.trainingSummary = summary + this + } +} ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Add trait hasTrainingSummary to replace the duplicate code > ---------------------------------------------------------- > > Key: SPARK-20351 > URL: https://issues.apache.org/jira/browse/SPARK-20351 > Project: Spark > Issue Type: Improvement > Components: ML > Affects Versions: 2.2.0 > Reporter: yuhao yang > Assignee: yuhao yang > Priority: Minor > Fix For: 3.0.0 > > > Add a trait HasTrainingSummary to avoid code duplicate related to training > summary. -- This message was sent by Atlassian JIRA (v7.6.3#76005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org