This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ec751177664 [SPARK-54940][PYTHON] Add tests for pa.scalar type 
inference
6ec751177664 is described below

commit 6ec751177664b6628b73d8d07d309daf832c2e41
Author: Fangchen Li <[email protected]>
AuthorDate: Fri Jan 9 19:17:33 2026 +0800

    [SPARK-54940][PYTHON] Add tests for pa.scalar type inference
    
    ### What changes were proposed in this pull request?
    
    Add tests for pa.scalar type inference.
    
    ### Why are the changes needed?
    
    We want to monitor changes in PyArrow's behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Opus 4.5
    
    Closes #53727 from fangchenli/pa-scalar-inference-tests.
    
    Authored-by: Fangchen Li <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 .../pyarrow/test_pyarrow_scalar_type_inference.py  | 492 +++++++++++++++++++++
 2 files changed, 493 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2b97f51d5c7f..4e956314c3d8 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -491,6 +491,7 @@ pyspark_core = Module(
         "pyspark.tests.test_stage_sched",
         # unittests for upstream projects
         "pyspark.tests.upstream.pyarrow.test_pyarrow_ignore_timezone",
+        "pyspark.tests.upstream.pyarrow.test_pyarrow_scalar_type_inference",
     ],
 )
 
diff --git 
a/python/pyspark/tests/upstream/pyarrow/test_pyarrow_scalar_type_inference.py 
b/python/pyspark/tests/upstream/pyarrow/test_pyarrow_scalar_type_inference.py
new file mode 100644
index 000000000000..b2a259616ee1
--- /dev/null
+++ 
b/python/pyspark/tests/upstream/pyarrow/test_pyarrow_scalar_type_inference.py
@@ -0,0 +1,492 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Tests for PyArrow pa.scalar type inference behavior.
+
+This module tests how PyArrow infers types when creating scalars from various
+Python objects. This helps ensure PySpark's assumptions about PyArrow behavior
+remain valid across versions.
+"""
+
+import datetime
+import math
+import unittest
+from decimal import Decimal
+from zoneinfo import ZoneInfo
+
+from pyspark.testing.utils import (
+    have_numpy,
+    have_pandas,
+    have_pyarrow,
+    numpy_requirement_message,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
[email protected](not have_pyarrow, pyarrow_requirement_message)
+class PyArrowScalarTypeInferenceTests(unittest.TestCase):
+    """Test pa.scalar type inference from various Python types."""
+
+    def test_simple_types(self):
+        """Test type inference for simple Python types."""
+        import pyarrow as pa
+
+        # (value, expected_type, check_roundtrip)
+        # check_roundtrip=False for values that need special comparison (e.g., 
NaN)
+        test_cases = [
+            # None
+            (None, pa.null(), False),
+            # Boolean
+            (True, pa.bool_(), True),
+            (False, pa.bool_(), True),
+            # Integer - PyArrow infers int64 for Python ints
+            (0, pa.int64(), True),
+            (1, pa.int64(), True),
+            (-1, pa.int64(), True),
+            (42, pa.int64(), True),
+            (2**62, pa.int64(), True),
+            (-(2**62), pa.int64(), True),
+            (2**63 - 1, pa.int64(), True),  # max int64
+            (-(2**63), pa.int64(), True),  # min int64
+            # Float - PyArrow infers float64 for Python floats
+            (0.0, pa.float64(), True),
+            (1.5, pa.float64(), True),
+            (-1.5, pa.float64(), True),
+            (3.14159, pa.float64(), True),
+            (float("inf"), pa.float64(), True),
+            (float("-inf"), pa.float64(), True),
+            (float("nan"), pa.float64(), False),  # NaN needs special 
comparison
+            # String - PyArrow infers string (utf8) for Python str
+            ("", pa.string(), True),
+            ("hello", pa.string(), True),
+            ("日本語", pa.string(), True),
+            ("emoji: 🎉", pa.string(), True),
+            # Bytes - PyArrow infers binary for Python bytes
+            (b"", pa.binary(), True),
+            (b"hello", pa.binary(), True),
+            (b"\x00\x01\x02", pa.binary(), True),
+            # Bytearray and memoryview infer as binary
+            (bytearray(b"hello"), pa.binary(), False),  # as_py returns bytes
+            (memoryview(b"hello"), pa.binary(), False),  # as_py returns bytes
+        ]
+
+        for value, expected_type, check_roundtrip in test_cases:
+            scalar = pa.scalar(value)
+            self.assertEqual(
+                scalar.type,
+                expected_type,
+                f"Type mismatch for {type(value).__name__}({value!r}): "
+                f"expected {expected_type}, got {scalar.type}",
+            )
+            if check_roundtrip:
+                self.assertEqual(
+                    scalar.as_py(),
+                    value,
+                    f"Roundtrip failed for {type(value).__name__}({value!r})",
+                )
+
+        # Special case: NaN comparison
+        scalar = pa.scalar(float("nan"))
+        self.assertTrue(math.isnan(scalar.as_py()), "NaN roundtrip failed")
+
+        # Special case: None is_valid check
+        scalar = pa.scalar(None)
+        self.assertFalse(scalar.is_valid, "None should create invalid scalar")
+
+    def test_integer_overflow(self):
+        """Test that integers outside int64 range raise OverflowError."""
+        import pyarrow as pa
+
+        for value in [2**63, -(2**63) - 1]:
+            with self.assertRaises(OverflowError):
+                pa.scalar(value)
+
+    def test_decimal_types(self):
+        """Test Decimal type inference with precision and scale."""
+        import pyarrow as pa
+
+        # (value, expected_scale, expected_precision)
+        test_cases = [
+            (Decimal("123.45"), 2, 5),
+            (Decimal("12345678901234567890.123456789"), 9, 29),
+            (Decimal("-999.999"), 3, 6),
+            (Decimal("0"), 0, 1),
+            (Decimal("0.00"), 2, 2),
+        ]
+
+        for value, expected_scale, expected_precision in test_cases:
+            scalar = pa.scalar(value)
+            self.assertTrue(
+                pa.types.is_decimal(scalar.type),
+                f"Expected decimal type for {value}, got {scalar.type}",
+            )
+            self.assertEqual(
+                scalar.type.scale,
+                expected_scale,
+                f"Scale mismatch for {value}: expected {expected_scale}, got 
{scalar.type.scale}",
+            )
+            self.assertEqual(
+                scalar.type.precision,
+                expected_precision,
+                f"Precision mismatch for {value}: "
+                f"expected {expected_precision}, got {scalar.type.precision}",
+            )
+            self.assertEqual(
+                str(scalar.as_py()),
+                str(value),
+                f"Roundtrip failed for Decimal {value}",
+            )
+
+    def test_date_time_types(self):
+        """Test date, time, datetime, and timedelta type inference."""
+        import pyarrow as pa
+
+        # Date - infers date32
+        date_cases = [
+            datetime.date(2024, 1, 15),
+            datetime.date(1970, 1, 1),  # Unix epoch
+            datetime.date(1, 1, 1),  # Minimum
+            datetime.date(9999, 12, 31),  # Maximum
+        ]
+        for value in date_cases:
+            scalar = pa.scalar(value)
+            self.assertEqual(
+                scalar.type,
+                pa.date32(),
+                f"Type mismatch for date {value}: expected date32, got 
{scalar.type}",
+            )
+            self.assertEqual(scalar.as_py(), value, f"Roundtrip failed for 
date {value}")
+
+        # Time - infers time64[us]
+        time_cases = [
+            datetime.time(12, 30, 45, 123456),
+            datetime.time(12, 30, 45),
+            datetime.time(0, 0, 0),
+        ]
+        for value in time_cases:
+            scalar = pa.scalar(value)
+            self.assertEqual(
+                scalar.type,
+                pa.time64("us"),
+                f"Type mismatch for time {value}: expected time64[us], got 
{scalar.type}",
+            )
+            self.assertEqual(scalar.as_py(), value, f"Roundtrip failed for 
time {value}")
+
+        # Timedelta - infers duration[us]
+        timedelta_cases = [
+            datetime.timedelta(days=5, hours=3, minutes=30, seconds=15, 
microseconds=123456),
+            datetime.timedelta(0),
+            datetime.timedelta(days=-1),
+        ]
+        for value in timedelta_cases:
+            scalar = pa.scalar(value)
+            self.assertEqual(
+                scalar.type,
+                pa.duration("us"),
+                f"Type mismatch for timedelta {value}: expected duration[us], 
got {scalar.type}",
+            )
+            self.assertEqual(scalar.as_py(), value, f"Roundtrip failed for 
timedelta {value}")
+
+    def test_datetime_timezone(self):
+        """Test datetime type inference with and without timezone."""
+        import pyarrow as pa
+
+        # Timezone-naive datetime -> timestamp[us]
+        dt_naive = datetime.datetime(2024, 1, 15, 12, 30, 45, 123456)
+        scalar = pa.scalar(dt_naive)
+        self.assertEqual(scalar.type, pa.timestamp("us"))
+        self.assertIsNone(scalar.type.tz)
+        self.assertEqual(scalar.as_py(), dt_naive)
+
+        # Timezone-aware datetime -> timestamp[us, tz=...]
+        tz_cases = [
+            ("America/New_York", ZoneInfo("America/New_York")),
+            ("UTC", ZoneInfo("UTC")),
+        ]
+        for expected_tz, tzinfo in tz_cases:
+            dt_aware = datetime.datetime(2024, 1, 15, 12, 30, 45, 123456, 
tzinfo=tzinfo)
+            scalar = pa.scalar(dt_aware)
+            self.assertEqual(
+                scalar.type.unit,
+                "us",
+                f"Unit mismatch for tz={expected_tz}",
+            )
+            self.assertEqual(
+                scalar.type.tz,
+                expected_tz,
+                f"Timezone mismatch: expected {expected_tz}, got 
{scalar.type.tz}",
+            )
+            self.assertEqual(
+                scalar.as_py().timestamp(),
+                dt_aware.timestamp(),
+                f"Timestamp mismatch for tz={expected_tz}",
+            )
+
+    @unittest.skipIf(not have_pandas, pandas_requirement_message)
+    def test_pandas_types(self):
+        """Test Pandas Timestamp and Timedelta type inference."""
+        import pandas as pd
+        import pyarrow as pa
+
+        # Timezone-naive Timestamp
+        ts = pd.Timestamp("2024-01-15 12:30:45.123456")
+        scalar = pa.scalar(ts)
+        self.assertEqual(scalar.type.unit, "us")
+        self.assertIsNone(scalar.type.tz)
+        self.assertEqual(scalar.as_py(), ts.to_pydatetime())
+
+        # Timezone-aware Timestamp
+        ts_tz = pd.Timestamp("2024-01-15 12:30:45.123456", 
tz="America/New_York")
+        scalar = pa.scalar(ts_tz)
+        self.assertEqual(scalar.type.unit, "us")
+        self.assertEqual(scalar.type.tz, "America/New_York")
+        self.assertEqual(scalar.as_py().timestamp(), ts_tz.timestamp())
+
+        # Timedelta
+        td = pd.Timedelta(days=5, hours=3, minutes=30, seconds=15, 
microseconds=123456)
+        scalar = pa.scalar(td)
+        self.assertEqual(scalar.type, pa.duration("us"))
+        self.assertEqual(scalar.as_py(), td.to_pytimedelta())
+
+    @unittest.skipIf(not have_pandas, pandas_requirement_message)
+    def test_pandas_nat_and_na(self):
+        """Test that pd.NaT and pd.NA raise errors in pa.scalar."""
+        import pandas as pd
+        import pyarrow as pa
+
+        # pd.NaT raises ValueError
+        with self.assertRaises(ValueError):
+            pa.scalar(pd.NaT)
+
+        # pd.NA raises ArrowInvalid
+        with self.assertRaises(pa.ArrowInvalid):
+            pa.scalar(pd.NA)
+
+    @unittest.skipIf(not have_numpy, numpy_requirement_message)
+    def test_numpy_scalar_types(self):
+        """Test NumPy scalar type inference."""
+        import numpy as np
+        import pyarrow as pa
+
+        # (numpy_value, expected_pyarrow_type)
+        test_cases = [
+            # Integer types
+            (np.int8(42), pa.int8()),
+            (np.int16(42), pa.int16()),
+            (np.int32(42), pa.int32()),
+            (np.int64(42), pa.int64()),
+            (np.uint8(42), pa.uint8()),
+            (np.uint16(42), pa.uint16()),
+            (np.uint32(42), pa.uint32()),
+            (np.uint64(42), pa.uint64()),
+            (np.uint64(2**64 - 1), pa.uint64()),  # max uint64
+            # Float types
+            (np.float16(1.5), pa.float16()),
+            (np.float32(1.5), pa.float32()),
+            (np.float64(1.5), pa.float64()),
+            # Boolean
+            (np.bool_(True), pa.bool_()),
+            (np.bool_(False), pa.bool_()),
+            # String and bytes
+            (np.str_("hello"), pa.string()),
+            (np.bytes_(b"hello"), pa.binary()),
+        ]
+
+        for np_val, expected_type in test_cases:
+            scalar = pa.scalar(np_val)
+            self.assertEqual(
+                scalar.type,
+                expected_type,
+                f"Type mismatch for {type(np_val).__name__}: "
+                f"expected {expected_type}, got {scalar.type}",
+            )
+
+        # Float NaN - needs special comparison
+        nan_cases = [
+            (np.float32("nan"), pa.float32()),
+            (np.float64("nan"), pa.float64()),
+        ]
+        for np_val, expected_type in nan_cases:
+            scalar = pa.scalar(np_val)
+            self.assertEqual(
+                scalar.type,
+                expected_type,
+                f"Type mismatch for {type(np_val).__name__} NaN: "
+                f"expected {expected_type}, got {scalar.type}",
+            )
+            self.assertTrue(
+                math.isnan(scalar.as_py()),
+                f"NaN roundtrip failed for {type(np_val).__name__}",
+            )
+
+    @unittest.skipIf(not have_numpy, numpy_requirement_message)
+    def test_numpy_datetime64(self):
+        """Test NumPy datetime64 type inference with different units."""
+        import numpy as np
+        import pyarrow as pa
+
+        # Time-based units work directly
+        unit_cases = [
+            (np.datetime64("2024-01-15T12:30", "s"), "s"),
+            (np.datetime64("2024-01-15T12:30:45", "ms"), "ms"),
+            (np.datetime64("2024-01-15T12:30:45.123456", "us"), "us"),
+            (np.datetime64("2024-01-15T12:30:45.123456789", "ns"), "ns"),
+        ]
+        for np_val, expected_unit in unit_cases:
+            scalar = pa.scalar(np_val)
+            self.assertTrue(pa.types.is_timestamp(scalar.type))
+            self.assertEqual(
+                scalar.type.unit,
+                expected_unit,
+                f"Unit mismatch: expected {expected_unit}, got 
{scalar.type.unit}",
+            )
+
+        # Date-based units (D) raise TypeError
+        with self.assertRaises(TypeError):
+            pa.scalar(np.datetime64("2024-01-15", "D"))
+
+    @unittest.skipIf(not have_numpy, numpy_requirement_message)
+    def test_numpy_timedelta64(self):
+        """Test NumPy timedelta64 type inference with different units."""
+        import numpy as np
+        import pyarrow as pa
+
+        # Time-based units work directly
+        unit_cases = [
+            (np.timedelta64(3600, "s"), "s"),
+            (np.timedelta64(1000, "ms"), "ms"),
+            (np.timedelta64(1000000, "us"), "us"),
+            (np.timedelta64(1000000000, "ns"), "ns"),
+        ]
+        for np_val, expected_unit in unit_cases:
+            scalar = pa.scalar(np_val)
+            self.assertTrue(pa.types.is_duration(scalar.type))
+            self.assertEqual(
+                scalar.type.unit,
+                expected_unit,
+                f"Unit mismatch: expected {expected_unit}, got 
{scalar.type.unit}",
+            )
+
+        # Day unit (D) raises ArrowNotImplementedError
+        with self.assertRaises(pa.ArrowNotImplementedError):
+            pa.scalar(np.timedelta64(5, "D"))
+
+    @unittest.skipIf(not have_numpy, numpy_requirement_message)
+    def test_numpy_nat(self):
+        """Test NumPy NaT handling - generic NaT raises, NaT with unit 
works."""
+        import numpy as np
+        import pyarrow as pa
+
+        # Generic NaT without unit raises ArrowNotImplementedError
+        with self.assertRaises(pa.ArrowNotImplementedError):
+            pa.scalar(np.datetime64("NaT"))
+
+        with self.assertRaises(pa.ArrowNotImplementedError):
+            pa.scalar(np.timedelta64("NaT"))
+
+        # NaT with explicit unit creates null scalar
+        scalar = pa.scalar(np.datetime64("NaT", "ns"))
+        self.assertTrue(pa.types.is_timestamp(scalar.type))
+        self.assertFalse(scalar.is_valid)
+
+        scalar = pa.scalar(np.timedelta64("NaT", "ns"))
+        self.assertTrue(pa.types.is_duration(scalar.type))
+        self.assertFalse(scalar.is_valid)
+
+    def test_list_types(self):
+        """Test list type inference."""
+        import pyarrow as pa
+
+        # Homogeneous lists
+        list_cases = [
+            ([1, 2, 3], pa.int64()),
+            ([1.0, 2.0, 3.0], pa.float64()),
+            (["a", "b", "c"], pa.string()),
+        ]
+        for value, expected_element_type in list_cases:
+            scalar = pa.scalar(value)
+            self.assertTrue(pa.types.is_list(scalar.type))
+            self.assertEqual(
+                scalar.type.value_type,
+                expected_element_type,
+                f"Element type mismatch for {value}",
+            )
+            self.assertEqual(scalar.as_py(), value)
+
+        # Empty list infers as list<null>
+        scalar = pa.scalar([])
+        self.assertTrue(pa.types.is_list(scalar.type))
+        self.assertEqual(scalar.type.value_type, pa.null())
+
+        # Nested list
+        scalar = pa.scalar([[1, 2], [3, 4]])
+        self.assertTrue(pa.types.is_list(scalar.type))
+        self.assertTrue(pa.types.is_list(scalar.type.value_type))
+
+        # List with None elements
+        scalar = pa.scalar([1, None, 3])
+        self.assertTrue(pa.types.is_list(scalar.type))
+        self.assertEqual(scalar.type.value_type, pa.int64())
+        self.assertEqual(scalar.as_py(), [1, None, 3])
+
+        # Mixed int and float promotes to float64
+        scalar = pa.scalar([1, 2.0])
+        self.assertTrue(pa.types.is_list(scalar.type))
+        self.assertEqual(scalar.type.value_type, pa.float64())
+
+        # Mixed incompatible types raise error
+        with self.assertRaises(pa.ArrowInvalid):
+            pa.scalar([1, "a"])
+
+    def test_dict_types(self):
+        """Test dict/struct type inference."""
+        import pyarrow as pa
+
+        # Simple dict creates struct
+        scalar = pa.scalar({"a": 1, "b": 2})
+        self.assertTrue(pa.types.is_struct(scalar.type))
+        result = scalar.as_py()
+        self.assertEqual(result["a"], 1)
+        self.assertEqual(result["b"], 2)
+
+        # Dict with mixed value types
+        scalar = pa.scalar({"int_val": 42, "str_val": "hello", "float_val": 
3.14})
+        self.assertTrue(pa.types.is_struct(scalar.type))
+
+        # Dict with None value
+        scalar = pa.scalar({"a": None, "b": 1})
+        self.assertTrue(pa.types.is_struct(scalar.type))
+        result = scalar.as_py()
+        self.assertIsNone(result["a"])
+        self.assertEqual(result["b"], 1)
+
+    def test_tuple_types(self):
+        """Test tuple type inference - tuples are treated as lists."""
+        import pyarrow as pa
+
+        scalar = pa.scalar((1, 2, 3))
+        self.assertTrue(pa.types.is_list(scalar.type))
+        self.assertEqual(scalar.type.value_type, pa.int64())
+        self.assertEqual(scalar.as_py(), [1, 2, 3])
+
+
+if __name__ == "__main__":
+    from pyspark.testing import main
+
+    main()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to