felipepessoto commented on code in PR #11461: URL: https://github.com/apache/gluten/pull/11461#discussion_r3054692961
########## backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/perf/GlutenDeltaOptimizedWriterExec.scala: ########## @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.delta.perf + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.backendsapi.velox.VeloxBatchType +import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.execution.{ValidatablePlan, ValidationResult} +import org.apache.gluten.extension.columnar.transition.Convention +import org.apache.gluten.vectorized.ColumnarBatchSerializerInstance + +// scalastyle:off import.ordering.noEmptyLine +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.ColumnarShuffleManager +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.BinPackingUtils +import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, ColumnarShuffleExchangeExec, GenerateTransformStageId} +import org.apache.spark.sql.execution.{ShuffledColumnarBatchRDD, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.storage._ +import org.apache.spark.util.ThreadUtils + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration + +/** Gluten's vectorized version of [[DeltaOptimizedWriterExec]]. */ +case class GlutenDeltaOptimizedWriterExec( + child: SparkPlan, + partitionColumns: Seq[String], + @transient deltaLog: DeltaLog +) extends ValidatablePlan + with UnaryExecNode + with DeltaLogging { + + override def output: Seq[Attribute] = child.output + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") + ) ++ readMetrics ++ writeMetrics + + private lazy val childNumPartitions = child.executeColumnar().getNumPartitions + + private lazy val numPartitions: Int = { + val targetShuffleBlocks = getConf(DeltaSQLConf.DELTA_OPTIMIZE_WRITE_SHUFFLE_BLOCKS) + math.min( + math.max(targetShuffleBlocks / childNumPartitions, 1), + getConf(DeltaSQLConf.DELTA_OPTIMIZE_WRITE_MAX_SHUFFLE_PARTITIONS)) + } + + @transient private var cachedShuffleRDD: ShuffledColumnarBatchRDD = _ + + @transient private lazy val mapTracker = SparkEnv.get.mapOutputTracker + + private lazy val columnarShufflePlan = { + val resolver = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + val saltedPartitioning = HashPartitioning( + partitionColumns.map( + p => + output + .find(o => resolver(p, o.name)) + .getOrElse(throw DeltaErrors.failedFindPartitionColumnInOutputPlan(p))), + numPartitions) + val shuffle = + ShuffleExchangeExec(saltedPartitioning, child) + val columnarShuffle = + BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(shuffle) + val columnarShuffleWithWst = + GenerateTransformStageId()( + ColumnarCollapseTransformStages(new GlutenConfig(conf))(columnarShuffle)) + columnarShuffleWithWst.asInstanceOf[ColumnarShuffleExchangeExec] + } + + /** Creates a ShuffledRowRDD for facilitating the shuffle in the map side. */ + private def getShuffleRDD: ShuffledColumnarBatchRDD = { + if (cachedShuffleRDD == null) { + val columnarShuffleRdd = + columnarShufflePlan.executeColumnar().asInstanceOf[ShuffledColumnarBatchRDD] + cachedShuffleRDD = columnarShuffleRdd + } + cachedShuffleRDD + } + + private def computeBins(): Array[List[(BlockManagerId, ArrayBuffer[(BlockId, Long, Int)])]] = { + // Get all shuffle information + val shuffleStats = getShuffleStats() + + // Group by blockId instead of block manager + val blockInfo = shuffleStats.flatMap { + case (bmId, blocks) => + blocks.map { + case (blockId, size, index) => + (blockId, (bmId, size, index)) + } + }.toMap + + val maxBinSize = + ByteUnit.BYTE.convertFrom(getConf(DeltaSQLConf.DELTA_OPTIMIZE_WRITE_BIN_SIZE), ByteUnit.MiB) + + val bins = shuffleStats.toSeq + .flatMap(_._2) + .groupBy(_._1.asInstanceOf[ShuffleBlockId].reduceId) + .flatMap { + case (_, blocks) => + BinPackingUtils.binPackBySize[(BlockId, Long, Int), BlockId]( + blocks, + _._2, // size + _._1, // blockId + maxBinSize) + } + + bins + .map { + bin => + var binSize = 0L + val blockLocations = + new mutable.HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long, Int)]]() + for (blockId <- bin) { + val (bmId, size, index) = blockInfo(blockId) + binSize += size + val blocksAtBM = + blockLocations.getOrElseUpdate(bmId, new ArrayBuffer[(BlockId, Long, Int)]()) + blocksAtBM.append((blockId, size, index)) + } + (binSize, blockLocations.toList) + } + .toArray + .sortBy(_._1)(Ordering[Long].reverse) // submit largest blocks first + .map(_._2) + } + + /** Performs the shuffle before the write, so that we can bin-pack output data. */ + private def getShuffleStats(): Array[(BlockManagerId, collection.Seq[(BlockId, Long, Int)])] = { + val dep = getShuffleRDD.dependency + // Gets the shuffle output stats + def getStats() = + mapTracker.getMapSizesByExecutorId(dep.shuffleId, 0, Int.MaxValue, 0, numPartitions).toArray + + // Executes the shuffle map stage in case we are missing output stats + def awaitShuffleMapStage(): Unit = { + assert(dep != null, "Shuffle dependency should not be null") + // hack to materialize the shuffle files in a fault tolerant way + ThreadUtils.awaitResult(sparkContext.submitMapStage(dep), Duration.Inf) + } + + try { + val res = getStats() + if (res.isEmpty) awaitShuffleMapStage() + getStats() + } catch { + case e: FetchFailedException => + logWarning(log"Failed to fetch shuffle blocks for the optimized writer. Retrying", e) + awaitShuffleMapStage() + getStats() + throw e Review Comment: @zhztheplayer is this `throw e` expected? It seems we have a retry, but after it, we rethrow the exception -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
