advancedxy commented on code in PR #455: URL: https://github.com/apache/datafusion-comet/pull/455#discussion_r1614532699
########## spark/src/test/scala/org/apache/comet/CometExpressionCoverageSuite.scala: ########## @@ -217,6 +325,25 @@ class CometExpressionCoverageSuite extends CometTestBase with AdaptiveSparkPlanH str shouldBe s"${getLicenseHeader()}\n# Supported Spark Expressions\n\n### group1\n - [x] f1\n - [ ] f2\n\n### group2\n - [x] f3\n - [ ] f4\n\n### group3\n - [x] f5" } + test("get sql function arguments") { + // getSqlFunctionArguments("SELECT unix_seconds(TIMESTAMP('1970-01-01 00:00:01Z'))") shouldBe Seq("TIMESTAMP('1970-01-01 00:00:01Z')") + // getSqlFunctionArguments("SELECT decode(unhex('537061726B2053514C'), 'UTF-8')") shouldBe Seq("unhex('537061726B2053514C')", "'UTF-8'") + // getSqlFunctionArguments("SELECT extract(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456')") shouldBe Seq("'YEAR'", "TIMESTAMP '2019-08-12 01:00:00.123456'") + // getSqlFunctionArguments("SELECT exists(array(1, 2, 3), x -> x % 2 == 0)") shouldBe Seq("array(1, 2, 3)") + getSqlFunctionArguments("select to_char(454, '999')") shouldBe Seq("array(1, 2, 3)") Review Comment: this test is wrong? the arguments are not correct. ########## spark/src/test/scala/org/apache/comet/CometExpressionCoverageSuite.scala: ########## @@ -54,16 +57,79 @@ class CometExpressionCoverageSuite extends CometTestBase with AdaptiveSparkPlanH private val valuesPattern = """(?i)FROM VALUES(.+?);""".r private val selectPattern = """(i?)SELECT(.+?)FROM""".r + // exclude funcs Comet has no plans to support streaming in near future + // like spark streaming functions, java calls + private val outofRoadmapFuncs = + List("window", "session_window", "window_time", "java_method", "reflect") + private val sqlConf = Seq( + "spark.comet.exec.shuffle.enabled" -> "true", + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding", + "spark.sql.adaptive.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + + // Tests to run manually as its syntax is different from usual or nested + val manualTests: Map[String, (String, String)] = Map( + "!" -> ("select true a", "select ! true from tbl"), + "%" -> ("select 1 a, 2 b", "select a + b from tbl"), Review Comment: the mapped should be `select a % b from the tbl`? ########## spark/src/test/scala/org/apache/comet/CometExpressionCoverageSuite.scala: ########## @@ -54,16 +57,79 @@ class CometExpressionCoverageSuite extends CometTestBase with AdaptiveSparkPlanH private val valuesPattern = """(?i)FROM VALUES(.+?);""".r private val selectPattern = """(i?)SELECT(.+?)FROM""".r + // exclude funcs Comet has no plans to support streaming in near future + // like spark streaming functions, java calls + private val outofRoadmapFuncs = + List("window", "session_window", "window_time", "java_method", "reflect") + private val sqlConf = Seq( + "spark.comet.exec.shuffle.enabled" -> "true", + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding", + "spark.sql.adaptive.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") + + // Tests to run manually as its syntax is different from usual or nested + val manualTests: Map[String, (String, String)] = Map( + "!" -> ("select true a", "select ! true from tbl"), + "%" -> ("select 1 a, 2 b", "select a + b from tbl"), Review Comment: Or maybe you can just generate the binary operators and its mappings in a pragmatic way? Such as: ```scala Seq("%", "&", ..., "|").map(x => x -> ("select 1 a, 2 b", s"select a $x b from tbl") ``` ########## spark/src/test/scala/org/apache/comet/CometExpressionCoverageSuite.scala: ########## @@ -116,20 +182,62 @@ class CometExpressionCoverageSuite extends CometTestBase with AdaptiveSparkPlanH // ConstantFolding is a operator optimization rule in Catalyst that replaces expressions // that can be statically evaluated with their equivalent literal values. dfMessage = runDatafusionCli(q) - testSingleLineQuery( - "select 'dummy' x", - s"${q.dropRight(1)}, x from tbl", - excludedOptimizerRules = - Some("org.apache.spark.sql.catalyst.optimizer.ConstantFolding")) + + manualTests.get(func.name) match { + // the test is manual query + case Some(test) => testSingleLineQuery(test._1, test._2, sqlConf = sqlConf) + case None => + // extract function arguments as a sql text + // example: + // cos(0) -> 0 + // explode_outer(array(10, 20)) -> array(10, 20) + val args = getSqlFunctionArguments(q.dropRight(1)) + val (aliased, aliases) = + if (Seq( + "bround", + "rlike", + "round", + "to_binary", + "to_char", + "to_number", + "try_to_binary", + "try_to_number", + "xpath", + "xpath_boolean", + "xpath_double", + "xpath_double", + "xpath_float", + "xpath_int", + "xpath_long", + "xpath_number", + "xpath_short", + "xpath_string").contains(func.name.toLowerCase)) { Review Comment: We can also extract this seq into a constant field? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org