http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala new file mode 100644 index 0000000..a98c258 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/AggregationsITCase.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.types.Row +import org.apache.flink.table.api.{TableEnvironment, ValidationException} +import org.apache.flink.table.examples.scala.WordCountTable.{WC => MyWC} +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class AggregationsITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + @Test + def testAggregationTypes(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1.sum, '_1.min, '_1.max, '_1.count, '_1.avg) + + val results = t.toDataSet[Row].collect() + val expected = "231,1,21,21,11" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testAggregationOnNonExistingField(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + // Must fail. Field 'foo does not exist. + .select('foo.avg) + } + + @Test + def testWorkingAggregationDataTypes(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements( + (1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, "Hello"), + (2: Byte, 2: Short, 2, 2L, 2.0f, 2.0d, "Ciao")).toTable(tEnv) + .select('_1.avg, '_2.avg, '_3.avg, '_4.avg, '_5.avg, '_6.avg, '_7.count) + + val expected = "1,1,1,1,1.5,1.5,2" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testProjection(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements( + (1: Byte, 1: Short), + (2: Byte, 2: Short)).toTable(tEnv) + .select('_1.avg, '_1.sum, '_1.count, '_2.avg, '_2.sum) + + val expected = "1,3,2,1,3" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAggregationWithArithmetic(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements((1f, "Hello"), (2f, "Ciao")).toTable(tEnv) + .select(('_1 + 2).avg + 2, '_2.count + 5) + + val expected = "5.5,7" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAggregationWithTwoCount(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements((1f, "Hello"), (2f, "Ciao")).toTable(tEnv) + .select('_1.count, '_2.count) + + val expected = "2,2" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAggregationAfterProjection(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements( + (1: Byte, 1: Short, 1, 1L, 1.0f, 1.0d, "Hello"), + (2: Byte, 2: Short, 2, 2L, 2.0f, 2.0d, "Ciao")).toTable(tEnv) + .select('_1, '_2, '_3) + .select('_1.avg, '_2.sum, '_3.count) + + val expected = "1,3,2" + val result = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testNonWorkingAggregationDataTypes(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements(("Hello", 1)).toTable(tEnv) + // Must fail. Field '_1 is not a numeric type. + .select('_1.sum) + + t.collect() + } + + @Test(expected = classOf[ValidationException]) + def testNoNestedAggregations(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env.fromElements(("Hello", 1)).toTable(tEnv) + // Must fail. Sum aggregation can not be chained. + .select('_2.sum.sum) + } + + @Test + def testSQLStyleAggregations(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select( + """Sum( a) as a1, a.sum as a2, + |Min (a) as b1, a.min as b2, + |Max (a ) as c1, a.max as c2, + |Avg ( a ) as d1, a.avg as d2, + |Count(a) as e1, a.count as e2 + """.stripMargin) + + val expected = "231,231,1,1,21,21,11,11,21,21" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testPojoAggregation(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val input = env.fromElements( + MyWC("hello", 1), + MyWC("hello", 1), + MyWC("ciao", 1), + MyWC("hola", 1), + MyWC("hola", 1)) + val expr = input.toTable(tEnv) + val result = expr + .groupBy('word) + .select('word, 'frequency.sum as 'frequency) + .filter('frequency === 2) + .toDataSet[MyWC] + + val mappedResult = result.map(w => (w.word, w.frequency * 10)).collect() + val expected = "(hello,20)\n" + "(hola,20)" + TestBaseUtils.compareResultAsText(mappedResult.asJava, expected) + } + + @Test + def testDistinct(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val distinct = ds.select('b).distinct() + + val expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + val results = distinct.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testDistinctAfterAggregate(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + val distinct = ds.groupBy('a, 'e).select('e).distinct() + + val expected = "1\n" + "2\n" + "3\n" + val results = distinct.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testGroupingOnNonExistentField(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + // must fail. '_foo not a valid field + .groupBy('_foo) + .select('a.avg) + } + + @Test(expected = classOf[ValidationException]) + def testGroupingInvalidSelection(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('a, 'b) + // must fail. 'c is not a grouping key or aggregation + .select('c) + } + + @Test + def testGroupedAggregate(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, 'a.sum) + + val expected = "1,1\n" + "2,5\n" + "3,15\n" + "4,34\n" + "5,65\n" + "6,111\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupingKeyForwardIfNotUsed(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('a.sum) + + val expected = "1\n" + "5\n" + "15\n" + "34\n" + "65\n" + "111\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupNoAggregation(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env) + .toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('a.sum as 'd, 'b) + .groupBy('b, 'd) + .select('b) + + val expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregateWithLongKeys(): Unit = { + // This uses very long keys to force serialized comparison. + // With short keys, the normalized key is sufficient. + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = env.fromElements( + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhaa", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2), + ("hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhab", 1, 2)) + .rebalance().setParallelism(2).toTable(tEnv, 'a, 'b, 'c) + .groupBy('a, 'b) + .select('c.sum) + + val expected = "10\n" + "8\n" + val results = ds.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregateWithConstant1(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select('a, 4 as 'four, 'b) + .groupBy('four, 'a) + .select('four, 'b.sum) + + val expected = "4,2\n" + "4,3\n" + "4,5\n" + "4,5\n" + "4,5\n" + "4,6\n" + + "4,6\n" + "4,6\n" + "4,3\n" + "4,4\n" + "4,6\n" + "4,1\n" + "4,4\n" + + "4,4\n" + "4,5\n" + "4,6\n" + "4,2\n" + "4,3\n" + "4,4\n" + "4,5\n" + "4,6\n" + val results = t.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregateWithConstant2(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select('b, 4 as 'four, 'a) + .groupBy('b, 'four) + .select('four, 'a.sum) + + val expected = "4,1\n" + "4,5\n" + "4,15\n" + "4,34\n" + "4,65\n" + "4,111\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregateWithExpression(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e, 'b % 3) + .select('c.min, 'e, 'a.avg, 'd.count) + + val expected = "0,1,1,1\n" + "3,2,3,3\n" + "7,1,4,2\n" + "14,2,5,1\n" + + "5,3,4,2\n" + "2,1,3,2\n" + "1,2,3,3\n" + "12,3,5,1" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testGroupedAggregateWithFilter(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, 'a.sum) + .where('b === 2) + + val expected = "2,5\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + +} +
http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala new file mode 100644 index 0000000..bc4f4bd --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import java.sql.{Date, Time, Timestamp} +import java.util + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.types.Row +import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException} +import org.apache.flink.table.expressions.Literal +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit.Assert._ +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class CalcITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + @Test + def testSimpleSelectAll(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv).select('_1, '_2, '_3) + + val expected = "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + + "4,3,Hello world, how are you?\n" + "5,3,I am fine.\n" + "6,3,Luke Skywalker\n" + + "7,4,Comment#1\n" + "8,4,Comment#2\n" + "9,4,Comment#3\n" + "10,4,Comment#4\n" + + "11,5,Comment#5\n" + "12,5,Comment#6\n" + "13,5,Comment#7\n" + "14,5,Comment#8\n" + + "15,5,Comment#9\n" + "16,6,Comment#10\n" + "17,6,Comment#11\n" + "18,6,Comment#12\n" + + "19,6,Comment#13\n" + "20,6,Comment#14\n" + "21,6,Comment#15\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSimpleSelectAllWithAs(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c).select('a, 'b, 'c) + + val expected = "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + + "4,3,Hello world, how are you?\n" + "5,3,I am fine.\n" + "6,3,Luke Skywalker\n" + + "7,4,Comment#1\n" + "8,4,Comment#2\n" + "9,4,Comment#3\n" + "10,4,Comment#4\n" + + "11,5,Comment#5\n" + "12,5,Comment#6\n" + "13,5,Comment#7\n" + "14,5,Comment#8\n" + + "15,5,Comment#9\n" + "16,6,Comment#10\n" + "17,6,Comment#11\n" + "18,6,Comment#12\n" + + "19,6,Comment#13\n" + "20,6,Comment#14\n" + "21,6,Comment#15\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSimpleSelectWithNaming(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1 as 'a, '_2 as 'b, '_1 as 'c) + .select('a, 'b) + + val expected = "1,1\n" + "2,2\n" + "3,2\n" + "4,3\n" + "5,3\n" + "6,3\n" + "7,4\n" + + "8,4\n" + "9,4\n" + "10,4\n" + "11,5\n" + "12,5\n" + "13,5\n" + "14,5\n" + "15,5\n" + + "16,6\n" + "17,6\n" + "18,6\n" + "19,6\n" + "20,6\n" + "21,6\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testSimpleSelectRenameAll(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1 as 'a, '_2 as 'b, '_3 as 'c) + .select('a, 'b) + + val expected = "1,1\n" + "2,2\n" + "3,2\n" + "4,3\n" + "5,3\n" + "6,3\n" + "7,4\n" + + "8,4\n" + "9,4\n" + "10,4\n" + "11,5\n" + "12,5\n" + "13,5\n" + "14,5\n" + "15,5\n" + + "16,6\n" + "17,6\n" + "18,6\n" + "19,6\n" + "20,6\n" + "21,6\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testSelectInvalidFieldFields(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + // must fail. Field 'foo does not exist + .select('a, 'foo) + } + + @Test(expected = classOf[ValidationException]) + def testSelectAmbiguousRenaming(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + // must fail. 'a and 'b are both renamed to 'foo + .select('a + 1 as 'foo, 'b + 2 as 'foo).toDataSet[Row].print() + } + + @Test(expected = classOf[ValidationException]) + def testSelectAmbiguousRenaming2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + // must fail. 'a and 'b are both renamed to 'a + .select('a, 'b as 'a).toDataSet[Row].print() + } + + @Test + def testSelectStar(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c).select('*) + + val expected = "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + + "4,3,Hello world, how are you?\n" + "5,3,I am fine.\n" + "6,3,Luke Skywalker\n" + + "7,4,Comment#1\n" + "8,4,Comment#2\n" + "9,4,Comment#3\n" + "10,4,Comment#4\n" + + "11,5,Comment#5\n" + "12,5,Comment#6\n" + "13,5,Comment#7\n" + "14,5,Comment#8\n" + + "15,5,Comment#9\n" + "16,6,Comment#10\n" + "17,6,Comment#11\n" + "18,6,Comment#12\n" + + "19,6,Comment#13\n" + "20,6,Comment#14\n" + "21,6,Comment#15\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAliasStarException(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + try { + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, '*, 'b, 'c) + fail("TableException expected") + } catch { + case _: TableException => //ignore + } + + try { + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1 as '*, '_2 as 'b, '_1 as 'c) + fail("ValidationException expected") + } catch { + case _: ValidationException => //ignore + } + + try { + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv).as('*, 'b, 'c) + fail("ValidationException expected") + } catch { + case _: ValidationException => //ignore + } + + try { + CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c).select('*, 'b) + fail("ValidationException expected") + } catch { + case _: ValidationException => //ignore + } + } + + @Test + def testAllRejectingFilter(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val filterDs = ds.filter( Literal(false) ) + + val expected = "\n" + val results = filterDs.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAllPassingFilter(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val filterDs = ds.filter( Literal(true) ) + val expected = "1,1,Hi\n" + "2,2,Hello\n" + "3,2,Hello world\n" + "4,3,Hello world, " + + "how are you?\n" + "5,3,I am fine.\n" + "6,3,Luke Skywalker\n" + "7,4," + + "Comment#1\n" + "8,4,Comment#2\n" + "9,4,Comment#3\n" + "10,4,Comment#4\n" + "11,5," + + "Comment#5\n" + "12,5,Comment#6\n" + "13,5,Comment#7\n" + "14,5,Comment#8\n" + "15,5," + + "Comment#9\n" + "16,6,Comment#10\n" + "17,6,Comment#11\n" + "18,6,Comment#12\n" + "19," + + "6,Comment#13\n" + "20,6,Comment#14\n" + "21,6,Comment#15\n" + val results = filterDs.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testFilterOnStringTupleField(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val filterDs = ds.filter( 'c.like("%world%") ) + + val expected = "3,2,Hello world\n" + "4,3,Hello world, how are you?\n" + val results = filterDs.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testFilterOnIntegerTupleField(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val filterDs = ds.filter( 'a % 2 === 0 ) + + val expected = "2,2,Hello\n" + "4,3,Hello world, how are you?\n" + + "6,3,Luke Skywalker\n" + "8,4," + "Comment#2\n" + "10,4,Comment#4\n" + + "12,5,Comment#6\n" + "14,5,Comment#8\n" + "16,6," + + "Comment#10\n" + "18,6,Comment#12\n" + "20,6,Comment#14\n" + val results = filterDs.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testNotEquals(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val filterDs = ds.filter( 'a % 2 !== 0) + val expected = "1,1,Hi\n" + "3,2,Hello world\n" + + "5,3,I am fine.\n" + "7,4,Comment#1\n" + "9,4,Comment#3\n" + + "11,5,Comment#5\n" + "13,5,Comment#7\n" + "15,5,Comment#9\n" + + "17,6,Comment#11\n" + "19,6,Comment#13\n" + "21,6,Comment#15\n" + val results = filterDs.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testDisjunctivePredicate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val filterDs = ds.filter( 'a < 2 || 'a > 20) + val expected = "1,1,Hi\n" + "21,6,Comment#15\n" + val results = filterDs.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testConsecutiveFilters(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val filterDs = ds.filter('a % 2 !== 0).filter('b % 2 === 0) + val expected = "3,2,Hello world\n" + "7,4,Comment#1\n" + + "9,4,Comment#3\n" + "17,6,Comment#11\n" + + "19,6,Comment#13\n" + "21,6,Comment#15\n" + val results = filterDs.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testFilterBasicType(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.getStringDataSet(env) + + val filterDs = ds.toTable(tEnv, 'a).filter( 'a.like("H%") ) + + val expected = "Hi\n" + "Hello\n" + "Hello world\n" + "Hello world, how are you?\n" + val results = filterDs.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testFilterOnCustomType(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.getCustomTypeDataSet(env) + val filterDs = ds.toTable(tEnv, 'myInt as 'i, 'myLong as 'l, 'myString as 's) + .filter( 's.like("%a%") ) + + val expected = "3,3,Hello world, how are you?\n" + "3,4,I am fine.\n" + "3,5,Luke Skywalker\n" + val results = filterDs.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testFilterInvalidFieldName(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + // must fail. Field 'foo does not exist + ds.filter( 'foo === 2 ) + } + + @Test + def testSimpleCalc(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1, '_2, '_3) + .where('_1 < 7) + .select('_1, '_3) + + val expected = "1,Hi\n" + "2,Hello\n" + "3,Hello world\n" + + "4,Hello world, how are you?\n" + "5,I am fine.\n" + "6,Luke Skywalker\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testCalcWithTwoFilters(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1, '_2, '_3) + .where('_1 < 7 && '_2 === 3) + .select('_1, '_3) + .where('_1 === 4) + .select('_1) + + val expected = "4\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testCalcWithAggregation(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv) + .select('_1, '_2, '_3) + .where('_1 < 15) + .groupBy('_2) + .select('_1.min, '_2.count as 'cnt) + .where('cnt > 3) + + val expected = "7,4\n" + "11,4\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testCalcJoin(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.select('a, 'b).join(ds2).where('b === 'e).select('a, 'b, 'd, 'e, 'f) + .where('b > 1).select('a, 'd).where('d === 2) + + val expected = "2,2\n" + "3,2\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testAdvancedDataTypes(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = env + .fromElements(( + BigDecimal("78.454654654654654").bigDecimal, + BigDecimal("4E+9999").bigDecimal, + Date.valueOf("1984-07-12"), + Time.valueOf("14:34:24"), + Timestamp.valueOf("1984-07-12 14:34:24"))) + .toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .select('a, 'b, 'c, 'd, 'e, BigDecimal("11.2"), BigDecimal("11.2").bigDecimal, + Date.valueOf("1984-07-12"), Time.valueOf("14:34:24"), + Timestamp.valueOf("1984-07-12 14:34:24")) + + val expected = "78.454654654654654,4E+9999,1984-07-12,14:34:24,1984-07-12 14:34:24.0," + + "11.2,11.2,1984-07-12,14:34:24,1984-07-12 14:34:24.0" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } +} + +object CalcITCase { + + @Parameterized.Parameters(name = "Execution mode = {0}, Table config = {1}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]]( + Array(TestExecutionMode.COLLECTION, TableProgramsTestBase.DEFAULT), + Array(TestExecutionMode.COLLECTION, TableProgramsTestBase.NO_NULL)).asJava + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala new file mode 100644 index 0000000..a80e0cb --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.api.scala.batch.table + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.{RowtimeAttribute, Upper, WindowReference} +import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.api.scala.batch.table.FieldProjectionTest._ +import org.apache.flink.table.plan.logical.EventTimeTumblingGroupWindow +import org.apache.flink.table.utils._ +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.table.utils.TableTestUtil._ +import org.junit.Test + +/** + * Tests for all the situations when we can do fields projection. Like selecting few fields + * from a large field count source. + */ +class FieldProjectionTest extends TableTestBase { + + val util: BatchTableTestUtil = batchTestUtil() + + val streamUtil: StreamTableTestUtil = streamTestUtil() + + @Test + def testSimpleSelect(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.select('a, 'b) + + val expected = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectAllFields(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable1 = sourceTable.select('*) + val resultTable2 = sourceTable.select('a, 'b, 'c, 'd) + + val expected = batchTableNode(0) + + util.verifyTable(resultTable1, expected) + util.verifyTable(resultTable2, expected) + } + + @Test + def testSelectAggregation(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.select('a.sum, 'b.max) + + val expected = unaryNode( + "DataSetAggregate", + binaryNode( + "DataSetUnion", + values( + "DataSetValues", + tuples(List(null, null)), + term("values", "a", "b") + ), + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") + ), + term("union", "a", "b") + ), + term("select", "SUM(a) AS TMP_0", "MAX(b) AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFunction(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + + util.tEnv.registerFunction("hashCode", MyHashCode) + + val resultTable = sourceTable.select("hashCode(c), b") + + val expected = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", s"${MyHashCode.getClass.getCanonicalName}(c) AS _c0", "b") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromGroupedTable(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy('a, 'c).select('a) + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c") + ), + term("groupBy", "a", "c"), + term("select", "a", "c") + ), + term("select", "a") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectAllFieldsFromGroupedTable(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy('a, 'c).select('a, 'c) + + val expected = unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c") + ), + term("groupBy", "a", "c"), + term("select", "a", "c") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectAggregationFromGroupedTable(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy('c).select('a.sum) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c") + ), + term("groupBy", "c"), + term("select", "c", "SUM(a) AS TMP_0") + ), + term("select", "TMP_0 AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromGroupedTableWithNonTrivialKey(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy(Upper('c) as 'k).select('a.sum) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c", "UPPER(c) AS k") + ), + term("groupBy", "k"), + term("select", "k", "SUM(a) AS TMP_0") + ), + term("select", "TMP_0 AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromGroupedTableWithFunctionKey(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy(MyHashCode('c) as 'k).select('a.sum) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c", s"${MyHashCode.getClass.getCanonicalName}(c) AS k") + ), + term("groupBy", "k"), + term("select", "k", "SUM(a) AS TMP_0") + ), + term("select", "TMP_0 AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromStreamingWindow(): Unit = { + val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable + .window(Tumble over 5.millis on 'rowtime as 'w) + .select(Upper('c).count, 'a.sum) + + val expected = + unaryNode( + "DataStreamAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "c", "a", "UPPER(c) AS $f2") + ), + term("window", + EventTimeTumblingGroupWindow( + Some(WindowReference("w")), + RowtimeAttribute(), + 5.millis)), + term("select", "COUNT($f2) AS TMP_0", "SUM(a) AS TMP_1") + ) + + streamUtil.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromStreamingGroupedWindow(): Unit = { + val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable + .groupBy('b) + .window(Tumble over 5.millis on 'rowtime as 'w) + .select(Upper('c).count, 'a.sum, 'b) + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "c", "a", "b", "UPPER(c) AS $f3") + ), + term("groupBy", "b"), + term("window", + EventTimeTumblingGroupWindow( + Some(WindowReference("w")), + RowtimeAttribute(), + 5.millis)), + term("select", "b", "COUNT($f3) AS TMP_0", "SUM(a) AS TMP_1") + ), + term("select", "TMP_0 AS TMP_2", "TMP_1 AS TMP_3", "b") + ) + + streamUtil.verifyTable(resultTable, expected) + } + + @Test(expected = classOf[ValidationException]) + def testSelectFromBatchWindow1(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + + // time field is selected + val resultTable = sourceTable + .window(Tumble over 5.millis on 'a as 'w) + .select('a.sum, 'c.count) + + val expected = "TODO" + + util.verifyTable(resultTable, expected) + } + + @Test(expected = classOf[ValidationException]) + def testSelectFromBatchWindow2(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + + // time field is not selected + val resultTable = sourceTable + .window(Tumble over 5.millis on 'a as 'w) + .select('c.count) + + val expected = "TODO" + + util.verifyTable(resultTable, expected) + } +} + +object FieldProjectionTest { + + object MyHashCode extends ScalarFunction { + def eval(s: String): Int = s.hashCode() + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/JoinITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/JoinITCase.scala new file mode 100644 index 0000000..ce16ada --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/JoinITCase.scala @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.types.Row +import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException} +import org.apache.flink.table.expressions.Literal +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class JoinITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + @Test + def testJoin(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).where('b === 'e).select('c, 'g) + + val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinWithFilter(): Unit = { + + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('d, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).where('b === 'e && 'b < 2).select('c, 'g) + + val expected = "Hi,Hallo\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinWithJoinFilter(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).where('b === 'e && 'a < 6 && 'h < 'b).select('c, 'g) + + val expected = "Hello world, how are you?,Hallo Welt wie\n" + + "I am fine.,Hallo Welt wie\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinWithMultipleKeys(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).filter('a === 'd && 'b === 'h).select('c, 'g) + + val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt wie gehts?\n" + + "Hello world,ABC\n" + "I am fine.,HIJ\n" + "I am fine.,IJK\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testJoinNonExistingKey(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + ds1.join(ds2) + // must fail. Field 'foo does not exist + .where('foo === 'e) + .select('c, 'g) + } + + @Test(expected = classOf[ValidationException]) + def testJoinWithNonMatchingKeyTypes(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + ds1.join(ds2) + // must fail. Field 'a is Int, and 'g is String + .where('a === 'g) + .select('c, 'g).collect() + } + + @Test(expected = classOf[ValidationException]) + def testJoinWithAmbiguousFields(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'c) + + ds1.join(ds2) + // must fail. Both inputs share the same field 'c + .where('a === 'd) + .select('c, 'g) + } + + @Test(expected = classOf[TableException]) + def testNoEqualityJoinPredicate1(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + ds1.join(ds2) + // must fail. No equality join predicate + .where('d === 'f) + .select('c, 'g).collect() + } + + @Test(expected = classOf[TableException]) + def testNoEqualityJoinPredicate2(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + ds1.join(ds2) + // must fail. No equality join predicate + .where('a < 'd) + .select('c, 'g).collect() + } + + @Test + def testJoinWithAggregation(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).where('a === 'd).select('g.count) + + val expected = "6" + val results = joinT.toDataSet[Row] collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinWithGroupedAggregation(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2) + .where('a === 'd) + .groupBy('a, 'd) + .select('b.sum, 'g.count) + + val expected = "6,3\n" + "4,2\n" + "1,1" + val results = joinT.toDataSet[Row] collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinPushThroughJoin(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + val ds3 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'j, 'k, 'l) + + val joinT = ds1.join(ds2) + .where(Literal(true)) + .join(ds3) + .where('a === 'd && 'e === 'k) + .select('a, 'f, 'l) + + val expected = "2,1,Hello\n" + "2,1,Hello world\n" + "1,0,Hi" + val results = joinT.toDataSet[Row] collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinWithDisjunctivePred(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).filter('a === 'd && ('b === 'e || 'b === 'e - 10)).select('c, 'g) + + val expected = "Hi,Hallo\n" + + "Hello,Hallo Welt\n" + + "I am fine.,IJK" + val results = joinT.toDataSet[Row] collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testJoinWithExpressionPreds(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.join(ds2).filter('b === 'h + 1 && 'a - 1 === 'd + 2).select('c, 'g) + + val expected = "I am fine.,Hallo Welt\n" + + "Luke Skywalker,Hallo Welt wie gehts?\n" + + "Luke Skywalker,ABC\n" + + "Comment#2,HIJ\n" + + "Comment#2,IJK" + val results = joinT.toDataSet[Row] collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testJoinTablesFromDifferentEnvs(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv1 = TableEnvironment.getTableEnvironment(env, config) + val tEnv2 = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv1, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv2, 'd, 'e, 'f, 'g, 'h) + + // Must fail. Tables are bound to different TableEnvironments. + ds1.join(ds2).where('b === 'e).select('c, 'g) + } + + @Test + def testLeftJoinWithMultipleKeys(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setNullCheck(true) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.leftOuterJoin(ds2, 'a === 'd && 'b === 'h).select('c, 'g) + + val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt wie gehts?\n" + + "Hello world,ABC\n" + "Hello world, how are you?,null\n" + "I am fine.,HIJ\n" + + "I am fine.,IJK\n" + "Luke Skywalker,null\n" + "Comment#1,null\n" + "Comment#2,null\n" + + "Comment#3,null\n" + "Comment#4,null\n" + "Comment#5,null\n" + "Comment#6,null\n" + + "Comment#7,null\n" + "Comment#8,null\n" + "Comment#9,null\n" + "Comment#10,null\n" + + "Comment#11,null\n" + "Comment#12,null\n" + "Comment#13,null\n" + "Comment#14,null\n" + + "Comment#15,null\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testNoJoinCondition(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setNullCheck(true) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds2.leftOuterJoin(ds1, 'b === 'd && 'b < 3).select('c, 'g) + } + + @Test(expected = classOf[ValidationException]) + def testNoEquiJoin(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setNullCheck(true) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds2.leftOuterJoin(ds1, 'b < 'd).select('c, 'g) + } + + @Test + def testRightJoinWithMultipleKeys(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setNullCheck(true) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.rightOuterJoin(ds2, "a = d && b = h").select('c, 'g) + + val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "null,Hallo Welt wie\n" + + "Hello world,Hallo Welt wie gehts?\n" + "Hello world,ABC\n" + "null,BCD\n" + "null,CDE\n" + + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "I am fine.,HIJ\n" + + "I am fine.,IJK\n" + "null,JKL\n" + "null,KLM\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightJoinWithNotOnlyEquiJoin(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setNullCheck(true) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.rightOuterJoin(ds2, "a = d && b < h").select('c, 'g) + + val expected = "Hello world,BCD\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testFullOuterJoinWithMultipleKeys(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setNullCheck(true) + + val ds1 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f, 'g, 'h) + + val joinT = ds1.fullOuterJoin(ds2, 'a === 'd && 'b === 'h).select('c, 'g) + + val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "null,Hallo Welt wie\n" + + "Hello world,Hallo Welt wie gehts?\n" + "Hello world,ABC\n" + "null,BCD\n" + "null,CDE\n" + + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "I am fine.,HIJ\n" + + "I am fine.,IJK\n" + "null,JKL\n" + "null,KLM\n" + "Luke Skywalker,null\n" + + "Comment#1,null\n" + "Comment#2,null\n" + "Comment#3,null\n" + "Comment#4,null\n" + + "Comment#5,null\n" + "Comment#6,null\n" + "Comment#7,null\n" + "Comment#8,null\n" + + "Comment#9,null\n" + "Comment#10,null\n" + "Comment#11,null\n" + "Comment#12,null\n" + + "Comment#13,null\n" + "Comment#14,null\n" + "Comment#15,null\n" + + "Hello world, how are you?,null\n" + val results = joinT.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SetOperatorsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SetOperatorsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SetOperatorsITCase.scala new file mode 100644 index 0000000..e369250 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SetOperatorsITCase.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.types.Row +import org.apache.flink.table.api.{TableEnvironment, ValidationException} +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Random + +@RunWith(classOf[Parameterized]) +class SetOperatorsITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + @Test + def testUnionAll(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f) + + val unionDs = ds1.unionAll(ds2).select('c) + + val results = unionDs.toDataSet[Row].collect() + val expected = "Hi\n" + "Hello\n" + "Hello world\n" + "Hi\n" + "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testUnion(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'd, 'e, 'f) + + val unionDs = ds1.union(ds2).select('c) + + val results = unionDs.toDataSet[Row].collect() + val expected = "Hi\n" + "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testTernaryUnionAll(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds3 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val unionDs = ds1.unionAll(ds2).unionAll(ds3).select('c) + + val results = unionDs.toDataSet[Row].collect() + val expected = "Hi\n" + "Hello\n" + "Hello world\n" + + "Hi\n" + "Hello\n" + "Hello world\n" + + "Hi\n" + "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testTernaryUnion(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds3 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + + val unionDs = ds1.union(ds2).union(ds3).select('c) + + val results = unionDs.toDataSet[Row].collect() + val expected = "Hi\n" + "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testUnionDifferentColumnSize(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'a, 'b, 'd, 'c, 'e) + + // must fail. Union inputs have different column size. + ds1.unionAll(ds2) + } + + @Test(expected = classOf[ValidationException]) + def testUnionDifferentFieldTypes(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .select('a, 'b, 'c) + + // must fail. Union inputs have different field types. + ds1.unionAll(ds2) + } + + @Test(expected = classOf[ValidationException]) + def testUnionTablesFromDifferentEnvs(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv1 = TableEnvironment.getTableEnvironment(env, config) + val tEnv2 = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv1, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv2, 'a, 'b, 'c) + + // Must fail. Tables are bound to different TableEnvironments. + ds1.unionAll(ds2).select('c) + } + + @Test + def testMinusAll(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = env.fromElements((1, 1L, "Hi")).toTable(tEnv, 'a, 'b, 'c) + + val minusDs = ds1.unionAll(ds1).unionAll(ds1) + .minusAll(ds2.unionAll(ds2)).select('c) + + val results = minusDs.toDataSet[Row].collect() + val expected = "Hi\n" + + "Hello\n" + "Hello world\n" + + "Hello\n" + "Hello world\n" + + "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testMinus(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = env.fromElements((1, 1L, "Hi")).toTable(tEnv, 'a, 'b, 'c) + + val minusDs = ds1.unionAll(ds1).unionAll(ds1) + .minus(ds2.unionAll(ds2)).select('c) + + val results = minusDs.toDataSet[Row].collect() + val expected = "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testMinusDifferentFieldTypes(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .select('a, 'b, 'c) + + // must fail. Minus inputs have different field types. + ds1.minus(ds2) + } + + @Test + def testMinusDifferentFieldNames(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = env.fromElements((1, 1L, "Hi")).toTable(tEnv, 'd, 'e, 'f) + + val minusDs = ds1.unionAll(ds1).unionAll(ds1) + .minus(ds2.unionAll(ds2)).select('c) + + val results = minusDs.toDataSet[Row].collect() + val expected = "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testMinusAllTablesFromDifferentEnvs(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv1 = TableEnvironment.getTableEnvironment(env, config) + val tEnv2 = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv1, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv2, 'a, 'b, 'c) + + // Must fail. Tables are bound to different TableEnvironments. + ds1.minusAll(ds2).select('c) + } + + @Test + def testIntersect(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Hi")) + data.+=((2, 2L, "Hello")) + data.+=((2, 2L, "Hello")) + data.+=((3, 2L, "Hello world!")) + val ds2 = env.fromCollection(Random.shuffle(data)).toTable(tEnv, 'a, 'b, 'c) + + val intersectDS = ds1.intersect(ds2).select('c).toDataSet[Row] + + val results = intersectDS.collect() + + val expected = "Hi\n" + "Hello\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testIntersectAll(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val data1 = new mutable.MutableList[Int] + data1 += (1, 1, 1, 2, 2) + val data2 = new mutable.MutableList[Int] + data2 += (1, 2, 2, 2, 3) + val ds1 = env.fromCollection(data1).toTable(tEnv, 'c) + val ds2 = env.fromCollection(data2).toTable(tEnv, 'c) + + val intersectDS = ds1.intersectAll(ds2).select('c).toDataSet[Row] + + val expected = "1\n2\n2" + val results = intersectDS.collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testIntersectWithDifferentFieldNames(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'e, 'f, 'g) + + val intersectDs = ds1.intersect(ds2).select('c) + + val results = intersectDs.toDataSet[Row].collect() + val expected = "Hi\n" + "Hello\n" + "Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testIntersectWithDifferentFieldTypes(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .select('a, 'b, 'c) + + // must fail. Intersect inputs have different field types. + ds1.intersect(ds2) + } + + @Test(expected = classOf[ValidationException]) + def testIntersectTablesFromDifferentEnvs(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv1 = TableEnvironment.getTableEnvironment(env, config) + val tEnv2 = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv1, 'a, 'b, 'c) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv2, 'a, 'b, 'c) + + // Must fail. Tables are bound to different TableEnvironments. + ds1.intersect(ds2).select('c) + } + + @Test + def testIntersectWithScalarExpression(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select('a + 1, 'b, 'c) + val ds2 = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c) + .select('a + 1, 'b, 'c) + + val intersectDs = ds1.intersect(ds2) + + val results = intersectDs.toDataSet[Row].collect() + val expected = "2,1,Hi\n" + "3,2,Hello\n" + "4,2,Hello world\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SortITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SortITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SortITCase.scala new file mode 100644 index 0000000..3cbc2c8 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/SortITCase.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import org.apache.flink.table.api.scala.batch.utils.SortTestUtils._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.scala.{ExecutionEnvironment, _} +import org.apache.flink.types.Row +import org.apache.flink.table.api.{TableEnvironment, ValidationException} +import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.TestBaseUtils +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class SortITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { + + def getExecutionEnvironment = { + val env = ExecutionEnvironment.getExecutionEnvironment + env.setParallelism(4) + env + } + + @Test + def testOrderByDesc(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).orderBy('_1.desc) + implicit def rowOrdering[T <: Product] = Ordering.by((x : T) => + - x.productElement(0).asInstanceOf[Int] ) + + val expected = sortExpectedly(tupleDataSetStrings) + val results = t.toDataSet[Row].mapPartition(rows => Seq(rows.toSeq)).collect() + + val result = results + .filterNot(_.isEmpty) + .sortBy(_.head)(Ordering.by(f=> f.toString)) + .reduceLeft(_ ++ _) + + TestBaseUtils.compareOrderedResultAsText(result.asJava, expected) + } + + @Test + def testOrderByAsc(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).orderBy('_1.asc) + implicit def rowOrdering[T <: Product] = Ordering.by((x : T) => + x.productElement(0).asInstanceOf[Int] ) + + val expected = sortExpectedly(tupleDataSetStrings) + val results = t.toDataSet[Row].mapPartition(rows => Seq(rows.toSeq)).collect() + + val result = results + .filterNot(_.isEmpty) + .sortBy(_.head)(Ordering.by(f=> f.toString)) + .reduceLeft(_ ++ _) + + TestBaseUtils.compareOrderedResultAsText(result.asJava, expected) + } + + @Test + def testOrderByMultipleFieldsDifferentDirections(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).orderBy('_2.asc, '_1.desc) + implicit def rowOrdering[T <: Product] = Ordering.by((x : T) => + (x.productElement(1).asInstanceOf[Long], - x.productElement(0).asInstanceOf[Int]) ) + + val expected = sortExpectedly(tupleDataSetStrings) + val results = t.toDataSet[Row].mapPartition(rows => Seq(rows.toSeq)).collect() + + val result = results + .filterNot(_.isEmpty) + .sortBy(_.head)(Ordering.by(f=> f.toString)) + .reduceLeft(_ ++ _) + + TestBaseUtils.compareOrderedResultAsText(result.asJava, expected) + } + + @Test + def testOrderByOffset(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).orderBy('_1.asc).limit(3) + implicit def rowOrdering[T <: Product] = Ordering.by((x : T) => + x.productElement(0).asInstanceOf[Int] ) + + val expected = sortExpectedly(tupleDataSetStrings, 3, 21) + val results = t.toDataSet[Row].mapPartition(rows => Seq(rows.toSeq)).collect() + + val result = results + .filterNot(_.isEmpty) + .sortBy(_.head)(Ordering.by(f=> f.toString)) + .reduceLeft(_ ++ _) + + TestBaseUtils.compareOrderedResultAsText(result.asJava, expected) + } + + @Test + def testOrderByOffsetAndFetch(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).orderBy('_1.desc).limit(3, 5) + implicit def rowOrdering[T <: Product] = Ordering.by((x : T) => + - x.productElement(0).asInstanceOf[Int] ) + + val expected = sortExpectedly(tupleDataSetStrings, 3, 8) + val results = t.toDataSet[Row].mapPartition(rows => Seq(rows.toSeq)).collect() + + val result = results + .filterNot(_.isEmpty) + .sortBy(_.head)(Ordering.by(f=> f.toString)) + .reduceLeft(_ ++ _) + + TestBaseUtils.compareOrderedResultAsText(result.asJava, expected) + } + + @Test + def testOrderByFetch(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).orderBy('_1.asc).limit(0, 5) + implicit def rowOrdering[T <: Product] = Ordering.by((x : T) => + x.productElement(0).asInstanceOf[Int] ) + + val expected = sortExpectedly(tupleDataSetStrings, 0, 5) + val results = t.toDataSet[Row].mapPartition(rows => Seq(rows.toSeq)).collect() + + val result = results + .filterNot(_.isEmpty) + .sortBy(_.head)(Ordering.by(f=> f.toString)) + .reduceLeft(_ ++ _) + + TestBaseUtils.compareOrderedResultAsText(result.asJava, expected) + } + + @Test(expected = classOf[ValidationException]) + def testFetchWithoutOrder(): Unit = { + val env = getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + val t = ds.toTable(tEnv).limit(0, 5) + + t.toDataSet[Row].collect() + } + +}