lindong28 commented on a change in pull request #32:
URL: https://github.com/apache/flink-ml/pull/32#discussion_r755192622



##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -158,10 +148,11 @@ public void testFeaturePredictionParam() throws Exception 
{
         assertEquals(
                 Arrays.asList("test_feature", "test_prediction"),
                 output.getResolvedSchema().getColumnNames());
-        Map<DenseVector, Integer> clusterIdByPoints =
-                executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
-        verifyClusteringResult(
-                clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2), 
Arrays.asList(3, 4, 5)));
+        assertTrue(
+                CollectionUtils.isEqualCollection(
+                        executeAndCollect(

Review comment:
       nits: this statement seems a bit too long. Maybe move 
`executeAndCollect` to a separate statement?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -83,43 +97,19 @@ public void before() {
         dataTable = tEnv.fromDataStream(env.fromCollection(DATA), 
schema).as("features");
     }
 
-    // Executes the graph and returns a map which maps points to clusterId.
-    private static Map<DenseVector, Integer> executeAndCollect(
-            Table output, String featureCol, String predictionCol) throws 
Exception {
-        StreamTableEnvironment tEnv =
-                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
-
-        DataStream<Tuple2<DenseVector, Integer>> stream =
-                tEnv.toDataStream(output)
-                        .map(
-                                new MapFunction<Row, Tuple2<DenseVector, 
Integer>>() {
-                                    @Override
-                                    public Tuple2<DenseVector, Integer> 
map(Row row) {
-                                        return Tuple2.of(
-                                                (DenseVector) 
row.getField(featureCol),
-                                                (Integer) 
row.getField(predictionCol));
-                                    }
-                                });
-
-        List<Tuple2<DenseVector, Integer>> pointsWithClusterId =
-                IteratorUtils.toList(stream.executeAndCollect());
-
-        Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>();
-        for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) {
-            clusterIdByPoints.put(entry.f0, entry.f1);
-        }
-        return clusterIdByPoints;
-    }
-
-    private static void verifyClusteringResult(
-            Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>> 
groups) {
-        for (List<Integer> group : groups) {
-            for (int i = 1; i < group.size(); i++) {
-                assertEquals(
-                        clusterIdByPoints.get(DATA.get(group.get(0))),
-                        clusterIdByPoints.get(DATA.get(group.get(i))));
+    private static List<Set<DenseVector>> executeAndCollect(

Review comment:
       nits: Can we add comments for this method?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Map;
+
+/** The model data of {@link NaiveBayesModel}. Also provides classes to 
save/load model data. */
+public class NaiveBayesModelData implements Serializable {
+    private static final long serialVersionUID = 3919917903722286395L;
+    public final Map<Double, Double>[][] theta;
+    public final double[] piArray;
+    public final double[] labels;
+
+    // Empty constructor is used when Kyro deserializes loaded model data.
+    public NaiveBayesModelData() {
+        this(null, null, null);
+    }
+
+    public NaiveBayesModelData(Map<Double, Double>[][] theta, double[] 
piArray, double[] labels) {
+        this.theta = theta;
+        this.piArray = piArray;
+        this.labels = labels;
+    }
+
+    public static Table fromDataStream(

Review comment:
       Can we have Java doc here since this is a public method?
   
   Should we also update KMeans to use the same pattern?
   
   Can we use more informational names here, e.g. `getModelDataTable` and 
`getModelDataStream`?
   
   Same for `toDataStream`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+        implements Estimator<NaiveBayes, NaiveBayesModel>, 
NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public NaiveBayes() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final double smoothing = getSmoothing();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Vector, Double>> input =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<Vector, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Vector, Double> map(Row row) 
throws Exception {
+                                        return new Tuple2<>(
+                                                (Vector) 
row.getField(featuresCol),
+                                                (Double) 
row.getField(labelCol));
+                                    }
+                                });
+
+        DataStream<NaiveBayesModelData> naiveBayesModel =
+                input.flatMap(new FlattenFunction())
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Integer, Double, 
Double>, Object>)
+                                        value -> new Tuple3<>(value.f0, 
value.f1, value.f2))
+                        .window(EndOfStreamWindows.get())
+                        .reduce(
+                                (ReduceFunction<Tuple4<Double, Integer, 
Double, Double>>)
+                                        (t0, t1) -> {
+                                            t0.f3 += t1.f3;
+                                            return t0;
+                                        })
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Integer, Double, 
Double>, Object>)
+                                        value -> new Tuple2<>(value.f0, 
value.f1))
+                        .window(EndOfStreamWindows.get())
+                        .aggregate(new ValueMapFunction())
+                        .keyBy(
+                                (KeySelector<
+                                                Tuple4<
+                                                        Double,
+                                                        Integer,
+                                                        Map<Double, Double>,
+                                                        Double>,
+                                                Object>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .aggregate(new MapArrayFunction())
+                        .windowAll(EndOfStreamWindows.get())
+                        .apply(new GenerateModelFunction(smoothing));
+
+        NaiveBayesModel model =
+                new NaiveBayesModel()
+                        .setModelData(NaiveBayesModelData.fromDataStream(tEnv, 
naiveBayesModel));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static NaiveBayes load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Function to convert each column into tuples of label, feature column 
index, feature value,
+     * weight.
+     */
+    private static class FlattenFunction
+            implements FlatMapFunction<
+                    Tuple2<Vector, Double>, Tuple4<Double, Integer, Double, 
Double>> {
+        @Override
+        public void flatMap(
+                Tuple2<Vector, Double> value,
+                Collector<Tuple4<Double, Integer, Double, Double>> collector) {
+            Preconditions.checkNotNull(value.f1);
+            for (int i = 0; i < value.f0.size(); i++) {
+                collector.collect(new Tuple4<>(value.f1, i, value.f0.get(i), 
1.0));
+            }
+        }
+    }
+
+    /**
+     * Function to aggregate feature value and weight into map from records 
with the same label and
+     * feature column index.
+     */
+    private static class ValueMapFunction
+            implements AggregateFunction<
+                    Tuple4<Double, Integer, Double, Double>,
+                    Tuple3<Double, Integer, Map<Double, Double>>,
+                    Tuple4<Double, Integer, Map<Double, Double>, Double>> {
+
+        @Override
+        public Tuple3<Double, Integer, Map<Double, Double>> 
createAccumulator() {
+            return new Tuple3<>(0., -1, new HashMap<>());
+        }
+
+        @Override
+        public Tuple3<Double, Integer, Map<Double, Double>> add(
+                Tuple4<Double, Integer, Double, Double> value,
+                Tuple3<Double, Integer, Map<Double, Double>> acc) {
+            acc.f0 = value.f0;
+            acc.f1 = value.f1;
+            acc.f2.put(value.f2, value.f3);
+            return acc;
+        }
+
+        @Override
+        public Tuple4<Double, Integer, Map<Double, Double>, Double> getResult(
+                Tuple3<Double, Integer, Map<Double, Double>> acc) {
+            double weightSum = acc.f2.values().stream().mapToDouble(f -> 
f).sum();
+            return new Tuple4<>(acc.f0, acc.f1, acc.f2, weightSum);
+        }
+
+        @Override
+        public Tuple3<Double, Integer, Map<Double, Double>> merge(
+                Tuple3<Double, Integer, Map<Double, Double>> acc0,
+                Tuple3<Double, Integer, Map<Double, Double>> acc1) {
+            Preconditions.checkArgument(acc0.f1 != -1);
+            acc0.f2.putAll(acc1.f2);
+            return acc0;
+        }
+    }
+
+    /** Function to aggregate maps under the same label into arrays. array len 
= featureSize. */
+    private static class MapArrayFunction

Review comment:
       It is not very easy to understand the purpose of this function.
   
   Could you add comments explaining the meaning of fields of input and 
accumulators?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -67,6 +69,18 @@
                     Vectors.dense(9.6, 0.0));
     private StreamExecutionEnvironment env;
     private StreamTableEnvironment tEnv;
+    private static List<Set<DenseVector>> expectedGroup =

Review comment:
       nits: expectedGroup -> expectedGroups

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
##########
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link NaiveBayes} and {@link NaiveBayesModel}. */
+public class NaiveBayesTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Schema schema;
+    private List<Row> trainData;
+    private List<Row> predictData;
+    private List<Row> expectedOutput;
+    private boolean isSaveLoad;

Review comment:
       I believe the general rules of thumb is to make variable a member 
variable if this variable is written only in `before`. For variables written in 
individual tests (e.g. `isSaveLoad`), we typically pass this variable directly 
from the test method to the callee method.
   
   I could be wrong. Please feel free to verify this practice in the Flink 
codebase. If this is indeed the case for most tests, can we follow this 
practice here?
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+        implements Estimator<NaiveBayes, NaiveBayesModel>, 
NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public NaiveBayes() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final double smoothing = getSmoothing();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Vector, Double>> input =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<Vector, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Vector, Double> map(Row row) 
throws Exception {
+                                        return new Tuple2<>(
+                                                (Vector) 
row.getField(featuresCol),
+                                                (Double) 
row.getField(labelCol));
+                                    }
+                                });
+
+        DataStream<NaiveBayesModelData> naiveBayesModel =
+                input.flatMap(new FlattenFunction())

