This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 56a3d0ff7 [SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when 
working with non-double band data (#1402)
56a3d0ff7 is described below

commit 56a3d0ff7f29a01d4ccc3e2b359f0518c7325de7
Author: Jia Yu <[email protected]>
AuthorDate: Tue May 7 19:43:19 2024 -0700

    [SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when working with 
non-double band data (#1402)
    
    Co-authored-by: Kristin Cowalcijk <[email protected]>
---
 .../apache/sedona/common/raster/MapAlgebra.java    |  10 +-
 .../sedona/common/raster/RasterBandEditors.java    |   8 +-
 .../apache/sedona/common/raster/serde/Serde.java   |  20 ----
 .../apache/sedona/common/utils/RasterUtils.java    |  28 +++---
 .../expressions/raster/AggregateFunctions.scala    | 105 +++++++++++----------
 5 files changed, 79 insertions(+), 92 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java 
b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
index 564077a36..b5fb469cd 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
@@ -78,12 +78,11 @@ public class MapAlgebra
             throw new IllegalArgumentException("Band index is out of bounds. 
Must be between 1 and " + (numBands + 1) + ")");
         }
 
-        Double[] bandValuesClass = 
Arrays.stream(bandValues).boxed().toArray(Double[]::new);
         if (bandIndex == numBands + 1) {
-            return RasterUtils.copyRasterAndAppendBand(rasterGeom, 
bandValuesClass, noDataValue);
+            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues, 
noDataValue);
         }
         else {
-            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, 
bandValuesClass, noDataValue, true);
+            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, 
bandValues, noDataValue, true);
         }
     }
 
@@ -94,12 +93,11 @@ public class MapAlgebra
             throw new IllegalArgumentException("Band index is out of bounds. 
Must be between 1 and " + (numBands + 1) + ")");
         }
 
-        Double[] bandValuesClass = 
Arrays.stream(bandValues).boxed().toArray(Double[]::new);
         if (bandIndex == numBands + 1) {
-            return RasterUtils.copyRasterAndAppendBand(rasterGeom, 
bandValuesClass);
+            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues);
         }
         else {
-            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, 
bandValuesClass);
+            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, 
bandValues);
         }
     }
 
diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java 
b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
index 4ffac4d30..41b159a5d 100644
--- 
a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
+++ 
b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
@@ -135,16 +135,16 @@ public class RasterBandEditors {
         if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
             int[] bandValues = rasterData.getSamples(0, 0, width, height, 
fromBand - 1, (int[]) null);
             if (numBands + 1 == toRasterIndex) {
-                return RasterUtils.copyRasterAndAppendBand(toRaster, 
Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue);
+                return RasterUtils.copyRasterAndAppendBand(toRaster, 
bandValues, noDataValue);
             } else {
-                return RasterUtils.copyRasterAndReplaceBand(toRaster, 
fromBand, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), 
noDataValue, false);
+                return RasterUtils.copyRasterAndReplaceBand(toRaster, 
fromBand, bandValues, noDataValue, false);
             }
         } else {
             double[] bandValues = rasterData.getSamples(0, 0, width, height, 
fromBand - 1, (double[]) null);
             if (numBands + 1 == toRasterIndex) {
-                return RasterUtils.copyRasterAndAppendBand(toRaster, 
Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue);
+                return RasterUtils.copyRasterAndAppendBand(toRaster, 
bandValues, noDataValue);
             } else {
-                return RasterUtils.copyRasterAndReplaceBand(toRaster, 
fromBand, Arrays.stream(bandValues).boxed().toArray(Double[]::new), 
noDataValue, false);
+                return RasterUtils.copyRasterAndReplaceBand(toRaster, 
fromBand, bandValues, noDataValue, false);
             }
         }
     }
diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java 
b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
index e775ee46a..7f67708b6 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -32,8 +32,6 @@ import org.geotools.api.referencing.operation.MathTransform;
 
 import javax.media.jai.RenderedImageAdapter;
 import java.awt.image.RenderedImage;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URI;
@@ -178,22 +176,4 @@ public class Serde {
             return state.restore();
         }
     }
-
-    public static byte[] serializeGridSampleDimension(GridSampleDimension 
sampleDimension) {
-        Kryo kryo = kryos.get();
-        ByteArrayOutputStream baos = new ByteArrayOutputStream();
-        Output output = new Output(baos);
-        GridSampleDimensionSerializer serializer = new 
GridSampleDimensionSerializer();
-        serializer.write(kryo, output, sampleDimension);
-        output.close();
-        return baos.toByteArray();
-    }
-
-    public static GridSampleDimension deserializeGridSampleDimension(byte[] 
data) {
-        Kryo kryo = kryos.get();
-        Input input = new Input(new ByteArrayInputStream(data));
-        GridSampleDimensionSerializer serializer = new 
GridSampleDimensionSerializer();
-        return serializer.read(kryo, input, GridSampleDimension.class);
-    }
-
 }
