This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new a7ad5f011 [SEDONA-543] Fixes RS_Union_aggr throwing referenceRaster is 
null error when run on cluster (#1364)
a7ad5f011 is described below

commit a7ad5f01158996fc047b80c1baf87353842d932c
Author: Pranav Toggi <[email protected]>
AuthorDate: Fri Apr 26 16:19:25 2024 -0400

    [SEDONA-543] Fixes RS_Union_aggr throwing referenceRaster is null error 
when run on cluster (#1364)
    
    * Init: move class level members to data buffer
    
    * move sampleDimension serde to Serde.java
    
    * update serde for sampleDimensions
    
    * Add checks for index
    
    * Undo typo
    
    * add custom GridSampleDimensionSerializer
---
 .../apache/sedona/common/raster/serde/Serde.java   |  20 ++++
 .../expressions/raster/AggregateFunctions.scala    | 126 ++++++++++++---------
 2 files changed, 90 insertions(+), 56 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java 
b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
index 616ded015..848c00b3f 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -32,6 +32,8 @@ import org.opengis.referencing.operation.MathTransform;
 
 import javax.media.jai.RenderedImageAdapter;
 import java.awt.image.RenderedImage;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URI;
@@ -176,4 +178,22 @@ public class Serde {
             return state.restore();
         }
     }
+
+    public static byte[] serializeGridSampleDimension(GridSampleDimension 
sampleDimension) {
+        Kryo kryo = kryos.get();
+        ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        Output output = new Output(baos);
+        GridSampleDimensionSerializer serializer = new 
GridSampleDimensionSerializer();
+        serializer.write(kryo, output, sampleDimension);
+        output.close();
+        return baos.toByteArray();
+    }
+
+    public static GridSampleDimension deserializeGridSampleDimension(byte[] 
data) {
+        Kryo kryo = kryos.get();
+        Input input = new Input(new ByteArrayInputStream(data));
+        GridSampleDimensionSerializer serializer = new 
GridSampleDimensionSerializer();
+        return serializer.read(kryo, input, GridSampleDimension.class);
+    }
+
 }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
index 1fa1cb6e7..b76841638 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
@@ -19,6 +19,7 @@
 
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
+import org.apache.sedona.common.raster.serde.Serde
 import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
 import org.apache.sedona.common.utils.RasterUtils
 import org.apache.spark.sql.Encoder
@@ -29,93 +30,106 @@ import org.geotools.coverage.grid.GridCoverage2D
 
 import java.awt.image.WritableRaster
 import javax.media.jai.RasterFactory
-import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-case class BandData(var bandInt: Array[Int], var bandDouble: Array[Double], 
var index: Int, var isIntegral: Boolean)
+case class BandData(
+                     var bandInt: Array[Int],
+                     var bandDouble: Array[Double],
+                     var index: Int,
+                     var isIntegral: Boolean,
+                     var serializedRaster: Array[Byte],
+                     var serializedSampleDimension: Array[Byte]
+                   )
+
 
 /**
  * Return a raster containing bands at given indexes from all rasters in a 
given column
  */
 class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), 
ArrayBuffer[BandData], GridCoverage2D]  {
 
-  var width: Int = -1
-
-  var height: Int = -1
-
-  var referenceRaster: GridCoverage2D = _
-
-  var gridSampleDimension: mutable.Map[Int, GridSampleDimension] = new 
mutable.HashMap()
-
   def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
 
-  /**
-   * Valid raster shape to be the same in the given column
-   */
-  def checkRasterShape(raster: GridCoverage2D): Boolean = {
-    // first iteration
-    if (width == -1 && height == -1) {
-      width = RasterAccessors.getWidth(raster)
-      height = RasterAccessors.getHeight(raster)
-      referenceRaster = raster
-      true
-    } else {
-      val widthNewRaster = RasterAccessors.getWidth(raster)
-      val heightNewRaster = RasterAccessors.getHeight(raster)
-
-      width == widthNewRaster && height == heightNewRaster
-    }
-  }
-
   def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): 
