This is an automated email from the ASF dual-hosted git repository. huaxingao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b0548c6 [SPARK-37165][SQL] Add REPEATABLE in TABLESAMPLE to specify seed b0548c6 is described below commit b0548c66672b54dfa82914c09caa581e0ff33947 Author: Huaxin Gao <huaxin_...@apple.com> AuthorDate: Sat Oct 30 09:27:11 2021 -0700 [SPARK-37165][SQL] Add REPEATABLE in TABLESAMPLE to specify seed ### What changes were proposed in this pull request? Add REPEATABLE in SQL syntax TABLESAMPLE so user can specify seed. ### Why are the changes needed? Current syntax for TABLESAMPLE: - TABLESAMPLE(x PERCENT) - TABLESAMPLE(BUCKET x OUT OF y) `Dataset.sample` has a param to specify seed, so we should allow SQL has a way to specify seed too. ``` def sample(fraction: Double, seed: Long): Dataset[T] = { sample(withReplacement = false, fraction = fraction, seed = seed) } ``` Most of the DBMS uses REPEATABLE to let user specify seed, e.g. DB2, we will follow the same way. <img width="1032" alt="Screen Shot 2021-10-29 at 8 46 04 AM" src="https://user-images.githubusercontent.com/13592258/139465718-285ab5fb-a9cf-4bef-bc32-88301745b12b.png"> ### Does this PR introduce _any_ user-facing change? Yes new SQL syntax - TABLESAMPLE(x PERCENT) [REPEATABLE (seed)] - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (seed)] ### How was this patch tested? new UT Closes #34442 from huaxingao/sample_syntax. Authored-by: Huaxin Gao <huaxin_...@apple.com> Signed-off-by: Huaxin Gao <huaxin_...@apple.com> --- docs/sql-ref-ansi-compliance.md | 1 + .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 5 ++++- .../spark/sql/catalyst/parser/AstBuilder.scala | 21 ++++++++++++++------- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++++++++++ 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 4527faa..2de8ba7 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -493,6 +493,7 @@ Below is a list of all the keywords in Spark SQL. |REGEXP|non-reserved|non-reserved|not a keyword| |RENAME|non-reserved|non-reserved|non-reserved| |REPAIR|non-reserved|non-reserved|non-reserved| +|REPEATABLE|non-reserved|non-reserved|non-reserved| |REPLACE|non-reserved|non-reserved|non-reserved| |RESET|non-reserved|non-reserved|non-reserved| |RESPECT|non-reserved|non-reserved|non-reserved| diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 32b080c..e1bccf6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -674,7 +674,7 @@ joinCriteria ; sample - : TABLESAMPLE '(' sampleMethod? ')' + : TABLESAMPLE '(' sampleMethod? ')' (REPEATABLE '('seed=INTEGER_VALUE')')? ; sampleMethod @@ -1194,6 +1194,7 @@ ansiNonReserved | REFRESH | RENAME | REPAIR + | REPEATABLE | REPLACE | RESET | RESPECT @@ -1460,6 +1461,7 @@ nonReserved | REFRESH | RENAME | REPAIR + | REPEATABLE | REPLACE | RESET | RESPECT @@ -1726,6 +1728,7 @@ REFERENCES: 'REFERENCES'; REFRESH: 'REFRESH'; RENAME: 'RENAME'; REPAIR: 'REPAIR'; +REPEATABLE: 'REPEATABLE'; REPLACE: 'REPLACE'; RESET: 'RESET'; RESPECT: 'RESPECT'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 768d406..722a055 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1174,13 +1174,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * * This currently supports the following sampling methods: * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. - * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages - * are defined as a number between 0 and 100. - * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + * - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the given percentage with + * seed 'y'. Note that percentages are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down to a 'x' divided by + * 'y' fraction with seed 'z'. */ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { // Create a sampled plan if we need one. - def sample(fraction: Double): Sample = { + def sample(fraction: Double, seed: Long): Sample = { // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. @@ -1188,13 +1189,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + Sample(0.0, fraction, withReplacement = false, seed, query) } if (ctx.sampleMethod() == null) { throw QueryParsingErrors.emptyInputForTableSampleError(ctx) } + val seed = if (ctx.seed != null) { + ctx.seed.getText.toLong + } else { + (math.random * 1000).toLong + } + ctx.sampleMethod() match { case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) @@ -1202,7 +1209,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 - sample(sign * fraction / 100.0d) + sample(sign * fraction / 100.0d, seed) case ctx: SampleByBytesContext => val bytesStr = ctx.bytes.getText @@ -1222,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } case ctx: SampleByBucketContext => - sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 11b7ee6..0b8e10c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4211,6 +4211,22 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("SELECT * FROM testData, LATERAL (SELECT * FROM testData)").collect() } } + + test("TABLE SAMPLE") { + withTable("test") { + sql("CREATE TABLE test(c int) USING PARQUET") + for (i <- 0 to 20) { + sql(s"INSERT INTO test VALUES ($i)") + } + val df1 = sql("SELECT * FROM test TABLESAMPLE (20 PERCENT) REPEATABLE (12345)") + val df2 = sql("SELECT * FROM test TABLESAMPLE (20 PERCENT) REPEATABLE (12345)") + checkAnswer(df1, df2) + + val df3 = sql("SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10) REPEATABLE (6789)") + val df4 = sql("SELECT * FROM test TABLESAMPLE (BUCKET 4 OUT OF 10) REPEATABLE (6789)") + checkAnswer(df3, df4) + } + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org