Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java?rev=1380428&view=auto ============================================================================== --- mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java (added) +++ mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java Tue Sep 4 02:18:44 2012 @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import junit.framework.Assert; +import org.apache.mahout.common.RandomUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class MultinomialTest { + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test(expected = IllegalArgumentException.class) + public void testNoValues() { + Multiset<String> emptySet = HashMultiset.create(); + new Multinomial<String>(emptySet); + } + + @Test + public void testSingleton() { + Multiset<String> oneThing = HashMultiset.create(); + oneThing.add("one"); + Multinomial<String> s = new Multinomial<String>(oneThing); + assertEquals("one", s.sample(0)); + assertEquals("one", s.sample(0.1)); + assertEquals("one", s.sample(1)); + } + + @Test + public void testEvenSplit() { + Multiset<String> stuff = HashMultiset.create(); + for (int i = 0; i < 5; i++) { + stuff.add(i + ""); + } + Multinomial<String> s = new Multinomial<String>(stuff); + final double EPSILON = 1e-15; + + Multiset<String> cnt = HashMultiset.create(); + for (int i = 0; i < 5; i++) { + cnt.add(s.sample(i * 0.2)); + cnt.add(s.sample(i * 0.2 + EPSILON)); + cnt.add(s.sample((i + 1) * 0.2 - EPSILON)); + } + + assertEquals(5, cnt.elementSet().size()); + for (String v : cnt.elementSet()) { + assertEquals(3, cnt.count(v), 1.01); + } + assertTrue(cnt.contains(s.sample(1))); + assertEquals(s.sample(1 - EPSILON), s.sample(1)); + } + + @Test + public void testPrime() { + List<String> data = Lists.newArrayList(); + for (int i = 0; i < 17; i++) { + String s = "0"; + if ((i & 1) != 0) { + s = "1"; + } + if ((i & 2) != 0) { + s = "2"; + } + if ((i & 4) != 0) { + s = "3"; + } + if ((i & 8) != 0) { + s = "4"; + } + data.add(s); + } + + Multiset<String> stuff = HashMultiset.create(); + + for (String x : data) { + stuff.add(x); + } + + Multinomial<String> s0 = new Multinomial<String>(stuff); + Multinomial<String> s1 = new Multinomial<String>(stuff); + Multinomial<String> s2 = new Multinomial<String>(stuff); + final double EPSILON = 1e-15; + + Multiset<String> cnt = HashMultiset.create(); + for (int i = 0; i < 50; i++) { + final double p0 = i * 0.02; + final double p1 = (i + 1) * 0.02; + cnt.add(s0.sample(p0)); + cnt.add(s0.sample(p0 + EPSILON)); + cnt.add(s0.sample(p1 - EPSILON)); + + assertEquals(s0.sample(p0), s1.sample(p0)); + assertEquals(s0.sample(p0 + EPSILON), s1.sample(p0 + EPSILON)); + assertEquals(s0.sample(p1 - EPSILON), s1.sample(p1 - EPSILON)); + + assertEquals(s0.sample(p0), s2.sample(p0)); + assertEquals(s0.sample(p0 + EPSILON), s2.sample(p0 + EPSILON)); + assertEquals(s0.sample(p1 - EPSILON), s2.sample(p1 - EPSILON)); + } + + assertEquals(s0.sample(0), s1.sample(0)); + assertEquals(s0.sample(0 + EPSILON), s1.sample(0 + EPSILON)); + assertEquals(s0.sample(1 - EPSILON), s1.sample(1 - EPSILON)); + assertEquals(s0.sample(1), s1.sample(1)); + + assertEquals(s0.sample(0), s2.sample(0)); + assertEquals(s0.sample(0 + EPSILON), s2.sample(0 + EPSILON)); + assertEquals(s0.sample(1 - EPSILON), s2.sample(1 - EPSILON)); + assertEquals(s0.sample(1), s2.sample(1)); + + assertEquals(5, cnt.elementSet().size()); + // regression test, really. These values depend on the original seed and exact algorithm. + // the actual values should be within about 2 of these, however, almost regardless of seed + Map<String, Integer> ref = ImmutableMap.of("3", 35, "2", 18, "1", 9, "0", 16, "4", 72); + for (String v : cnt.elementSet()) { + assertEquals(ref.get(v).intValue(), cnt.count(v)); + } + + assertTrue(cnt.contains(s0.sample(1))); + assertEquals(s0.sample(1 - EPSILON), s0.sample(1)); + } + + @Test + public void testInsert() { + Random rand = new Random(); + Multinomial<Integer> table = new Multinomial<Integer>(); + + double[] p = new double[10]; + for (int i = 0; i < 10; i++) { + p[i] = rand.nextDouble(); + table.add(i, p[i]); + } + + checkSelfConsistent(table); + + for (int i = 0; i < 10; i++) { + Assert.assertEquals(p[i], table.getWeight(i), 0); + } + } + + @Test + public void testDeleteAndUpdate() { + Random rand = new Random(); + Multinomial<Integer> table = new Multinomial<Integer>(); + assertEquals(0, table.getWeight(), 1e-9); + + double total = 0; + double[] p = new double[10]; + for (int i = 0; i < 10; i++) { + p[i] = rand.nextDouble(); + table.add(i, p[i]); + total += p[i]; + assertEquals(total, table.getWeight(), 1e-9); + } + + assertEquals(total, table.getWeight(), 1e-9); + + checkSelfConsistent(table); + + double delta = p[7] + p[8]; + table.delete(7); + p[7] = 0; + + table.set(8, 0); + p[8] = 0; + total -= delta; + + checkSelfConsistent(table); + + assertEquals(total, table.getWeight(), 1e-9); + for (int i = 0; i < 10; i++) { + assertEquals(p[i], table.getWeight(i), 0); + assertEquals(p[i] / total, table.getProbability(i), 1e-10); + } + + table.set(9, 5.1); + total -= p[9]; + p[9] = 5.1; + total += 5.1; + + assertEquals(total , table.getWeight(), 1e-9); + for (int i = 0; i < 10; i++) { + assertEquals(p[i], table.getWeight(i), 0); + assertEquals(p[i] / total, table.getProbability(i), 1e-10); + } + + checkSelfConsistent(table); + + for (int i = 0; i < 10; i++) { + Assert.assertEquals(p[i], table.getWeight(i), 0); + } + } + + private void checkSelfConsistent(Multinomial<Integer> table) { + final List<Double> weights = table.getWeights(); + + double totalWeight = table.getWeight(); + + double p = 0; + int[] k = new int[10]; + for (double weight : weights) { + if (weight > 0) { + if (p > 0) { + k[table.sample(p - 1e-9)]++; + } + k[table.sample(p + 1e-9)]++; + } + p += weight / totalWeight; + } + k[table.sample(p - 1e-9)]++; + Assert.assertEquals(1, p, 1e-9); + + for (int i = 0; i < 10; i++) { + if (table.getWeight(i) > 0) { + Assert.assertEquals(k[i], 2); + } + } + } +}
Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java?rev=1380428&view=auto ============================================================================== --- mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java (added) +++ mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java Tue Sep 4 02:18:44 2012 @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.commons.math.MathException; +import org.apache.commons.math.distribution.NormalDistribution; +import org.apache.commons.math.distribution.NormalDistributionImpl; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.stats.OnlineSummarizer; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; + +public class NormalTest { + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test + public void testOffset() { + OnlineSummarizer s = new OnlineSummarizer(); + Sampler<Double> sampler = new Normal(2, 5); + for (int i = 0; i < 10001; i++) { + s.add(sampler.sample()); + } + + assertEquals(String.format("m = %.3f, sd = %.3f", s.getMean(), s.getSD()), 2, s.getMean(), 0.04 * s.getSD()); + assertEquals(5, s.getSD(), 0.12); + } + + @Test + public void testSample() throws MathException { + double[] data = new double[10001]; + Sampler<Double> sampler = new Normal(); + for (int i = 0; i < 10001; i++) { + data[i] = sampler.sample(); + } + Arrays.sort(data); + + NormalDistribution reference = new NormalDistributionImpl(); + + assertEquals("Median", reference.inverseCumulativeProbability(0.5), data[5000], 0.04); + } +} Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java?rev=1380428&view=auto ============================================================================== --- mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java (added) +++ mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java Tue Sep 4 02:18:44 2012 @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.commons.math.distribution.PoissonDistribution; +import org.apache.commons.math.distribution.PoissonDistributionImpl; +import org.apache.mahout.common.RandomUtils; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class PoissonSamplerTest { + @Before + public void setUp() { + RandomUtils.useTestSeed(); + } + + @Test + public void testBasics() { + for (double alpha : new double[]{0.1, 1, 10, 100}) { + checkDistribution(new PoissonSampler(alpha), alpha); + } + } + + private void checkDistribution(PoissonSampler pd, double alpha) { + int[] count = new int[(int) Math.max(10, 5 * alpha)]; + for (int i = 0; i < 10000; i++) { + count[pd.sample().intValue()]++; + } + + PoissonDistribution ref = new PoissonDistributionImpl(alpha); + for (int i = 0; i < count.length; i++) { + assertEquals(ref.probability(i), count[i] / 10000.0, 2e-2); + } + } +}
