Github user srowen commented on a diff in the pull request: https://github.com/apache/spark/pull/21561#discussion_r209256004 --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala --- @@ -157,11 +157,15 @@ class NaiveBayes @Since("1.5.0") ( instr.logNumFeatures(numFeatures) val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val countAccum = dataset.sparkSession.sparkContext.longAccumulator + // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd - .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) + .map { row => + countAccum.add(1L) --- End diff -- Is this guaranteed to work correctly, given that this is in a map operation? wondering if this introduces a correctness issue or whether this number is available elsewhere.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org