[FLINK-3934] [tableAPI] Check for equi-join predicates before translation.

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3d6ce294
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3d6ce294
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3d6ce294

Branch: refs/heads/master
Commit: 3d6ce294123d90ca7bb029d4041f29a1ae1ccd81
Parents: 960f4ac
Author: Fabian Hueske <fhue...@apache.org>
Authored: Thu May 19 01:47:29 2016 +0200
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Thu May 19 14:11:26 2016 +0200

----------------------------------------------------------------------
 .../table/plan/nodes/dataset/DataSetJoin.scala  | 67 ++------------------
 .../plan/rules/dataSet/DataSetJoinRule.scala    | 58 +++++++++--------
 .../flink/api/scala/batch/sql/JoinITCase.scala  |  2 +-
 .../api/scala/batch/table/JoinITCase.scala      |  7 +-
 4 files changed, 44 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala
index 4f24f4e..cdf7461 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetJoin.scala
@@ -23,7 +23,6 @@ import org.apache.calcite.rel.`type`.RelDataType
 import org.apache.calcite.rel.core.JoinInfo
 import org.apache.calcite.rel.metadata.RelMetadataQuery
 import org.apache.calcite.rel.{RelWriter, BiRel, RelNode}
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
 import org.apache.calcite.util.mapping.IntPair
 import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
 import org.apache.flink.api.common.typeinfo.TypeInformation
@@ -31,12 +30,11 @@ import org.apache.flink.api.java.DataSet
 import org.apache.flink.api.java.operators.join.JoinType
 import org.apache.flink.api.table.codegen.CodeGenerator
 import org.apache.flink.api.table.runtime.FlatJoinRunner
-import org.apache.flink.api.table.typeutils.TypeConverter
+import org.apache.flink.api.table.typeutils.TypeConverter.determineReturnType
 import org.apache.flink.api.table.{BatchTableEnvironment, TableException}
 import org.apache.flink.api.common.functions.FlatJoinFunction
-import TypeConverter.determineReturnType
 import scala.collection.mutable.ArrayBuffer
-import org.apache.calcite.rex.{RexInputRef, RexCall, RexNode}
+import org.apache.calcite.rex.RexNode
 
 import scala.collection.JavaConverters._
 import scala.collection.JavaConversions._
