Fanoid commented on code in PR #191:
URL: https://github.com/apache/flink-ml/pull/191#discussion_r1059262975


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java:
##########
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/**
+ * Base class for LSH model.
+ *
+ * <p>In addition to transforming input feature vectors to multiple hash 
values, it also supports
+ * approximate nearest neighbors search within a dataset regarding a key 
vector and approximate
+ * similarity join between two datasets.
+ *
+ * @param <T> class type of the LSHModel implementation itself.
+ */
+abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, 
LSHModelParams<T> {
+    private static final String MODEL_DATA_BC_KEY = "modelData";
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    /** Stores the corresponding model data class of T. */
+    private final Class<? extends LSHScheme> modelDataClass;
+
+    protected Table modelDataTable;
+
+    public LSHModel(Class<? extends LSHScheme> modelDataClass) {
+        this.modelDataClass = modelDataClass;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public T setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return (T) this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<? extends LSHScheme> modelData =
+                tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?> outputType = 
TypeInformation.of(DenseVector[].class);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
outputType),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        
Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) 
inputList.get(0);
+                            return data.map(
+                                    new 
PredictOutputMapFunction(getInputCol()), outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /**
+     * Approximately finds at most k items from a dataset which have the 
closest distance to a given
+     * item . If the `outputCol` is missing in the given dataset, this method 
transforms the dataset
+     * with the model at first.
+     *
+     * @param dataset The dataset in which to to search for nearest neighbors.
+     * @param key The item to search for.
+     * @param k The maximum number of nearest neighbors.
+     * @param distCol The output column storing the distance between each 
neighbor and the key.
+     * @return A dataset containing at most k items closest to the key with a 
column named `distCol`
+     *     appended.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k, 
String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
dataset).getTableEnvironment();
+        Table transformedTable =
+                
(dataset.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataset
+                        : transform(dataset)[0];
+
+        DataStream<? extends LSHScheme> modelData =
+                tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.DOUBLE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
distCol));
+
+        // Fetch items in the same bucket with key's, and calculate their 
distances to key.
+        DataStream<Row> filteredData =
+                BroadcastUtils.withBroadcastStream(
+                        
Collections.singletonList(tEnv.toDataStream(transformedTable)),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) 
inputList.get(0);
+                            return data.flatMap(
+                                    new FilterBySameBucketsFlatMapFunction(
+                                            getInputCol(), getOutputCol(), 
key),
+                                    outputTypeInfo);
+                        });
+        DataStream<Row> partitionedTopKData =

Review Comment:
   I've rewrite this part with `DataStreamUtils.aggregate()`. I think it's more 
concise and efficient.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to