[FLINK-3934] [tableAPI] Check for equi-join predicates before translation.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3d6ce294 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3d6ce294 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3d6ce294 Branch: refs/heads/master Commit: 3d6ce294123d90ca7bb029d4041f29a1ae1ccd81 Parents: 960f4ac Author: Fabian Hueske <fhue...@apache.org> Authored: Thu May 19 01:47:29 2016 +0200 Committer: Fabian Hueske <fhue...@apache.org> Committed: Thu May 19 14:11:26 2016 +0200 ---------------------------------------------------------------------- .../table/plan/nodes/dataset/DataSetJoin.scala | 67 ++------------------ .../plan/rules/dataSet/DataSetJoinRule.scala | 58 +++++++++-------- .../flink/api/scala/batch/sql/JoinITCase.scala | 2 +- .../api/scala/batch/table/JoinITCase.scala | 7 +- 4 files changed, 44 insertions(+), 90 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala index 4f24f4e..cdf7461 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala @@ -23,7 +23,6 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.JoinInfo import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelWriter, BiRel, RelNode} -import org.apache.calcite.sql.fun.SqlStdOperatorTable import org.apache.calcite.util.mapping.IntPair import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint import org.apache.flink.api.common.typeinfo.TypeInformation @@ -31,12 +30,11 @@ import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table.codegen.CodeGenerator import org.apache.flink.api.table.runtime.FlatJoinRunner -import org.apache.flink.api.table.typeutils.TypeConverter +import org.apache.flink.api.table.typeutils.TypeConverter.determineReturnType import org.apache.flink.api.table.{BatchTableEnvironment, TableException} import org.apache.flink.api.common.functions.FlatJoinFunction -import TypeConverter.determineReturnType import scala.collection.mutable.ArrayBuffer -import org.apache.calcite.rex.{RexInputRef, RexCall, RexNode} +import org.apache.calcite.rex.RexNode import scala.collection.JavaConverters._ import scala.collection.JavaConversions._ @@ -60,8 +58,6 @@ class DataSetJoin( extends BiRel(cluster, traitSet, left, right) with DataSetRel { - val translatable = canBeTranslated - override def deriveRowType() = rowType override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { @@ -92,19 +88,12 @@ class DataSetJoin( override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { - if (!translatable) { - // join cannot be translated. Make huge costs - planner.getCostFactory.makeHugeCost() - } else { - // join can be translated. Compute cost estimate - val children = this.getInputs - children.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, child) => - val rowCnt = metadata.getRowCount(child) - val rowSize = this.estimateRowSize(child.getRowType) - cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize)) - } + val children = this.getInputs + children.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, child) => + val rowCnt = metadata.getRowCount(child) + val rowSize = this.estimateRowSize(child.getRowType) + cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * rowSize)) } - } override def translateToPlan( @@ -204,48 +193,6 @@ class DataSetJoin( .`with`(joinFun).name(joinOpName).asInstanceOf[DataSet[Any]] } - private def canBeTranslated: Boolean = { - - val equiCondition = - joinInfo.getEquiCondition(left, right, cluster.getRexBuilder) - - // joins require at least one equi-condition - if (equiCondition.isAlwaysTrue) { - false - } - else { - // check that all equality predicates refer to field refs only (not computed expressions) - // Note: Calcite treats equality predicates on expressions as non-equi predicates - joinCondition match { - - // conjunction of join predicates - case c: RexCall if c.getOperator.equals(SqlStdOperatorTable.AND) => - - c.getOperands.asScala - // look at equality predicates only - .filter { o => - o.isInstanceOf[RexCall] && - o.asInstanceOf[RexCall].getOperator.equals(SqlStdOperatorTable.EQUALS) - } - // check that both children are field references - .map { o => - o.asInstanceOf[RexCall].getOperands.get(0).isInstanceOf[RexInputRef] && - o.asInstanceOf[RexCall].getOperands.get(1).isInstanceOf[RexInputRef] - } - // any equality predicate that does not refer to a field reference? - .reduce( (a, b) => a && b) - - // single equi-join predicate - case c: RexCall if c.getOperator.equals(SqlStdOperatorTable.EQUALS) => - c.getOperands.get(0).isInstanceOf[RexInputRef] && - c.getOperands.get(1).isInstanceOf[RexInputRef] - case _ => - false - } - } - - } - private def joinSelectionToString: String = { rowType.getFieldNames.asScala.toList.mkString(", ") } http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala index 55100d2..f3bd402 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala @@ -25,6 +25,7 @@ import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.logical.LogicalJoin import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table.plan.nodes.dataset.{DataSetJoin, DataSetConvention} + import scala.collection.JavaConversions._ class DataSetJoinRule @@ -32,40 +33,45 @@ class DataSetJoinRule classOf[LogicalJoin], Convention.NONE, DataSetConvention.INSTANCE, - "DataSetJoinRule") - { + "DataSetJoinRule") { - /** - * Only translate INNER joins for now - */ override def matches(call: RelOptRuleCall): Boolean = { val join: LogicalJoin = call.rel(0).asInstanceOf[LogicalJoin] - join.getJoinType.equals(JoinRelType.INNER) + + val joinInfo = join.analyzeCondition + + // joins require an equi-condition or a conjunctive predicate with at least one equi-condition + val hasValidCondition = !joinInfo.pairs().isEmpty + // only inner joins are supported at the moment + val isInnerJoin = join.getJoinType.equals(JoinRelType.INNER) + + // check that condition is valid and inner join + hasValidCondition && isInnerJoin } - def convert(rel: RelNode): RelNode = { + override def convert(rel: RelNode): RelNode = { - val join: LogicalJoin = rel.asInstanceOf[LogicalJoin] - val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE) - val convLeft: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE) - val convRight: RelNode = RelOptRule.convert(join.getInput(1), DataSetConvention.INSTANCE) - val joinInfo = join.analyzeCondition + val join: LogicalJoin = rel.asInstanceOf[LogicalJoin] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE) + val convLeft: RelNode = RelOptRule.convert(join.getInput(0), DataSetConvention.INSTANCE) + val convRight: RelNode = RelOptRule.convert(join.getInput(1), DataSetConvention.INSTANCE) + val joinInfo = join.analyzeCondition - new DataSetJoin( - rel.getCluster, - traitSet, - convLeft, - convRight, - rel.getRowType, - join.getCondition, - join.getRowType, - joinInfo, - joinInfo.pairs.toList, - JoinType.INNER, - null, - description) - } + new DataSetJoin( + rel.getCluster, + traitSet, + convLeft, + convRight, + rel.getRowType, + join.getCondition, + join.getRowType, + joinInfo, + joinInfo.pairs.toList, + JoinType.INNER, + null, + description) } +} object DataSetJoinRule { val INSTANCE: RelOptRule = new DataSetJoinRule http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala index 40d7546..d388c33 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala @@ -190,7 +190,7 @@ class JoinITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } - @Test(expected = classOf[TableException]) + @Test(expected = classOf[PlanGenException]) def testJoinNoEqualityPredicate(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala index 1a9be93..ae76ace 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala @@ -21,8 +21,9 @@ package org.apache.flink.api.scala.batch.table import org.apache.flink.api.scala._ import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.table.plan.PlanGenException +import org.apache.flink.api.table.{ValidationException, Row, TableEnvironment} import org.apache.flink.api.table.expressions.Literal -import org.apache.flink.api.table.{Row, TableEnvironment, TableException, ValidationException} import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils} import org.junit._ @@ -139,7 +140,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) .select('c, 'g) } - @Test(expected = classOf[TableException]) + @Test(expected = classOf[PlanGenException]) def testNoEqualityJoinPredicate1(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) @@ -153,7 +154,7 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) .select('c, 'g).collect() } - @Test(expected = classOf[TableException]) + @Test(expected = classOf[PlanGenException]) def testNoEqualityJoinPredicate2(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env)