Fanoid commented on code in PR #192:
URL: https://github.com/apache/flink-ml/pull/192#discussion_r1053957894


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/Swing.java:
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.runtime.typeutils.TypeCheckUtils;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/**
+ * An Estimator which implements the Swing algorithm.
+ *
+ * <p>Swing is an item recall model. The topology of user-item graph usually 
can be described as user-item-user or
+ * item-user-item, which are like 'swing'. For example, if both user 
<em>u</em> and user <em>v</em> have purchased the
+ * same commodity <em>i</em> , they will form a relationship diagram similar 
to a swing. If <em>u</em> and <em>v</em>
+ * have purchased commodity <em>j</em> in addition to <em>i</em>, it is 
supposed <em>i</em> and <em>j</em> are
+ * similar.</p>
+ */
+public class Swing implements Estimator <Swing, SwingModel>, SwingParams 
<Swing> {
+       private final Map <Param <?>, Object> paramMap = new HashMap <>();
+
+       public Swing() {
+               ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+       }
+
+       @Override
+       public SwingModel fit(Table... inputs) {
+
+               final String userCol = getUserCol();
+               final String itemCol = getItemCol();
+               final int minUserItems = getMinUserItems();
+               final int maxUserItems = getMaxUserItems();
+               final ResolvedSchema schema = ((TableImpl) 
inputs[0]).getResolvedSchema();
+               final LogicalType userColType = 
schema.getColumn(userCol).get().getDataType().getLogicalType();
+               final LogicalType itemColType = 
schema.getColumn(itemCol).get().getDataType().getLogicalType();
+
+               if 
(!TypeCheckUtils.isCharacterString(InternalTypeInfo.of(userColType).toLogicalType())
 ||
+                       
!TypeCheckUtils.isCharacterString(InternalTypeInfo.of(itemColType).toLogicalType()))
 {
+                       throw new IllegalArgumentException("Type of user and 
item column must be string.");
+               }
+
+               StreamTableEnvironment tEnv =
+                       (StreamTableEnvironment) ((TableImpl) inputs[0])
+                               .getTableEnvironment();
+
+               SingleOutputStreamOperator <Tuple2 <String, String>> itemUsers =
+                       tEnv.toDataStream(inputs[0])
+                               .map(row -> Tuple2.of((String) 
row.getFieldAs(userCol), (String) row.getFieldAs(itemCol)))
+                               .returns(Types.TUPLE(Types.STRING, 
Types.STRING));
+
+               SingleOutputStreamOperator <Tuple3 <String, String, List 
<String>>> userAllItemsStream = itemUsers
+                       .keyBy(tuple -> tuple.f0)
+                       .transform("fillUserItemsTable",
+                               Types.TUPLE(Types.STRING, Types.STRING, 
Types.LIST(Types.STRING)),
+                               new BuildSwingData(minUserItems, maxUserItems));
+
+               SingleOutputStreamOperator <SwingModelData> similarity = 
userAllItemsStream
+                       .keyBy(tuple -> tuple.f1)
+                       .transform("calculateSimilarity",
+                               Types.ROW(Types.STRING, 
Types.LIST(Types.STRING), Types.LIST(Types.FLOAT)),
+                               new CalculateSimilarity(getTopN()))
+                       .map(new MapFunction <Row, SwingModelData>() {
+                               @Override
+                               public SwingModelData map(Row value) throws 
Exception {
+                                       return new 
SwingModelData(value.getFieldAs(0), value.getFieldAs(1), value.getFieldAs(2));
+                               }
+                       });
+
+               SwingModel model = new 
SwingModel().setModelData(tEnv.fromDataStream(similarity));
+               ReadWriteUtils.updateExistingParams(model, getParamMap());
+               return model;
+       }
+
+       @Override
+       public Map <Param <?>, Object> getParamMap() {
+               return paramMap;
+       }
+
+       @Override
+       public void save(String path) throws IOException {
+               ReadWriteUtils.saveMetadata(this, path);
+       }
+
+       public static Swing load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+               return ReadWriteUtils.loadStageParam(path);
+       }
+
+       /**
+        * Append one column that records all items the user has purchased to 
the input table.
+        *
+        * <p>During the process, this operator collect users and all items a 
user has purchased into a map of list.
+        * When the input is finished, this operator append the certain 
user-purchased-items list to each row. </p>
+        */
+       private static class BuildSwingData
+               extends AbstractStreamOperator <Tuple3 <String, String, List 
<String>>>
+               implements OneInputStreamOperator <
+               Tuple2 <String, String>,
+               Tuple3 <String, String, List <String>>>,
+               BoundedOneInput {
+               final int minUserItems;
+               final int maxUserItems;
+
+               private Map <String, List <String>> userItemsMap = new HashMap 
<>();
+
+               private ListState <Map <String, List <String>>> 
userItemsMapState;
+
+               private BuildSwingData(int minUserItems, int maxUserItems) {
+                       this.minUserItems = minUserItems;
+                       this.maxUserItems = maxUserItems;
+               }
+
+               @Override
+               public void endInput() {
+
+                       for (Entry <String, List <String>> entry :
+                               userItemsMap.entrySet()) {
+                               List <String> items = entry.getValue();
+                               String user = entry.getKey();
+                               if (items.size() < minUserItems || items.size() 
> maxUserItems) {
+                                       continue;
+                               }
+                               for (String item : items) {
+                                       output.collect(
+                                               new StreamRecord <>(
+                                                       new Tuple3 <>(
+                                                               user, item, 
items)));
+                               }
+                       }
+
+                       userItemsMapState.clear();
+               }
+
+               @Override
+               public void processElement(StreamRecord <Tuple2 <String, 
String>> element) {
+                       Tuple2 <String, String> userAndItem =
+                               element.getValue();
+                       String user = userAndItem.f0;
+                       String item = userAndItem.f1;
+                       List <String> items = userItemsMap.get(user);
+
+                       if (items == null) {
+                               ArrayList <String> value = new ArrayList <>();
+                               value.add(item);
+                               userItemsMap.put(user, value);
+                       } else {
+                               if (!items.contains(item)) {
+                                       items.add(item);
+                               }
+                       }
+               }
+
+               @Override
+               public void initializeState(StateInitializationContext context) 
throws Exception {
+                       super.initializeState(context);
+                       userItemsMapState =
+                               context.getOperatorStateStore()
+                                       .getListState(
+                                               new ListStateDescriptor <>(
+                                                       "userItemsMapState",
+                                                       Types.MAP(
+                                                               Types.STRING,
+                                                               Types.LIST(
+                                                                       
Types.STRING)))
+                                       );
+
+                       OperatorStateUtils.getUniqueElement(userItemsMapState, 
"userItemsMapState")
+                               .ifPresent(x -> userItemsMap = x);
+
+               }
+
+               @Override
+               public void snapshotState(StateSnapshotContext context) throws 
Exception {
+                       super.snapshotState(context);
+                       
userItemsMapState.update(Collections.singletonList(userItemsMap));
+               }
+       }
+
+       /**
+        * Calculate top N similar items of each item.
+        */
+       private static class CalculateSimilarity
+               extends AbstractStreamOperator <Row>
+               implements OneInputStreamOperator <
+               Tuple3 <String, String, List <String>>,
+               Row>,
+               BoundedOneInput {
+
+               private Map <String, HashSet <String>> userItemsMap = new 
HashMap <>();
+               private Map <String, HashSet <String>> itemUsersMap = new 
HashMap <>();
+               private ListState <Map <String, List <String>>> 
userItemsMapState;
+               private ListState <Map <String, List <String>>> 
itemUsersMapState;
+
+               final private double alpha = 1.0;
+               final private double userAlpha = 5.0;
+               final private double userBeta = -0.35;
+               final private int topN;
+
+               private CalculateSimilarity(int topN) {this.topN = topN;}
+
+               @Override
+               public void endInput() throws Exception {
+
+                       Map <String, Float> userWeights = new HashMap 
<>(userItemsMap.size());
+                       userItemsMap.forEach((k, v) -> {
+                               int count = v.size();
+                               userWeights.put(k, calculateWeight(count));
+                       });
+
+                       for (String mainItem : itemUsersMap.keySet()) {
+                               List <String> userList = new 
ArrayList(itemUsersMap.get(mainItem));
+                               HashMap <String, Float> id2swing = new HashMap 
<>();
+
+                               for (int i = 0; i < userList.size(); i++) {
+                                       String u = userList.get(i);
+                                       for (int j = i + 1; j < 
userList.size(); j++) {
+                                               String v = userList.get(j);
+                                               HashSet <String> interaction = 
(HashSet <String>) userItemsMap.get(u).clone();
+                                               
interaction.retainAll(userItemsMap.get(v));
+                                               if (interaction.size() == 0) {
+                                                       continue;
+                                               }
+                                               float similarity =
+                                                       (float) 
(userWeights.get(u) * userWeights.get(v) / (alpha + interaction.size()));
+                                               for (String simItem : 
interaction) {
+                                                       if 
(simItem.equals(mainItem)) {
+                                                               continue;
+                                                       }
+                                                       float itemSimilarity = 
id2swing.getOrDefault(simItem, (float) 0) + similarity;
+                                                       id2swing.put(simItem, 
itemSimilarity);
+                                               }
+                                       }
+                               }
+
+                               ArrayList <Tuple2 <String, Float>> itemAndScore 
= new ArrayList <>();
+                               id2swing.forEach(
+                                       (key, value) -> 
itemAndScore.add(Tuple2.of(key, value))
+                               );
+
+                               itemAndScore.sort(new Comparator <Tuple2 
<String, Float>>() {
+                                       @Override
+                                       public int compare(Tuple2 <String, 
Float> o1, Tuple2 <String, Float> o2) {
+                                               return 0 - Float.compare(o1.f1, 
o2.f1);
+                                       }
+                               });
+
+                               if (itemAndScore.size() == 0) {
+                                       continue;
+                               }
+
+                               int itemNums = topN > itemAndScore.size() ? 
itemAndScore.size() : topN;
+                               String[] itemIds = new String[itemNums];
+                               Float[] itemScores = new Float[itemNums];
+                               for (int i = 0; i < itemNums; i++) {
+                                       itemIds[i] = itemAndScore.get(i).f0;
+                                       itemScores[i] = itemAndScore.get(i).f1;
+                               }
+
+                               output.collect(
+                                       new StreamRecord <>(
+                                               Row.of(
+                                                       mainItem, new ArrayList 
<String>(Arrays.asList(itemIds)),

Review Comment:
   I just pulled the code, and tried to modify and run it. 
   
   Actually, you met a `ClassCastException` here because `items` and `scores` 
in `SwingModelData` are defined as `ArrayList`. After defining them as 
`ArrayList`, an additional cast is used in `ModelDataDecoder#read`.
   
   At least, the cast in `ModelDataDecoder#read` is inappropriate, as usually 
you cannot assume more information than the API provides.
   For other places, using interfaces usually gives more flexibility than using 
concrete classes.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to