Github user fhueske commented on a diff in the pull request:
https://github.com/apache/flink/pull/1567#discussion_r51732524
--- Diff:
flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataset/DataSetJoinRule.scala
---
@@ -46,12 +57,93 @@ class DataSetJoinRule
convRight,
rel.getRowType,
join.toString,
- Array[Int](),
- Array[Int](),
- JoinType.INNER,
+ joinKeys._1,
+ joinKeys._2,
+ sqlJoinTypeToFlinkJoinType(join.getJoinType),
null,
null)
}
+
+ private def getJoinKeys(join: FlinkJoin): (Array[Int], Array[Int]) = {
+ val joinKeys = ArrayBuffer.empty[(Int, Int)]
+ parseJoinRexNode(join.getCondition.asInstanceOf[RexCall], joinKeys)
+
+ val joinedRowType= join.getRowType
+ val leftRowType = join.getLeft.getRowType
+ val rightRowType = join.getRight.getRowType
+
+ // The fetched join key index from Calcite is based on joined row
type, we need
+ // the join key index based on left/right input row type.
+ val joinKeyPairs: ArrayBuffer[(Int, Int)] = joinKeys.map {
+ case (first, second) =>
+ var leftIndex = findIndexInSingleInput(first, joinedRowType,
leftRowType)
+ if (leftIndex == -1) {
+ leftIndex = findIndexInSingleInput(second, joinedRowType,
leftRowType)
+ if (leftIndex == -1) {
+ throw new PlanGenException("Invalid join condition, could not
find " +
+ joinedRowType.getFieldNames.get(first) + " and " +
+ joinedRowType.getFieldNames.get(second) + " in left table")
+ }
+ val rightIndex = findIndexInSingleInput(first, joinedRowType,
rightRowType)
+ if (rightIndex == -1) {
+ throw new PlanGenException("Invalid join condition could not
find " +
+ joinedRowType.getFieldNames.get(first) + " in right table")
+ }
+ (leftIndex, rightIndex)
+ } else {
+ val rightIndex = findIndexInSingleInput(second, joinedRowType,
rightRowType)
+ if (rightIndex == -1) {
+ throw new PlanGenException("Invalid join condition could not
find " +
+ joinedRowType.getFieldNames.get(second) + " in right table")
+ }
+ (leftIndex, rightIndex)
+ }
+ }
+
+ val joinKeysPair = joinKeyPairs.unzip
+
+ (joinKeysPair._1.toArray, joinKeysPair._2.toArray)
+ }
+
+ // Parse the join condition recursively, find all the join keys' index.
+ private def parseJoinRexNode(condition: RexCall, joinKeys:
ArrayBuffer[(Int, Int)]): Unit = {
--- End diff --
We should extract all conjunctive equality conditions and ignore all other
conditions. If there are no conjunctive equality conditions, we should generate
a data set cross. All non-equality conditions need to be evaluated in a join or
cross function.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---