Github user squito commented on a diff in the pull request: https://github.com/apache/spark/pull/11105#discussion_r86489315 --- Diff: core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala --- @@ -136,15 +181,92 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { def reset(): Unit /** + * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator. + * Developers should extend addImpl to customize the adding functionality. + */ + final def add(v: IN): Unit = { + if (metadata != null && metadata.dataProperty) { + dataPropertyAdd(v) + } else { + addImpl(v) + } + } + + private def dataPropertyAdd(v: IN): Unit = { + // To allow the user to be able to access the current accumulated value from their process + // worker side then we need to perform a "normal" add as well as the data property add. + addImpl(v) + // Add to the pending updates for data property + val updateInfo = TaskContext.get().getRDDPartitionInfo() + val base = pending.getOrElse(updateInfo, copyAndReset()) + // Since we may have constructed a new accumulator, set atDriverSide to false as the default + // new accumulators will have atDriverSide equal to true. + base.atDriverSide = false + base.addImpl(v) + pending(updateInfo) = base + } + + /** + * Mark a specific rdd/shuffle/partition as completely processed. This is a noop for + * non-data property accumuables. + */ + private[spark] def markFullyProcessed(taskOutputId: TaskOutputId): Unit = { + if (metadata.dataProperty) { + completed += taskOutputId + } + } + + /** + * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator. + * Developers should extend addImpl to customize the adding functionality. * Takes the inputs and accumulates. */ - def add(v: IN): Unit + protected[spark] def addImpl(v: IN) + + /** + * Merges another same-type accumulator into this one and update its state, i.e. this should be + * merge-in-place. Developers should extend mergeImpl to customize the merge functionality. + */ + final private[spark] lazy val merge: (AccumulatorV2[IN, OUT] => Unit) = { + assert(isAtDriverSide) + // Handle data property accumulators + if (metadata != null && metadata.dataProperty) { + dataPropertyMerge _ + } else { + mergeImpl _ + } + } + + final private[spark] def dataPropertyMerge(other: AccumulatorV2[IN, OUT]) = { + // Apply all foreach partitions regardless - they can only be fully evaluated + val unprocessed = other.pending.filter{ + case (ForeachOutputId(), v) => mergeImpl(v); false + case _ => true + } + val term = unprocessed.filter{case (k, v) => other.completed.contains(k)} + term.flatMap { + case (RDDOutputId(rddId, splitId), v) => + Some((rddProcessed, rddId, splitId, v)) + case (ShuffleMapOutputId(shuffleWriteId, splitId), v) => + Some((shuffleProcessed, shuffleWriteId, splitId, v)) + case _ => // We won't ever hit this case but avoid compiler warnings + None + }.foreach { + case (processed, id, splitId, v) => + val splits = processed.getOrElseUpdate(id, new mutable.BitSet()) + if (!splits.contains(splitId)) { + splits += splitId + mergeImpl(v) + } + } --- End diff -- this methods feels more complicated than it needs to be -- what do you think of this version? there is a small amount of duplication but I think its easier to follow: ```scala final private[spark] def dataPropertyMerge(other: AccumulatorV2[IN, OUT]) = { def mergeAccumUpdateAndMarkOutputAsProcessed( partitionsAlreadyMerged: mutable.BitSet, outputId: TaskOutputId, accumUpdate: AccumulatorV2[IN, OUT] ): Unit = { // we don't merge in accumulator updates from an incomplete accumulator update, eg. a take() // which only partially reads an rdd partition if (other.completed.contains(outputId)) { // has this partition been processed before? if (!partitionsAlreadyMerged.contains(outputId.partition)) { partitionsAlreadyMerged += outputId.partition mergeImpl(accumUpdate) } } } other.pending.foreach { // Apply all foreach partitions regardless - they can only be fully evaluated case (ForeachOutputId, accumUpdate) => mergeImpl(accumUpdate) // For RDDs & shuffles, apply the accumulator updates as long as the output is complete // and its the first time we're seeing it (just slightly different bookkeeping between // RDDs and shuffles). case (rddOutput: RDDOutputId, accumUpdate) => val processed = rddProcessed.getOrElseUpdate(rddOutput.rddId, new mutable.BitSet()) mergeAccumUpdateAndMarkOutputAsProcessed(processed, rddOutput, accumUpdate) case (shuffleOutput: ShuffleMapOutputId, accumUpdate) => val processed = rddProcessed.getOrElseUpdate(shuffleOutput.shuffleId, new mutable.BitSet()) mergeAccumUpdateAndMarkOutputAsProcessed(processed, shuffleOutput, accumUpdate) } } ``` (I added `def partitionId: Int` to `TaskOutputId`)
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org