This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 883fd8ed9 [SEDONA-478] Sedona 1.5.1 context initialization fails 
without GeoTools coverage (#1394)
883fd8ed9 is described below

commit 883fd8ed96fb50fb45e1d1794d27fe876290279e
Author: Jia Yu <[email protected]>
AuthorDate: Tue Apr 30 21:44:55 2024 -0700

    [SEDONA-478] Sedona 1.5.1 context initialization fails without GeoTools 
coverage (#1394)
    
    * Make Raster optional
    
    * UDT registrator is not accessible in Spark 3.1 and before
---
 .../org/apache/sedona/spark/SedonaContext.scala    |  2 +
 .../org/apache/sedona/sql/RasterRegistrator.scala  | 55 ++++++++++++++++++++++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  6 +--
 .../apache/sedona/sql/UDF/RasterUdafCatalog.scala  | 29 ++++++++++++
 .../org/apache/sedona/sql/UDF/UdfRegistrator.scala |  2 -
 .../sedona/sql/utils/SedonaSQLRegistrator.scala    |  2 +
 ...per.scala => RasterUdtRegistratorWrapper.scala} | 19 +-------
 .../sql/sedona_sql/UDT/UdtRegistratorWrapper.scala | 12 -----
 8 files changed, 91 insertions(+), 36 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala 
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index 6b262ed16..e16c5b766 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -20,6 +20,7 @@ package org.apache.sedona.spark
 
 import org.apache.sedona.common.utils.TelemetryCollector
 import org.apache.sedona.core.serde.SedonaKryoRegistrator
+import org.apache.sedona.sql.RasterRegistrator
 import org.apache.sedona.sql.UDF.UdfRegistrator
 import org.apache.sedona.sql.UDT.UdtRegistrator
 import org.apache.spark.serializer.KryoSerializer
@@ -57,6 +58,7 @@ object SedonaContext {
       sparkSession.experimental.extraOptimizations ++= Seq(new 
SpatialFilterPushDownForGeoParquet(sparkSession))
     }
     addGeoParquetToSupportNestedFilterSources(sparkSession)
+    RasterRegistrator.registerAll(sparkSession)
     UdtRegistrator.registerAll()
     UdfRegistrator.registerAll(sparkSession)
     sparkSession
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
new file mode 100644
index 000000000..e3152e40d
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.sedona.sql.UDF.RasterUdafCatalog
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
+import org.apache.spark.sql.{SparkSession, functions}
+import org.slf4j.{Logger, LoggerFactory}
+
+object RasterRegistrator {
+  val logger: Logger = LoggerFactory.getLogger(getClass)
+  private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
+
+  // Helper method to check if GridCoverage2D is available
+  private def isGeoToolsAvailable: Boolean = {
+    try {
+      Class.forName(gridClassName, true, 
Thread.currentThread().getContextClassLoader)
+      true
+    } catch {
+      case _: ClassNotFoundException =>
+        logger.warn("Geotools was not found on the classpath. Raster 
operations will not be available.")
+        false
+    }
+  }
+
+  def registerAll(sparkSession: SparkSession): Unit = {
+    if (isGeoToolsAvailable) {
+      RasterUdtRegistratorWrapper.registerAll(gridClassName)
+      
sparkSession.udf.register(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName,
 functions.udaf(RasterUdafCatalog.rasterAggregateExpression))
+    }
+  }
+
+  def dropAll(sparkSession: SparkSession): Unit = {
+    if (isGeoToolsAvailable) {
+      
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName))
+    }
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 011e89b07..25467f3cd 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -22,14 +22,12 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionInfo, Literal}
 import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.sedona_sql.expressions._
 import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
 import org.apache.spark.sql.sedona_sql.expressions.raster._
-import org.apache.spark.sql.sedona_sql.expressions._
-import org.geotools.coverage.grid.GridCoverage2D
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.operation.buffer.BufferParameters
 
-import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
 
 object Catalog {
@@ -285,8 +283,6 @@ object Catalog {
     function[RS_NetCDFInfo]()
   )
 
-  val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), 
ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr
-
   val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = 
Seq(
     new ST_Union_Aggr,
     new ST_Envelope_Aggr,
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala
new file mode 100644
index 000000000..a3deb50f5
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.UDF
+
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.sedona_sql.expressions.raster.{BandData, 
RS_Union_Aggr}
+import org.geotools.coverage.grid.GridCoverage2D
+
+import scala.collection.mutable.ArrayBuffer
+
+object RasterUdafCatalog {
+  val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), 
ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr
+}
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
index c8d6590b3..547556848 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
@@ -37,7 +37,6 @@ object UdfRegistrator {
     }
 Catalog.aggregateExpressions.foreach(f => 
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // 
SPARK3 anchor
 //Catalog.aggregateExpressions_UDAF.foreach(f => 
sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
-    
sparkSession.udf.register(Catalog.rasterAggregateExpression.getClass.getSimpleName,
 functions.udaf(Catalog.rasterAggregateExpression))
   }
 
   def dropAll(sparkSession: SparkSession): Unit = {
@@ -46,6 +45,5 @@ Catalog.aggregateExpressions.foreach(f => 
sparkSession.udf.register(f.getClass.g
     }
 Catalog.aggregateExpressions.foreach(f => 
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName)))
 // SPARK3 anchor
 //Catalog.aggregateExpressions_UDAF.foreach(f => 
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName)))
 // SPARK2 anchor
