This is an automated email from the ASF dual-hosted git repository. huaxingao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d5563f3 [SPARK-36956][MLLIB] model prediction in .mllib avoid conversion to breeze vector d5563f3 is described below commit d5563f3897b925fee7c2f7a9999ae6418d8e03a5 Author: Ruifeng Zheng <ruife...@foxmail.com> AuthorDate: Mon Oct 25 11:11:44 2021 -0700 [SPARK-36956][MLLIB] model prediction in .mllib avoid conversion to breeze vector ### What changes were proposed in this pull request? model prediction in .mllib avoid conversion to breeze vector ### Why are the changes needed? avoid unnecessary conversion ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing suites Closes #34221 from zhengruifeng/mllib_model_avoid_breeze_conversion. Authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: Huaxin Gao <huaxin_...@apple.com> --- mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala | 4 ++-- .../scala/org/apache/spark/mllib/regression/LinearRegression.scala | 4 ++-- .../scala/org/apache/spark/mllib/regression/RidgeRegression.scala | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 90cc4fb..33ce0d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ @@ -72,7 +72,7 @@ class SVMModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { - val margin = weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept + val margin = BLAS.dot(weightMatrix, dataMatrix) + intercept threshold match { case Some(t) => if (margin > t) 1.0 else 0.0 case None => margin diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 47bb1fa9..13920b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel @@ -43,7 +43,7 @@ class LassoModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { - weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept + BLAS.dot(weightMatrix, dataMatrix) + intercept } @Since("1.3.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index f68ebc1..bd42d7b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel @@ -43,7 +43,7 @@ class LinearRegressionModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { - weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept + BLAS.dot(weightMatrix, dataMatrix) + intercept } @Since("1.3.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 1c3bdce..1f67536 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel @@ -43,7 +43,7 @@ class RidgeRegressionModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { - weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept + BLAS.dot(weightMatrix, dataMatrix) + intercept } @Since("1.3.0") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org