@@ -60,8 +58,6 @@ class DataSetJoin(
   extends BiRel(cluster, traitSet, left, right)
   with DataSetRel {
 
-  val translatable = canBeTranslated
-
   override def deriveRowType() = rowType
 
   override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): 
RelNode = {
@@ -92,19 +88,12 @@ class DataSetJoin(
 
   override def computeSelfCost (planner: RelOptPlanner, metadata: 
RelMetadataQuery): RelOptCost = {
 
-    if (!translatable) {
-      // join cannot be translated. Make huge costs
-      planner.getCostFactory.makeHugeCost()
-    } else {
-      // join can be translated. Compute cost estimate
-      val children = this.getInputs
-      children.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, 
child) =>
-        val rowCnt = metadata.getRowCount(child)
-        val rowSize = this.estimateRowSize(child.getRowType)
-        cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 
rowSize))
-      }
+    val children = this.getInputs
+    children.foldLeft(planner.getCostFactory.makeCost(0, 0, 0)) { (cost, 
child) =>
+      val rowCnt = metadata.getRowCount(child)
+      val rowSize = this.estimateRowSize(child.getRowType)
+      cost.plus(planner.getCostFactory.makeCost(rowCnt, rowCnt, rowCnt * 
rowSize))
     }
-
   }
 
   override def translateToPlan(
@@ -204,48 +193,6 @@ class DataSetJoin(
       .`with`(joinFun).name(joinOpName).asInstanceOf[DataSet[Any]]
   }
 
-  private def canBeTranslated: Boolean = {
-
-    val equiCondition =
-      joinInfo.getEquiCondition(left, right, cluster.getRexBuilder)
-
-    // joins require at least one equi-condition
-    if (equiCondition.isAlwaysTrue) {
-      false
-    }
-    else {
-      // check that all equality predicates refer to field refs only (not 
computed expressions)
-      //   Note: Calcite treats equality predicates on expressions as non-equi 
predicates
-      joinCondition match {
-
-        // conjunction of join predicates
-        case c: RexCall if c.getOperator.equals(SqlStdOperatorTable.AND) =>
-
-          c.getOperands.asScala
-            // look at equality predicates only
-            .filter { o =>
-            o.isInstanceOf[RexCall] &&
-              
o.asInstanceOf[RexCall].getOperator.equals(SqlStdOperatorTable.EQUALS)
-          }
-            // check that both children are field references
-            .map { o =>
-            
o.asInstanceOf[RexCall].getOperands.get(0).isInstanceOf[RexInputRef] &&
-              
o.asInstanceOf[RexCall].getOperands.get(1).isInstanceOf[RexInputRef]
-          }
-            // any equality predicate that does not refer to a field reference?
-            .reduce( (a, b) => a && b)
-
-        // single equi-join predicate
-        case c: RexCall if c.getOperator.equals(SqlStdOperatorTable.EQUALS) =>
-          c.getOperands.get(0).isInstanceOf[RexInputRef] &&
-            c.getOperands.get(1).isInstanceOf[RexInputRef]
-        case _ =>
-          false
-      }
-    }
-
-  }
-
   private def joinSelectionToString: String = {
     rowType.getFieldNames.asScala.toList.mkString(", ")
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala
index 55100d2..f3bd402 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetJoinRule.scala
@@ -25,6 +25,7 @@ import org.apache.calcite.rel.core.JoinRelType
 import org.apache.calcite.rel.logical.LogicalJoin
 import org.apache.flink.api.java.operators.join.JoinType
 import org.apache.flink.api.table.plan.nodes.dataset.{DataSetJoin, 
DataSetConvention}
+
 import scala.collection.JavaConversions._
 
 class DataSetJoinRule
@@ -32,40 +33,45 @@ class DataSetJoinRule
       classOf[LogicalJoin],
       Convention.NONE,
       DataSetConvention.INSTANCE,
-      "DataSetJoinRule")
-  {
+      "DataSetJoinRule") {
 
-  /**
-   * Only translate INNER joins for now
-   */
   override def matches(call: RelOptRuleCall): Boolean = {
     val join: LogicalJoin = call.rel(0).asInstanceOf[LogicalJoin]
-    join.getJoinType.equals(JoinRelType.INNER)
+
+    val joinInfo = join.analyzeCondition
+
+    // joins require an equi-condition or a conjunctive predicate with at 
least one equi-condition
+    val hasValidCondition = !joinInfo.pairs().isEmpty
+    // only inner joins are supported at the moment
+    val isInnerJoin = join.getJoinType.equals(JoinRelType.INNER)
+
+    // check that condition is valid and inner join
+    hasValidCondition && isInnerJoin
   }
 
-    def convert(rel: RelNode): RelNode = {
+  override def convert(rel: RelNode): RelNode = {
 
-      val join: LogicalJoin = rel.asInstanceOf[LogicalJoin]
-      val traitSet: RelTraitSet = 
rel.getTraitSet.replace(DataSetConvention.INSTANCE)
-      val convLeft: RelNode = RelOptRule.convert(join.getInput(0), 
DataSetConvention.INSTANCE)
-      val convRight: RelNode = RelOptRule.convert(join.getInput(1), 
DataSetConvention.INSTANCE)
-      val joinInfo = join.analyzeCondition
+    val join: LogicalJoin = rel.asInstanceOf[LogicalJoin]
+    val traitSet: RelTraitSet = 
rel.getTraitSet.replace(DataSetConvention.INSTANCE)
+    val convLeft: RelNode = RelOptRule.convert(join.getInput(0), 
DataSetConvention.INSTANCE)
+    val convRight: RelNode = RelOptRule.convert(join.getInput(1), 
DataSetConvention.INSTANCE)
+    val joinInfo = join.analyzeCondition
 
-        new DataSetJoin(
-          rel.getCluster,
-          traitSet,
-          convLeft,
-          convRight,
-          rel.getRowType,
-          join.getCondition,
-          join.getRowType,
-          joinInfo,
-          joinInfo.pairs.toList,
-          JoinType.INNER,
-          null,
-          description)
-    }
+    new DataSetJoin(
+      rel.getCluster,
+      traitSet,
+      convLeft,
+      convRight,
+      rel.getRowType,
+      join.getCondition,
+      join.getRowType,
+      joinInfo,
+      joinInfo.pairs.toList,
+      JoinType.INNER,
+      null,
+      description)
   }
+}
 
 object DataSetJoinRule {
   val INSTANCE: RelOptRule = new DataSetJoinRule

http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
index 40d7546..d388c33 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/JoinITCase.scala
@@ -190,7 +190,7 @@ class JoinITCase(
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
-  @Test(expected = classOf[TableException])
+  @Test(expected = classOf[PlanGenException])
   def testJoinNoEqualityPredicate(): Unit = {
 
     val env = ExecutionEnvironment.getExecutionEnvironment

http://git-wip-us.apache.org/repos/asf/flink/blob/3d6ce294/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala
index 1a9be93..ae76ace 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/table/JoinITCase.scala
@@ -21,8 +21,9 @@ package org.apache.flink.api.scala.batch.table
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.table._
 import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.api.table.plan.PlanGenException
+import org.apache.flink.api.table.{ValidationException, Row, TableEnvironment}
 import org.apache.flink.api.table.expressions.Literal
-import org.apache.flink.api.table.{Row, TableEnvironment, TableException, 
ValidationException}
 import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
 import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
 import org.junit._
@@ -139,7 +140,7 @@ class JoinITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(mode)
       .select('c, 'g)
   }
 
-  @Test(expected = classOf[TableException])
+  @Test(expected = classOf[PlanGenException])
   def testNoEqualityJoinPredicate1(): Unit = {
     val env: ExecutionEnvironment = 
ExecutionEnvironment.getExecutionEnvironment
     val tEnv = TableEnvironment.getTableEnvironment(env)
@@ -153,7 +154,7 @@ class JoinITCase(mode: TestExecutionMode) extends 
MultipleProgramsTestBase(mode)
       .select('c, 'g).collect()
   }
 
-  @Test(expected = classOf[TableException])
+  @Test(expected = classOf[PlanGenException])
   def testNoEqualityJoinPredicate2(): Unit = {
     val env: ExecutionEnvironment = 
ExecutionEnvironment.getExecutionEnvironment
     val tEnv = TableEnvironment.getTableEnvironment(env)

Reply via email to