Review comment:
       Can we rename `FlattenFunction` to something more meaningful?
   
   Same for `ValueMapFunction` and `MapArrayFunction`.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -83,43 +97,19 @@ public void before() {
         dataTable = tEnv.fromDataStream(env.fromCollection(DATA), 
schema).as("features");
     }
 
-    // Executes the graph and returns a map which maps points to clusterId.
-    private static Map<DenseVector, Integer> executeAndCollect(
-            Table output, String featureCol, String predictionCol) throws 
Exception {
-        StreamTableEnvironment tEnv =
-                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
-
-        DataStream<Tuple2<DenseVector, Integer>> stream =
-                tEnv.toDataStream(output)
-                        .map(
-                                new MapFunction<Row, Tuple2<DenseVector, 
Integer>>() {
-                                    @Override
-                                    public Tuple2<DenseVector, Integer> 
map(Row row) {
-                                        return Tuple2.of(
-                                                (DenseVector) 
row.getField(featureCol),
-                                                (Integer) 
row.getField(predictionCol));
-                                    }
-                                });
-
-        List<Tuple2<DenseVector, Integer>> pointsWithClusterId =
-                IteratorUtils.toList(stream.executeAndCollect());
-
-        Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>();
-        for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) {
-            clusterIdByPoints.put(entry.f0, entry.f1);
-        }
-        return clusterIdByPoints;
-    }
-
-    private static void verifyClusteringResult(
-            Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>> 
groups) {
-        for (List<Integer> group : groups) {
-            for (int i = 1; i < group.size(); i++) {
-                assertEquals(
-                        clusterIdByPoints.get(DATA.get(group.get(0))),
-                        clusterIdByPoints.get(DATA.get(group.get(i))));
+    private static List<Set<DenseVector>> executeAndCollect(
+            Table output, String featureCol, String predictionCol) {
+        Map<Integer, Set<DenseVector>> map = new HashMap<>();
+        for (CloseableIterator<Row> it = output.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            int predict = (Integer) row.getField(predictionCol);
+            if (!map.containsKey(predict)) {
+                map.put(predict, new HashSet<>());

Review comment:
       nits: `map.putIfAbsent(predict, new HashSet<>())`

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
##########
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link NaiveBayes} and {@link NaiveBayesModel}. */
+public class NaiveBayesTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Schema schema;
+    private List<Row> trainData;
+    private List<Row> predictData;
+    private List<Row> expectedOutput;
+    private boolean isSaveLoad;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .column("f2", DataTypes.DOUBLE())
+                        .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)")
+                        .watermark("rowtime", "SOURCE_WATERMARK()")
+                        .build();
+
+        trainData =
+                Arrays.asList(
+                        Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0),
+                        Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0),
+                        Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0),
+                        Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0),
+                        Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0),
+                        Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0),
+                        Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0));
+
+        predictData = trainData;
+
+        expectedOutput =
+                Arrays.asList(
+                        Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0, 
11.0),
+                        Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0, 
11.0),
+                        Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0, 
11.0),
+                        Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0, 
11.0),
+                        Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0, 
10.0),
+                        Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0, 
10.0),
+                        Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0, 
10.0));
+
+        isSaveLoad = false;
+    }
+
+    @Test
+    public void testParam() {
+        NaiveBayes estimator = new NaiveBayes();
+
+        assertEquals("features", estimator.getFeaturesCol());
+        assertEquals("label", estimator.getLabelCol());
+        assertEquals("multinomial", estimator.getModelType());
+        assertEquals("prediction", estimator.getPredictionCol());
+        assertEquals(1.0, estimator.getSmoothing(), 1e-5);
+
+        estimator
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setPredictionCol("test_prediction")
+                .setSmoothing(2.0);
+
+        assertEquals("test_feature", estimator.getFeaturesCol());
+        assertEquals("test_label", estimator.getLabelCol());
+        assertEquals("test_prediction", estimator.getPredictionCol());
+        assertEquals(2.0, estimator.getSmoothing(), 1e-5);
+
+        NaiveBayesModel model = new NaiveBayesModel();
+
+        assertEquals("features", model.getFeaturesCol());
+        assertEquals("multinomial", model.getModelType());
+        assertEquals("prediction", model.getPredictionCol());
+
+        
model.setFeaturesCol("test_feature").setPredictionCol("test_prediction");
+
+        assertEquals("test_feature", model.getFeaturesCol());
+        assertEquals("test_prediction", model.getPredictionCol());
+    }
+
+    @Test
+    public void testNaiveBayes() throws Exception {

Review comment:
       nits: This test method is a bit broad here given that the class name is 
already `NaiveBayesTest` and every tests in this class is supposed to "test 
NaiveBayes". How about changing this to `testFitAndPredict`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+        implements Estimator<NaiveBayes, NaiveBayesModel>, 
NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public NaiveBayes() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final double smoothing = getSmoothing();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Vector, Double>> input =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<Vector, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Vector, Double> map(Row row) 
throws Exception {
+                                        return new Tuple2<>(
+                                                (Vector) 
row.getField(featuresCol),
+                                                (Double) 
row.getField(labelCol));
+                                    }
+                                });
+
+        DataStream<NaiveBayesModelData> naiveBayesModel =
+                input.flatMap(new FlattenFunction())
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Integer, Double, 
Double>, Object>)
+                                        value -> new Tuple3<>(value.f0, 
value.f1, value.f2))
+                        .window(EndOfStreamWindows.get())
+                        .reduce(
+                                (ReduceFunction<Tuple4<Double, Integer, 
Double, Double>>)
+                                        (t0, t1) -> {
+                                            t0.f3 += t1.f3;
+                                            return t0;
+                                        })
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Integer, Double, 
Double>, Object>)
+                                        value -> new Tuple2<>(value.f0, 
value.f1))
+                        .window(EndOfStreamWindows.get())
+                        .aggregate(new ValueMapFunction())
+                        .keyBy(
+                                (KeySelector<
+                                                Tuple4<
+                                                        Double,
+                                                        Integer,
+                                                        Map<Double, Double>,
+                                                        Double>,
+                                                Object>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .aggregate(new MapArrayFunction())
+                        .windowAll(EndOfStreamWindows.get())
+                        .apply(new GenerateModelFunction(smoothing));
+
+        NaiveBayesModel model =
+                new NaiveBayesModel()
+                        .setModelData(NaiveBayesModelData.fromDataStream(tEnv, 
naiveBayesModel));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static NaiveBayes load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Function to convert each column into tuples of label, feature column 
index, feature value,
+     * weight.
+     */
+    private static class FlattenFunction
+            implements FlatMapFunction<
+                    Tuple2<Vector, Double>, Tuple4<Double, Integer, Double, 
Double>> {

Review comment:
       Would it be better to case `label` to `Integer` instead of `Double`, 
given that the label in `NaiveBayes` should always be `Integer`?
   
   Note that the label value is used as the grouping key in the program. I am 
not sure if there will misclassification due to precision error.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+/** A Model which classifies data using the model data computed by {@link 
NaiveBayes}. */
+public class NaiveBayesModel
+        implements Model<NaiveBayesModel>, 
NaiveBayesModelParams<NaiveBayesModel> {
+    private static final long serialVersionUID = -4673084154965905629L;
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelTable;
+
+    public NaiveBayesModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String predictionCol = getPredictionCol();
+        final String featuresCol = getFeaturesCol();
+        final String broadcastModelKey = "NaiveBayesModelStream";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
TypeInformation.of(Object.class)),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
predictionCol));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelTable).getTableEnvironment();
+        DataStream<NaiveBayesModelData> modelStream =
+                NaiveBayesModelData.toDataStream(tEnv, modelTable);
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(broadcastModelKey, modelStream);
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            this.getClass().getSimpleName(),
+                            outputTypeInfo,
+                            new PredictLabelOperator(
+                                    new PredictLabelFunction(featuresCol, 
broadcastModelKey)));
+                };
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+
+        Table outputTable = tEnv.fromDataStream(output);
+
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelTable).getTableEnvironment();
+
+        String dataPath = ReadWriteUtils.getDataPath(path);
+        FileSink<NaiveBayesModelData> sink =
+                FileSink.forRowFormat(
+                                new Path(dataPath), new 
NaiveBayesModelData.ModelDataEncoder())
+                        .withRollingPolicy(OnCheckpointRollingPolicy.build())
+                        .withBucketAssigner(new BasePathBucketAssigner<>())
+                        .build();
+        NaiveBayesModelData.toDataStream(tEnv, modelTable).sinkTo(sink);
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static NaiveBayesModel load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        Source<NaiveBayesModelData, ?, ?> source =
+                FileSource.forRecordStreamFormat(
+                                new 
NaiveBayesModelData.ModelDataStreamFormat(),
+                                ReadWriteUtils.getDataPaths(path))
+                        .build();
+        NaiveBayesModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<NaiveBayesModelData> modelData =
+                env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"modelData");
+        model.setModelData(NaiveBayesModelData.fromDataStream(tEnv, 
modelData));
+
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public NaiveBayesModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelTable};
+    }
+
+    private static class PredictLabelOperator
+            extends AbstractUdfStreamOperator<Row, PredictLabelFunction>
+            implements OneInputStreamOperator<Row, Row> {
+        public PredictLabelOperator(PredictLabelFunction userFunction) {
+            super(userFunction);
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) {

Review comment:
       Would it be simpler to remove `PredictLabelFunction` and move its logic 
into `PredictLabelOperator::processElement`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+        implements Estimator<NaiveBayes, NaiveBayesModel>, 
NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public NaiveBayes() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final double smoothing = getSmoothing();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Vector, Double>> input =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<Vector, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Vector, Double> map(Row row) 
throws Exception {
+                                        return new Tuple2<>(
+                                                (Vector) 
row.getField(featuresCol),
+                                                (Double) 
row.getField(labelCol));
+                                    }
+                                });
+
+        DataStream<NaiveBayesModelData> naiveBayesModel =

