This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d304a0313ff1 [SPARK-54166][GEO][PYTHON] Introduce type encoders for 
geospatial types in PySpark
d304a0313ff1 is described below

commit d304a0313ff14948c24506500f661f415f16f9bc
Author: Uros Bojanic <[email protected]>
AuthorDate: Sat Nov 8 07:35:47 2025 -0800

    [SPARK-54166][GEO][PYTHON] Introduce type encoders for geospatial types in 
PySpark
    
    ### What changes were proposed in this pull request?
    This PR introduces type encoders for `Geography` and `Geometry` in PySpark.
    
    Note that Scala-side encoders for geospatial types were added in: 
https://github.com/apache/spark/pull/52813.
    
    ### Why are the changes needed?
    Expanding client support for geospatial types in PySpark API.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New PySpark unit tests:
    - `test_types`
    - `test_parity_types`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #52861 from uros-db/geo-pyspark-encoders.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/__init__.py                     |   4 +-
 python/pyspark/sql/pandas/types.py                 | 120 ++++++++
 .../pyspark/sql/tests/connect/test_parity_types.py |   4 +
 python/pyspark/sql/tests/test_types.py             | 322 +++++++++++++++++++++
 python/pyspark/sql/types.py                        |  32 ++
 .../sql/execution/python/EvaluatePython.scala      |  27 +-
 6 files changed, 505 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index a0a6e8ef70c8..eeeeddd00e3a 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -39,7 +39,7 @@ Important classes of Spark SQL and DataFrames:
     - :class:`pyspark.sql.Window`
       For working with window functions.
 """
-from pyspark.sql.types import Row, VariantVal
+from pyspark.sql.types import Geography, Geometry, Row, VariantVal
 from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, 
UDTFRegistration
 from pyspark.sql.session import SparkSession
 from pyspark.sql.column import Column
@@ -69,6 +69,8 @@ __all__ = [
     "DataFrameNaFunctions",
     "DataFrameStatFunctions",
     "VariantVal",
+    "Geography",
+    "Geometry",
     "Window",
     "WindowSpec",
     "DataFrameReader",
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index 327e3941d938..d8a45daa77e8 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -50,6 +50,10 @@ from pyspark.sql.types import (
     UserDefinedType,
     VariantType,
     VariantVal,
+    GeometryType,
+    Geometry,
+    GeographyType,
+    Geography,
     _create_row,
 )
 from pyspark.errors import PySparkTypeError, UnsupportedOperationException, 
PySparkValueError
@@ -202,6 +206,28 @@ def to_arrow_type(
             pa.field("metadata", pa.binary(), nullable=False, 
metadata={b"variant": b"true"}),
         ]
         arrow_type = pa.struct(fields)
+    elif type(dt) == GeometryType:
+        fields = [
+            pa.field("srid", pa.int32(), nullable=False),
+            pa.field(
+                "wkb",
+                pa.binary(),
+                nullable=False,
+                metadata={b"geometry": b"true", b"srid": str(dt.srid)},
+            ),
+        ]
+        arrow_type = pa.struct(fields)
+    elif type(dt) == GeographyType:
+        fields = [
+            pa.field("srid", pa.int32(), nullable=False),
+            pa.field(
+                "wkb",
+                pa.binary(),
+                nullable=False,
+                metadata={b"geography": b"true", b"srid": str(dt.srid)},
+            ),
+        ]
+        arrow_type = pa.struct(fields)
     else:
         raise PySparkTypeError(
             errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -272,6 +298,38 @@ def is_variant(at: "pa.DataType") -> bool:
     ) and any(field.name == "value" for field in at)
 
 
+def is_geometry(at: "pa.DataType") -> bool:
+    """Check if a PyArrow struct data type represents a geometry"""
+    import pyarrow.types as types
+
+    assert types.is_struct(at)
+
+    return any(
+        (
+            field.name == "wkb"
+            and b"geometry" in field.metadata
+            and field.metadata[b"geometry"] == b"true"
+        )
+        for field in at
+    ) and any(field.name == "srid" for field in at)
+
+
+def is_geography(at: "pa.DataType") -> bool:
+    """Check if a PyArrow struct data type represents a geography"""
+    import pyarrow.types as types
+
+    assert types.is_struct(at)
+
+    return any(
+        (
+            field.name == "wkb"
+            and b"geography" in field.metadata
+            and field.metadata[b"geography"] == b"true"
+        )
+        for field in at
+    ) and any(field.name == "srid" for field in at)
+
+
 def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> 
DataType:
     """Convert pyarrow type to Spark data type."""
     import pyarrow.types as types
@@ -337,6 +395,18 @@ def from_arrow_type(at: "pa.DataType", 
prefer_timestamp_ntz: bool = False) -> Da
     elif types.is_struct(at):
         if is_variant(at):
             return VariantType()
+        elif is_geometry(at):
+            srid = int(at.field("wkb").metadata.get(b"srid"))
+            if srid == GeometryType.MIXED_SRID:
+                return GeometryType("ANY")
+            else:
+                return GeometryType(srid)
+        elif is_geography(at):
+            srid = int(at.field("wkb").metadata.get(b"srid"))
+            if srid == GeographyType.MIXED_SRID:
+                return GeographyType("ANY")
+            else:
+                return GeographyType(srid)
         return StructType(
             [
                 StructField(
@@ -1098,6 +1168,40 @@ def _create_converter_to_pandas(
 
             return convert_variant
 
+        elif isinstance(dt, GeographyType):
+
+            def convert_geography(value: Any) -> Any:
+                if value is None:
+                    return None
+                elif (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["wkb", "srid"])
+                    and isinstance(value["wkb"], bytes)
+                    and isinstance(value["srid"], int)
+                ):
+                    return Geography.fromWKB(value["wkb"], value["srid"])
+                else:
+                    raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+            return convert_geography
+
+        elif isinstance(dt, GeometryType):
+
+            def convert_geometry(value: Any) -> Any:
+                if value is None:
+                    return None
+                elif (
+                    isinstance(value, dict)
+                    and all(key in value for key in ["wkb", "srid"])
+                    and isinstance(value["wkb"], bytes)
+                    and isinstance(value["srid"], int)
+                ):
+                    return Geometry.fromWKB(value["wkb"], value["srid"])
+                else:
+                    raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+            return convert_geometry
+
         else:
             return None
 
@@ -1360,6 +1464,22 @@ def _create_converter_from_pandas(
 
             return convert_variant
 
+        elif isinstance(dt, GeographyType):
+
+            def convert_geography(value: Any) -> Any:
+                assert isinstance(value, Geography)
+                return {"srid": value.srid, "wkb": value.wkb}
+
+            return convert_geography
+
+        elif isinstance(dt, GeometryType):
+
+            def convert_geometry(value: Any) -> Any:
+                assert isinstance(value, Geometry)
+                return {"srid": value.srid, "wkb": value.wkb}
+
+            return convert_geometry
+
         return None
 
     conv = _converter(data_type)
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py 
b/python/pyspark/sql/tests/connect/test_parity_types.py
index 6d06611def6a..a39e92233bc0 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -34,6 +34,10 @@ class TypesParityTests(TypesTestsMixin, 
ReusedConnectTestCase):
     def test_apply_schema_to_row(self):
         super().test_apply_schema_to_row()
 
+    @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
+    def test_geospatial_create_dataframe_rdd(self):
+        super().test_geospatial_create_dataframe_rdd()
+
     @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
     def test_create_dataframe_schema_mismatch(self):
         super().test_create_dataframe_schema_mismatch()
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 6979095acca8..4ff2ab3e5cd7 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -29,6 +29,8 @@ from pyspark.sql import Row
 from pyspark.sql import functions as F
 from pyspark.errors import (
     AnalysisException,
+    IllegalArgumentException,
+    SparkRuntimeException,
     ParseException,
     PySparkTypeError,
     PySparkValueError,
@@ -51,6 +53,8 @@ from pyspark.sql.types import (
     MapType,
     StringType,
     CharType,
+    Geography,
+    Geometry,
     VarcharType,
     StructType,
     StructField,
@@ -1365,6 +1369,7 @@ class TypesTestsMixin:
             NullType(),
             GeographyType(4326),
             GeographyType("ANY"),
+            GeometryType(0),
             GeometryType(4326),
             GeometryType("ANY"),
             VariantType(),
@@ -2447,6 +2452,323 @@ class TypesTestsMixin:
         with self.assertRaises(PySparkValueError, msg="Rows cannot be of type 
VariantVal"):
             self.spark.createDataFrame([VariantVal.parseJson("2")], "v 
variant")
 
+    def test_geospatial_encoding(self):
+        df = self.spark.createDataFrame(
+            [
+                (
+                    
bytes.fromhex("0101000000000000000000F03F0000000000000040"),
+                    4326,
+                )
+            ],
+            ["wkb", "srid"],
+        )
+        row = df.select(
+            F.st_geomfromwkb(df.wkb).alias("geom"),
+            F.st_geogfromwkb(df.wkb).alias("geog"),
+        ).collect()[0]
+
+        self.assertEqual(type(row["geom"]), Geometry)
+        self.assertEqual(type(row["geog"]), Geography)
+        self.assertEqual(
+            row["geom"].getBytes(), 
bytes.fromhex("0101000000000000000000F03F0000000000000040")
+        )
+        self.assertEqual(row["geom"].getSrid(), 0)
+        self.assertEqual(
+            row["geog"].getBytes(), 
bytes.fromhex("0101000000000000000000F03F0000000000000040")
+        )
+        self.assertEqual(row["geog"].getSrid(), 4326)
+
+    def test_geospatial_create_dataframe_rdd(self):
+        schema = StructType(
+            [
+                StructField("id", IntegerType(), True),
+                StructField("geom", GeometryType(0), True),
+                StructField("geom4326", GeometryType(4326), True),
+                StructField("geog", GeographyType(4326), True),
+            ]
+        )
+        geospatial_data = [
+            (
+                1,
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
0),
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
4326),
+                Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+                ),
+            ),
+            (
+                2,
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 
0),
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 
4326),
+                Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+                ),
+            ),
+        ]
+        rdd_data = self.sc.parallelize(geospatial_data)
+        df = self.spark.createDataFrame(rdd_data, schema)
+        rows = df.select(
+            F.st_asbinary(df.geom).alias("geom_wkb"),
+            F.st_srid(df.geom).alias("geom_srid"),
+            F.st_asbinary(df.geom4326).alias("geom4326_wkb"),
+            F.st_srid(df.geom4326).alias("geom4326_srid"),
+            F.st_asbinary(df.geog).alias("geog_wkb"),
+            F.st_srid(df.geog).alias("geog_srid"),
+        ).collect()
+
+        point0_wkb = 
bytes.fromhex("010100000000000000000031400000000000001c40")
+        point1_wkb = 
bytes.fromhex("010100000000000000000014400000000000001440")
+        self.assertEqual(rows[0]["geom_wkb"], point0_wkb)
+        self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb)
+        self.assertEqual(rows[0]["geog_wkb"], point0_wkb)
+        self.assertEqual(rows[1]["geom_wkb"], point1_wkb)
+        self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb)
+        self.assertEqual(rows[1]["geog_wkb"], point1_wkb)
+        self.assertEqual(rows[0]["geom_srid"], 0)
+        self.assertEqual(rows[0]["geom4326_srid"], 4326)
+        self.assertEqual(rows[0]["geog_srid"], 4326)
+        self.assertEqual(rows[1]["geom_srid"], 0)
+        self.assertEqual(rows[1]["geom4326_srid"], 4326)
+        self.assertEqual(rows[1]["geog_srid"], 4326)
+        schema_df = self.spark.createDataFrame(rdd_data).select(
+            F.col("_1").alias("id"),
+            F.col("_2").alias("geom"),
+            F.col("_3").alias("geom4326"),
+            F.col("_4").alias("geog"),
+        )
+        self.assertEqual(df.collect(), schema_df.collect())
+
+    def test_geospatial_create_dataframe(self):
+        # Positive Test: Creating DataFrame from a list of tuples with 
explicit schema
+        geospatial_data = [
+            (
+                1,
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
0),
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
4326),
+                Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+                ),
+            ),
+            (
+                2,
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 
0),
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 
4326),
+                Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+                ),
+            ),
+        ]
+        named_geospatial_data = [
+            {
+                "id": 1,
+                "geom": Geometry.fromWKB(
+                    
bytes.fromhex("010100000000000000000031400000000000001c40"), 0
+                ),
+                "geom4326": Geometry.fromWKB(
+                    
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+                ),
+                "geog": Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+                ),
+            },
+            {
+                "id": 2,
+                "geom": Geometry.fromWKB(
+                    
bytes.fromhex("010100000000000000000014400000000000001440"), 0
+                ),
+                "geom4326": Geometry.fromWKB(
+                    
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+                ),
+                "geog": Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+                ),
+            },
+        ]
+        GeospatialRow = Row("id", "geom", "geom4326", "geog")
+        spark_row_data = [
+            GeospatialRow(
+                1,
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
0),
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
4326),
+                Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+                ),
+            ),
+            GeospatialRow(
+                2,
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 
0),
+                
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 
4326),
+                Geography.fromWKB(
+                    
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+                ),
+            ),
+        ]
+        schema = StructType(
+            [
+                StructField("id", IntegerType(), True),
+                StructField("geom", GeometryType(0), True),
+                StructField("geom4326", GeometryType(4326), True),
+                StructField("geog", GeographyType(4326), True),
+            ]
+        )
+        # Negative Test: Schema mismatch
+        mismatched_schema = StructType(
+            [
+                StructField("id", IntegerType(), True),  # Should be 
GeometryType
+                StructField("geom", GeometryType(4326), True),  # Should be 
GeometryType
+                StructField("geom4326", GeometryType(4326), True),  # Should 
be GeometryType
+                StructField("geog", GeographyType(4326), True),  # Should be 
GeographyType
+            ]
+        )
+
+        # Explicitly validate single test case
+        # rest will be compared with this one.
+        df = self.spark.createDataFrame(geospatial_data, schema)
+        rows = df.select(
+            F.st_asbinary(df.geom).alias("geom_wkb"),
+            F.st_srid(df.geom).alias("geom_srid"),
+            F.st_asbinary(df.geom4326).alias("geom4326_wkb"),
+            F.st_srid(df.geom4326).alias("geom4326_srid"),
+            F.st_asbinary(df.geog).alias("geog_wkb"),
+            F.st_srid(df.geog).alias("geog_srid"),
+        ).collect()
+
+        point0_wkb = 
bytes.fromhex("010100000000000000000031400000000000001c40")
+        point1_wkb = 
bytes.fromhex("010100000000000000000014400000000000001440")
+        self.assertEqual(rows[0]["geom_wkb"], point0_wkb)
+        self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb)
+        self.assertEqual(rows[0]["geog_wkb"], point0_wkb)
+        self.assertEqual(rows[1]["geom_wkb"], point1_wkb)
+        self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb)
+        self.assertEqual(rows[1]["geog_wkb"], point1_wkb)
+        self.assertEqual(rows[0]["geom_srid"], 0)
+        self.assertEqual(rows[0]["geom4326_srid"], 4326)
+        self.assertEqual(rows[0]["geog_srid"], 4326)
+        self.assertEqual(rows[1]["geom_srid"], 0)
+        self.assertEqual(rows[1]["geom4326_srid"], 4326)
+        self.assertEqual(rows[1]["geog_srid"], 4326)
+
+        # Just the data set without parameters.
+        self.assertEqual(
+            self.spark.createDataFrame(named_geospatial_data)
+            .select("id", "geom", "geom4326", "geog")
+            .collect(),
+            df.collect(),
+        )
+        
self.assertEqual(self.spark.createDataFrame(geospatial_data).collect(), 
df.collect())
+        self.assertEqual(self.spark.createDataFrame(spark_row_data).collect(), 
df.collect())
+
+        # Define DataFrame creation methods
+        datasets = [named_geospatial_data, geospatial_data, spark_row_data]
+        schemas = [
+            schema,
+            "id INT, geom GEOMETRY(0), geom4326 GEOMETRY(4326), geog 
GEOGRAPHY(4326)",
+            ["id", "geom", "geom4326", "geog"],
+        ]
+
+        for dataset_to_check, schema_to_check in zip(datasets, schemas):
+            df_to_check = self.spark.createDataFrame(dataset_to_check, 
schema_to_check).select(
+                "id", "geom", "geom4326", "geog"
+            )
+            self.assertEqual(df_to_check.collect(), df.collect(), "DataFrame 
creation with schema")
+
+        # Negative Test: Schema mismatch
+        for dataset_to_check in datasets:
+            with self.assertRaises(SparkRuntimeException) as pe:
+                self.spark.createDataFrame(dataset_to_check, 
mismatched_schema).collect()
+
+            self.check_error(
+                exception=pe.exception,
+                errorClass="GEO_ENCODER_SRID_MISMATCH_ERROR",
+                messageParameters={"type": "GEOMETRY", "typeSrid": "4326", 
"valueSrid": "0"},
+            )
+
+    def test_geospatial_schema_inferrence(self):
+        # Mixed data with different SRIDs
+        wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+        geometry_dataset = [
+            (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), 
Geometry.fromWKB(wkb, 4326)),
+            (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), 
Geometry.fromWKB(wkb, 0)),
+            (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), 
Geometry.fromWKB(wkb, 4326)),
+            (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), 
Geometry.fromWKB(wkb, 0)),
+        ]
+        # Create DataFrame with mixed data types
+        df = self.spark.createDataFrame(geometry_dataset, ["geom0", 
"geom4326", "geom_any"])
+        expected_schema = StructType(
+            [
+                StructField("geom0", GeometryType(0), True),
+                StructField("geom4326", GeometryType(4326), True),
+                StructField("geom_any", GeometryType("ANY"), True),
+            ]
+        )
+        self.assertEqual(df.schema, expected_schema)
+
+        rows = df.select(
+            F.st_asbinary("geom0").alias("geom0_wkb"),
+            F.st_srid("geom0").alias("geom0_srid"),
+            F.st_asbinary("geom4326").alias("geom4326_wkb"),
+            F.st_srid("geom4326").alias("geom4326_srid"),
+            F.st_asbinary("geom_any").alias("geom_any_wkb"),
+            F.st_srid("geom_any").alias("geom_any_srid"),
+        ).collect()
+
+        point_wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+        self.assertEqual(rows[0]["geom0_wkb"], point_wkb)
+        self.assertEqual(rows[1]["geom0_wkb"], point_wkb)
+        self.assertEqual(rows[0]["geom4326_wkb"], point_wkb)
+        self.assertEqual(rows[1]["geom4326_wkb"], point_wkb)
+        self.assertEqual(rows[0]["geom_any_wkb"], point_wkb)
+        self.assertEqual(rows[1]["geom_any_wkb"], point_wkb)
+        self.assertEqual(rows[0]["geom0_srid"], 0)
+        self.assertEqual(rows[1]["geom0_srid"], 0)
+        self.assertEqual(rows[0]["geom4326_srid"], 4326)
+        self.assertEqual(rows[1]["geom4326_srid"], 4326)
+        self.assertEqual(rows[0]["geom_any_srid"], 4326)
+        self.assertEqual(rows[1]["geom_any_srid"], 0)
+
+    def test_geospatial_mixed_check_srid_validity(self):
+        geometry_mixed_invalid_data = [
+            (1, 
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
0)),
+            (2, 
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 
1)),
+        ]
+
+        with self.assertRaises(IllegalArgumentException) as pe:
+            self.spark.createDataFrame(geometry_mixed_invalid_data).collect()
+        self.check_error(
+            exception=pe.exception,
+            errorClass="ST_INVALID_SRID_VALUE",
+            messageParameters={"srid": "1"},
+        )
+
+        with self.assertRaises(IllegalArgumentException) as pe:
+            self.spark.createDataFrame(
+                geometry_mixed_invalid_data, "id INT, geom GEOMETRY(ANY)"
+            ).collect()
+        self.check_error(
+            exception=pe.exception,
+            errorClass="ST_INVALID_SRID_VALUE",
+            messageParameters={"srid": "1"},
+        )
+
+    def test_geospatial_result_encoding(self):
+        point_wkb = "010100000000000000000031400000000000001c40"
+        point_bytes = bytes.fromhex(point_wkb)
+        df = self.spark.sql(
+            f"""
+            SELECT ST_GeomFromWKB(X'{point_wkb}') AS geom,
+            ST_GeogFromWKB(X'{point_wkb}') AS geog"""
+        )
+        GeospatialRow = Row("geom", "geog")
+        self.assertEqual(
+            df.collect(),
+            [
+                GeospatialRow(
+                    Geometry.fromWKB(point_bytes, 0),
+                    Geography.fromWKB(point_bytes, 4326),
+                )
+            ],
+        )
+
     def test_to_ddl(self):
         schema = StructType().add("a", NullType()).add("b", 
BooleanType()).add("c", BinaryType())
         self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY")
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8aae39880072..95307ea3859c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -2517,6 +2517,8 @@ def _assert_valid_collation_provider(provider: str) -> 
None:
 # Mapping Python types to Spark SQL DataType
 _type_mappings = {
     type(None): NullType,
+    Geometry: GeometryType,
+    Geography: GeographyType,
     bool: BooleanType,
     int: LongType,
     float: DoubleType,
@@ -2648,6 +2650,12 @@ def _infer_type(
         return obj.__UDT__
 
     dataType = _type_mappings.get(type(obj))
+    if dataType is GeographyType:
+        assert isinstance(obj, Geography)
+        return GeographyType(obj.getSrid())
+    if dataType is GeometryType:
+        assert isinstance(obj, Geometry)
+        return GeometryType(obj.getSrid())
     if dataType is DecimalType:
         # the precision and scale of `obj` may be different from row to row.
         return DecimalType(38, 18)
@@ -2915,6 +2923,10 @@ def _merge_type(
         return a
     elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType):
         return b
+    elif isinstance(a, GeometryType) and isinstance(b, GeometryType) and 
a.srid != b.srid:
+        return GeometryType("ANY")
+    elif isinstance(a, GeographyType) and isinstance(b, GeographyType) and 
a.srid != b.srid:
+        return GeographyType("ANY")
     elif isinstance(a, AtomicType) and isinstance(b, StringType):
         return b
     elif isinstance(a, StringType) and isinstance(b, AtomicType):
@@ -3068,6 +3080,8 @@ _acceptable_types = {
     ArrayType: (list, tuple, array),
     MapType: (dict,),
     StructType: (tuple, list, dict),
+    GeometryType: (Geometry,),
+    GeographyType: (Geography,),
     VariantType: (
         bool,
         int,
@@ -3419,6 +3433,24 @@ def _make_type_verifier(
 
         verify_value = verify_variant
 
+    elif isinstance(dataType, GeometryType):
+
+        def verify_geometry(obj: Any) -> None:
+            assert_acceptable_types(obj)
+            verify_acceptable_types(obj)
+            assert isinstance(obj, Geometry)
+
+        verify_value = verify_geometry
+
+    elif isinstance(dataType, GeographyType):
+
+        def verify_geography(obj: Any) -> None:
+            assert_acceptable_types(obj)
+            verify_acceptable_types(obj)
+            assert isinstance(obj, Geography)
+
+        verify_value = verify_geography
+
     else:
 
         def verify_default(obj: Any) -> None:
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 212cc5db124c..33622ca7349a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -29,9 +29,9 @@ import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData, STUtils}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, 
VariantVal}
 
 object EvaluatePython {
 
@@ -43,7 +43,7 @@ object EvaluatePython {
 
   def needConversionInPython(dt: DataType): Boolean = dt match {
     case DateType | TimestampType | TimestampNTZType | VariantType | _: 
DayTimeIntervalType
-         | _: TimeType => true
+         | _: TimeType | _: GeometryType | _: GeographyType => true
     case _: StructType => true
     case _: UserDefinedType[_] => true
     case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -92,6 +92,10 @@ object EvaluatePython {
 
       case (s: UTF8String, _: StringType) => s.toString
 
+      case (g: GeometryVal, gt: GeometryType) => STUtils.deserializeGeom(g, gt)
+
+      case (g: GeographyVal, gt: GeographyType) => STUtils.deserializeGeog(g, 
gt)
+
       case (bytes: Array[Byte], BinaryType) =>
         if (binaryAsBytes) {
           new BytesWrapper(bytes)
@@ -228,6 +232,23 @@ object EvaluatePython {
         )
     }
 
+    case g: GeographyType => (obj: Any) => nullSafeConvert(obj) {
+      case s: java.util.HashMap[_, _] =>
+        val geographySrid = s.get("srid").asInstanceOf[Int]
+        g.assertSridAllowedForType(geographySrid)
+        STUtils.stGeogFromWKB(
+          s.get("wkb").asInstanceOf[Array[Byte]])
+    }
+
+    case g: GeometryType => (obj: Any) => nullSafeConvert(obj) {
+      case s: java.util.HashMap[_, _] =>
+        val geometrySrid = s.get("srid").asInstanceOf[Int]
+        g.assertSridAllowedForType(geometrySrid)
+        STUtils.stGeomFromWKB(
+          s.get("wkb").asInstanceOf[Array[Byte]],
+          geometrySrid)
+    }
+
     case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to