Repository: spark Updated Branches: refs/heads/branch-2.1 c42301f1e -> dcbf3fd4b
[SPARK-17854][SQL] rand/randn allows null/long as input seed ## What changes were proposed in this pull request? This PR proposes `rand`/`randn` accept `null` as input in Scala/SQL and `LongType` as input in SQL. In this case, it treats the values as `0`. So, this PR includes both changes below: - `null` support It seems MySQL also accepts this. ``` sql mysql> select rand(0); +---------------------+ | rand(0) | +---------------------+ | 0.15522042769493574 | +---------------------+ 1 row in set (0.00 sec) mysql> select rand(NULL); +---------------------+ | rand(NULL) | +---------------------+ | 0.15522042769493574 | +---------------------+ 1 row in set (0.00 sec) ``` and also Hive does according to [HIVE-14694](https://issues.apache.org/jira/browse/HIVE-14694) So the codes below: ``` scala spark.range(1).selectExpr("rand(null)").show() ``` prints.. **Before** ``` Input argument to rand must be an integer literal.;; line 1 pos 0 org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer literal.;; line 1 pos 0 at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:465) at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:444) ``` **After** ``` +-----------------------+ |rand(CAST(NULL AS INT))| +-----------------------+ | 0.13385709732307427| +-----------------------+ ``` - `LongType` support in SQL. In addition, it make the function allows to take `LongType` consistently within Scala/SQL. In more details, the codes below: ``` scala spark.range(1).select(rand(1), rand(1L)).show() spark.range(1).selectExpr("rand(1)", "rand(1L)").show() ``` prints.. **Before** ``` +------------------+------------------+ | rand(1)| rand(1)| +------------------+------------------+ |0.2630967864682161|0.2630967864682161| +------------------+------------------+ Input argument to rand must be an integer literal.;; line 1 pos 0 org.apache.spark.sql.AnalysisException: Input argument to rand must be an integer literal.;; line 1 pos 0 at org.apache.spark.sql.catalyst.analysis.FunctionRegistry$$anonfun$5.apply(FunctionRegistry.scala:465) at ``` **After** ``` +------------------+------------------+ | rand(1)| rand(1)| +------------------+------------------+ |0.2630967864682161|0.2630967864682161| +------------------+------------------+ +------------------+------------------+ | rand(1)| rand(1)| +------------------+------------------+ |0.2630967864682161|0.2630967864682161| +------------------+------------------+ ``` ## How was this patch tested? Unit tests in `DataFrameSuite.scala` and `RandomSuite.scala`. Author: hyukjinkwon <gurwls...@gmail.com> Closes #15432 from HyukjinKwon/SPARK-17854. (cherry picked from commit 340f09d100cb669bc6795f085aac6fa05630a076) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dcbf3fd4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dcbf3fd4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dcbf3fd4 Branch: refs/heads/branch-2.1 Commit: dcbf3fd4bd42059aed9c966d4f0cdf58815eb802 Parents: c42301f Author: hyukjinkwon <gurwls...@gmail.com> Authored: Sun Nov 6 14:11:37 2016 +0000 Committer: Sean Owen <so...@cloudera.com> Committed: Sun Nov 6 14:11:47 2016 +0000 ---------------------------------------------------------------------- .../expressions/randomExpressions.scala | 50 +++++++----- .../sql/catalyst/expressions/RandomSuite.scala | 6 ++ .../test/resources/sql-tests/inputs/random.sql | 17 ++++ .../resources/sql-tests/results/random.sql.out | 84 ++++++++++++++++++++ 4 files changed, 135 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/dcbf3fd4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index a331a55..1d7a3c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.{DataType, DoubleType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -32,10 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends LeafExpression with Nondeterministic { - - protected def seed: Long - +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -46,12 +42,18 @@ abstract class RDG extends LeafExpression with Nondeterministic { rng = new XORShiftRandom(seed + partitionIndex) } + @transient protected lazy val seed: Long = child match { + case Literal(s, IntegerType) => s.asInstanceOf[Int] + case Literal(s, LongType) => s.asInstanceOf[Long] + case _ => throw new AnalysisException( + s"Input argument to $prettyName must be an integer, long or null literal.") + } + override def nullable: Boolean = false override def dataType: DataType = DoubleType - // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. - override def sql: String = s"$prettyName($seed)" + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ @@ -64,17 +66,15 @@ abstract class RDG extends LeafExpression with Nondeterministic { 0.9629742951434543 > SELECT _FUNC_(0); 0.8446490682263027 + > SELECT _FUNC_(null); + 0.8446490682263027 """) // scalastyle:on line.size.limit -case class Rand(seed: Long) extends RDG { - override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() +case class Rand(child: Expression) extends RDG { - def this() = this(Utils.random.nextLong()) + def this() = this(Literal(Utils.random.nextLong(), LongType)) - def this(seed: Expression) = this(seed match { - case IntegerLiteral(s) => s - case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") - }) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") @@ -87,6 +87,10 @@ case class Rand(seed: Long) extends RDG { } } +object Rand { + def apply(seed: Long): Rand = Rand(Literal(seed, LongType)) +} + /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( @@ -97,17 +101,15 @@ case class Rand(seed: Long) extends RDG { -0.3254147983080288 > SELECT _FUNC_(0); 1.1164209726833079 + > SELECT _FUNC_(null); + 1.1164209726833079 """) // scalastyle:on line.size.limit -case class Randn(seed: Long) extends RDG { - override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() +case class Randn(child: Expression) extends RDG { - def this() = this(Utils.random.nextLong()) + def this() = this(Literal(Utils.random.nextLong(), LongType)) - def this(seed: Expression) = this(seed match { - case IntegerLiteral(s) => s - case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") - }) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") @@ -119,3 +121,7 @@ case class Randn(seed: Long) extends RDG { final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } } + +object Randn { + def apply(seed: Long): Randn = Randn(Literal(seed, LongType)) +} http://git-wip-us.apache.org/repos/asf/spark/blob/dcbf3fd4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index b7a0d44..752c9d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,12 +20,18 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{IntegerType, LongType} class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) + + checkDoubleEvaluation( + new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001) + checkDoubleEvaluation( + new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001) } test("SPARK-9127 codegen with long seed") { http://git-wip-us.apache.org/repos/asf/spark/blob/dcbf3fd4/sql/core/src/test/resources/sql-tests/inputs/random.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql new file mode 100644 index 0000000..a1aae7b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql @@ -0,0 +1,17 @@ +-- rand with the seed 0 +SELECT rand(0); +SELECT rand(cast(3 / 7 AS int)); +SELECT rand(NULL); +SELECT rand(cast(NULL AS int)); + +-- rand unsupported data type +SELECT rand(1.0); + +-- randn with the seed 0 +SELECT randn(0L); +SELECT randn(cast(3 / 7 AS long)); +SELECT randn(NULL); +SELECT randn(cast(NULL AS long)); + +-- randn unsupported data type +SELECT rand('1') http://git-wip-us.apache.org/repos/asf/spark/blob/dcbf3fd4/sql/core/src/test/resources/sql-tests/results/random.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out new file mode 100644 index 0000000..bca6732 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -0,0 +1,84 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +SELECT rand(0) +-- !query 0 schema +struct<rand(0):double> +-- !query 0 output +0.8446490682263027 + + +-- !query 1 +SELECT rand(cast(3 / 7 AS int)) +-- !query 1 schema +struct<rand(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS INT)):double> +-- !query 1 output +0.8446490682263027 + + +-- !query 2 +SELECT rand(NULL) +-- !query 2 schema +struct<rand(CAST(NULL AS INT)):double> +-- !query 2 output +0.8446490682263027 + + +-- !query 3 +SELECT rand(cast(NULL AS int)) +-- !query 3 schema +struct<rand(CAST(NULL AS INT)):double> +-- !query 3 output +0.8446490682263027 + + +-- !query 4 +SELECT rand(1.0) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'rand(1.0BD)' due to data type mismatch: argument 1 requires (int or bigint) type, however, '1.0BD' is of decimal(2,1) type.; line 1 pos 7 + + +-- !query 5 +SELECT randn(0L) +-- !query 5 schema +struct<randn(0):double> +-- !query 5 output +1.1164209726833079 + + +-- !query 6 +SELECT randn(cast(3 / 7 AS long)) +-- !query 6 schema +struct<randn(CAST((CAST(3 AS DOUBLE) / CAST(7 AS DOUBLE)) AS BIGINT)):double> +-- !query 6 output +1.1164209726833079 + + +-- !query 7 +SELECT randn(NULL) +-- !query 7 schema +struct<randn(CAST(NULL AS INT)):double> +-- !query 7 output +1.1164209726833079 + + +-- !query 8 +SELECT randn(cast(NULL AS long)) +-- !query 8 schema +struct<randn(CAST(NULL AS BIGINT)):double> +-- !query 8 output +1.1164209726833079 + + +-- !query 9 +SELECT rand('1') +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'rand('1')' due to data type mismatch: argument 1 requires (int or bigint) type, however, ''1'' is of string type.; line 1 pos 7 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org