This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit de22d7c0d5afd3233ab8e174ed4d837e08438ab3 Author: JingsongLi <lzljs3620...@aliyun.com> AuthorDate: Thu Aug 22 12:46:52 2019 +0200 [FLINK-13774][table-planner-blink] Modify filterable table source accept ResolvedExpression --- .../planner/plan/utils/RexNodeExtractor.scala | 52 ++++++++++++---------- .../table/planner/utils/testTableSources.scala | 12 ++--- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala index f938c79..b4535bf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala @@ -28,6 +28,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.utils.Logging import org.apache.flink.table.runtime.functions.SqlDateTimeUtils.unixTimestampToLocalDateTime import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType +import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.util.Preconditions @@ -294,9 +295,9 @@ class RexNodeToExpressionConverter( inputNames: Array[String], functionCatalog: FunctionCatalog, timeZone: TimeZone) - extends RexVisitor[Option[Expression]] { + extends RexVisitor[Option[ResolvedExpression]] { - override def visitInputRef(inputRef: RexInputRef): Option[Expression] = { + override def visitInputRef(inputRef: RexInputRef): Option[ResolvedExpression] = { Preconditions.checkArgument(inputRef.getIndex < inputNames.length) Some(new FieldReferenceExpression( inputNames(inputRef.getIndex), @@ -306,14 +307,14 @@ class RexNodeToExpressionConverter( )) } - override def visitTableInputRef(rexTableInputRef: RexTableInputRef): Option[Expression] = + override def visitTableInputRef(rexTableInputRef: RexTableInputRef): Option[ResolvedExpression] = visitInputRef(rexTableInputRef) - override def visitLocalRef(localRef: RexLocalRef): Option[Expression] = { + override def visitLocalRef(localRef: RexLocalRef): Option[ResolvedExpression] = { throw new TableException("Bug: RexLocalRef should have been expanded") } - override def visitLiteral(literal: RexLiteral): Option[Expression] = { + override def visitLiteral(literal: RexLiteral): Option[ResolvedExpression] = { // TODO support SqlTrimFunction.Flag literal.getValue match { case _: SqlTrimFunction.Flag => return None @@ -384,53 +385,58 @@ class RexNodeToExpressionConverter( fromLogicalTypeToDataType(literalType))) } - override def visitCall(rexCall: RexCall): Option[Expression] = { + override def visitCall(rexCall: RexCall): Option[ResolvedExpression] = { val operands = rexCall.getOperands.map( operand => operand.accept(this).orNull ) + val outputType = fromLogicalTypeToDataType(FlinkTypeFactory.toLogicalType(rexCall.getType)) + // return null if we cannot translate all the operands of the call if (operands.contains(null)) { None } else { rexCall.getOperator match { case SqlStdOperatorTable.OR => - Option(operands.reduceLeft { (l, r) => unresolvedCall(OR, l, r) }) + Option(operands.reduceLeft((l, r) => new CallExpression(OR, Seq(l, r), outputType))) case SqlStdOperatorTable.AND => - Option(operands.reduceLeft { (l, r) => unresolvedCall(AND, l, r) }) + Option(operands.reduceLeft((l, r) => new CallExpression(AND, Seq(l, r), outputType))) case SqlStdOperatorTable.CAST => - Option(unresolvedCall(CAST, operands.head, - typeLiteral(fromLogicalTypeToDataType( - FlinkTypeFactory.toLogicalType(rexCall.getType))))) + Option(new CallExpression(CAST, Seq(operands.head, typeLiteral(outputType)), outputType)) case function: SqlFunction => - lookupFunction(replace(function.getName), operands) + lookupFunction(replace(function.getName), operands, outputType) case postfix: SqlPostfixOperator => - lookupFunction(replace(postfix.getName), operands) + lookupFunction(replace(postfix.getName), operands, outputType) case operator@_ => - lookupFunction(replace(s"${operator.getKind}"), operands) + lookupFunction(replace(s"${operator.getKind}"), operands, outputType) } } } - override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[Expression] = None + override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = None - override def visitCorrelVariable(correlVariable: RexCorrelVariable): Option[Expression] = None + override def visitCorrelVariable( + correlVariable: RexCorrelVariable): Option[ResolvedExpression] = None - override def visitRangeRef(rangeRef: RexRangeRef): Option[Expression] = None + override def visitRangeRef(rangeRef: RexRangeRef): Option[ResolvedExpression] = None - override def visitSubQuery(subQuery: RexSubQuery): Option[Expression] = None + override def visitSubQuery(subQuery: RexSubQuery): Option[ResolvedExpression] = None - override def visitDynamicParam(dynamicParam: RexDynamicParam): Option[Expression] = None + override def visitDynamicParam(dynamicParam: RexDynamicParam): Option[ResolvedExpression] = None - override def visitOver(over: RexOver): Option[Expression] = None + override def visitOver(over: RexOver): Option[ResolvedExpression] = None - override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): Option[Expression] = None + override def visitPatternFieldRef( + fieldRef: RexPatternFieldRef): Option[ResolvedExpression] = None - private def lookupFunction(name: String, operands: Seq[Expression]): Option[Expression] = { + private def lookupFunction( + name: String, + operands: Seq[ResolvedExpression], + outputType: DataType): Option[ResolvedExpression] = { Try(functionCatalog.lookupFunction(name)) match { case Success(f: java.util.Optional[FunctionLookup.Result]) => if (f.isPresent) { - Some(unresolvedCall(f.get().getFunctionDefinition, operands: _*)) + Some(new CallExpression(f.get().getFunctionDefinition, operands, outputType)) } else { None } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala index 44cb4eb..24fab42 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/utils/testTableSources.scala @@ -29,7 +29,7 @@ import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment import org.apache.flink.table.api.{TableSchema, Types} import org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedCall -import org.apache.flink.table.expressions.{Expression, FieldReferenceExpression, UnresolvedCallExpression, ValueLiteralExpression} +import org.apache.flink.table.expressions.{CallExpression, Expression, FieldReferenceExpression, ValueLiteralExpression} import org.apache.flink.table.functions.BuiltInFunctionDefinitions import org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row @@ -398,12 +398,12 @@ class TestFilterableTableSource( private def shouldPushDown(expr: Expression): Boolean = { expr match { - case expr: UnresolvedCallExpression if expr.getChildren.size() == 2 => shouldPushDown(expr) + case expr: CallExpression if expr.getChildren.size() == 2 => shouldPushDown(expr) case _ => false } } - private def shouldPushDown(binExpr: UnresolvedCallExpression): Boolean = { + private def shouldPushDown(binExpr: CallExpression): Boolean = { val children = binExpr.getChildren require(children.size() == 2) (children.head, children.last) match { @@ -419,13 +419,13 @@ class TestFilterableTableSource( private def shouldKeep(row: Row): Boolean = { filterPredicates.isEmpty || filterPredicates.forall { - case expr: UnresolvedCallExpression if expr.getChildren.size() == 2 => + case expr: CallExpression if expr.getChildren.size() == 2 => binaryFilterApplies(expr, row) case expr => throw new RuntimeException(expr + " not supported!") } } - private def binaryFilterApplies(binExpr: UnresolvedCallExpression, row: Row): Boolean = { + private def binaryFilterApplies(binExpr: CallExpression, row: Row): Boolean = { val children = binExpr.getChildren require(children.size() == 2) val (lhsValue, rhsValue) = extractValues(binExpr, row) @@ -447,7 +447,7 @@ class TestFilterableTableSource( } private def extractValues( - binExpr: UnresolvedCallExpression, + binExpr: CallExpression, row: Row): (Comparable[Any], Comparable[Any]) = { val children = binExpr.getChildren require(children.size() == 2)