[FLINK-3849] [table] Add FilterableTableSource interface and rules for pushing it (1)
fix filterable test rebase and trying fix rexnode parsing create wrapper and update rules Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9f6cd2e7 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9f6cd2e7 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9f6cd2e7 Branch: refs/heads/master Commit: 9f6cd2e76a44f194d96db7a218e5032791c5925c Parents: ab014ef Author: tonycox <anton_solo...@epam.com> Authored: Wed Jan 11 13:15:49 2017 +0400 Committer: Kurt Young <k...@apache.org> Committed: Fri Mar 17 18:01:27 2017 +0800 ---------------------------------------------------------------------- .../flink/table/api/BatchTableEnvironment.scala | 2 +- .../table/api/StreamTableEnvironment.scala | 2 +- .../flink/table/api/TableEnvironment.scala | 12 + .../flink/table/calcite/RexNodeWrapper.scala | 106 +++++++ .../table/expressions/stringExpressions.scala | 2 +- .../flink/table/plan/nodes/CommonCalc.scala | 34 ++- .../nodes/dataset/BatchTableSourceScan.scala | 20 +- .../table/plan/nodes/dataset/DataSetCalc.scala | 24 +- .../plan/nodes/datastream/DataStreamCalc.scala | 15 +- .../datastream/StreamTableSourceScan.scala | 24 +- .../flink/table/plan/rules/FlinkRuleSets.scala | 11 +- ...PushFilterIntoBatchTableSourceScanRule.scala | 95 ++++++ ...ushProjectIntoBatchTableSourceScanRule.scala | 2 +- ...ushFilterIntoStreamTableSourceScanRule.scala | 95 ++++++ ...shProjectIntoStreamTableSourceScanRule.scala | 2 +- .../rules/util/RexProgramProjectExtractor.scala | 120 -------- .../table/plan/schema/TableSourceTable.scala | 1 + .../util/RexProgramExpressionExtractor.scala | 163 ++++++++++ .../plan/util/RexProgramProjectExtractor.scala | 120 ++++++++ .../table/sources/FilterableTableSource.scala | 38 +++ .../flink/table/sources/TableSource.scala | 1 - .../flink/table/validate/FunctionCatalog.scala | 24 ++ .../flink/table/TableEnvironmentTest.scala | 25 +- .../apache/flink/table/TableSourceTest.scala | 300 +++++++++++++++++++ .../api/scala/batch/TableSourceITCase.scala | 16 + .../table/api/scala/batch/TableSourceTest.scala | 209 ------------- .../api/scala/stream/TableSourceITCase.scala | 19 ++ .../util/RexProgramProjectExtractorTest.scala | 121 -------- .../RexProgramExpressionExtractorTest.scala | 182 +++++++++++ .../util/RexProgramProjectExtractorTest.scala | 121 ++++++++ .../flink/table/utils/CommonTestData.scala | 128 +++++++- 31 files changed, 1511 insertions(+), 523 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala index b48e9f9..7f27357 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala @@ -95,7 +95,7 @@ abstract class BatchTableEnvironment( tableSource match { case batchTableSource: BatchTableSource[_] => - registerTableInternal(name, new TableSourceTable(batchTableSource)) + registerTableInternal(name, new TableSourceTable(batchTableSource, this)) case _ => throw new TableException("Only BatchTableSource can be registered in " + "BatchTableEnvironment") http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala index d927c3a..7e9f38f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala @@ -136,7 +136,7 @@ abstract class StreamTableEnvironment( tableSource match { case streamTableSource: StreamTableSource[_] => - registerTableInternal(name, new TableSourceTable(streamTableSource)) + registerTableInternal(name, new TableSourceTable(streamTableSource, this)) case _ => throw new TableException("Only StreamTableSource can be registered in " + "StreamTableEnvironment") http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala index 1dda3a8..291f49f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala @@ -842,6 +842,18 @@ object TableEnvironment { } /** + * Returns field types for a given [[TableSource]]. + * + * @param tableSource The TableSource to extract field types from. + * @tparam A The type of the TableSource. + * @return An array holding the field types. + */ + def getFieldTypes[A](tableSource: TableSource[A]): Array[TypeInformation[_]] = { + val returnType = tableSource.getReturnType + TableEnvironment.getFieldTypes(returnType) + } + + /** * Returns field names for a given [[TableSource]]. * * @param tableSource The TableSource to extract field names from. http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RexNodeWrapper.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RexNodeWrapper.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RexNodeWrapper.scala new file mode 100644 index 0000000..1926a67 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/RexNodeWrapper.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.calcite + +import org.apache.calcite.rex._ +import org.apache.calcite.sql._ +import org.apache.flink.table.api.TableException +import org.apache.flink.table.expressions.{Expression, Literal, ResolvedFieldReference} +import org.apache.flink.table.validate.FunctionCatalog +import org.apache.flink.table.calcite.RexNodeWrapper._ + +abstract class RexNodeWrapper(rex: RexNode) { + def get: RexNode = rex + def toExpression(names: Map[RexInputRef, String]): Expression +} + +case class RexLiteralWrapper(literal: RexLiteral) extends RexNodeWrapper(literal) { + override def toExpression(names: Map[RexInputRef, String]): Expression = { + val typeInfo = FlinkTypeFactory.toTypeInfo(literal.getType) + Literal(literal.getValue, typeInfo) + } +} + +case class RexInputWrapper(input: RexInputRef) extends RexNodeWrapper(input) { + override def toExpression(names: Map[RexInputRef, String]): Expression = { + val typeInfo = FlinkTypeFactory.toTypeInfo(input.getType) + ResolvedFieldReference(names(input), typeInfo) + } +} + +case class RexCallWrapper( + call: RexCall, + operands: Seq[RexNodeWrapper]) extends RexNodeWrapper(call) { + + override def toExpression(names: Map[RexInputRef, String]): Expression = { + val ops = operands.map(_.toExpression(names)) + call.op match { + case function: SqlFunction => + lookupFunction(replace(function.getName), ops) + case postfix: SqlPostfixOperator => + lookupFunction(replace(postfix.getName), ops) + case operator@_ => + val name = replace(s"${operator.kind}") + lookupFunction(name, ops) + } + } + + def replace(str: String): String = { + str.replaceAll("\\s|_", "") + } +} + +object RexNodeWrapper { + + private var catalog: Option[FunctionCatalog] = None + + def wrap(rex: RexNode, functionCatalog: FunctionCatalog): RexNodeWrapper = { + catalog = Option(functionCatalog) + rex.accept(new WrapperVisitor) + } + + private[table] def lookupFunction(name: String, operands: Seq[Expression]): Expression = { + catalog.getOrElse(throw TableException("FunctionCatalog was not defined")) + .lookupFunction(name, operands) + } +} + +class WrapperVisitor extends RexVisitorImpl[RexNodeWrapper](true) { + + override def visitInputRef(inputRef: RexInputRef): RexNodeWrapper = { + RexInputWrapper(inputRef) + } + + override def visitLiteral(literal: RexLiteral): RexNodeWrapper = { + RexLiteralWrapper(literal) + } + + override def visitLocalRef(localRef: RexLocalRef): RexNodeWrapper = { + localRef.accept(this) + } + + override def visitCall(call: RexCall): RexNodeWrapper = { + val operands = for { + x <- 0 until call.operands.size() + } yield { + call.operands.get(x).accept(this) + } + RexCallWrapper(call, operands) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala index f4b58cc..e8ae0d8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala @@ -111,7 +111,7 @@ case class Lower(child: Expression) extends UnaryExpression { } } - override def toString: String = s"($child).toLowerCase()" + override def toString: String = s"($child).lowerCase()" override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { relBuilder.call(SqlStdOperatorTable.LOWER, child.toRexNode) http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala index 3f46258..8b07aac 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala @@ -18,8 +18,9 @@ package org.apache.flink.table.plan.nodes +import org.apache.calcite.plan.{RelOptCost, RelOptPlanner} import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rex.{RexNode, RexProgram} +import org.apache.calcite.rex._ import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.TableConfig @@ -149,4 +150,35 @@ trait CommonCalc { val name = calcOpName(calcProgram, expression) s"Calc($name)" } + + private[flink] def computeSelfCost( + calcProgram: RexProgram, + planner: RelOptPlanner, + rowCnt: Double): RelOptCost = { + + // compute number of expressions that do not access a field or literal, i.e. computations, + // conditions, etc. We only want to account for computations, not for simple projections. + // CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule + // in normalization stage. So we should ignore CASTs here in optimization stage. + val compCnt = calcProgram.getExprList.asScala.toList.count { + case i: RexInputRef => false + case l: RexLiteral => false + case c: RexCall if c.getOperator.getName.equals("CAST") => false + case _ => true + } + + planner.getCostFactory.makeCost(rowCnt, rowCnt * compCnt, 0) + } + + private[flink] def estimateRowCount( + calcProgram: RexProgram, + rowCnt: Double): Double = { + + if (calcProgram.getCondition != null) { + // we reduce the result card to push filters down + (rowCnt * 0.75).min(1.0) + } else { + rowCnt + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala index 9b8e1ea..11f595c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala @@ -21,20 +21,21 @@ package org.apache.flink.table.plan.nodes.dataset import org.apache.calcite.plan._ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.calcite.rex.RexNode import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.{BatchTableEnvironment, TableEnvironment} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.schema.TableSourceTable -import org.apache.flink.table.sources.BatchTableSource import org.apache.flink.types.Row +import org.apache.flink.table.sources.BatchTableSource /** Flink RelNode to read data from an external source defined by a [[BatchTableSource]]. */ class BatchTableSourceScan( cluster: RelOptCluster, traitSet: RelTraitSet, table: RelOptTable, - val tableSource: BatchTableSource[_]) + val tableSource: BatchTableSource[_], + filterCondition: RexNode = null) extends BatchScan(cluster, traitSet, table) { override def deriveRowType() = { @@ -54,13 +55,20 @@ class BatchTableSourceScan( cluster, traitSet, getTable, - tableSource + tableSource, + filterCondition ) } override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) + val terms = super.explainTerms(pw) .item("fields", TableEnvironment.getFieldNames(tableSource).mkString(", ")) + if (filterCondition != null) { + import scala.collection.JavaConverters._ + val fieldNames = getTable.getRowType.getFieldNames.asScala.toList + terms.item("filter", getExpressionString(filterCondition, fieldNames, None)) + } + terms } override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = { @@ -68,6 +76,6 @@ class BatchTableSourceScan( val config = tableEnv.getConfig val inputDataSet = tableSource.getDataSet(tableEnv.execEnv).asInstanceOf[DataSet[Any]] - convertToInternalRow(inputDataSet, new TableSourceTable(tableSource), config) + convertToInternalRow(inputDataSet, new TableSourceTable(tableSource, tableEnv), config) } } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala index 9b3ff63..972e45b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala @@ -31,8 +31,6 @@ import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonCalc import org.apache.flink.types.Row -import scala.collection.JavaConverters._ - /** * Flink RelNode which matches along with LogicalCalc. * @@ -71,34 +69,17 @@ class DataSetCalc( } override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { - val child = this.getInput val rowCnt = metadata.getRowCount(child) - // compute number of expressions that do not access a field or literal, i.e. computations, - // conditions, etc. We only want to account for computations, not for simple projections. - // CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule - // in normalization stage. So we should ignore CASTs here in optimization stage. - val compCnt = calcProgram.getExprList.asScala.toList.count { - case i: RexInputRef => false - case l: RexLiteral => false - case c: RexCall if c.getOperator.getName.equals("CAST") => false - case _ => true - } - - planner.getCostFactory.makeCost(rowCnt, rowCnt * compCnt, 0) + computeSelfCost(calcProgram, planner, rowCnt) } override def estimateRowCount(metadata: RelMetadataQuery): Double = { val child = this.getInput val rowCnt = metadata.getRowCount(child) - if (calcProgram.getCondition != null) { - // we reduce the result card to push filters down - (rowCnt * 0.75).min(1.0) - } else { - rowCnt - } + estimateRowCount(calcProgram, rowCnt) } override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] = { @@ -127,5 +108,4 @@ class DataSetCalc( val mapFunc = calcMapFunction(genFunction) inputDS.flatMap(mapFunc).name(calcOpName(calcProgram, getExpressionString)) } - } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala index b39ae4a..26778d7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCalc.scala @@ -18,8 +18,9 @@ package org.apache.flink.table.plan.nodes.datastream -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.RexProgram import org.apache.flink.api.common.functions.FlatMapFunction @@ -68,6 +69,18 @@ class DataStreamCalc( calcProgram.getCondition != null) } + override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { + val child = this.getInput + val rowCnt = metadata.getRowCount(child) + computeSelfCost(calcProgram, planner, rowCnt) + } + + override def estimateRowCount(metadata: RelMetadataQuery): Double = { + val child = this.getInput + val rowCnt = metadata.getRowCount(child) + estimateRowCount(calcProgram, rowCnt) + } + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { val config = tableEnv.getConfig http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala index 73d0291..b808d8d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala @@ -21,19 +21,21 @@ package org.apache.flink.table.plan.nodes.datastream import org.apache.calcite.plan._ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.flink.streaming.api.datastream.DataStream -import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment} +import org.apache.calcite.rex.RexNode import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.schema.TableSourceTable -import org.apache.flink.table.sources.StreamTableSource import org.apache.flink.types.Row +import org.apache.flink.table.sources.StreamTableSource +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment} /** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */ class StreamTableSourceScan( cluster: RelOptCluster, traitSet: RelTraitSet, table: RelOptTable, - val tableSource: StreamTableSource[_]) + val tableSource: StreamTableSource[_], + filterCondition: RexNode = null) extends StreamScan(cluster, traitSet, table) { override def deriveRowType() = { @@ -53,13 +55,20 @@ class StreamTableSourceScan( cluster, traitSet, getTable, - tableSource + tableSource, + filterCondition ) } override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) + val terms = super.explainTerms(pw) .item("fields", TableEnvironment.getFieldNames(tableSource).mkString(", ")) + if (filterCondition != null) { + import scala.collection.JavaConverters._ + val fieldNames = getTable.getRowType.getFieldNames.asScala.toList + terms.item("filter", getExpressionString(filterCondition, fieldNames, None)) + } + terms } override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { @@ -68,7 +77,6 @@ class StreamTableSourceScan( val inputDataStream: DataStream[Any] = tableSource .getDataStream(tableEnv.execEnv).asInstanceOf[DataStream[Any]] - convertToInternalRow(inputDataStream, new TableSourceTable(tableSource), config) + convertToInternalRow(inputDataStream, new TableSourceTable(tableSource, tableEnv), config) } - } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 3b20236..952ee34 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -122,8 +122,10 @@ object FlinkRuleSets { DataSetValuesRule.INSTANCE, DataSetCorrelateRule.INSTANCE, BatchTableSourceScanRule.INSTANCE, - // project pushdown optimization - PushProjectIntoBatchTableSourceScanRule.INSTANCE + + // scan optimization + PushProjectIntoBatchTableSourceScanRule.INSTANCE, + PushFilterIntoBatchTableSourceScanRule.INSTANCE ) /** @@ -178,7 +180,10 @@ object FlinkRuleSets { DataStreamValuesRule.INSTANCE, DataStreamCorrelateRule.INSTANCE, StreamTableSourceScanRule.INSTANCE, - PushProjectIntoStreamTableSourceScanRule.INSTANCE + + // scan optimization + PushProjectIntoStreamTableSourceScanRule.INSTANCE, + PushFilterIntoStreamTableSourceScanRule.INSTANCE ) } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushFilterIntoBatchTableSourceScanRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushFilterIntoBatchTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushFilterIntoBatchTableSourceScanRule.scala new file mode 100644 index 0000000..f95e34e --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushFilterIntoBatchTableSourceScanRule.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.dataSet + +import org.apache.calcite.plan.RelOptRule._ +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.calcite.rex.RexProgram +import org.apache.flink.table.plan.nodes.dataset.{BatchTableSourceScan, DataSetCalc} +import org.apache.flink.table.plan.util.RexProgramExpressionExtractor._ +import org.apache.flink.table.plan.schema.TableSourceTable +import org.apache.flink.table.sources.FilterableTableSource + +class PushFilterIntoBatchTableSourceScanRule extends RelOptRule( + operand(classOf[DataSetCalc], + operand(classOf[BatchTableSourceScan], none)), + "PushFilterIntoBatchTableSourceScanRule") { + + override def matches(call: RelOptRuleCall) = { + val calc: DataSetCalc = call.rel(0).asInstanceOf[DataSetCalc] + val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan] + scan.tableSource match { + case _: FilterableTableSource => + calc.calcProgram.getCondition != null + case _ => false + } + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val calc: DataSetCalc = call.rel(0).asInstanceOf[DataSetCalc] + val scan: BatchTableSourceScan = call.rel(1).asInstanceOf[BatchTableSourceScan] + + val filterableSource = scan.tableSource.asInstanceOf[FilterableTableSource] + + val program: RexProgram = calc.calcProgram + val tst = scan.getTable.unwrap(classOf[TableSourceTable[_]]) + val predicate = extractPredicateExpressions( + program, + call.builder().getRexBuilder, + tst.tableEnv.getFunctionCatalog) + + if (predicate.length != 0) { + val remainingPredicate = filterableSource.setPredicate(predicate) + + if (verifyExpressions(predicate, remainingPredicate)) { + + val filterRexNode = getFilterExpressionAsRexNode( + program.getInputRowType, + scan, + predicate.diff(remainingPredicate))(call.builder()) + + val newScan = new BatchTableSourceScan( + scan.getCluster, + scan.getTraitSet, + scan.getTable, + scan.tableSource, + filterRexNode) + + val newCalcProgram = rewriteRexProgram( + program, + newScan, + remainingPredicate)(call.builder()) + + val newCalc = new DataSetCalc( + calc.getCluster, + calc.getTraitSet, + newScan, + calc.getRowType, + newCalcProgram, + description) + + call.transformTo(newCalc) + } + } + } +} + +object PushFilterIntoBatchTableSourceScanRule { + val INSTANCE: RelOptRule = new PushFilterIntoBatchTableSourceScanRule +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala index 70639b7..53f5fff 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala @@ -22,7 +22,7 @@ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.plan.RelOptRule.{none, operand} import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.plan.nodes.dataset.{BatchTableSourceScan, DataSetCalc} -import org.apache.flink.table.plan.rules.util.RexProgramProjectExtractor._ +import org.apache.flink.table.plan.util.RexProgramProjectExtractor._ import org.apache.flink.table.sources.{BatchTableSource, ProjectableTableSource} /** http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushFilterIntoStreamTableSourceScanRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushFilterIntoStreamTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushFilterIntoStreamTableSourceScanRule.scala new file mode 100644 index 0000000..9c02dd7 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushFilterIntoStreamTableSourceScanRule.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.RelOptRule._ +import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} +import org.apache.flink.table.plan.nodes.datastream.{DataStreamCalc, StreamTableSourceScan} +import org.apache.flink.table.plan.util.RexProgramExpressionExtractor._ +import org.apache.flink.table.plan.schema.TableSourceTable +import org.apache.flink.table.sources.FilterableTableSource + +class PushFilterIntoStreamTableSourceScanRule extends RelOptRule( + operand(classOf[DataStreamCalc], + operand(classOf[StreamTableSourceScan], none)), + "PushFilterIntoStreamTableSourceScanRule") { + + override def matches(call: RelOptRuleCall) = { + val calc: DataStreamCalc = call.rel(0).asInstanceOf[DataStreamCalc] + val scan: StreamTableSourceScan = call.rel(1).asInstanceOf[StreamTableSourceScan] + scan.tableSource match { + case _: FilterableTableSource => + calc.calcProgram.getCondition != null + case _ => false + } + } + + override def onMatch(call: RelOptRuleCall): Unit = { + val calc: DataStreamCalc = call.rel(0).asInstanceOf[DataStreamCalc] + val scan: StreamTableSourceScan = call.rel(1).asInstanceOf[StreamTableSourceScan] + + val filterableSource = scan.tableSource.asInstanceOf[FilterableTableSource] + + val program = calc.calcProgram + val tst = scan.getTable.unwrap(classOf[TableSourceTable[_]]) + val predicates = extractPredicateExpressions( + program, + call.builder().getRexBuilder, + tst.tableEnv.getFunctionCatalog) + + if (predicates.length != 0) { + val remainingPredicate = filterableSource.setPredicate(predicates) + + if (verifyExpressions(predicates, remainingPredicate)) { + + val filterRexNode = getFilterExpressionAsRexNode( + program.getInputRowType, + scan, + predicates.diff(remainingPredicate))(call.builder()) + + val newScan = new StreamTableSourceScan( + scan.getCluster, + scan.getTraitSet, + scan.getTable, + scan.tableSource, + filterRexNode) + + val newCalcProgram = rewriteRexProgram( + program, + newScan, + remainingPredicate)(call.builder()) + + val newCalc = new DataStreamCalc( + calc.getCluster, + calc.getTraitSet, + newScan, + calc.getRowType, + newCalcProgram, + description) + + call.transformTo(newCalc) + } + } + } + +} + +object PushFilterIntoStreamTableSourceScanRule { + val INSTANCE: RelOptRule = new PushFilterIntoStreamTableSourceScanRule +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushProjectIntoStreamTableSourceScanRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushProjectIntoStreamTableSourceScanRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushProjectIntoStreamTableSourceScanRule.scala index a6d4b82..0c20f2a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushProjectIntoStreamTableSourceScanRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/PushProjectIntoStreamTableSourceScanRule.scala @@ -22,7 +22,7 @@ import org.apache.calcite.plan.RelOptRule._ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.plan.nodes.datastream.{DataStreamCalc, StreamTableSourceScan} -import org.apache.flink.table.plan.rules.util.RexProgramProjectExtractor._ +import org.apache.flink.table.plan.util.RexProgramProjectExtractor._ import org.apache.flink.table.sources.{ProjectableTableSource, StreamTableSource} /** http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractor.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractor.scala deleted file mode 100644 index 129cfd1..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractor.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.rules.util - -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rex._ - -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.collection.JavaConverters._ - -object RexProgramProjectExtractor { - - /** - * Extracts the indexes of input fields accessed by the RexProgram. - * - * @param rexProgram The RexProgram to analyze - * @return The indexes of accessed input fields - */ - def extractRefInputFields(rexProgram: RexProgram): Array[Int] = { - val visitor = new RefFieldsVisitor - // extract input fields from project expressions - rexProgram.getProjectList.foreach(exp => rexProgram.expandLocalRef(exp).accept(visitor)) - val condition = rexProgram.getCondition - // extract input fields from condition expression - if (condition != null) { - rexProgram.expandLocalRef(condition).accept(visitor) - } - visitor.getFields - } - - /** - * Generates a new RexProgram based on mapped input fields. - * - * @param rexProgram original RexProgram - * @param inputRowType input row type - * @param usedInputFields indexes of used input fields - * @param rexBuilder builder for Rex expressions - * - * @return A RexProgram with mapped input field expressions. - */ - def rewriteRexProgram( - rexProgram: RexProgram, - inputRowType: RelDataType, - usedInputFields: Array[Int], - rexBuilder: RexBuilder): RexProgram = { - - val inputRewriter = new InputRewriter(usedInputFields) - val newProjectExpressions = rexProgram.getProjectList.map( - exp => rexProgram.expandLocalRef(exp).accept(inputRewriter) - ).toList.asJava - - val oldCondition = rexProgram.getCondition - val newConditionExpression = { - oldCondition match { - case ref: RexLocalRef => rexProgram.expandLocalRef(ref).accept(inputRewriter) - case _ => null // null does not match any type - } - } - RexProgram.create( - inputRowType, - newProjectExpressions, - newConditionExpression, - rexProgram.getOutputRowType, - rexBuilder - ) - } -} - -/** - * A RexVisitor to extract used input fields - */ -class RefFieldsVisitor extends RexVisitorImpl[Unit](true) { - private var fields = mutable.LinkedHashSet[Int]() - - def getFields: Array[Int] = fields.toArray - - override def visitInputRef(inputRef: RexInputRef): Unit = fields += inputRef.getIndex - - override def visitCall(call: RexCall): Unit = - call.operands.foreach(operand => operand.accept(this)) -} - -/** - * A RexShuttle to rewrite field accesses of a RexProgram. - * - * @param fields fields mapping - */ -class InputRewriter(fields: Array[Int]) extends RexShuttle { - - /** old input fields ref index -> new input fields ref index mappings */ - private val fieldMap: Map[Int, Int] = - fields.zipWithIndex.toMap - - override def visitInputRef(inputRef: RexInputRef): RexNode = - new RexInputRef(relNodeIndex(inputRef), inputRef.getType) - - override def visitLocalRef(localRef: RexLocalRef): RexNode = - new RexInputRef(relNodeIndex(localRef), localRef.getType) - - private def relNodeIndex(ref: RexSlot): Int = - fieldMap.getOrElse(ref.getIndex, - throw new IllegalArgumentException("input field contains invalid index")) -} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala index a3851e3..faf5efc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala @@ -25,6 +25,7 @@ import org.apache.flink.table.sources.TableSource /** Table which defines an external table via a [[TableSource]] */ class TableSourceTable[T]( val tableSource: TableSource[T], + val tableEnv: TableEnvironment, override val statistic: FlinkStatistic = FlinkStatistic.UNKNOWN) extends FlinkTable[T]( typeInfo = tableSource.getReturnType, http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractor.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractor.scala new file mode 100644 index 0000000..337b3de --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractor.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import org.apache.calcite.rel.core.TableScan +import org.apache.calcite.rex._ +import org.apache.calcite.plan.RelOptUtil +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.calcite.{FlinkTypeFactory, RexNodeWrapper} +import org.apache.flink.table.expressions._ +import org.apache.flink.table.validate.FunctionCatalog + +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.immutable.IndexedSeq + +object RexProgramExpressionExtractor { + + /** + * converts a rexProgram condition into independent CNF expressions + * + * @param rexProgram The RexProgram to analyze + * @return converted expression + */ + private[flink] def extractPredicateExpressions( + rexProgram: RexProgram, + rexBuilder: RexBuilder, + catalog: FunctionCatalog): Array[Expression] = { + + val fieldNames = getInputsWithNames(rexProgram) + + val condition = rexProgram.getCondition + if (condition == null) { + return Array.empty + } + val call = rexProgram.expandLocalRef(condition) + val cnf = RexUtil.toCnf(rexBuilder, call) + val conjunctions = RelOptUtil.conjunctions(cnf) + val expressions = conjunctions.asScala.map( + RexNodeWrapper.wrap(_, catalog).toExpression(fieldNames) + ) + expressions.toArray + } + + /** + * verify should we apply remained expressions on + * + * @param original initial expression + * @param remained remained part of original expression + * @return whether or not to decouple parts of the origin expression + */ + private[flink] def verifyExpressions( + original: Array[Expression], + remained: Array[Expression]): Boolean = + remained forall (original contains) + + /** + * Generates a new RexProgram based on new expression. + * + * @param rexProgram original RexProgram + * @param scan input source + * @param predicate filter condition (fields must be resolved) + * @param relBuilder builder for converting expression to Rex + */ + private[flink] def rewriteRexProgram( + rexProgram: RexProgram, + scan: TableScan, + predicate: Array[Expression])(implicit relBuilder: RelBuilder): RexProgram = { + + relBuilder.push(scan) + + val inType = rexProgram.getInputRowType + val resolvedExps = resolveFields(predicate, inType) + val projs = rexProgram.getProjectList.map(rexProgram.expandLocalRef) + + RexProgram.create( + inType, + projs, + conjunct(resolvedExps).get.toRexNode, + rexProgram.getOutputRowType, + relBuilder.getRexBuilder) + } + + private[flink] def getFilterExpressionAsRexNode( + inputTpe: RelDataType, + scan: TableScan, + exps: Array[Expression])(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.push(scan) + val resolvedExps = resolveFields(exps, inputTpe) + val fullExp = conjunct(resolvedExps) + if (fullExp.isDefined) { + fullExp.get.toRexNode + } else { + null + } + } + + private def resolveFields( + predicate: Array[Expression], + inType: RelDataType): Array[Expression] = { + val fieldTypes: Map[String, TypeInformation[_]] = inType.getFieldList + .map(f => f.getName -> FlinkTypeFactory.toTypeInfo(f.getType)) + .toMap + val rule: PartialFunction[Expression, Expression] = { + case u@UnresolvedFieldReference(name) => + ResolvedFieldReference(name, fieldTypes(name)) + } + predicate.map(_.postOrderTransform(rule)) + } + + private def conjunct(exps: Array[Expression]): Option[Expression] = { + def overIndexes(): IndexedSeq[Expression] = { + for { + i <- exps.indices by 2 + } yield { + if (i + 1 < exps.length) { + And(exps(i), exps(i + 1)) + } else { + exps(i) + } + } + } + exps.length match { + case 0 => + None + case 1 => + Option(exps(0)) + case _ => + conjunct(overIndexes().toArray) + } + } + + private def getInputsWithNames(rexProgram: RexProgram): Map[RexInputRef, String] = { + val names = rexProgram.getInputRowType.getFieldNames + + val buffer = for { + exp <- rexProgram.getExprList.asScala + if exp.isInstanceOf[RexInputRef] + ref = exp.asInstanceOf[RexInputRef] + } yield { + ref -> names(ref.getIndex) + } + buffer.toMap + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractor.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractor.scala new file mode 100644 index 0000000..1198167 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractor.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rex._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.JavaConverters._ + +object RexProgramProjectExtractor { + + /** + * Extracts the indexes of input fields accessed by the RexProgram. + * + * @param rexProgram The RexProgram to analyze + * @return The indexes of accessed input fields + */ + def extractRefInputFields(rexProgram: RexProgram): Array[Int] = { + val visitor = new RefFieldsVisitor + // extract input fields from project expressions + rexProgram.getProjectList.foreach(exp => rexProgram.expandLocalRef(exp).accept(visitor)) + val condition = rexProgram.getCondition + // extract input fields from condition expression + if (condition != null) { + rexProgram.expandLocalRef(condition).accept(visitor) + } + visitor.getFields + } + + /** + * Generates a new RexProgram based on mapped input fields. + * + * @param rexProgram original RexProgram + * @param inputRowType input row type + * @param usedInputFields indexes of used input fields + * @param rexBuilder builder for Rex expressions + * + * @return A RexProgram with mapped input field expressions. + */ + def rewriteRexProgram( + rexProgram: RexProgram, + inputRowType: RelDataType, + usedInputFields: Array[Int], + rexBuilder: RexBuilder): RexProgram = { + + val inputRewriter = new InputRewriter(usedInputFields) + val newProjectExpressions = rexProgram.getProjectList.map( + exp => rexProgram.expandLocalRef(exp).accept(inputRewriter) + ).toList.asJava + + val oldCondition = rexProgram.getCondition + val newConditionExpression = { + oldCondition match { + case ref: RexLocalRef => rexProgram.expandLocalRef(ref).accept(inputRewriter) + case _ => null // null does not match any type + } + } + RexProgram.create( + inputRowType, + newProjectExpressions, + newConditionExpression, + rexProgram.getOutputRowType, + rexBuilder + ) + } +} + +/** + * A RexVisitor to extract used input fields + */ +class RefFieldsVisitor extends RexVisitorImpl[Unit](true) { + private var fields = mutable.LinkedHashSet[Int]() + + def getFields: Array[Int] = fields.toArray + + override def visitInputRef(inputRef: RexInputRef): Unit = fields += inputRef.getIndex + + override def visitCall(call: RexCall): Unit = + call.operands.foreach(operand => operand.accept(this)) +} + +/** + * A RexShuttle to rewrite field accesses of a RexProgram. + * + * @param fields fields mapping + */ +class InputRewriter(fields: Array[Int]) extends RexShuttle { + + /** old input fields ref index -> new input fields ref index mappings */ + private val fieldMap: Map[Int, Int] = + fields.zipWithIndex.toMap + + override def visitInputRef(inputRef: RexInputRef): RexNode = + new RexInputRef(relNodeIndex(inputRef), inputRef.getType) + + override def visitLocalRef(localRef: RexLocalRef): RexNode = + new RexInputRef(relNodeIndex(localRef), localRef.getType) + + private def relNodeIndex(ref: RexSlot): Int = + fieldMap.getOrElse(ref.getIndex, + throw new IllegalArgumentException("input field contains invalid index")) +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/FilterableTableSource.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/FilterableTableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/FilterableTableSource.scala new file mode 100644 index 0000000..bbbf862 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/FilterableTableSource.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.sources + +import org.apache.flink.table.expressions.Expression + +/** + * Adds support for filtering push-down to a [[TableSource]]. + * A [[TableSource]] extending this interface is able to filter the fields of the return table. + * + */ +trait FilterableTableSource { + + /** return an predicate expression that was set. */ + def getPredicate: Array[Expression] + + /** + * @param predicate a filter expression that will be applied to fields to return. + * @return an unsupported predicate expression. + */ + def setPredicate(predicate: Array[Expression]): Array[Expression] +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala index a3eb03d..fe205f1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala @@ -19,7 +19,6 @@ package org.apache.flink.table.sources import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.api.TableEnvironment /** Defines an external table by providing schema information and used to produce a * [[org.apache.flink.api.scala.DataSet]] or [[org.apache.flink.streaming.api.scala.DataStream]]. http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 3c89ec4..2c08d8d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -139,7 +139,21 @@ class FunctionCatalog { object FunctionCatalog { val builtInFunctions: Map[String, Class[_]] = Map( + +// SqlStdOperatorTable.AS, +// SqlStdOperatorTable.DIVIDE_INTEGER, +// SqlStdOperatorTable.DOT, + // logic + "and" -> classOf[And], + "or" -> classOf[Or], + "not" -> classOf[Not], + "equals" -> classOf[EqualTo], + "greaterThan" -> classOf[GreaterThan], + "greaterThanOrEqual" -> classOf[GreaterThanOrEqual], + "lessThan" -> classOf[LessThan], + "lessThanOrEqual" -> classOf[LessThanOrEqual], + "notEquals" -> classOf[NotEqualTo], "isNull" -> classOf[IsNull], "isNotNull" -> classOf[IsNotNull], "isTrue" -> classOf[IsTrue], @@ -158,15 +172,23 @@ object FunctionCatalog { "charLength" -> classOf[CharLength], "initCap" -> classOf[InitCap], "like" -> classOf[Like], + "concat" -> classOf[Plus], + "lower" -> classOf[Lower], "lowerCase" -> classOf[Lower], "similar" -> classOf[Similar], "substring" -> classOf[Substring], "trim" -> classOf[Trim], + // duplicate functions for calcite + "upper" -> classOf[Upper], "upperCase" -> classOf[Upper], "position" -> classOf[Position], "overlay" -> classOf[Overlay], // math functions + "plus" -> classOf[Plus], + "minus" -> classOf[Minus], + "divide" -> classOf[Div], + "times" -> classOf[Mul], "abs" -> classOf[Abs], "ceil" -> classOf[Ceil], "exp" -> classOf[Exp], @@ -176,6 +198,7 @@ object FunctionCatalog { "power" -> classOf[Power], "mod" -> classOf[Mod], "sqrt" -> classOf[Sqrt], + "minusPrefix" -> classOf[UnaryMinus], // temporal functions "extract" -> classOf[Extract], @@ -186,6 +209,7 @@ object FunctionCatalog { "localTimestamp" -> classOf[LocalTimestamp], "quarter" -> classOf[Quarter], "temporalOverlaps" -> classOf[TemporalOverlaps], + "dateTimePlus" -> classOf[Plus], // array "cardinality" -> classOf[ArrayCardinality], http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala index 05c2a49..767e83f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala @@ -18,18 +18,16 @@ package org.apache.flink.table -import org.apache.calcite.tools.RuleSet import org.apache.flink.api.scala._ -import org.apache.flink.table.api.scala._ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TypeExtractor} -import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment, TableException} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.TableException import org.apache.flink.table.expressions.{Alias, UnresolvedFieldReference} -import org.apache.flink.table.sinks.TableSink -import org.apache.flink.table.sources.TableSource -import org.apache.flink.table.utils.TableTestBase -import org.apache.flink.table.utils.TableTestUtil.{batchTableNode, term, unaryNode, binaryNode, streamTableNode} +import org.apache.flink.table.utils.{MockTableEnvironment, TableTestBase} +import org.apache.flink.table.utils.TableTestUtil._ + import org.junit.Test import org.junit.Assert.assertEquals @@ -350,19 +348,6 @@ class TableEnvironmentTest extends TableTestBase { } -class MockTableEnvironment extends TableEnvironment(new TableConfig) { - - override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ??? - - override protected def checkValidTableName(name: String): Unit = ??? - - override protected def getBuiltInNormRuleSet: RuleSet = ??? - - override protected def getBuiltInOptRuleSet: RuleSet = ??? - - override def registerTableSource(name: String, tableSource: TableSource[_]) = ??? -} - case class CClass(cf1: Int, cf2: String, cf3: Double) class PojoClass(var pf1: Int, var pf2: String, var pf3: Double) { http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala new file mode 100644 index 0000000..058eca7 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableSourceTest.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table + +import org.apache.flink.table.api.Types +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.sources.{CsvTableSource, TableSource} +import org.apache.flink.table.utils.TableTestUtil._ +import org.junit.{Assert, Test} +import org.apache.flink.table.utils.{CommonTestData, TableTestBase} + +class TableSourceTest extends TableTestBase { + + private val projectedFields: Array[String] = Array("last", "id", "score") + private val noCalcFields: Array[String] = Array("id", "score", "first") + + // batch plan + + @Test + def testBatchProjectableSourceScanPlanTableApi(): Unit = { + val (tableSource, tableName) = csvTable + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('last.upperCase(), 'id.floor(), 'score * 2) + + val expected = unaryNode( + "DataSetCalc", + projectableSourceBatchTableNode(tableName, projectedFields), + term("select", "UPPER(last) AS _c0", "FLOOR(id) AS _c1", "*(score, 2) AS _c2") + ) + + util.verifyTable(result, expected) + } + + @Test + def testBatchProjectableSourceScanPlanSQL(): Unit = { + val (tableSource, tableName) = csvTable + val util = batchTestUtil() + + util.tEnv.registerTableSource(tableName, tableSource) + + val sqlQuery = s"SELECT last, floor(id), score * 2 FROM $tableName" + + val expected = unaryNode( + "DataSetCalc", + projectableSourceBatchTableNode(tableName, projectedFields), + term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testBatchProjectableSourceScanNoIdentityCalc(): Unit = { + val (tableSource, tableName) = csvTable + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('id, 'score, 'first) + + val expected = projectableSourceBatchTableNode(tableName, noCalcFields) + util.verifyTable(result, expected) + } + + @Test + def testBatchFilterableSourceScanPlanTableApi(): Unit = { + val (tableSource, tableName) = filterableTableSource + val util = batchTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('price, 'id, 'amount) + .where("amount > 2 && price * 2 < 32") + + val expected = unaryNode( + "DataSetCalc", + filterableSourceBatchTableNode( + tableName, + Array("name", "id", "amount", "price"), + ">(amount, 2)"), + term("select", "price", "id", "amount"), + term("where", "<(*(price, 2), 32)") + ) + + util.verifyTable(result, expected) + } + + // stream plan + + @Test + def testStreamProjectableSourceScanPlanTableApi(): Unit = { + val (tableSource, tableName) = csvTable + val util = streamTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('last, 'id.floor(), 'score * 2) + + val expected = unaryNode( + "DataStreamCalc", + projectableSourceStreamTableNode(tableName, projectedFields), + term("select", "last", "FLOOR(id) AS _c1", "*(score, 2) AS _c2") + ) + + util.verifyTable(result, expected) + } + + @Test + def testStreamProjectableSourceScanPlanSQL(): Unit = { + val (tableSource, tableName) = csvTable + val util = streamTestUtil() + + util.tEnv.registerTableSource(tableName, tableSource) + + val sqlQuery = s"SELECT last, floor(id), score * 2 FROM $tableName" + + val expected = unaryNode( + "DataStreamCalc", + projectableSourceStreamTableNode(tableName, projectedFields), + term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2") + ) + + util.verifySql(sqlQuery, expected) + } + + @Test + def testStreamProjectableSourceScanNoIdentityCalc(): Unit = { + val (tableSource, tableName) = csvTable + val util = streamTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('id, 'score, 'first) + + val expected = projectableSourceStreamTableNode(tableName, noCalcFields) + util.verifyTable(result, expected) + } + + @Test + def testStreamFilterableSourceScanPlanTableApi(): Unit = { + val (tableSource, tableName) = filterableTableSource + val util = streamTestUtil() + val tEnv = util.tEnv + + tEnv.registerTableSource(tableName, tableSource) + + val result = tEnv + .scan(tableName) + .select('price, 'id, 'amount) + .where("amount > 2 && price * 2 < 32") + + val expected = unaryNode( + "DataStreamCalc", + filterableSourceStreamTableNode( + tableName, + Array("name", "id", "amount", "price"), + ">(amount, 2)"), + term("select", "price", "id", "amount"), + term("where", "<(*(price, 2), 32)") + ) + + util.verifyTable(result, expected) + } + + // csv builder + + @Test + def testCsvTableSourceBuilder(): Unit = { + val source1 = CsvTableSource.builder() + .path("/path/to/csv") + .field("myfield", Types.STRING) + .field("myfield2", Types.INT) + .quoteCharacter(';') + .fieldDelimiter("#") + .lineDelimiter("\r\n") + .commentPrefix("%%") + .ignoreFirstLine() + .ignoreParseErrors() + .build() + + val source2 = new CsvTableSource( + "/path/to/csv", + Array("myfield", "myfield2"), + Array(Types.STRING, Types.INT), + "#", + "\r\n", + ';', + true, + "%%", + true) + + Assert.assertEquals(source1, source2) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testCsvTableSourceBuilderWithNullPath(): Unit = { + CsvTableSource.builder() + .field("myfield", Types.STRING) + // should fail, path is not defined + .build() + } + + @Test(expected = classOf[IllegalArgumentException]) + def testCsvTableSourceBuilderWithDuplicateFieldName(): Unit = { + CsvTableSource.builder() + .path("/path/to/csv") + .field("myfield", Types.STRING) + // should fail, field name must no be duplicate + .field("myfield", Types.INT) + } + + @Test(expected = classOf[IllegalArgumentException]) + def testCsvTableSourceBuilderWithEmptyField(): Unit = { + CsvTableSource.builder() + .path("/path/to/csv") + // should fail, field can be empty + .build() + } + + // utils + + def filterableTableSource:(TableSource[_], String) = { + val tableSource = CommonTestData.getFilterableTableSource + (tableSource, "filterableTable") + } + + def csvTable: (CsvTableSource, String) = { + val csvTable = CommonTestData.getCsvTableSource + val tableName = "csvTable" + (csvTable, tableName) + } + + def projectableSourceBatchTableNode( + sourceName: String, + fields: Array[String]): String = { + + "BatchTableSourceScan(" + + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}])" + } + + def projectableSourceStreamTableNode( + sourceName: String, + fields: Array[String]): String = { + + "StreamTableSourceScan(" + + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}])" + } + + def filterableSourceBatchTableNode( + sourceName: String, + fields: Array[String], + exp: String): String = { + + "BatchTableSourceScan(" + + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], filter=[$exp])" + } + + def filterableSourceStreamTableNode( + sourceName: String, + fields: Array[String], + exp: String): String = { + + "StreamTableSourceScan(" + + s"table=[[$sourceName]], fields=[${fields.mkString(", ")}], filter=[$exp])" + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala index 70f4345..ca7cd8a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala @@ -102,4 +102,20 @@ class TableSourceITCase( TestBaseUtils.compareResultAsText(result.asJava, expected) } + @Test + def testTableSourceWithFilterable(): Unit = { + val tableName = "MyTable" + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + tableEnv.registerTableSource(tableName, CommonTestData.getFilterableTableSource) + val results = tableEnv + .scan(tableName) + .where("amount > 4 && price < 9") + .select("id, name") + .collect() + + val expected = Seq( + "5,Record_5", "6,Record_6", "7,Record_7", "8,Record_8").mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceTest.scala deleted file mode 100644 index 670e268..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceTest.scala +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.api.scala.batch - -import org.apache.flink.table.api.Types -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.sources.CsvTableSource -import org.apache.flink.table.utils.{CommonTestData, TableTestBase} -import org.apache.flink.table.utils.TableTestUtil._ -import org.junit.{Assert, Test} - -class TableSourceTest extends TableTestBase { - - private val projectedFields: Array[String] = Array("last", "id", "score") - private val noCalcFields: Array[String] = Array("id", "score", "first") - - @Test - def testBatchProjectableSourceScanPlanTableApi(): Unit = { - val (csvTable, tableName) = tableSource - val util = batchTestUtil() - val tEnv = util.tEnv - - tEnv.registerTableSource(tableName, csvTable) - - val result = tEnv - .scan(tableName) - .select('last.upperCase(), 'id.floor(), 'score * 2) - - val expected = unaryNode( - "DataSetCalc", - sourceBatchTableNode(tableName, projectedFields), - term("select", "UPPER(last) AS _c0", "FLOOR(id) AS _c1", "*(score, 2) AS _c2") - ) - - util.verifyTable(result, expected) - } - - @Test - def testBatchProjectableSourceScanPlanSQL(): Unit = { - val (csvTable, tableName) = tableSource - val util = batchTestUtil() - - util.tEnv.registerTableSource(tableName, csvTable) - - val sqlQuery = s"SELECT last, floor(id), score * 2 FROM $tableName" - - val expected = unaryNode( - "DataSetCalc", - sourceBatchTableNode(tableName, projectedFields), - term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2") - ) - - util.verifySql(sqlQuery, expected) - } - - @Test - def testBatchProjectableSourceScanNoIdentityCalc(): Unit = { - val (csvTable, tableName) = tableSource - val util = batchTestUtil() - val tEnv = util.tEnv - - tEnv.registerTableSource(tableName, csvTable) - - val result = tEnv - .scan(tableName) - .select('id, 'score, 'first) - - val expected = sourceBatchTableNode(tableName, noCalcFields) - util.verifyTable(result, expected) - } - - @Test - def testStreamProjectableSourceScanPlanTableApi(): Unit = { - val (csvTable, tableName) = tableSource - val util = streamTestUtil() - val tEnv = util.tEnv - - tEnv.registerTableSource(tableName, csvTable) - - val result = tEnv - .scan(tableName) - .select('last, 'id.floor(), 'score * 2) - - val expected = unaryNode( - "DataStreamCalc", - sourceStreamTableNode(tableName, projectedFields), - term("select", "last", "FLOOR(id) AS _c1", "*(score, 2) AS _c2") - ) - - util.verifyTable(result, expected) - } - - @Test - def testStreamProjectableSourceScanPlanSQL(): Unit = { - val (csvTable, tableName) = tableSource - val util = streamTestUtil() - - util.tEnv.registerTableSource(tableName, csvTable) - - val sqlQuery = s"SELECT last, floor(id), score * 2 FROM $tableName" - - val expected = unaryNode( - "DataStreamCalc", - sourceStreamTableNode(tableName, projectedFields), - term("select", "last", "FLOOR(id) AS EXPR$1", "*(score, 2) AS EXPR$2") - ) - - util.verifySql(sqlQuery, expected) - } - - @Test - def testStreamProjectableSourceScanNoIdentityCalc(): Unit = { - val (csvTable, tableName) = tableSource - val util = streamTestUtil() - val tEnv = util.tEnv - - tEnv.registerTableSource(tableName, csvTable) - - val result = tEnv - .scan(tableName) - .select('id, 'score, 'first) - - val expected = sourceStreamTableNode(tableName, noCalcFields) - util.verifyTable(result, expected) - } - - @Test - def testCsvTableSourceBuilder(): Unit = { - val source1 = CsvTableSource.builder() - .path("/path/to/csv") - .field("myfield", Types.STRING) - .field("myfield2", Types.INT) - .quoteCharacter(';') - .fieldDelimiter("#") - .lineDelimiter("\r\n") - .commentPrefix("%%") - .ignoreFirstLine() - .ignoreParseErrors() - .build() - - val source2 = new CsvTableSource( - "/path/to/csv", - Array("myfield", "myfield2"), - Array(Types.STRING, Types.INT), - "#", - "\r\n", - ';', - true, - "%%", - true) - - Assert.assertEquals(source1, source2) - } - - @Test(expected = classOf[IllegalArgumentException]) - def testCsvTableSourceBuilderWithNullPath(): Unit = { - CsvTableSource.builder() - .field("myfield", Types.STRING) - // should fail, path is not defined - .build() - } - - @Test(expected = classOf[IllegalArgumentException]) - def testCsvTableSourceBuilderWithDuplicateFieldName(): Unit = { - CsvTableSource.builder() - .path("/path/to/csv") - .field("myfield", Types.STRING) - // should fail, field name must no be duplicate - .field("myfield", Types.INT) - } - - @Test(expected = classOf[IllegalArgumentException]) - def testCsvTableSourceBuilderWithEmptyField(): Unit = { - CsvTableSource.builder() - .path("/path/to/csv") - // should fail, field can be empty - .build() - } - - def tableSource: (CsvTableSource, String) = { - val csvTable = CommonTestData.getCsvTableSource - val tableName = "csvTable" - (csvTable, tableName) - } - - def sourceBatchTableNode(sourceName: String, fields: Array[String]): String = { - s"BatchTableSourceScan(table=[[$sourceName]], fields=[${fields.mkString(", ")}])" - } - - def sourceStreamTableNode(sourceName: String, fields: Array[String] ): String = { - s"StreamTableSourceScan(table=[[$sourceName]], fields=[${fields.mkString(", ")}])" - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala index 06d94aa..973c2f3 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/TableSourceITCase.scala @@ -83,4 +83,23 @@ class TableSourceITCase extends StreamingMultipleProgramsTestBase { "Williams,4.68") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + + @Test + def testCsvTableSourceWithFilterable(): Unit = { + StreamITCase.testResults = mutable.MutableList() + val tableName = "MyTable" + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerTableSource(tableName, CommonTestData.getFilterableTableSource) + tEnv.scan(tableName) + .where("amount > 4 && price < 9") + .select("id, name") + .addSink(new StreamITCase.StringSink) + + env.execute() + + val expected = mutable.MutableList( + "5,Record_5", "6,Record_6", "7,Record_7", "8,Record_8") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala deleted file mode 100644 index fe19c8e..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/rules/util/RexProgramProjectExtractorTest.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.rules.util - -import java.math.BigDecimal - -import org.apache.calcite.adapter.java.JavaTypeFactory -import org.apache.calcite.jdbc.JavaTypeFactoryImpl -import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem} -import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder} -import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR} -import org.apache.calcite.sql.fun.SqlStdOperatorTable -import org.apache.flink.table.plan.rules.util.RexProgramProjectExtractor._ -import org.junit.Assert.{assertArrayEquals, assertTrue} -import org.junit.{Before, Test} - -import scala.collection.JavaConverters._ - -/** - * This class is responsible for testing RexProgramProjectExtractor. - */ -class RexProgramProjectExtractorTest { - private var typeFactory: JavaTypeFactory = _ - private var rexBuilder: RexBuilder = _ - private var allFieldTypes: Seq[RelDataType] = _ - private val allFieldNames = List("name", "id", "amount", "price") - - @Before - def setUp(): Unit = { - typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT) - rexBuilder = new RexBuilder(typeFactory) - allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_)) - } - - @Test - def testExtractRefInputFields(): Unit = { - val usedFields = extractRefInputFields(buildRexProgram()) - assertArrayEquals(usedFields, Array(2, 3, 1)) - } - - @Test - def testRewriteRexProgram(): Unit = { - val originRexProgram = buildRexProgram() - assertTrue(extractExprStrList(originRexProgram).sameElements(Array( - "$0", - "$1", - "$2", - "$3", - "*($t2, $t3)", - "100", - "<($t4, $t5)", - "6", - ">($t1, $t7)", - "AND($t6, $t8)"))) - // use amount, id, price fields to create a new RexProgram - val usedFields = Array(2, 3, 1) - val types = usedFields.map(allFieldTypes(_)).toList.asJava - val names = usedFields.map(allFieldNames(_)).toList.asJava - val inputRowType = typeFactory.createStructType(types, names) - val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder) - assertTrue(extractExprStrList(newRexProgram).sameElements(Array( - "$0", - "$1", - "$2", - "*($t0, $t1)", - "100", - "<($t3, $t4)", - "6", - ">($t2, $t6)", - "AND($t5, $t7)"))) - } - - private def buildRexProgram(): RexProgram = { - val types = allFieldTypes.asJava - val names = allFieldNames.asJava - val inputRowType = typeFactory.createStructType(types, names) - val builder = new RexProgramBuilder(inputRowType, rexBuilder) - val t0 = rexBuilder.makeInputRef(types.get(2), 2) - val t1 = rexBuilder.makeInputRef(types.get(1), 1) - val t2 = rexBuilder.makeInputRef(types.get(3), 3) - val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2)) - val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L)) - val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L)) - // project: amount, amount * price - builder.addProject(t0, "amount") - builder.addProject(t3, "total") - // condition: amount * price < 100 and id > 6 - val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4)) - val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5)) - val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava)) - builder.addCondition(t8) - builder.getProgram - } - - /** - * extract all expression string list from input RexProgram expression lists - * - * @param rexProgram input RexProgram instance to analyze - * @return all expression string list of input RexProgram expression lists - */ - private def extractExprStrList(rexProgram: RexProgram) = { - rexProgram.getExprList.asScala.map(_.toString) - } - -}