This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch raster-file-output in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 17e774c8d72ae045c33e4c6dcc29b0802d75ff9b Author: Jia Yu <[email protected]> AuthorDate: Fri Feb 9 14:12:24 2024 -0800 Fix --- .../spark/sql/sedona_sql/io/raster/RasterFileFormat.scala | 2 +- .../test/scala/org/apache/sedona/sql/TestBaseScala.scala | 13 +++++++++++++ .../src/test/scala/org/apache/sedona/sql/rasterIOTest.scala | 11 +++++++++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala index abf11c9ed..d7851b11d 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/raster/RasterFileFormat.scala @@ -104,7 +104,7 @@ private class RasterFileWriter(savePath: String, val rasterFilePath = getRasterFilePath(row, dataSchema, rasterOptions) // write the image to file try { - val out = hfs.create(new Path(Paths.get(savePath, new Path(rasterFilePath).getName).toString)) + val out = hfs.create(new Path(savePath, new Path(rasterFilePath).getName)) out.write(rasterRaw) out.close() } catch { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 8dd4f743f..fec235696 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -19,6 +19,8 @@ package org.apache.sedona.sql import com.google.common.math.DoubleMath +import org.apache.hadoop.fs.FileUtil +import org.apache.hadoop.hdfs.{HdfsConfiguration, MiniDFSCluster} import org.apache.log4j.{Level, Logger} import org.apache.sedona.common.Functions.{frechetDistance, hausdorffDistance} import org.apache.sedona.common.Predicates.dWithin @@ -28,6 +30,8 @@ import org.apache.spark.sql.DataFrame import org.locationtech.jts.geom._ import org.scalatest.{BeforeAndAfterAll, FunSpec} +import java.io.File + trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getRootLogger.setLevel(Level.WARN) Logger.getLogger("org.apache").setLevel(Level.WARN) @@ -74,10 +78,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { val buildingDataLocation: String = resourceFolder + "813_buildings_test.csv" val smallRasterDataLocation: String = resourceFolder + "raster/test1.tiff" private val factory = new GeometryFactory() + var hdfsURI: String = _ override def beforeAll(): Unit = { SedonaContext.create(sparkSession) + // Set up HDFS minicluster + val baseDir = new File("./target/hdfs/").getAbsoluteFile + FileUtil.fullyDelete(baseDir) + val hdfsConf = new HdfsConfiguration + hdfsConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath) + val builder = new MiniDFSCluster.Builder(hdfsConf) + val hdfsCluster = builder.build + hdfsURI = "hdfs://127.0.0.1:" + hdfsCluster.getNameNodePort + "/" } override def afterAll(): Unit = { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala index 79d1c6dee..d5203e6a0 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala @@ -19,9 +19,7 @@ package org.apache.sedona.sql import org.apache.commons.io.FileUtils -import org.apache.sedona.common.raster.RasterAccessors import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.sedona_sql.expressions.raster.RS_Metadata import org.junit.Assert.assertEquals import org.scalatest.{BeforeAndAfter, GivenWhenThen} @@ -149,6 +147,15 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen rasterDf = df.selectExpr("RS_FromArcInfoAsciiGrid(content)") assert(rasterDf.count() == rasterCount) } + + it("should read geotiff using binary source and write geotiff back to hdfs using raster source") { + var rasterDf = sparkSession.read.format("binaryFile").load(rasterdatalocation) + val rasterCount = rasterDf.count() + rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(hdfsURI + "/raster-written") + rasterDf = sparkSession.read.format("binaryFile").load(hdfsURI + "/raster-written/*") + rasterDf = rasterDf.selectExpr("RS_FromGeoTiff(content)") + assert(rasterDf.count() == rasterCount) + } } override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir))
