vacaly commented on code in PR #192:
URL: https://github.com/apache/flink-ml/pull/192#discussion_r1053881373


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/Swing.java:
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.runtime.typeutils.TypeCheckUtils;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/**
+ * An Estimator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as user-item-user or
+ * item-user-item, which are like 'swing'. For example, if both user 
<em>u</em> and user <em>v</em> have purchased the
+ * same commodity <em>i</em> , they will form a relationship diagram similar 
to a swing. If <em>u</em> and <em>v</em>
+ * have purchased commodity <em>j</em> in addition to <em>i</em>, it is 
supposed <em>i</em> and <em>j</em> are
+ * similar.</p>
+ */
+public class Swing implements Estimator <Swing, SwingModel>, SwingParams 
<Swing> {
+       private final Map <Param <?>, Object> paramMap = new HashMap <>();
+
+       public Swing() {
+               ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+       }
+
+       @Override
+       public SwingModel fit(Table... inputs) {
+
+               final String userCol = getUserCol();
+               final String itemCol = getItemCol();
+               final int minUserItems = getMinUserItems();
+               final int maxUserItems = getMaxUserItems();
+               final ResolvedSchema schema = ((TableImpl) 
inputs[0]).getResolvedSchema();
+               final LogicalType userColType = 
schema.getColumn(userCol).get().getDataType().getLogicalType();
+               final LogicalType itemColType = 
schema.getColumn(itemCol).get().getDataType().getLogicalType();
+
+               if 
(!TypeCheckUtils.isCharacterString(InternalTypeInfo.of(userColType).toLogicalType())
 ||
+                       
!TypeCheckUtils.isCharacterString(InternalTypeInfo.of(itemColType).toLogicalType()))
 {
+                       throw new IllegalArgumentException("Type of user and 
item column must be string.");
+               }
+
+               StreamTableEnvironment tEnv =
+                       (StreamTableEnvironment) ((TableImpl) inputs[0])
+                               .getTableEnvironment();
+
+               SingleOutputStreamOperator <Tuple2 <String, String>> itemUsers =
+                       tEnv.toDataStream(inputs[0])
+                               .map(row -> Tuple2.of((String) 
row.getFieldAs(userCol), (String) row.getFieldAs(itemCol)))
+                               .returns(Types.TUPLE(Types.STRING, 
Types.STRING));
+
+               SingleOutputStreamOperator <Tuple3 <String, String, List 
<String>>> userAllItemsStream = itemUsers
+                       .keyBy(tuple -> tuple.f0)
+                       .transform("fillUserItemsTable",
+                               Types.TUPLE(Types.STRING, Types.STRING, 
Types.LIST(Types.STRING)),
+                               new BuildSwingData(minUserItems, maxUserItems));
+
+               SingleOutputStreamOperator <SwingModelData> similarity = 
userAllItemsStream
+                       .keyBy(tuple -> tuple.f1)
+                       .transform("calculateSimilarity",
+                               Types.ROW(Types.STRING, 
Types.LIST(Types.STRING), Types.LIST(Types.FLOAT)),
+                               new CalculateSimilarity(getTopN()))
+                       .map(new MapFunction <Row, SwingModelData>() {
+                               @Override
+                               public SwingModelData map(Row value) throws 
Exception {
+                                       return new 
SwingModelData(value.getFieldAs(0), value.getFieldAs(1), value.getFieldAs(2));
+                               }
+                       });
+
+               SwingModel model = new 
SwingModel().setModelData(tEnv.fromDataStream(similarity));
+               ReadWriteUtils.updateExistingParams(model, getParamMap());
+               return model;
+       }
+
+       @Override
+       public Map <Param <?>, Object> getParamMap() {
+               return paramMap;
+       }
+
+       @Override
+       public void save(String path) throws IOException {
+               ReadWriteUtils.saveMetadata(this, path);
+       }
+
+       public static Swing load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+               return ReadWriteUtils.loadStageParam(path);
+       }
+
+       /**
+        * Append one column that records all items the user has purchased to 
the input table.
+        *
+        * <p>During the process, this operator collect users and all items a 
user has purchased into a map of list.
+        * When the input is finished, this operator append the certain 
user-purchased-items list to each row. </p>
+        */
+       private static class BuildSwingData
+               extends AbstractStreamOperator <Tuple3 <String, String, List 
<String>>>
+               implements OneInputStreamOperator <
+               Tuple2 <String, String>,
+               Tuple3 <String, String, List <String>>>,
+               BoundedOneInput {
+               final int minUserItems;
+               final int maxUserItems;
+
+               private Map <String, List <String>> userItemsMap = new HashMap 
<>();
+
+               private ListState <Map <String, List <String>>> 
userItemsMapState;
+
+               private BuildSwingData(int minUserItems, int maxUserItems) {
+                       this.minUserItems = minUserItems;
+                       this.maxUserItems = maxUserItems;
+               }
+
+               @Override
+               public void endInput() {
+
+                       for (Entry <String, List <String>> entry :
+                               userItemsMap.entrySet()) {
+                               List <String> items = entry.getValue();
+                               String user = entry.getKey();
+                               if (items.size() < minUserItems || items.size() 
> maxUserItems) {
+                                       continue;
+                               }
+                               for (String item : items) {
+                                       output.collect(
+                                               new StreamRecord <>(
+                                                       new Tuple3 <>(
+                                                               user, item, 
items)));
+                               }
+                       }
+
+                       userItemsMapState.clear();
+               }
+
+               @Override
+               public void processElement(StreamRecord <Tuple2 <String, 
String>> element) {
+                       Tuple2 <String, String> userAndItem =
+                               element.getValue();
+                       String user = userAndItem.f0;
+                       String item = userAndItem.f1;
+                       List <String> items = userItemsMap.get(user);
+
+                       if (items == null) {
+                               ArrayList <String> value = new ArrayList <>();
+                               value.add(item);
+                               userItemsMap.put(user, value);
+                       } else {
+                               if (!items.contains(item)) {
+                                       items.add(item);
+                               }
+                       }
+               }
+
+               @Override
+               public void initializeState(StateInitializationContext context) 
throws Exception {
+                       super.initializeState(context);
+                       userItemsMapState =
+                               context.getOperatorStateStore()
+                                       .getListState(
+                                               new ListStateDescriptor <>(
+                                                       "userItemsMapState",
+                                                       Types.MAP(
+                                                               Types.STRING,
+                                                               Types.LIST(
+                                                                       
Types.STRING)))
+                                       );
+
+                       OperatorStateUtils.getUniqueElement(userItemsMapState, 
"userItemsMapState")
+                               .ifPresent(x -> userItemsMap = x);
+
+               }
+
+               @Override
+               public void snapshotState(StateSnapshotContext context) throws 
Exception {
+                       super.snapshotState(context);
+                       
userItemsMapState.update(Collections.singletonList(userItemsMap));
+               }
+       }
+
+       /**
+        * Calculate top N similar items of each item.
+        */
+       private static class CalculateSimilarity
+               extends AbstractStreamOperator <Row>
+               implements OneInputStreamOperator <
+               Tuple3 <String, String, List <String>>,
+               Row>,
+               BoundedOneInput {
+
+               private Map <String, HashSet <String>> userItemsMap = new 
HashMap <>();
+               private Map <String, HashSet <String>> itemUsersMap = new 
HashMap <>();
+               private ListState <Map <String, List <String>>> 
userItemsMapState;
+               private ListState <Map <String, List <String>>> 
itemUsersMapState;
+
+               final private double alpha = 1.0;
+               final private double userAlpha = 5.0;
+               final private double userBeta = -0.35;
+               final private int topN;
+
+               private CalculateSimilarity(int topN) {this.topN = topN;}
+
+               @Override
+               public void endInput() throws Exception {
+
+                       Map <String, Float> userWeights = new HashMap 
<>(userItemsMap.size());
+                       userItemsMap.forEach((k, v) -> {
+                               int count = v.size();
+                               userWeights.put(k, calculateWeight(count));
+                       });
+
+                       for (String mainItem : itemUsersMap.keySet()) {
+                               List <String> userList = new 
ArrayList(itemUsersMap.get(mainItem));
+                               HashMap <String, Float> id2swing = new HashMap 
<>();
+
+                               for (int i = 0; i < userList.size(); i++) {
+                                       String u = userList.get(i);
+                                       for (int j = i + 1; j < 
userList.size(); j++) {
+                                               String v = userList.get(j);
+                                               HashSet <String> interaction = 
(HashSet <String>) userItemsMap.get(u).clone();

Review Comment:
   You're right. I agree to memorize and reuse them. But I think 
pre-computation of `similarity` for each user pair is impossible, especially 
when the matrix is huge and sparse. Should I use a fixed length array or sparse 
matrix to store high-frequency pairs? Will creating and searching this 
structure cost too much time?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/Swing.java:
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.runtime.typeutils.TypeCheckUtils;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/**
+ * An Estimator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as user-item-user or
+ * item-user-item, which are like 'swing'. For example, if both user 
<em>u</em> and user <em>v</em> have purchased the
+ * same commodity <em>i</em> , they will form a relationship diagram similar 
to a swing. If <em>u</em> and <em>v</em>
+ * have purchased commodity <em>j</em> in addition to <em>i</em>, it is 
supposed <em>i</em> and <em>j</em> are
+ * similar.</p>
+ */
+public class Swing implements Estimator <Swing, SwingModel>, SwingParams 
<Swing> {
+       private final Map <Param <?>, Object> paramMap = new HashMap <>();
+
+       public Swing() {
+               ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+       }
+
+       @Override
+       public SwingModel fit(Table... inputs) {
+
+               final String userCol = getUserCol();
+               final String itemCol = getItemCol();
+               final int minUserItems = getMinUserItems();
+               final int maxUserItems = getMaxUserItems();
+               final ResolvedSchema schema = ((TableImpl) 
inputs[0]).getResolvedSchema();
+               final LogicalType userColType = 
schema.getColumn(userCol).get().getDataType().getLogicalType();
+               final LogicalType itemColType = 
schema.getColumn(itemCol).get().getDataType().getLogicalType();
+
+               if 
(!TypeCheckUtils.isCharacterString(InternalTypeInfo.of(userColType).toLogicalType())
 ||
+                       
!TypeCheckUtils.isCharacterString(InternalTypeInfo.of(itemColType).toLogicalType()))
 {
+                       throw new IllegalArgumentException("Type of user and 
item column must be string.");
+               }
+
+               StreamTableEnvironment tEnv =
+                       (StreamTableEnvironment) ((TableImpl) inputs[0])
+                               .getTableEnvironment();
+
+               SingleOutputStreamOperator <Tuple2 <String, String>> itemUsers =
+                       tEnv.toDataStream(inputs[0])
+                               .map(row -> Tuple2.of((String) 
row.getFieldAs(userCol), (String) row.getFieldAs(itemCol)))
+                               .returns(Types.TUPLE(Types.STRING, 
Types.STRING));
+
+               SingleOutputStreamOperator <Tuple3 <String, String, List 
<String>>> userAllItemsStream = itemUsers
+                       .keyBy(tuple -> tuple.f0)
+                       .transform("fillUserItemsTable",
+                               Types.TUPLE(Types.STRING, Types.STRING, 
Types.LIST(Types.STRING)),
+                               new BuildSwingData(minUserItems, maxUserItems));
+
+               SingleOutputStreamOperator <SwingModelData> similarity = 
userAllItemsStream
+                       .keyBy(tuple -> tuple.f1)
+                       .transform("calculateSimilarity",
+                               Types.ROW(Types.STRING, 
Types.LIST(Types.STRING), Types.LIST(Types.FLOAT)),
+                               new CalculateSimilarity(getTopN()))
+                       .map(new MapFunction <Row, SwingModelData>() {
+                               @Override
+                               public SwingModelData map(Row value) throws 
Exception {
+                                       return new 
SwingModelData(value.getFieldAs(0), value.getFieldAs(1), value.getFieldAs(2));
+                               }
+                       });
+
+               SwingModel model = new 
SwingModel().setModelData(tEnv.fromDataStream(similarity));
+               ReadWriteUtils.updateExistingParams(model, getParamMap());
+               return model;
+       }
+
+       @Override
+       public Map <Param <?>, Object> getParamMap() {
+               return paramMap;
+       }
+
+       @Override
+       public void save(String path) throws IOException {
+               ReadWriteUtils.saveMetadata(this, path);
+       }
+
+       public static Swing load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+               return ReadWriteUtils.loadStageParam(path);
+       }
+
+       /**
+        * Append one column that records all items the user has purchased to 
the input table.
+        *
+        * <p>During the process, this operator collect users and all items a 
user has purchased into a map of list.
+        * When the input is finished, this operator append the certain 
user-purchased-items list to each row. </p>
+        */
+       private static class BuildSwingData
+               extends AbstractStreamOperator <Tuple3 <String, String, List 
<String>>>
+               implements OneInputStreamOperator <
+               Tuple2 <String, String>,
+               Tuple3 <String, String, List <String>>>,
+               BoundedOneInput {
+               final int minUserItems;
+               final int maxUserItems;
+
+               private Map <String, List <String>> userItemsMap = new HashMap 
<>();
+
+               private ListState <Map <String, List <String>>> 
userItemsMapState;
+
+               private BuildSwingData(int minUserItems, int maxUserItems) {
+                       this.minUserItems = minUserItems;
+                       this.maxUserItems = maxUserItems;
+               }
+
+               @Override
+               public void endInput() {
+
+                       for (Entry <String, List <String>> entry :
+                               userItemsMap.entrySet()) {
+                               List <String> items = entry.getValue();
+                               String user = entry.getKey();
+                               if (items.size() < minUserItems || items.size() 
> maxUserItems) {
+                                       continue;
+                               }
+                               for (String item : items) {
+                                       output.collect(
+                                               new StreamRecord <>(
+                                                       new Tuple3 <>(
+                                                               user, item, 
items)));
+                               }
+                       }
+
+                       userItemsMapState.clear();
+               }
+
+               @Override
+               public void processElement(StreamRecord <Tuple2 <String, 
String>> element) {
+                       Tuple2 <String, String> userAndItem =
+                               element.getValue();
+                       String user = userAndItem.f0;
+                       String item = userAndItem.f1;
+                       List <String> items = userItemsMap.get(user);
+
+                       if (items == null) {
+                               ArrayList <String> value = new ArrayList <>();
+                               value.add(item);
+                               userItemsMap.put(user, value);
+                       } else {
+                               if (!items.contains(item)) {
+                                       items.add(item);
+                               }
+                       }
+               }
+
+               @Override
+               public void initializeState(StateInitializationContext context) 
throws Exception {
+                       super.initializeState(context);
+                       userItemsMapState =
+                               context.getOperatorStateStore()
+                                       .getListState(
+                                               new ListStateDescriptor <>(
+                                                       "userItemsMapState",
+                                                       Types.MAP(
+                                                               Types.STRING,
+                                                               Types.LIST(
+                                                                       
Types.STRING)))
+                                       );
+
+                       OperatorStateUtils.getUniqueElement(userItemsMapState, 
"userItemsMapState")
+                               .ifPresent(x -> userItemsMap = x);
+
+               }
+
+               @Override
+               public void snapshotState(StateSnapshotContext context) throws 
Exception {
+                       super.snapshotState(context);
+                       
userItemsMapState.update(Collections.singletonList(userItemsMap));
+               }
+       }
+
+       /**
+        * Calculate top N similar items of each item.
+        */
+       private static class CalculateSimilarity
+               extends AbstractStreamOperator <Row>
+               implements OneInputStreamOperator <
+               Tuple3 <String, String, List <String>>,
+               Row>,
+               BoundedOneInput {
+
+               private Map <String, HashSet <String>> userItemsMap = new 
HashMap <>();
+               private Map <String, HashSet <String>> itemUsersMap = new 
HashMap <>();
+               private ListState <Map <String, List <String>>> 
userItemsMapState;
+               private ListState <Map <String, List <String>>> 
itemUsersMapState;
+
+               final private double alpha = 1.0;
+               final private double userAlpha = 5.0;
+               final private double userBeta = -0.35;
+               final private int topN;
+
+               private CalculateSimilarity(int topN) {this.topN = topN;}
+
+               @Override
+               public void endInput() throws Exception {
+
+                       Map <String, Float> userWeights = new HashMap 
<>(userItemsMap.size());
+                       userItemsMap.forEach((k, v) -> {
+                               int count = v.size();
+                               userWeights.put(k, calculateWeight(count));
+                       });
+
+                       for (String mainItem : itemUsersMap.keySet()) {
+                               List <String> userList = new 
ArrayList(itemUsersMap.get(mainItem));
+                               HashMap <String, Float> id2swing = new HashMap 
<>();
+
+                               for (int i = 0; i < userList.size(); i++) {
+                                       String u = userList.get(i);
+                                       for (int j = i + 1; j < 
userList.size(); j++) {
+                                               String v = userList.get(j);
+                                               HashSet <String> interaction = 
(HashSet <String>) userItemsMap.get(u).clone();

Review Comment:
   Thanks for you  suggestion, I checked the code and am sure that a shallow 
copy can work as intended in this calculation. This copy is used to count 
number same of items in two `HashSet` instances, and has its own`map` that 
points to different `HashMap` instance. As for items in the `map`, this method 
doesn't modify value of items because it is used to count number of 
`interaction` so I think `clone` can work in this envionment.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to