ArrayBuffer[BandData] = {
     val raster = input._1
-    if (!checkRasterShape(raster)) {
-      throw new IllegalArgumentException("Rasters provides should be of the 
same shape.")
-    }
-    if (gridSampleDimension.contains(input._2)) {
-      throw new IllegalArgumentException("Indexes shouldn't be repeated. Index 
should be in an arithmetic sequence.")
-    }
-
     val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
     val isIntegral = 
RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)
 
-    val bandData = if (isIntegral) {
-      val band = rasterData.getSamples(0, 0, width, height, 0, 
null.asInstanceOf[Array[Int]])
-      BandData(band, null, input._2, isIntegral)
+    // Serializing GridSampleDimension
+    val serializedBytes = 
Serde.serializeGridSampleDimension(raster.getSampleDimension(0))
+
+    // Check and set dimensions based on the first raster in the buffer
+    if (buffer.isEmpty) {
+      val width = RasterAccessors.getWidth(raster)
+      val height = RasterAccessors.getHeight(raster)
+      val referenceSerializedRaster = Serde.serialize(raster)
+
+      buffer += BandData(
+        if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, 
null.asInstanceOf[Array[Int]]) else null,
+        if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, 
null.asInstanceOf[Array[Double]]) else null,
+        input._2,
+        isIntegral,
+        referenceSerializedRaster,
+        serializedBytes
+      )
     } else {
-      val band = rasterData.getSamples(0, 0, width, height, 0, 
null.asInstanceOf[Array[Double]])
-      BandData(null, band, input._2, isIntegral)
+      val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
+      val width = RasterAccessors.getWidth(referenceRaster)
+      val height = RasterAccessors.getHeight(referenceRaster)
+
+      if (width != RasterAccessors.getWidth(raster) || height != 
RasterAccessors.getHeight(raster)) {
+        throw new IllegalArgumentException("All rasters must have the same 
dimensions")
+      }
+
+      buffer += BandData(
+        if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, 
null.asInstanceOf[Array[Int]]) else null,
+        if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, 
null.asInstanceOf[Array[Double]]) else null,
+        input._2,
+        isIntegral,
+        Serde.serialize(raster),
+        serializedBytes
+      )
     }
-    gridSampleDimension = gridSampleDimension + (input._2 -> 
raster.getSampleDimension(0))
 
-    buffer += bandData
+    buffer
   }
 
+
   def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): 
ArrayBuffer[BandData] = {
-    ArrayBuffer.concat(buffer1, buffer2)
+    val combined = ArrayBuffer.concat(buffer1, buffer2)
+    if (combined.map(_.index).distinct.length != combined.length) {
+      throw new IllegalArgumentException("Indexes shouldn't be repeated.")
+    }
+    combined
   }
 
+
   def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
     val sortedMerged = merged.sortBy(_.index)
+    if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
+      if (idx > 0) (band.index - sortedMerged(idx - 1).index) != 
(sortedMerged(1).index - sortedMerged(0).index)
+      else false
+    }) {
+      throw new IllegalArgumentException("Index should be in an arithmetic 
sequence.")
+    }
+
     val numBands = sortedMerged.length
-    val rasterData = RasterUtils.getRaster(referenceRaster.getRenderedImage)
-    val dataTypeCode = rasterData.getDataBuffer.getDataType
+    val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
+    val width = RasterAccessors.getWidth(referenceRaster)
+    val height = RasterAccessors.getHeight(referenceRaster)
+    val dataTypeCode = 
RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
     val resultRaster: WritableRaster = 
RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
     val gridSampleDimensions: Array[GridSampleDimension] = new 
Array[GridSampleDimension](numBands)
-    var indexCheck = 1
 
-    for (bandData: BandData <- sortedMerged) {
-      if (bandData.index != indexCheck) {
-        throw new IllegalArgumentException("Indexes should be in a valid 
arithmetic sequence.")
-      }
-      indexCheck += 1
-      gridSampleDimensions(bandData.index - 1) = 
gridSampleDimension(bandData.index)
-      if(RasterUtils.isDataTypeIntegral(dataTypeCode))
-        resultRaster.setSamples(0, 0, width, height, (bandData.index - 1), 
bandData.bandInt)
-      else
-        resultRaster.setSamples(0, 0, width, height, bandData.index - 1, 
bandData.bandDouble)
+    for ((bandData, idx) <- sortedMerged.zipWithIndex) {
+      // Deserializing GridSampleDimension
+      gridSampleDimensions(idx) =  
Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension)
 
+      if(bandData.isIntegral)
+        resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt)
+      else
+        resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble)
     }
+
     val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
     RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, 
gridSampleDimensions, referenceRaster, noDataValue, true)
   }

Reply via email to