Repository: spark
Updated Branches:
  refs/heads/branch-2.0 182991edd -> ddbff011e


[SPARK-16875][SQL] Add args checking for DataSet randomSplit and sample

## What changes were proposed in this pull request?

Add the missing args-checking for randomSplit and sample

## How was this patch tested?
unit tests

Author: Zheng RuiFeng <ruife...@foxmail.com>

Closes #14478 from zhengruifeng/fix_randomSplit.

(cherry picked from commit be8ea4b2f7ddf1196111acb61fe1a79866376003)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ddbff011
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ddbff011
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ddbff011

Branch: refs/heads/branch-2.0
Commit: ddbff011eaa79aa4c96184bfd15682fbb220f8e7
Parents: 182991e
Author: Zheng RuiFeng <ruife...@foxmail.com>
Authored: Thu Aug 4 21:39:45 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Aug 4 21:39:55 2016 +0100

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 37 +++++++++++++-------
 .../scala/org/apache/spark/sql/Dataset.scala    | 14 ++++++--
 2 files changed, 37 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ddbff011/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a4905dd..2ee13dc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -474,12 +474,17 @@ abstract class RDD[T: ClassTag](
   def sample(
       withReplacement: Boolean,
       fraction: Double,
-      seed: Long = Utils.random.nextLong): RDD[T] = withScope {
-    require(fraction >= 0.0, "Negative fraction value: " + fraction)
-    if (withReplacement) {
-      new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), 
true, seed)
-    } else {
-      new PartitionwiseSampledRDD[T, T](this, new 
BernoulliSampler[T](fraction), true, seed)
+      seed: Long = Utils.random.nextLong): RDD[T] = {
+    require(fraction >= 0,
+      s"Fraction must be nonnegative, but got ${fraction}")
+
+    withScope {
+      require(fraction >= 0.0, "Negative fraction value: " + fraction)
+      if (withReplacement) {
+        new PartitionwiseSampledRDD[T, T](this, new 
PoissonSampler[T](fraction), true, seed)
+      } else {
+        new PartitionwiseSampledRDD[T, T](this, new 
BernoulliSampler[T](fraction), true, seed)
+      }
     }
   }
 
@@ -493,14 +498,22 @@ abstract class RDD[T: ClassTag](
    */
   def randomSplit(
       weights: Array[Double],
-      seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope {
-    val sum = weights.sum
-    val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
-    normalizedCumWeights.sliding(2).map { x =>
-      randomSampleWithRange(x(0), x(1), seed)
-    }.toArray
+      seed: Long = Utils.random.nextLong): Array[RDD[T]] = {
+    require(weights.forall(_ >= 0),
+      s"Weights must be nonnegative, but got ${weights.mkString("[", ",", 
"]")}")
+    require(weights.sum > 0,
+      s"Sum of weights must be positive, but got ${weights.mkString("[", ",", 
"]")}")
+
+    withScope {
+      val sum = weights.sum
+      val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+      normalizedCumWeights.sliding(2).map { x =>
+        randomSampleWithRange(x(0), x(1), seed)
+      }.toArray
+    }
   }
 
+
   /**
    * Internal method exposed for Random Splits in DataFrames. Samples an RDD 
given a probability
    * range.

http://git-wip-us.apache.org/repos/asf/spark/blob/ddbff011/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 067cbec..6ca0138 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1500,8 +1500,13 @@ class Dataset[T] private[sql](
    * @group typedrel
    * @since 1.6.0
    */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Long): 
Dataset[T] = withTypedPlan {
-    Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
+  def sample(withReplacement: Boolean, fraction: Double, seed: Long): 
Dataset[T] = {
+    require(fraction >= 0,
+      s"Fraction must be nonnegative, but got ${fraction}")
+
+    withTypedPlan {
+      Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
+    }
   }
 
   /**
@@ -1529,6 +1534,11 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
+    require(weights.forall(_ >= 0),
+      s"Weights must be nonnegative, but got ${weights.mkString("[", ",", 
"]")}")
+    require(weights.sum > 0,
+      s"Sum of weights must be positive, but got ${weights.mkString("[", ",", 
"]")}")
+
     // It is possible that the underlying dataframe doesn't guarantee the 
ordering of rows in its
     // constituent partitions each time a split is materialized which could 
result in
     // overlapping splits. To prevent this, we explicitly sort each input 
partition to make the


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to