This is an automated email from the ASF dual-hosted git repository. shengkai pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 202eacb2ee9 [FLINK-26361][hive] Create LogicalFilter with CorrelationId to fix failed to rewrite subquery in hive dialect (#18920) 202eacb2ee9 is described below commit 202eacb2ee96607be8c7c8c569db62a296539e3a Author: yuxia Luo <luoyu...@alumni.sjtu.edu.cn> AuthorDate: Fri Jul 1 10:29:26 2022 +0800 [FLINK-26361][hive] Create LogicalFilter with CorrelationId to fix failed to rewrite subquery in hive dialect (#18920) --- .../delegation/hive/HiveParserCalcitePlanner.java | 15 ++++- .../planner/delegation/hive/HiveParserUtils.java | 37 ++++++++++++ .../hive/copy/HiveParserBaseSemanticAnalyzer.java | 70 ++++++++++++++++++++++ .../src/test/resources/query-test/sub_query.q | 17 ++++++ 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java index 60f5cda3283..69f4a8d7a23 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserCalcitePlanner.java @@ -75,7 +75,6 @@ import org.apache.calcite.rel.core.SetOp; import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalCorrelate; -import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalIntersect; import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.logical.LogicalMinus; @@ -910,7 +909,12 @@ public class HiveParserCalcitePlanner { RexNode factoredFilterExpr = RexUtil.pullFactors(cluster.getRexBuilder(), convertedFilterExpr) .accept(funcConverter); - RelNode filterRel = LogicalFilter.create(srcRel, factoredFilterExpr); + RelNode filterRel = + HiveParserUtils.genFilterRelNode( + srcRel, + factoredFilterExpr, + HiveParserBaseSemanticAnalyzer.getVariablesSetForFilter( + factoredFilterExpr)); relToRowResolver.put(filterRel, relToRowResolver.get(srcRel)); relToHiveColNameCalcitePosMap.put(filterRel, hiveColNameToCalcitePos); @@ -1070,7 +1074,12 @@ public class HiveParserCalcitePlanner { .convert(subQueryExpr) .accept(funcConverter); - RelNode filterRel = LogicalFilter.create(srcRel, convertedFilterLHS); + RelNode filterRel = + HiveParserUtils.genFilterRelNode( + srcRel, + convertedFilterLHS, + HiveParserBaseSemanticAnalyzer.getVariablesSetForFilter( + convertedFilterLHS)); relToHiveColNameCalcitePosMap.put(filterRel, relToHiveColNameCalcitePosMap.get(srcRel)); relToRowResolver.put(filterRel, relToRowResolver.get(srcRel)); diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java index ae682206abe..4504572afbd 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserUtils.java @@ -63,6 +63,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalValues; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -162,6 +163,13 @@ public class HiveParserUtils { "org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList"); private static final boolean useShadedImmutableList = shadedImmutableListClz != null; + private static final Class immutableSetClz = + HiveReflectionUtils.tryGetClass("com.google.common.collect.ImmutableSet"); + private static final Class shadedImmutableSetClz = + HiveReflectionUtils.tryGetClass( + "org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableSet"); + private static final boolean useShadedImmutableSet = shadedImmutableSetClz != null; + private HiveParserUtils() {} public static void removeASTChild(HiveParserASTNode node) { @@ -317,6 +325,17 @@ public class HiveParserUtils { } } + // converts a collection to guava ImmutableSet + private static Object toImmutableSet(Collection collection) { + try { + Class clz = useShadedImmutableSet ? shadedImmutableSetClz : immutableSetClz; + return HiveReflectionUtils.invokeMethod( + clz, null, "copyOf", new Class[] {Collection.class}, new Object[] {collection}); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new FlinkHiveException("Failed to create immutable set", e); + } + } + // creates LogicalValues node public static RelNode genValuesRelNode( RelOptCluster cluster, RelDataType rowType, List<List<RexLiteral>> rows) { @@ -339,6 +358,24 @@ public class HiveParserUtils { } } + // creates LogicFilter node + public static RelNode genFilterRelNode( + RelNode relNode, RexNode rexNode, Collection<CorrelationId> variables) { + Class[] argTypes = + new Class[] { + RelNode.class, + RexNode.class, + useShadedImmutableSet ? shadedImmutableSetClz : immutableSetClz + }; + Method method = HiveReflectionUtils.tryGetMethod(LogicalFilter.class, "create", argTypes); + Preconditions.checkState(method != null, "Cannot get the method to create a LogicalFilter"); + try { + return (LogicalFilter) method.invoke(null, relNode, rexNode, toImmutableSet(variables)); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new FlinkHiveException("Failed to create LogicalFilter", e); + } + } + /** Proxy to {@link RexSubQuery#in(RelNode, com.google.common.collect.ImmutableList)}. */ public static RexSubQuery rexSubQueryIn(RelNode relNode, Collection<RexNode> rexNodes) { Class[] argTypes = new Class[] {RelNode.class, null}; diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java index b4e6d4d4672..4a9fd748312 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/copy/HiveParserBaseSemanticAnalyzer.java @@ -40,12 +40,17 @@ import org.antlr.runtime.tree.TreeVisitorAction; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexWindowBound; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlKind; @@ -74,6 +79,7 @@ import org.apache.hadoop.hive.ql.metadata.InvalidTableException; import org.apache.hadoop.hive.ql.metadata.Partition; import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.metadata.VirtualColumn; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter; import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.WindowingSpec; @@ -1862,6 +1868,70 @@ public class HiveParserBaseSemanticAnalyzer { rows); } + /** + * traverse the given node to find all correlated variables, the main logic is from {@link + * HiveFilter#getVariablesSet()}. + */ + public static Set<CorrelationId> getVariablesSetForFilter(RexNode rexNode) { + Set<CorrelationId> correlationVariables = new HashSet<>(); + if (rexNode instanceof RexSubQuery) { + RexSubQuery rexSubQuery = (RexSubQuery) rexNode; + // we expect correlated variables in Filter only for now. + // also check case where operator has 0 inputs .e.g TableScan + if (rexSubQuery.rel.getInputs().isEmpty()) { + return correlationVariables; + } + RelNode input = rexSubQuery.rel.getInput(0); + while (input != null + && !(input instanceof LogicalFilter) + && input.getInputs().size() >= 1) { + // we don't expect corr vars within UNION for now + if (input.getInputs().size() > 1) { + if (input instanceof LogicalJoin) { + correlationVariables.addAll( + findCorrelatedVar(((LogicalJoin) input).getCondition())); + } + // todo: throw Unsupported exception when the input isn't LogicalJoin and + // contains correlate variables in FLINK-28317 + return correlationVariables; + } + input = input.getInput(0); + } + if (input instanceof LogicalFilter) { + correlationVariables.addAll( + findCorrelatedVar(((LogicalFilter) input).getCondition())); + } + return correlationVariables; + } + // AND, NOT etc + if (rexNode instanceof RexCall) { + int numOperands = ((RexCall) rexNode).getOperands().size(); + for (int i = 0; i < numOperands; i++) { + RexNode op = ((RexCall) rexNode).getOperands().get(i); + correlationVariables.addAll(getVariablesSetForFilter(op)); + } + } + return correlationVariables; + } + + private static Set<CorrelationId> findCorrelatedVar(RexNode node) { + Set<CorrelationId> allVars = new HashSet<>(); + if (node instanceof RexCall) { + RexCall nd = (RexCall) node; + for (RexNode rn : nd.getOperands()) { + if (rn instanceof RexFieldAccess) { + final RexNode ref = ((RexFieldAccess) rn).getReferenceExpr(); + if (ref instanceof RexCorrelVariable) { + allVars.add(((RexCorrelVariable) ref).id); + } + } else { + allVars.addAll(findCorrelatedVar(rn)); + } + } + } + return allVars; + } + private static void validatePartColumnType( Table tbl, Map<String, String> partSpec, diff --git a/flink-connectors/flink-connector-hive/src/test/resources/query-test/sub_query.q b/flink-connectors/flink-connector-hive/src/test/resources/query-test/sub_query.q new file mode 100644 index 00000000000..26a55cb0446 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/resources/query-test/sub_query.q @@ -0,0 +1,17 @@ +-- SORT_QUERY_RESULTS + +select * from src where src.key in (select c.key from (select * from src b where exists (select a.key from src a where b.value = a.value)) c); + +[+I[1, val1], +I[2, val2], +I[3, val3]] + +select * from src x where x.key in (select y.key from src y where exists (select z.key from src z where y.key = z.key)); + +[+I[1, val1], +I[2, val2], +I[3, val3]] + +select * from src x join src y on x.key = y.key where exists (select * from src z where z.value = x.value and z.value = y.value); + +[+I[1, val1, 1, val1], +I[2, val2, 2, val2], +I[3, val3, 3, val3]] + +select * from (select x.key from src x); + +[+I[1], +I[2], +I[3]]