This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 9268392  [SPARK-33945][SQL][3.1] Handles a random seed consisting of 
an expr tree
9268392 is described below

commit 9268392b957b263692e13fecaf9adec2136e1865
Author: Takeshi Yamamuro <yamam...@apache.org>
AuthorDate: Sun Jan 3 21:36:25 2021 -0800

    [SPARK-33945][SQL][3.1] Handles a random seed consisting of an expr tree
    
    ### What changes were proposed in this pull request?
    
    This PR intends to fix the minor bug that throws an analysis exception when 
a seed param in `rand`/`randn` having a expr tree (e.g., `rand(1 + 1)`) with 
constant folding (`ConstantFolding` and `ReorderAssociativeOperator`) disabled. 
A query to reproduce this issue is as follows;
    ```
    // v3.1.0, v3.0.2, and v2.4.8
    $./bin/spark-shell
    scala> sql("select rand(1 + 2)").show()
    +-------------------+
    |      rand((1 + 2))|
    +-------------------+
    |0.25738143505962285|
    +-------------------+
    
    $./bin/spark-shell --conf 
spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConstantFolding,org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator
    scala> sql("select rand(1 + 2)").show()
    org.apache.spark.sql.AnalysisException: Input argument to rand must be an 
integer, long or null literal.;
      at 
org.apache.spark.sql.catalyst.expressions.RDG.seed$lzycompute(randomExpressions.scala:49)
      at 
org.apache.spark.sql.catalyst.expressions.RDG.seed(randomExpressions.scala:46)
      at 
org.apache.spark.sql.catalyst.expressions.Rand.doGenCode(randomExpressions.scala:98)
      at 
org.apache.spark.sql.catalyst.expressions.Expression.$anonfun$genCode$3(Expression.scala:146)
      at scala.Option.getOrElse(Option.scala:189)
      ...
    ```
    
    A root cause is that the match-case code below cannot handle the case 
described above:
    
https://github.com/apache/spark/blob/42f5e62403469cec6da680b9fbedd0aa508dcbe5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala#L46-L51
    
    ### Why are the changes needed?
    
    Bugfix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Checking if GA/Jenkins can pass
    
    Closes #30977 from maropu/FixRandSeedIssue.
    
    Authored-by: Takeshi Yamamuro <yamam...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/expressions/randomExpressions.scala    |  6 +++---
 .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 17 ++++++++++++++++-
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 6a94517..a14b1fa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -44,10 +44,10 @@ abstract class RDG extends UnaryExpression with 
ExpectsInputTypes with Stateful
   }
 
   @transient protected lazy val seed: Long = child match {
-    case Literal(s, IntegerType) => s.asInstanceOf[Int]
-    case Literal(s, LongType) => s.asInstanceOf[Long]
+    case e if child.foldable && e.dataType == IntegerType => 
e.eval().asInstanceOf[Int]
+    case e if child.foldable && e.dataType == LongType => 
e.eval().asInstanceOf[Long]
     case _ => throw new AnalysisException(
-      s"Input argument to $prettyName must be an integer, long or null 
literal.")
+      s"Input argument to $prettyName must be an integer, long, or null 
constant.")
   }
 
   override def nullable: Boolean = false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 237d2c3..a003275 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException}
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql.catalyst.expressions.GenericRow
 import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
-import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, 
NestedColumnAliasingSuite}
+import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, 
ConvertToLocalRelation, NestedColumnAliasingSuite, ReorderAssociativeOperator}
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.catalyst.util.StringUtils
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -3758,6 +3758,21 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
       }
     })
   }
+
+  test("SPARK-33945: handles a random seed consisting of an expr tree") {
+    val excludedRules = Seq(ConstantFolding, 
ReorderAssociativeOperator).map(_.ruleName)
+    withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> 
excludedRules.mkString(",")) {
+      Seq("rand", "randn").foreach { f =>
+        // Just checks if a query works correctly
+        sql(s"SELECT $f(1 + 1)").collect()
+
+        val msg = intercept[AnalysisException] {
+          sql(s"SELECT $f(id + 1) FROM range(0, 3)").collect()
+        }.getMessage
+        assert(msg.contains("must be an integer, long, or null constant"))
+      }
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to