Github user aljoscha commented on a diff in the pull request: https://github.com/apache/flink/pull/5230#discussion_r165395860 --- Diff: flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java --- @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.co; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ReadOnlyBroadcastState; +import org.apache.flink.api.common.state.ReadWriteBroadcastState; +import org.apache.flink.api.common.state.State; +import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.KeyedStateFunction; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; +import org.apache.flink.streaming.api.SimpleTimerService; +import org.apache.flink.streaming.api.TimeDomain; +import org.apache.flink.streaming.api.TimerService; +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.InternalTimer; +import org.apache.flink.streaming.api.operators.InternalTimerService; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.operators.Triggerable; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A {@link TwoInputStreamOperator} for executing {@link KeyedBroadcastProcessFunction KeyedBroadcastProcessFunctions}. + * + * @param <KS> The key type of the input keyed stream. + * @param <IN1> The input type of the keyed (non-broadcast) side. + * @param <IN2> The input type of the broadcast side. + * @param <OUT> The output type of the operator. + */ +@Internal +public class CoBroadcastWithKeyedOperator<KS, IN1, IN2, OUT> + extends AbstractUdfStreamOperator<OUT, KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>> + implements TwoInputStreamOperator<IN1, IN2, OUT>, Triggerable<KS, VoidNamespace> { + + private static final long serialVersionUID = 5926499536290284870L; + + private final List<MapStateDescriptor<?, ?>> broadcastStateDescriptors; + + private transient TimestampedCollector<OUT> collector; + + private transient Map<MapStateDescriptor<?, ?>, ReadWriteBroadcastState<?, ?>> broadcastStates; + + private transient ReadWriteContextImpl rwContext; + + private transient ReadOnlyContextImpl rContext; + + private transient OnTimerContextImpl onTimerContext; + + public CoBroadcastWithKeyedOperator( + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, + final List<MapStateDescriptor<?, ?>> broadcastStateDescriptors) { + super(function); + this.broadcastStateDescriptors = Preconditions.checkNotNull(broadcastStateDescriptors); + } + + @Override + public void open() throws Exception { + super.open(); + + InternalTimerService<VoidNamespace> internalTimerService = + getInternalTimerService("user-timers", VoidNamespaceSerializer.INSTANCE, this); + + TimerService timerService = new SimpleTimerService(internalTimerService); + + collector = new TimestampedCollector<>(output); + + this.broadcastStates = new HashMap<>(broadcastStateDescriptors.size()); + for (MapStateDescriptor<?, ?> descriptor: broadcastStateDescriptors) { + broadcastStates.put(descriptor, getOperatorStateBackend().getBroadcastState(descriptor)); + } + + rwContext = new ReadWriteContextImpl(getKeyedStateBackend(), userFunction, broadcastStates, timerService); + rContext = new ReadOnlyContextImpl(userFunction, broadcastStates, timerService); + onTimerContext = new OnTimerContextImpl(userFunction, broadcastStates, timerService); + } + + @Override + public void processElement1(StreamRecord<IN1> element) throws Exception { + collector.setTimestamp(element); + rContext.setElement(element); + userFunction.processElement(element.getValue(), rContext, collector); + rContext.setElement(null); + } + + @Override + public void processElement2(StreamRecord<IN2> element) throws Exception { + collector.setTimestamp(element); + rwContext.setElement(element); + userFunction.processBroadcastElement(element.getValue(), rwContext, collector); + rwContext.setElement(null); + } + + @Override + public void onEventTime(InternalTimer<KS, VoidNamespace> timer) throws Exception { + collector.setAbsoluteTimestamp(timer.getTimestamp()); + onTimerContext.timeDomain = TimeDomain.EVENT_TIME; + onTimerContext.timer = timer; + userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + onTimerContext.timer = null; + } + + @Override + public void onProcessingTime(InternalTimer<KS, VoidNamespace> timer) throws Exception { + collector.eraseTimestamp(); + onTimerContext.timeDomain = TimeDomain.PROCESSING_TIME; + onTimerContext.timer = timer; + userFunction.onTimer(timer.getTimestamp(), onTimerContext, collector); + onTimerContext.timeDomain = null; + onTimerContext.timer = null; + } + + private class ReadWriteContextImpl extends KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.KeyedReadWriteContext { + + private final KeyedStateBackend<KS> keyedStateBackend; + + private final Map<MapStateDescriptor<?, ?>, ReadWriteBroadcastState<?, ?>> states; + + private final TimerService timerService; + + private StreamRecord<IN2> element; + + ReadWriteContextImpl ( + final KeyedStateBackend<KS> keyedStateBackend, + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, + final Map<MapStateDescriptor<?, ?>, ReadWriteBroadcastState<?, ?>> broadcastStates, + final TimerService timerService) { + + function.super(); + this.keyedStateBackend = Preconditions.checkNotNull(keyedStateBackend); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + void setElement(StreamRecord<IN2> e) { + this.element = e; + } + + @Override + public Long timestamp() { + checkState(element != null); + return element.getTimestamp(); + } + + @Override + public <K, V> ReadWriteBroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, V> stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + return (ReadWriteBroadcastState<K, V>) states.get(stateDescriptor); + } + + @Override + public <X> void output(OutputTag<X> outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + + @Override + public long currentProcessingTime() { + return timerService.currentProcessingTime(); + } + + @Override + public long currentWatermark() { + return timerService.currentWatermark(); + } + + @Override + public <VS, S extends State> void applyToKeyedState( + final StateDescriptor<S, VS> stateDescriptor, + final KeyedStateFunction<KS, S> function) throws Exception { + + keyedStateBackend.applyToAllKeys( + Preconditions.checkNotNull(stateDescriptor), + Preconditions.checkNotNull(function)); + } + } + + private class ReadOnlyContextImpl extends KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.KeyedReadOnlyContext { + + private final Map<MapStateDescriptor<?, ?>, ReadWriteBroadcastState<?, ?>> states; + + private final TimerService timerService; + + private StreamRecord<IN1> element; + + ReadOnlyContextImpl( + final KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT> function, + final Map<MapStateDescriptor<?, ?>, ReadWriteBroadcastState<?, ?>> broadcastStates, + final TimerService timerService) { + + function.super(); + this.states = Preconditions.checkNotNull(broadcastStates); + this.timerService = Preconditions.checkNotNull(timerService); + } + + void setElement(StreamRecord<IN1> e) { + this.element = e; + } + + @Override + public Long timestamp() { + checkState(element != null); + return element.hasTimestamp() ? element.getTimestamp() : null; + } + + @Override + public TimerService timerService() { + return timerService; + } + + @Override + public long currentProcessingTime() { + return timerService.currentProcessingTime(); + } + + @Override + public long currentWatermark() { + return timerService.currentWatermark(); + } + + @Override + public <X> void output(OutputTag<X> outputTag, X value) { + checkArgument(outputTag != null, "OutputTag must not be null."); + output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); + } + + @Override + public <K, V> ReadOnlyBroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, V> stateDescriptor) { + Preconditions.checkNotNull(stateDescriptor); + return (ReadOnlyBroadcastState<K, V>) states.get(stateDescriptor); --- End diff -- I think we should throw an exception here that explains what's going on if a state is not available.
---