Github user aljoscha commented on a diff in the pull request:

    https://github.com/apache/flink/pull/5230#discussion_r159223446
  
    --- Diff: 
flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
 ---
    @@ -0,0 +1,734 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + * <p>
    + * http://www.apache.org/licenses/LICENSE-2.0
    + * <p>
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.streaming.api.operators.co;
    +
    +import org.apache.flink.api.common.state.ListState;
    +import org.apache.flink.api.common.state.ListStateDescriptor;
    +import org.apache.flink.api.common.state.MapStateDescriptor;
    +import org.apache.flink.api.common.state.ValueStateDescriptor;
    +import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
    +import org.apache.flink.api.common.typeinfo.TypeInformation;
    +import org.apache.flink.api.java.functions.KeySelector;
    +import org.apache.flink.runtime.state.KeyedStateFunction;
    +import 
org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
    +import org.apache.flink.streaming.api.watermark.Watermark;
    +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
    +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
    +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
    +import 
org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
    +import org.apache.flink.streaming.util.TestHarnessUtil;
    +import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness;
    +import org.apache.flink.util.Collector;
    +import org.apache.flink.util.Preconditions;
    +
    +import org.junit.Assert;
    +import org.junit.Test;
    +
    +import java.util.ArrayList;
    +import java.util.Comparator;
    +import java.util.HashMap;
    +import java.util.HashSet;
    +import java.util.Iterator;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Queue;
    +import java.util.Set;
    +import java.util.concurrent.ConcurrentLinkedQueue;
    +import java.util.function.Function;
    +
    +/**
    + * Tests for the {@link CoBroadcastWithKeyedOperator}.
    + */
    +public class CoBroadcastWithKeyedOperatorTest {
    +
    +   /** Test the iteration over the keyed state on the broadcast side. */
    +   @Test
    +   public void testAccessToKeyedStateIt() throws Exception {
    +           final List<String> test1content = new ArrayList<>();
    +           test1content.add("test1");
    +           test1content.add("test1");
    +
    +           final List<String> test2content = new ArrayList<>();
    +           test2content.add("test2");
    +           test2content.add("test2");
    +           test2content.add("test2");
    +           test2content.add("test2");
    +
    +           final List<String> test3content = new ArrayList<>();
    +           test3content.add("test3");
    +           test3content.add("test3");
    +           test3content.add("test3");
    +
    +           final Map<String, List<String>> expectedState = new HashMap<>();
    +           expectedState.put("test1", test1content);
    +           expectedState.put("test2", test2content);
    +           expectedState.put("test3", test3content);
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
StatefulFunctionWithKeyedStateAccessedOnBroadcast(expectedState),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO)
    +           ) {
    +
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness = autoTestHarness.getTestHarness();
    +
    +                   // send elements to the keyed state
    +                   testHarness.processElement2(new StreamRecord<>("test1", 
12L));
    +                   testHarness.processElement2(new StreamRecord<>("test1", 
12L));
    +
    +                   testHarness.processElement2(new StreamRecord<>("test2", 
13L));
    +                   testHarness.processElement2(new StreamRecord<>("test2", 
13L));
    +                   testHarness.processElement2(new StreamRecord<>("test2", 
13L));
    +
    +                   testHarness.processElement2(new StreamRecord<>("test3", 
14L));
    +                   testHarness.processElement2(new StreamRecord<>("test3", 
14L));
    +                   testHarness.processElement2(new StreamRecord<>("test3", 
14L));
    +
    +                   testHarness.processElement2(new StreamRecord<>("test2", 
13L));
    +
    +                   // this is the element on the broadcast side that will 
trigger the verification
    +                   // check the 
StatefulFunctionWithKeyedStateAccessedOnBroadcast#processElementOnBroadcastSide()
    +                   testHarness.processElement1(new StreamRecord<>(1, 13L));
    +           }
    +   }
    +
    +   /**
    +    * Simple {@link KeyedBroadcastProcessFunction} that adds all incoming 
elements in the non-broadcast
    +    * side to a listState and at the broadcast side it verifies if the 
stored data is the expected ones.
    +    */
    +   private static class StatefulFunctionWithKeyedStateAccessedOnBroadcast
    +                   extends KeyedBroadcastProcessFunction<String, Integer, 
