This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5c96d64 [SPARK-35707][ML] optimize sparse GEMM by skipping bound checking 5c96d64 is described below commit 5c96d643eeb4ca1ad7e4e9cc711971203fcacc6c Author: Ruifeng Zheng <ruife...@foxmail.com> AuthorDate: Wed Jun 16 08:57:27 2021 +0800 [SPARK-35707][ML] optimize sparse GEMM by skipping bound checking ### What changes were proposed in this pull request? Sparse gemm use mothod `DenseMatrix.apply` to access the values, which can be optimized by skipping checking the bound and `isTransposed` ``` override def apply(i: Int, j: Int): Double = values(index(i, j)) private[ml] def index(i: Int, j: Int): Int = { require(i >= 0 && i < numRows, s"Expected 0 <= i < $numRows, got i = $i.") require(j >= 0 && j < numCols, s"Expected 0 <= j < $numCols, got j = $j.") if (!isTransposed) i + numRows * j else j + numCols * i } ``` ### Why are the changes needed? to improve performance, about 15% faster in the designed case ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuite and additional performance test Closes #32857 from zhengruifeng/gemm_opt_index. Authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: Ruifeng Zheng <ruife...@foxmail.com> --- mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 0bc8b2f..d1255de 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -480,7 +480,7 @@ private[spark] object BLAS extends Serializable { val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(ArowIndices(i), colCounterForB) + sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -522,7 +522,7 @@ private[spark] object BLAS extends Serializable { while (colCounterForA < kA) { var i = AcolPtrs(colCounterForA) val indEnd = AcolPtrs(colCounterForA + 1) - val Bval = B(colCounterForA, colCounterForB) * alpha + val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha while (i < indEnd) { Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index e38cfe4..5cbec53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -462,7 +462,7 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(ArowIndices(i), colCounterForB) + sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -504,7 +504,7 @@ private[spark] object BLAS extends Serializable with Logging { while (colCounterForA < kA) { var i = AcolPtrs(colCounterForA) val indEnd = AcolPtrs(colCounterForA + 1) - val Bval = B(colCounterForA, colCounterForB) * alpha + val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha while (i < indEnd) { Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org