Github user lianhuiwang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1499#discussion_r15352745
  
    --- Diff: 
core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala ---
    @@ -0,0 +1,156 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.shuffle.sort
    +
    +import java.io.{BufferedOutputStream, File, FileOutputStream, 
DataOutputStream}
    +
    +import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
    +import org.apache.spark.executor.ShuffleWriteMetrics
    +import org.apache.spark.scheduler.MapStatus
    +import org.apache.spark.serializer.Serializer
    +import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
    +import org.apache.spark.storage.ShuffleBlockId
    +import org.apache.spark.util.collection.ExternalSorter
    +
    +private[spark] class SortShuffleWriter[K, V, C](
    +    handle: BaseShuffleHandle[K, V, C],
    +    mapId: Int,
    +    context: TaskContext)
    +  extends ShuffleWriter[K, V] with Logging {
    +
    +  private val dep = handle.dependency
    +  private val numPartitions = dep.partitioner.numPartitions
    +
    +  private val blockManager = SparkEnv.get.blockManager
    +  private val ser = 
Serializer.getSerializer(dep.serializer.getOrElse(null))
    +
    +  private val conf = SparkEnv.get.conf
    +  private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 
100) * 1024
    +
    +  private var sorter: ExternalSorter[K, V, _] = null
    +  private var outputFile: File = null
    +
    +  private var stopping = false
    +  private var mapStatus: MapStatus = null
    +
    +  /** Write a bunch of records to this task's output */
    +  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    +    val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
    +      if (dep.mapSideCombine) {
    +        if (!dep.aggregator.isDefined) {
    +          throw new IllegalStateException("Aggregator is empty for 
map-side combine")
    +        }
    +        sorter = new ExternalSorter[K, V, C](
    +          dep.aggregator, Some(dep.partitioner), dep.keyOrdering, 
dep.serializer)
    +        sorter.write(records)
    +        sorter.partitionedIterator
    +      } else {
    +        // In this case we pass neither an aggregator nor an ordering to 
the sorter, because we
    +        // don't care whether the keys get sorted in each partition; that 
will be done on the
    +        // reduce side if the operation being run is sortByKey.
    +        sorter = new ExternalSorter[K, V, V](
    +          None, Some(dep.partitioner), None, dep.serializer)
    +        sorter.write(records)
    +        sorter.partitionedIterator
    +      }
    +    }
    +
    +    // Create a single shuffle file with reduce ID 0 that we'll write all 
results to. We'll later
    +    // serve different ranges of this file using an index file that we 
create at the end.
    +    val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
    +    outputFile = blockManager.diskBlockManager.getFile(blockId)
    +
    +    // Track location of each range in the output file
    +    val offsets = new Array[Long](numPartitions + 1)
    +    val lengths = new Array[Long](numPartitions)
    +
    +    // Statistics
    +    var totalBytes = 0L
    +    var totalTime = 0L
    +
    +    for ((id, elements) <- partitions) {
    +      if (elements.hasNext) {
    +        val writer = blockManager.getDiskWriter(blockId, outputFile, ser, 
fileBufferSize)
    +        for (elem <- elements) {
    +          writer.write(elem)
    +        }
    +        writer.commit()
    +        writer.close()
    +        val segment = writer.fileSegment()
    +        offsets(id + 1) = segment.offset + segment.length
    +        lengths(id) = segment.length
    +        totalTime += writer.timeWriting()
    +        totalBytes += segment.length
    +      } else {
    +        // Don't create a new writer to avoid writing any headers and 
things like that
    +        offsets(id + 1) = offsets(id)
    +      }
    +    }
    +
    +    val shuffleMetrics = new ShuffleWriteMetrics
    +    shuffleMetrics.shuffleBytesWritten = totalBytes
    +    shuffleMetrics.shuffleWriteTime = totalTime
    +    context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
    +    context.taskMetrics.memoryBytesSpilled = sorter.memoryBytesSpilled
    +    context.taskMetrics.diskBytesSpilled = sorter.diskBytesSpilled
    +
    +    // Write an index file with the offsets of each block, plus a final 
offset at the end for the
    +    // end of the output file. This will be used by 
SortShuffleManager.getBlockLocation to figure
    +    // out where each block begins and ends.
    +
    +    val diskBlockManager = blockManager.diskBlockManager
    +    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
    +    val out = new DataOutputStream(new BufferedOutputStream(new 
FileOutputStream(indexFile)))
    +    try {
    +      var i = 0
    +      while (i < numPartitions + 1) {
    +        out.writeLong(offsets(i))
    +        i += 1
    +      }
    +    } finally {
    +      out.close()
    +    }
    +
    +    mapStatus = new MapStatus(blockManager.blockManagerId,
    +      lengths.map(MapOutputTracker.compressSize))
    --- End diff --
    
    maybe can create new function instead of MapOutputTracker.compressSize? so 
here it do not refer to MapOutputTracker.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to