This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 56a3d0ff7 [SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when
working with non-double band data (#1402)
56a3d0ff7 is described below
commit 56a3d0ff7f29a01d4ccc3e2b359f0518c7325de7
Author: Jia Yu <[email protected]>
AuthorDate: Tue May 7 19:43:19 2024 -0700
[SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when working with
non-double band data (#1402)
Co-authored-by: Kristin Cowalcijk <[email protected]>
---
.../apache/sedona/common/raster/MapAlgebra.java | 10 +-
.../sedona/common/raster/RasterBandEditors.java | 8 +-
.../apache/sedona/common/raster/serde/Serde.java | 20 ----
.../apache/sedona/common/utils/RasterUtils.java | 28 +++---
.../expressions/raster/AggregateFunctions.scala | 105 +++++++++++----------
5 files changed, 79 insertions(+), 92 deletions(-)
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
index 564077a36..b5fb469cd 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
@@ -78,12 +78,11 @@ public class MapAlgebra
throw new IllegalArgumentException("Band index is out of bounds.
Must be between 1 and " + (numBands + 1) + ")");
}
- Double[] bandValuesClass =
Arrays.stream(bandValues).boxed().toArray(Double[]::new);
if (bandIndex == numBands + 1) {
- return RasterUtils.copyRasterAndAppendBand(rasterGeom,
bandValuesClass, noDataValue);
+ return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues,
noDataValue);
}
else {
- return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex,
bandValuesClass, noDataValue, true);
+ return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex,
bandValues, noDataValue, true);
}
}
@@ -94,12 +93,11 @@ public class MapAlgebra
throw new IllegalArgumentException("Band index is out of bounds.
Must be between 1 and " + (numBands + 1) + ")");
}
- Double[] bandValuesClass =
Arrays.stream(bandValues).boxed().toArray(Double[]::new);
if (bandIndex == numBands + 1) {
- return RasterUtils.copyRasterAndAppendBand(rasterGeom,
bandValuesClass);
+ return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues);
}
else {
- return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex,
bandValuesClass);
+ return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex,
bandValues);
}
}
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
index 4ffac4d30..41b159a5d 100644
---
a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
+++
b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
@@ -135,16 +135,16 @@ public class RasterBandEditors {
if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
int[] bandValues = rasterData.getSamples(0, 0, width, height,
fromBand - 1, (int[]) null);
if (numBands + 1 == toRasterIndex) {
- return RasterUtils.copyRasterAndAppendBand(toRaster,
Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue);
+ return RasterUtils.copyRasterAndAppendBand(toRaster,
bandValues, noDataValue);
} else {
- return RasterUtils.copyRasterAndReplaceBand(toRaster,
fromBand, Arrays.stream(bandValues).boxed().toArray(Integer[]::new),
noDataValue, false);
+ return RasterUtils.copyRasterAndReplaceBand(toRaster,
fromBand, bandValues, noDataValue, false);
}
} else {
double[] bandValues = rasterData.getSamples(0, 0, width, height,
fromBand - 1, (double[]) null);
if (numBands + 1 == toRasterIndex) {
- return RasterUtils.copyRasterAndAppendBand(toRaster,
Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue);
+ return RasterUtils.copyRasterAndAppendBand(toRaster,
bandValues, noDataValue);
} else {
- return RasterUtils.copyRasterAndReplaceBand(toRaster,
fromBand, Arrays.stream(bandValues).boxed().toArray(Double[]::new),
noDataValue, false);
+ return RasterUtils.copyRasterAndReplaceBand(toRaster,
fromBand, bandValues, noDataValue, false);
}
}
}
diff --git
a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
index e775ee46a..7f67708b6 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -32,8 +32,6 @@ import org.geotools.api.referencing.operation.MathTransform;
import javax.media.jai.RenderedImageAdapter;
import java.awt.image.RenderedImage;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
@@ -178,22 +176,4 @@ public class Serde {
return state.restore();
}
}
-
- public static byte[] serializeGridSampleDimension(GridSampleDimension
sampleDimension) {
- Kryo kryo = kryos.get();
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- Output output = new Output(baos);
- GridSampleDimensionSerializer serializer = new
GridSampleDimensionSerializer();
- serializer.write(kryo, output, sampleDimension);
- output.close();
- return baos.toByteArray();
- }
-
- public static GridSampleDimension deserializeGridSampleDimension(byte[]
data) {
- Kryo kryo = kryos.get();
- Input input = new Input(new ByteArrayInputStream(data));
- GridSampleDimensionSerializer serializer = new
GridSampleDimensionSerializer();
- return serializer.read(kryo, input, GridSampleDimension.class);
- }
-
}
diff --git
a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
index 7cc6c5080..ce6c3b604 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
@@ -556,7 +556,7 @@ public class RasterUtils {
* @param bandValues
* @return
*/
- public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D
gridCoverage2D, Number[] bandValues, Double noDataValue) {
+ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D
gridCoverage2D, Object bandValues, Double noDataValue) {
// Get the original image and its properties
RenderedImage originalImage = gridCoverage2D.getRenderedImage();
Raster raster = getRaster(originalImage);
@@ -565,17 +565,19 @@ public class RasterUtils {
// Copy the raster data and append the new band values
for (int i = 0; i < raster.getWidth(); i++) {
for (int j = 0; j < raster.getHeight(); j++) {
- if (bandValues instanceof Double[]) {
+ if (bandValues instanceof double[]) {
+ double[] values = (double[]) bandValues;
double[] pixels = raster.getPixel(i, j, (double[]) null);
double[] copiedPixels = new double[pixels.length + 1];
System.arraycopy(pixels, 0, copiedPixels, 0,
pixels.length);
- copiedPixels[pixels.length] = (double) bandValues[j *
raster.getWidth() + i];
+ copiedPixels[pixels.length] = values[j * raster.getWidth()
+ i];
wr.setPixel(i, j, copiedPixels);
- } else if (bandValues instanceof Integer[]) {
+ } else if (bandValues instanceof int[]) {
+ int[] values = (int[]) bandValues;
int[] pixels = raster.getPixel(i, j, (int[]) null);
int[] copiedPixels = new int[pixels.length + 1];
System.arraycopy(pixels, 0, copiedPixels, 0,
pixels.length);
- copiedPixels[pixels.length] = (int) bandValues[j *
raster.getWidth() + i];
+ copiedPixels[pixels.length] = values[j * raster.getWidth()
+ i];
wr.setPixel(i, j, copiedPixels);
}
}
@@ -594,11 +596,11 @@ public class RasterUtils {
return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions,
gridCoverage2D, null, true);
}
- public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D
gridCoverage2D, Number[] bandValues) {
+ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D
gridCoverage2D, Object bandValues) {
return copyRasterAndAppendBand(gridCoverage2D, bandValues, null);
}
- public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D
gridCoverage2D, int bandIndex, Number[] bandValues, Double noDataValue, boolean
removeNoDataIfNull) {
+ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D
gridCoverage2D, int bandIndex, Object bandValues, Double noDataValue, boolean
removeNoDataIfNull) {
// Do not allow the band index to be out of bounds
ensureBand(gridCoverage2D, bandIndex);
// Get the original image and its properties
@@ -608,13 +610,15 @@ public class RasterUtils {
// Copy the raster data and replace the band values
for (int i = 0; i < raster.getWidth(); i++) {
for (int j = 0; j < raster.getHeight(); j++) {
- if (bandValues instanceof Double[]) {
+ if (bandValues instanceof double[]) {
+ double[] values = (double[]) bandValues;
double[] bands = raster.getPixel(i, j, (double[]) null);
- bands[bandIndex - 1] = (double) bandValues[j *
raster.getWidth() + i];
+ bands[bandIndex - 1] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, bands);
- } else if (bandValues instanceof Integer[]) {
+ } else if (bandValues instanceof int[]) {
+ int[] values = (int[]) bandValues;
int[] bands = raster.getPixel(i, j, (int[]) null);
- bands[bandIndex - 1] = (int) bandValues[j *
raster.getWidth() + i];
+ bands[bandIndex - 1] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, bands);
}
}
@@ -629,7 +633,7 @@ public class RasterUtils {
return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions,
gridCoverage2D, null, true);
}
- public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D
gridCoverage2D, int bandIndex, Number[] bandValues) {
+ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D
gridCoverage2D, int bandIndex, Object bandValues) {
return copyRasterAndReplaceBand(gridCoverage2D, bandIndex, bandValues,
null, false);
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
index 3bf132629..d8a6f4be4 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
@@ -19,13 +19,12 @@
package org.apache.spark.sql.sedona_sql.expressions.raster
-import org.apache.sedona.common.raster.serde.Serde
import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
+import org.apache.sedona.sql.utils.RasterSerializer
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
-import org.geotools.coverage.GridSampleDimension
import org.geotools.coverage.grid.GridCoverage2D
import java.awt.image.WritableRaster
@@ -33,12 +32,10 @@ import javax.media.jai.RasterFactory
import scala.collection.mutable.ArrayBuffer
case class BandData(
- var bandsData: Array[Array[Double]],
- var index: Int,
- var serializedRaster: Array[Byte],
- var serializedSampleDimensions: Array[Array[Byte]]
- )
-
+ index: Int,
+ width: Int,
+ height: Int,
+ serializedRaster: Array[Byte])
/**
* Return a raster containing bands at given indexes from all rasters in a
given column
@@ -48,37 +45,32 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D,
Int), ArrayBuffer[BandDa
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)):
ArrayBuffer[BandData] = {
- val raster = input._1
- val renderedImage = raster.getRenderedImage
- val numBands = renderedImage.getSampleModel.getNumBands
- val width = renderedImage.getWidth
- val height = renderedImage.getHeight
-
- // First check if this is the first raster to set dimensions or validate
against existing dimensions
- if (buffer.nonEmpty) {
- val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
- val refWidth = RasterAccessors.getWidth(referenceRaster)
- val refHeight = RasterAccessors.getHeight(referenceRaster)
- if (width != refWidth || height != refHeight) {
- throw new IllegalArgumentException("All rasters must have the same
dimensions")
- }
+ val (raster, index) = input
+ val renderedImage = raster.getRenderedImage
+ val width = renderedImage.getWidth
+ val height = renderedImage.getHeight
+ val serializedRaster = RasterSerializer.serialize(raster)
+ raster.dispose(true)
+
+ // First check if this is the first raster to set dimensions or validate
against existing dimensions
+ if (buffer.nonEmpty) {
+ val refWidth = buffer.head.width
+ val refHeight = buffer.head.height
+ if (width != refWidth || height != refHeight) {
+ throw new IllegalArgumentException("All rasters must have the same
dimensions")
}
+ }
- // Extract data for each band
- val rasterData = renderedImage.getData
- val bandsData = Array.ofDim[Double](numBands, width * height)
- val serializedSampleDimensions = new Array[Array[Byte]](numBands)
+ buffer += BandData(index, width, height, serializedRaster)
+ buffer
+ }
- for (band <- 0 until numBands) {
- bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new
Array[Double](width * height))
- serializedSampleDimensions(band) =
Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
+ def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]):
ArrayBuffer[BandData] = {
+ if (buffer1.nonEmpty && buffer2.nonEmpty) {
+ if (buffer1.head.width != buffer2.head.width || buffer1.head.height !=
buffer2.head.height) {
+ throw new IllegalArgumentException("All rasters must have the same
dimensions")
}
-
- buffer += BandData(bandsData, input._2, Serde.serialize(raster),
serializedSampleDimensions)
- buffer
}
-
- def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]):
ArrayBuffer[BandData] = {
val combined = ArrayBuffer.concat(buffer1, buffer2)
if (combined.map(_.index).distinct.length != combined.length) {
throw new IllegalArgumentException("Indexes shouldn't be repeated.")
@@ -95,24 +87,37 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D,
Int), ArrayBuffer[BandDa
throw new IllegalArgumentException("Index should be in an arithmetic
sequence.")
}
- val totalBands = sortedMerged.map(_.bandsData.length).sum
- val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
- val width = RasterAccessors.getWidth(referenceRaster)
- val height = RasterAccessors.getHeight(referenceRaster)
- val dataTypeCode =
RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
- val resultRaster: WritableRaster =
RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
- val gridSampleDimensions =
sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray
-
- var currentBand = 0
- sortedMerged.foreach { bandData =>
- bandData.bandsData.foreach { band =>
- resultRaster.setSamples(0, 0, width, height, currentBand, band)
- currentBand += 1
+ val rasters = sortedMerged.map(d =>
RasterSerializer.deserialize(d.serializedRaster))
+ try {
+ val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray
+ val totalBands = rasters.map(_.getNumSampleDimensions).sum
+ val referenceRaster = rasters.head
+ val width = RasterAccessors.getWidth(referenceRaster)
+ val height = RasterAccessors.getHeight(referenceRaster)
+ val dataTypeCode =
RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
+ val resultRaster: WritableRaster =
RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
+
+ var currentBand = 0
+ rasters.foreach { raster =>
+ var bandIndex = 0
+ while (bandIndex < raster.getNumSampleDimensions) {
+ if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
+ val band =
RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height,
bandIndex, new Array[Int](width * height))
+ resultRaster.setSamples(0, 0, width, height, currentBand, band)
+ } else {
+ val band =
RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height,
bandIndex, new Array[Double](width * height))
+ resultRaster.setSamples(0, 0, width, height, currentBand, band)
+ }
+ currentBand += 1
+ bandIndex += 1
+ }
}
- }
- val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
- RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry,
gridSampleDimensions, referenceRaster, noDataValue, true)
+ val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
+ RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry,
gridSampleDimensions, referenceRaster, noDataValue, false)
+ } finally {
+ rasters.foreach(_.dispose(true))
+ }
}
val serde = ExpressionEncoder[GridCoverage2D]