Repository: spark Updated Branches: refs/heads/branch-1.5 cd55fbcc2 -> be68a4bcb
[SPARK-10973] [ML] [PYTHON] Fix IndexError exception on SparseVector when asked for index after the last non-zero entry See https://github.com/apache/spark/pull/9009 for details. Author: zero323 <matthew.szymkiew...@gmail.com> Closes #9064 from zero323/SPARK-10973_1.5. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/be68a4bc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/be68a4bc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/be68a4bc Branch: refs/heads/branch-1.5 Commit: be68a4bcbfac3dad4e0279cf6ce099cd830a4507 Parents: cd55fbc Author: zero323 <matthew.szymkiew...@gmail.com> Authored: Mon Oct 12 12:09:06 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Oct 12 12:09:06 2015 -0700 ---------------------------------------------------------------------- python/pyspark/mllib/linalg/__init__.py | 3 +++ python/pyspark/mllib/tests.py | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/be68a4bc/python/pyspark/mllib/linalg/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 30d05d0..2db2dec 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -732,6 +732,9 @@ class SparseVector(Vector): raise ValueError("Index %d out of bounds." % index) insert_index = np.searchsorted(inds, index) + if insert_index >= inds.size: + return 0. + row_ind = inds[insert_index] if row_ind == index: return vals[insert_index] http://git-wip-us.apache.org/repos/asf/spark/blob/be68a4bc/python/pyspark/mllib/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5097c5e..4d47144 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -205,15 +205,17 @@ class VectorTests(MLlibTestCase): self.assertTrue(dv.array.dtype == 'float64') def test_sparse_vector_indexing(self): - sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) - for ind in [4, -5]: + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org