This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9c30116 [SPARK-33857][SQL] Unify the default seed of random functions 9c30116 is described below commit 9c30116fb428f87543155323617cf5fb700e84cd Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Thu Dec 24 14:30:34 2020 -0800 [SPARK-33857][SQL] Unify the default seed of random functions ### What changes were proposed in this pull request? Unify the seed of random functions 1. Add a hold place expression `UnresolvedSeed ` as the defualt seed. 2. Change `Rand`,`Randn`,`Uuid`,`Shuffle` default seed to `UnresolvedSeed `. 3. Replace `UnresolvedSeed ` to real seed at `ResolveRandomSeed` rule. ### Why are the changes needed? `Uuid` and `Shuffle` use the `ResolveRandomSeed` rule to set the seed if user doesn't give a seed value. `Rand` and `Randn` do this at constructing. It's better to unify the default seed at Analyzer side since we have used `ExpressionWithRandomSeed` at streaming query. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass exists test and add test. Closes #30864 from ulysses-you/SPARK-33857. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../spark/sql/catalyst/analysis/unresolved.scala | 9 ++++++++ .../expressions/collectionOperations.scala | 4 +++- .../spark/sql/catalyst/expressions/misc.scala | 3 +++ .../catalyst/expressions/randomExpressions.scala | 24 ++++++++++++---------- .../sql/catalyst/analysis/AnalysisSuite.scala | 12 +++++++++++ 6 files changed, 42 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ba24914..8af692d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3000,8 +3000,8 @@ class Analyzer(override val catalogManager: CatalogManager) override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if p.resolved => p case p => p transformExpressionsUp { - case Uuid(None) => Uuid(Some(random.nextLong())) - case Shuffle(child, None) => Shuffle(child, Some(random.nextLong())) + case e: ExpressionWithRandomSeed if e.seedExpression == UnresolvedSeed => + e.withNewSeed(random.nextLong()) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 8a73208..8461488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -561,3 +561,12 @@ case class UnresolvedHaving( override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output } + +/** + * A place holder expression used in random functions, will be replaced after analyze. + */ +case object UnresolvedSeed extends LeafExpression with Unevaluable { + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3379446..17b45bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedSeed} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -943,6 +943,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) def this(child: Expression) = this(child, None) + override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed) + override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) override lazy val resolved: Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 34a64dd..4ad4c4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator @@ -187,6 +188,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta def this() = this(None) + override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed) + override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) override lazy val resolved: Boolean = randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 0fa4d6c..630c934 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom /** @@ -32,7 +32,8 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful { +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful + with ExpressionWithRandomSeed { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -43,7 +44,9 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful rng = new XORShiftRandom(seed + partitionIndex) } - @transient protected lazy val seed: Long = child match { + override def seedExpression: Expression = child + + @transient protected lazy val seed: Long = seedExpression match { case Literal(s, IntegerType) => s.asInstanceOf[Int] case Literal(s, LongType) => s.asInstanceOf[Long] case _ => throw new AnalysisException( @@ -62,6 +65,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful * Usually the random seed needs to be renewed at each execution under streaming queries. */ trait ExpressionWithRandomSeed { + def seedExpression: Expression def withNewSeed(seed: Long): Expression } @@ -84,14 +88,13 @@ trait ExpressionWithRandomSeed { since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Rand(child: Expression, hideSeed: Boolean = false) - extends RDG with ExpressionWithRandomSeed { +case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG { - def this() = this(Literal(Utils.random.nextLong(), LongType), true) + def this() = this(UnresolvedSeed, true) def this(child: Expression) = this(child, false) - override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType)) + override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed) override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -136,14 +139,13 @@ object Rand { since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Randn(child: Expression, hideSeed: Boolean = false) - extends RDG with ExpressionWithRandomSeed { +case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { - def this() = this(Literal(Utils.random.nextLong(), LongType), true) + def this() = this(UnresolvedSeed, true) def this(child: Expression) = this(child, false) - override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType)) + override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed) override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index b206bc9..f66871e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1006,4 +1006,16 @@ class AnalysisSuite extends AnalysisTest with Matchers { checkAnalysis(plan, expect) } } + + test("SPARK-33857: Unify the default seed of random functions") { + Seq(new Rand(), new Randn(), Shuffle(Literal(Array(1))), Uuid()).foreach { r => + assert(r.seedExpression == UnresolvedSeed) + val p = getAnalyzer.execute(Project(Seq(r.as("r")), testRelation)) + assert( + p.asInstanceOf[Project].projectList.head.asInstanceOf[Alias] + .child.asInstanceOf[ExpressionWithRandomSeed] + .seedExpression.isInstanceOf[Literal] + ) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org