This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c30116  [SPARK-33857][SQL] Unify the default seed of random functions
9c30116 is described below

commit 9c30116fb428f87543155323617cf5fb700e84cd
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Thu Dec 24 14:30:34 2020 -0800

    [SPARK-33857][SQL] Unify the default seed of random functions
    
    ### What changes were proposed in this pull request?
    
    Unify the seed of random functions
    1. Add a hold place expression `UnresolvedSeed ` as the defualt seed.
    2. Change `Rand`,`Randn`,`Uuid`,`Shuffle` default seed to `UnresolvedSeed `.
    3. Replace `UnresolvedSeed ` to real seed at `ResolveRandomSeed` rule.
    
    ### Why are the changes needed?
    
    `Uuid` and `Shuffle` use the `ResolveRandomSeed` rule to set the seed if 
user doesn't give a seed value. `Rand` and `Randn` do this at constructing.
    
    It's better to unify the default seed at Analyzer side since we have used 
`ExpressionWithRandomSeed` at streaming query.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Pass exists test and add test.
    
    Closes #30864 from ulysses-you/SPARK-33857.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  4 ++--
 .../spark/sql/catalyst/analysis/unresolved.scala   |  9 ++++++++
 .../expressions/collectionOperations.scala         |  4 +++-
 .../spark/sql/catalyst/expressions/misc.scala      |  3 +++
 .../catalyst/expressions/randomExpressions.scala   | 24 ++++++++++++----------
 .../sql/catalyst/analysis/AnalysisSuite.scala      | 12 +++++++++++
 6 files changed, 42 insertions(+), 14 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ba24914..8af692d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3000,8 +3000,8 @@ class Analyzer(override val catalogManager: 
CatalogManager)
     override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUp {
       case p if p.resolved => p
       case p => p transformExpressionsUp {
-        case Uuid(None) => Uuid(Some(random.nextLong()))
-        case Shuffle(child, None) => Shuffle(child, Some(random.nextLong()))
+        case e: ExpressionWithRandomSeed if e.seedExpression == UnresolvedSeed 
=>
+          e.withNewSeed(random.nextLong())
       }
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 8a73208..8461488 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -561,3 +561,12 @@ case class UnresolvedHaving(
   override lazy val resolved: Boolean = false
   override def output: Seq[Attribute] = child.output
 }
+
+/**
+ * A place holder expression used in random functions, will be replaced after 
analyze.
+ */
+case object UnresolvedSeed extends LeafExpression with Unevaluable {
+  override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
+  override def dataType: DataType = throw new UnresolvedException(this, 
"dataType")
+  override lazy val resolved = false
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 3379446..17b45bc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
 import scala.reflect.ClassTag
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedSeed}
 import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -943,6 +943,8 @@ case class Shuffle(child: Expression, randomSeed: 
Option[Long] = None)
 
   def this(child: Expression) = this(child, None)
 
+  override def seedExpression: Expression = 
randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)
+
   override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
 
   override lazy val resolved: Boolean =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 34a64dd..4ad4c4d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
@@ -187,6 +188,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends 
LeafExpression with Sta
 
   def this() = this(None)
 
+  override def seedExpression: Expression = 
randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)
+
   override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))
 
   override lazy val resolved: Boolean = randomSeed.isDefined
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 0fa4d6c..630c934 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodeGenerator, ExprCode, FalseLiteral}
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
 
 /**
@@ -32,7 +32,8 @@ import org.apache.spark.util.random.XORShiftRandom
  *
  * Since this expression is stateful, it cannot be a case object.
  */
-abstract class RDG extends UnaryExpression with ExpectsInputTypes with 
Stateful {
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
+  with ExpressionWithRandomSeed {
   /**
    * Record ID within each partition. By being transient, the Random Number 
Generator is
    * reset every time we serialize and deserialize and initialize it.
@@ -43,7 +44,9 @@ abstract class RDG extends UnaryExpression with 
ExpectsInputTypes with Stateful
     rng = new XORShiftRandom(seed + partitionIndex)
   }
 
-  @transient protected lazy val seed: Long = child match {
+  override def seedExpression: Expression = child
+
+  @transient protected lazy val seed: Long = seedExpression match {
     case Literal(s, IntegerType) => s.asInstanceOf[Int]
     case Literal(s, LongType) => s.asInstanceOf[Long]
     case _ => throw new AnalysisException(
@@ -62,6 +65,7 @@ abstract class RDG extends UnaryExpression with 
ExpectsInputTypes with Stateful
  * Usually the random seed needs to be renewed at each execution under 
streaming queries.
  */
 trait ExpressionWithRandomSeed {
+  def seedExpression: Expression
   def withNewSeed(seed: Long): Expression
 }
 
@@ -84,14 +88,13 @@ trait ExpressionWithRandomSeed {
   since = "1.5.0",
   group = "math_funcs")
 // scalastyle:on line.size.limit
-case class Rand(child: Expression, hideSeed: Boolean = false)
-  extends RDG with ExpressionWithRandomSeed {
+case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
 
-  def this() = this(Literal(Utils.random.nextLong(), LongType), true)
+  def this() = this(UnresolvedSeed, true)
 
   def this(child: Expression) = this(child, false)
 
-  override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType))
+  override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), 
hideSeed)
 
   override protected def evalInternal(input: InternalRow): Double = 
rng.nextDouble()
 
@@ -136,14 +139,13 @@ object Rand {
   since = "1.5.0",
   group = "math_funcs")
 // scalastyle:on line.size.limit
-case class Randn(child: Expression, hideSeed: Boolean = false)
-  extends RDG with ExpressionWithRandomSeed {
+case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
 
-  def this() = this(Literal(Utils.random.nextLong(), LongType), true)
+  def this() = this(UnresolvedSeed, true)
 
   def this(child: Expression) = this(child, false)
 
-  override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType))
+  override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), 
hideSeed)
 
   override protected def evalInternal(input: InternalRow): Double = 
rng.nextGaussian()
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index b206bc9..f66871e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1006,4 +1006,16 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       checkAnalysis(plan, expect)
     }
   }
+
+  test("SPARK-33857: Unify the default seed of random functions") {
+    Seq(new Rand(), new Randn(), Shuffle(Literal(Array(1))), Uuid()).foreach { 
r =>
+      assert(r.seedExpression == UnresolvedSeed)
+      val p = getAnalyzer.execute(Project(Seq(r.as("r")), testRelation))
+      assert(
+        p.asInstanceOf[Project].projectList.head.asInstanceOf[Alias]
+          .child.asInstanceOf[ExpressionWithRandomSeed]
+          .seedExpression.isInstanceOf[Literal]
+      )
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to