This is an automated email from the ASF dual-hosted git repository. kurt pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 10ac65d [FLINK-12077][table-runtime-blink] Introduce HashJoinOperator and LongHashJoinGenerator to blink runtime (#8093) 10ac65d is described below commit 10ac65dfd1620c605f2214725ecb558ef10f603c Author: Jingsong Lee <lzljs3620...@aliyun.com> AuthorDate: Tue Apr 2 16:19:37 2019 +0800 [FLINK-12077][table-runtime-blink] Introduce HashJoinOperator and LongHashJoinGenerator to blink runtime (#8093) --- flink-table/flink-table-planner-blink/pom.xml | 16 + .../table/codegen/LongHashJoinGenerator.scala | 341 ++++++++++++++++ .../table/codegen/OperatorCodeGenerator.scala | 81 +++- .../table/codegen/LongHashJoinGeneratorTest.java | 110 +++++ flink-table/flink-table-runtime-blink/pom.xml | 12 + ...Projection.java => GeneratedJoinCondition.java} | 14 +- .../flink/table/generated/GeneratedProjection.java | 2 +- .../table/runtime/TwoInputOperatorWrapper.java | 169 ++++++++ .../flink/table/runtime/join/HashJoinOperator.java | 434 ++++++++++++++++++++ .../runtime/join/Int2HashJoinOperatorTest.java | 448 +++++++++++++++++++++ .../runtime/join/String2HashJoinOperatorTest.java | 335 +++++++++++++++ 11 files changed, 1953 insertions(+), 9 deletions(-) diff --git a/flink-table/flink-table-planner-blink/pom.xml b/flink-table/flink-table-planner-blink/pom.xml index da05406..f8ab62e 100644 --- a/flink-table/flink-table-planner-blink/pom.xml +++ b/flink-table/flink-table-planner-blink/pom.xml @@ -184,6 +184,22 @@ under the License. <version>${project.version}</version> <scope>test</scope> </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-table-runtime-blink</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-streaming-java_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> </dependencies> <build> diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/LongHashJoinGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/LongHashJoinGenerator.scala new file mode 100644 index 0000000..e69bf11 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/LongHashJoinGenerator.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen + +import org.apache.flink.metrics.Gauge +import org.apache.flink.table.`type`.{DateType, InternalType, InternalTypes, RowType, TimestampType} +import org.apache.flink.table.api.TableConfig +import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, BINARY_ROW, baseRowFieldReadAccess, newName} +import org.apache.flink.table.codegen.OperatorCodeGenerator.generateCollect +import org.apache.flink.table.dataformat.{BaseRow, JoinedRow} +import org.apache.flink.table.generated.{GeneratedJoinCondition, GeneratedProjection} +import org.apache.flink.table.runtime.TwoInputOperatorWrapper +import org.apache.flink.table.runtime.hashtable.{LongHashPartition, LongHybridHashTable} +import org.apache.flink.table.runtime.join.HashJoinType +import org.apache.flink.table.typeutils.BinaryRowSerializer + +/** + * Generate a long key hash join operator using [[LongHybridHashTable]]. + */ +object LongHashJoinGenerator { + + def support( + joinType: HashJoinType, + keyType: RowType, + filterNulls: Array[Boolean]): Boolean = { + (joinType == HashJoinType.INNER || + joinType == HashJoinType.SEMI || + joinType == HashJoinType.ANTI || + joinType == HashJoinType.PROBE_OUTER) && + filterNulls.forall(b => b) && + keyType.getFieldTypes.length == 1 && { + val t = keyType.getTypeAt(0) + t == InternalTypes.LONG || t == InternalTypes.INT || t == InternalTypes.SHORT || + t == InternalTypes.BYTE || t == InternalTypes.FLOAT || t == InternalTypes.DOUBLE || + t.isInstanceOf[DateType] || t.isInstanceOf[TimestampType] || t == InternalTypes.TIME + // TODO decimal and multiKeys support. + // TODO All HashJoinType support. + } + } + + private def genGetLongKey( + ctx: CodeGeneratorContext, + keyType: RowType, + keyMapping: Array[Int], + rowTerm: String): String = { + val singleType = keyType.getFieldTypes()(0) + val getCode = baseRowFieldReadAccess(ctx, keyMapping(0), rowTerm, singleType) + val term = singleType match { + case InternalTypes.FLOAT => s"Float.floatToIntBits($getCode)" + case InternalTypes.DOUBLE => s"Double.doubleToLongBits($getCode)" + case _ => getCode + } + s"return $term;" + } + + def genAnyNullsInKeys(keyMapping: Array[Int], rowTerm: String): (String, String) = { + val builder = new StringBuilder() + val anyNullTerm = newName("anyNull") + keyMapping.foreach(key => + builder.append(s"$anyNullTerm |= $rowTerm.isNullAt($key);") + ) + (s""" + |boolean $anyNullTerm = false; + |$builder + """.stripMargin, anyNullTerm) + } + + def genProjection(conf: TableConfig, types: Array[InternalType]): GeneratedProjection = { + val rowType = new RowType(types: _*) + ProjectionCodeGenerator.generateProjection( + CodeGeneratorContext.apply(conf), + "Projection", + rowType, + rowType, + types.indices.toArray) + } + + def gen( + conf: TableConfig, + hashJoinType: HashJoinType, + keyType: RowType, + buildType: RowType, + probeType: RowType, + buildKeyMapping: Array[Int], + probeKeyMapping: Array[Int], + managedMemorySize: Long, + preferredMemorySize: Long, + perRequestSize: Long, + buildRowSize: Int, + buildRowCount: Long, + reverseJoinFunction: Boolean, + condFunc: GeneratedJoinCondition): TwoInputOperatorWrapper[BaseRow, BaseRow, BaseRow] = { + + val buildSer = new BinaryRowSerializer(buildType.getArity) + val probeSer = new BinaryRowSerializer(probeType.getArity) + + val tableTerm = newName("LongHashTable") + val ctx = CodeGeneratorContext(conf) + val buildSerTerm = ctx.addReusableObject(buildSer, "buildSer") + val probeSerTerm = ctx.addReusableObject(probeSer, "probeSer") + + val bGenProj = genProjection(conf, buildType.getFieldTypes) + ctx.addReusableInnerClass(bGenProj.getClassName, bGenProj.getCode) + val pGenProj = genProjection(conf, probeType.getFieldTypes) + ctx.addReusableInnerClass(pGenProj.getClassName, pGenProj.getCode) + ctx.addReusableInnerClass(condFunc.getClassName, condFunc.getCode) + + ctx.addReusableMember(s"${bGenProj.getClassName} buildToBinaryRow;") + val buildProjRefs = ctx.addReusableObject(bGenProj.getReferences, "buildProjRefs") + ctx.addReusableInitStatement( + s"buildToBinaryRow = new ${bGenProj.getClassName}($buildProjRefs);") + + ctx.addReusableMember(s"${pGenProj.getClassName} probeToBinaryRow;") + val probeProjRefs = ctx.addReusableObject(pGenProj.getReferences, "probeProjRefs") + ctx.addReusableInitStatement( + s"probeToBinaryRow = new ${pGenProj.getClassName}($probeProjRefs);") + + ctx.addReusableMember(s"${condFunc.getClassName} condFunc;") + val condRefs = ctx.addReusableObject(condFunc.getReferences, "condRefs") + ctx.addReusableInitStatement(s"condFunc = new ${condFunc.getClassName}($condRefs);") + + val gauge = classOf[Gauge[_]].getCanonicalName + ctx.addReusableOpenStatement( + s""" + |getMetricGroup().gauge("memoryUsedSizeInBytes", new $gauge<Long>() { + | @Override + | public Long getValue() { + | return table.getUsedMemoryInBytes(); + | } + |}); + |getMetricGroup().gauge("numSpillFiles", new $gauge<Long>() { + | @Override + | public Long getValue() { + | return table.getNumSpillFiles(); + | } + |}); + |getMetricGroup().gauge("spillInBytes", new $gauge<Long>() { + | @Override + | public Long getValue() { + | return table.getSpillInBytes(); + | } + |}); + """.stripMargin) + + val tableCode = + s""" + |public class $tableTerm extends ${classOf[LongHybridHashTable].getCanonicalName} { + | + | public $tableTerm() { + | super(getContainingTask().getJobConfiguration(), getContainingTask(), + | $buildSerTerm, $probeSerTerm, + | getContainingTask().getEnvironment().getMemoryManager(), + | ${managedMemorySize}L, ${preferredMemorySize}L, ${perRequestSize}L, + | getContainingTask().getEnvironment().getIOManager(), + | $buildRowSize, + | ${buildRowCount}L / getRuntimeContext().getNumberOfParallelSubtasks()); + | } + | + | @Override + | public long getBuildLongKey($BASE_ROW row) { + | ${genGetLongKey(ctx, keyType, buildKeyMapping, "row")} + | } + | + | @Override + | public long getProbeLongKey($BASE_ROW row) { + | ${genGetLongKey(ctx, keyType, probeKeyMapping, "row")} + | } + | + | @Override + | public $BINARY_ROW probeToBinary($BASE_ROW row) { + | if (row instanceof $BINARY_ROW) { + | return ($BINARY_ROW) row; + | } else { + | return probeToBinaryRow.apply(row); + | } + | } + |} + """.stripMargin + ctx.addReusableInnerClass(tableTerm, tableCode) + + ctx.addReusableNullRow("buildSideNullRow", buildSer.getArity) + ctx.addReusableOutputRecord(new RowType(), classOf[JoinedRow], "joinedRow") + ctx.addReusableMember(s"$tableTerm table;") + ctx.addReusableOpenStatement(s"table = new $tableTerm();") + + val (nullCheckBuildCode, nullCheckBuildTerm) = genAnyNullsInKeys(buildKeyMapping, "row") + val (nullCheckProbeCode, nullCheckProbeTerm) = genAnyNullsInKeys(probeKeyMapping, "row") + + def collectCode(term1: String, term2: String) = + if (reverseJoinFunction) { + generateCollect(s"joinedRow.replace($term2, $term1)") + } else { + generateCollect(s"joinedRow.replace($term1, $term2)") + } + + val applyCond = + if (reverseJoinFunction) { + s"condFunc.apply(probeRow, buildIter.getRow())" + } else { + s"condFunc.apply(buildIter.getRow(), probeRow)" + } + + // innerJoin Now. + val joinCode = hashJoinType match { + case HashJoinType.INNER => + s""" + |while (buildIter.advanceNext()) { + | if ($applyCond) { + | ${collectCode("buildIter.getRow()", "probeRow")} + | } + |} + """.stripMargin + case HashJoinType.SEMI => + s""" + |while (buildIter.advanceNext()) { + | if ($applyCond) { + | ${generateCollect("probeRow")} + | break; + | } + |} + """.stripMargin + case HashJoinType.ANTI => + s""" + |boolean matched = false; + |while (buildIter.advanceNext()) { + | if ($applyCond) { + | matched = true; + | break; + | } + |} + |if (!matched) { + | ${generateCollect("probeRow")} + |} + """.stripMargin + case HashJoinType.PROBE_OUTER => + s""" + |boolean matched = false; + |while (buildIter.advanceNext()) { + | if ($applyCond) { + | ${collectCode("buildIter.getRow()", "probeRow")} + | matched = true; + | } + |} + |if (!matched) { + | ${collectCode("buildSideNullRow", "probeRow")} + |} + """.stripMargin + } + + val nullOuterJoin = hashJoinType match { + case HashJoinType.ANTI => + s""" + |else { + | ${generateCollect("row")} + |} + """.stripMargin + case HashJoinType.PROBE_OUTER => + s""" + |else { + | ${collectCode("buildSideNullRow", "row")} + |} + """.stripMargin + case _ => "" + } + + ctx.addReusableMember( + s""" + |private void joinWithNextKey() throws Exception { + | ${classOf[LongHashPartition#MatchIterator].getCanonicalName} buildIter = + | table.getBuildSideIterator(); + | $BASE_ROW probeRow = table.getCurrentProbeRow(); + | if (probeRow == null) { + | throw new RuntimeException("ProbeRow should not be null"); + | } + | $joinCode + |} + """.stripMargin) + + ctx.addReusableCloseStatement( + s""" + |if (this.table != null) { + | this.table.close(); + | this.table.free(); + | this.table = null; + |} + """.stripMargin) + + val genOp = OperatorCodeGenerator.generateTwoInputStreamOperator[BaseRow, BaseRow, BaseRow]( + ctx, + "LongHashJoinOperator", + s""" + |$BASE_ROW row = ($BASE_ROW) element.getValue(); + |$nullCheckBuildCode + |if (!$nullCheckBuildTerm) { + | table.putBuildRow(row instanceof $BINARY_ROW ? + | ($BINARY_ROW) row : buildToBinaryRow.apply(row)); + |} + """.stripMargin, + s""" + |LOG.info("Finish build phase."); + |table.endBuild(); + """.stripMargin, + s""" + |$BASE_ROW row = ($BASE_ROW) element.getValue(); + |$nullCheckProbeCode + |if (!$nullCheckProbeTerm) { + | if (table.tryProbe(row)) { + | joinWithNextKey(); + | } + |} + |$nullOuterJoin + """.stripMargin, + s""" + |LOG.info("Finish probe phase."); + |while (this.table.nextMatching()) { + | joinWithNextKey(); + |} + |LOG.info("Finish rebuild phase."); + """.stripMargin, + buildType, + probeType) + + new TwoInputOperatorWrapper[BaseRow, BaseRow, BaseRow](genOp) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/OperatorCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/OperatorCodeGenerator.scala index cffcea5..8f7d088 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/OperatorCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/codegen/OperatorCodeGenerator.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.table.codegen -import org.apache.flink.streaming.api.operators.{OneInputStreamOperator, StreamOperator} +import org.apache.flink.streaming.api.operators.{OneInputStreamOperator, StreamOperator, TwoInputStreamOperator} import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.table.`type`.InternalType import org.apache.flink.table.api.TableConfig @@ -112,6 +112,85 @@ object OperatorCodeGenerator extends Logging { new GeneratedOperator(operatorName, operatorCode, ctx.references.toArray) } + def generateTwoInputStreamOperator[IN1 <: Any, IN2 <: Any, OUT <: Any]( + ctx: CodeGeneratorContext, + name: String, + processCode1: String, + endInputCode1: String, + processCode2: String, + endInputCode2: String, + input1Type: InternalType, + input2Type: InternalType, + input1Term: String = CodeGenUtils.DEFAULT_INPUT1_TERM, + input2Term: String = CodeGenUtils.DEFAULT_INPUT2_TERM, + useTimeCollect: Boolean = false) + : GeneratedOperator[TwoInputStreamOperator[IN1, IN2, OUT]] = { + addReuseOutElement(ctx) + val operatorName = newName(name) + val abstractBaseClass = ctx.getOperatorBaseClass + val baseClass = classOf[TwoInputStreamOperator[IN1, IN2, OUT]] + val inputTypeTerm1 = boxedTypeTermForType(input1Type) + val inputTypeTerm2 = boxedTypeTermForType(input2Type) + + val operatorCode = + j""" + public class $operatorName extends ${abstractBaseClass.getCanonicalName} + implements ${baseClass.getCanonicalName} { + + public static org.slf4j.Logger LOG = org.slf4j.LoggerFactory.getLogger("$operatorName"); + + private final Object[] references; + ${ctx.reuseMemberCode()} + + public $operatorName(Object[] references) throws Exception { + this.references = references; + ${ctx.reuseInitCode()} + } + + @Override + public void open() throws Exception { + super.open(); + ${ctx.reuseOpenCode()} + } + + @Override + public void processElement1($STREAM_RECORD $ELEMENT) + throws Exception { + ${ctx.reuseLocalVariableCode()} + $inputTypeTerm1 $input1Term = ${generateInputTerm(inputTypeTerm1)} + $processCode1 + } + + public void endInput1() throws Exception { + $endInputCode1 + } + + @Override + public void processElement2($STREAM_RECORD $ELEMENT) + throws Exception { + ${ctx.reuseLocalVariableCode()} + $inputTypeTerm2 $input2Term = ${generateInputTerm(inputTypeTerm2)} + $processCode2 + } + + public void endInput2() throws Exception { + $endInputCode2 + } + + @Override + public void close() throws Exception { + super.close(); + ${ctx.reuseCloseCode()} + } + + ${ctx.reuseInnerClassDefinitionCode()} + } + """.stripMargin + + LOG.debug(s"Compiling TwoInputStreamOperator Code:\n$name") + new GeneratedOperator(operatorName, operatorCode, ctx.references.toArray) + } + private def generateInputTerm(inputTypeTerm: String): String = { s"($inputTypeTerm) $ELEMENT.getValue();" } diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/codegen/LongHashJoinGeneratorTest.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/codegen/LongHashJoinGeneratorTest.java new file mode 100644 index 0000000..17453b2 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/codegen/LongHashJoinGeneratorTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.codegen; + +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTaskTestHarness; +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.generated.GeneratedJoinCondition; +import org.apache.flink.table.generated.JoinCondition; +import org.apache.flink.table.runtime.TwoInputOperatorWrapper; +import org.apache.flink.table.runtime.join.HashJoinType; +import org.apache.flink.table.runtime.join.Int2HashJoinOperatorTest; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.type.RowType; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Test for {@link LongHashJoinGenerator}. + */ +public class LongHashJoinGeneratorTest extends Int2HashJoinOperatorTest { + + @Override + public StreamOperator newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) { + RowType keyType = new RowType(InternalTypes.INT); + Assert.assertTrue(LongHashJoinGenerator.support(type, keyType, new boolean[] {true})); + return LongHashJoinGenerator.gen( + new TableConfig(), type, + keyType, + new RowType(InternalTypes.INT, InternalTypes.INT), + new RowType(InternalTypes.INT, InternalTypes.INT), + new int[]{0}, + new int[]{0}, + memorySize, memorySize, 0, 20, 10000, + reverseJoinFunction, + new GeneratedJoinCondition(MyJoinCondition.class.getCanonicalName(), "", new Object[0]) + ); + } + + public void endInput1(TwoInputStreamTaskTestHarness harness) throws Exception { + TwoInputOperatorWrapper wrapper = (TwoInputOperatorWrapper) ((OperatorChain) harness.getTask().getStreamStatusMaintainer()) + .getHeadOperator(); + TwoInputStreamOperator op = wrapper.getOperator(); + op.getClass().getMethod("endInput1").invoke(op); + } + + public void endInput2(TwoInputStreamTaskTestHarness harness) throws Exception { + TwoInputOperatorWrapper wrapper = (TwoInputOperatorWrapper) ((OperatorChain) harness.getTask().getStreamStatusMaintainer()) + .getHeadOperator(); + TwoInputStreamOperator op = wrapper.getOperator(); + op.getClass().getMethod("endInput2").invoke(op); + } + + @Test + @Override + public void testBuildLeftSemiJoin() throws Exception {} + + @Test + @Override + public void testBuildSecondHashFullOutJoin() throws Exception {} + + @Test + @Override + public void testBuildSecondHashRightOutJoin() throws Exception {} + + @Test + @Override + public void testBuildLeftAntiJoin() throws Exception {} + + @Test + @Override + public void testBuildFirstHashLeftOutJoin() throws Exception {} + + @Test + @Override + public void testBuildFirstHashFullOutJoin() throws Exception {} + + /** + * Test cond. + */ + public static class MyJoinCondition implements JoinCondition { + + public MyJoinCondition(Object[] reference) {} + + @Override + public boolean apply(BaseRow in1, BaseRow in2) { + return true; + } + } +} diff --git a/flink-table/flink-table-runtime-blink/pom.xml b/flink-table/flink-table-runtime-blink/pom.xml index deb0849..1024fd1 100644 --- a/flink-table/flink-table-runtime-blink/pom.xml +++ b/flink-table/flink-table-runtime-blink/pom.xml @@ -161,6 +161,18 @@ under the License. </execution> </executions> </plugin> + + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + </execution> + </executions> + </plugin> </plugins> </build> </project> diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedJoinCondition.java similarity index 68% copy from flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java copy to flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedJoinCondition.java index a0c2246..60ff809 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedJoinCondition.java @@ -19,20 +19,20 @@ package org.apache.flink.table.generated; /** - * Describes a generated {@link Projection}. + * Describes a generated {@link JoinCondition}. */ -public final class GeneratedProjection extends GeneratedClass<Projection> { +public class GeneratedJoinCondition extends GeneratedClass<JoinCondition> { private static final long serialVersionUID = 1L; /** - * Creates a GeneratedProjection. + * Creates a GeneratedJoinCondition. * - * @param className class name of the generated Function. - * @param code code of the generated Function. - * @param references referenced objects of the generated Function. + * @param className class name of the generated JoinCondition. + * @param code code of the generated JoinCondition. + * @param references referenced objects of the generated JoinCondition. */ - public GeneratedProjection(String className, String code, Object[] references) { + public GeneratedJoinCondition(String className, String code, Object[] references) { super(className, code, references); } } diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java index a0c2246..d43b5f3 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/generated/GeneratedProjection.java @@ -21,7 +21,7 @@ package org.apache.flink.table.generated; /** * Describes a generated {@link Projection}. */ -public final class GeneratedProjection extends GeneratedClass<Projection> { +public class GeneratedProjection extends GeneratedClass<Projection> { private static final long serialVersionUID = 1L; diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/TwoInputOperatorWrapper.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/TwoInputOperatorWrapper.java new file mode 100644 index 0000000..2315585 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/TwoInputOperatorWrapper.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.metrics.MetricGroup; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.CheckpointStreamFactory; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTask; +import org.apache.flink.table.generated.GeneratedClass; + +/** + * Wrapper for code gen operator. + * TODO Remove it after FLINK-11974. + */ +public class TwoInputOperatorWrapper<IN1, IN2, OUT> + implements TwoInputStreamOperator<IN1, IN2, OUT> { + + private final GeneratedClass<TwoInputStreamOperator<IN1, IN2, OUT>> generatedClass; + + private transient TwoInputStreamOperator<IN1, IN2, OUT> operator; + + public TwoInputOperatorWrapper(GeneratedClass<TwoInputStreamOperator<IN1, IN2, OUT>> generatedClass) { + this.generatedClass = generatedClass; + } + + @Override + public void setup(StreamTask<?, ?> containingTask, StreamConfig config, + Output<StreamRecord<OUT>> output) { + operator = generatedClass.newInstance(containingTask.getUserCodeClassLoader()); + operator.setup(containingTask, config, output); + } + + @VisibleForTesting + public TwoInputStreamOperator<IN1, IN2, OUT> getOperator() { + return operator; + } + + @Override + public void open() throws Exception { + operator.open(); + } + + @Override + public void close() throws Exception { + operator.close(); + } + + @Override + public void dispose() throws Exception { + operator.dispose(); + } + + @Override + public void prepareSnapshotPreBarrier(long checkpointId) throws Exception { + operator.prepareSnapshotPreBarrier(checkpointId); + } + + @Override + public OperatorSnapshotFutures snapshotState(long checkpointId, long timestamp, + CheckpointOptions checkpointOptions, + CheckpointStreamFactory storageLocation) throws Exception { + return operator.snapshotState(checkpointId, timestamp, checkpointOptions, storageLocation); + } + + @Override + public void initializeState() throws Exception { + operator.initializeState(); + } + + @Override + public void setKeyContextElement1(StreamRecord<?> record) throws Exception { + operator.setKeyContextElement1(record); + } + + @Override + public void setKeyContextElement2(StreamRecord<?> record) throws Exception { + operator.setKeyContextElement2(record); + } + + @Override + public ChainingStrategy getChainingStrategy() { + return ChainingStrategy.ALWAYS; + } + + @Override + public void setChainingStrategy(ChainingStrategy strategy) { + operator.setChainingStrategy(strategy); + } + + @Override + public MetricGroup getMetricGroup() { + return operator.getMetricGroup(); + } + + @Override + public OperatorID getOperatorID() { + return operator.getOperatorID(); + } + + @Override + public void processElement1(StreamRecord<IN1> element) throws Exception { + operator.processElement1(element); + } + + @Override + public void processElement2(StreamRecord<IN2> element) throws Exception { + operator.processElement2(element); + } + + @Override + public void processWatermark1(Watermark mark) throws Exception { + operator.processWatermark1(mark); + } + + @Override + public void processWatermark2(Watermark mark) throws Exception { + operator.processWatermark2(mark); + } + + @Override + public void processLatencyMarker1(LatencyMarker latencyMarker) throws Exception { + operator.processLatencyMarker1(latencyMarker); + } + + @Override + public void processLatencyMarker2(LatencyMarker latencyMarker) throws Exception { + operator.processLatencyMarker2(latencyMarker); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + operator.notifyCheckpointComplete(checkpointId); + } + + @Override + public void setCurrentKey(Object key) { + operator.setCurrentKey(key); + } + + @Override + public Object getCurrentKey() { + return operator.getCurrentKey(); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/HashJoinOperator.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/HashJoinOperator.java new file mode 100644 index 0000000..932a1fa --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/HashJoinOperator.java @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file exceBinaryRow in compliance + * with the License. You may oBinaryRowain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHBinaryRow WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join; + +import org.apache.flink.configuration.AlgorithmOptions; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.dataformat.GenericRow; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.generated.GeneratedJoinCondition; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.runtime.TableStreamOperator; +import org.apache.flink.table.runtime.hashtable.BinaryHashTable; +import org.apache.flink.table.runtime.util.RowIterator; +import org.apache.flink.table.runtime.util.StreamRecordCollector; +import org.apache.flink.table.type.RowType; +import org.apache.flink.table.typeutils.AbstractRowSerializer; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; + +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Hash join base operator. + * + * <p>The join operator implements the logic of a join operator at runtime. It uses a + * hybrid-hash-join internally to match the records with equal key. The build side of the hash + * is the first input of the match. It support all join type in {@link HashJoinType}. + */ +public abstract class HashJoinOperator extends TableStreamOperator<BaseRow> + implements TwoInputStreamOperator<BaseRow, BaseRow, BaseRow> { + + private static final Logger LOG = LoggerFactory.getLogger(HashJoinOperator.class); + + private final HashJoinParameter parameter; + private final boolean reverseJoinFunction; + private final HashJoinType type; + + private transient BinaryHashTable table; + transient Collector<BaseRow> collector; + + transient BaseRow buildSideNullRow; + private transient BaseRow probeSideNullRow; + private transient JoinedRow joinedRow; + private transient boolean buildEnd; + + HashJoinOperator(HashJoinParameter parameter) { + this.parameter = parameter; + this.type = parameter.type; + this.reverseJoinFunction = parameter.reverseJoinFunction; + } + + @Override + public void open() throws Exception { + super.open(); + + ClassLoader cl = getContainingTask().getUserCodeClassLoader(); + + final AbstractRowSerializer buildSerializer = (AbstractRowSerializer) getOperatorConfig() + .getTypeSerializerIn1(getUserCodeClassloader()); + final AbstractRowSerializer probeSerializer = (AbstractRowSerializer) getOperatorConfig() + .getTypeSerializerIn2(getUserCodeClassloader()); + + boolean hashJoinUseBitMaps = getContainingTask().getEnvironment().getTaskConfiguration() + .getBoolean(AlgorithmOptions.HASH_JOIN_BLOOM_FILTERS); + + int parallel = getRuntimeContext().getNumberOfParallelSubtasks(); + + this.table = new BinaryHashTable( + getContainingTask().getJobConfiguration(), + getContainingTask(), + buildSerializer, probeSerializer, + parameter.buildProjectionCode.newInstance(cl), + parameter.probeProjectionCode.newInstance(cl), + getContainingTask().getEnvironment().getMemoryManager(), + parameter.reservedMemorySize, + parameter.maxMemorySize, + parameter.perRequestMemorySize, + getContainingTask().getEnvironment().getIOManager(), + parameter.buildRowSize, + parameter.buildRowCount / parallel, + hashJoinUseBitMaps, + type, + parameter.condFuncCode.newInstance(cl), + reverseJoinFunction, + parameter.filterNullKeys, + parameter.tryDistinctBuildRow); + + this.collector = new StreamRecordCollector<>(output); + + this.buildSideNullRow = new GenericRow(buildSerializer.getArity()); + this.probeSideNullRow = new GenericRow(probeSerializer.getArity()); + this.joinedRow = new JoinedRow(); + this.buildEnd = false; + + getMetricGroup().gauge("memoryUsedSizeInBytes", table::getUsedMemoryInBytes); + getMetricGroup().gauge("numSpillFiles", table::getNumSpillFiles); + getMetricGroup().gauge("spillInBytes", table::getSpillInBytes); + + parameter.condFuncCode = null; + parameter.buildProjectionCode = null; + parameter.probeProjectionCode = null; + } + + @Override + public void processElement1(StreamRecord<BaseRow> element) throws Exception { + checkState(!buildEnd, "Should not build ended."); + this.table.putBuildRow(element.getValue()); + } + + @Override + public void processElement2(StreamRecord<BaseRow> element) throws Exception { + checkState(buildEnd, "Should build ended."); + if (this.table.tryProbe(element.getValue())) { + joinWithNextKey(); + } + } + + public void endInput1() throws Exception { + checkState(!buildEnd, "Should not build ended."); + LOG.info("Finish build phase."); + buildEnd = true; + this.table.endBuild(); + } + + public void endInput2() throws Exception { + checkState(buildEnd, "Should build ended."); + LOG.info("Finish probe phase."); + while (this.table.nextMatching()) { + joinWithNextKey(); + } + LOG.info("Finish rebuild phase."); + } + + private void joinWithNextKey() throws Exception { + // we have a next record, get the iterators to the probe and build side values + join(table.getBuildSideIterator(), table.getCurrentProbeRow()); + } + + public abstract void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception; + + void innerJoin(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + collect(buildIter.getRow(), probeRow); + while (buildIter.advanceNext()) { + collect(buildIter.getRow(), probeRow); + } + } + + void buildOuterJoin(RowIterator<BinaryRow> buildIter) throws Exception { + collect(buildIter.getRow(), probeSideNullRow); + while (buildIter.advanceNext()) { + collect(buildIter.getRow(), probeSideNullRow); + } + } + + void collect(BaseRow row1, BaseRow row2) throws Exception { + if (reverseJoinFunction) { + collector.collect(joinedRow.replace(row2, row1)); + } else { + collector.collect(joinedRow.replace(row1, row2)); + } + } + + @Override + public void close() throws Exception { + super.close(); + if (this.table != null) { + this.table.close(); + this.table.free(); + this.table = null; + } + } + + public static HashJoinOperator newHashJoinOperator( + long minMemorySize, + long maxMemorySize, + long eachRequestMemorySize, + HashJoinType type, + GeneratedJoinCondition condFuncCode, + boolean reverseJoinFunction, + boolean[] filterNullKeys, + GeneratedProjection buildProjectionCode, + GeneratedProjection probeProjectionCode, + boolean tryDistinctBuildRow, + int buildRowSize, + long buildRowCount, + long probeRowCount, + RowType keyType) { + HashJoinParameter parameter = new HashJoinParameter(minMemorySize, maxMemorySize, eachRequestMemorySize, + type, condFuncCode, reverseJoinFunction, filterNullKeys, buildProjectionCode, probeProjectionCode, + tryDistinctBuildRow, buildRowSize, buildRowCount, probeRowCount, keyType); + switch (type) { + case INNER: + return new InnerHashJoinOperator(parameter); + case BUILD_OUTER: + return new BuildOuterHashJoinOperator(parameter); + case PROBE_OUTER: + return new ProbeOuterHashJoinOperator(parameter); + case FULL_OUTER: + return new FullOuterHashJoinOperator(parameter); + case SEMI: + return new SemiHashJoinOperator(parameter); + case ANTI: + return new AntiHashJoinOperator(parameter); + case BUILD_LEFT_SEMI: + case BUILD_LEFT_ANTI: + return new BuildLeftSemiOrAntiHashJoinOperator(parameter); + default: + throw new IllegalArgumentException("invalid: " + type); + } + } + + static class HashJoinParameter implements Serializable { + long reservedMemorySize; + long maxMemorySize; + long perRequestMemorySize; + HashJoinType type; + GeneratedJoinCondition condFuncCode; + boolean reverseJoinFunction; + boolean[] filterNullKeys; + GeneratedProjection buildProjectionCode; + GeneratedProjection probeProjectionCode; + boolean tryDistinctBuildRow; + int buildRowSize; + long buildRowCount; + long probeRowCount; + RowType keyType; + + HashJoinParameter( + long reservedMemorySize, long maxMemorySize, long perRequestMemorySize, HashJoinType type, + GeneratedJoinCondition condFuncCode, boolean reverseJoinFunction, + boolean[] filterNullKeys, + GeneratedProjection buildProjectionCode, + GeneratedProjection probeProjectionCode, boolean tryDistinctBuildRow, + int buildRowSize, long buildRowCount, long probeRowCount, RowType keyType) { + this.reservedMemorySize = reservedMemorySize; + this.maxMemorySize = maxMemorySize; + this.perRequestMemorySize = perRequestMemorySize; + this.type = type; + this.condFuncCode = condFuncCode; + this.reverseJoinFunction = reverseJoinFunction; + this.filterNullKeys = filterNullKeys; + this.buildProjectionCode = buildProjectionCode; + this.probeProjectionCode = probeProjectionCode; + this.tryDistinctBuildRow = tryDistinctBuildRow; + this.buildRowSize = buildRowSize; + this.buildRowCount = buildRowCount; + this.probeRowCount = probeRowCount; + this.keyType = keyType; + } + } + + /** + * Inner join. + * Output assembled {@link JoinedRow} when build side row matched probe side row. + */ + private static class InnerHashJoinOperator extends HashJoinOperator { + + InnerHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + if (buildIter.advanceNext()) { + if (probeRow != null) { + innerJoin(buildIter, probeRow); + } + } + } + } + + /** + * BuildOuter join. + * Output assembled {@link JoinedRow} when build side row matched probe side row. + * And if there is no match in the probe table, output {@link JoinedRow} assembled by + * build side row and nulls. + */ + private static class BuildOuterHashJoinOperator extends HashJoinOperator { + + BuildOuterHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + if (buildIter.advanceNext()) { + if (probeRow != null) { + innerJoin(buildIter, probeRow); + } else { + buildOuterJoin(buildIter); + } + } + } + } + + /** + * ProbeOuter join. + * Output assembled {@link JoinedRow} when probe side row matched build side row. + * And if there is no match in the build table, output {@link JoinedRow} assembled by + * nulls and probe side row. + */ + private static class ProbeOuterHashJoinOperator extends HashJoinOperator { + + ProbeOuterHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + if (buildIter.advanceNext()) { + if (probeRow != null) { + innerJoin(buildIter, probeRow); + } + } else if (probeRow != null) { + collect(buildSideNullRow, probeRow); + } + } + } + + /** + * BuildOuter join. + * Output assembled {@link JoinedRow} when build side row matched probe side row. + * And if there is no match, output {@link JoinedRow} assembled by build side row and nulls or + * nulls and probe side row. + */ + private static class FullOuterHashJoinOperator extends HashJoinOperator { + + FullOuterHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + if (buildIter.advanceNext()) { + if (probeRow != null) { + innerJoin(buildIter, probeRow); + } else { + buildOuterJoin(buildIter); + } + } else if (probeRow != null) { + collect(buildSideNullRow, probeRow); + } + } + } + + /** + * Semi join. + * Output probe side row when probe side row matched build side row. + */ + private static class SemiHashJoinOperator extends HashJoinOperator { + + SemiHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + checkNotNull(probeRow); + if (buildIter.advanceNext()) { + collector.collect(probeRow); + } + } + } + + /** + * Anti join. + * Output probe side row when probe side row not matched build side row. + */ + private static class AntiHashJoinOperator extends HashJoinOperator { + + AntiHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + checkNotNull(probeRow); + if (!buildIter.advanceNext()) { + collector.collect(probeRow); + } + } + } + + /** + * BuildLeftSemiOrAnti join. + * BuildLeftSemiJoin: Output build side row when build side row matched probe side row. + * BuildLeftAntiJoin: Output build side row when build side row not matched probe side row. + */ + private static class BuildLeftSemiOrAntiHashJoinOperator extends HashJoinOperator { + + BuildLeftSemiOrAntiHashJoinOperator(HashJoinParameter parameter) { + super(parameter); + } + + @Override + public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception { + if (buildIter.advanceNext()) { + if (probeRow != null) { //Probe phase + // we must iterator to set probedSet. + while (buildIter.advanceNext()) {} + } else { //End Probe phase, iterator build side elements. + collector.collect(buildIter.getRow()); + while (buildIter.advanceNext()) { + collector.collect(buildIter.getRow()); + } + } + } + } + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/Int2HashJoinOperatorTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/Int2HashJoinOperatorTest.java new file mode 100644 index 0000000..7d38b4d --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/Int2HashJoinOperatorTest.java @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask; +import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTaskTestHarness; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.dataformat.BinaryRowWriter; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.generated.GeneratedJoinCondition; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.generated.JoinCondition; +import org.apache.flink.table.generated.Projection; +import org.apache.flink.table.runtime.util.UniformBinaryRowGenerator; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.type.RowType; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; +import org.apache.flink.util.MutableObjectIterator; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import java.util.Queue; + +import static java.lang.Long.valueOf; + +/** + * Random test for {@link HashJoinOperator}. + */ +public class Int2HashJoinOperatorTest implements Serializable { + + //---------------------- build first inner join ----------------------------------------- + @Test + public void testBuildFirstHashInnerJoin() throws Exception { + + int numKeys = 100; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + + // create a build input that gives 300 pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false); + // create a probe input that gives 1000 pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, false, false, true, numKeys * buildValsPerKey * probeValsPerKey, + numKeys, 165); + } + + //---------------------- build first left out join ----------------------------------------- + @Test + public void testBuildFirstHashLeftOutJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, true, false, true, numKeys1 * buildValsPerKey * probeValsPerKey, + numKeys1, 165); + } + + //---------------------- build first right out join ----------------------------------------- + @Test + public void testBuildFirstHashRightOutJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, false, true, true, 280, numKeys2, -1); + } + + //---------------------- build first full out join ----------------------------------------- + @Test + public void testBuildFirstHashFullOutJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, true, true, true, 280, numKeys2, -1); + } + + //---------------------- build second inner join ----------------------------------------- + @Test + public void testBuildSecondHashInnerJoin() throws Exception { + + int numKeys = 100; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 300 pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys, buildValsPerKey, false); + + // create a probe input that gives 1000 pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, false, false, false, numKeys * buildValsPerKey * probeValsPerKey, + numKeys, 165); + } + + //---------------------- build second left out join ----------------------------------------- + @Test + public void testBuildSecondHashLeftOutJoin() throws Exception { + + int numKeys1 = 10; + int numKeys2 = 9; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, true, false, false, numKeys2 * buildValsPerKey * probeValsPerKey, + numKeys2, 165); + } + + //---------------------- build second right out join ----------------------------------------- + @Test + public void testBuildSecondHashRightOutJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, false, true, false, + numKeys1 * buildValsPerKey * probeValsPerKey, numKeys2, -1); + } + + //---------------------- build second full out join ----------------------------------------- + @Test + public void testBuildSecondHashFullOutJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + buildJoin(buildInput, probeInput, true, true, false, 280, numKeys2, -1); + } + + @Test + public void testSemiJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + HashJoinType type = HashJoinType.SEMI; + StreamOperator operator = newOperator(33 * 32 * 1024, type, false); + joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true); + } + + @Test + public void testAntiJoin() throws Exception { + + int numKeys1 = 9; + int numKeys2 = 10; + int buildValsPerKey = 3; + int probeValsPerKey = 10; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + HashJoinType type = HashJoinType.ANTI; + StreamOperator operator = newOperator(33 * 32 * 1024, type, false); + joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true); + } + + @Test + public void testBuildLeftSemiJoin() throws Exception { + + int numKeys1 = 10; + int numKeys2 = 9; + int buildValsPerKey = 10; + int probeValsPerKey = 3; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + HashJoinType type = HashJoinType.BUILD_LEFT_SEMI; + StreamOperator operator = newOperator(33 * 32 * 1024, type, false); + joinAndAssert(operator, buildInput, probeInput, 90, 9, 45, true); + } + + @Test + public void testBuildLeftAntiJoin() throws Exception { + + int numKeys1 = 10; + int numKeys2 = 9; + int buildValsPerKey = 10; + int probeValsPerKey = 3; + // create a build input that gives 3 million pairs with 3 values sharing the same key + MutableObjectIterator<BinaryRow> buildInput = new UniformBinaryRowGenerator(numKeys1, buildValsPerKey, true); + + // create a probe input that gives 10 million pairs with 10 values sharing a key + MutableObjectIterator<BinaryRow> probeInput = new UniformBinaryRowGenerator(numKeys2, probeValsPerKey, true); + + HashJoinType type = HashJoinType.BUILD_LEFT_ANTI; + StreamOperator operator = newOperator(33 * 32 * 1024, type, false); + joinAndAssert(operator, buildInput, probeInput, 10, 1, 45, true); + } + + private void buildJoin( + MutableObjectIterator<BinaryRow> buildInput, + MutableObjectIterator<BinaryRow> probeInput, + boolean leftOut, boolean rightOut, boolean buildLeft, + int expectOutSize, int expectOutKeySize, int expectOutVal) throws Exception { + HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut); + StreamOperator operator = newOperator(33 * 32 * 1024, type, !buildLeft); + joinAndAssert(operator, buildInput, probeInput, expectOutSize, expectOutKeySize, expectOutVal, false); + } + + private void joinAndAssert( + StreamOperator operator, + MutableObjectIterator<BinaryRow> input1, + MutableObjectIterator<BinaryRow> input2, + int expectOutSize, + int expectOutKeySize, + int expectOutVal, + boolean semiJoin) throws Exception { + BaseRowTypeInfo typeInfo = new BaseRowTypeInfo(InternalTypes.INT, InternalTypes.INT); + BaseRowTypeInfo baseRowType = new BaseRowTypeInfo( + InternalTypes.INT, InternalTypes.INT, InternalTypes.INT, InternalTypes.INT); + TwoInputStreamTaskTestHarness<BinaryRow, BinaryRow, JoinedRow> testHarness = + new TwoInputStreamTaskTestHarness<>(TwoInputStreamTask::new, + 2, 1, new int[]{1, 2}, typeInfo, (TypeInformation) typeInfo, baseRowType); + testHarness.memorySize = 36 * 1024 * 1024; + testHarness.getExecutionConfig().enableObjectReuse(); + testHarness.setupOutputForSingletonOperatorChain(); + testHarness.getStreamConfig().setStreamOperator(operator); + testHarness.getStreamConfig().setOperatorID(new OperatorID()); + + testHarness.invoke(); + testHarness.waitForTaskRunning(); + + BinaryRow row1; + while ((row1 = input1.next()) != null) { + testHarness.processElement(new StreamRecord<>(row1), 0, 0); + } + testHarness.waitForInputProcessing(); + endInput1(testHarness); + + BinaryRow row2; + while ((row2 = input2.next()) != null) { + testHarness.processElement(new StreamRecord<>(row2), 1, 0); + } + testHarness.waitForInputProcessing(); + endInput2(testHarness); + + testHarness.endInput(); + testHarness.waitForInputProcessing(); + testHarness.waitForTaskCompletion(); + + Queue<Object> actual = testHarness.getOutput(); + + Assert.assertEquals("Output was not correct.", expectOutSize, actual.size()); + + // Don't verify the output value when experOutVal is -1 + if (expectOutVal != -1) { + if (semiJoin) { + HashMap<Integer, Long> map = new HashMap<>(expectOutKeySize); + + for (Object o : actual) { + StreamRecord<BaseRow> record = (StreamRecord<BaseRow>) o; + BaseRow row = record.getValue(); + int key = row.getInt(0); + int val = row.getInt(1); + Long contained = map.get(key); + if (contained == null) { + contained = (long) val; + } else { + contained = valueOf(contained + val); + } + map.put(key, contained); + } + + Assert.assertEquals("Wrong number of keys", expectOutKeySize, map.size()); + for (Map.Entry<Integer, Long> entry : map.entrySet()) { + long val = entry.getValue(); + int key = entry.getKey(); + + Assert.assertEquals("Wrong number of values in per-key cross product for key " + key, + expectOutVal, val); + } + } else { + // create the map for validating the results + HashMap<Integer, Long> map = new HashMap<>(expectOutKeySize); + + for (Object o : actual) { + StreamRecord<BaseRow> record = (StreamRecord<BaseRow>) o; + BaseRow row = record.getValue(); + int key = row.isNullAt(0) ? row.getInt(2) : row.getInt(0); + + int val1 = 0; + int val2 = 0; + if (!row.isNullAt(1)) { + val1 = row.getInt(1); + } + if (!row.isNullAt(3)) { + val2 = row.getInt(3); + } + int val = val1 + val2; + + Long contained = map.get(key); + if (contained == null) { + contained = (long) val; + } else { + contained = valueOf(contained + val); + } + map.put(key, contained); + } + + Assert.assertEquals("Wrong number of keys", expectOutKeySize, map.size()); + for (Map.Entry<Integer, Long> entry : map.entrySet()) { + long val = entry.getValue(); + int key = entry.getKey(); + + Assert.assertEquals("Wrong number of values in per-key cross product for key " + key, + expectOutVal, val); + } + } + } + } + + /** + * my projection. + */ + public static final class MyProjection implements Projection<BaseRow, BinaryRow> { + + BinaryRow innerRow = new BinaryRow(1); + BinaryRowWriter writer = new BinaryRowWriter(innerRow); + + @Override + public BinaryRow apply(BaseRow row) { + writer.reset(); + if (row.isNullAt(0)) { + writer.setNullAt(0); + } else { + writer.writeInt(0, row.getInt(0)); + } + writer.complete(); + return innerRow; + } + } + + public void endInput1(TwoInputStreamTaskTestHarness harness) throws Exception { + HashJoinOperator op = (HashJoinOperator) ((OperatorChain) harness.getTask().getStreamStatusMaintainer()) + .getHeadOperator(); + op.endInput1(); + } + + public void endInput2(TwoInputStreamTaskTestHarness harness) throws Exception { + HashJoinOperator op = (HashJoinOperator) ((OperatorChain) harness.getTask().getStreamStatusMaintainer()) + .getHeadOperator(); + op.endInput2(); + } + + public StreamOperator newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) { + return HashJoinOperator.newHashJoinOperator( + memorySize, memorySize, 0, type, + new GeneratedJoinCondition("", "", new Object[0]) { + @Override + public JoinCondition newInstance(ClassLoader classLoader) { + return (in1, in2) -> true; + } + }, + reverseJoinFunction, new boolean[]{true}, + new GeneratedProjection("", "", new Object[0]) { + @Override + public Projection newInstance(ClassLoader classLoader) { + return new MyProjection(); + } + }, + new GeneratedProjection("", "", new Object[0]) { + @Override + public Projection newInstance(ClassLoader classLoader) { + return new MyProjection(); + } + }, + false, 20, 10000, + 10000, new RowType(InternalTypes.INT)); + } +} diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/String2HashJoinOperatorTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/String2HashJoinOperatorTest.java new file mode 100644 index 0000000..fcb71d6 --- /dev/null +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/join/String2HashJoinOperatorTest.java @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.join; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OperatorChain; +import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask; +import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTaskTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.table.dataformat.BaseRow; +import org.apache.flink.table.dataformat.BinaryRow; +import org.apache.flink.table.dataformat.BinaryRowWriter; +import org.apache.flink.table.dataformat.BinaryString; +import org.apache.flink.table.dataformat.JoinedRow; +import org.apache.flink.table.generated.GeneratedJoinCondition; +import org.apache.flink.table.generated.GeneratedProjection; +import org.apache.flink.table.generated.JoinCondition; +import org.apache.flink.table.generated.Projection; +import org.apache.flink.table.type.InternalTypes; +import org.apache.flink.table.type.RowType; +import org.apache.flink.table.typeutils.BaseRowTypeInfo; + +import org.junit.Test; + +import java.io.Serializable; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Test for {@link HashJoinOperator}. + */ +public class String2HashJoinOperatorTest implements Serializable { + + private BaseRowTypeInfo typeInfo = new BaseRowTypeInfo(InternalTypes.STRING, InternalTypes.STRING); + private BaseRowTypeInfo joinedInfo = new BaseRowTypeInfo( + InternalTypes.STRING, InternalTypes.STRING, InternalTypes.STRING, InternalTypes.STRING); + private transient TwoInputStreamTaskTestHarness<BinaryRow, BinaryRow, JoinedRow> testHarness; + private ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>(); + private long initialTime = 0L; + + public static LinkedBlockingQueue<Object> transformToBinary(LinkedBlockingQueue<Object> output) { + LinkedBlockingQueue<Object> ret = new LinkedBlockingQueue<>(); + for (Object o : output) { + BaseRow row = ((StreamRecord<BaseRow>) o).getValue(); + BinaryRow binaryRow; + if (row.isNullAt(0)) { + binaryRow = newRow(row.getString(2).toString(), row.getString(3) + "null"); + } else if (row.isNullAt(2)) { + binaryRow = newRow(row.getString(0).toString(), row.getString(1) + "null"); + } else { + String value1 = row.getString(1).toString(); + String value2 = row.getString(3).toString(); + binaryRow = newRow(row.getString(0).toString(), value1 + value2); + } + ret.add(new StreamRecord(binaryRow)); + } + return ret; + } + + private void init(boolean leftOut, boolean rightOut, boolean buildLeft) throws Exception { + HashJoinType type = HashJoinType.of(buildLeft, leftOut, rightOut); + HashJoinOperator operator = newOperator(33 * 32 * 1024, type, !buildLeft); + testHarness = new TwoInputStreamTaskTestHarness<>( + TwoInputStreamTask::new, 2, 2, new int[]{1, 2}, typeInfo, (TypeInformation) typeInfo, joinedInfo); + testHarness.memorySize = 36 * 1024 * 1024; + testHarness.getExecutionConfig().enableObjectReuse(); + testHarness.setupOutputForSingletonOperatorChain(); + testHarness.getStreamConfig().setStreamOperator(operator); + testHarness.getStreamConfig().setOperatorID(new OperatorID()); + + testHarness.invoke(); + testHarness.waitForTaskRunning(); + } + + private void endInput1() throws Exception { + endInput1(testHarness); + } + + private void endInput2() throws Exception { + endInput2(testHarness); + } + + static void endInput1(TwoInputStreamTaskTestHarness harness) throws Exception { + HashJoinOperator op = + (HashJoinOperator) ((OperatorChain) harness.getTask().getStreamStatusMaintainer()) + .getHeadOperator(); + op.endInput1(); + } + + static void endInput2(TwoInputStreamTaskTestHarness harness) throws Exception { + HashJoinOperator op = + (HashJoinOperator) ((OperatorChain) harness.getTask().getStreamStatusMaintainer()) + .getHeadOperator(); + op.endInput2(); + } + + @Test + public void testInnerHashJoin() throws Exception { + + init(false, false, true); + + testHarness.processElement(new StreamRecord<>(newRow("a", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("d", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("b", "1"), initialTime), 0, 1); + + testHarness.waitForInputProcessing(); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, + transformToBinary(testHarness.getOutput())); + + endInput1(); + testHarness.processElement(new StreamRecord<>(newRow("a", "2"), initialTime), 1, 1); + + testHarness.waitForInputProcessing(); + expectedOutput.add(new StreamRecord<>(newRow("a", "02"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, + transformToBinary(testHarness.getOutput())); + + testHarness.processElement(new StreamRecord<>(newRow("c", "2"), initialTime), 1, 1); + testHarness.processElement(new StreamRecord<>(newRow("b", "4"), initialTime), 1, 0); + expectedOutput.add(new StreamRecord<>(newRow("b", "14"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, + transformToBinary(testHarness.getOutput())); + + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, + transformToBinary(testHarness.getOutput())); + } + + @Test + public void testProbeOuterHashJoin() throws Exception { + + init(true, false, false); + + testHarness.processElement(new StreamRecord<>(newRow("a", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("d", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("b", "1"), initialTime), 0, 1); + + testHarness.waitForInputProcessing(); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + endInput1(); + testHarness.processElement(new StreamRecord<>(newRow("a", "2"), initialTime), 1, 1); + testHarness.waitForInputProcessing(); + + expectedOutput.add(new StreamRecord<>(newRow("a", "20"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + testHarness.processElement(new StreamRecord<>(newRow("c", "2"), initialTime), 1, 1); + testHarness.processElement(new StreamRecord<>(newRow("b", "4"), initialTime), 1, 0); + expectedOutput.add(new StreamRecord<>(newRow("c", "2null"))); + expectedOutput.add(new StreamRecord<>(newRow("b", "41"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + } + + @Test + public void testBuildOuterHashJoin() throws Exception { + + init(false, true, false); + + testHarness.processElement(new StreamRecord<>(newRow("a", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("d", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("b", "1"), initialTime), 0, 1); + + testHarness.waitForInputProcessing(); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + endInput1(); + testHarness.processElement(new StreamRecord<>(newRow("a", "2"), initialTime), 1, 1); + testHarness.waitForInputProcessing(); + + expectedOutput.add(new StreamRecord<>(newRow("a", "20"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + testHarness.processElement(new StreamRecord<>(newRow("c", "2"), initialTime), 1, 1); + testHarness.processElement(new StreamRecord<>(newRow("b", "4"), initialTime), 1, 0); + expectedOutput.add(new StreamRecord<>(newRow("b", "41"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + endInput2(); + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + expectedOutput.add(new StreamRecord<>(newRow("d", "0null"))); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + } + + @Test + public void testFullOuterHashJoin() throws Exception { + + init(true, true, true); + + testHarness.processElement(new StreamRecord<>(newRow("a", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("d", "0"), initialTime), 0, 0); + testHarness.processElement(new StreamRecord<>(newRow("b", "1"), initialTime), 0, 1); + + testHarness.waitForInputProcessing(); + + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + endInput1(); + testHarness.processElement(new StreamRecord<>(newRow("a", "2"), initialTime), 1, 1); + testHarness.waitForInputProcessing(); + + expectedOutput.add(new StreamRecord<>(newRow("a", "02"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + testHarness.processElement(new StreamRecord<>(newRow("c", "2"), initialTime), 1, 1); + testHarness.processElement(new StreamRecord<>(newRow("b", "4"), initialTime), 1, 0); + expectedOutput.add(new StreamRecord<>(newRow("c", "2null"))); + expectedOutput.add(new StreamRecord<>(newRow("b", "14"))); + testHarness.waitForInputProcessing(); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + + endInput2(); + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + expectedOutput.add(new StreamRecord<>(newRow("d", "0null"))); + TestHarnessUtil.assertOutputEquals("Output was not correct.", + expectedOutput, + transformToBinary(testHarness.getOutput())); + } + + /** + * my project. + */ + public static final class MyProjection implements Projection<BinaryRow, BinaryRow> { + + BinaryRow innerRow = new BinaryRow(1); + BinaryRowWriter writer = new BinaryRowWriter(innerRow); + + @Override + public BinaryRow apply(BinaryRow row) { + writer.reset(); + writer.writeString(0, row.getString(0)); + writer.complete(); + return innerRow; + } + } + + public static BinaryRow newRow(String... s) { + BinaryRow row = new BinaryRow(s.length); + BinaryRowWriter writer = new BinaryRowWriter(row); + + for (int i = 0; i < s.length; i++) { + if (s[i] == null) { + writer.setNullAt(i); + } else { + writer.writeString(i, BinaryString.fromString(s[i])); + } + } + + writer.complete(); + return row; + } + + private HashJoinOperator newOperator(long memorySize, HashJoinType type, boolean reverseJoinFunction) { + return HashJoinOperator.newHashJoinOperator( + memorySize, memorySize, 0, type, + new GeneratedJoinCondition("", "", new Object[0]) { + @Override + public JoinCondition newInstance(ClassLoader classLoader) { + return (in1, in2) -> true; + } + }, + reverseJoinFunction, new boolean[]{true}, + new GeneratedProjection("", "", new Object[0]) { + @Override + public Projection newInstance(ClassLoader classLoader) { + return new MyProjection(); + } + }, + new GeneratedProjection("", "", new Object[0]) { + @Override + public Projection newInstance(ClassLoader classLoader) { + return new MyProjection(); + } + }, + false, 20, 10000, + 10000, new RowType(InternalTypes.STRING)); + } +}