http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala new file mode 100644 index 0000000..a80d29c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +}
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala new file mode 100644 index 0000000..23f05ce --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala new file mode 100644 index 0000000..297b193 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala deleted file mode 100644 index 1963fc3..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.SparkFunSuite - -/** - * Test various parser errors. - */ -class ErrorParserSuite extends SparkFunSuite { - def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { - val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) - - // Check position. - assert(e.line.isDefined) - assert(e.line.get === line) - assert(e.startPosition.isDefined) - assert(e.startPosition.get === startPosition) - - // Check messages. - val error = e.getMessage - messages.foreach { message => - assert(error.contains(message)) - } - } - - test("no viable input") { - intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") - intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") - intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") - } - - test("extraneous input") { - intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") - intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") - } - - test("mismatched input") { - intercept("select * from r order by q from t", 1, 27, - "mismatched input", - "---------------------------^^^") - intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") - } - - test("semantic errors") { - intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, - "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", - "^^^") - intercept("select * from r where a in (select * from t)", 1, 24, - "IN with a Sub-query is currently not supported", - "------------------------^^^") - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala deleted file mode 100644 index 32311a5..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala +++ /dev/null @@ -1,497 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * Test basic expression parsing. If a type of expression is supported it should be tested here. - * - * Please note that some of the expressions test don't have to be sound expressions, only their - * structure needs to be valid. Unsound expressions should be caught by the Analyzer or - * CheckAnalysis classes. - */ -class ExpressionParserSuite extends PlanTest { - import CatalystSqlParser._ - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - def assertEqual(sqlCommand: String, e: Expression): Unit = { - compareExpressions(parseExpression(sqlCommand), e) - } - - def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parseExpression(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } - - test("star expressions") { - // Global Star - assertEqual("*", UnresolvedStar(None)) - - // Targeted Star - assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) - } - - // NamedExpression (Alias/Multialias) - test("named expressions") { - // No Alias - val r0 = 'a - assertEqual("a", r0) - - // Single Alias. - val r1 = 'a as "b" - assertEqual("a as b", r1) - assertEqual("a b", r1) - - // Multi-Alias - assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) - assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) - - // Numeric literals without a space between the literal qualifier and the alias, should not be - // interpreted as such. An unresolved reference should be returned instead. - // TODO add the JIRA-ticket number. - assertEqual("1SL", Symbol("1SL")) - - // Aliased star is allowed. - assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) - } - - test("binary logical expressions") { - // And - assertEqual("a and b", 'a && 'b) - - // Or - assertEqual("a or b", 'a || 'b) - - // Combination And/Or check precedence - assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) - assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) - - // Multiple AND/OR get converted into a balanced tree - assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) - assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) - } - - test("long binary logical expressions") { - def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { - val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) - val e = parseExpression(sql) - assert(e.collect { case _: EqualTo => true }.size === 1000) - assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) - } - testVeryBinaryExpression(" AND ", classOf[And]) - testVeryBinaryExpression(" OR ", classOf[Or]) - } - - test("not expressions") { - assertEqual("not a", !'a) - assertEqual("!a", !'a) - assertEqual("not true > true", Not(GreaterThan(true, true))) - } - - test("exists expression") { - intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") - } - - test("comparison expressions") { - assertEqual("a = b", 'a === 'b) - assertEqual("a == b", 'a === 'b) - assertEqual("a <=> b", 'a <=> 'b) - assertEqual("a <> b", 'a =!= 'b) - assertEqual("a != b", 'a =!= 'b) - assertEqual("a < b", 'a < 'b) - assertEqual("a <= b", 'a <= 'b) - assertEqual("a > b", 'a > 'b) - assertEqual("a >= b", 'a >= 'b) - } - - test("between expressions") { - assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) - assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) - } - - test("in expressions") { - assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) - assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) - } - - test("in sub-query") { - intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") - } - - test("like expressions") { - assertEqual("a like 'pattern%'", 'a like "pattern%") - assertEqual("a not like 'pattern%'", !('a like "pattern%")) - assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") - assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) - assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") - assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) - } - - test("is null expressions") { - assertEqual("a is null", 'a.isNull) - assertEqual("a is not null", 'a.isNotNull) - assertEqual("a = b is null", ('a === 'b).isNull) - assertEqual("a = b is not null", ('a === 'b).isNotNull) - } - - test("binary arithmetic expressions") { - // Simple operations - assertEqual("a * b", 'a * 'b) - assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", ('a / 'b).cast(LongType)) - assertEqual("a % b", 'a % 'b) - assertEqual("a + b", 'a + 'b) - assertEqual("a - b", 'a - 'b) - assertEqual("a & b", 'a & 'b) - assertEqual("a ^ b", 'a ^ 'b) - assertEqual("a | b", 'a | 'b) - - // Check precedences - assertEqual( - "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) - } - - test("unary arithmetic expressions") { - assertEqual("+a", 'a) - assertEqual("-a", -'a) - assertEqual("~a", ~'a) - assertEqual("-+~~a", -(~(~'a))) - } - - test("cast expressions") { - // Note that DataType parsing is tested elsewhere. - assertEqual("cast(a as int)", 'a.cast(IntegerType)) - assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) - assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType))) - assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) - } - - test("function expressions") { - assertEqual("foo()", 'foo.function()) - assertEqual("foo.bar()", Symbol("foo.bar").function()) - assertEqual("foo(*)", 'foo.function(star())) - assertEqual("count(*)", 'count.function(1)) - assertEqual("foo(a, b)", 'foo.function('a, 'b)) - assertEqual("foo(all a, b)", 'foo.function('a, 'b)) - assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) - assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) - assertEqual("`select`(all a, b)", 'select.function('a, 'b)) - } - - test("window function expressions") { - val func = 'foo.function(star()) - def windowed( - partitioning: Seq[Expression] = Seq.empty, - ordering: Seq[SortOrder] = Seq.empty, - frame: WindowFrame = UnspecifiedFrame): Expression = { - WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) - } - - // Basic window testing. - assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) - assertEqual("foo(*) over ()", windowed()) - assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) - assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) - - // Test use of expressions in window functions. - assertEqual( - "sum(product + 1) over (partition by ((product) + (1)) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) - assertEqual( - "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) - - // Range/Row - val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) - val boundaries = Seq( - ("10 preceding", ValuePreceding(10), CurrentRow), - ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis - ("unbounded preceding", UnboundedPreceding, CurrentRow), - ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis - ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), - ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), - ("between current row and 5 following", CurrentRow, ValueFollowing(5)), - ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) - ) - frameTypes.foreach { - case (frameTypeSql, frameType) => - boundaries.foreach { - case (boundarySql, begin, end) => - val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" - val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) - assertEqual(query, expr) - } - } - - // We cannot use non integer constants. - intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", - "Frame bound value must be a constant integer.") - - // We cannot use an arbitrary expression. - intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", - "Frame bound value must be a constant integer.") - } - - test("row constructor") { - // Note that '(a)' will be interpreted as a nested expression. - assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) - assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) - } - - test("scalar sub-query") { - assertEqual( - "(select max(val) from tbl) > current", - ScalarSubquery(table("tbl").select('max.function('val))) > 'current) - assertEqual( - "a = (select b from s)", - 'a === ScalarSubquery(table("s").select('b))) - } - - test("case when") { - assertEqual("case a when 1 then b when 2 then c else d end", - CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) - assertEqual("case when a = 1 then b when a = 2 then c else d end", - CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) - } - - test("dereference") { - assertEqual("a.b", UnresolvedAttribute("a.b")) - assertEqual("`select`.b", UnresolvedAttribute("select.b")) - assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. - assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) - } - - test("reference") { - // Regular - assertEqual("a", 'a) - - // Starting with a digit. - assertEqual("1a", Symbol("1a")) - - // Quoted using a keyword. - assertEqual("`select`", 'select) - - // Unquoted using an unreserved keyword. - assertEqual("columns", 'columns) - } - - test("subscript") { - assertEqual("a[b]", 'a.getItem('b)) - assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) - assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) - } - - test("parenthesis") { - assertEqual("(a)", 'a) - assertEqual("r * (a + b)", 'r * ('a + 'b)) - } - - test("type constructors") { - // Dates. - assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) - intercept[IllegalArgumentException] { - parseExpression("DAtE 'mar 11 2016'") - } - - // Timestamps. - assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", - Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) - intercept[IllegalArgumentException] { - parseExpression("timestamP '2016-33-11 20:54:00.000'") - } - - // Unsupported datatype. - intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") - } - - test("literals") { - // NULL - assertEqual("null", Literal(null)) - - // Boolean - assertEqual("trUe", Literal(true)) - assertEqual("False", Literal(false)) - - // Integral should have the narrowest possible type - assertEqual("787324", Literal(787324)) - assertEqual("7873247234798249234", Literal(7873247234798249234L)) - assertEqual("78732472347982492793712334", - Literal(BigDecimal("78732472347982492793712334").underlying())) - - // Decimal - assertEqual("7873247234798249279371.2334", - Literal(BigDecimal("7873247234798249279371.2334").underlying())) - - // Scientific Decimal - assertEqual("9.0e1", 90d) - assertEqual(".9e+2", 90d) - assertEqual("0.9e+2", 90d) - assertEqual("900e-1", 90d) - assertEqual("900.0E-1", 90d) - assertEqual("9.e+1", 90d) - intercept(".e3") - - // Tiny Int Literal - assertEqual("10Y", Literal(10.toByte)) - intercept("-1000Y") - - // Small Int Literal - assertEqual("10S", Literal(10.toShort)) - intercept("40000S") - - // Long Int Literal - assertEqual("10L", Literal(10L)) - intercept("78732472347982492793712334L") - - // Double Literal - assertEqual("10.0D", Literal(10.0D)) - // TODO we need to figure out if we should throw an exception here! - assertEqual("1E309", Literal(Double.PositiveInfinity)) - } - - test("strings") { - // Single Strings. - assertEqual("\"hello\"", "hello") - assertEqual("'hello'", "hello") - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld") - assertEqual("'hello' \" \" 'world'", "hello world") - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%") - assertEqual("'no-pattern\\%'", "no-pattern\\%") - assertEqual("'pattern\\\\%'", "pattern\\%") - assertEqual("'pattern\\\\\\%'", "pattern\\\\%") - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') - assertEqual("'\\''", "\'") // Single quote - assertEqual("'\\\"'", "\"") // Double quote - assertEqual("'\\b'", "\b") // Backspace - assertEqual("'\\n'", "\n") // Newline - assertEqual("'\\r'", "\r") // Carriage return - assertEqual("'\\t'", "\t") // Tab character - assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") - - // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") - } - - test("intervals") { - def intervalLiteral(u: String, s: String): Literal = { - Literal(CalendarInterval.fromSingleUnitString(u, s)) - } - - // Empty interval statement - intercept("interval", "at least one time unit should be given for interval literal") - - // Single Intervals. - val units = Seq( - "year", - "month", - "week", - "day", - "hour", - "minute", - "second", - "millisecond", - "microsecond") - val forms = Seq("", "s") - val values = Seq("0", "10", "-7", "21") - units.foreach { unit => - forms.foreach { form => - values.foreach { value => - val expected = intervalLiteral(unit, value) - assertEqual(s"interval $value $unit$form", expected) - assertEqual(s"interval '$value' $unit$form", expected) - } - } - } - - // Hive nanosecond notation. - assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) - assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) - - // Non Existing unit - intercept("interval 10 nanoseconds", "No interval can be constructed") - - // Year-Month intervals. - val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") - yearMonthValues.foreach { value => - val result = Literal(CalendarInterval.fromYearMonthString(value)) - assertEqual(s"interval '$value' year to month", result) - } - - // Day-Time intervals. - val datTimeValues = Seq( - "99 11:22:33.123456789", - "-99 11:22:33.123456789", - "10 9:8:7.123456789", - "1 0:0:0", - "-1 0:0:0", - "1 0:0:1") - datTimeValues.foreach { value => - val result = Literal(CalendarInterval.fromDayTimeString(value)) - assertEqual(s"interval '$value' day to second", result) - } - - // Unknown FROM TO intervals - intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") - - // Composed intervals. - assertEqual( - "interval 3 months 22 seconds 1 millisecond", - Literal(new CalendarInterval(3, 22001000L))) - assertEqual( - "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", - Literal(new CalendarInterval(14, - 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) - } - - test("composed expressions") { - assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) - assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) - intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala deleted file mode 100644 index 4206d22..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala +++ /dev/null @@ -1,429 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{BooleanType, IntegerType} - -class PlanParserSuite extends PlanTest { - import CatalystSqlParser._ - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan) - } - - def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parsePlan(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } - - test("case insensitive") { - val plan = table("a").select(star()) - assertEqual("sELEct * FroM a", plan) - assertEqual("select * fRoM a", plan) - assertEqual("SELECT * FROM a", plan) - } - - test("show functions") { - assertEqual("show functions", ShowFunctions(None, None)) - assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) - assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) - assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) - intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") - } - - test("describe function") { - assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) - assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) - assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) - assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) - } - - test("set operations") { - val a = table("a").select(star()) - val b = table("b").select(star()) - - assertEqual("select * from a union select * from b", Distinct(a.union(b))) - assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) - assertEqual("select * from a union all select * from b", a.union(b)) - assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") - assertEqual("select * from a except distinct select * from b", a.except(b)) - assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") - assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) - } - - test("common table expressions") { - def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { - val ctes = namedPlans.map { - case (name, cte) => - name -> SubqueryAlias(name, cte) - }.toMap - With(plan, ctes) - } - assertEqual( - "with cte1 as (select * from a) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) - assertEqual( - "with cte1 (select 1) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) - assertEqual( - "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", - cte(table("cte2").select(star()), - "cte1" -> OneRowRelation.select(1), - "cte2" -> table("cte1").select(star()))) - intercept( - "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", - "Name 'cte1' is used for multiple common table expressions") - } - - test("simple select query") { - assertEqual("select 1", OneRowRelation.select(1)) - assertEqual("select a, b", OneRowRelation.select('a, 'b)) - assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) - assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) - assertEqual( - "select a, b from db.c having x < 1", - table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) - assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) - assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) - } - - test("reverse select query") { - assertEqual("from a", table("a")) - assertEqual("from a select b, c", table("a").select('b, 'c)) - assertEqual( - "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) - assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) - assertEqual( - "from (from a union all from b) c select *", - table("a").union(table("b")).as("c").select(star())) - } - - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - - test("multi select query") { - assertEqual( - "from a select * select * where s < 10", - table("a").select(star()).union(table("a").where('s < 10).select(star()))) - intercept( - "from a select * select * from x where a.s < 10", - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") - assertEqual( - "from a insert into tbl1 select * insert into tbl2 select * where s < 10", - table("a").select(star()).insertInto("tbl1").union( - table("a").where('s < 10).select(star()).insertInto("tbl2"))) - } - - test("query organization") { - // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows - val baseSql = "select * from t" - val basePlan = table("t").select(star()) - - val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) - val limitWindowClauses = Seq( - ("", (p: LogicalPlan) => p), - (" limit 10", (p: LogicalPlan) => p.limit(10)), - (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), - (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) - ) - - val orderSortDistrClusterClauses = Seq( - ("", basePlan), - (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)), - (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) - ) - - orderSortDistrClusterClauses.foreach { - case (s1, p1) => - limitWindowClauses.foreach { - case (s2, pf2) => - assertEqual(baseSql + s1 + s2, pf2(p1)) - } - } - - val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" - intercept(s"$baseSql order by a sort by a", msg) - intercept(s"$baseSql cluster by a distribute by a", msg) - intercept(s"$baseSql order by a cluster by a", msg) - intercept(s"$baseSql order by a distribute by a", msg) - } - - test("insert into") { - val sql = "select * from t" - val plan = table("t").select(star()) - def insert( - partition: Map[String, Option[String]], - overwrite: Boolean = false, - ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) - - // Single inserts - assertEqual(s"insert overwrite table s $sql", - insert(Map.empty, overwrite = true)) - assertEqual(s"insert overwrite table s if not exists $sql", - insert(Map.empty, overwrite = true, ifNotExists = true)) - assertEqual(s"insert into s $sql", - insert(Map.empty)) - assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", - insert(Map("c" -> Option("d"), "e" -> Option("1")))) - assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", - insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) - - // Multi insert - val plan2 = table("t").where('x > 5).select(star()) - assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", - InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( - InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) - } - - test("aggregation") { - val sql = "select a, b, sum(c) as c from d group by a, b" - - // Normal - assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) - - // Cube - assertEqual(s"$sql with cube", - table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) - - // Rollup - assertEqual(s"$sql with rollup", - table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) - - // Grouping Sets - assertEqual(s"$sql grouping sets((a, b), (a), ())", - GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) - intercept(s"$sql grouping sets((a, b), (c), ())", - "c doesn't show up in the GROUP BY list") - } - - test("limit") { - val sql = "select * from t" - val plan = table("t").select(star()) - assertEqual(s"$sql limit 10", plan.limit(10)) - assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) - } - - test("window spec") { - // Note that WindowSpecs are testing in the ExpressionParserSuite - val sql = "select * from t" - val plan = table("t").select(star()) - val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), - SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) - - // Test window resolution. - val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) - assertEqual( - s"""$sql - |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), - | w2 as w1, - | w3 as w1""".stripMargin, - WithWindowDefinition(ws1, plan)) - - // Fail with no reference. - intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") - - // Fail when resolved reference is not a window spec. - intercept( - s"""$sql - |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), - | w2 as w1, - | w3 as w2""".stripMargin, - "Window reference 'w2' is not a window specification" - ) - } - - test("lateral view") { - // Single lateral view - assertEqual( - "select * from t lateral view explode(x) expl as x", - table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) - .select(star())) - - // Multiple lateral views - assertEqual( - """select * - |from t - |lateral view explode(x) expl - |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, - table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) - .select(star())) - - // Multi-Insert lateral views. - val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) - assertEqual( - """from t1 - |lateral view explode(x) expl as x - |insert into t2 - |select * - |lateral view json_tuple(x, y) jtup q, z - |insert into t3 - |select * - |where s < 10 - """.stripMargin, - Union(from - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) - .select(star()) - .insertInto("t2"), - from.where('s < 10).select(star()).insertInto("t3"))) - - // Unsupported generator. - intercept( - "select * from t lateral view posexplode(x) posexpl as x, y", - "Generator function 'posexplode' is not supported") - } - - test("joins") { - // Test single joins. - val testUnconditionalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t as tt $sql u", - table("t").as("tt").join(table("u"), jt, None).select(star())) - } - val testConditionalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t $sql u as uu on a = b", - table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) - } - val testNaturalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t tt natural $sql u as uu", - table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) - } - val testUsingJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t $sql u using(a, b)", - table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) - } - val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) - - def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { - tests.foreach(_(sql, jt)) - } - test("cross join", Inner, Seq(testUnconditionalJoin)) - test(",", Inner, Seq(testUnconditionalJoin)) - test("join", Inner, testAll) - test("inner join", Inner, testAll) - test("left join", LeftOuter, testAll) - test("left outer join", LeftOuter, testAll) - test("right join", RightOuter, testAll) - test("right outer join", RightOuter, testAll) - test("full join", FullOuter, testAll) - test("full outer join", FullOuter, testAll) - - // Test multiple consecutive joins - assertEqual( - "select * from a join b join c right join d", - table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) - } - - test("sampled relations") { - val sql = "select * from t" - assertEqual(s"$sql tablesample(100 rows)", - table("t").limit(100).select(star())) - assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) - assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) - intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", - "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") - intercept(s"$sql tablesample(bucket 11 out of 10) as x", - s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") - } - - test("sub-query") { - val plan = table("t0").select('id) - assertEqual("select id from (t0)", plan) - assertEqual("select id from ((((((t0))))))", plan) - assertEqual( - "(select * from t1) union distinct (select * from t2)", - Distinct(table("t1").select(star()).union(table("t2").select(star())))) - assertEqual( - "select * from ((select * from t1) union (select * from t2)) t", - Distinct( - table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) - assertEqual( - """select id - |from (((select id from t0) - | union all - | (select id from t0)) - | union all - | (select id from t0)) as u_1 - """.stripMargin, - plan.union(plan).union(plan).as("u_1").select('id)) - } - - test("scalar sub-query") { - assertEqual( - "select (select max(b) from s) ss from t", - table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) - assertEqual( - "select * from t where a = (select b from s)", - table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) - assertEqual( - "select g from t group by g having a > (select b from s)", - table("t") - .groupBy('g)('g) - .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) - } - - test("table reference") { - assertEqual("table t", table("t")) - assertEqual("table d.t", table("d", "t")) - } - - test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) - assertEqual( - "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") - intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala deleted file mode 100644 index 0874322..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier - -class TableIdentifierParserSuite extends SparkFunSuite { - import CatalystSqlParser._ - - test("table identifier") { - // Regular names. - assert(TableIdentifier("q") === parseTableIdentifier("q")) - assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) - - // Illegal names. - intercept[ParseException](parseTableIdentifier("")) - intercept[ParseException](parseTableIdentifier("d.q.g")) - - // SQL Keywords. - val keywords = Seq("select", "from", "where", "left", "right") - keywords.foreach { keyword => - intercept[ParseException](parseTableIdentifier(keyword)) - assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) - assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) - } - } -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org