zhipeng93 commented on code in PR #191: URL: https://github.com/apache/flink-ml/pull/191#discussion_r1053903196
########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java: ########## @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Base class for estimators which implement LSH (Locality-sensitive hashing) algorithms. + * + * <p>The basic idea of LSH algorithms is to use to a family of hash functions to map data samples + * to buckets, where closer samples are expected to be in same buckets with higher probabilities, + * and vice versa. AND-amplification and OR-amplification are utilized to increase the recall and + * precision when searching close samples. + * + * <p>An LSH algorithm is specified by its mapping function and corresponding distance metric (see + * {@link LSHScheme}). + * + * <p>See: <a + * href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">Locality-sensitive_hashing</a>. + * + * @param <E> class type of the Estimator implementation itself. + * @param <M> class type of the Model this Estimator produces. + */ +abstract class LSH<E extends Estimator<E, M>, M extends LSHModel<M>> + implements Estimator<E, M>, LSHParams<E> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public LSH() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public M fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Integer> inputDim = getVectorSize(tEnv.toDataStream(inputs[0]), getInputCol()); + return createModel(inputDim, tEnv); + } + + protected abstract M createModel(DataStream<Integer> inputDim, StreamTableEnvironment tEnv); + + private static DataStream<Integer> getVectorSize(DataStream<Row> input, String vectorCol) { + DataStream<Integer> vecSizeDataStream = Review Comment: nit: `vecSizeDataStream` --> `vectorSizes` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHScheme.java: ########## @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; + +/** + * Interface for an LSH scheme. An LSH scheme should implement how to map a feature vector to + * multiple hash vectors, and how to calculate corresponding distance between two feature vectors. + */ +interface LSHScheme { Review Comment: How about renaming it as `LSHModelData` and making it an abstract class? Since different LSH model data are indeed implementations of this interface/class. In this case the class hierachy would be more consistent with lsh estimators and models. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java: ########## @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.lsh.MinHashLSH; +import org.apache.flink.ml.feature.lsh.MinHashLSHModel; +import org.apache.flink.ml.feature.lsh.MinHashLSHModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link MinHashLSH} and {@link MinHashLSHModel}. */ +public class MinHashLSHTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + /** + * Default case for most tests. + * + * @return a tuple including the estimator, input data table, and output rows. + */ + private Tuple3<MinHashLSH, Table, List<Row>> getDefaultCase() { + MinHashLSH lsh = + new MinHashLSH() + .setInputCol("vec") + .setOutputCol("hashes") + .setSeed(2022L) + .setNumHashTables(5) + .setNumHashFunctionsPerTable(3); + + List<Row> inputRows = + Arrays.asList( + Row.of( + 0, + Vectors.sparse(6, new int[] {0, 1, 2}, new double[] {1., 1., 1.})), + Row.of( + 1, + Vectors.sparse(6, new int[] {2, 3, 4}, new double[] {1., 1., 1.})), + Row.of( + 2, + Vectors.sparse(6, new int[] {0, 2, 4}, new double[] {1., 1., 1.}))); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(SparseVector.class)) + .build(); + DataStream<Row> dataStream = env.fromCollection(inputRows); + Table inputTable = tEnv.fromDataStream(dataStream, schema).as("id", "vec"); + + List<Row> outputRows = + convertToOutputFormat( + Arrays.asList( + new double[][] { + {1.73046954E8, 1.57275425E8, 6.90717571E8}, + {5.02301169E8, 7.967141E8, 4.06089319E8}, + {2.83652171E8, 1.97714719E8, 6.04731316E8}, + {5.2181506E8, 6.36933726E8, 6.13894128E8}, + {3.04301769E8, 1.113672955E9, 6.1388711E8} + }, + new double[][] { + {1.73046954E8, 1.57275425E8, 6.7798584E7}, + {6.38582806E8, 1.78703694E8, 4.06089319E8}, + {6.232638E8, 9.28867E7, 9.92010642E8}, + {2.461064E8, 1.12787481E8, 1.92180297E8}, + {2.38162496E8, 1.552933319E9, 2.77995137E8} + }, + new double[][] { + {1.73046954E8, 1.57275425E8, 6.90717571E8}, + {1.453197722E9, 7.967141E8, 4.06089319E8}, + {6.232638E8, 1.97714719E8, 6.04731316E8}, + {2.461064E8, 1.12787481E8, 1.92180297E8}, + {1.224130231E9, 1.113672955E9, 2.77995137E8} + })); + + return Tuple3.of(lsh, inputTable, outputRows); + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + } + + /** + * Convert a list of 2d double arrays to a list of rows with each of which containing a + * DenseVector array. + */ + private static List<Row> convertToOutputFormat(List<double[][]> arrays) { + return arrays.stream() + .map( + array -> { + DenseVector[] denseVectors = + Arrays.stream(array) + .map(Vectors::dense) + .toArray(DenseVector[]::new); + return Row.of((Object) denseVectors); + }) + .collect(Collectors.toList()); + } + + private static class DenseVectorComparator implements Comparator<DenseVector> { Review Comment: Please checkout `org.apache.flink.ml.util.TestUtils.compare` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java: ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.util.Preconditions; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Random; + +/** + * Model data of {@link MinHashLSHModel}. + * + * <p>This class also provides classes to save/load model data. + */ +public class MinHashLSHModelData implements LSHScheme { + + // A large prime smaller than sqrt(2^63 − 1) + private static final int HASH_PRIME = 2038074743; + + public int numHashTables; + public int numHashFunctionsPerTable; + public int[] randCoeffA; + public int[] randCoeffB; + + public MinHashLSHModelData() {} + + public MinHashLSHModelData( + int numHashTables, int numHashFunctionsPerTable, int[] randCoeffA, int[] randCoeffB) { + this.numHashTables = numHashTables; + this.numHashFunctionsPerTable = numHashFunctionsPerTable; + this.randCoeffA = randCoeffA; + this.randCoeffB = randCoeffB; + } + + public static MinHashLSHModelData generateModelData( + int numHashTables, int numHashFunctionsPerTable, int dim, long seed) { + Preconditions.checkArgument( + dim <= HASH_PRIME, + "The input vector dimension %d exceeds the threshold %s.", + dim, + HASH_PRIME); + + Random random = new Random(seed); + int numHashFunctions = numHashTables * numHashFunctionsPerTable; + int[] randCoeffA = new int[numHashFunctions]; + int[] randCoeffB = new int[numHashFunctions]; + for (int i = 0; i < numHashFunctions; i += 1) { + randCoeffA[i] = 1 + random.nextInt(HASH_PRIME - 1); + randCoeffB[i] = random.nextInt(HASH_PRIME - 1); + } + return new MinHashLSHModelData( + numHashTables, numHashFunctionsPerTable, randCoeffA, randCoeffB); + } + + static class ModelDataDecoder extends SimpleStreamFormat<MinHashLSHModelData> { + @Override + public Reader<MinHashLSHModelData> createReader( + Configuration configuration, FSDataInputStream fsDataInputStream) + throws IOException { + return new Reader<MinHashLSHModelData>() { + @Override + public MinHashLSHModelData read() throws IOException { + try { + DataInputViewStreamWrapper source = + new DataInputViewStreamWrapper(fsDataInputStream); + int numHashTables = IntSerializer.INSTANCE.deserialize(source); + int numHashFunctionsPerTable = IntSerializer.INSTANCE.deserialize(source); + int[] randCoeffA = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + int[] randCoeffB = IntPrimitiveArraySerializer.INSTANCE.deserialize(source); + return new MinHashLSHModelData( + numHashTables, numHashFunctionsPerTable, randCoeffA, randCoeffB); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + fsDataInputStream.close(); + } + }; + } + + @Override + public TypeInformation<MinHashLSHModelData> getProducedType() { + return TypeInformation.of(MinHashLSHModelData.class); + } + } + + @Override + public DenseVector[] hashFunction(Vector vec) { + int[] indices = vec.toSparse().indices; + Preconditions.checkArgument(indices.length > 0, "Must have at least 1 non zero entry."); + double[][] hashValues = new double[numHashTables][numHashFunctionsPerTable]; + for (int i = 0; i < numHashTables; i += 1) { + for (int j = 0; j < numHashFunctionsPerTable; j += 1) { + // for each hash function, the hash value is computed by + // min(((1 + index) * randCoefficientA + randCoefficientB) % HASH_PRIME) Review Comment: nit: Let's rename the class variable as `randCoefficientA` and `randCoefficientB`. These two names seems more meaningful. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java: ########## @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.lsh.MinHashLSH; +import org.apache.flink.ml.feature.lsh.MinHashLSHModel; +import org.apache.flink.ml.feature.lsh.MinHashLSHModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link MinHashLSH} and {@link MinHashLSHModel}. */ +public class MinHashLSHTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + /** + * Default case for most tests. + * + * @return a tuple including the estimator, input data table, and output rows. + */ + private Tuple3<MinHashLSH, Table, List<Row>> getDefaultCase() { Review Comment: nit: We can remove this function and declare the returned values as class variables. ########## docs/content/docs/operators/feature/minhashlsh.md: ########## @@ -0,0 +1,276 @@ +--- +title: "MinHash LSH" +weight: 1 +type: docs +aliases: +- /operators/feature/minhashlsh.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## MinHash LSH + +MinHash LSH is a Locality Sensitive Hashing (LSH) scheme for Jaccard distance metric. +The input features are sets of natural numbers represented as non-zero indices of vectors, +either dense vectors or sparse vectors. Typically, sparse vectors are more efficient. + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:----------|:-----------------------| +| inputCol | Vector | `"input"` | Features to be mapped. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:--------------|:-----------|:-------------| +| outputCol | DenseVector[] | `"output"` | Hash values. | + +### Parameters + +| Key | Default | Type | Required | Description | +|-------------------------|-----------------------------------------------------------|---------|----------|--------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| seed | `"org.apache.flink.ml.feature.lsh.MinHashLSH".hashCode()` | Long | no | The random seed. | +| numHashTables | `1` | Integer | no | Default number of hash tables, for OR-amplification. | +| numHashFunctionPerTable | `1` | Integer | no | Default number of hash functions per table, for AND-amplification. | + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.ml.feature.lsh.MinHashLSH; +import org.apache.flink.ml.feature.lsh.MinHashLSHModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.Arrays; +import java.util.List; + +import static org.apache.flink.table.api.Expressions.$; + +/** + * Simple program that trains a MinHashLSH model and uses it for approximate nearest neighbors and + * similarity join. + */ +public class MinHashLSHExample { + public static void main(String[] args) throws Exception { + + // create a new StreamExecutionEnvironment Review Comment: Let's capitalize the first character of the java/python doc and describe it in the third person. Same for other docs. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java: ########## @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.PriorityQueue; + +/** + * Base class for LSH model. + * + * <p>In addition to transforming input feature vectors to multiple hash values, it also supports + * approximate nearest neighbors search within a dataset regarding a key vector and approximate + * similarity join between two datasets. + * + * @param <T> class type of the LSHModel implementation itself. + */ +abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHModelParams<T> { + private static final String MODEL_DATA_BC_KEY = "modelData"; + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + /** Stores the corresponding model data class of T. */ + private final Class<? extends LSHScheme> modelDataClass; + + protected Table modelDataTable; + + public LSHModel(Class<? extends LSHScheme> modelDataClass) { + this.modelDataClass = modelDataClass; + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public T setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return (T) this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<? extends LSHScheme> modelData = + tEnv.toDataStream(modelDataTable, modelDataClass); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + TypeInformation<?> outputType = TypeInformation.of(DenseVector[].class); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(tEnv.toDataStream(inputs[0])), + Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), + inputList -> { + //noinspection unchecked + DataStream<Row> data = (DataStream<Row>) inputList.get(0); + return data.map( + new PredictOutputMapFunction(getInputCol()), outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** + * Approximately finds at most k items from a dataset which have the closest distance to a given + * item . If the `outputCol` is missing in the given dataset, this method transforms the dataset + * with the model at first. + * + * @param dataset The dataset in which to to search for nearest neighbors. + * @param key The item to search for. + * @param k The maximum number of nearest neighbors. + * @param distCol The output column storing the distance between each neighbor and the key. + * @return A dataset containing at most k items closest to the key with a column named `distCol` + * appended. + */ + public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment(); + Table transformedTable = + (dataset.getResolvedSchema().getColumnNames().contains(getOutputCol())) + ? dataset + : transform(dataset)[0]; + + DataStream<? extends LSHScheme> modelData = + tEnv.toDataStream(modelDataTable, modelDataClass); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), distCol)); + + // Fetch items in the same bucket with key's, and calculate their distances to key. + DataStream<Row> filteredData = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(tEnv.toDataStream(transformedTable)), + Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), + inputList -> { + //noinspection unchecked + DataStream<Row> data = (DataStream<Row>) inputList.get(0); + return data.flatMap( + new FilterBySameBucketsFlatMapFunction( + getInputCol(), getOutputCol(), key), + outputTypeInfo); + }); + DataStream<Row> partitionedTopKData = + DataStreamUtils.mapPartition( + filteredData, new TopKMapPartitionFunction(distCol, k)); + DataStream<Row> topKData = + DataStreamUtils.mapPartition( + partitionedTopKData, new TopKMapPartitionFunction(distCol, k)); + topKData.getTransformation().setOutputType(outputTypeInfo); + topKData.getTransformation().setParallelism(1); + return tEnv.fromDataStream(topKData); + } + + /** + * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of + * `distCol`. + */ + public Table approxNearestNeighbors(Table dataset, Vector key, int k) { + return approxNearestNeighbors(dataset, key, k, "distCol"); + } + + /** + * Joins two datasets to approximately find all pairs of rows whose distance are smaller than or + * equal to the threshold. If the `outputCol` is missing in either dataset, this method + * transforms the dataset at first. + * + * @param datasetA One dataset. + * @param datasetB The other dataset. + * @param threshold The distance threshold. + * @param idCol A column in the two datasets to identify each row. + * @param distCol The output column storing the distance between each pair of rows. + * @return A joined dataset containing pairs of rows. The original rows are in columns + * "datasetA" and "datasetB", and a column "distCol" is added to show the distance between + * each pair. + */ + public Table approxSimilarityJoin( + Table datasetA, Table datasetB, double threshold, String idCol, String distCol) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) datasetA).getTableEnvironment(); + + DataStream<Row> explodedA = preprocessData(datasetA, idCol); + DataStream<Row> explodedB = preprocessData(datasetB, idCol); + + DataStream<? extends LSHScheme> modelData = + tEnv.toDataStream(modelDataTable, modelDataClass); + DataStream<Row> sameBucketPairs = + explodedA + .join(explodedB) + .where(new IndexHashValueKeySelector()) + .equalTo(new IndexHashValueKeySelector()) + .window(EndOfStreamWindows.get()) + .apply( + (r0, r1) -> + Row.of( + r0.getField(0), + r1.getField(0), + r0.getField(1), + r1.getField(1))); + DataStream<Row> distinctSameBucketPairs = + DataStreamUtils.reduce( + sameBucketPairs.keyBy( + new KeySelector<Row, Tuple2<Integer, Integer>>() { + @Override + public Tuple2<Integer, Integer> getKey(Row r) { + return Tuple2.of(r.getFieldAs(0), r.getFieldAs(1)); + } + }), + (r0, r1) -> r0); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(datasetA.getResolvedSchema()); + TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol); + DataStream<Row> pairsWithDists = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(distinctSameBucketPairs), + Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), + inputList -> { + DataStream<Row> data = (DataStream<Row>) inputList.get(0); + return data.flatMap( + new FilterByDistanceFlatMapFunction(threshold), + new RowTypeInfo( + new TypeInformation[] { + idColType, idColType, Types.DOUBLE + }, + new String[] {"datasetA.id", "datasetB.id", distCol})); + }); + return tEnv.fromDataStream(pairsWithDists); + } + + /** + * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of + * `distCol`. + */ + public Table approxSimilarityJoin( + Table datasetA, Table datasetB, double threshold, String idCol) { + return approxSimilarityJoin(datasetA, datasetB, threshold, idCol, "distCol"); + } + + private DataStream<Row> preprocessData(Table dataTable, String idCol) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataTable).getTableEnvironment(); + + dataTable = + (dataTable.getResolvedSchema().getColumnNames().contains(getOutputCol())) + ? dataTable + : transform(dataTable)[0]; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema()); + TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol); + final String indexCol = "index"; + final String hashValueCol = "hashValue"; + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] { + idColType, + TypeInformation.of(Vector.class), + Types.INT, + TypeInformation.of(DenseVector.class) + }, + new String[] {idCol, getInputCol(), indexCol, hashValueCol}); + + return tEnv.toDataStream(dataTable) + .flatMap( + new ExplodeHashValuesFlatMapFunction(idCol, getInputCol(), getOutputCol()), + outputTypeInfo); + } + + private static class PredictOutputMapFunction extends RichMapFunction<Row, Row> { + private final String inputCol; + + private LSHScheme modelData; + + public PredictOutputMapFunction(String inputCol) { + this.inputCol = inputCol; + } + + @Override + public Row map(Row value) throws Exception { + if (null == modelData) { + modelData = + (LSHScheme) + getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0); + } + Vector[] hashValues = modelData.hashFunction(value.getFieldAs(inputCol)); + return Row.join(value, Row.of((Object) hashValues)); + } + } + + private static class FilterBySameBucketsFlatMapFunction extends RichFlatMapFunction<Row, Row> { Review Comment: Is `FilterByBucketFunction` a simpler name? Same for other XXXMapFunction, XXXFlatMapFunction, etc. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHParams.java: ########## @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.ml.common.param.HasInputCol; +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.common.param.HasSeed; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params for {@link LSH}. + * + * @param <T> The class type of this instance. + */ +public interface LSHParams<T> extends HasInputCol<T>, HasOutputCol<T>, HasSeed<T> { + Param<Integer> NUM_HASH_TABLES = Review Comment: Let's add some guidelines to tell users how to set this parameter in the java doc. Are `NUM_HASH_TABLE` and `NUM_HASH_FUNCTIONS_PER_TABLE` relevant to the AND-amplification and OR-amplification? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java: ########## @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.PriorityQueue; + +/** + * Base class for LSH model. + * + * <p>In addition to transforming input feature vectors to multiple hash values, it also supports + * approximate nearest neighbors search within a dataset regarding a key vector and approximate + * similarity join between two datasets. + * + * @param <T> class type of the LSHModel implementation itself. + */ +abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHModelParams<T> { + private static final String MODEL_DATA_BC_KEY = "modelData"; + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + /** Stores the corresponding model data class of T. */ + private final Class<? extends LSHScheme> modelDataClass; + + protected Table modelDataTable; + + public LSHModel(Class<? extends LSHScheme> modelDataClass) { + this.modelDataClass = modelDataClass; + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public T setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return (T) this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<? extends LSHScheme> modelData = + tEnv.toDataStream(modelDataTable, modelDataClass); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + TypeInformation<?> outputType = TypeInformation.of(DenseVector[].class); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(tEnv.toDataStream(inputs[0])), + Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), + inputList -> { + //noinspection unchecked + DataStream<Row> data = (DataStream<Row>) inputList.get(0); + return data.map( + new PredictOutputMapFunction(getInputCol()), outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** + * Approximately finds at most k items from a dataset which have the closest distance to a given + * item . If the `outputCol` is missing in the given dataset, this method transforms the dataset + * with the model at first. + * + * @param dataset The dataset in which to to search for nearest neighbors. + * @param key The item to search for. + * @param k The maximum number of nearest neighbors. + * @param distCol The output column storing the distance between each neighbor and the key. + * @return A dataset containing at most k items closest to the key with a column named `distCol` + * appended. + */ + public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment(); + Table transformedTable = + (dataset.getResolvedSchema().getColumnNames().contains(getOutputCol())) + ? dataset + : transform(dataset)[0]; + + DataStream<? extends LSHScheme> modelData = + tEnv.toDataStream(modelDataTable, modelDataClass); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), distCol)); + + // Fetch items in the same bucket with key's, and calculate their distances to key. + DataStream<Row> filteredData = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(tEnv.toDataStream(transformedTable)), + Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), + inputList -> { + //noinspection unchecked + DataStream<Row> data = (DataStream<Row>) inputList.get(0); + return data.flatMap( + new FilterBySameBucketsFlatMapFunction( + getInputCol(), getOutputCol(), key), + outputTypeInfo); + }); + DataStream<Row> partitionedTopKData = Review Comment: Could you please checkout `DataStreamUtils.aggregate()`? `DataStreamUtils.mapPartition()` may not be that efficient since it always caches the data. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java: ########## @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.util.Preconditions; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.Random; + +/** + * Model data of {@link MinHashLSHModel}. + * + * <p>This class also provides classes to save/load model data. + */ +public class MinHashLSHModelData implements LSHScheme { + + // A large prime smaller than sqrt(2^63 − 1) + private static final int HASH_PRIME = 2038074743; + + public int numHashTables; + public int numHashFunctionsPerTable; + public int[] randCoeffA; + public int[] randCoeffB; + + public MinHashLSHModelData() {} + + public MinHashLSHModelData( + int numHashTables, int numHashFunctionsPerTable, int[] randCoeffA, int[] randCoeffB) { + this.numHashTables = numHashTables; + this.numHashFunctionsPerTable = numHashFunctionsPerTable; + this.randCoeffA = randCoeffA; + this.randCoeffB = randCoeffB; + } + + public static MinHashLSHModelData generateModelData( + int numHashTables, int numHashFunctionsPerTable, int dim, long seed) { + Preconditions.checkArgument( + dim <= HASH_PRIME, + "The input vector dimension %d exceeds the threshold %s.", Review Comment: It seems that the dimension of the input vector should always be smaller than `HASH_PRIME`. Let's add this constraint in the java doc of LSH. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java: ########## @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Base class for estimators which implement LSH (Locality-sensitive hashing) algorithms. + * + * <p>The basic idea of LSH algorithms is to use to a family of hash functions to map data samples Review Comment: How about using the following java doc? Moreover, do we have a parameter for enabling `AND-amplification` or `OR-amplification`? ``` /** * Base class for estimators that support LSH (Locality-sensitive hashing) algorithm for * different metrics (e.g., Jaccard distance). * * <p>The basic idea of LSH is to use to a set of hash functions to map input vectors * into different buckets, where closer vectors are expected to be in same bucket with * higher probabilities. In detail, each input vector is hashed by all functions. * To decide whether two input vectors are mapped into the same bucket, two mechanisms * for assigning buckets are proposed as follows. * * <ul> * <li>AND-amplification: The two input vectors are defined to be in the same * bucket as long as ALL of the hash value matches. * <li>OR-amplification: The two input vectors are defined to be in the same * bucket as long as ANY of the hash value matches. * </ul> * * <p>See: <a * href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">Locality-sensitive_hashing</a>. * * @param <E> class type of the Estimator implementation. * @param <M> class type of the Model this Estimator produces. */ ` `` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java: ########## @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; +import java.util.PriorityQueue; + +/** + * Base class for LSH model. + * + * <p>In addition to transforming input feature vectors to multiple hash values, it also supports + * approximate nearest neighbors search within a dataset regarding a key vector and approximate + * similarity join between two datasets. + * + * @param <T> class type of the LSHModel implementation itself. + */ +abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, LSHModelParams<T> { + private static final String MODEL_DATA_BC_KEY = "modelData"; + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + /** Stores the corresponding model data class of T. */ + private final Class<? extends LSHScheme> modelDataClass; + + protected Table modelDataTable; + + public LSHModel(Class<? extends LSHScheme> modelDataClass) { + this.modelDataClass = modelDataClass; + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public T setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return (T) this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<? extends LSHScheme> modelData = + tEnv.toDataStream(modelDataTable, modelDataClass); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + TypeInformation<?> outputType = TypeInformation.of(DenseVector[].class); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(tEnv.toDataStream(inputs[0])), + Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), + inputList -> { + //noinspection unchecked + DataStream<Row> data = (DataStream<Row>) inputList.get(0); + return data.map( + new PredictOutputMapFunction(getInputCol()), outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** + * Approximately finds at most k items from a dataset which have the closest distance to a given Review Comment: How about rephrase the java doc as follows? ``` /** * Approximately finds k nearest neighbors for the given input vector. If the `outputCol` is * missing in the given input table, this method transforms the table with the model first. * * @param inputTable The table in which to search for nearest neighbors. * @param inputVector The input vector to search for nearest neighbors. * @param k The maximum number of nearest neighbors. * @param distCol The output column storing the distance between each neighbor and the key. * @return A table containing at k closest neighbors to the inputVector with a column named `distCol` * appended. */ ``` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java: ########## @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.lsh; + +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; + +import java.io.IOException; + +/** + * An Estimator which implements the MinHash LSH algorithm. Review Comment: How about using the following java doc? ``` /** * An Estimator that implements the MinHash LSH algorithm, which supports LSH for Jaccard distance. * * <p>The input could be dense or sparse vectors. Each input vector must hava at least one non-zero index * and all non-zero values are treated as binary "1" values. * <p>See: <a href="https://en.wikipedia.org/wiki/MinHash">MinHash</a>. */ ``` ########## flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java: ########## @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.lsh.MinHashLSH; +import org.apache.flink.ml.feature.lsh.MinHashLSHModel; +import org.apache.flink.ml.feature.lsh.MinHashLSHModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.table.api.Expressions.$; + +/** Tests {@link MinHashLSH} and {@link MinHashLSHModel}. */ +public class MinHashLSHTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + /** + * Default case for most tests. + * + * @return a tuple including the estimator, input data table, and output rows. + */ + private Tuple3<MinHashLSH, Table, List<Row>> getDefaultCase() { + MinHashLSH lsh = + new MinHashLSH() + .setInputCol("vec") + .setOutputCol("hashes") + .setSeed(2022L) + .setNumHashTables(5) + .setNumHashFunctionsPerTable(3); + + List<Row> inputRows = + Arrays.asList( + Row.of( + 0, + Vectors.sparse(6, new int[] {0, 1, 2}, new double[] {1., 1., 1.})), + Row.of( + 1, + Vectors.sparse(6, new int[] {2, 3, 4}, new double[] {1., 1., 1.})), + Row.of( + 2, + Vectors.sparse(6, new int[] {0, 2, 4}, new double[] {1., 1., 1.}))); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(SparseVector.class)) + .build(); + DataStream<Row> dataStream = env.fromCollection(inputRows); + Table inputTable = tEnv.fromDataStream(dataStream, schema).as("id", "vec"); + + List<Row> outputRows = + convertToOutputFormat( + Arrays.asList( + new double[][] { + {1.73046954E8, 1.57275425E8, 6.90717571E8}, + {5.02301169E8, 7.967141E8, 4.06089319E8}, + {2.83652171E8, 1.97714719E8, 6.04731316E8}, + {5.2181506E8, 6.36933726E8, 6.13894128E8}, + {3.04301769E8, 1.113672955E9, 6.1388711E8} + }, + new double[][] { + {1.73046954E8, 1.57275425E8, 6.7798584E7}, + {6.38582806E8, 1.78703694E8, 4.06089319E8}, + {6.232638E8, 9.28867E7, 9.92010642E8}, + {2.461064E8, 1.12787481E8, 1.92180297E8}, + {2.38162496E8, 1.552933319E9, 2.77995137E8} + }, + new double[][] { + {1.73046954E8, 1.57275425E8, 6.90717571E8}, + {1.453197722E9, 7.967141E8, 4.06089319E8}, + {6.232638E8, 1.97714719E8, 6.04731316E8}, + {2.461064E8, 1.12787481E8, 1.92180297E8}, + {1.224130231E9, 1.113672955E9, 2.77995137E8} + })); + + return Tuple3.of(lsh, inputTable, outputRows); + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.getConfig().enableObjectReuse(); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + } + + /** + * Convert a list of 2d double arrays to a list of rows with each of which containing a + * DenseVector array. + */ + private static List<Row> convertToOutputFormat(List<double[][]> arrays) { + return arrays.stream() + .map( + array -> { + DenseVector[] denseVectors = + Arrays.stream(array) + .map(Vectors::dense) + .toArray(DenseVector[]::new); + return Row.of((Object) denseVectors); + }) + .collect(Collectors.toList()); + } + + private static class DenseVectorComparator implements Comparator<DenseVector> { + @Override + public int compare(DenseVector o1, DenseVector o2) { + if (o1.size() != o2.size()) { + return Integer.compare(o1.size(), o2.size()); + } + for (int i = 0; i < o1.values.length; i += 1) { + int cmp = Double.compare(o1.values[i], o2.values[i]); + if (0 != cmp) { + return cmp; + } + } + return 0; + } + } + + private static class DenseVectorArrayComparator implements Comparator<DenseVector[]> { + @Override + public int compare(DenseVector[] o1, DenseVector[] o2) { + if (o1.length != o2.length) { + return o1.length - o2.length; + } + DenseVectorComparator denseVectorComparator = new DenseVectorComparator(); + for (int i = 0; i < o1.length; i += 1) { + int cmp = denseVectorComparator.compare(o1[i], o2[i]); + if (0 != cmp) { + return cmp; + } + } + return 0; + } + } + + private static void verifyPredictionResult(Table output, List<Row> expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + List<Row> results = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + compareResultCollections( + expected, + results, + (d0, d1) -> { + DenseVectorArrayComparator denseVectorArrayComparator = + new DenseVectorArrayComparator(); + return denseVectorArrayComparator.compare(d0.getFieldAs(0), d1.getFieldAs(0)); + }); + } + + @Test + public void testHashFunction() { + MinHashLSHModelData lshModelData = + new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); + Vector vec = Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); + DenseVector[] result = lshModelData.hashFunction(vec); + Assert.assertEquals(3, result.length); + Assert.assertEquals(Vectors.dense(1.), result[0]); + Assert.assertEquals(Vectors.dense(5.), result[1]); + Assert.assertEquals(Vectors.dense(9.), result[2]); + } + + @Test + public void testHashFunctionEqualWithSparseDenseVector() { + // Use randomly generate coefficients, so that the hash values are not always from the least + // non-zero index. + MinHashLSHModelData lshModelData = MinHashLSHModelData.generateModelData(3, 1, 10, 2022L); + new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); + Vector vec = Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); + DenseVector[] denseResults = lshModelData.hashFunction(vec.toDense()); + DenseVector[] sparseResults = lshModelData.hashFunction(vec.toSparse()); + Assert.assertArrayEquals(denseResults, sparseResults); + } + + @Test(expected = IllegalArgumentException.class) + public void testHashFunctionWithEmptyVector() { + MinHashLSHModelData lshModelData = + new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); + Vector vec = Vectors.sparse(10, new int[] {}, new double[] {}); + lshModelData.hashFunction(vec); + } + + @Test + public void testParam() { + MinHashLSH lsh = new MinHashLSH(); + Assert.assertEquals("input", lsh.getInputCol()); + Assert.assertEquals("output", lsh.getOutputCol()); + Assert.assertEquals(MinHashLSH.class.getName().hashCode(), lsh.getSeed()); + Assert.assertEquals(1, (int) lsh.getNumHashTables()); + Assert.assertEquals(1, (int) lsh.getNumHashFunctionsPerTable()); + lsh.setInputCol("vec") + .setOutputCol("hashes") + .setSeed(2022L) + .setNumHashTables(3) + .setNumHashFunctionsPerTable(4); + Assert.assertEquals("vec", lsh.getInputCol()); + Assert.assertEquals("hashes", lsh.getOutputCol()); + Assert.assertEquals(2022L, lsh.getSeed()); + Assert.assertEquals(3, (int) lsh.getNumHashTables()); Review Comment: We can avoid this casting here by letting `getNumHashTables()` return a `int`. Same for `getNumHashFunctionsPerTable()` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org