diff --git 
a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java 
b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
index 7cc6c5080..ce6c3b604 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
@@ -556,7 +556,7 @@ public class RasterUtils {
      * @param bandValues
      * @return
      */
-    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D 
gridCoverage2D, Number[] bandValues, Double noDataValue) {
+    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D 
gridCoverage2D, Object bandValues, Double noDataValue) {
         // Get the original image and its properties
         RenderedImage originalImage = gridCoverage2D.getRenderedImage();
         Raster raster = getRaster(originalImage);
@@ -565,17 +565,19 @@ public class RasterUtils {
         // Copy the raster data and append the new band values
         for (int i = 0; i < raster.getWidth(); i++) {
             for (int j = 0; j < raster.getHeight(); j++) {
-                if (bandValues instanceof Double[]) {
+                if (bandValues instanceof double[]) {
+                    double[] values = (double[]) bandValues;
                     double[] pixels = raster.getPixel(i, j, (double[]) null);
                     double[] copiedPixels = new double[pixels.length + 1];
                     System.arraycopy(pixels, 0, copiedPixels, 0, 
pixels.length);
-                    copiedPixels[pixels.length] = (double) bandValues[j * 
raster.getWidth() + i];
+                    copiedPixels[pixels.length] = values[j * raster.getWidth() 
+ i];
                     wr.setPixel(i, j, copiedPixels);
-                } else if (bandValues instanceof Integer[]) {
+                } else if (bandValues instanceof int[]) {
+                    int[] values = (int[]) bandValues;
                     int[] pixels = raster.getPixel(i, j, (int[]) null);
                     int[] copiedPixels = new int[pixels.length + 1];
                     System.arraycopy(pixels, 0, copiedPixels, 0, 
pixels.length);
-                    copiedPixels[pixels.length] = (int) bandValues[j * 
raster.getWidth() + i];
+                    copiedPixels[pixels.length] = values[j * raster.getWidth() 
+ i];
                     wr.setPixel(i, j, copiedPixels);
                 }
             }
@@ -594,11 +596,11 @@ public class RasterUtils {
         return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, 
gridCoverage2D, null, true);
     }
 
-    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D 
gridCoverage2D, Number[] bandValues) {
+    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D 
gridCoverage2D, Object bandValues) {
         return copyRasterAndAppendBand(gridCoverage2D, bandValues, null);
     }
 
-    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D 
gridCoverage2D, int bandIndex, Number[] bandValues, Double noDataValue, boolean 
removeNoDataIfNull) {
+    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D 
gridCoverage2D, int bandIndex, Object bandValues, Double noDataValue, boolean 
removeNoDataIfNull) {
         // Do not allow the band index to be out of bounds
         ensureBand(gridCoverage2D, bandIndex);
         // Get the original image and its properties
@@ -608,13 +610,15 @@ public class RasterUtils {
         // Copy the raster data and replace the band values
         for (int i = 0; i < raster.getWidth(); i++) {
             for (int j = 0; j < raster.getHeight(); j++) {
-                if (bandValues instanceof Double[]) {
+                if (bandValues instanceof double[]) {
+                    double[] values = (double[]) bandValues;
                     double[] bands = raster.getPixel(i, j, (double[]) null);
-                    bands[bandIndex - 1] = (double) bandValues[j * 
raster.getWidth() + i];
+                    bands[bandIndex - 1] = values[j * raster.getWidth() + i];
                     wr.setPixel(i, j, bands);
-                } else if (bandValues instanceof Integer[]) {
+                } else if (bandValues instanceof int[]) {
+                    int[] values = (int[]) bandValues;
                     int[] bands = raster.getPixel(i, j, (int[]) null);
-                    bands[bandIndex - 1] = (int) bandValues[j * 
raster.getWidth() + i];
+                    bands[bandIndex - 1] = values[j * raster.getWidth() + i];
                     wr.setPixel(i, j, bands);
                 }
             }
@@ -629,7 +633,7 @@ public class RasterUtils {
         return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, 
gridCoverage2D, null, true);
     }
 
-    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D 
gridCoverage2D, int bandIndex, Number[] bandValues) {
+    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D 
gridCoverage2D, int bandIndex, Object bandValues) {
         return copyRasterAndReplaceBand(gridCoverage2D, bandIndex, bandValues, 
null, false);
     }
 
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
index 3bf132629..d8a6f4be4 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
@@ -19,13 +19,12 @@
 
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
-import org.apache.sedona.common.raster.serde.Serde
 import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
 import org.apache.sedona.common.utils.RasterUtils
+import org.apache.sedona.sql.utils.RasterSerializer
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
-import org.geotools.coverage.GridSampleDimension
 import org.geotools.coverage.grid.GridCoverage2D
 
 import java.awt.image.WritableRaster
@@ -33,12 +32,10 @@ import javax.media.jai.RasterFactory
 import scala.collection.mutable.ArrayBuffer
 
 case class BandData(
-                     var bandsData: Array[Array[Double]],
-                     var index: Int,
-                     var serializedRaster: Array[Byte],
-                     var serializedSampleDimensions: Array[Array[Byte]]
-                   )
-
+  index: Int,
+  width: Int,
+  height: Int,
+  serializedRaster: Array[Byte])
 
 /**
  * Return a raster containing bands at given indexes from all rasters in a 
given column
@@ -48,37 +45,32 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, 
Int), ArrayBuffer[BandDa
   def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
 
   def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): 
ArrayBuffer[BandData] = {
-      val raster = input._1
-      val renderedImage = raster.getRenderedImage
-      val numBands = renderedImage.getSampleModel.getNumBands
-      val width = renderedImage.getWidth
-      val height = renderedImage.getHeight
-
-      // First check if this is the first raster to set dimensions or validate 
against existing dimensions
-      if (buffer.nonEmpty) {
-        val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
-        val refWidth = RasterAccessors.getWidth(referenceRaster)
-        val refHeight = RasterAccessors.getHeight(referenceRaster)
-        if (width != refWidth || height != refHeight) {
-          throw new IllegalArgumentException("All rasters must have the same 
dimensions")
-        }
+    val (raster, index) = input
+    val renderedImage = raster.getRenderedImage
+    val width = renderedImage.getWidth
+    val height = renderedImage.getHeight
+    val serializedRaster = RasterSerializer.serialize(raster)
+    raster.dispose(true)
+
+    // First check if this is the first raster to set dimensions or validate 
against existing dimensions
+    if (buffer.nonEmpty) {
+      val refWidth = buffer.head.width
+      val refHeight = buffer.head.height
+      if (width != refWidth || height != refHeight) {
+        throw new IllegalArgumentException("All rasters must have the same 
dimensions")
       }
+    }
 
-      // Extract data for each band
-      val rasterData = renderedImage.getData
-      val bandsData = Array.ofDim[Double](numBands, width * height)
-      val serializedSampleDimensions = new Array[Array[Byte]](numBands)
+    buffer += BandData(index, width, height, serializedRaster)
+    buffer
+  }
 
-      for (band <- 0 until numBands) {
-        bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new 
Array[Double](width * height))
-        serializedSampleDimensions(band) = 
Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
+  def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): 
ArrayBuffer[BandData] = {
+    if (buffer1.nonEmpty && buffer2.nonEmpty) {
+      if (buffer1.head.width != buffer2.head.width || buffer1.head.height != 
buffer2.head.height) {
+        throw new IllegalArgumentException("All rasters must have the same 
dimensions")
       }
-
-      buffer += BandData(bandsData, input._2, Serde.serialize(raster), 
serializedSampleDimensions)
-      buffer
     }
-
-  def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): 
ArrayBuffer[BandData] = {
     val combined = ArrayBuffer.concat(buffer1, buffer2)
     if (combined.map(_.index).distinct.length != combined.length) {
       throw new IllegalArgumentException("Indexes shouldn't be repeated.")
@@ -95,24 +87,37 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, 
Int), ArrayBuffer[BandDa
       throw new IllegalArgumentException("Index should be in an arithmetic 
sequence.")
     }
 
-    val totalBands = sortedMerged.map(_.bandsData.length).sum
-    val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
-    val width = RasterAccessors.getWidth(referenceRaster)
-    val height = RasterAccessors.getHeight(referenceRaster)
-    val dataTypeCode = 
RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
-    val resultRaster: WritableRaster = 
RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
-    val gridSampleDimensions = 
sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray
-
-    var currentBand = 0
-    sortedMerged.foreach { bandData =>
-      bandData.bandsData.foreach { band =>
-        resultRaster.setSamples(0, 0, width, height, currentBand, band)
-        currentBand += 1
+    val rasters = sortedMerged.map(d => 
RasterSerializer.deserialize(d.serializedRaster))
+    try {
+      val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray
+      val totalBands = rasters.map(_.getNumSampleDimensions).sum
+      val referenceRaster = rasters.head
+      val width = RasterAccessors.getWidth(referenceRaster)
+      val height = RasterAccessors.getHeight(referenceRaster)
+      val dataTypeCode = 
RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
+      val resultRaster: WritableRaster = 
RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
+
+      var currentBand = 0
+      rasters.foreach { raster =>
+        var bandIndex = 0
+        while (bandIndex < raster.getNumSampleDimensions) {
+          if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
+            val band = 
RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, 
bandIndex, new Array[Int](width * height))
+            resultRaster.setSamples(0, 0, width, height, currentBand, band)
+          } else {
+            val band = 
RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, 
bandIndex, new Array[Double](width * height))
+            resultRaster.setSamples(0, 0, width, height, currentBand, band)
+          }
+          currentBand += 1
+          bandIndex += 1
+        }
       }
-    }
 
-    val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
-    RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, 
gridSampleDimensions, referenceRaster, noDataValue, true)
+      val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
+      RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, 
gridSampleDimensions, referenceRaster, noDataValue, false)
+    } finally {
+      rasters.foreach(_.dispose(true))
+    }
   }
 
   val serde = ExpressionEncoder[GridCoverage2D]

Reply via email to