Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/8112#discussion_r40339926
  
    --- Diff: core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
---
    @@ -263,6 +263,80 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       }
     
       /**
    +   * ::Experimental::
    +   * Return random, non-overlapping splits of this RDD sampled by key (via 
stratified sampling)
    +   * with each split containing exactly math.ceil(numItems * samplingRate) 
for each stratum.
    +   *
    +   * This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in 
that it provides random
    +   * splits (and their complements) instead of just a subsample of the 
data. This requires
    +   * segmenting random keys into ranges with upper and lower bounds 
instead of segmenting the keys
    +   * into a high/low bisection of the entire dataset.
    +   *
    +   * @param weights array of maps of (key -> samplingRate) pairs for each 
split, normed by key
    +   * @param exact boolean specifying whether to use exact subsampling
    +   * @param seed seed for the random number generator
    +   * @return array of tuples containing the subsample and complement RDDs 
for each split
    +   */
    +  @Experimental
    +  def randomSplitByKey(
    +     weights: Array[Map[K, Double]],
    +     exact: Boolean = false,
    +     seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, 
V)])] = self.withScope {
    +
    +    require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative 
sampling rates.")
    +    if (weights.length > 1) {
    +      require(weights.map(m => m.keys.toList).sliding(2).forall(t => t(0) 
== t(1)),
    +        "Inconsistent keys between splits.")
    +    }
    +
    +    // normalize and cumulative sum
    +    val baseFold = weights(0).map(x => (x._1, 0.0))
    +    val cumWeightsByKey = weights.scanLeft(baseFold){ case (accMap, 
iterMap) =>
    +      accMap.map { case (k, v) => (k, v + iterMap(k)) }
    +    }.drop(1)
    +
    +    val weightSumsByKey = cumWeightsByKey.last
    +    val normedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { 
case (key, threshold) =>
    +      (key, threshold / weightSumsByKey(key))
    +    })
    +
    +    // compute exact thresholds for each stratum if required
    +    val splits = if (exact) {
    --- End diff --
    
    Renamed some of the variables to make this more readable for other 
developers. 


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to