yunfengzhou-hub commented on code in PR #171:
URL: https://github.com/apache/flink-ml/pull/171#discussion_r1018560327


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -75,6 +76,9 @@
  * performs the hierarchical clustering on each mini-batch independently. The 
clustering result of
  * each element depends only on the elements in the same mini-batch.
  *
+ * <p>The implementation is based on the nearest-neighbor-chain method 
proposed in "Modern
+ * hierarchical, agglomerative clustering algorithms", by Daniel Mullner.

Review Comment:
   It might be better to avoid discussing implementation details in the 
JavaDoc. We can add this description to the internal method or class as a 
comment.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java:
##########
@@ -205,19 +206,19 @@ public void testTransform() throws Exception {
         Table[] outputs;
         AgglomerativeClustering agglomerativeClustering =
                 new AgglomerativeClustering()
-                        
.setLinkage(AgglomerativeClusteringParams.LINKAGE_AVERAGE)
+                        .setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD)

Review Comment:
   Maybe we can add a new test case to verify that agglomerative clustering can 
now function correctly on euclidean + ward, instead of modifying the 
configuration of several existing test cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -249,175 +264,223 @@ public void process(
             for (int i = 0; i < numDataPoints; i++) {
                 output.collect(Row.join(inputList.get(i), 
Row.of(clusterIds[i])));
             }
+
+            // Outputs the merge info.
+            if (computeFullTree) {
+                stoppedIdx = nnChain.size();
+            }
+            for (int i = 0; i < stoppedIdx; i++) {
+                Tuple4<Integer, Integer, Integer, Double> mergeItem = 
nnChain.get(i);
+                int cid1 = Math.min(mergeItem.f0, mergeItem.f1);
+                int cid2 = Math.max(mergeItem.f0, mergeItem.f1);
+                context.output(
+                        mergeInfoOutputTag,
+                        Tuple4.of(
+                                cid1,
+                                cid2,
+                                mergeItem.f3,
+                                nnChainAndSize.f1[cid1] + 
nnChainAndSize.f1[cid2]));
+            }
         }
 