String, String, Integer, String> {
    +
    +           private static final long serialVersionUID = 
7496674620398203933L;
    +
    +           private final ListStateDescriptor<String> listStateDesc =
    +                           new ListStateDescriptor<>("listStateTest", 
BasicTypeInfo.STRING_TYPE_INFO);
    +
    +           private final Map<String, List<String>> expectedKeyedStates;
    +
    +           StatefulFunctionWithKeyedStateAccessedOnBroadcast(Map<String, 
List<String>> expectedKeyedState) {
    +                   this.expectedKeyedStates = 
Preconditions.checkNotNull(expectedKeyedState);
    +           }
    +
    +           @Override
    +           public void processElementOnBroadcastSide(Integer value, 
KeyedReadWriteContext ctx, Collector<String> out) throws Exception {
    +                   // put an element in the broadcast state
    +                   ctx.applyToKeyedState(
    +                                   listStateDesc,
    +                                   new KeyedStateFunction<String, 
ListState<String>>() {
    +                                           @Override
    +                                           public void process(String key, 
ListState<String> state) throws Exception {
    +                                                   final Iterator<String> 
it = state.get().iterator();
    +
    +                                                   final List<String> list 
= new ArrayList<>();
    +                                                   while (it.hasNext()) {
    +                                                           
list.add(it.next());
    +                                                   }
    +                                                   
Assert.assertEquals(expectedKeyedStates.get(key), list);
    +                                           }
    +                                   });
    +           }
    +
    +           @Override
    +           public void processElement(String value, KeyedReadOnlyContext 
ctx, Collector<String> out) throws Exception {
    +                   
getRuntimeContext().getListState(listStateDesc).add(value);
    +           }
    +   }
    +
    +   @Test
    +   public void testFunctionWithTimer() throws Exception {
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
FunctionWithTimerOnKeyed(41L),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO)
    +           ) {
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness = autoTestHarness.getTestHarness();
    +
    +                   testHarness.processWatermark1(new Watermark(10L));
    +                   testHarness.processWatermark2(new Watermark(10L));
    +                   testHarness.processElement1(new StreamRecord<>(5, 12L));
    +
    +                   testHarness.processWatermark1(new Watermark(40L));
    +                   testHarness.processWatermark2(new Watermark(40L));
    +                   testHarness.processElement2(new StreamRecord<>("6", 
13L));
    +                   testHarness.processElement2(new StreamRecord<>("6", 
15L));
    +
    +                   testHarness.processWatermark1(new Watermark(50L));
    +                   testHarness.processWatermark2(new Watermark(50L));
    +
    +                   Queue<Object> expectedOutput = new 
ConcurrentLinkedQueue<>();
    +
    +                   expectedOutput.add(new Watermark(10L));
    +                   expectedOutput.add(new StreamRecord<>("BR:5 WM:10 
TS:12", 12L));
    +                   expectedOutput.add(new Watermark(40L));
    +                   expectedOutput.add(new StreamRecord<>("NON-BR:6 WM:40 
TS:13", 13L));
    +                   expectedOutput.add(new StreamRecord<>("NON-BR:6 WM:40 
TS:15", 15L));
    +                   expectedOutput.add(new StreamRecord<>("TIMER:41", 41L));
    +                   expectedOutput.add(new Watermark(50L));
    +
    +                   TestHarnessUtil.assertOutputEquals("Output was not 
correct.", expectedOutput, testHarness.getOutput());
    +           }
    +   }
    +
    +   /**
    +    * {@link KeyedBroadcastProcessFunction} that registers a timer and 
emits
    +    * for every element the watermark and the timestamp of the element.
    +    */
    +   private static class FunctionWithTimerOnKeyed extends 
KeyedBroadcastProcessFunction<String, Integer, String, String, Integer, String> 
{
    +
    +           private static final long serialVersionUID = 
7496674620398203933L;
    +
    +           private final long timerTS;
    +
    +           FunctionWithTimerOnKeyed(long timerTS) {
    +                   this.timerTS = timerTS;
    +           }
    +
    +           @Override
    +           public void processElementOnBroadcastSide(Integer value, 
KeyedReadWriteContext ctx, Collector<String> out) throws Exception {
    +                   out.collect("BR:" + value + " WM:" + 
ctx.currentWatermark() + " TS:" + ctx.timestamp());
    +           }
    +
    +           @Override
    +           public void processElement(String value, KeyedReadOnlyContext 
ctx, Collector<String> out) throws Exception {
    +                   ctx.timerService().registerEventTimeTimer(timerTS);
    +                   out.collect("NON-BR:" + value + " WM:" + 
ctx.currentWatermark() + " TS:" + ctx.timestamp());
    +           }
    +
    +           @Override
    +           public void onTimer(long timestamp, OnTimerContext ctx, 
Collector<String> out) throws Exception {
    +                   out.collect("TIMER:" + timestamp);
    +           }
    +   }
    +
    +   @Test
    +   public void testFunctionWithBroadcastState() throws Exception {
    +
    +           final Map<String, Integer> expectedBroadcastState = new 
HashMap<>();
    +           expectedBroadcastState.put("5.key", 5);
    +           expectedBroadcastState.put("34.key", 34);
    +           expectedBroadcastState.put("53.key", 53);
    +           expectedBroadcastState.put("12.key", 12);
    +           expectedBroadcastState.put("98.key", 98);
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
FunctionWithBroadcastState("key", expectedBroadcastState, 41L),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO)
    +           ) {
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness = autoTestHarness.getTestHarness();
    +
    +                   testHarness.processWatermark1(new Watermark(10L));
    +                   testHarness.processWatermark2(new Watermark(10L));
    +
    +                   testHarness.processElement1(new StreamRecord<>(5, 10L));
    +                   testHarness.processElement1(new StreamRecord<>(34, 
12L));
    +                   testHarness.processElement1(new StreamRecord<>(53, 
15L));
    +                   testHarness.processElement1(new StreamRecord<>(12, 
16L));
    +                   testHarness.processElement1(new StreamRecord<>(98, 
19L));
    +
    +                   testHarness.processElement2(new 
StreamRecord<>("trigger", 13L));
    +
    +                   testHarness.processElement1(new StreamRecord<>(51, 
21L));
    +
    +                   testHarness.processWatermark1(new Watermark(50L));
    +                   testHarness.processWatermark2(new Watermark(50L));
    +
    +                   Queue<Object> output = testHarness.getOutput();
    +                   Assert.assertEquals(3L, output.size());
    +
    +                   Object firstRawWm = output.poll();
    +                   Assert.assertTrue(firstRawWm instanceof Watermark);
    +                   Watermark firstWm = (Watermark) firstRawWm;
    +                   Assert.assertEquals(10L, firstWm.getTimestamp());
    +
    +                   Object rawOutputElem = output.poll();
    +                   Assert.assertTrue(rawOutputElem instanceof 
StreamRecord);
    +                   StreamRecord<?> outputRec = (StreamRecord<?>) 
rawOutputElem;
    +                   Assert.assertTrue(outputRec.getValue() instanceof 
String);
    +                   String outputElem = (String) outputRec.getValue();
    +
    +                   expectedBroadcastState.put("51.key", 51);
    +                   List<Map.Entry<String, Integer>> expectedEntries = new 
ArrayList<>();
    +                   
expectedEntries.addAll(expectedBroadcastState.entrySet());
    +                   String expected = "TS:41 " + 
mapToString(expectedEntries);
    +                   Assert.assertEquals(expected, outputElem);
    +
    +                   Object secondRawWm = output.poll();
    +                   Assert.assertTrue(secondRawWm instanceof Watermark);
    +                   Watermark secondWm = (Watermark) secondRawWm;
    +                   Assert.assertEquals(50L, secondWm.getTimestamp());
    +           }
    +   }
    +
    +   private static class FunctionWithBroadcastState extends 
KeyedBroadcastProcessFunction<String, Integer, String, String, Integer, String> 
{
    +
    +           private static final long serialVersionUID = 
7496674620398203933L;
    +
    +           private final String keyPostfix;
    +           private final Map<String, Integer> expectedBroadcastState;
    +           private final long timerTs;
    +
    +           FunctionWithBroadcastState(
    +                           final String keyPostfix,
    +                           final Map<String, Integer> 
expectedBroadcastState,
    +                           final long timerTs
    +           ) {
    +                   this.keyPostfix = 
Preconditions.checkNotNull(keyPostfix);
    +                   this.expectedBroadcastState = 
Preconditions.checkNotNull(expectedBroadcastState);
    +                   this.timerTs = timerTs;
    +           }
    +
    +           @Override
    +           public void processElementOnBroadcastSide(Integer value, 
KeyedReadWriteContext ctx, Collector<String> out) throws Exception {
    +                   // put an element in the broadcast state
    +                   final String key = value + "." + keyPostfix;
    +                   ctx.putToBroadcast(key, value);
    +           }
    +
    +           @Override
    +           public void processElement(String value, KeyedReadOnlyContext 
ctx, Collector<String> out) throws Exception {
    +                   Iterator<Map.Entry<String, Integer>> broadcastStateIt = 
ctx.readOnlyBroadcastIterable().iterator();
    +
    +                   for (int i = 0; i < expectedBroadcastState.size(); i++) 
{
    +                           Assert.assertTrue(broadcastStateIt.hasNext());
    +
    +                           Map.Entry<String, Integer> entry = 
broadcastStateIt.next();
    +                           
Assert.assertTrue(expectedBroadcastState.containsKey(entry.getKey()));
    +                           
Assert.assertEquals(expectedBroadcastState.get(entry.getKey()), 
entry.getValue());
    +                   }
    +
    +                   Assert.assertFalse(broadcastStateIt.hasNext());
    +
    +                   ctx.timerService().registerEventTimeTimer(timerTs);
    +           }
    +
    +           @Override
    +           public void onTimer(long timestamp, OnTimerContext ctx, 
Collector<String> out) throws Exception {
    +                   final Iterator<Map.Entry<String, Integer>> 
broadcastStateIt = ctx.readOnlyBroadcastIterable().iterator();
    +                   final List<Map.Entry<String, Integer>> map = new 
ArrayList<>();
    +                   while (broadcastStateIt.hasNext()) {
    +                           map.add(broadcastStateIt.next());
    +                   }
    +                   final String mapToStr = mapToString(map);
    +                   out.collect("TS:" + timestamp + " " + mapToStr);
    +           }
    +   }
    +
    +   @Test
    +   public void testScaleUp() throws Exception {
    +           final Set<String> keysToRegister = new HashSet<>();
    +           keysToRegister.add("test1");
    +           keysToRegister.add("test2");
    +           keysToRegister.add("test3");
    +
    +           final OperatorStateHandles mergedSnapshot;
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness1 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           2,
    +                                           0);
    +
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness2 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           2,
    +                                           1)
    +
    +           ) {
    +
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness1 = autoTestHarness1.getTestHarness();
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness2 = autoTestHarness2.getTestHarness();
    +
    +                   // make sure all operators have the same state
    +                   testHarness1.processElement1(new StreamRecord<>(3));
    +                   testHarness2.processElement1(new StreamRecord<>(3));
    +
    +                   mergedSnapshot = 
AbstractStreamOperatorTestHarness.repackageState(
    +                                   testHarness1.snapshot(0L, 0L),
    +                                   testHarness2.snapshot(0L, 0L)
    +                   );
    +           }
    +
    +           final Set<String> expected = new HashSet<>(3);
    +           expected.add("test1=3");
    +           expected.add("test2=3");
    +           expected.add("test3=3");
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness1 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           3,
    +                                           0,
    +                                           mergedSnapshot);
    +
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness2 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           3,
    +                                           1,
    +                                           mergedSnapshot);
    +
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness3 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           3,
    +                                           2,
    +                                           mergedSnapshot)
    +                   ) {
    +
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness1 = autoTestHarness1.getTestHarness();
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness2 = autoTestHarness2.getTestHarness();
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness3 = autoTestHarness3.getTestHarness();
    +
    +                   testHarness1.processElement2(new 
StreamRecord<>("trigger"));
    +                   testHarness2.processElement2(new 
StreamRecord<>("trigger"));
    +                   testHarness3.processElement2(new 
StreamRecord<>("trigger"));
    +
    +                   Queue<?> output1 = testHarness1.getOutput();
    +                   Queue<?> output2 = testHarness2.getOutput();
    +                   Queue<?> output3 = testHarness3.getOutput();
    +
    +                   Assert.assertEquals(expected.size(), output1.size());
    +                   for (Object o: output1) {
    +                           StreamRecord<String> rec = 
(StreamRecord<String>) o;
    +                           
Assert.assertTrue(expected.contains(rec.getValue()));
    +                   }
    +
    +                   Assert.assertEquals(expected.size(), output2.size());
    +                   for (Object o: output2) {
    +                           StreamRecord<String> rec = 
(StreamRecord<String>) o;
    +                           
Assert.assertTrue(expected.contains(rec.getValue()));
    +                   }
    +
    +                   Assert.assertEquals(expected.size(), output3.size());
    +                   for (Object o: output3) {
    +                           StreamRecord<String> rec = 
(StreamRecord<String>) o;
    +                           
Assert.assertTrue(expected.contains(rec.getValue()));
    +                   }
    +           }
    +   }
    +
    +   @Test
    +   public void testScaleDown() throws Exception {
    +           final Set<String> keysToRegister = new HashSet<>();
    +           keysToRegister.add("test1");
    +           keysToRegister.add("test2");
    +           keysToRegister.add("test3");
    +
    +           final OperatorStateHandles mergedSnapshot;
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness1 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           3,
    +                                           0);
    +
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness2 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           3,
    +                                           1);
    +
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness3 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           3,
    +                                           2)
    +           ) {
    +
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness1 = autoTestHarness1.getTestHarness();
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness2 = autoTestHarness2.getTestHarness();
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness3 = autoTestHarness3.getTestHarness();
    +
    +                   // make sure all operators have the same state
    +                   testHarness1.processElement1(new StreamRecord<>(3));
    +                   testHarness2.processElement1(new StreamRecord<>(3));
    +                   testHarness3.processElement1(new StreamRecord<>(3));
    +
    +                   mergedSnapshot = 
AbstractStreamOperatorTestHarness.repackageState(
    +                                   testHarness1.snapshot(0L, 0L),
    +                                   testHarness2.snapshot(0L, 0L),
    +                                   testHarness3.snapshot(0L, 0L)
    +                   );
    +           }
    +
    +           final Set<String> expected = new HashSet<>(3);
    +           expected.add("test1=3");
    +           expected.add("test2=3");
    +           expected.add("test3=3");
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness1 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           2,
    +                                           0,
    +                                           mergedSnapshot);
    +
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness2 = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
TestFunctionWithOutput(keysToRegister),
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO,
    +                                           10,
    +                                           2,
    +                                           1,
    +                                           mergedSnapshot)
    +           ) {
    +
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness1 = autoTestHarness1.getTestHarness();
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness2 = autoTestHarness2.getTestHarness();
    +
    +                   testHarness1.processElement2(new 
StreamRecord<>("trigger"));
    +                   testHarness2.processElement2(new 
StreamRecord<>("trigger"));
    +
    +                   Queue<?> output1 = testHarness1.getOutput();
    +                   Queue<?> output2 = testHarness2.getOutput();
    +
    +                   Assert.assertEquals(expected.size(), output1.size());
    +                   for (Object o: output1) {
    +                           StreamRecord<String> rec = 
(StreamRecord<String>) o;
    +                           
Assert.assertTrue(expected.contains(rec.getValue()));
    +                   }
    +
    +                   Assert.assertEquals(expected.size(), output2.size());
    +                   for (Object o: output2) {
    +                           StreamRecord<String> rec = 
(StreamRecord<String>) o;
    +                           
Assert.assertTrue(expected.contains(rec.getValue()));
    +                   }
    +           }
    +   }
    +
    +   private static class TestFunctionWithOutput extends 
