Github user squito commented on a diff in the pull request:

    https://github.com/apache/spark/pull/3100#discussion_r20130288
  
    --- Diff: 
graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala ---
    @@ -285,50 +337,126 @@ class EdgePartition[
       }
     
       /**
    -   * Upgrade the given edge iterator into a triplet iterator.
    +   * Send messages along edges and aggregate them at the receiving 
vertices. Implemented by scanning
    +   * all edges sequentially and filtering them with `idPred`.
    +   *
    +   * @param sendMsg generates messages to neighboring vertices of an edge
    +   * @param mergeMsg the combiner applied to messages destined to the same 
vertex
    +   * @param sendMsgUsesSrcAttr whether or not `mapFunc` uses the edge's 
source vertex attribute
    +   * @param sendMsgUsesDstAttr whether or not `mapFunc` uses the edge's 
destination vertex attribute
    +   * @param idPred a predicate to filter edges based on their source and 
destination vertex ids
        *
    -   * Be careful not to keep references to the objects from this iterator. 
To improve GC performance
    -   * the same object is re-used in `next()`.
    +   * @return iterator aggregated messages keyed by the receiving vertex id
        */
    -  def upgradeIterator(
    -      edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, 
includeDst: Boolean = true)
    -    : Iterator[EdgeTriplet[VD, ED]] = {
    -    new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst)
    +  def aggregateMessages[A: ClassTag](
    +      sendMsg: EdgeContext[VD, ED, A] => Unit,
    +      mergeMsg: (A, A) => A,
    +      tripletFields: TripletFields,
    +      idPred: (VertexId, VertexId) => Boolean): Iterator[(VertexId, A)] = {
    +    val aggregates = new Array[A](vertexAttrs.length)
    +    val bitset = new BitSet(vertexAttrs.length)
    +
    +    var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, 
bitset)
    +    var i = 0
    +    while (i < size) {
    +      val localSrcId = localSrcIds(i)
    +      val srcId = local2global(localSrcId)
    +      val localDstId = localDstIds(i)
    +      val dstId = local2global(localDstId)
    +      if (idPred(srcId, dstId)) {
    +        ctx.localSrcId = localSrcId
    +        ctx.localDstId = localDstId
    +        ctx.srcId = srcId
    +        ctx.dstId = dstId
    +        ctx.attr = data(i)
    +        if (tripletFields.useSrc) { ctx.srcAttr = vertexAttrs(localSrcId) }
    +        if (tripletFields.useDst) { ctx.dstAttr = vertexAttrs(localDstId) }
    +        sendMsg(ctx)
    +      }
    +      i += 1
    +    }
    +
    +    bitset.iterator.map { localId => (local2global(localId), 
aggregates(localId)) }
       }
     
       /**
    -   * Get an iterator over the edges in this partition whose source vertex 
ids match srcIdPred. The
    -   * iterator is generated using an index scan, so it is efficient at 
skipping edges that don't
    -   * match srcIdPred.
    +   * Send messages along edges and aggregate them at the receiving 
vertices. Implemented by
    +   * filtering the source vertex index with `srcIdPred`, then scanning 
edge clusters and filtering
    +   * with `dstIdPred`. Both `srcIdPred` and `dstIdPred` must match for an 
edge to run.
        *
    -   * Be careful not to keep references to the objects from this iterator. 
To improve GC performance
    -   * the same object is re-used in `next()`.
    -   */
    -  def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] =
    -    index.iterator.filter(kv => 
srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
    -
    -  /**
    -   * Get an iterator over the cluster of edges in this partition with 
source vertex id `srcId`. The
    -   * cluster must start at position `index`.
    +   * @param sendMsg generates messages to neighboring vertices of an edge
    +   * @param mergeMsg the combiner applied to messages destined to the same 
vertex
    +   * @param srcIdPred a predicate to filter edges based on their source 
vertex id
    +   * @param dstIdPred a predicate to filter edges based on their 
destination vertex id
        *
    -   * Be careful not to keep references to the objects from this iterator. 
To improve GC performance
    -   * the same object is re-used in `next()`.
    +   * @return iterator aggregated messages keyed by the receiving vertex id
        */
    -  private def clusterIterator(srcId: VertexId, index: Int) = new 
Iterator[Edge[ED]] {
    -    private[this] val edge = new Edge[ED]
    -    private[this] var pos = index
    +  def aggregateMessagesWithIndex[A: ClassTag](
    +      sendMsg: EdgeContext[VD, ED, A] => Unit,
    +      mergeMsg: (A, A) => A,
    +      tripletFields: TripletFields,
    +      srcIdPred: VertexId => Boolean,
    +      dstIdPred: VertexId => Boolean): Iterator[(VertexId, A)] = {
    +    val aggregates = new Array[A](vertexAttrs.length)
    +    val bitset = new BitSet(vertexAttrs.length)
     
    -    override def hasNext: Boolean = {
    -      pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
    +    var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, 
bitset)
    +    index.iterator.foreach { cluster =>
    +      val clusterSrcId = cluster._1
    +      val clusterPos = cluster._2
    +      val clusterLocalSrcId = localSrcIds(clusterPos)
    +      if (srcIdPred(clusterSrcId)) {
    +        var pos = clusterPos
    +        ctx.srcId = clusterSrcId
    +        ctx.localSrcId = clusterLocalSrcId
    +        if (tripletFields.useSrc) { ctx.srcAttr = 
vertexAttrs(clusterLocalSrcId) }
    +        while (pos < size && localSrcIds(pos) == clusterLocalSrcId) {
    +          val localDstId = localDstIds(pos)
    +          val dstId = local2global(localDstId)
    +          if (dstIdPred(dstId)) {
    +            ctx.dstId = dstId
    +            ctx.localDstId = localDstId
    +            ctx.attr = data(pos)
    +            if (tripletFields.useDst) { ctx.dstAttr = 
vertexAttrs(localDstId) }
    +            sendMsg(ctx)
    +          }
    +          pos += 1
    +        }
    +      }
         }
     
    -    override def next(): Edge[ED] = {
    -      assert(srcIds(pos) == srcId)
    -      edge.srcId = srcIds(pos)
    -      edge.dstId = dstIds(pos)
    -      edge.attr = data(pos)
    -      pos += 1
    -      edge
    +    bitset.iterator.map { localId => (local2global(localId), 
aggregates(localId)) }
    +  }
    +}
    +
    +private class AggregatingEdgeContext[VD, ED, A](
    +    mergeMsg: (A, A) => A,
    +    aggregates: Array[A],
    +    bitset: BitSet)
    +  extends EdgeContext[VD, ED, A] {
    +
    +  var srcId: VertexId = _
    --- End diff --
    
    @rxin just curious -- do you know whether the JIT will take care of for 
you?  I had assumed it would, but I've been wrong plenty of times before ...


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to