Repository: spark Updated Branches: refs/heads/branch-2.2 ca3f7edba -> 4bbfad44e
[SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException ## What changes were proposed in this pull request? Added a check for for the number of defined values. Previously the argmax function assumed that at least one value was defined if the vector size was greater than zero. ## How was this patch tested? Tests were added to the existing VectorsSuite to cover this case. Author: Jon McLean <jon.mcl...@atsid.com> Closes #17877 from jonmclean/vectorArgmaxIndexBug. (cherry picked from commit be53a78352ae7c70d8a07d0df24574b3e3129b4a) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4bbfad44 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4bbfad44 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4bbfad44 Branch: refs/heads/branch-2.2 Commit: 4bbfad44e426365ad9f4941d68c110523b17ea6d Parents: ca3f7ed Author: Jon McLean <jon.mcl...@atsid.com> Authored: Tue May 9 09:47:50 2017 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Tue May 9 09:47:58 2017 +0100 ---------------------------------------------------------------------- .../src/main/scala/org/apache/spark/ml/linalg/Vectors.scala | 2 ++ .../test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala | 7 +++++++ .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../scala/org/apache/spark/mllib/linalg/VectorsSuite.scala | 7 +++++++ 4 files changed, 18 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4bbfad44/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala ---------------------------------------------------------------------- diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 8e166ba..3fbc095 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) http://git-wip-us.apache.org/repos/asf/spark/blob/4bbfad44/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index dfbdaf1..4cd91af 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { http://git-wip-us.apache.org/repos/asf/spark/blob/4bbfad44/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 723addc..f063420 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) http://git-wip-us.apache.org/repos/asf/spark/blob/4bbfad44/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 71a3cea..6172cff 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org