This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 883fd8ed9 [SEDONA-478] Sedona 1.5.1 context initialization fails
without GeoTools coverage (#1394)
883fd8ed9 is described below
commit 883fd8ed96fb50fb45e1d1794d27fe876290279e
Author: Jia Yu <[email protected]>
AuthorDate: Tue Apr 30 21:44:55 2024 -0700
[SEDONA-478] Sedona 1.5.1 context initialization fails without GeoTools
coverage (#1394)
* Make Raster optional
* UDT registrator is not accessible in Spark 3.1 and before
---
.../org/apache/sedona/spark/SedonaContext.scala | 2 +
.../org/apache/sedona/sql/RasterRegistrator.scala | 55 ++++++++++++++++++++++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 6 +--
.../apache/sedona/sql/UDF/RasterUdafCatalog.scala | 29 ++++++++++++
.../org/apache/sedona/sql/UDF/UdfRegistrator.scala | 2 -
.../sedona/sql/utils/SedonaSQLRegistrator.scala | 2 +
...per.scala => RasterUdtRegistratorWrapper.scala} | 19 +-------
.../sql/sedona_sql/UDT/UdtRegistratorWrapper.scala | 12 -----
8 files changed, 91 insertions(+), 36 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index 6b262ed16..e16c5b766 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -20,6 +20,7 @@ package org.apache.sedona.spark
import org.apache.sedona.common.utils.TelemetryCollector
import org.apache.sedona.core.serde.SedonaKryoRegistrator
+import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
@@ -57,6 +58,7 @@ object SedonaContext {
sparkSession.experimental.extraOptimizations ++= Seq(new
SpatialFilterPushDownForGeoParquet(sparkSession))
}
addGeoParquetToSupportNestedFilterSources(sparkSession)
+ RasterRegistrator.registerAll(sparkSession)
UdtRegistrator.registerAll()
UdfRegistrator.registerAll(sparkSession)
sparkSession
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
new file mode 100644
index 000000000..e3152e40d
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.sedona.sql.UDF.RasterUdafCatalog
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
+import org.apache.spark.sql.{SparkSession, functions}
+import org.slf4j.{Logger, LoggerFactory}
+
+object RasterRegistrator {
+ val logger: Logger = LoggerFactory.getLogger(getClass)
+ private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
+
+ // Helper method to check if GridCoverage2D is available
+ private def isGeoToolsAvailable: Boolean = {
+ try {
+ Class.forName(gridClassName, true,
Thread.currentThread().getContextClassLoader)
+ true
+ } catch {
+ case _: ClassNotFoundException =>
+ logger.warn("Geotools was not found on the classpath. Raster
operations will not be available.")
+ false
+ }
+ }
+
+ def registerAll(sparkSession: SparkSession): Unit = {
+ if (isGeoToolsAvailable) {
+ RasterUdtRegistratorWrapper.registerAll(gridClassName)
+
sparkSession.udf.register(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName,
functions.udaf(RasterUdafCatalog.rasterAggregateExpression))
+ }
+ }
+
+ def dropAll(sparkSession: SparkSession): Unit = {
+ if (isGeoToolsAvailable) {
+
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName))
+ }
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 011e89b07..25467f3cd 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -22,14 +22,12 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes,
Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
import org.apache.spark.sql.sedona_sql.expressions.raster._
-import org.apache.spark.sql.sedona_sql.expressions._
-import org.geotools.coverage.grid.GridCoverage2D
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters
-import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
object Catalog {
@@ -285,8 +283,6 @@ object Catalog {
function[RS_NetCDFInfo]()
)
- val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int),
ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr
-
val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
Seq(
new ST_Union_Aggr,
new ST_Envelope_Aggr,
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala
new file mode 100644
index 000000000..a3deb50f5
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.UDF
+
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.sedona_sql.expressions.raster.{BandData,
RS_Union_Aggr}
+import org.geotools.coverage.grid.GridCoverage2D
+
+import scala.collection.mutable.ArrayBuffer
+
+object RasterUdafCatalog {
+ val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int),
ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr
+}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
index c8d6590b3..547556848 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
@@ -37,7 +37,6 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) //
SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f =>
sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
-
sparkSession.udf.register(Catalog.rasterAggregateExpression.getClass.getSimpleName,
functions.udaf(Catalog.rasterAggregateExpression))
}
def dropAll(sparkSession: SparkSession): Unit = {
@@ -46,6 +45,5 @@ Catalog.aggregateExpressions.foreach(f =>
sparkSession.udf.register(f.getClass.g
}
Catalog.aggregateExpressions.foreach(f =>
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName)))
// SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f =>
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName)))
// SPARK2 anchor
-
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(Catalog.rasterAggregateExpression.getClass.getSimpleName))
}
}
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 91a712fed..52f7ceb1c 100644
---
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -19,6 +19,7 @@
package org.apache.sedona.sql.utils
import org.apache.sedona.spark.SedonaContext
+import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.spark.sql.{SQLContext, SparkSession}
@@ -44,5 +45,6 @@ object SedonaSQLRegistrator {
def dropAll(sparkSession: SparkSession): Unit = {
UdfRegistrator.dropAll(sparkSession)
+ RasterRegistrator.dropAll(sparkSession)
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala
similarity index 53%
copy from
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
copy to
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala
index 127205faf..b4a4e258a 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala
@@ -18,26 +18,11 @@
*/
package org.apache.spark.sql.sedona_sql.UDT
-import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.sql.types.UDTRegistration
-import org.locationtech.jts.geom.Geometry
-import org.locationtech.jts.index.SpatialIndex
-object UdtRegistratorWrapper {
+object RasterUdtRegistratorWrapper {
- val logger: Logger = LoggerFactory.getLogger(getClass)
-
- def registerAll(): Unit = {
- UDTRegistration.register(classOf[Geometry].getName,
classOf[GeometryUDT].getName)
- UDTRegistration.register(classOf[SpatialIndex].getName,
classOf[IndexUDT].getName)
- // Rasters requires geotools which is optional.
- val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
- try {
- // Trigger an exception if geotools is not found.
- java.lang.Class.forName(gridClassName, true,
Thread.currentThread().getContextClassLoader)
+ def registerAll(gridClassName: String): Unit = {
UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
- } catch {
- case e: ClassNotFoundException => logger.warn("Geotools was not found on
the classpath. Raster type will not be registered.")
}
- }
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
index 127205faf..a96d15c00 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
@@ -18,26 +18,14 @@
*/
package org.apache.spark.sql.sedona_sql.UDT
-import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.sql.types.UDTRegistration
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.index.SpatialIndex
object UdtRegistratorWrapper {
- val logger: Logger = LoggerFactory.getLogger(getClass)
-
def registerAll(): Unit = {
UDTRegistration.register(classOf[Geometry].getName,
classOf[GeometryUDT].getName)
UDTRegistration.register(classOf[SpatialIndex].getName,
classOf[IndexUDT].getName)
- // Rasters requires geotools which is optional.
- val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
- try {
- // Trigger an exception if geotools is not found.
- java.lang.Class.forName(gridClassName, true,
Thread.currentThread().getContextClassLoader)
- UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
- } catch {
- case e: ClassNotFoundException => logger.warn("Geotools was not found on
the classpath. Raster type will not be registered.")
- }
}
}