Github user srowen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20572#discussion_r169490790
  
    --- Diff: 
external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
 ---
    @@ -87,47 +89,63 @@ private[spark] class KafkaRDD[K, V](
         }.toArray
       }
     
    -  override def count(): Long = offsetRanges.map(_.count).sum
    +  override def count(): Long =
    +    if (compacted) {
    +      super.count()
    +    } else {
    +      offsetRanges.map(_.count).sum
    +    }
     
       override def countApprox(
           timeout: Long,
           confidence: Double = 0.95
    -  ): PartialResult[BoundedDouble] = {
    -    val c = count
    -    new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
    -  }
    +  ): PartialResult[BoundedDouble] =
    +    if (compacted) {
    +      super.countApprox(timeout, confidence)
    +    } else {
    +      val c = count
    +      new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
    +    }
     
    -  override def isEmpty(): Boolean = count == 0L
    +  override def isEmpty(): Boolean =
    +    if (compacted) {
    +      super.isEmpty()
    +    } else {
    +      count == 0L
    +    }
     
    -  override def take(num: Int): Array[ConsumerRecord[K, V]] = {
    -    val nonEmptyPartitions = this.partitions
    -      .map(_.asInstanceOf[KafkaRDDPartition])
    -      .filter(_.count > 0)
    +  override def take(num: Int): Array[ConsumerRecord[K, V]] =
    +    if (compacted) {
    +      super.take(num)
    +    } else {
    +      val nonEmptyPartitions = this.partitions
    +        .map(_.asInstanceOf[KafkaRDDPartition])
    +        .filter(_.count > 0)
     
    -    if (num < 1 || nonEmptyPartitions.isEmpty) {
    -      return new Array[ConsumerRecord[K, V]](0)
    -    }
    +      if (num < 1 || nonEmptyPartitions.isEmpty) {
    +        return new Array[ConsumerRecord[K, V]](0)
    +      }
     
    -    // Determine in advance how many messages need to be taken from each 
partition
    -    val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, 
part) =>
    -      val remain = num - result.values.sum
    -      if (remain > 0) {
    -        val taken = Math.min(remain, part.count)
    -        result + (part.index -> taken.toInt)
    -      } else {
    -        result
    +      // Determine in advance how many messages need to be taken from each 
partition
    +      val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, 
part) =>
    +        val remain = num - result.values.sum
    +        if (remain > 0) {
    +          val taken = Math.min(remain, part.count)
    +          result + (part.index -> taken.toInt)
    +        } else {
    +          result
    +        }
           }
    -    }
     
    -    val buf = new ArrayBuffer[ConsumerRecord[K, V]]
    -    val res = context.runJob(
    -      this,
    -      (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) =>
    -      it.take(parts(tc.partitionId)).toArray, parts.keys.toArray
    -    )
    -    res.foreach(buf ++= _)
    -    buf.toArray
    -  }
    +      val buf = new ArrayBuffer[ConsumerRecord[K, V]]
    +      val res = context.runJob(
    +        this,
    +        (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) =>
    +        it.take(parts(tc.partitionId)).toArray, parts.keys.toArray
    +      )
    +      res.foreach(buf ++= _)
    +      buf.toArray
    --- End diff --
    
    I am not sure why this code doesn't just `.flatten` the result of `.runJob` 
to get an array of all of the results. Feel free to change it, or not. Maybe 
I'm missing something


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to