lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r747189344



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/algo/batch/knn/KnnTrainBatchOp.java
##########
@@ -0,0 +1,230 @@
+package org.apache.flink.ml.algo.batch.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistance;
+import org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistanceData;
+import org.apache.flink.ml.algo.batch.knn.distance.FastDistanceMatrixData;
+import org.apache.flink.ml.algo.batch.knn.distance.FastDistanceSparseData;
+import org.apache.flink.ml.algo.batch.knn.distance.FastDistanceVectorData;
+import org.apache.flink.ml.common.BatchOperator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.linalg.DenseVector;
+import org.apache.flink.ml.common.linalg.VectorUtil;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.params.knn.HasKnnDistanceType;
+import org.apache.flink.ml.params.knn.KnnTrainParams;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static 
org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistanceData.pGson;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples. Note that though there is no ``training process`` in KNN, 
we create a ``fake
+ * one`` to use in pipeline model. In this operator, we do some preparation to 
speed up the
+ * inference process.
+ */
+public final class KnnTrainBatchOp extends BatchOperator<KnnTrainBatchOp>

Review comment:
       Could you help explain why we need to have both `KnnClassifier` and 
`KnnTrainBatchOp`? Would it be simpler to merge them into one class?
   
   

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/algo/batch/knn/KnnBatchOpTest.java
##########
@@ -0,0 +1,206 @@
+package org.apache.flink.ml.algo.batch.knn;
+
+import org.apache.flink.api.common.RuntimeExecutionMode;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.iteration.config.IterationOptions;
+import org.apache.flink.ml.api.core.Pipeline;
+import org.apache.flink.ml.api.core.Stage;
+import org.apache.flink.ml.common.BatchOperator;
+import org.apache.flink.ml.common.MLEnvironmentFactory;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class KnnBatchOpTest {
+    private BatchOperator<?> getSourceOp(List<Row> rows) {
+        DataStream<Row> dataStream =
+                MLEnvironmentFactory.getDefault()
+                        .getStreamExecutionEnvironment()
+                        .fromCollection(
+                                rows,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.INT, Types.STRING, 
Types.DOUBLE
+                                        },
+                                        new String[] {"re", "vec", "label"}));
+
+        Table out =
+                MLEnvironmentFactory.getDefault()
+                        .getStreamTableEnvironment()
+                        .fromDataStream(dataStream);
+        return new TableSourceBatchOp(out);
+    }
+
+    private Table getTable(List<Row> rows) {
+        DataStream<Row> dataStream =
+                MLEnvironmentFactory.getDefault()
+                        .getStreamExecutionEnvironment()
+                        .fromCollection(
+                                rows,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.INT, Types.STRING, 
Types.DOUBLE
+                                        },
+                                        new String[] {"re", "vec", "label"}));
+
+        Table out =
+                MLEnvironmentFactory.getDefault()
+                        .getStreamTableEnvironment()
+                        .fromDataStream(dataStream);
+        return out;
+    }
+
+    @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
+    @Test
+    public void testKnnTrainBatchOp() throws Exception {
+
+        StreamExecutionEnvironment.setDefaultLocalParallelism(1);
+        org.apache.flink.streaming.api.environment.StreamExecutionEnvironment 
env =
+                
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
+        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+        Configuration configuration = new Configuration();
+        configuration.set(RestOptions.PORT, 18082);
+        configuration.set(
+                IterationOptions.DATA_CACHE_PATH,
+                "file://" + tempFolder.newFolder().getAbsolutePath());
+        configuration.set(
+                
ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env.getConfig().setGlobalJobParameters(configuration);
+
+        List<Row> rows =
+                Arrays.asList(
+                        Row.of(1, "1 2 3 4", 1.),
+                        Row.of(1, "1 2 3 4.2", 2.),
+                        Row.of(1, "1 2 3 4.3", 3.),
+                        Row.of(1, "1 2 3 4.4", 4.),
+                        Row.of(1, "1 2 3 4.5", 5.),
+                        Row.of(1, "3 2 3 4.6", 6.),
+                        Row.of(1, "1 2 3 4.7", 7.),
+                        Row.of(1, "3 2 3 4.9", 8.));
+
+        BatchOperator source = getSourceOp(rows);
+        BatchOperator<?> knn =
+                new 
KnnTrainBatchOp().setLabelCol("label").setVectorCol("vec").linkFrom(source);
+
+        BatchOperator result =
+                new KnnPredictBatchOp(null)
+                        .setK(2)
+                        .setReservedCols("re", "label")
+                        .setPredictionCol("pred")
+                        .setPredictionDetailCol("detail")
+                        .linkFrom(source, knn);
+
+        MLEnvironmentFactory.getDefault()
+                .getStreamTableEnvironment()
+                .toDataStream(result.getOutput())
+                .addSink(
+                        new SinkFunction<Row>() {
+                            @Override
+                            public void invoke(Row value, Context context) 
throws Exception {
+                                System.out.println("[Output]: " + 
value.toString());
+                            }
+                        });
+        
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().execute();
+    }
+
+    @Test
+    public void testKnnPipeline() throws Exception {
+
+        StreamExecutionEnvironment.setDefaultLocalParallelism(4);
+        org.apache.flink.streaming.api.environment.StreamExecutionEnvironment 
env =
+                
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
+        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+        Configuration configuration = new Configuration();
+        configuration.set(RestOptions.PORT, 18082);
+        configuration.set(
+                IterationOptions.DATA_CACHE_PATH,
+                "file://" + tempFolder.newFolder().getAbsolutePath());
+        configuration.set(
+                
ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env.getConfig().setGlobalJobParameters(configuration);
+
+        List<Row> rows =
+                Arrays.asList(
+                        Row.of(1, "1 2 3 4", 1.),
+                        Row.of(1, "1 2 3 4.2", 2.),
+                        Row.of(1, "1 2 3 4.3", 3.),
+                        Row.of(1, "1 2 3 4.4", 4.),
+                        Row.of(1, "1 2 3 4.5", 5.),
+                        Row.of(1, "3 2 3 4.6", 6.),
+                        Row.of(1, "1 2 3 4.7", 7.),
+                        Row.of(1, "3 2 3 4.9", 8.));
+
+        Table source = getTable(rows);
+        KnnClassifier knn =
+                new KnnClassifier()
+                        .setLabelCol("label")
+                        .setVectorCol("vec")
+                        .setReservedCols("label")
+                        .setK(2)
+                        .setPredictionCol("pred")
+                        .setPredictionDetailCol("detail");
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(knn);

Review comment:
       Since the test here focuses on KNN instead of Pipeline, maybe it is 
better to use fit/transform of knn directly without using Pipeline?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/algo/batch/knn/KnnPredictBatchOp.java
##########
@@ -0,0 +1,345 @@
+package org.apache.flink.ml.algo.batch.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistance;
+import org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistanceData;
+import org.apache.flink.ml.algo.batch.knn.distance.FastDistanceMatrixData;
+import org.apache.flink.ml.algo.batch.knn.distance.FastDistanceSparseData;
+import org.apache.flink.ml.algo.batch.knn.distance.FastDistanceVectorData;
+import org.apache.flink.ml.common.BatchOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.linalg.DenseVector;
+import org.apache.flink.ml.common.linalg.Vector;
+import org.apache.flink.ml.common.linalg.VectorUtil;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.params.knn.HasKnnDistanceType.DistanceType;
+import org.apache.flink.ml.params.knn.KnnPredictParams;
+import org.apache.flink.ml.params.knn.KnnTrainParams;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+import static 
org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistanceData.pGson;
+
+/**
+ * batch map batch operator. you can inherit this class to develop model-based 
prediction batch
+ * operator.
+ */
+public class KnnPredictBatchOp extends BatchOperator<KnnPredictBatchOp>

Review comment:
       Could you explain why we need both `KnnPredictBatchOp` and 
`KnnClassificationModel`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/params/knn/KnnClassifierParams.java
##########
@@ -0,0 +1,4 @@
+package org.apache.flink.ml.params.knn;
+
+/** knn train parameters. */
+public interface KnnClassifierParams<T> extends KnnTrainParams<T>, 
KnnPredictParams<T> {}

Review comment:
       It is probably better to make the class name intuitive and consistent 
with the abstract/interface classes we have introduced in the Flink ML infra.
   
   What is the different between `Predict` and `Model`? if they are the sam, 
how about naming it `KnnModelParams`?
   
   What is the difference between `Train` and `Estimator`? If they are the 
same, maybe we can just name it `KnnParams` (same as the spark convention), 
just name it as `KnnEstimatorParams`. If we do the latter, we probably want to 
make it a practice for all Estimator subclasses.
   
   And what is the difference between `Classifier` and `Estimator`? Note that 
Spark has an `abstract class Classifier` that provides extra semantics on top 
of `Estimator`. So it is makes sense for some algorithm class names to have 
`Classifier` in their names. But we don't have such abstract class in Flink yet.
   
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironmentFactory.java
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common;
+
+import org.apache.flink.util.Preconditions;
+
+import java.util.HashMap;
+
+/**
+ * Factory to get the MLEnvironment using a MLEnvironmentId.
+ *
+ * <p>The following code snippet shows how to interact with 
MLEnvironmentFactory.
+ *
+ * <pre>{@code
+ * long mlEnvId = MLEnvironmentFactory.getNewMLEnvironmentId();
+ * MLEnvironment mlEnv = MLEnvironmentFactory.get(mlEnvId);
+ * }</pre>
+ */
+public class MLEnvironmentFactory {

Review comment:
       Could you explain why we need `MLEnvironmentFactory`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/params/knn/HasKnnDistanceType.java
##########
@@ -0,0 +1,56 @@
+package org.apache.flink.ml.params.knn;
+
+import org.apache.flink.ml.algo.batch.knn.distance.BaseFastDistance;
+import org.apache.flink.ml.algo.batch.knn.distance.CosineDistance;
+import org.apache.flink.ml.algo.batch.knn.distance.EuclideanDistance;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+import org.apache.flink.ml.params.ParamUtil;
+
+import java.io.Serializable;
+
+/** Params: Distance type for clustering, support EUCLIDEAN and COSINE. */
+public interface HasKnnDistanceType<T> extends WithParams<T> {

Review comment:
       And we have use the same convention of Spark by using String-typed param 
instead of enum typed param? IMHO the Spark style is simpler and more readable. 
For example, we could remove the `enum DistanceType` entirely, and just keep 
the actual `DistanceMeasure` subclasses and the string-typed param.
   
   The logic for converting the string-typed param to the actual distance class 
could be handled in a static method. In Spark it is done in 
`DistanceMeasure::decodeFromString`. I did something similar in 
https://github.com/apache/flink-ml/pull/27.
   
   If we agree to do this, feel free to keep the code as is and update it after 
e.g. KMeans with its infra has been committed.
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java
##########
@@ -0,0 +1,577 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.common.linalg;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+    /**
+     * Row dimension.
+     *
+     * <p>Package private to allow access from {@link MatVecOp} and {@link 
BLAS}.
+     */
+    int m;
+
+    /**
+     * Column dimension.
+     *
+     * <p>Package private to allow access from {@link MatVecOp} and {@link 
BLAS}.
+     */
+    int n;
+
+    /**
+     * Array for internal storage of elements.
+     *
+     * <p>Package private to allow access from {@link MatVecOp} and {@link 
BLAS}.
+     *
+     * <p>The matrix data is stored in column major format internally.
+     */
+    double[] data;
+
+    /**
+     * Construct an m-by-n matrix of zeros.
+     *
+     * @param m Number of rows.
+     * @param n Number of columns.
+     */
+    public DenseMatrix(int m, int n) {
+        this(m, n, new double[m * n], false);
+    }
+
+    /**
+     * Construct a matrix from a 1-D array. The data in the array should 
organize in column major.
+     *
+     * @param m Number of rows.
+     * @param n Number of cols.
+     * @param data One-dimensional array of doubles.
+     */
+    public DenseMatrix(int m, int n, double[] data) {
+        this(m, n, data, false);
+    }
+
+    /**
+     * Construct a matrix from a 1-D array. The data in the array is organized 
in column major or in
+     * row major, which is specified by parameter 'inRowMajor'
+     *
+     * @param m Number of rows.
+     * @param n Number of cols.
+     * @param data One-dimensional array of doubles.
+     * @param inRowMajor Whether the matrix in 'data' is in row major format.
+     */
+    public DenseMatrix(int m, int n, double[] data, boolean inRowMajor) {
+        assert (data.length == m * n);
+        this.m = m;
+        this.n = n;
+        if (inRowMajor) {
+            toColumnMajor(m, n, data);
+        }
+        this.data = data;
+    }
+
+    /**
+     * Construct a matrix from a 2-D array.
+     *
+     * @param data Two-dimensional array of doubles.
+     * @throws IllegalArgumentException All rows must have the same size
+     */
+    public DenseMatrix(double[][] data) {
+        this.m = data.length;
+        if (this.m == 0) {
+            this.n = 0;
+            this.data = new double[0];
+            return;
+        }
+        this.n = data[0].length;
+        for (int i = 0; i < m; i++) {
+            if (data[i].length != n) {
+                throw new IllegalArgumentException("All rows must have the 
same size.");
+            }
+        }
+        this.data = new double[m * n];
+        for (int i = 0; i < m; i++) {
+            for (int j = 0; j < n; j++) {
+                this.set(i, j, data[i][j]);
+            }
+        }
+    }
+
+    /**
+     * Create an identity matrix.
+     *
+     * @param n the dimension of the eye matrix.
+     * @return an identity matrix.
+     */
+    public static DenseMatrix eye(int n) {
+        return eye(n, n);
+    }
+
+    /**
+     * Create a m * n identity matrix.
+     *
+     * @param m the row dimension.
+     * @param n the column dimension.e
+     * @return the m * n identity matrix.
+     */
+    public static DenseMatrix eye(int m, int n) {
+        DenseMatrix mat = new DenseMatrix(m, n);
+        int k = Math.min(m, n);
+        for (int i = 0; i < k; i++) {
+            mat.data[i * m + i] = 1.0;
+        }
+        return mat;
+    }
+
+    /**
+     * Create a zero matrix.
+     *
+     * @param m the row dimension.
+     * @param n the column dimension.
+     * @return a m * n zero matrix.
+     */
+    public static DenseMatrix zeros(int m, int n) {
+        return new DenseMatrix(m, n);
+    }
+
+    /**
+     * Create a matrix with all elements set to 1.
+     *
+     * @param m the row dimension
+     * @param n the column dimension
+     * @return the m * n matrix with all elements set to 1.
+     */
+    public static DenseMatrix ones(int m, int n) {
+        DenseMatrix mat = new DenseMatrix(m, n);
+        Arrays.fill(mat.data, 1.);
+        return mat;
+    }
+
+    /**
+     * Create a random matrix.
+     *
+     * @param m the row dimension
+     * @param n the column dimension.
+     * @return a m * n random matrix.
+     */
+    public static DenseMatrix rand(int m, int n) {
+        DenseMatrix mat = new DenseMatrix(m, n);
+        for (int i = 0; i < mat.data.length; i++) {
+            mat.data[i] = Math.random();
+        }
+        return mat;
+    }
+
+    /**
+     * Create a random symmetric matrix.
+     *
+     * @param n the dimension of the symmetric matrix.
+     * @return a n * n random symmetric matrix.
+     */
+    public static DenseMatrix randSymmetric(int n) {
+        DenseMatrix mat = new DenseMatrix(n, n);
+        for (int i = 0; i < n; i++) {
+            for (int j = i; j < n; j++) {
+                double r = Math.random();
+                mat.set(i, j, r);
+                if (i != j) {
+                    mat.set(j, i, r);
+                }
+            }
+        }
+        return mat;
+    }
+
+    /**
+     * Get a single element.
+     *
+     * @param i Row index.
+     * @param j Column index.
+     * @return matA(i, j)
+     * @throws ArrayIndexOutOfBoundsException
+     */
+    public double get(int i, int j) {
+        return data[j * m + i];
+    }
+
+    /**
+     * Get the data array of this matrix.
+     *
+     * @return the data array of this matrix.
+     */
+    public double[] getData() {
+        return this.data;
+    }
+
+    /**
+     * Get all the matrix data, returned as a 2-D array.
+     *
+     * @return all matrix data, returned as a 2-D array.
+     */
+    public double[][] getArrayCopy2D() {
+        double[][] arrayData = new double[m][n];
+        for (int i = 0; i < m; i++) {
+            for (int j = 0; j < n; j++) {
+                arrayData[i][j] = this.get(i, j);
+            }
+        }
+        return arrayData;
+    }
+
+    /**
+     * Get all matrix data, returned as a 1-D array.
+     *
+     * @param inRowMajor Whether to return data in row major.
+     * @return all matrix data, returned as a 1-D array.
+     */
+    public double[] getArrayCopy1D(boolean inRowMajor) {
+        if (inRowMajor) {
+            double[] arrayData = new double[m * n];
+            for (int i = 0; i < m; i++) {
+                for (int j = 0; j < n; j++) {
+                    arrayData[i * n + j] = this.get(i, j);
+                }
+            }
+            return arrayData;
+        } else {
+            return this.data.clone();
+        }
+    }
+
+    /**
+     * Get one row.
+     *
+     * @param row the row index.
+     * @return the row with the given index.
+     */
+    public double[] getRow(int row) {
+        assert (row >= 0 && row < m) : "Invalid row index.";
+        double[] r = new double[n];
+        for (int i = 0; i < n; i++) {
+            r[i] = this.get(row, i);
+        }
+        return r;
+    }
+
+    /**
+     * Get one column.
+     *
+     * @param col the column index.
+     * @return the column with the given index.
+     */
+    public double[] getColumn(int col) {
+        assert (col >= 0 && col < n) : "Invalid column index.";
+        double[] columnData = new double[m];
+        System.arraycopy(this.data, col * m, columnData, 0, m);
+        return columnData;
+    }
+
+    /** Clone the Matrix object. */
+    @Override
+    public DenseMatrix clone() {
+        return new DenseMatrix(this.m, this.n, this.data.clone(), false);
+    }
+
+    /**
+     * Create a new matrix by selecting some of the rows.
+     *
+     * @param rows the array of row indexes to select.
+     * @return a new matrix by selecting some of the rows.
+     */
+    public DenseMatrix selectRows(int[] rows) {
+        DenseMatrix sub = new DenseMatrix(rows.length, this.n);
+        for (int i = 0; i < rows.length; i++) {
+            for (int j = 0; j < this.n; j++) {
+                sub.set(i, j, this.get(rows[i], j));
+            }
+        }
+        return sub;
+    }
+
+    /**
+     * Get sub matrix.
+     *
+     * @param m0 the starting row index (inclusive)
+     * @param m1 the ending row index (exclusive)
+     * @param n0 the starting column index (inclusive)
+     * @param n1 the ending column index (exclusive)
+     * @return the specified sub matrix.
+     */
+    public DenseMatrix getSubMatrix(int m0, int m1, int n0, int n1) {
+        assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index 
range.";
+        DenseMatrix sub = new DenseMatrix(m1 - m0, n1 - n0);
+        for (int i = 0; i < sub.m; i++) {
+            for (int j = 0; j < sub.n; j++) {
+                sub.set(i, j, this.get(m0 + i, n0 + j));
+            }
+        }
+        return sub;
+    }
+
+    /**
+     * Set part of the matrix values from the values of another matrix.
+     *
+     * @param sub the matrix whose element values will be assigned to the sub 
matrix of this matrix.
+     * @param m0 the starting row index (inclusive)
+     * @param m1 the ending row index (exclusive)
+     * @param n0 the starting column index (inclusive)
+     * @param n1 the ending column index (exclusive)
+     */
+    public void setSubMatrix(DenseMatrix sub, int m0, int m1, int n0, int n1) {
+        assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index 
range.";
+        for (int i = 0; i < sub.m; i++) {
+            for (int j = 0; j < sub.n; j++) {
+                this.set(m0 + i, n0 + j, sub.get(i, j));
+            }
+        }
+    }
+
+    /**
+     * Set a single element.
+     *
+     * @param i Row index.
+     * @param j Column index.
+     * @param s A(i,j).
+     * @throws ArrayIndexOutOfBoundsException
+     */
+    public void set(int i, int j, double s) {
+        data[j * m + i] = s;
+    }
+
+    /**
+     * Add the given value to a single element.
+     *
+     * @param i Row index.
+     * @param j Column index.
+     * @param s A(i,j).
+     * @throws ArrayIndexOutOfBoundsException
+     */
+    public void add(int i, int j, double s) {
+        data[j * m + i] += s;
+    }
+
+    /**
+     * Check whether the matrix is square matrix.
+     *
+     * @return <code>true</code> if this matrix is a square matrix, 
<code>false</code> otherwise.
+     */
+    public boolean isSquare() {
+        return m == n;
+    }
+
+    /**
+     * Check whether the matrix is symmetric matrix.
+     *
+     * @return <code>true</code> if this matrix is a symmetric matrix, 
<code>false</code> otherwise.
+     */
+    public boolean isSymmetric() {
+        if (m != n) {
+            return false;
+        }
+        for (int i = 0; i < n; i++) {
+            for (int j = i + 1; j < n; j++) {
+                if (this.get(i, j) != this.get(j, i)) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    /**
+     * Get the number of rows.
+     *
+     * @return the number of rows.
+     */
+    public int numRows() {
+        return m;
+    }
+
+    /**
+     * Get the number of columns.
+     *
+     * @return the number of columns.
+     */
+    public int numCols() {
+        return n;
+    }
+
+    /** Sum of all elements of the matrix. */
+    public double sum() {
+        double s = 0.;
+        for (int i = 0; i < this.data.length; i++) {
+            s += this.data[i];
+        }
+        return s;
+    }
+
+    /** Scale the vector by value "v" and create a new matrix to store the 
result. */
+    public DenseMatrix scale(double v) {
+        DenseMatrix r = this.clone();
+        BLAS.scal(v, r);
+        return r;
+    }
+
+    /** Scale the matrix by value "v". */
+    public void scaleEqual(double v) {
+        BLAS.scal(v, this);
+    }
+
+    /** Create a new matrix by plussing another matrix. */
+    public DenseMatrix plus(DenseMatrix mat) {
+        DenseMatrix r = this.clone();
+        BLAS.axpy(1.0, mat, r);
+        return r;
+    }
+
+    /** Create a new matrix by plussing a constant. */
+    public DenseMatrix plus(double alpha) {
+        DenseMatrix r = this.clone();
+        for (int i = 0; i < r.data.length; i++) {
+            r.data[i] += alpha;
+        }
+        return r;
+    }
+
+    /** Plus with another matrix. */
+    public void plusEquals(DenseMatrix mat) {

Review comment:
       We will need to discuss the API of these Flink ML infra classes. There 
are in general two styles. The approach taken in this PR adds quite a few 
methods here. Another approach is to keep only the very basic operations on the 
DenseMatrix API and move computation methods to a dedicated class (e.g. BLAS). 
This is pretty much what Spark does.
   
   It is probably good to start simple and keep only the basic API needed by 
the algorithms we planned for FFA. We can discuss offline on this as well.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to