Review comment:
       nits: the name`naiveBayesModel` may indicate this is a `Model` while 
this is actually model data. How about renaming this to `modelData`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.api.core.Stage;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+
+import java.lang.reflect.Method;
+import java.nio.file.Files;
+
+/** Utility methods for testing stages. */
+public class StageTestUtils {
+    /**
+     * Saves a stage to filesystem and reloads it with the static load() 
method a stage must
+     * implement.
+     */
+    public static <T extends Stage<T>> T 
saveAndReload(StreamExecutionEnvironment env, T stage)
+            throws Exception {
+        String tempDir = Files.createTempDirectory("").toString();

Review comment:
       Nice!
   
   Could you help make one more improvement by using 
`org.junit.rules.TemporaryFolder` to create temporary directory in Junit tests? 
This is the approach used in most (if not all) Flink tests (see `KafkaTestBase` 
for example). According to the Java doc of `TemporaryFolder`, this approach can 
delete folders after the test method finishes.
   
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+        implements Estimator<NaiveBayes, NaiveBayesModel>, 
NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public NaiveBayes() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final double smoothing = getSmoothing();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Vector, Double>> input =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<Vector, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Vector, Double> map(Row row) 
throws Exception {
+                                        return new Tuple2<>(
+                                                (Vector) 
row.getField(featuresCol),
+                                                (Double) 
row.getField(labelCol));
+                                    }
+                                });
+
+        DataStream<NaiveBayesModelData> naiveBayesModel =
+                input.flatMap(new FlattenFunction())
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Integer, Double, 
Double>, Object>)
+                                        value -> new Tuple3<>(value.f0, 
value.f1, value.f2))
+                        .window(EndOfStreamWindows.get())
+                        .reduce(
+                                (ReduceFunction<Tuple4<Double, Integer, 
Double, Double>>)
+                                        (t0, t1) -> {
+                                            t0.f3 += t1.f3;
+                                            return t0;
+                                        })
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Integer, Double, 
Double>, Object>)
+                                        value -> new Tuple2<>(value.f0, 
value.f1))
+                        .window(EndOfStreamWindows.get())
+                        .aggregate(new ValueMapFunction())
+                        .keyBy(
+                                (KeySelector<
+                                                Tuple4<
+                                                        Double,
+                                                        Integer,
+                                                        Map<Double, Double>,
+                                                        Double>,
+                                                Object>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .aggregate(new MapArrayFunction())
+                        .windowAll(EndOfStreamWindows.get())
+                        .apply(new GenerateModelFunction(smoothing));
+
+        NaiveBayesModel model =
+                new NaiveBayesModel()
+                        .setModelData(NaiveBayesModelData.fromDataStream(tEnv, 
naiveBayesModel));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static NaiveBayes load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Function to convert each column into tuples of label, feature column 
index, feature value,
+     * weight.
+     */
+    private static class FlattenFunction
+            implements FlatMapFunction<
+                    Tuple2<Vector, Double>, Tuple4<Double, Integer, Double, 
Double>> {
+        @Override
+        public void flatMap(
+                Tuple2<Vector, Double> value,
+                Collector<Tuple4<Double, Integer, Double, Double>> collector) {
+            Preconditions.checkNotNull(value.f1);
+            for (int i = 0; i < value.f0.size(); i++) {
+                collector.collect(new Tuple4<>(value.f1, i, value.f0.get(i), 
1.0));
+            }
+        }
+    }
+
+    /**
+     * Function to aggregate feature value and weight into map from records 
with the same label and
+     * feature column index.
+     */
+    private static class ValueMapFunction
+            implements AggregateFunction<
+                    Tuple4<Double, Integer, Double, Double>,
+                    Tuple3<Double, Integer, Map<Double, Double>>,
+                    Tuple4<Double, Integer, Map<Double, Double>, Double>> {
+
+        @Override
+        public Tuple3<Double, Integer, Map<Double, Double>> 
createAccumulator() {
+            return new Tuple3<>(0., -1, new HashMap<>());
+        }
+
+        @Override
+        public Tuple3<Double, Integer, Map<Double, Double>> add(
+                Tuple4<Double, Integer, Double, Double> value,
+                Tuple3<Double, Integer, Map<Double, Double>> acc) {
+            acc.f0 = value.f0;
+            acc.f1 = value.f1;
+            acc.f2.put(value.f2, value.f3);
+            return acc;
+        }
+
+        @Override
+        public Tuple4<Double, Integer, Map<Double, Double>, Double> getResult(
+                Tuple3<Double, Integer, Map<Double, Double>> acc) {
+            double weightSum = acc.f2.values().stream().mapToDouble(f -> 
f).sum();

Review comment:
       nits: it seems a bit better to use the approach explained in 
https://stackoverflow.com/questions/30125296/how-to-sum-a-list-of-integers-with-java-streams

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Param of the name of the label column in the input table.

Review comment:
       nits: If we follow the pattern of other param subclasses, the comment 
could be `Interface for the shared labelCol param.`
   
   The param <T> is specified but not explained in this comment. We could just 
remove it for consistency with other params.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to