Github user zhengruifeng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21561#discussion_r209496789
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---
    @@ -157,11 +157,15 @@ class NaiveBayes @Since("1.5.0") (
         instr.logNumFeatures(numFeatures)
         val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) 
else col($(weightCol))
     
    +    val countAccum = dataset.sparkSession.sparkContext.longAccumulator
    +
         // Aggregates term frequencies per label.
         // TODO: Calling aggregateByKey and collect creates two stages, we can 
implement something
         // TODO: similar to reduceByKeyLocally to save one stage.
         val aggregated = dataset.select(col($(labelCol)), w, 
col($(featuresCol))).rdd
    -      .map { row => (row.getDouble(0), (row.getDouble(1), 
row.getAs[Vector](2)))
    +      .map { row =>
    +        countAccum.add(1L)
    --- End diff --
    
    This should work correctly, however, to guarantee the correctness, I update 
the pr to compute the number without Accumulator


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to