zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340287



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);

Review comment:
       Hi @gaoyunhaii, I have made the changes.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new 
CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            
transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, 
inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), 
inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = 
draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams 
are available at all
+     * parallel instances of the input operators. A broadcast data stream is 
registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first 
and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input 
are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is 
the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can 
access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> 
userDefinedFunction) {
+        Preconditions.checkState(inputList.size() > 0);
+        StreamExecutionEnvironment env = 
inputList.get(0).getExecutionEnvironment();
+        final String[] broadcastStreamNames = bcStreams.keySet().toArray(new 
String[0]);
+        DataStream<OUT> resultStream =
+                buildGraph(env, inputList, broadcastStreamNames, 
userDefinedFunction);
+
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = getCoLocationKey(broadcastStreamNames);
+        DataStream<OUT> cachedBroadcastInputs = cacheBroadcastVariables(env, 
bcStreams, outType);
+
+        for (int i = 0; i < inputList.size(); i++) {
+            
inputList.get(i).getTransformation().setCoLocationGroupKey(coLocationKey);

Review comment:
       Hi Yun, Thanks for the feedback.
   
   I have updated the code in my my understanding. Please check 
`BroadcatUtils#line93` for details.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] 
inTypes) {

Review comment:
       Thanks Yun, I have made the change.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] 
inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, 
boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] 
inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, 
boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);
+        this.broadcastStreamNames = broadcastStreamNames;

Review comment:
       Thanks Yun. But can we use CheckNotNull here instead?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new 
CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link 
BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>
+        implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, 
Serializable {
+    /** names of the broadcast DataStreams. */
+    private final String[] broadcastNames;
+    /** input list of the multi-input operator. */
+    private final List<Input> inputList;
+    /** output types of input DataStreams. */
+    private final TypeInformation<?>[] inTypes;
+    /** caches of the broadcast inputs. */
+    private final List<?>[] caches;
+    /** state storage of the broadcast inputs. */
+    private ListState<?>[] cacheStates;
+    /** cacheReady state storage of the broadcast inputs. */
+    private ListState<Boolean>[] cacheReadyStates;
+
+    public CacheStreamOperator(
+            StreamOperatorParameters<OUT> parameters,
+            String[] broadcastNames,
+            TypeInformation<?>[] inTypes) {
+        super(parameters, broadcastNames.length);
+        this.broadcastNames = broadcastNames;
+        this.inTypes = inTypes;
+        this.caches = new List[inTypes.length];
+        for (int i = 0; i < inTypes.length; i++) {
+            caches[i] = new ArrayList<>();
+        }
+        this.cacheStates = new ListState[inTypes.length];
+        this.cacheReadyStates = new ListState[inTypes.length];
+
+        inputList = new ArrayList<>();
+        for (int i = 0; i < inTypes.length; i++) {
+            inputList.add(new ProxyInput(this, i + 1));
+        }
+    }
+
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
+    }
+
+    @Override
+    public void endInput(int i) {
+        BroadcastContext.markCacheFinished(
+                Tuple2.of(broadcastNames[i - 1], 
getRuntimeContext().getIndexOfThisSubtask()));
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i].clear();
+            cacheStates[i].addAll((List) caches[i]);
+            cacheReadyStates[i].clear();
+            boolean isCacheFinished =
+                    BroadcastContext.isCacheFinished(
+                            Tuple2.of(
+                                    broadcastNames[i],
+                                    
getRuntimeContext().getIndexOfThisSubtask()));
+            cacheReadyStates[i].add(isCacheFinished);
+        }
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws 
Exception {
+        super.initializeState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i] =

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import 
org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new 
String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new 
DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new 
TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new 
CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            
transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, 
OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, 
isBlocking);
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception 
{
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                dataCacheWriters[0].finishCurrentSegmentAndStartNewSegment();
+                
segmentLists[0].addAll(dataCacheWriters[0].getNewlyFinishedSegments());
+                if (segmentLists[0].size() != 0) {
+                    DataCacheReader dataCacheReader =
+                            new DataCacheReader<>(
+                                    inTypes[0].createSerializer(
+                                            
containingTask.getExecutionConfig()),
+                                    fileSystem,
+                                    segmentLists[0]);
+                    while (dataCacheReader.hasNext()) {
+                        wrappedOperator.processElement(new 
StreamRecord(dataCacheReader.next()));
+                    }
+                }
+                segmentLists[0].clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                dataCacheWriters[0].addRecord(streamRecord.getValue());
+            }
+
+        } else {
+            while (!areBroadcastVariablesReady()) {
+                mailboxExecutor.yield();
+            }
+            wrappedOperator.processElement(streamRecord);
+        }
+    }

Review comment:
       Thanks Yun. I have finished the refactoring, please refer to 
`AbstractBroadcastWrapperOperator#processElementX()` and 
`AbstractBroadcastWrapperOperator#endInputX()` 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to