This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d304a0313ff1 [SPARK-54166][GEO][PYTHON] Introduce type encoders for
geospatial types in PySpark
d304a0313ff1 is described below
commit d304a0313ff14948c24506500f661f415f16f9bc
Author: Uros Bojanic <[email protected]>
AuthorDate: Sat Nov 8 07:35:47 2025 -0800
[SPARK-54166][GEO][PYTHON] Introduce type encoders for geospatial types in
PySpark
### What changes were proposed in this pull request?
This PR introduces type encoders for `Geography` and `Geometry` in PySpark.
Note that Scala-side encoders for geospatial types were added in:
https://github.com/apache/spark/pull/52813.
### Why are the changes needed?
Expanding client support for geospatial types in PySpark API.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New PySpark unit tests:
- `test_types`
- `test_parity_types`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52861 from uros-db/geo-pyspark-encoders.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/__init__.py | 4 +-
python/pyspark/sql/pandas/types.py | 120 ++++++++
.../pyspark/sql/tests/connect/test_parity_types.py | 4 +
python/pyspark/sql/tests/test_types.py | 322 +++++++++++++++++++++
python/pyspark/sql/types.py | 32 ++
.../sql/execution/python/EvaluatePython.scala | 27 +-
6 files changed, 505 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index a0a6e8ef70c8..eeeeddd00e3a 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -39,7 +39,7 @@ Important classes of Spark SQL and DataFrames:
- :class:`pyspark.sql.Window`
For working with window functions.
"""
-from pyspark.sql.types import Row, VariantVal
+from pyspark.sql.types import Geography, Geometry, Row, VariantVal
from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration,
UDTFRegistration
from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
@@ -69,6 +69,8 @@ __all__ = [
"DataFrameNaFunctions",
"DataFrameStatFunctions",
"VariantVal",
+ "Geography",
+ "Geometry",
"Window",
"WindowSpec",
"DataFrameReader",
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index 327e3941d938..d8a45daa77e8 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -50,6 +50,10 @@ from pyspark.sql.types import (
UserDefinedType,
VariantType,
VariantVal,
+ GeometryType,
+ Geometry,
+ GeographyType,
+ Geography,
_create_row,
)
from pyspark.errors import PySparkTypeError, UnsupportedOperationException,
PySparkValueError
@@ -202,6 +206,28 @@ def to_arrow_type(
pa.field("metadata", pa.binary(), nullable=False,
metadata={b"variant": b"true"}),
]
arrow_type = pa.struct(fields)
+ elif type(dt) == GeometryType:
+ fields = [
+ pa.field("srid", pa.int32(), nullable=False),
+ pa.field(
+ "wkb",
+ pa.binary(),
+ nullable=False,
+ metadata={b"geometry": b"true", b"srid": str(dt.srid)},
+ ),
+ ]
+ arrow_type = pa.struct(fields)
+ elif type(dt) == GeographyType:
+ fields = [
+ pa.field("srid", pa.int32(), nullable=False),
+ pa.field(
+ "wkb",
+ pa.binary(),
+ nullable=False,
+ metadata={b"geography": b"true", b"srid": str(dt.srid)},
+ ),
+ ]
+ arrow_type = pa.struct(fields)
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -272,6 +298,38 @@ def is_variant(at: "pa.DataType") -> bool:
) and any(field.name == "value" for field in at)
+def is_geometry(at: "pa.DataType") -> bool:
+ """Check if a PyArrow struct data type represents a geometry"""
+ import pyarrow.types as types
+
+ assert types.is_struct(at)
+
+ return any(
+ (
+ field.name == "wkb"
+ and b"geometry" in field.metadata
+ and field.metadata[b"geometry"] == b"true"
+ )
+ for field in at
+ ) and any(field.name == "srid" for field in at)
+
+
+def is_geography(at: "pa.DataType") -> bool:
+ """Check if a PyArrow struct data type represents a geography"""
+ import pyarrow.types as types
+
+ assert types.is_struct(at)
+
+ return any(
+ (
+ field.name == "wkb"
+ and b"geography" in field.metadata
+ and field.metadata[b"geography"] == b"true"
+ )
+ for field in at
+ ) and any(field.name == "srid" for field in at)
+
+
def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) ->
DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
@@ -337,6 +395,18 @@ def from_arrow_type(at: "pa.DataType",
prefer_timestamp_ntz: bool = False) -> Da
elif types.is_struct(at):
if is_variant(at):
return VariantType()
+ elif is_geometry(at):
+ srid = int(at.field("wkb").metadata.get(b"srid"))
+ if srid == GeometryType.MIXED_SRID:
+ return GeometryType("ANY")
+ else:
+ return GeometryType(srid)
+ elif is_geography(at):
+ srid = int(at.field("wkb").metadata.get(b"srid"))
+ if srid == GeographyType.MIXED_SRID:
+ return GeographyType("ANY")
+ else:
+ return GeographyType(srid)
return StructType(
[
StructField(
@@ -1098,6 +1168,40 @@ def _create_converter_to_pandas(
return convert_variant
+ elif isinstance(dt, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geography.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+ return convert_geography
+
+ elif isinstance(dt, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geometry.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+ return convert_geometry
+
else:
return None
@@ -1360,6 +1464,22 @@ def _create_converter_from_pandas(
return convert_variant
+ elif isinstance(dt, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ assert isinstance(value, Geography)
+ return {"srid": value.srid, "wkb": value.wkb}
+
+ return convert_geography
+
+ elif isinstance(dt, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ assert isinstance(value, Geometry)
+ return {"srid": value.srid, "wkb": value.wkb}
+
+ return convert_geometry
+
return None
conv = _converter(data_type)
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py
b/python/pyspark/sql/tests/connect/test_parity_types.py
index 6d06611def6a..a39e92233bc0 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -34,6 +34,10 @@ class TypesParityTests(TypesTestsMixin,
ReusedConnectTestCase):
def test_apply_schema_to_row(self):
super().test_apply_schema_to_row()
+ @unittest.skip("Spark Connect does not support RDD but the tests depend on
them.")
+ def test_geospatial_create_dataframe_rdd(self):
+ super().test_geospatial_create_dataframe_rdd()
+
@unittest.skip("Spark Connect does not support RDD but the tests depend on
them.")
def test_create_dataframe_schema_mismatch(self):
super().test_create_dataframe_schema_mismatch()
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 6979095acca8..4ff2ab3e5cd7 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -29,6 +29,8 @@ from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.errors import (
AnalysisException,
+ IllegalArgumentException,
+ SparkRuntimeException,
ParseException,
PySparkTypeError,
PySparkValueError,
@@ -51,6 +53,8 @@ from pyspark.sql.types import (
MapType,
StringType,
CharType,
+ Geography,
+ Geometry,
VarcharType,
StructType,
StructField,
@@ -1365,6 +1369,7 @@ class TypesTestsMixin:
NullType(),
GeographyType(4326),
GeographyType("ANY"),
+ GeometryType(0),
GeometryType(4326),
GeometryType("ANY"),
VariantType(),
@@ -2447,6 +2452,323 @@ class TypesTestsMixin:
with self.assertRaises(PySparkValueError, msg="Rows cannot be of type
VariantVal"):
self.spark.createDataFrame([VariantVal.parseJson("2")], "v
variant")
+ def test_geospatial_encoding(self):
+ df = self.spark.createDataFrame(
+ [
+ (
+
bytes.fromhex("0101000000000000000000F03F0000000000000040"),
+ 4326,
+ )
+ ],
+ ["wkb", "srid"],
+ )
+ row = df.select(
+ F.st_geomfromwkb(df.wkb).alias("geom"),
+ F.st_geogfromwkb(df.wkb).alias("geog"),
+ ).collect()[0]
+
+ self.assertEqual(type(row["geom"]), Geometry)
+ self.assertEqual(type(row["geog"]), Geography)
+ self.assertEqual(
+ row["geom"].getBytes(),
bytes.fromhex("0101000000000000000000F03F0000000000000040")
+ )
+ self.assertEqual(row["geom"].getSrid(), 0)
+ self.assertEqual(
+ row["geog"].getBytes(),
bytes.fromhex("0101000000000000000000F03F0000000000000040")
+ )
+ self.assertEqual(row["geog"].getSrid(), 4326)
+
+ def test_geospatial_create_dataframe_rdd(self):
+ schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("geom", GeometryType(0), True),
+ StructField("geom4326", GeometryType(4326), True),
+ StructField("geog", GeographyType(4326), True),
+ ]
+ )
+ geospatial_data = [
+ (
+ 1,
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
0),
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
4326),
+ Geography.fromWKB(
+
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ ),
+ (
+ 2,
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"),
0),
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"),
4326),
+ Geography.fromWKB(
+
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ ),
+ ]
+ rdd_data = self.sc.parallelize(geospatial_data)
+ df = self.spark.createDataFrame(rdd_data, schema)
+ rows = df.select(
+ F.st_asbinary(df.geom).alias("geom_wkb"),
+ F.st_srid(df.geom).alias("geom_srid"),
+ F.st_asbinary(df.geom4326).alias("geom4326_wkb"),
+ F.st_srid(df.geom4326).alias("geom4326_srid"),
+ F.st_asbinary(df.geog).alias("geog_wkb"),
+ F.st_srid(df.geog).alias("geog_srid"),
+ ).collect()
+
+ point0_wkb =
bytes.fromhex("010100000000000000000031400000000000001c40")
+ point1_wkb =
bytes.fromhex("010100000000000000000014400000000000001440")
+ self.assertEqual(rows[0]["geom_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geog_wkb"], point0_wkb)
+ self.assertEqual(rows[1]["geom_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geog_wkb"], point1_wkb)
+ self.assertEqual(rows[0]["geom_srid"], 0)
+ self.assertEqual(rows[0]["geom4326_srid"], 4326)
+ self.assertEqual(rows[0]["geog_srid"], 4326)
+ self.assertEqual(rows[1]["geom_srid"], 0)
+ self.assertEqual(rows[1]["geom4326_srid"], 4326)
+ self.assertEqual(rows[1]["geog_srid"], 4326)
+ schema_df = self.spark.createDataFrame(rdd_data).select(
+ F.col("_1").alias("id"),
+ F.col("_2").alias("geom"),
+ F.col("_3").alias("geom4326"),
+ F.col("_4").alias("geog"),
+ )
+ self.assertEqual(df.collect(), schema_df.collect())
+
+ def test_geospatial_create_dataframe(self):
+ # Positive Test: Creating DataFrame from a list of tuples with
explicit schema
+ geospatial_data = [
+ (
+ 1,
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
0),
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
4326),
+ Geography.fromWKB(
+
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ ),
+ (
+ 2,
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"),
0),
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"),
4326),
+ Geography.fromWKB(
+
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ ),
+ ]
+ named_geospatial_data = [
+ {
+ "id": 1,
+ "geom": Geometry.fromWKB(
+
bytes.fromhex("010100000000000000000031400000000000001c40"), 0
+ ),
+ "geom4326": Geometry.fromWKB(
+
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ "geog": Geography.fromWKB(
+
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ },
+ {
+ "id": 2,
+ "geom": Geometry.fromWKB(
+
bytes.fromhex("010100000000000000000014400000000000001440"), 0
+ ),
+ "geom4326": Geometry.fromWKB(
+
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ "geog": Geography.fromWKB(
+
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ },
+ ]
+ GeospatialRow = Row("id", "geom", "geom4326", "geog")
+ spark_row_data = [
+ GeospatialRow(
+ 1,
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
0),
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
4326),
+ Geography.fromWKB(
+
bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ ),
+ GeospatialRow(
+ 2,
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"),
0),
+
Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"),
4326),
+ Geography.fromWKB(
+
bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ ),
+ ]
+ schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("geom", GeometryType(0), True),
+ StructField("geom4326", GeometryType(4326), True),
+ StructField("geog", GeographyType(4326), True),
+ ]
+ )
+ # Negative Test: Schema mismatch
+ mismatched_schema = StructType(
+ [
+ StructField("id", IntegerType(), True), # Should be
GeometryType
+ StructField("geom", GeometryType(4326), True), # Should be
GeometryType
+ StructField("geom4326", GeometryType(4326), True), # Should
be GeometryType
+ StructField("geog", GeographyType(4326), True), # Should be
GeographyType
+ ]
+ )
+
+ # Explicitly validate single test case
+ # rest will be compared with this one.
+ df = self.spark.createDataFrame(geospatial_data, schema)
+ rows = df.select(
+ F.st_asbinary(df.geom).alias("geom_wkb"),
+ F.st_srid(df.geom).alias("geom_srid"),
+ F.st_asbinary(df.geom4326).alias("geom4326_wkb"),
+ F.st_srid(df.geom4326).alias("geom4326_srid"),
+ F.st_asbinary(df.geog).alias("geog_wkb"),
+ F.st_srid(df.geog).alias("geog_srid"),
+ ).collect()
+
+ point0_wkb =
bytes.fromhex("010100000000000000000031400000000000001c40")
+ point1_wkb =
bytes.fromhex("010100000000000000000014400000000000001440")
+ self.assertEqual(rows[0]["geom_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geog_wkb"], point0_wkb)
+ self.assertEqual(rows[1]["geom_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geog_wkb"], point1_wkb)
+ self.assertEqual(rows[0]["geom_srid"], 0)
+ self.assertEqual(rows[0]["geom4326_srid"], 4326)
+ self.assertEqual(rows[0]["geog_srid"], 4326)
+ self.assertEqual(rows[1]["geom_srid"], 0)
+ self.assertEqual(rows[1]["geom4326_srid"], 4326)
+ self.assertEqual(rows[1]["geog_srid"], 4326)
+
+ # Just the data set without parameters.
+ self.assertEqual(
+ self.spark.createDataFrame(named_geospatial_data)
+ .select("id", "geom", "geom4326", "geog")
+ .collect(),
+ df.collect(),
+ )
+
self.assertEqual(self.spark.createDataFrame(geospatial_data).collect(),
df.collect())
+ self.assertEqual(self.spark.createDataFrame(spark_row_data).collect(),
df.collect())
+
+ # Define DataFrame creation methods
+ datasets = [named_geospatial_data, geospatial_data, spark_row_data]
+ schemas = [
+ schema,
+ "id INT, geom GEOMETRY(0), geom4326 GEOMETRY(4326), geog
GEOGRAPHY(4326)",
+ ["id", "geom", "geom4326", "geog"],
+ ]
+
+ for dataset_to_check, schema_to_check in zip(datasets, schemas):
+ df_to_check = self.spark.createDataFrame(dataset_to_check,
schema_to_check).select(
+ "id", "geom", "geom4326", "geog"
+ )
+ self.assertEqual(df_to_check.collect(), df.collect(), "DataFrame
creation with schema")
+
+ # Negative Test: Schema mismatch
+ for dataset_to_check in datasets:
+ with self.assertRaises(SparkRuntimeException) as pe:
+ self.spark.createDataFrame(dataset_to_check,
mismatched_schema).collect()
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="GEO_ENCODER_SRID_MISMATCH_ERROR",
+ messageParameters={"type": "GEOMETRY", "typeSrid": "4326",
"valueSrid": "0"},
+ )
+
+ def test_geospatial_schema_inferrence(self):
+ # Mixed data with different SRIDs
+ wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+ geometry_dataset = [
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326),
Geometry.fromWKB(wkb, 4326)),
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326),
Geometry.fromWKB(wkb, 0)),
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326),
Geometry.fromWKB(wkb, 4326)),
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326),
Geometry.fromWKB(wkb, 0)),
+ ]
+ # Create DataFrame with mixed data types
+ df = self.spark.createDataFrame(geometry_dataset, ["geom0",
"geom4326", "geom_any"])
+ expected_schema = StructType(
+ [
+ StructField("geom0", GeometryType(0), True),
+ StructField("geom4326", GeometryType(4326), True),
+ StructField("geom_any", GeometryType("ANY"), True),
+ ]
+ )
+ self.assertEqual(df.schema, expected_schema)
+
+ rows = df.select(
+ F.st_asbinary("geom0").alias("geom0_wkb"),
+ F.st_srid("geom0").alias("geom0_srid"),
+ F.st_asbinary("geom4326").alias("geom4326_wkb"),
+ F.st_srid("geom4326").alias("geom4326_srid"),
+ F.st_asbinary("geom_any").alias("geom_any_wkb"),
+ F.st_srid("geom_any").alias("geom_any_srid"),
+ ).collect()
+
+ point_wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+ self.assertEqual(rows[0]["geom0_wkb"], point_wkb)
+ self.assertEqual(rows[1]["geom0_wkb"], point_wkb)
+ self.assertEqual(rows[0]["geom4326_wkb"], point_wkb)
+ self.assertEqual(rows[1]["geom4326_wkb"], point_wkb)
+ self.assertEqual(rows[0]["geom_any_wkb"], point_wkb)
+ self.assertEqual(rows[1]["geom_any_wkb"], point_wkb)
+ self.assertEqual(rows[0]["geom0_srid"], 0)
+ self.assertEqual(rows[1]["geom0_srid"], 0)
+ self.assertEqual(rows[0]["geom4326_srid"], 4326)
+ self.assertEqual(rows[1]["geom4326_srid"], 4326)
+ self.assertEqual(rows[0]["geom_any_srid"], 4326)
+ self.assertEqual(rows[1]["geom_any_srid"], 0)
+
+ def test_geospatial_mixed_check_srid_validity(self):
+ geometry_mixed_invalid_data = [
+ (1,
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
0)),
+ (2,
Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"),
1)),
+ ]
+
+ with self.assertRaises(IllegalArgumentException) as pe:
+ self.spark.createDataFrame(geometry_mixed_invalid_data).collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="ST_INVALID_SRID_VALUE",
+ messageParameters={"srid": "1"},
+ )
+
+ with self.assertRaises(IllegalArgumentException) as pe:
+ self.spark.createDataFrame(
+ geometry_mixed_invalid_data, "id INT, geom GEOMETRY(ANY)"
+ ).collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="ST_INVALID_SRID_VALUE",
+ messageParameters={"srid": "1"},
+ )
+
+ def test_geospatial_result_encoding(self):
+ point_wkb = "010100000000000000000031400000000000001c40"
+ point_bytes = bytes.fromhex(point_wkb)
+ df = self.spark.sql(
+ f"""
+ SELECT ST_GeomFromWKB(X'{point_wkb}') AS geom,
+ ST_GeogFromWKB(X'{point_wkb}') AS geog"""
+ )
+ GeospatialRow = Row("geom", "geog")
+ self.assertEqual(
+ df.collect(),
+ [
+ GeospatialRow(
+ Geometry.fromWKB(point_bytes, 0),
+ Geography.fromWKB(point_bytes, 4326),
+ )
+ ],
+ )
+
def test_to_ddl(self):
schema = StructType().add("a", NullType()).add("b",
BooleanType()).add("c", BinaryType())
self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY")
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8aae39880072..95307ea3859c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -2517,6 +2517,8 @@ def _assert_valid_collation_provider(provider: str) ->
None:
# Mapping Python types to Spark SQL DataType
_type_mappings = {
type(None): NullType,
+ Geometry: GeometryType,
+ Geography: GeographyType,
bool: BooleanType,
int: LongType,
float: DoubleType,
@@ -2648,6 +2650,12 @@ def _infer_type(
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
+ if dataType is GeographyType:
+ assert isinstance(obj, Geography)
+ return GeographyType(obj.getSrid())
+ if dataType is GeometryType:
+ assert isinstance(obj, Geometry)
+ return GeometryType(obj.getSrid())
if dataType is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
@@ -2915,6 +2923,10 @@ def _merge_type(
return a
elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType):
return b
+ elif isinstance(a, GeometryType) and isinstance(b, GeometryType) and
a.srid != b.srid:
+ return GeometryType("ANY")
+ elif isinstance(a, GeographyType) and isinstance(b, GeographyType) and
a.srid != b.srid:
+ return GeographyType("ANY")
elif isinstance(a, AtomicType) and isinstance(b, StringType):
return b
elif isinstance(a, StringType) and isinstance(b, AtomicType):
@@ -3068,6 +3080,8 @@ _acceptable_types = {
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list, dict),
+ GeometryType: (Geometry,),
+ GeographyType: (Geography,),
VariantType: (
bool,
int,
@@ -3419,6 +3433,24 @@ def _make_type_verifier(
verify_value = verify_variant
+ elif isinstance(dataType, GeometryType):
+
+ def verify_geometry(obj: Any) -> None:
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ assert isinstance(obj, Geometry)
+
+ verify_value = verify_geometry
+
+ elif isinstance(dataType, GeographyType):
+
+ def verify_geography(obj: Any) -> None:
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ assert isinstance(obj, Geography)
+
+ verify_value = verify_geography
+
else:
def verify_default(obj: Any) -> None:
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 212cc5db124c..33622ca7349a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -29,9 +29,9 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData,
GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData,
GenericArrayData, MapData, STUtils}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String,
VariantVal}
object EvaluatePython {
@@ -43,7 +43,7 @@ object EvaluatePython {
def needConversionInPython(dt: DataType): Boolean = dt match {
case DateType | TimestampType | TimestampNTZType | VariantType | _:
DayTimeIntervalType
- | _: TimeType => true
+ | _: TimeType | _: GeometryType | _: GeographyType => true
case _: StructType => true
case _: UserDefinedType[_] => true
case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -92,6 +92,10 @@ object EvaluatePython {
case (s: UTF8String, _: StringType) => s.toString
+ case (g: GeometryVal, gt: GeometryType) => STUtils.deserializeGeom(g, gt)
+
+ case (g: GeographyVal, gt: GeographyType) => STUtils.deserializeGeog(g,
gt)
+
case (bytes: Array[Byte], BinaryType) =>
if (binaryAsBytes) {
new BytesWrapper(bytes)
@@ -228,6 +232,23 @@ object EvaluatePython {
)
}
+ case g: GeographyType => (obj: Any) => nullSafeConvert(obj) {
+ case s: java.util.HashMap[_, _] =>
+ val geographySrid = s.get("srid").asInstanceOf[Int]
+ g.assertSridAllowedForType(geographySrid)
+ STUtils.stGeogFromWKB(
+ s.get("wkb").asInstanceOf[Array[Byte]])
+ }
+
+ case g: GeometryType => (obj: Any) => nullSafeConvert(obj) {
+ case s: java.util.HashMap[_, _] =>
+ val geometrySrid = s.get("srid").asInstanceOf[Int]
+ g.assertSridAllowedForType(geometrySrid)
+ STUtils.stGeomFromWKB(
+ s.get("wkb").asInstanceOf[Array[Byte]],
+ geometrySrid)
+ }
+
case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]