This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dc186c5e6b6 [SPARK-43773][CONNECT][PYTHON][, THRESHOLD] Implement 'levenshtein(str1, str2)' functions in python client dc186c5e6b6 is described below commit dc186c5e6b6bdb63345081ee9f70b8c102792cdd Author: panbingkun <pbk1...@gmail.com> AuthorDate: Sun May 28 08:38:32 2023 +0800 [SPARK-43773][CONNECT][PYTHON][, THRESHOLD] Implement 'levenshtein(str1, str2)' functions in python client ### What changes were proposed in this pull request? The pr aims to implement 'levenshtein(str1, str2[, threshold])' functions in python client ### Why are the changes needed? After Add a max distance argument to the levenshtein() function We have already implemented it on the scala side, so we need to align it on `pyspark`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Manual testing python/run-tests --testnames 'python.pyspark.sql.tests.test_functions FunctionsTests.test_levenshtein_function' - Pass GA Closes #41296 from panbingkun/SPARK-43773. Lead-authored-by: panbingkun <pbk1...@gmail.com> Co-authored-by: panbingkun <84731...@qq.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/functions.py | 9 +++++++-- python/pyspark/sql/functions.py | 19 +++++++++++++++++-- .../sql/tests/connect/test_connect_function.py | 5 +++++ python/pyspark/sql/tests/test_functions.py | 7 +++++++ 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index b7d7bc937cf..d3a05d6a1c6 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -1878,8 +1878,13 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: substring_index.__doc__ = pysparkfuncs.substring_index.__doc__ -def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: - return _invoke_function_over_columns("levenshtein", left, right) +def levenshtein( + left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None +) -> Column: + if threshold is None: + return _invoke_function_over_columns("levenshtein", left, right) + else: + return _invoke_function("levenshtein", _to_col(left), _to_col(right), lit(threshold)) levenshtein.__doc__ = pysparkfuncs.levenshtein.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e9b71f7d617..fe35f12c402 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -6594,7 +6594,9 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: @try_remote_functions -def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: +def levenshtein( + left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None +) -> Column: """Computes the Levenshtein distance of the two given strings. .. versionadded:: 1.5.0 @@ -6608,6 +6610,12 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: first column value. right : :class:`~pyspark.sql.Column` or str second column value. + threshold : int, optional + if set when the levenshtein distance of the two given strings + less than or equal to a given threshold then return result distance, or -1 + + .. versionchanged: 3.5.0 + Added ``threshold`` argument. Returns ------- @@ -6619,8 +6627,15 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) >>> df0.select(levenshtein('l', 'r').alias('d')).collect() [Row(d=3)] + >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect() + [Row(d=-1)] """ - return _invoke_function_over_columns("levenshtein", left, right) + if threshold is None: + return _invoke_function_over_columns("levenshtein", left, right) + else: + return _invoke_function( + "levenshtein", _to_java_column(left), _to_java_column(right), threshold + ) @try_remote_functions diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index e274635d3c6..3e3b4dd5b16 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -1924,6 +1924,11 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S cdf.select(CF.levenshtein(cdf.b, cdf.c)).toPandas(), sdf.select(SF.levenshtein(sdf.b, sdf.c)).toPandas(), ) + self.assert_eq( + cdf.select(CF.levenshtein(cdf.b, cdf.c, 1)).toPandas(), + sdf.select(SF.levenshtein(sdf.b, sdf.c, 1)).toPandas(), + ) + self.assert_eq( cdf.select(CF.locate("e", cdf.b)).toPandas(), sdf.select(SF.locate("e", sdf.b)).toPandas(), diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 9067de34633..72c6c365b80 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -377,6 +377,13 @@ class FunctionsTestsMixin: actual = df.select(F.array_contains(df.data, "1").alias("b")).collect() self.assertEqual([Row(b=True), Row(b=False)], actual) + def test_levenshtein_function(self): + df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"]) + actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")).collect() + self.assertEqual([Row(b=3)], actual_without_threshold) + actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")).collect() + self.assertEqual([Row(b=-1)], actual_with_threshold) + def test_between_function(self): df = self.spark.createDataFrame( [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org