KeyedBroadcastProcessFunction<String, Integer, String, String, Integer, String> 
{
    +
    +           private static final long serialVersionUID = 
7496674620398203933L;
    +
    +           private final Set<String> keysToRegister;
    +
    +           TestFunctionWithOutput(Set<String> keysToRegister) {
    +                   this.keysToRegister = 
Preconditions.checkNotNull(keysToRegister);
    +           }
    +
    +           @Override
    +           public void processElementOnBroadcastSide(Integer value, 
KeyedReadWriteContext ctx, Collector<String> out) throws Exception {
    +                   // put an element in the broadcast state
    +                   for (String k : keysToRegister) {
    +                           ctx.putToBroadcast(k, value);
    +                   }
    +           }
    +
    +           @Override
    +           public void processElement(String value, KeyedReadOnlyContext 
ctx, Collector<String> out) throws Exception {
    +                   for (Map.Entry<String, Integer> entry : 
ctx.readOnlyBroadcastIterable()) {
    +                           out.collect(entry.toString());
    +                   }
    +           }
    +   }
    +
    +   @Test
    +   public void testNoKeyedStateOnBroadcastSide() throws Exception {
    +
    +           boolean exceptionThrown = false;
    +
    +           try (
    +                           AutoClosableTestHarness<String, Integer, 
String, String, Integer, String> autoTestHarness = new 
AutoClosableTestHarness<>(
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           new IdentityKeySelector<>(),
    +                                           new 
KeyedBroadcastProcessFunction<String, Integer, String, String, Integer, 
String>() {
    +
    +                                                   private static final 
long serialVersionUID = -1725365436500098384L;
    +
    +                                                   private final 
ValueStateDescriptor<String> valueState = new ValueStateDescriptor<>("any", 
BasicTypeInfo.STRING_TYPE_INFO);
    +
    +                                                   @Override
    +                                                   public void 
processElementOnBroadcastSide(Integer value, KeyedReadWriteContext ctx, 
Collector<String> out) throws Exception {
    +                                                           
getRuntimeContext().getState(valueState).value(); // this should fail
    +                                                   }
    +
    +                                                   @Override
    +                                                   public void 
processElement(String value, KeyedReadOnlyContext ctx, Collector<String> out) 
throws Exception {
    +                                                           // do nothing
    +                                                   }
    +                                           },
    +                                           BasicTypeInfo.STRING_TYPE_INFO,
    +                                           BasicTypeInfo.INT_TYPE_INFO)
    +           ) {
    +                   TwoInputStreamOperatorTestHarness<Integer, String, 
String> testHarness = autoTestHarness.getTestHarness();
    +
    +                   testHarness.processWatermark1(new Watermark(10L));
    +                   testHarness.processWatermark2(new Watermark(10L));
    +                   testHarness.processElement1(new StreamRecord<>(5, 12L));
    +           } catch (NullPointerException e) {
    +                   Assert.assertEquals("No key set. This method should not 
be called outside of a keyed context.", e.getMessage());
    +                   exceptionThrown = true;
    +           }
    +
    +           if (!exceptionThrown) {
    +                   Assert.fail("No exception thrown");
    +           }
    +   }
    +
    +   private static class IdentityKeySelector<T> implements KeySelector<T, 
T> {
    +           private static final long serialVersionUID = 1L;
    +
    +           @Override
    +           public T getKey(T value) throws Exception {
    +                   return value;
    +           }
    +   }
    +
    +   /**
    +    * A wrapper of the test harness that makes sure to close it after the 
test finishes.
    +    */
    +   private static class AutoClosableTestHarness<KEY, IN1, IN2, K, V, OUT> 
implements AutoCloseable {
    --- End diff --
    
    The harnesses themselves are already `AutoCloseable`. (It's a somewhat 
newer addition, though.)


---

Reply via email to