-        private int getNextClusterId() {
-            return nextClusterId++;
+        /** Reorders the nearest-neighbor-chain. */
+        private void reOrderNnChain(List<Tuple4<Integer, Integer, Integer, 
Double>> nnChain) {
+            int nextClusterId = nnChain.size() + 1;
+            HashMap<Integer, Integer> nodeMapping = new HashMap<>();
+            for (Tuple4<Integer, Integer, Integer, Double> t : nnChain) {
+                if (nodeMapping.containsKey(t.f0)) {
+                    t.f0 = nodeMapping.get(t.f0);
+                }
+                if (nodeMapping.containsKey(t.f1)) {
+                    t.f1 = nodeMapping.get(t.f1);
+                }
+                nodeMapping.put(t.f2, nextClusterId);
+                nextClusterId++;
+            }
         }
 
-        private void doClustering(
-                List<Cluster> activeClusters,
-                ProcessAllWindowFunction<Row, Row, ?>.Context context) {
-            boolean clusteringRunning =
-                    (numCluster != null && activeClusters.size() > numCluster)
-                            || (distanceThreshold != null);
-
-            while (clusteringRunning || (computeFullTree && 
activeClusters.size() > 1)) {
-                int clusterOffset1 = -1, clusterOffset2 = -1;
-                // Computes the distance between two clusters.
-                double minDistance = Double.MAX_VALUE;
-                for (int i = 0; i < activeClusters.size(); i++) {
-                    for (int j = i + 1; j < activeClusters.size(); j++) {
-                        double distance =
-                                computeDistanceBetweenClusters(
-                                        activeClusters.get(i), 
activeClusters.get(j));
-                        if (distance < minDistance) {
-                            minDistance = distance;
-                            clusterOffset1 = i;
-                            clusterOffset2 = j;
+        /** Converts the cluster Ids for each input data point. */
+        private int[] label(
+                List<Tuple4<Integer, Integer, Integer, Double>> nnChains, int 
numDataPoints) {
+            UnionFind unionFind = new UnionFind(numDataPoints);
+            for (Tuple4<Integer, Integer, Integer, Double> t : nnChains) {
+                unionFind.union(unionFind.find(t.f0), unionFind.find(t.f1));
+            }
+            int[] clusterIds = new int[numDataPoints];
+            for (int i = 0; i < clusterIds.length; i++) {
+                clusterIds[i] = unionFind.find(i);
+            }
+            return clusterIds;
+        }
+
+        /** The main logic of nearest-neighbor-chain algorithm. */
+        private Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> 
nnChainCore(
+                HashSet<Integer> nodeLabels, DistanceMatrix distanceMatrix, 
String linkage) {
+            int numDataPoints = nodeLabels.size();
+            int nextClusterId = numDataPoints;
+            List<Tuple4<Integer, Integer, Integer, Double>> nnChain =
+                    new ArrayList<>(numDataPoints);
+            List<Integer> chain = new ArrayList<>();
+            int[] size = new int[numDataPoints * 2 - 1];
+            for (int i = 0; i < numDataPoints; i++) {
+                size[i] = 1;
+            }
+
+            int a, b;
+            while (nodeLabels.size() > 1) {
+                if (chain.size() <= 3) {
+                    Iterator<Integer> iterator = nodeLabels.iterator();
+                    a = iterator.next();
+                    chain.clear();
+                    chain.add(a);
+                    b = iterator.next();
+                } else {
+                    int chainSize = chain.size();
+                    a = chain.get(chainSize - 4);
+                    b = chain.get(chainSize - 3);
+                    chain.remove(chainSize - 1);
+                    chain.remove(chainSize - 2);
+                    chain.remove(chainSize - 3);
+                }
+
+                while (chain.size() < 3 || chain.get(chain.size() - 3) != a) {
+                    double minDistance = Double.MAX_VALUE;
+                    int c = -1;
+                    for (int x : nodeLabels) {
+                        if (x == a) {
+                            continue;
+                        }
+                        double dax = distanceMatrix.get(a, x);
+                        if (dax < minDistance) {
+                            c = x;
+                            minDistance = dax;
                         }
                     }
+                    if (minDistance == distanceMatrix.get(a, b) && 
nodeLabels.contains(b)) {
+                        c = b;
+                    }
+                    b = a;
+                    a = c;
+                    chain.add(a);
                 }
 
-                // Outputs the merge info.
-                Cluster cluster1 = activeClusters.get(clusterOffset1);
-                Cluster cluster2 = activeClusters.get(clusterOffset2);
-                int clusterId1 = cluster1.clusterId;
-                int clusterId2 = cluster2.clusterId;
-                context.output(
-                        mergeInfoOutputTag,
-                        Tuple4.of(
-                                Math.min(clusterId1, clusterId2),
-                                Math.max(clusterId1, clusterId2),
-                                minDistance,
-                                cluster1.dataPointIds.size() + 
cluster2.dataPointIds.size()));
-
-                // Merges these two clusters.
-                Cluster mergedCluster =
-                        new Cluster(
-                                getNextClusterId(), cluster1.dataPointIds, 
cluster2.dataPointIds);
-                activeClusters.set(clusterOffset1, mergedCluster);
-                activeClusters.remove(clusterOffset2);
-
-                // Updates cluster Ids for each data point if clustering is 
still running.
-                if (clusteringRunning) {
-                    int mergedClusterId = mergedCluster.clusterId;
-                    for (int dataPointId : mergedCluster.dataPointIds) {
-                        clusterIds[dataPointId] = mergedClusterId;
-                    }
+                int mergedNodeLabel = nextClusterId;
+                nnChain.add(Tuple4.of(a, b, mergedNodeLabel, 
distanceMatrix.get(a, b)));
+                nodeLabels.remove(a);
+                nodeLabels.remove(b);
+                nextClusterId++;
+                size[mergedNodeLabel] = size[a] + size[b];
+
+                for (int x : nodeLabels) {
+                    double d =
+                            computeClusterDistances(
+                                    distanceMatrix.get(a, x),
+                                    distanceMatrix.get(b, x),
+                                    distanceMatrix.get(a, b),
+                                    size[a],
+                                    size[b],
+                                    size[x],
+                                    linkage);
+                    distanceMatrix.set(x, mergedNodeLabel, d);
                 }
 
-                clusteringRunning =
-                        (numCluster != null && activeClusters.size() > 
numCluster)
-                                || (distanceThreshold != null
-                                        && distanceThreshold > minDistance
-                                        && activeClusters.size() > 1);
+                nodeLabels.add(mergedNodeLabel);
+            }
+
+            return Tuple2.of(nnChain, size);
+        }
+
+        /** Utility class for finding labels for input data points. */
+        private static class UnionFind {
+            private final int[] parent;
+            private int nextLabel;
+
+            public UnionFind(int numDataPoints) {
+                parent = new int[2 * numDataPoints - 1];
+                Arrays.fill(parent, -1);
+                nextLabel = numDataPoints;
+            }
+
+            public void union(int m, int n) {
+                parent[m] = nextLabel;
+                parent[n] = nextLabel;
+                nextLabel++;
+            }
+
+            public int find(int n) {
+                int p = n;
+                while (parent[n] != -1) {
+                    n = parent[n];
+                }
+                while (parent[p] != n && parent[p] != -1) {
+                    p = parent[p];
+                    parent[p] = n;
+                }
+                return n;
             }
         }
 
-        private double computeDistanceBetweenClusters(Cluster cluster1, 
Cluster cluster2) {
-            double distance;
-            int size1 = cluster1.dataPointIds.size();
-            int size2 = cluster2.dataPointIds.size();
+        /** Utility class for storing distances between every two clusters. */
+        private static class DistanceMatrix {
+            /** The storage of distances between each two clusters. */
+            private final double[] distances;
+            /** Number of clusters. */
+            private final int n;
 
+            public DistanceMatrix(int n) {
+                distances = new double[n * (n - 1) / 2];
+                this.n = n;
+            }
+
+            public void set(int i, int j, double value) {
+                int smallIdx = Math.min(i, j);
+                int bigIdx = Math.max(i, j);
+                int offset = (n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - 
smallIdx - 1);
+                distances[offset] = value;
+            }
+
+            public double get(int i, int j) {
+                int smallIdx = Math.min(i, j);
+                int bigIdx = Math.max(i, j);
+                int offset = (n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - 
smallIdx - 1);
+                return distances[offset];
+            }
+        }
+
+        /**
+         * Computes the distance between cluster k and the new cluster merged 
by cluster i and k.

Review Comment:
   What is the relationship between cluster j and cluster i/k?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to