Repository: incubator-hivemall
Updated Branches:
  refs/heads/master a8a97d6e8 -> b88e9f5e0


[HIVEMALL-210][BUGFIX] Fix a bug in lda_predict/plsa_predict

## What changes were proposed in this pull request?

Fixed a bug in lda_predict/plsa_predict that duplicated term probability is 
[unexpectedly 
replaced](https://github.com/apache/incubator-hivemall/blame/a8a97d6e873d5a8a30b06f92ddc14d1ec95c2738/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java#L396)

## What type of PR is it?

Bug Fix

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-210

## How was this patch tested?

unit tests and manual tests

## Checklist

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for 
your commit?
- [x] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <m...@apache.org>

Closes #154 from myui/HIVEMALL-210.


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b88e9f5e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b88e9f5e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b88e9f5e

Branch: refs/heads/master
Commit: b88e9f5e0728633dc100dcbb6a701c4fee6f7268
Parents: a8a97d6
Author: Makoto Yui <m...@apache.org>
Authored: Mon Aug 6 16:42:20 2018 +0900
Committer: Makoto Yui <m...@apache.org>
Committed: Mon Aug 6 16:42:20 2018 +0900

----------------------------------------------------------------------
 .../hivemall/topicmodel/LDAPredictUDAF.java     | 18 ++--
 .../hivemall/topicmodel/PLSAPredictUDAF.java    | 22 ++---
 .../hivemall/utils/struct/KeySortablePair.java  | 89 ++++++++++++++++++++
 .../utils/struct/ValueSortablePair.java         | 85 +++++++++++++++++++
 .../hivemall/topicmodel/LDAPredictUDAFTest.java | 47 +++++++++--
 .../topicmodel/PLSAPredictUDAFTest.java         | 47 +++++++++--
 .../utils/struct/KeySortablePairTest.java       | 71 ++++++++++++++++
 .../utils/struct/ValueSortablePairTest.java     | 71 ++++++++++++++++
 8 files changed, 422 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 2befec1..687f20e 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -21,16 +21,16 @@ package hivemall.topicmodel;
 import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.lang.CommandLineUtils;
 import hivemall.utils.lang.Primitives;
+import hivemall.utils.struct.KeySortablePair;
 
 import java.io.PrintWriter;
 import java.io.StringWriter;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
 
 import javax.annotation.Nonnull;
 
@@ -384,20 +384,22 @@ public final class LDAPredictUDAF extends 
AbstractGenericUDAFResolver {
             myAggr.merge(wcList, lambdaMap);
         }
 
+        @SuppressWarnings("unchecked")
         @Override
         public Object terminate(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
                 throws HiveException {
             OnlineLDAPredictAggregationBuffer myAggr = 
(OnlineLDAPredictAggregationBuffer) agg;
-            float[] topicDistr = myAggr.get();
 
-            SortedMap<Float, Integer> sortedDistr =
-                    new TreeMap<Float, Integer>(Collections.reverseOrder());
+            final float[] topicDistr = myAggr.get();
+            final KeySortablePair<Float, Integer>[] sorted =
+                    new KeySortablePair[topicDistr.length];
             for (int i = 0; i < topicDistr.length; i++) {
-                sortedDistr.put(topicDistr[i], i);
+                sorted[i] = new KeySortablePair<>(topicDistr[i], i);
             }
+            Arrays.sort(sorted, Collections.reverseOrder());
 
-            List<Object[]> result = new ArrayList<Object[]>();
-            for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+            final List<Object[]> result = new 
ArrayList<Object[]>(sorted.length);
+            for (KeySortablePair<Float, Integer> e : sorted) {
                 Object[] struct = new Object[2];
                 struct[0] = new IntWritable(e.getValue()); // label
                 struct[1] = new FloatWritable(e.getKey()); // probability

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java 
b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index d9df347..414f980 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -21,16 +21,16 @@ package hivemall.topicmodel;
 import hivemall.utils.hadoop.HiveUtils;
 import hivemall.utils.lang.CommandLineUtils;
 import hivemall.utils.lang.Primitives;
+import hivemall.utils.struct.KeySortablePair;
 
 import java.io.PrintWriter;
 import java.io.StringWriter;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
@@ -384,23 +384,25 @@ public final class PLSAPredictUDAF extends 
AbstractGenericUDAFResolver {
             myAggr.merge(wcList, probMap);
         }
 
+        @SuppressWarnings("unchecked")
         @Override
         public Object terminate(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
                 throws HiveException {
             PLSAPredictAggregationBuffer myAggr = 
(PLSAPredictAggregationBuffer) agg;
-            float[] topicDistr = myAggr.get();
 
-            SortedMap<Float, Integer> sortedDistr =
-                    new TreeMap<Float, Integer>(Collections.reverseOrder());
+            final float[] topicDistr = myAggr.get();
+            final KeySortablePair<Float, Integer>[] sorted =
+                    new KeySortablePair[topicDistr.length];
             for (int i = 0; i < topicDistr.length; i++) {
-                sortedDistr.put(topicDistr[i], i);
+                sorted[i] = new KeySortablePair<>(topicDistr[i], i);
             }
+            Arrays.sort(sorted, Collections.reverseOrder());
 
-            List<Object[]> result = new ArrayList<Object[]>();
-            for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+            final List<Object[]> result = new 
ArrayList<Object[]>(sorted.length);
+            for (KeySortablePair<Float, Integer> e : sorted) {
                 Object[] struct = new Object[2];
-                struct[0] = new IntWritable(e.getValue().intValue()); // label
-                struct[1] = new FloatWritable(e.getKey().floatValue()); // 
probability
+                struct[0] = new IntWritable(e.getValue()); // label
+                struct[1] = new FloatWritable(e.getKey()); // probability
                 result.add(struct);
             }
             return result;

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/utils/struct/KeySortablePair.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/struct/KeySortablePair.java 
b/core/src/main/java/hivemall/utils/struct/KeySortablePair.java
new file mode 100644
index 0000000..f85e7e9
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/struct/KeySortablePair.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import hivemall.utils.lang.Preconditions;
+
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+public final class KeySortablePair<K extends Comparable<K>, V>
+        implements Comparable<KeySortablePair<K, V>> {
+
+    @Nonnull
+    private final K k;
+    @Nullable
+    private final V v;
+
+    public KeySortablePair(@CheckForNull K k, @Nullable V v) {
+        this.k = Preconditions.checkNotNull(k);
+        this.v = v;
+    }
+
+    @Nonnull
+    public K getKey() {
+        return k;
+    }
+
+    @Nullable
+    public V getValue() {
+        return v;
+    }
+
+    @Override
+    public int compareTo(KeySortablePair<K, V> o) {
+        return k.compareTo(o.k);
+    }
+
+    @Override
+    public int hashCode() {
+        final int prime = 31;
+        int result = 1;
+        result = prime * result + k.hashCode();
+        result = prime * result + ((v == null) ? 0 : v.hashCode());
+        return result;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj)
+            return true;
+        if (obj == null)
+            return false;
+        if (getClass() != obj.getClass())
+            return false;
+        KeySortablePair<K, V> other = (KeySortablePair<K, V>) obj;
+        if (!k.equals(other.k))
+            return false;
+        if (v == null) {
+            if (other.v != null)
+                return false;
+        } else if (!v.equals(other.v))
+            return false;
+        return true;
+    }
+
+    @Override
+    public String toString() {
+        return "k=" + k + ", v=" + v;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java 
b/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java
new file mode 100644
index 0000000..891764e
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import hivemall.utils.lang.Preconditions;
+
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+
+public final class ValueSortablePair<K, V extends Comparable<V>>
+        implements Comparable<ValueSortablePair<K, V>> {
+
+    @Nonnull
+    private final K k;
+    @Nonnull
+    private final V v;
+
+    public ValueSortablePair(@CheckForNull K k, @Nonnull V v) {
+        this.k = Preconditions.checkNotNull(k);
+        this.v = Preconditions.checkNotNull(v);
+    }
+
+    @Nonnull
+    public K getKey() {
+        return k;
+    }
+
+    @Nonnull
+    public V getValue() {
+        return v;
+    }
+
+    @Override
+    public int compareTo(ValueSortablePair<K, V> o) {
+        return v.compareTo(o.v);
+    }
+
+    @Override
+    public int hashCode() {
+        final int prime = 31;
+        int result = 1;
+        result = prime * result + k.hashCode();
+        result = prime * result + v.hashCode();
+        return result;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj)
+            return true;
+        if (obj == null)
+            return false;
+        if (getClass() != obj.getClass())
+            return false;
+        ValueSortablePair<K, V> other = (ValueSortablePair<K, V>) obj;
+        if (!k.equals(other.k))
+            return false;
+        if (!v.equals(other.v))
+            return false;
+        return true;
+    }
+
+    @Override
+    public String toString() {
+        return "k=" + k + ", v=" + v;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index bf46485..d6a728f 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -20,6 +20,11 @@ package hivemall.topicmodel;
 
 import hivemall.utils.math.MathUtils;
 
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -27,11 +32,6 @@ import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.HashMap;
-
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
@@ -265,4 +265,41 @@ public class LDAPredictUDAFTest {
         Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc2Distr.length);
         Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
     }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testTerminateWithSameTopicProbability() throws Exception {
+        udaf = new LDAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        // Assume that all words in a document are NOT in vocabulary that 
composes a LDA model.
+        // Hence, the document should be assigned to topic #1 (#2) with 
probability 0.5 (0.5).
+        for (int i = 0; i < 18; i++) {
+            evaluator.iterate(agg, new Object[] {words[i], 0.f, labels[i], 
lambdas[i]});
+        }
+
+        // Probability for each of the two topics should be same.
+        List<Object[]> result = (List<Object[]>) evaluator.terminate(agg);
+        Assert.assertEquals(result.size(), 2);
+        Assert.assertEquals(result.get(0)[1], result.get(1)[1]);
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java 
b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
index 1d364ee..e61222a 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
@@ -20,6 +20,11 @@ package hivemall.topicmodel;
 
 import hivemall.utils.math.MathUtils;
 
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -27,11 +32,6 @@ import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.HashMap;
-
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
@@ -264,4 +264,41 @@ public class PLSAPredictUDAFTest {
         Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc2Distr.length);
         Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
     }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testTerminateWithSameTopicProbability() throws Exception {
+        udaf = new PLSAPredictUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.STRING),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.INT),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                    PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector, 
"-topics 2")};
+
+        evaluator = udaf.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+        agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        // Assume that all words in a document are NOT in vocabulary that 
composes a LDA model.
+        // Hence, the document should be assigned to topic #1 (#2) with 
probability 0.5 (0.5).
+        for (int i = 0; i < words.length; i++) {
+            String word = words[i];
+            evaluator.iterate(agg, new Object[] {word, 0.f, labels[i], 
probs[i]});
+        }
+
+        // Probability for each of the two topics should be same.
+        List<Object[]> result = (List<Object[]>) evaluator.terminate(agg);
+        Assert.assertEquals(result.size(), 2);
+        Assert.assertEquals(result.get(0)[1], result.get(1)[1]);
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java 
b/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java
new file mode 100644
index 0000000..6f3fd70
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.PriorityQueue;
+
+import org.junit.Test;
+
+public class KeySortablePairTest {
+
+    @Test
+    public void testPriorityQueue() {
+        KeySortablePair<Float, Integer> v1 = new KeySortablePair<>(3.f, 1);
+        KeySortablePair<Float, Integer> v2 = new KeySortablePair<>(1.f, 2);
+        KeySortablePair<Float, Integer> v3 = new KeySortablePair<>(4.f, 3);
+        KeySortablePair<Float, Integer> v4 = new KeySortablePair<>(-1.f, 4);
+
+        PriorityQueue<KeySortablePair<Float, Integer>> pq =
+                new PriorityQueue<>(11, Collections.reverseOrder());
+        pq.add(v1);
+        pq.add(v2);
+        pq.add(v3);
+        pq.add(v4);
+
+        assertEquals(Float.valueOf(4.f), pq.poll().getKey());
+        assertEquals(Float.valueOf(3.f), pq.poll().getKey());
+        assertEquals(Float.valueOf(1.f), pq.poll().getKey());
+        assertEquals(Float.valueOf(-1.f), pq.poll().getKey());
+
+        assertTrue(pq.isEmpty());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testArraySort() {
+        KeySortablePair<Float, Integer> v1 = new KeySortablePair<>(3.f, 1);
+        KeySortablePair<Float, Integer> v2 = new KeySortablePair<>(1.f, 2);
+        KeySortablePair<Float, Integer> v3 = new KeySortablePair<>(4.f, 3);
+        KeySortablePair<Float, Integer> v4 = new KeySortablePair<>(-1.f, 4);
+
+        KeySortablePair<Float, Integer>[] arr = new KeySortablePair[] {v1, v2, 
v3, v4};
+        Arrays.sort(arr, Collections.reverseOrder());
+
+        assertEquals(Float.valueOf(4.f), arr[0].getKey());
+        assertEquals(Float.valueOf(3.f), arr[1].getKey());
+        assertEquals(Float.valueOf(1.f), arr[2].getKey());
+        assertEquals(Float.valueOf(-1.f), arr[3].getKey());
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java 
b/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java
new file mode 100644
index 0000000..4829279
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.PriorityQueue;
+
+import org.junit.Test;
+
+public class ValueSortablePairTest {
+
+    @Test
+    public void testPriorityQueue() {
+        ValueSortablePair<Float, Integer> v1 = new ValueSortablePair<>(1.f, 
-1);
+        ValueSortablePair<Float, Integer> v2 = new ValueSortablePair<>(2.f, 3);
+        ValueSortablePair<Float, Integer> v3 = new ValueSortablePair<>(3.f, 2);
+        ValueSortablePair<Float, Integer> v4 = new ValueSortablePair<>(4.f, 0);
+
+        PriorityQueue<ValueSortablePair<Float, Integer>> pq =
+                new PriorityQueue<>(11, Collections.reverseOrder());
+        pq.add(v1);
+        pq.add(v2);
+        pq.add(v3);
+        pq.add(v4);
+
+        assertEquals(3, pq.poll().getValue().intValue());
+        assertEquals(2, pq.poll().getValue().intValue());
+        assertEquals(0, pq.poll().getValue().intValue());
+        assertEquals(-1, pq.poll().getValue().intValue());
+
+        assertTrue(pq.isEmpty());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testArraySort() {
+        ValueSortablePair<Float, Integer> v1 = new ValueSortablePair<>(1.f, 
-1);
+        ValueSortablePair<Float, Integer> v2 = new ValueSortablePair<>(2.f, 3);
+        ValueSortablePair<Float, Integer> v3 = new ValueSortablePair<>(3.f, 2);
+        ValueSortablePair<Float, Integer> v4 = new ValueSortablePair<>(4.f, 0);
+
+        ValueSortablePair<Float, Integer>[] arr = new ValueSortablePair[] {v1, 
v2, v3, v4};
+        Arrays.sort(arr, Collections.reverseOrder());
+
+        assertEquals(3, arr[0].getValue().intValue());
+        assertEquals(2, arr[1].getValue().intValue());
+        assertEquals(0, arr[2].getValue().intValue());
+        assertEquals(-1, arr[3].getValue().intValue());
+    }
+
+}

Reply via email to