Repository: spark
Updated Branches:
  refs/heads/branch-1.4 686c22e57 -> d1cae3206


[SPARK-10973] [ML] [PYTHON] Fix IndexError exception on SparseVector when asked 
for index after the last non-zero entry

See https://github.com/apache/spark/pull/9009 for details.

Author: zero323 <matthew.szymkiew...@gmail.com>

Closes #9063 from zero323/SPARK-10973_1.4.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d1cae320
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d1cae320
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d1cae320

Branch: refs/heads/branch-1.4
Commit: d1cae3206af71c54c13e785b17936c0ea7160545
Parents: 686c22e
Author: zero323 <matthew.szymkiew...@gmail.com>
Authored: Mon Oct 12 15:00:50 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Oct 12 15:00:50 2015 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/linalg.py |  3 +++
 python/pyspark/mllib/tests.py  | 20 +++++++++++---------
 2 files changed, 14 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d1cae320/python/pyspark/mllib/linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 7702beb..ee1ad03 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -673,6 +673,9 @@ class SparseVector(Vector):
             raise ValueError("Index %d out of bounds." % index)
 
         insert_index = np.searchsorted(inds, index)
+        if insert_index >= inds.size:
+            return 0.0
+
         row_ind = inds[insert_index]
         if row_ind == index:
             return vals[insert_index]

http://git-wip-us.apache.org/repos/asf/spark/blob/d1cae320/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 4335143..d883f6f 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -137,15 +137,17 @@ class VectorTests(MLlibTestCase):
         self.assertTrue(dv.array.dtype == 'float64')
 
     def test_sparse_vector_indexing(self):
-        sv = SparseVector(4, {1: 1, 3: 2})
-        self.assertEquals(sv[0], 0.)
-        self.assertEquals(sv[3], 2.)
-        self.assertEquals(sv[1], 1.)
-        self.assertEquals(sv[2], 0.)
-        self.assertEquals(sv[-1], 2)
-        self.assertEquals(sv[-2], 0)
-        self.assertEquals(sv[-4], 0)
-        for ind in [4, -5]:
+        sv = SparseVector(5, {1: 1, 3: 2})
+        self.assertEqual(sv[0], 0.)
+        self.assertEqual(sv[3], 2.)
+        self.assertEqual(sv[1], 1.)
+        self.assertEqual(sv[2], 0.)
+        self.assertEqual(sv[4], 0.)
+        self.assertEqual(sv[-1], 0.)
+        self.assertEqual(sv[-2], 2.)
+        self.assertEqual(sv[-3], 0.)
+        self.assertEqual(sv[-5], 0.)
+        for ind in [5, -6]:
             self.assertRaises(ValueError, sv.__getitem__, ind)
         for ind in [7.8, '1']:
             self.assertRaises(TypeError, sv.__getitem__, ind)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to