This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c506b2fb4190 [SPARK-52701][PS] Fix float32 type widening in `mod` with 
bool under ANSI
c506b2fb4190 is described below

commit c506b2fb4190baab80baff429264496de1ab38af
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Jul 8 17:17:41 2025 -0700

    [SPARK-52701][PS] Fix float32 type widening in `mod` with bool under ANSI
    
    ### What changes were proposed in this pull request?
    Fix float32 type widening in `mod` with bool under ANSI.
    
    ### Why are the changes needed?
    Ensure pandas on Spark works well with ANSI mode on.
    Part of https://issues.apache.org/jira/browse/SPARK-52700.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `mod` under ANSI works as pandas.
    ```py
    (dev3.11) spark (mod_dtype) % SPARK_ANSI_SQL_MODE=False  ./python/run-tests 
--python-executables=python3.11 --testnames 
"pyspark.pandas.tests.data_type_ops.test_num_mod NumModTests.test_mod"
    ...
    Tests passed in 8 seconds
    
    (dev3.11) spark (mod_dtype) % SPARK_ANSI_SQL_MODE=True  ./python/run-tests 
--python-executables=python3.11 --testnames 
"pyspark.pandas.tests.data_type_ops.test_num_mod NumModTests.test_mod"
    ...
    Tests passed in 7 seconds
    
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #51394 from xinrong-meng/mod_dtype.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Xinrong Meng <[email protected]>
---
 python/pyspark/pandas/data_type_ops/num_ops.py     | 38 ++++++++++++++++++----
 .../pandas/tests/data_type_ops/test_num_mod.py     |  1 -
 2 files changed, 31 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py 
b/python/pyspark/pandas/data_type_ops/num_ops.py
index b4a14b84e87d..78e5751568e8 100644
--- a/python/pyspark/pandas/data_type_ops/num_ops.py
+++ b/python/pyspark/pandas/data_type_ops/num_ops.py
@@ -23,6 +23,7 @@ import pandas as pd
 from pandas.api.types import (  # type: ignore[attr-defined]
     is_bool_dtype,
     is_integer_dtype,
+    is_float_dtype,
     CategoricalDtype,
     is_list_like,
 )
@@ -42,7 +43,7 @@ from pyspark.pandas.data_type_ops.base import (
     _is_valid_for_logical_operator,
     _is_boolean_type,
 )
-from pyspark.pandas.typedef.typehints import extension_dtypes, 
pandas_on_spark_type
+from pyspark.pandas.typedef.typehints import extension_dtypes, 
pandas_on_spark_type, as_spark_type
 from pyspark.pandas.utils import is_ansi_mode_enabled
 from pyspark.sql import functions as F, Column as PySparkColumn
 from pyspark.sql.types import (
@@ -69,6 +70,26 @@ def _non_fractional_astype(
         return _as_other_type(index_ops, dtype, spark_type)
 
 
+def _cast_back_float(
+    expr: PySparkColumn, left_dtype: Union[str, type, Dtype], right: Any
+) -> PySparkColumn:
+    """
+    Cast the result expression back to the original float dtype if needed.
+
+    This function ensures pandas on Spark matches pandas behavior when 
performing
+    arithmetic operations involving float and boolean values. In such cases, 
under ANSI mode,
+    Spark implicitly widen float32 to float64, which deviates from pandas 
behavior where the
+    result retains float32.
+    """
+    is_left_float = is_float_dtype(left_dtype)
+    is_right_bool = isinstance(right, bool) or (
+        hasattr(right, "dtype") and is_bool_dtype(right.dtype)
+    )
+    if is_left_float and is_right_bool:
+        return expr.cast(as_spark_type(left_dtype))
+    return expr
+
+
 class NumericOps(DataTypeOps):
     """The class for binary operations of numeric pandas-on-Spark objects."""
 
@@ -98,16 +119,19 @@ class NumericOps(DataTypeOps):
             raise TypeError("Modulo can not be applied to given types.")
         spark_session = left._internal.spark_frame.sparkSession
 
-        def mod(left: PySparkColumn, right: Any) -> PySparkColumn:
+        def mod(left_op: PySparkColumn, right_op: Any) -> PySparkColumn:
             if is_ansi_mode_enabled(spark_session):
-                return F.when(F.lit(right == 0), F.lit(None)).otherwise(
-                    ((left % right) + right) % right
+                expr = F.when(F.lit(right_op == 0), F.lit(None)).otherwise(
+                    ((left_op % right_op) + right_op) % right_op
                 )
+                expr = _cast_back_float(expr, left.dtype, right)
             else:
-                return ((left % right) + right) % right
+                expr = ((left_op % right_op) + right_op) % right_op
+            return expr
 
-        right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
-        return column_op(mod)(left, right)
+        new_right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
+
+        return column_op(mod)(left, new_right)
 
     def pow(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_mod.py 
b/python/pyspark/pandas/tests/data_type_ops/test_num_mod.py
index 5e4b6f46f433..7809b5edf036 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_num_mod.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_num_mod.py
@@ -35,7 +35,6 @@ class NumModTestsMixin:
     def float_psser(self):
         return ps.from_pandas(self.float_pser)
 
-    @unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
     def test_mod(self):
         pdf, psdf = self.pdf, self.psdf
         for col in self.numeric_df_cols:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to