-    
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(Catalog.rasterAggregateExpression.getClass.getSimpleName))
   }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
 
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 91a712fed..52f7ceb1c 100644
--- 
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ 
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -19,6 +19,7 @@
 package org.apache.sedona.sql.utils
 
 import org.apache.sedona.spark.SedonaContext
+import org.apache.sedona.sql.RasterRegistrator
 import org.apache.sedona.sql.UDF.UdfRegistrator
 import org.apache.spark.sql.{SQLContext, SparkSession}
 
@@ -44,5 +45,6 @@ object SedonaSQLRegistrator {
 
   def dropAll(sparkSession: SparkSession): Unit = {
     UdfRegistrator.dropAll(sparkSession)
+    RasterRegistrator.dropAll(sparkSession)
   }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala
similarity index 53%
copy from 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
copy to 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala
index 127205faf..b4a4e258a 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala
@@ -18,26 +18,11 @@
  */
 package org.apache.spark.sql.sedona_sql.UDT
 
-import org.slf4j.{Logger, LoggerFactory}
 import org.apache.spark.sql.types.UDTRegistration
-import org.locationtech.jts.geom.Geometry
-import org.locationtech.jts.index.SpatialIndex
 
-object UdtRegistratorWrapper {
+object RasterUdtRegistratorWrapper {
 
-  val logger: Logger = LoggerFactory.getLogger(getClass)
-
-  def registerAll(): Unit = {
-    UDTRegistration.register(classOf[Geometry].getName, 
classOf[GeometryUDT].getName)
-    UDTRegistration.register(classOf[SpatialIndex].getName, 
classOf[IndexUDT].getName)
-    // Rasters requires geotools which is optional.
-    val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
-    try {
-      // Trigger an exception if geotools is not found.
-      java.lang.Class.forName(gridClassName, true, 
Thread.currentThread().getContextClassLoader)
+    def registerAll(gridClassName: String): Unit = {
       UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
-    } catch {
-      case e: ClassNotFoundException => logger.warn("Geotools was not found on 
the classpath. Raster type will not be registered.")
     }
-  }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
index 127205faf..a96d15c00 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala
@@ -18,26 +18,14 @@
  */
 package org.apache.spark.sql.sedona_sql.UDT
 
-import org.slf4j.{Logger, LoggerFactory}
 import org.apache.spark.sql.types.UDTRegistration
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.index.SpatialIndex
 
 object UdtRegistratorWrapper {
 
-  val logger: Logger = LoggerFactory.getLogger(getClass)
-
   def registerAll(): Unit = {
     UDTRegistration.register(classOf[Geometry].getName, 
classOf[GeometryUDT].getName)
     UDTRegistration.register(classOf[SpatialIndex].getName, 
classOf[IndexUDT].getName)
-    // Rasters requires geotools which is optional.
-    val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
-    try {
-      // Trigger an exception if geotools is not found.
-      java.lang.Class.forName(gridClassName, true, 
Thread.currentThread().getContextClassLoader)
-      UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
-    } catch {
-      case e: ClassNotFoundException => logger.warn("Geotools was not found on 
the classpath. Raster type will not be registered.")
-    }
   }
 }

Reply via email to