This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c23245d78e2 [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB c23245d78e2 is described below commit c23245d78e25497ac6e8848ca400a920fed62144 Author: Ziqi Liu <ziqi....@databricks.com> AuthorDate: Tue Nov 15 20:54:20 2022 -0600 [SPARK-40622][SQL][CORE] Remove the limitation that single task result must fit in 2GB ### What changes were proposed in this pull request? Single task result must fit in 2GB, because we're using byte array or `ByteBuffer`(which is backed by byte array as well), and thus has a limit of 2GB(java array size limit on `byte[]`). This PR is trying to fix this by replacing byte array with `ChunkedByteBuffer`. ### Why are the changes needed? To overcome the 2GB limit for single task. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #38064 from liuzqt/SPARK-40622. Authored-by: Ziqi Liu <ziqi....@databricks.com> Signed-off-by: Mridul <mridul<at>gmail.com> --- .../scala/org/apache/spark/executor/Executor.scala | 19 ++++--- .../org/apache/spark/internal/config/package.scala | 2 + .../org/apache/spark/scheduler/TaskResult.scala | 27 ++++++---- .../apache/spark/scheduler/TaskResultGetter.scala | 14 ++--- .../apache/spark/serializer/SerializerHelper.scala | 54 +++++++++++++++++++ .../main/scala/org/apache/spark/util/Utils.scala | 45 ++++++++++------ .../apache/spark/util/io/ChunkedByteBuffer.scala | 62 ++++++++++++++++++++-- .../apache/spark/io/ChunkedByteBufferSuite.scala | 50 +++++++++++++++++ .../scheduler/SchedulerIntegrationSuite.scala | 3 +- .../spark/scheduler/TaskResultGetterSuite.scala | 2 +- .../spark/scheduler/TaskSchedulerImplSuite.scala | 8 +-- .../spark/scheduler/TaskSetManagerSuite.scala | 2 +- .../KryoSerializerResizableOutputSuite.scala | 16 +++--- project/SparkBuild.scala | 1 + .../spark/sql/catalyst/expressions/Cast.scala | 6 +-- .../org/apache/spark/sql/execution/SparkPlan.scala | 22 ++++---- .../scala/org/apache/spark/sql/DatasetSuite.scala | 30 ++++++++++- 17 files changed, 289 insertions(+), 74 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index db507bd176b..8d8a4592a3e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -48,10 +48,10 @@ import org.apache.spark.metrics.source.JVMCPUSource import org.apache.spark.resource.ResourceInformation import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler._ +import org.apache.spark.serializer.SerializerHelper import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher} import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ -import org.apache.spark.util.io.ChunkedByteBuffer /** * Spark executor, backed by a threadpool to run tasks. @@ -172,7 +172,7 @@ private[spark] class Executor( env.serializerManager.setDefaultClassLoader(replClassLoader) // Max size of direct result. If task result is bigger than this, we use the block manager - // to send the result back. + // to send the result back. This is guaranteed to be smaller than array bytes limit (2GB) private val maxDirectResultSize = Math.min( conf.get(TASK_MAX_DIRECT_RESULT_SIZE), RpcUtils.maxMessageSizeBytes(conf)) @@ -596,7 +596,7 @@ private[spark] class Executor( val resultSer = env.serializer.newInstance() val beforeSerializationNs = System.nanoTime() - val valueBytes = resultSer.serialize(value) + val valueByteBuffer = SerializerHelper.serializeToChunkedBuffer(resultSer, value) val afterSerializationNs = System.nanoTime() // Deserialization happens in two parts: first, we deserialize a Task object, which @@ -659,9 +659,11 @@ private[spark] class Executor( val accumUpdates = task.collectAccumulatorUpdates() val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId) // TODO: do not serialize value twice - val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks) - val serializedDirectResult = ser.serialize(directResult) - val resultSize = serializedDirectResult.limit() + val directResult = new DirectTaskResult(valueByteBuffer, accumUpdates, metricPeaks) + // try to estimate a reasonable upper bound of DirectTaskResult serialization + val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult, + valueByteBuffer.size + accumUpdates.size * 32 + metricPeaks.length * 8) + val resultSize = serializedDirectResult.size // directSend = sending directly back to the driver val serializedResult: ByteBuffer = { @@ -674,13 +676,14 @@ private[spark] class Executor( val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, - new ChunkedByteBuffer(serializedDirectResult.duplicate()), + serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) logInfo(s"Finished $taskName. $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { logInfo(s"Finished $taskName. $resultSize bytes result sent to driver") - serializedDirectResult + // toByteBuffer is safe here, guarded by maxDirectResultSize + serializedDirectResult.toByteBuffer } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 64801712c5f..ad899d7dfd6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -802,6 +802,8 @@ package object config { ConfigBuilder("spark.task.maxDirectResultSize") .version("2.0.0") .bytesConf(ByteUnit.BYTE) + .checkValue(_ < ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toLong, + "The max direct result size is 2GB") .createWithDefault(1L << 20) private[spark] val TASK_MAX_FAILURES = diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 11d969e1aba..e5ab74f544e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -24,20 +24,21 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv import org.apache.spark.metrics.ExecutorMetricType -import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.serializer.{SerializerHelper, SerializerInstance} import org.apache.spark.storage.BlockId import org.apache.spark.util.{AccumulatorV2, Utils} +import org.apache.spark.util.io.ChunkedByteBuffer // Task result. Also contains updates to accumulator variables and executor metric peaks. private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Long) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value, accumulator updates and metric peaks. */ private[spark] class DirectTaskResult[T]( - var valueBytes: ByteBuffer, + var valueByteBuffer: ChunkedByteBuffer, var accumUpdates: Seq[AccumulatorV2[_, _]], var metricPeaks: Array[Long]) extends TaskResult[T] with Externalizable { @@ -45,12 +46,18 @@ private[spark] class DirectTaskResult[T]( private var valueObjectDeserialized = false private var valueObject: T = _ - def this() = this(null.asInstanceOf[ByteBuffer], null, + def this( + valueByteBuffer: ByteBuffer, + accumUpdates: Seq[AccumulatorV2[_, _]], + metricPeaks: Array[Long]) = { + this(new ChunkedByteBuffer(Array(valueByteBuffer)), accumUpdates, metricPeaks) + } + + def this() = this(null.asInstanceOf[ChunkedByteBuffer], Seq(), new Array[Long](ExecutorMetricType.numMetrics)) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(valueBytes.remaining) - Utils.writeByteBuffer(valueBytes, out) + valueByteBuffer.writeExternal(out) out.writeInt(accumUpdates.size) accumUpdates.foreach(out.writeObject) out.writeInt(metricPeaks.length) @@ -58,10 +65,8 @@ private[spark] class DirectTaskResult[T]( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - val blen = in.readInt() - val byteVal = new Array[Byte](blen) - in.readFully(byteVal) - valueBytes = ByteBuffer.wrap(byteVal) + valueByteBuffer = new ChunkedByteBuffer() + valueByteBuffer.readExternal(in) val numUpdates = in.readInt if (numUpdates == 0) { @@ -100,7 +105,7 @@ private[spark] class DirectTaskResult[T]( // This should not run when holding a lock because it may cost dozens of seconds for a large // value val ser = if (resultSer == null) SparkEnv.get.serializer.newInstance() else resultSer - valueObject = ser.deserialize(valueBytes) + valueObject = SerializerHelper.deserializeFromChunkedBuffer(ser, valueByteBuffer) valueObjectDeserialized = true valueObject } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index cfc1f79fab2..a4f29395095 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging -import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.serializer.{SerializerHelper, SerializerInstance} import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils} /** @@ -63,7 +63,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { case directResult: DirectTaskResult[_] => - if (!taskSetManager.canFetchMoreResults(directResult.valueBytes.limit())) { + if (!taskSetManager.canFetchMoreResults(directResult.valueByteBuffer.size)) { // kill the task so that it will not become zombie task scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled( "Tasks result size has exceeded maxResultSize")) @@ -73,7 +73,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul // We should call it here, so that when it's called again in // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. directResult.value(taskResultSerializer.get()) - (directResult, serializedData.limit()) + (directResult, serializedData.limit().toLong) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { // dropped by executor if size is larger than maxResultSize @@ -94,8 +94,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul taskSetManager, tid, TaskState.FINISHED, TaskResultLost) return } - val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( - serializedTaskResult.get.toByteBuffer) + val deserializedResult = SerializerHelper + .deserializeFromChunkedBuffer[DirectTaskResult[_]]( + serializer.get(), + serializedTaskResult.get) // force deserialization of referenced value deserializedResult.value(taskResultSerializer.get()) sparkEnv.blockManager.master.removeBlock(blockId) @@ -109,7 +111,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { val acc = a.asInstanceOf[LongAccumulator] assert(acc.sum == 0L, "task result size should not have been set on the executors") - acc.setValue(size.toLong) + acc.setValue(size) acc } else { a diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala new file mode 100644 index 00000000000..2cff87990a4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerHelper.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import org.apache.spark.internal.Logging +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +private[spark] object SerializerHelper extends Logging { + + /** + * + * @param serializerInstance instance of SerializerInstance + * @param objectToSerialize the object to serialize, of type `T` + * @param estimatedSize estimated size of `t`, used as a hint to choose proper chunk size + */ + def serializeToChunkedBuffer[T: ClassTag]( + serializerInstance: SerializerInstance, + objectToSerialize: T, + estimatedSize: Long = -1): ChunkedByteBuffer = { + val chunkSize = ChunkedByteBuffer.estimateBufferChunkSize(estimatedSize) + val cbbos = new ChunkedByteBufferOutputStream(chunkSize, ByteBuffer.allocate) + val out = serializerInstance.serializeStream(cbbos) + out.writeObject(objectToSerialize) + out.close() + cbbos.close() + cbbos.toChunkedByteBuffer + } + + def deserializeFromChunkedBuffer[T: ClassTag]( + serializerInstance: SerializerInstance, + bytes: ChunkedByteBuffer): T = { + val in = serializerInstance.deserializeStream(bytes.toInputStream()) + in.readObject() + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f963727e79f..70477a5c9c0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -111,6 +111,12 @@ private[spark] object Utils extends Logging { private val PATTERN_FOR_COMMAND_LINE_ARG = "-D(.+?)=(.+)".r + private val COPY_BUFFER_LEN = 1024 + + private val copyBuffer = ThreadLocal.withInitial[Array[Byte]](() => { + new Array[Byte](COPY_BUFFER_LEN) + }) + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -237,34 +243,39 @@ private[spark] object Utils extends Logging { } } - /** - * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] - */ - def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + private def writeByteBufferImpl(bb: ByteBuffer, writer: (Array[Byte], Int, Int) => Unit): Unit = { if (bb.hasArray) { - out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + // Avoid extra copy if the bytebuffer is backed by bytes array + writer(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { + // Fallback to copy approach + val buffer = { + // reuse the copy buffer from thread local + copyBuffer.get() + } val originalPosition = bb.position() - val bbval = new Array[Byte](bb.remaining()) - bb.get(bbval) - out.write(bbval) + var bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN) + while (bytesToCopy > 0) { + bb.get(buffer, 0, bytesToCopy) + writer(buffer, 0, bytesToCopy) + bytesToCopy = Math.min(bb.remaining(), COPY_BUFFER_LEN) + } bb.position(originalPosition) } } + /** + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] + */ + def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = { + writeByteBufferImpl(bb, out.write) + } + /** * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]] */ def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = { - if (bb.hasArray) { - out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) - } else { - val originalPosition = bb.position() - val bbval = new Array[Byte](bb.remaining()) - bb.get(bbval) - out.write(bbval) - bb.position(originalPosition) - } + writeByteBufferImpl(bb, out.write) } /** diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 8635f1a3d70..73e4e72cc5b 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.io -import java.io.{File, FileInputStream, InputStream} +import java.io.{Externalizable, File, FileInputStream, InputStream, ObjectInput, ObjectOutput} import java.nio.ByteBuffer import java.nio.channels.WritableByteChannel @@ -42,8 +42,9 @@ import org.apache.spark.util.Utils * buffers may also be used elsewhere then the caller is responsible for copying * them as needed. */ -private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { +private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) extends Externalizable { require(chunks != null, "chunks must not be null") + require(!chunks.contains(null), "chunks must not contain null") require(chunks.forall(_.position() == 0), "chunks' positions must be 0") // Chunk size in bytes @@ -54,9 +55,16 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { private[this] var disposed: Boolean = false /** - * This size of this buffer, in bytes. + * This size of this buffer, in bytes. Using var here for serialization purpose (need to set a + * object after default construction) */ - val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum + private var _size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum + + def size: Long = _size + + def this() = { + this(Array.empty[ByteBuffer]) + } def this(byteBuffer: ByteBuffer) = { this(Array(byteBuffer)) @@ -84,6 +92,38 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } } + /** + * Writes to the provided ObjectOutput with zero copy if possible. + */ + override def writeExternal(out: ObjectOutput): Unit = { + // We want to keep the chunks layout + out.writeInt(chunks.length) + val chunksCopy = getChunks() + chunksCopy.foreach(buffer => out.writeInt(buffer.limit())) + chunksCopy.foreach(Utils.writeByteBuffer(_, out)) + } + + override def readExternal(in: ObjectInput): Unit = { + val chunksNum = in.readInt() + val indices = 0 until chunksNum + val chunksSize = indices.map(_ => in.readInt()) + val chunks = new Array[ByteBuffer](chunksNum) + + // We deserialize all chunks into on-heap buffer by default. If we have use case in the future + // where we want to preserve the on-heap/off-heap nature of chunks, then we need to record the + // `isDirect` property of each chunk during serialization + indices.foreach { i => + val chunkSize = chunksSize(i) + chunks(i) = { + val arr = new Array[Byte](chunkSize) + in.readFully(arr, 0, chunkSize) + ByteBuffer.wrap(arr) + } + } + this.chunks = chunks + this._size = chunks.map(_.limit().toLong).sum + } + /** * Wrap this in a custom "FileRegion" which allows us to transfer over 2 GB. */ @@ -171,6 +211,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } private[spark] object ChunkedByteBuffer { + private val CHUNK_BUFFER_SIZE: Int = 1024 * 1024 + private val MINIMUM_CHUNK_BUFFER_SIZE: Int = 1024 def fromManagedBuffer(data: ManagedBuffer): ChunkedByteBuffer = { data match { @@ -207,6 +249,18 @@ private[spark] object ChunkedByteBuffer { } out.toChunkedByteBuffer } + + /** + * Try to estimate appropriate chunk size so that it's not too large (waste memory) or too + * small (too many segments) + */ + def estimateBufferChunkSize(estimatedSize: Long = -1): Int = { + if (estimatedSize < 0) { + CHUNK_BUFFER_SIZE + } else { + Math.max(Math.min(estimatedSize, CHUNK_BUFFER_SIZE).toInt, MINIMUM_CHUNK_BUFFER_SIZE) + } + } } /** diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 083c5e696b7..68b181de292 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.io +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} import java.nio.ByteBuffer import com.google.common.io.ByteStreams @@ -28,6 +29,18 @@ import org.apache.spark.util.io.ChunkedByteBuffer class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { + /** + * compare two ChunkedByteBuffer: + * - chunks nums equal + * - each chunk's content + */ + def assertBufferEqual(buffer1: ChunkedByteBuffer, buffer2: ChunkedByteBuffer): Unit = { + assert(buffer1.chunks.length == buffer2.chunks.length) + assert(buffer1.chunks.zip(buffer2.chunks).forall { + case (chunk1, chunk2) => chunk1 == chunk2 + }) + } + test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) assert(emptyChunkedByteBuffer.size === 0) @@ -69,6 +82,43 @@ class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { } } + test("Externalizable: writeExternal() and readExternal()") { + // intentionally generate arrays of different len, in order to verify the chunks layout + // is preserved after ser/deser + val byteArrays = (1 to 15).map(i => (0 until i).map(_.toByte).toArray) + val chunkedByteBuffer = new ChunkedByteBuffer(byteArrays.map(ByteBuffer.wrap).toArray) + val baos = new ByteArrayOutputStream() + val objOut = new ObjectOutputStream(baos) + chunkedByteBuffer.writeExternal(objOut) + objOut.close() + assert(chunkedByteBuffer.chunks.forall(_.position() == 0)) + + val chunkedByteBuffer2 = { + val tmp = new ChunkedByteBuffer + tmp.readExternal(new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray))) + tmp + } + assertBufferEqual(chunkedByteBuffer, chunkedByteBuffer2) + } + + test( + "Externalizable: writeExternal() and readExternal() should handle off-heap buffer properly") { + val chunkedByteBuffer = new ChunkedByteBuffer( + (0 until 10).map(_ => ByteBuffer.allocateDirect(10)).toArray) + val baos = new ByteArrayOutputStream() + val objOut = new ObjectOutputStream(baos) + chunkedByteBuffer.writeExternal(objOut) + objOut.close() + + val chunkedByteBuffer2 = { + val tmp = new ChunkedByteBuffer + tmp.readExternal(new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray))) + tmp + } + + assertBufferEqual(chunkedByteBuffer, chunkedByteBuffer2) + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 9ed26e71256..dac675fd738 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -321,7 +321,8 @@ private[spark] abstract class MockBackend( def taskSuccess(task: TaskDescription, result: Any): Unit = { val ser = env.serializer.newInstance() val resultBytes = ser.serialize(result) - val directResult = new DirectTaskResult(resultBytes, Seq(), Array()) // no accumulator updates + // no accumulator updates + val directResult = new DirectTaskResult(resultBytes, Seq(), Array[Long]()) taskUpdate(task, TaskState.FINISHED, directResult) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 1583d3b96ee..1f61fab3e07 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -153,7 +153,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local override def canFetchMoreResults(size: Long): Boolean = false } val indirectTaskResult = IndirectTaskResult(TaskResultBlockId(0), 0) - val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array()) + val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array[Long]()) val ser = sc.env.closureSerializer.newInstance() val serializedIndirect = ser.serialize(indirectTaskResult) val serializedDirect = ser.serialize(directTaskResult) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 4e9e9755e85..b81f85bd1d7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -761,11 +761,13 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext } // End the other task of the taskset, doesn't matter whether it succeeds or fails. val otherTask = tasks(1) - val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq(), Array()) + val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq(), + Array[Long]()) tsm.handleSuccessfulTask(otherTask.taskId, result) } else { tasks.foreach { task => - val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq(), Array()) + val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq(), + Array[Long]()) tsm.handleSuccessfulTask(task.taskId, result) } } @@ -2131,7 +2133,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext assert(2 === taskDescriptions.length) val ser = sc.env.serializer.newInstance() - val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty) + val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), Array.empty[Long]) val resultBytes = ser.serialize(directResult) val busyTask = new Runnable { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 32a43b093ee..2dc7f0d0dfa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -882,7 +882,7 @@ class TaskSetManagerSuite assert(manager.runningTasks === 2) assert(manager.isZombie === false) - val directTaskResult = new DirectTaskResult[String](null, Seq(), Array()) { + val directTaskResult = new DirectTaskResult[String]() { override def value(resultSer: SerializerInstance): String = "" } // Complete one copy of the task, which should result in the task set manager diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 25f0b19c980..41c1131a280 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.LocalSparkContext._ -import org.apache.spark.SparkContext import org.apache.spark.SparkException import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Kryo._ @@ -34,9 +32,10 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") conf.set(KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") conf.set(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "1m") - withSpark(new SparkContext("local", "test", conf)) { sc => - intercept[SparkException](sc.parallelize(x).collect()) - } + + val ser = new KryoSerializer(conf) + val serInstance = ser.newInstance() + intercept[SparkException](serInstance.serialize(x)) } test("kryo with resizable output buffer should succeed on large array") { @@ -44,8 +43,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite { conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") conf.set(KRYO_SERIALIZER_BUFFER_SIZE.key, "1m") conf.set(KRYO_SERIALIZER_MAX_BUFFER_SIZE.key, "2m") - withSpark(new SparkContext("local", "test", conf)) { sc => - assert(sc.parallelize(x).collect() === x) - } + + val ser = new KryoSerializer(conf) + val serInstance = ser.newInstance() + serInstance.serialize(x) } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 18667d1efea..a63f52e5430 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1175,6 +1175,7 @@ object Unidoc { !f.getCanonicalPath.contains("org/apache/spark/unsafe/types/CalendarInterval"))) .map(_.filterNot(_.getCanonicalPath.contains("python"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/io"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect"))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 549bc70bac7..a302298d99c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -46,9 +46,9 @@ object Cast extends QueryErrorsBase { * As per section 6.13 "cast specification" in "Information technology — Database languages " + * "- SQL — Part 2: Foundation (SQL/Foundation)": * If the <cast operand> is a <value expression>, then the valid combinations of TD and SD - * in a <cast specification> are given by the following table. “Y” indicates that the - * combination is syntactically valid without restriction; “M” indicates that the combination - * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and “N” indicates + * in a <cast specification> are given by the following table. "Y" indicates that the + * combination is syntactically valid without restriction; "M" indicates that the combination + * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and "N" indicates * that the combination is not valid: * SD TD * EN AN C D T TS YM DT BO UDT B RT CT RW diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a56732fdc12..4aca67a17cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{DataInputStream, DataOutputStream} +import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, ListBuffer} @@ -38,6 +39,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} object SparkPlan { /** The original [[LogicalPlan]] from which this [[SparkPlan]] is converted. */ @@ -336,13 +338,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * compressed. */ private def getByteArrayRdd( - n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, Array[Byte])] = { + n: Int = -1, takeFromEnd: Boolean = false): RDD[(Long, ChunkedByteBuffer)] = { execute().mapPartitionsInternal { iter => var count = 0 val buffer = new Array[Byte](4 << 10) // 4K val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bos = new ByteArrayOutputStream() - val out = new DataOutputStream(codec.compressedOutputStream(bos)) + val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + val out = new DataOutputStream(codec.compressedOutputStream(cbbos)) if (takeFromEnd && n > 0) { // To collect n from the last, we should anyway read everything with keeping the n. @@ -371,19 +373,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ out.writeInt(-1) out.flush() out.close() - Iterator((count, bos.toByteArray)) + Iterator((count, cbbos.toChunkedByteBuffer)) } } /** * Decodes the byte arrays back to UnsafeRows and put them into buffer. */ - private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = { + private def decodeUnsafeRows(bytes: ChunkedByteBuffer): Iterator[InternalRow] = { val nFields = schema.length val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(codec.compressedInputStream(bis)) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) new NextIterator[InternalRow] { private var sizeOfNextRow = ins.readInt() @@ -503,8 +505,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ parts } val sc = sparkContext - val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) => - if (it.hasNext) it.next() else (0L, Array.emptyByteArray), partsToScan) + val res = sc.runJob(childRDD, (it: Iterator[(Long, ChunkedByteBuffer)]) => + if (it.hasNext) it.next() else (0L, new ChunkedByteBuffer()), partsToScan) var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8f5740e65ed..370e5ca546b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,13 +20,16 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import scala.util.Random + import org.apache.hadoop.fs.{Path, PathFilter} import org.scalatest.Assertions._ import org.scalatest.exceptions.TestFailedException import org.scalatest.prop.TableDrivenPropertyChecks._ -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{SparkConf, SparkException, TaskContext} import org.apache.spark.TestUtils.withListener +import org.apache.spark.internal.config.MAX_RESULT_SIZE import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} @@ -2228,6 +2231,31 @@ class DatasetSuite extends QueryTest } } +class DatasetLargeResultCollectingSuite extends QueryTest + with SharedSparkSession { + + override protected def sparkConf: SparkConf = super.sparkConf.set(MAX_RESULT_SIZE.key, "4g") + test("collect data with single partition larger than 2GB bytes array limit") { + // This test requires large memory and leads to OOM in Github Action so we skip it. Developer + // should verify it in local build. + assume(!sys.env.contains("GITHUB_ACTIONS")) + import org.apache.spark.sql.functions.udf + + val genData = udf((id: Long, bytesSize: Int) => { + val rand = new Random(id) + val arr = new Array[Byte](bytesSize) + rand.nextBytes(arr) + arr + }) + + spark.udf.register("genData", genData.asNondeterministic()) + // create data of size >2GB in single partition, which exceeds the byte array limit + // random gen to make sure it's poorly compressed + val df = spark.range(0, 2100, 1, 1).selectExpr("id", s"genData(id, 1000000) as data") + val res = df.queryExecution.executedPlan.executeCollect() + } +} + case class Bar(a: Int) object AssertExecutionId { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org