Github user mengxr commented on a diff in the pull request: https://github.com/apache/spark/pull/2455#discussion_r17769391 --- Diff: core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala --- @@ -43,66 +46,218 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable throw new NotImplementedError("clone() is not implemented.") } + +object RandomSampler { + // Default random number generator used by random samplers + def rngDefault: Random = new XORShiftRandom + + // Default gap sampling maximum + // For sampling fractions <= this value, the gap sampling optimization will be applied. + // Above this value, it is assumed that "tradtional" bernoulli sampling is faster. The + // optimal value for this will depend on the RNG. More expensive RNGs will tend to make + // the optimal value higher. The most reliable way to determine this value for a given RNG + // is to experiment. I would expect a value of 0.5 to be close in most cases. + def gsmDefault: Double = 0.4 + + // Default gap sampling epsilon + // When sampling random floating point values the gap sampling logic requires value > 0. An + // optimal value for this parameter is at or near the minimum positive floating point value + // returned by nextDouble() for the RNG being used. + def epsDefault: Double = 5e-11 +} + + /** * :: DeveloperApi :: * A sampler based on Bernoulli trials. * - * @param lb lower bound of the acceptance range - * @param ub upper bound of the acceptance range - * @param complement whether to use the complement of the range specified, default to false + * @param fraction the sampling fraction, aka Bernoulli sampling probability * @tparam T item type */ @DeveloperApi -class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) - extends RandomSampler[T, T] { +class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { - private[random] var rng: Random = new XORShiftRandom + require(fraction >= 0.0 && fraction <= 1.0, "Sampling fraction must be on interval [0, 1]") - def this(ratio: Double) = this(0.0d, ratio) + def this(lb: Double, ub: Double, complement: Boolean = false) = + this(if (complement) (1.0 - (ub - lb)) else (ub - lb)) + + private val rng: Random = RandomSampler.rngDefault override def setSeed(seed: Long) = rng.setSeed(seed) override def sample(items: Iterator[T]): Iterator[T] = { - items.filter { item => - val x = rng.nextDouble() - (x >= lb && x < ub) ^ complement + fraction match { + case f if (f <= 0.0) => Iterator.empty + case f if (f >= 1.0) => items + case f if (f <= RandomSampler.gsmDefault) => + new GapSamplingIterator(items, f, rng, RandomSampler.epsDefault) + case _ => items.filter(_ => (rng.nextDouble() <= fraction)) --- End diff -- Did you test whether `rdd.randomSplit()` will produce non-overlapping subsets with this change?
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org