Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22365#discussion_r219034294
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala ---
    @@ -370,29 +370,76 @@ final class DataFrameStatFunctions private[sql](df: 
DataFrame) {
        * @since 1.5.0
        */
       def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): 
DataFrame = {
    +    sampleBy(Column(col), fractions, seed)
    +  }
    +
    +  /**
    +   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
    +   * @param col column that defines strata
    +   * @param fractions sampling fraction for each stratum. If a stratum is 
not specified, we treat
    +   *                  its fraction as zero.
    +   * @param seed random seed
    +   * @tparam T stratum type
    +   * @return a new `DataFrame` that represents the stratified sample
    +   *
    +   * @since 1.5.0
    +   */
    +  def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: 
Long): DataFrame = {
    +    sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], 
seed)
    +  }
    +
    +  /**
    +   * Returns a stratified sample without replacement based on the fraction 
given on each stratum.
    +   * @param col column that defines strata
    +   * @param fractions sampling fraction for each stratum. If a stratum is 
not specified, we treat
    +   *                  its fraction as zero.
    +   * @param seed random seed
    +   * @tparam T stratum type
    +   * @return a new `DataFrame` that represents the stratified sample
    +   *
    +   * The stratified sample can be performed over multiple columns:
    +   * {{{
    +   *    import org.apache.spark.sql.Row
    +   *    import org.apache.spark.sql.functions.struct
    +   *
    +   *    val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), 
("Nico", 8), ("Bob", 17),
    +   *      ("Alice", 10))).toDF("name", "age")
    +   *    val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
    +   *    df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
    +   *    +-----+---+
    +   *    | name|age|
    +   *    +-----+---+
    +   *    | Nico|  8|
    +   *    |Alice| 10|
    +   *    +-----+---+
    +   * }}}
    +   *
    +   * @since 3.0.0
    --- End diff --
    
    the next release is 2.5.0


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to