This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 20df062d85e8 [SPARK-46327][PS][CONNECT][TESTS] Reorganize `SeriesStringTests` 20df062d85e8 is described below commit 20df062d85e80422a55afae80ddbf2060f26516c Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Dec 8 21:08:50 2023 +0900 [SPARK-46327][PS][CONNECT][TESTS] Reorganize `SeriesStringTests` ### What changes were proposed in this pull request? Reorganize `SeriesStringTests` ### Why are the changes needed? test code clean up ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #44257 from zhengruifeng/ps_test_ser_str. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- dev/sparktestsupport/modules.py | 6 +- .../test_parity_string_ops_adv.py} | 8 +- .../test_parity_string_ops_basic.py} | 8 +- .../test_string_ops_adv.py} | 125 +------------- .../pandas/tests/series/test_string_ops_basic.py | 184 +++++++++++++++++++++ 5 files changed, 199 insertions(+), 132 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 834b3bd235aa..e67cfce0f5c0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -752,7 +752,8 @@ pyspark_pandas = Module( "pyspark.pandas.tests.test_scalars", "pyspark.pandas.tests.test_series_conversion", "pyspark.pandas.tests.test_series_datetime", - "pyspark.pandas.tests.test_series_string", + "pyspark.pandas.tests.series.test_string_ops_adv", + "pyspark.pandas.tests.series.test_string_ops_basic", "pyspark.pandas.tests.test_spark_functions", "pyspark.pandas.tests.test_sql", "pyspark.pandas.tests.test_typedef", @@ -1005,7 +1006,8 @@ pyspark_pandas_connect_part0 = Module( "pyspark.pandas.tests.connect.test_parity_scalars", "pyspark.pandas.tests.connect.test_parity_series_conversion", "pyspark.pandas.tests.connect.test_parity_series_datetime", - "pyspark.pandas.tests.connect.test_parity_series_string", + "pyspark.pandas.tests.connect.series.test_parity_string_ops_adv", + "pyspark.pandas.tests.connect.series.test_parity_string_ops_basic", "pyspark.pandas.tests.connect.test_parity_spark_functions", "pyspark.pandas.tests.connect.test_parity_sql", "pyspark.pandas.tests.connect.test_parity_typedef", diff --git a/python/pyspark/pandas/tests/connect/test_parity_series_string.py b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_adv.py similarity index 80% copy from python/pyspark/pandas/tests/connect/test_parity_series_string.py copy to python/pyspark/pandas/tests/connect/series/test_parity_string_ops_adv.py index d7c0335f15c7..1213ae073cf5 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_series_string.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_adv.py @@ -16,19 +16,19 @@ # import unittest -from pyspark.pandas.tests.test_series_string import SeriesStringTestsMixin +from pyspark.pandas.tests.series.test_string_ops_adv import SeriesStringOpsAdvMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class SeriesStringParityTests( - SeriesStringTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase +class SeriesParityStringOpsAdvTests( + SeriesStringOpsAdvMixin, PandasOnSparkTestUtils, ReusedConnectTestCase ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_series_string import * # noqa: F401 + from pyspark.pandas.tests.connect.series.test_parity_string_ops_adv import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_series_string.py b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_basic.py similarity index 81% rename from python/pyspark/pandas/tests/connect/test_parity_series_string.py rename to python/pyspark/pandas/tests/connect/series/test_parity_string_ops_basic.py index d7c0335f15c7..58f10fa505b0 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_series_string.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_string_ops_basic.py @@ -16,19 +16,19 @@ # import unittest -from pyspark.pandas.tests.test_series_string import SeriesStringTestsMixin +from pyspark.pandas.tests.series.test_string_ops_basic import SeriesStringOpsMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils -class SeriesStringParityTests( - SeriesStringTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase +class SeriesStringOpsParityTests( + SeriesStringOpsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase ): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_series_string import * # noqa: F401 + from pyspark.pandas.tests.connect.series.test_parity_string_ops_basic import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/series/test_string_ops_adv.py similarity index 67% rename from python/pyspark/pandas/tests/test_series_string.py rename to python/pyspark/pandas/tests/series/test_string_ops_adv.py index b8d35764f1bc..d1954cdf6dad 100644 --- a/python/pyspark/pandas/tests/test_series_string.py +++ b/python/pyspark/pandas/tests/series/test_string_ops_adv.py @@ -18,14 +18,13 @@ import pandas as pd import numpy as np import re -import unittest from pyspark import pandas as ps from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.sqlutils import SQLTestUtils -class SeriesStringTestsMixin: +class SeriesStringOpsAdvMixin: @property def pser(self): return pd.Series( @@ -49,124 +48,6 @@ class SeriesStringTestsMixin: def check_func_on_series(self, func, pser, almost=False): self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost) - def test_string_add_str_num(self): - pdf = pd.DataFrame(dict(col1=["a"], col2=[1])) - psdf = ps.from_pandas(pdf) - with self.assertRaises(TypeError): - psdf["col1"] + psdf["col2"] - - def test_string_add_assign(self): - pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"])) - psdf = ps.from_pandas(pdf) - psdf["col1"] += psdf["col2"] - pdf["col1"] += pdf["col2"] - self.assert_eq(psdf["col1"], pdf["col1"]) - - def test_string_add_str_str(self): - pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"])) - psdf = ps.from_pandas(pdf) - - # TODO: Fix the Series names - self.assert_eq(psdf["col1"] + psdf["col2"], pdf["col1"] + pdf["col2"]) - self.assert_eq(psdf["col2"] + psdf["col1"], pdf["col2"] + pdf["col1"]) - - def test_string_add_str_lit(self): - pdf = pd.DataFrame(dict(col1=["a", "b", "c"])) - psdf = ps.from_pandas(pdf) - self.assert_eq(psdf["col1"] + "_lit", pdf["col1"] + "_lit") - self.assert_eq("_lit" + psdf["col1"], "_lit" + pdf["col1"]) - - def test_string_capitalize(self): - self.check_func(lambda x: x.str.capitalize()) - - def test_string_title(self): - self.check_func(lambda x: x.str.title()) - - def test_string_lower(self): - self.check_func(lambda x: x.str.lower()) - - def test_string_upper(self): - self.check_func(lambda x: x.str.upper()) - - def test_string_swapcase(self): - self.check_func(lambda x: x.str.swapcase()) - - def test_string_startswith(self): - pattern = "car" - self.check_func(lambda x: x.str.startswith(pattern)) - self.check_func(lambda x: x.str.startswith(pattern, na=False)) - - def test_string_endswith(self): - pattern = "s" - self.check_func(lambda x: x.str.endswith(pattern)) - self.check_func(lambda x: x.str.endswith(pattern, na=False)) - - def test_string_strip(self): - self.check_func(lambda x: x.str.strip()) - self.check_func(lambda x: x.str.strip("es\t")) - self.check_func(lambda x: x.str.strip("1")) - - def test_string_lstrip(self): - self.check_func(lambda x: x.str.lstrip()) - self.check_func(lambda x: x.str.lstrip("\n1le")) - self.check_func(lambda x: x.str.lstrip("s")) - - def test_string_rstrip(self): - self.check_func(lambda x: x.str.rstrip()) - self.check_func(lambda x: x.str.rstrip("\t ec")) - self.check_func(lambda x: x.str.rstrip("0")) - - def test_string_get(self): - self.check_func(lambda x: x.str.get(6)) - self.check_func(lambda x: x.str.get(-1)) - - def test_string_isalnum(self): - self.check_func(lambda x: x.str.isalnum()) - - def test_string_isalpha(self): - self.check_func(lambda x: x.str.isalpha()) - - def test_string_isdigit(self): - self.check_func(lambda x: x.str.isdigit()) - - def test_string_isspace(self): - self.check_func(lambda x: x.str.isspace()) - - def test_string_islower(self): - self.check_func(lambda x: x.str.islower()) - - def test_string_isupper(self): - self.check_func(lambda x: x.str.isupper()) - - def test_string_istitle(self): - self.check_func(lambda x: x.str.istitle()) - - def test_string_isnumeric(self): - self.check_func(lambda x: x.str.isnumeric()) - - def test_string_isdecimal(self): - self.check_func(lambda x: x.str.isdecimal()) - - def test_string_cat(self): - psser = ps.from_pandas(self.pser) - with self.assertRaises(NotImplementedError): - psser.str.cat() - - def test_string_center(self): - self.check_func(lambda x: x.str.center(0)) - self.check_func(lambda x: x.str.center(10)) - self.check_func(lambda x: x.str.center(10, "x")) - - def test_string_contains(self): - self.check_func(lambda x: x.str.contains("le", regex=False)) - self.check_func(lambda x: x.str.contains("White", case=True, regex=False)) - self.check_func(lambda x: x.str.contains("apples|carrots", regex=True)) - self.check_func(lambda x: x.str.contains("BANANAS", flags=re.IGNORECASE, na=False)) - - def test_string_count(self): - self.check_func(lambda x: x.str.count("wh|Wh")) - self.check_func(lambda x: x.str.count("WH", flags=re.IGNORECASE)) - def test_string_decode(self): psser = ps.from_pandas(self.pser) with self.assertRaises(NotImplementedError): @@ -332,13 +213,13 @@ class SeriesStringTestsMixin: self.check_func(lambda x: x.str.get_dummies()) -class SeriesStringTests(SeriesStringTestsMixin, PandasOnSparkTestCase, SQLTestUtils): +class SeriesStringOpsAdvTests(SeriesStringOpsAdvMixin, PandasOnSparkTestCase, SQLTestUtils): pass if __name__ == "__main__": import unittest - from pyspark.pandas.tests.test_series_string import * # noqa: F401 + from pyspark.pandas.tests.series.test_string_ops_adv import * # noqa: F401 try: import xmlrunner diff --git a/python/pyspark/pandas/tests/series/test_string_ops_basic.py b/python/pyspark/pandas/tests/series/test_string_ops_basic.py new file mode 100644 index 000000000000..bdb3bf74b098 --- /dev/null +++ b/python/pyspark/pandas/tests/series/test_string_ops_basic.py @@ -0,0 +1,184 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pandas as pd +import numpy as np +import re + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils + + +class SeriesStringOpsMixin: + @property + def pser(self): + return pd.Series( + [ + "apples", + "Bananas", + "carrots", + "1", + "100", + "", + "\nleading-whitespace", + "trailing-Whitespace \t", + None, + np.NaN, + ] + ) + + def check_func(self, func, almost=False): + self.check_func_on_series(func, self.pser, almost=almost) + + def check_func_on_series(self, func, pser, almost=False): + self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost) + + def test_string_add_str_num(self): + pdf = pd.DataFrame(dict(col1=["a"], col2=[1])) + psdf = ps.from_pandas(pdf) + with self.assertRaises(TypeError): + psdf["col1"] + psdf["col2"] + + def test_string_add_assign(self): + pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"])) + psdf = ps.from_pandas(pdf) + psdf["col1"] += psdf["col2"] + pdf["col1"] += pdf["col2"] + self.assert_eq(psdf["col1"], pdf["col1"]) + + def test_string_add_str_str(self): + pdf = pd.DataFrame(dict(col1=["a", "b", "c"], col2=["1", "2", "3"])) + psdf = ps.from_pandas(pdf) + + # TODO: Fix the Series names + self.assert_eq(psdf["col1"] + psdf["col2"], pdf["col1"] + pdf["col2"]) + self.assert_eq(psdf["col2"] + psdf["col1"], pdf["col2"] + pdf["col1"]) + + def test_string_add_str_lit(self): + pdf = pd.DataFrame(dict(col1=["a", "b", "c"])) + psdf = ps.from_pandas(pdf) + self.assert_eq(psdf["col1"] + "_lit", pdf["col1"] + "_lit") + self.assert_eq("_lit" + psdf["col1"], "_lit" + pdf["col1"]) + + def test_string_capitalize(self): + self.check_func(lambda x: x.str.capitalize()) + + def test_string_title(self): + self.check_func(lambda x: x.str.title()) + + def test_string_lower(self): + self.check_func(lambda x: x.str.lower()) + + def test_string_upper(self): + self.check_func(lambda x: x.str.upper()) + + def test_string_swapcase(self): + self.check_func(lambda x: x.str.swapcase()) + + def test_string_startswith(self): + pattern = "car" + self.check_func(lambda x: x.str.startswith(pattern)) + self.check_func(lambda x: x.str.startswith(pattern, na=False)) + + def test_string_endswith(self): + pattern = "s" + self.check_func(lambda x: x.str.endswith(pattern)) + self.check_func(lambda x: x.str.endswith(pattern, na=False)) + + def test_string_strip(self): + self.check_func(lambda x: x.str.strip()) + self.check_func(lambda x: x.str.strip("es\t")) + self.check_func(lambda x: x.str.strip("1")) + + def test_string_lstrip(self): + self.check_func(lambda x: x.str.lstrip()) + self.check_func(lambda x: x.str.lstrip("\n1le")) + self.check_func(lambda x: x.str.lstrip("s")) + + def test_string_rstrip(self): + self.check_func(lambda x: x.str.rstrip()) + self.check_func(lambda x: x.str.rstrip("\t ec")) + self.check_func(lambda x: x.str.rstrip("0")) + + def test_string_get(self): + self.check_func(lambda x: x.str.get(6)) + self.check_func(lambda x: x.str.get(-1)) + + def test_string_isalnum(self): + self.check_func(lambda x: x.str.isalnum()) + + def test_string_isalpha(self): + self.check_func(lambda x: x.str.isalpha()) + + def test_string_isdigit(self): + self.check_func(lambda x: x.str.isdigit()) + + def test_string_isspace(self): + self.check_func(lambda x: x.str.isspace()) + + def test_string_islower(self): + self.check_func(lambda x: x.str.islower()) + + def test_string_isupper(self): + self.check_func(lambda x: x.str.isupper()) + + def test_string_istitle(self): + self.check_func(lambda x: x.str.istitle()) + + def test_string_isnumeric(self): + self.check_func(lambda x: x.str.isnumeric()) + + def test_string_isdecimal(self): + self.check_func(lambda x: x.str.isdecimal()) + + def test_string_cat(self): + psser = ps.from_pandas(self.pser) + with self.assertRaises(NotImplementedError): + psser.str.cat() + + def test_string_center(self): + self.check_func(lambda x: x.str.center(0)) + self.check_func(lambda x: x.str.center(10)) + self.check_func(lambda x: x.str.center(10, "x")) + + def test_string_contains(self): + self.check_func(lambda x: x.str.contains("le", regex=False)) + self.check_func(lambda x: x.str.contains("White", case=True, regex=False)) + self.check_func(lambda x: x.str.contains("apples|carrots", regex=True)) + self.check_func(lambda x: x.str.contains("BANANAS", flags=re.IGNORECASE, na=False)) + + def test_string_count(self): + self.check_func(lambda x: x.str.count("wh|Wh")) + self.check_func(lambda x: x.str.count("WH", flags=re.IGNORECASE)) + + +class SeriesStringOpsTests(SeriesStringOpsMixin, PandasOnSparkTestCase, SQLTestUtils): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.series.test_string_ops_basic import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org