This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6ef15f9d075 [SPARK-38822][PYSPARK] Raise indexError when insert loc is out of bounds 6ef15f9d075 is described below commit 6ef15f9d075bf1735131f22b890043829c623f7f Author: Yikun Jiang <yikunk...@gmail.com> AuthorDate: Wed Apr 13 10:42:46 2022 +0900 [SPARK-38822][PYSPARK] Raise indexError when insert loc is out of bounds ### What changes were proposed in this pull request? Since Pandas 1.4.0, pandas are using `numpy.insert`, `numpy.insert` raises indexError when insert loc is out of bounds. Related changes: - panda 1.4.0+ is using numpy insert: https://github.com/pandas-dev/pandas/commit/c021d33ecf0e096a186edb731964767e9288a875 - Since numpy 1.8 (10 years ago https://github.com/numpy/numpy/commit/908e06c3c465434023649b0ca522836580c5cfdc) : [`out-of-bound indices will generate an error.`](https://numpy.org/devdocs/release/1.8.0-notes.html#changes) ### Why are the changes needed? Follow pandas behaviors. ### Does this PR introduce _any_ user-facing change? Yes, raise the index error in index out of bounds case, follow pandas behaviors. ### How was this patch tested? UT Closes #36115 from Yikun/SPARK-38822. Authored-by: Yikun Jiang <yikunk...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/indexes/base.py | 7 +++--- python/pyspark/pandas/indexes/multi.py | 17 +++---------- python/pyspark/pandas/tests/indexes/test_base.py | 32 ++++++++++++++++++------ python/pyspark/pandas/tests/test_utils.py | 13 ++++++++++ python/pyspark/pandas/utils.py | 19 ++++++++++++++ 5 files changed, 62 insertions(+), 26 deletions(-) diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index fd1c2dff032..c8be0b436fa 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -73,6 +73,7 @@ from pyspark.pandas.utils import ( scol_for, verify_temp_column_name, validate_bool_kwarg, + validate_index_loc, ERROR_MESSAGE_CANNOT_COMBINE, log_advice, ) @@ -2544,10 +2545,8 @@ class Index(IndexOpsMixin): >>> psidx.insert(-3, 100) Int64Index([1, 2, 100, 3, 4, 5], dtype='int64') """ - if loc < 0: - length = len(self) - loc = loc + length - loc = 0 if loc < 0 else loc + validate_index_loc(self, loc) + loc = loc + len(self) if loc < 0 else loc index_name = self._internal.index_spark_column_names[0] sdf_before = self.to_frame(name=index_name)[:loc]._to_spark() diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index e3aea0d075d..43ced2f4541 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -38,6 +38,7 @@ from pyspark.pandas.utils import ( name_like_string, scol_for, verify_temp_column_name, + validate_index_loc, ) from pyspark.pandas.internal import ( InternalField, @@ -1108,20 +1109,8 @@ class MultiIndex(Index): ('c', 'z')], ) """ - length = len(self) - if loc < 0: - loc = loc + length - if loc < 0: - raise IndexError( - "index {} is out of bounds for axis 0 with size {}".format( - (loc - length), length - ) - ) - else: - if loc > length: - raise IndexError( - "index {} is out of bounds for axis 0 with size {}".format(loc, length) - ) + validate_index_loc(self, loc) + loc = loc + len(self) if loc < 0 else loc index_name: List[Label] = [(name,) for name in self._internal.index_spark_column_names] sdf_before = self.to_frame(name=index_name)[:loc]._to_spark() diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 3e03bbc028c..de138b58c68 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -2191,32 +2191,48 @@ class IndexesTest(ComparisonTestBase, TestUtils): psidx = ps.from_pandas(pidx) self.assert_eq(pidx.insert(1, 100), psidx.insert(1, 100)) self.assert_eq(pidx.insert(-1, 100), psidx.insert(-1, 100)) - self.assert_eq(pidx.insert(100, 100), psidx.insert(100, 100)) - self.assert_eq(pidx.insert(-100, 100), psidx.insert(-100, 100)) + err_msg = "index 100 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(100, 100) + err_msg = "index -100 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(-100, 100) # Floating pidx = pd.Index([1.0, 2.0, 3.0], name="Koalas") psidx = ps.from_pandas(pidx) self.assert_eq(pidx.insert(1, 100.0), psidx.insert(1, 100.0)) self.assert_eq(pidx.insert(-1, 100.0), psidx.insert(-1, 100.0)) - self.assert_eq(pidx.insert(100, 100.0), psidx.insert(100, 100.0)) - self.assert_eq(pidx.insert(-100, 100.0), psidx.insert(-100, 100.0)) + err_msg = "index 100 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(100, 100) + err_msg = "index -100 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(-100, 100) # String pidx = pd.Index(["a", "b", "c"], name="Koalas") psidx = ps.from_pandas(pidx) self.assert_eq(pidx.insert(1, "x"), psidx.insert(1, "x")) self.assert_eq(pidx.insert(-1, "x"), psidx.insert(-1, "x")) - self.assert_eq(pidx.insert(100, "x"), psidx.insert(100, "x")) - self.assert_eq(pidx.insert(-100, "x"), psidx.insert(-100, "x")) + err_msg = "index 100 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(100, "x") + err_msg = "index -100 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(-100, "x") # Boolean pidx = pd.Index([True, False, True, False], name="Koalas") psidx = ps.from_pandas(pidx) self.assert_eq(pidx.insert(1, True), psidx.insert(1, True)) self.assert_eq(pidx.insert(-1, True), psidx.insert(-1, True)) - self.assert_eq(pidx.insert(100, True), psidx.insert(100, True)) - self.assert_eq(pidx.insert(-100, True), psidx.insert(-100, True)) + err_msg = "index 100 is out of bounds for axis 0 with size 4" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(100, True) + err_msg = "index -100 is out of bounds for axis 0 with size 4" + with self.assertRaisesRegex(IndexError, err_msg): + psidx.insert(-100, True) # MultiIndex pmidx = pd.MultiIndex.from_tuples( diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index b601c695476..11f560c6f55 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -17,10 +17,12 @@ import pandas as pd +from pyspark.pandas.indexes.base import Index from pyspark.pandas.utils import ( lazy_property, validate_arguments_and_invoke_function, validate_bool_kwarg, + validate_index_loc, validate_mode, ) from pyspark.testing.pandasutils import PandasOnSparkTestCase @@ -92,6 +94,17 @@ class UtilsTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaises(ValueError): validate_mode("r") + def test_validate_index_loc(self): + psidx = Index([1, 2, 3]) + validate_index_loc(psidx, -1) + validate_index_loc(psidx, -3) + err_msg = "index 4 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + validate_index_loc(psidx, 4) + err_msg = "index -4 is out of bounds for axis 0 with size 3" + with self.assertRaisesRegex(IndexError, err_msg): + validate_index_loc(psidx, -4) + class TestClassForLazyProp: def __init__(self): diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index a61ea7d19b3..0a7831b94ff 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -49,6 +49,7 @@ from pyspark.pandas.spark import functions as SF from pyspark.pandas.typedef.typehints import as_spark_type if TYPE_CHECKING: + from pyspark.pandas.indexes.base import Index from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.frame import DataFrame from pyspark.pandas.internal import InternalFrame @@ -975,6 +976,24 @@ def log_advice(message: str) -> None: warnings.warn(message, PandasAPIOnSparkAdviceWarning) +def validate_index_loc(index: "Index", loc: int) -> None: + """ + Raises IndexError if index is out of bounds + """ + length = len(index) + if loc < 0: + loc = loc + length + if loc < 0: + raise IndexError( + "index {} is out of bounds for axis 0 with size {}".format((loc - length), length) + ) + else: + if loc > length: + raise IndexError( + "index {} is out of bounds for axis 0 with size {}".format(loc, length) + ) + + def _test() -> None: import os import doctest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org