Repository: incubator-hivemall Updated Branches: refs/heads/master a8a97d6e8 -> b88e9f5e0
[HIVEMALL-210][BUGFIX] Fix a bug in lda_predict/plsa_predict ## What changes were proposed in this pull request? Fixed a bug in lda_predict/plsa_predict that duplicated term probability is [unexpectedly replaced](https://github.com/apache/incubator-hivemall/blame/a8a97d6e873d5a8a30b06f92ddc14d1ec95c2738/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java#L396) ## What type of PR is it? Bug Fix ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-210 ## How was this patch tested? unit tests and manual tests ## Checklist - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <m...@apache.org> Closes #154 from myui/HIVEMALL-210. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b88e9f5e Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b88e9f5e Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b88e9f5e Branch: refs/heads/master Commit: b88e9f5e0728633dc100dcbb6a701c4fee6f7268 Parents: a8a97d6 Author: Makoto Yui <m...@apache.org> Authored: Mon Aug 6 16:42:20 2018 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Mon Aug 6 16:42:20 2018 +0900 ---------------------------------------------------------------------- .../hivemall/topicmodel/LDAPredictUDAF.java | 18 ++-- .../hivemall/topicmodel/PLSAPredictUDAF.java | 22 ++--- .../hivemall/utils/struct/KeySortablePair.java | 89 ++++++++++++++++++++ .../utils/struct/ValueSortablePair.java | 85 +++++++++++++++++++ .../hivemall/topicmodel/LDAPredictUDAFTest.java | 47 +++++++++-- .../topicmodel/PLSAPredictUDAFTest.java | 47 +++++++++-- .../utils/struct/KeySortablePairTest.java | 71 ++++++++++++++++ .../utils/struct/ValueSortablePairTest.java | 71 ++++++++++++++++ 8 files changed, 422 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java index 2befec1..687f20e 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java @@ -21,16 +21,16 @@ package hivemall.topicmodel; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.CommandLineUtils; import hivemall.utils.lang.Primitives; +import hivemall.utils.struct.KeySortablePair; import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; import javax.annotation.Nonnull; @@ -384,20 +384,22 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { myAggr.merge(wcList, lambdaMap); } + @SuppressWarnings("unchecked") @Override public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; - float[] topicDistr = myAggr.get(); - SortedMap<Float, Integer> sortedDistr = - new TreeMap<Float, Integer>(Collections.reverseOrder()); + final float[] topicDistr = myAggr.get(); + final KeySortablePair<Float, Integer>[] sorted = + new KeySortablePair[topicDistr.length]; for (int i = 0; i < topicDistr.length; i++) { - sortedDistr.put(topicDistr[i], i); + sorted[i] = new KeySortablePair<>(topicDistr[i], i); } + Arrays.sort(sorted, Collections.reverseOrder()); - List<Object[]> result = new ArrayList<Object[]>(); - for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) { + final List<Object[]> result = new ArrayList<Object[]>(sorted.length); + for (KeySortablePair<Float, Integer> e : sorted) { Object[] struct = new Object[2]; struct[0] = new IntWritable(e.getValue()); // label struct[1] = new FloatWritable(e.getKey()); // probability http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java index d9df347..414f980 100644 --- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java @@ -21,16 +21,16 @@ package hivemall.topicmodel; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.CommandLineUtils; import hivemall.utils.lang.Primitives; +import hivemall.utils.struct.KeySortablePair; import java.io.PrintWriter; import java.io.StringWriter; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -384,23 +384,25 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver { myAggr.merge(wcList, probMap); } + @SuppressWarnings("unchecked") @Override public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; - float[] topicDistr = myAggr.get(); - SortedMap<Float, Integer> sortedDistr = - new TreeMap<Float, Integer>(Collections.reverseOrder()); + final float[] topicDistr = myAggr.get(); + final KeySortablePair<Float, Integer>[] sorted = + new KeySortablePair[topicDistr.length]; for (int i = 0; i < topicDistr.length; i++) { - sortedDistr.put(topicDistr[i], i); + sorted[i] = new KeySortablePair<>(topicDistr[i], i); } + Arrays.sort(sorted, Collections.reverseOrder()); - List<Object[]> result = new ArrayList<Object[]>(); - for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) { + final List<Object[]> result = new ArrayList<Object[]>(sorted.length); + for (KeySortablePair<Float, Integer> e : sorted) { Object[] struct = new Object[2]; - struct[0] = new IntWritable(e.getValue().intValue()); // label - struct[1] = new FloatWritable(e.getKey().floatValue()); // probability + struct[0] = new IntWritable(e.getValue()); // label + struct[1] = new FloatWritable(e.getKey()); // probability result.add(struct); } return result; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/utils/struct/KeySortablePair.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/struct/KeySortablePair.java b/core/src/main/java/hivemall/utils/struct/KeySortablePair.java new file mode 100644 index 0000000..f85e7e9 --- /dev/null +++ b/core/src/main/java/hivemall/utils/struct/KeySortablePair.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.struct; + +import hivemall.utils.lang.Preconditions; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +public final class KeySortablePair<K extends Comparable<K>, V> + implements Comparable<KeySortablePair<K, V>> { + + @Nonnull + private final K k; + @Nullable + private final V v; + + public KeySortablePair(@CheckForNull K k, @Nullable V v) { + this.k = Preconditions.checkNotNull(k); + this.v = v; + } + + @Nonnull + public K getKey() { + return k; + } + + @Nullable + public V getValue() { + return v; + } + + @Override + public int compareTo(KeySortablePair<K, V> o) { + return k.compareTo(o.k); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + k.hashCode(); + result = prime * result + ((v == null) ? 0 : v.hashCode()); + return result; + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + KeySortablePair<K, V> other = (KeySortablePair<K, V>) obj; + if (!k.equals(other.k)) + return false; + if (v == null) { + if (other.v != null) + return false; + } else if (!v.equals(other.v)) + return false; + return true; + } + + @Override + public String toString() { + return "k=" + k + ", v=" + v; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java b/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java new file mode 100644 index 0000000..891764e --- /dev/null +++ b/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.struct; + +import hivemall.utils.lang.Preconditions; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; + +public final class ValueSortablePair<K, V extends Comparable<V>> + implements Comparable<ValueSortablePair<K, V>> { + + @Nonnull + private final K k; + @Nonnull + private final V v; + + public ValueSortablePair(@CheckForNull K k, @Nonnull V v) { + this.k = Preconditions.checkNotNull(k); + this.v = Preconditions.checkNotNull(v); + } + + @Nonnull + public K getKey() { + return k; + } + + @Nonnull + public V getValue() { + return v; + } + + @Override + public int compareTo(ValueSortablePair<K, V> o) { + return v.compareTo(o.v); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + k.hashCode(); + result = prime * result + v.hashCode(); + return result; + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + ValueSortablePair<K, V> other = (ValueSortablePair<K, V>) obj; + if (!k.equals(other.k)) + return false; + if (!v.equals(other.v)) + return false; + return true; + } + + @Override + public String toString() { + return "k=" + k + ", v=" + v; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java index bf46485..d6a728f 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java @@ -20,6 +20,11 @@ package hivemall.topicmodel; import hivemall.utils.math.MathUtils; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -27,11 +32,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; - -import java.util.ArrayList; -import java.util.Map; -import java.util.HashMap; - import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -265,4 +265,41 @@ public class LDAPredictUDAFTest { Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc2Distr.length); Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d); } + + @SuppressWarnings("unchecked") + @Test + public void testTerminateWithSameTopicProbability() throws Exception { + udaf = new LDAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer(); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + // Assume that all words in a document are NOT in vocabulary that composes a LDA model. + // Hence, the document should be assigned to topic #1 (#2) with probability 0.5 (0.5). + for (int i = 0; i < 18; i++) { + evaluator.iterate(agg, new Object[] {words[i], 0.f, labels[i], lambdas[i]}); + } + + // Probability for each of the two topics should be same. + List<Object[]> result = (List<Object[]>) evaluator.terminate(agg); + Assert.assertEquals(result.size(), 2); + Assert.assertEquals(result.get(0)[1], result.get(1)[1]); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java index 1d364ee..e61222a 100644 --- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java +++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java @@ -20,6 +20,11 @@ package hivemall.topicmodel; import hivemall.utils.math.MathUtils; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -27,11 +32,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; - -import java.util.ArrayList; -import java.util.Map; -import java.util.HashMap; - import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -264,4 +264,41 @@ public class PLSAPredictUDAFTest { Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc2Distr.length); Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d); } + + @SuppressWarnings("unchecked") + @Test + public void testTerminateWithSameTopicProbability() throws Exception { + udaf = new PLSAPredictUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.STRING), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.FLOAT), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; + + evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer(); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + // Assume that all words in a document are NOT in vocabulary that composes a LDA model. + // Hence, the document should be assigned to topic #1 (#2) with probability 0.5 (0.5). + for (int i = 0; i < words.length; i++) { + String word = words[i]; + evaluator.iterate(agg, new Object[] {word, 0.f, labels[i], probs[i]}); + } + + // Probability for each of the two topics should be same. + List<Object[]> result = (List<Object[]>) evaluator.terminate(agg); + Assert.assertEquals(result.size(), 2); + Assert.assertEquals(result.get(0)[1], result.get(1)[1]); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java b/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java new file mode 100644 index 0000000..6f3fd70 --- /dev/null +++ b/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.struct; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.PriorityQueue; + +import org.junit.Test; + +public class KeySortablePairTest { + + @Test + public void testPriorityQueue() { + KeySortablePair<Float, Integer> v1 = new KeySortablePair<>(3.f, 1); + KeySortablePair<Float, Integer> v2 = new KeySortablePair<>(1.f, 2); + KeySortablePair<Float, Integer> v3 = new KeySortablePair<>(4.f, 3); + KeySortablePair<Float, Integer> v4 = new KeySortablePair<>(-1.f, 4); + + PriorityQueue<KeySortablePair<Float, Integer>> pq = + new PriorityQueue<>(11, Collections.reverseOrder()); + pq.add(v1); + pq.add(v2); + pq.add(v3); + pq.add(v4); + + assertEquals(Float.valueOf(4.f), pq.poll().getKey()); + assertEquals(Float.valueOf(3.f), pq.poll().getKey()); + assertEquals(Float.valueOf(1.f), pq.poll().getKey()); + assertEquals(Float.valueOf(-1.f), pq.poll().getKey()); + + assertTrue(pq.isEmpty()); + } + + @SuppressWarnings("unchecked") + @Test + public void testArraySort() { + KeySortablePair<Float, Integer> v1 = new KeySortablePair<>(3.f, 1); + KeySortablePair<Float, Integer> v2 = new KeySortablePair<>(1.f, 2); + KeySortablePair<Float, Integer> v3 = new KeySortablePair<>(4.f, 3); + KeySortablePair<Float, Integer> v4 = new KeySortablePair<>(-1.f, 4); + + KeySortablePair<Float, Integer>[] arr = new KeySortablePair[] {v1, v2, v3, v4}; + Arrays.sort(arr, Collections.reverseOrder()); + + assertEquals(Float.valueOf(4.f), arr[0].getKey()); + assertEquals(Float.valueOf(3.f), arr[1].getKey()); + assertEquals(Float.valueOf(1.f), arr[2].getKey()); + assertEquals(Float.valueOf(-1.f), arr[3].getKey()); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java b/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java new file mode 100644 index 0000000..4829279 --- /dev/null +++ b/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.struct; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.PriorityQueue; + +import org.junit.Test; + +public class ValueSortablePairTest { + + @Test + public void testPriorityQueue() { + ValueSortablePair<Float, Integer> v1 = new ValueSortablePair<>(1.f, -1); + ValueSortablePair<Float, Integer> v2 = new ValueSortablePair<>(2.f, 3); + ValueSortablePair<Float, Integer> v3 = new ValueSortablePair<>(3.f, 2); + ValueSortablePair<Float, Integer> v4 = new ValueSortablePair<>(4.f, 0); + + PriorityQueue<ValueSortablePair<Float, Integer>> pq = + new PriorityQueue<>(11, Collections.reverseOrder()); + pq.add(v1); + pq.add(v2); + pq.add(v3); + pq.add(v4); + + assertEquals(3, pq.poll().getValue().intValue()); + assertEquals(2, pq.poll().getValue().intValue()); + assertEquals(0, pq.poll().getValue().intValue()); + assertEquals(-1, pq.poll().getValue().intValue()); + + assertTrue(pq.isEmpty()); + } + + @SuppressWarnings("unchecked") + @Test + public void testArraySort() { + ValueSortablePair<Float, Integer> v1 = new ValueSortablePair<>(1.f, -1); + ValueSortablePair<Float, Integer> v2 = new ValueSortablePair<>(2.f, 3); + ValueSortablePair<Float, Integer> v3 = new ValueSortablePair<>(3.f, 2); + ValueSortablePair<Float, Integer> v4 = new ValueSortablePair<>(4.f, 0); + + ValueSortablePair<Float, Integer>[] arr = new ValueSortablePair[] {v1, v2, v3, v4}; + Arrays.sort(arr, Collections.reverseOrder()); + + assertEquals(3, arr[0].getValue().intValue()); + assertEquals(2, arr[1].getValue().intValue()); + assertEquals(0, arr[2].getValue().intValue()); + assertEquals(-1, arr[3].getValue().intValue()); + } + +}