Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/12612#discussion_r62410489 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala --- @@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat -import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext} +import org.apache.spark.{NewAccumulator, SparkContext} import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.util.Utils -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - * - * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. - */ -private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( - name: String, - val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name)) { - // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later - override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { - new AccumulableInfo(id, Some(name), update, value, true, countFailedValues, - Some(SQLMetrics.ACCUM_IDENTIFIER)) - } - - def reset(): Unit = { - this.value = param.zero - } -} - -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - */ -private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { - - /** - * A function that defines how we aggregate the final accumulator results among all tasks, - * and represent it in string for a SQL physical operator. - */ - val stringValue: Seq[T] => String - - def zero: R -} +class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] { + // This is a workaround for SPARK-11013. + // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will + // update it at the end of task and the value will be at least 0. Then we can filter out the -1 + // values before calculate max, min, etc. + private[this] var _value = initValue -/** - * Create a layer for specialized metric. We cannot add `@specialized` to - * `Accumulable/AccumulableParam` because it will break Java source compatibility. - */ -private[sql] trait SQLMetricValue[T] extends Serializable { + override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue) - def value: T - - override def toString: String = value.toString -} - -/** - * A wrapper of Long to avoid boxing and unboxing when using Accumulator - */ -private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { - - def add(incr: Long): LongSQLMetricValue = { - _value += incr - this + override def merge(other: NewAccumulator[Long, Long]): Unit = other match { + case o: SQLMetric => _value += o.localValue + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - // Although there is a boxing here, it's fine because it's only called in SQLListener - override def value: Long = _value - - // Needed for SQLListenerSuite - override def equals(other: Any): Boolean = other match { - case o: LongSQLMetricValue => value == o.value - case _ => false - } + override def isZero(): Boolean = _value == initValue - override def hashCode(): Int = _value.hashCode() -} + override def add(v: Long): Unit = _value += v -/** - * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's - * `+=` and `add`. - */ -private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam) - extends SQLMetric[LongSQLMetricValue, Long](name, param) { + def +=(v: Long): Unit = _value += v - override def +=(term: Long): Unit = { - localValue.add(term) - } + override def localValue: Long = _value - override def add(term: Long): Unit = { - localValue.add(term) + // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later + private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER)) } -} - -private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long) - extends SQLMetricParam[LongSQLMetricValue, Long] { - - override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) - override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = - r1.add(r2.value) - - override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero - - override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue) + def reset(): Unit = _value = initValue } -private object LongSQLMetricParam - extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L) - -private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.bytesToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) - -private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam( - (values: Seq[Long]) => { - // This is a workaround for SPARK-11013. - // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update - // it at the end of task and the value will be at least 0. - val validValues = values.filter(_ >= 0) - val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { - Seq.fill(4)(0L) - } else { - val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) - } - metric.map(Utils.msDurationToString) - } - s"\n$sum ($min, $med, $max)" - }, -1L) private[sql] object SQLMetrics { - // Identifier for distinguishing SQL metrics from other accumulators private[sql] val ACCUM_IDENTIFIER = "sql" - private def createLongMetric( - sc: SparkContext, - name: String, - param: LongSQLMetricParam): LongSQLMetric = { - val acc = new LongSQLMetric(name, param) - // This is an internal accumulator so we need to register it explicitly. - Accumulators.register(acc) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) - acc - } + private[sql] val SUM_METRIC = "sum" + private[sql] val SIZE_METRIC = "size" + private[sql] val TIMING_METRIC = "timing" - def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { - createLongMetric(sc, name, LongSQLMetricParam) + def createMetric(sc: SparkContext, name: String): SQLMetric = { + val acc = new SQLMetric(SUM_METRIC) + acc.register(sc, name = Some(name), countFailedValues = true) --- End diff -- oh damn this is a mistake! I'm surprised our tests don't cover it, I'll fix it and add regression test
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org