http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java new file mode 100644 index 0000000..f2899c2 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamEntry; +import org.apache.ignite.ml.dataset.UpstreamTransformer; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; +import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironment; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.structures.LabeledVectorSet; +import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link DataStreamGenerator}. + */ +public class DataStreamGeneratorTest { + /** */ + @Test + public void testUnlabeled() { + DataStreamGenerator generator = new DataStreamGenerator() { + @Override public Stream<LabeledVector<Vector, Double>> labeled() { + return Stream.generate(() -> new LabeledVector<>(VectorUtils.of(1., 2.), 100.)); + } + }; + + generator.unlabeled().limit(100).forEach(v -> { + assertArrayEquals(new double[] {1., 2.}, v.asArray(), 1e-7); + }); + } + + /** */ + @Test + public void testLabeled() { + DataStreamGenerator generator = new DataStreamGenerator() { + @Override public Stream<LabeledVector<Vector, Double>> labeled() { + return Stream.generate(() -> new LabeledVector<>(VectorUtils.of(1., 2.), 100.)); + } + }; + + generator.labeled(v -> -100.).limit(100).forEach(v -> { + assertArrayEquals(new double[] {1., 2.}, v.features().asArray(), 1e-7); + assertEquals(-100., v.label(), 1e-7); + }); + } + + /** */ + @Test + public void testMapVectors() { + DataStreamGenerator generator = new DataStreamGenerator() { + @Override public Stream<LabeledVector<Vector, Double>> labeled() { + return Stream.generate(() -> new LabeledVector<>(VectorUtils.of(1., 2.), 100.)); + } + }; + + generator.mapVectors(v -> VectorUtils.of(2., 1.)).labeled().limit(100).forEach(v -> { + assertArrayEquals(new double[] {2., 1.}, v.features().asArray(), 1e-7); + assertEquals(100., v.label(), 1e-7); + }); + } + + /** */ + @Test + public void testBlur() { + DataStreamGenerator generator = new DataStreamGenerator() { + @Override public Stream<LabeledVector<Vector, Double>> labeled() { + return Stream.generate(() -> new LabeledVector<>(VectorUtils.of(1., 2.), 100.)); + } + }; + + generator.blur(() -> 1.).labeled().limit(100).forEach(v -> { + assertArrayEquals(new double[] {2., 3.}, v.features().asArray(), 1e-7); + assertEquals(100., v.label(), 1e-7); + }); + } + + /** */ + @Test + public void testAsMap() { + DataStreamGenerator generator = new DataStreamGenerator() { + @Override public Stream<LabeledVector<Vector, Double>> labeled() { + return Stream.generate(() -> new LabeledVector<>(VectorUtils.of(1., 2.), 100.)); + } + }; + + int N = 100; + Map<Vector, Double> dataset = generator.asMap(N); + assertEquals(N, dataset.size()); + dataset.forEach(((vector, label) -> { + assertArrayEquals(new double[] {1., 2.}, vector.asArray(), 1e-7); + assertEquals(100., label, 1e-7); + })); + } + + /** */ + @Test + public void testAsDatasetBuilder() throws Exception { + AtomicInteger counter = new AtomicInteger(); + DataStreamGenerator generator = new DataStreamGenerator() { + @Override public Stream<LabeledVector<Vector, Double>> labeled() { + return Stream.generate(() -> { + int value = counter.getAndIncrement(); + return new LabeledVector<>(VectorUtils.of(value), (double)value % 2); + }); + } + }; + + int N = 100; + counter.set(0); + DatasetBuilder<Vector, Double> b1 = generator.asDatasetBuilder(N, 2); + counter.set(0); + DatasetBuilder<Vector, Double> b2 = generator.asDatasetBuilder(N, (v, l) -> l == 0, 2); + counter.set(0); + DatasetBuilder<Vector, Double> b3 = generator.asDatasetBuilder(N, (v, l) -> l == 1, 2, + new UpstreamTransformerBuilder<Vector, Double>() { + @Override public UpstreamTransformer<Vector, Double> build(LearningEnvironment env) { + return new UpstreamTransformerForTest(); + } + }); + + checkDataset(N, b1, v -> (Double)v.label() == 0 || (Double)v.label() == 1); + checkDataset(N / 2, b2, v -> (Double)v.label() == 0); + checkDataset(N / 2, b3, v -> (Double)v.label() < 0); + } + + /** */ + private void checkDataset(int sampleSize, DatasetBuilder<Vector, Double> datasetBuilder, + Predicate<LabeledVector> labelCheck) throws Exception { + + try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = buildDataset(datasetBuilder)) { + List<LabeledVector> res = dataset.compute(this::map, this::reduce); + assertEquals(sampleSize, res.size()); + + res.forEach(v -> assertTrue(labelCheck.test(v))); + } + } + + /** */ + private Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset( + DatasetBuilder<Vector, Double> b1) { + return b1.build(LearningEnvironmentBuilder.defaultBuilder(), + new EmptyContextBuilder<>(), + new LabeledDatasetPartitionDataBuilderOnHeap<>((v, l) -> v, (v, l) -> l) + ); + } + + /** */ + private List<LabeledVector> map(LabeledVectorSet<Double, LabeledVector> d) { + return IntStream.range(0, d.rowSize()).mapToObj(d::getRow).collect(Collectors.toList()); + } + + /** */ + private List<LabeledVector> reduce(List<LabeledVector> l, List<LabeledVector> r) { + if (l == null) { + if (r == null) + return Collections.emptyList(); + else + return r; + } + else { + List<LabeledVector> res = new ArrayList<>(); + res.addAll(l); + res.addAll(r); + return res; + } + } + + /** */ + private static class UpstreamTransformerForTest implements UpstreamTransformer<Vector, Double> { + @Override public Stream<UpstreamEntry<Vector, Double>> transform( + Stream<UpstreamEntry<Vector, Double>> upstream) { + return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -entry.getValue())); + } + } +}
http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/DiscreteRandomProducerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/DiscreteRandomProducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/DiscreteRandomProducerTest.java new file mode 100644 index 0000000..83178ac --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/DiscreteRandomProducerTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.scalar; + +import java.util.HashMap; +import java.util.Map; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link DiscreteRandomProducer}. + */ +public class DiscreteRandomProducerTest { + /** */ + @Test + public void testGet() { + double[] probs = new double[] {0.1, 0.2, 0.3, 0.4}; + DiscreteRandomProducer producer = new DiscreteRandomProducer(0L, probs); + + Map<Integer, Double> counters = new HashMap<>(); + IntStream.range(0, probs.length).forEach(i -> counters.put(i, 0.0)); + + final int N = 500000; + Stream.generate(producer::getInt).limit(N).forEach(i -> counters.put(i, counters.get(i) + 1)); + IntStream.range(0, probs.length).forEach(i -> counters.put(i, counters.get(i) / N)); + + for (int i = 0; i < probs.length; i++) + assertEquals(probs[i], counters.get(i), 0.01); + + assertEquals(probs.length, producer.size()); + } + + /** */ + @Test + public void testSeedConsidering() { + DiscreteRandomProducer producer1 = new DiscreteRandomProducer(0L, 0.1, 0.2, 0.3, 0.4); + DiscreteRandomProducer producer2 = new DiscreteRandomProducer(0L, 0.1, 0.2, 0.3, 0.4); + + assertEquals(producer1.get(), producer2.get(), 0.0001); + } + + /** */ + @Test + public void testUniformGeneration() { + int N = 10; + DiscreteRandomProducer producer = DiscreteRandomProducer.uniform(N); + + Map<Integer, Double> counters = new HashMap<>(); + IntStream.range(0, N).forEach(i -> counters.put(i, 0.0)); + + final int sampleSize = 500000; + Stream.generate(producer::getInt).limit(sampleSize).forEach(i -> counters.put(i, counters.get(i) + 1)); + IntStream.range(0, N).forEach(i -> counters.put(i, counters.get(i) / sampleSize)); + + for (int i = 0; i < N; i++) + assertEquals(1.0 / N, counters.get(i), 0.01); + } + + /** */ + @Test + public void testDistributionGeneration() { + double[] probs = DiscreteRandomProducer.randomDistribution(5, 0L); + assertArrayEquals(new double[] {0.23, 0.27, 0.079, 0.19, 0.20}, probs, 0.01); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testInvalidDistribution1() { + new DiscreteRandomProducer(0L, 0.1, 0.2, 0.3, 0.0); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testInvalidDistribution2() { + new DiscreteRandomProducer(0L, 0.1, 0.2, 0.3, 1.0); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testInvalidDistribution3() { + new DiscreteRandomProducer(0L, 0.1, 0.2, 0.3, 1.0, -0.6); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/GaussRandomProducerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/GaussRandomProducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/GaussRandomProducerTest.java new file mode 100644 index 0000000..845c284 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/GaussRandomProducerTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.scalar; + +import java.util.Random; +import java.util.stream.IntStream; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link GaussRandomProducer}. + */ +public class GaussRandomProducerTest { + /** */ + @Test + public void testGet() { + Random random = new Random(0L); + final double mean = random.nextInt(5) - 2.5; + final double variance = random.nextInt(5); + GaussRandomProducer producer = new GaussRandomProducer(mean, variance, 1L); + + final int N = 50000; + double meanStat = IntStream.range(0, N).mapToDouble(i -> producer.get()).sum() / N; + double varianceStat = IntStream.range(0, N).mapToDouble(i -> Math.pow(producer.get() - mean, 2)).sum() / N; + + assertEquals(mean, meanStat, 0.01); + assertEquals(variance, varianceStat, 0.1); + } + + /** */ + @Test + public void testSeedConsidering() { + GaussRandomProducer producer1 = new GaussRandomProducer(0L); + GaussRandomProducer producer2 = new GaussRandomProducer(0L); + + assertEquals(producer1.get(), producer2.get(), 0.0001); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testIllegalVariance1() { + new GaussRandomProducer(0, 0.); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testIllegalVariance2() { + new GaussRandomProducer(0, -1.); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/RandomProducerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/RandomProducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/RandomProducerTest.java new file mode 100644 index 0000000..34e44b3 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/RandomProducerTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.scalar; + +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link RandomProducer}. + */ +public class RandomProducerTest { + /** */ + @Test + public void testVectorize() { + RandomProducer p = () -> 1.0; + Vector vec = p.vectorize(3).get(); + + assertEquals(3, vec.size()); + assertArrayEquals(new double[] {1., 1., 1.}, vec.asArray(), 1e-7); + } + + /** */ + @Test + public void testVectorize2() { + Vector vec = RandomProducer.vectorize( + () -> 1.0, + () -> 2.0, + () -> 3.0 + ).get(); + + assertEquals(3, vec.size()); + assertArrayEquals(new double[] {1., 2., 3.}, vec.asArray(), 1e-7); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testVectorizeFail() { + RandomProducer.vectorize(); + } + + /** */ + @Test + public void testNoizify1() { + IgniteFunction<Double, Double> f = v -> 2 * v; + RandomProducer p = () -> 1.0; + + IgniteFunction<Double, Double> res = p.noizify(f); + + for (int i = 0; i < 10; i++) + assertEquals(2 * i + 1.0, res.apply((double)i), 1e-7); + } + + /** */ + @Test + public void testNoizify2() { + RandomProducer p = () -> 1.0; + assertArrayEquals(new double[] {1., 2.}, p.noizify(VectorUtils.of(0., 1.)).asArray(), 1e-7); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/UniformRandomProducerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/UniformRandomProducerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/UniformRandomProducerTest.java new file mode 100644 index 0000000..bc18c93 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/scalar/UniformRandomProducerTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.scalar; + +import java.util.Arrays; +import java.util.Random; +import java.util.stream.IntStream; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link UniformRandomProducer}. + */ +public class UniformRandomProducerTest { + /** */ + @Test + public void testGet() { + Random random = new Random(0L); + double[] bounds = Arrays.asList(random.nextInt(10) - 5, random.nextInt(10) - 5) + .stream().sorted().mapToDouble(x -> x) + .toArray(); + + double min = Math.min(bounds[0], bounds[1]); + double max = Math.max(bounds[0], bounds[1]); + + double mean = (min + max) / 2; + double variance = Math.pow(min - max, 2) / 12; + UniformRandomProducer producer = new UniformRandomProducer(min, max, 0L); + + final int N = 500000; + double meanStat = IntStream.range(0, N).mapToDouble(i -> producer.get()).sum() / N; + double varianceStat = IntStream.range(0, N).mapToDouble(i -> Math.pow(producer.get() - mean, 2)).sum() / N; + + assertEquals(mean, meanStat, 0.01); + assertEquals(variance, varianceStat, 0.1); + } + + /** */ + @Test + public void testSeedConsidering() { + UniformRandomProducer producer1 = new UniformRandomProducer(0, 1, 0L); + UniformRandomProducer producer2 = new UniformRandomProducer(0, 1, 0L); + + assertEquals(producer1.get(), producer2.get(), 0.0001); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testFail() { + new UniformRandomProducer(1, 0, 0L); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/ParametricVectorGeneratorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/ParametricVectorGeneratorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/ParametricVectorGeneratorTest.java new file mode 100644 index 0000000..70ae237 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/ParametricVectorGeneratorTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.vector; + +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link ParametricVectorGenerator}. + */ +public class ParametricVectorGeneratorTest { + /** */ + @Test + public void testGet() { + Vector vec = new ParametricVectorGenerator( + () -> 2., + t -> t, + t -> 2 * t, + t -> 3 * t, + t -> 100. + ).get(); + + assertEquals(4, vec.size()); + assertArrayEquals(new double[] {2., 4., 6., 100.}, vec.asArray(), 1e-7); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testIllegalArguments() { + new ParametricVectorGenerator(() -> 2.).get(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorPrimitivesTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorPrimitivesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorPrimitivesTest.java new file mode 100644 index 0000000..85dd6df --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorPrimitivesTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.vector; + +import java.util.stream.IntStream; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link VectorGeneratorPrimitives}. + */ +public class VectorGeneratorPrimitivesTest { + /** */ + @Test + public void testConstant() { + Vector vec = VectorUtils.of(1.0, 0.0); + assertArrayEquals(vec.copy().asArray(), VectorGeneratorPrimitives.constant(vec).get().asArray(), 1e-7); + } + + /** */ + @Test + public void testZero() { + assertArrayEquals(new double[] {0., 0.}, VectorGeneratorPrimitives.zero(2).get().asArray(), 1e-7); + } + + /** */ + @Test + public void testRing() { + VectorGeneratorPrimitives.ring(1., 0, 2 * Math.PI) + .asDataStream().unlabeled().limit(1000) + .forEach(v -> assertEquals(v.getLengthSquared(), 1., 1e-7)); + + VectorGeneratorPrimitives.ring(1., 0, Math.PI / 2) + .asDataStream().unlabeled().limit(1000) + .forEach(v -> { + assertTrue(v.get(0) >= 0.); + assertTrue(v.get(1) >= 0.); + }); + } + + /** */ + @Test + public void testCircle() { + VectorGeneratorPrimitives.circle(1.) + .asDataStream().unlabeled().limit(1000) + .forEach(v -> assertTrue(Math.sqrt(v.getLengthSquared()) <= 1.)); + } + + /** */ + @Test + public void testParallelogram() { + VectorGeneratorPrimitives.parallelogram(VectorUtils.of(2., 100.)) + .asDataStream().unlabeled().limit(1000) + .forEach(v -> { + assertTrue(v.get(0) <= 2.); + assertTrue(v.get(0) >= -2.); + assertTrue(v.get(1) <= 100.); + assertTrue(v.get(1) >= -100.); + }); + } + + /** */ + @Test + public void testGauss() { + VectorGenerator gen = VectorGeneratorPrimitives.gauss(VectorUtils.of(2., 100.), VectorUtils.of(20., 1.), 10L); + + final double[] mean = new double[] {2., 100.}; + final double[] variance = new double[] {20., 1.}; + + final int N = 50000; + Vector meanStat = IntStream.range(0, N).mapToObj(i -> gen.get()).reduce(Vector::plus).get().times(1. / N); + Vector varianceStat = IntStream.range(0, N).mapToObj(i -> gen.get().minus(meanStat)) + .map(v -> v.times(v)).reduce(Vector::plus).get().times(1. / N); + + assertArrayEquals(mean, meanStat.asArray(), 0.1); + assertArrayEquals(variance, varianceStat.asArray(), 0.1); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testGaussFail1() { + VectorGeneratorPrimitives.gauss(VectorUtils.of(), VectorUtils.of()); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testGaussFail2() { + VectorGeneratorPrimitives.gauss(VectorUtils.of(0.5, -0.5), VectorUtils.of(1.0, -1.0)); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorTest.java new file mode 100644 index 0000000..19e42d5 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.vector; + +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.util.generators.primitives.scalar.UniformRandomProducer; +import org.junit.Test; +import org.junit.internal.ArrayComparisonFailure; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link VectorGenerator}. + */ +public class VectorGeneratorTest { + /** */ + @Test + public void testMap() { + Vector originalVec = new UniformRandomProducer(-1, 1).vectorize(2).get(); + Vector doubledVec = VectorGeneratorPrimitives.constant(originalVec).map(v -> v.times(2.)).get(); + assertArrayEquals(originalVec.times(2.).asArray(), doubledVec.asArray(), 1e-7); + } + + /** */ + @Test + public void testFilter() { + new UniformRandomProducer(-1, 1).vectorize(2) + .filter(v -> v.get(0) < 0.5) + .filter(v -> v.get(1) > -0.5) + .asDataStream().unlabeled().limit(100) + .forEach(v -> assertTrue(v.get(0) < 0.5 && v.get(1) > -0.5)); + } + + /** */ + @Test + public void concat1() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2.)); + VectorGenerator g2 = VectorGeneratorPrimitives.constant(VectorUtils.of(3., 4.)); + VectorGenerator g12 = g1.concat(g2); + VectorGenerator g21 = g2.concat(g1); + + assertArrayEquals(new double[] {1., 2., 3., 4.}, g12.get().asArray(), 1e-7); + assertArrayEquals(new double[] {3., 4., 1., 2.}, g21.get().asArray(), 1e-7); + } + + /** */ + @Test + public void concat2() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2.)); + VectorGenerator g2 = g1.concat(() -> 1.0); + + assertArrayEquals(new double[] {1., 2., 1.}, g2.get().asArray(), 1e-7); + } + + /** */ + @Test + public void plus() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2.)); + VectorGenerator g2 = VectorGeneratorPrimitives.constant(VectorUtils.of(3., 4.)); + VectorGenerator g12 = g1.plus(g2); + VectorGenerator g21 = g2.plus(g1); + + assertArrayEquals(new double[] {4., 6.}, g21.get().asArray(), 1e-7); + assertArrayEquals(g21.get().asArray(), g12.get().asArray(), 1e-7); + } + + /** */ + @Test(expected = CardinalityException.class) + public void testPlusForDifferentSizes1() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2.)); + VectorGenerator g2 = VectorGeneratorPrimitives.constant(VectorUtils.of(3.)); + g1.plus(g2).get(); + } + + /** */ + @Test(expected = CardinalityException.class) + public void testPlusForDifferentSizes2() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2.)); + VectorGenerator g2 = VectorGeneratorPrimitives.constant(VectorUtils.of(3.)); + g2.plus(g1).get(); + } + + /** */ + @Test + public void shuffle() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2., 3., 4.)) + .shuffle(0L); + + double[] exp = {4., 1., 2., 3.}; + Vector v1 = g1.get(); + Vector v2 = g1.get(); + assertArrayEquals(exp, v1.asArray(), 1e-7); + assertArrayEquals(v1.asArray(), v2.asArray(), 1e-7); + } + + /** */ + @Test + public void duplicateRandomFeatures() { + VectorGenerator g1 = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2., 3., 4.)) + .duplicateRandomFeatures(2, 1L); + + double[] exp = {1., 2., 3., 4., 3., 1.}; + Vector v1 = g1.get(); + Vector v2 = g1.get(); + + assertArrayEquals(exp, v1.asArray(), 1e-7); + + try { + assertArrayEquals(v1.asArray(), v2.asArray(), 1e-7); + } + catch (ArrayComparisonFailure e) { + //this is valid situation - duplicater should get different features + } + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testWithNegativeIncreaseSize() { + VectorGeneratorPrimitives.constant(VectorUtils.of(1., 2., 3., 4.)) + .duplicateRandomFeatures(-2, 1L).get(); + } + + /** */ + @Test + public void move() { + Vector res = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 1.)) + .move(VectorUtils.of(2., 4.)) + .get(); + + assertArrayEquals(new double[] {3., 5.}, res.asArray(), 1e-7); + } + + /** */ + @Test(expected = CardinalityException.class) + public void testMoveWithDifferentSizes1() { + VectorGeneratorPrimitives.constant(VectorUtils.of(1., 1.)) + .move(VectorUtils.of(2.)) + .get(); + } + + /** */ + @Test(expected = CardinalityException.class) + public void testMoveWithDifferentSizes2() { + VectorGeneratorPrimitives.constant(VectorUtils.of(1.)) + .move(VectorUtils.of(2., 1.)) + .get(); + } + + /** */ + @Test + public void rotate() { + double[] angles = {0., Math.PI / 2, -Math.PI / 2, Math.PI, 2 * Math.PI, Math.PI / 4}; + Vector[] exp = new Vector[] { + VectorUtils.of(1., 0., 100.), + VectorUtils.of(0., -1., 100.), + VectorUtils.of(0., 1., 100.), + VectorUtils.of(-1., 0., 100.), + VectorUtils.of(1., 0., 100.), + VectorUtils.of(0.707, -0.707, 100.) + }; + + for (int i = 0; i < angles.length; i++) { + Vector res = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 0., 100.)) + .rotate(angles[i]).get(); + assertArrayEquals(exp[i].asArray(), res.asArray(), 1e-3); + } + } + + /** */ + @Test + public void noisify() { + Vector res = VectorGeneratorPrimitives.constant(VectorUtils.of(1., 0.)) + .noisify(() -> 0.5).get(); + assertArrayEquals(new double[] {1.5, 0.5}, res.asArray(), 1e-7); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d0facb26/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorsFamilyTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorsFamilyTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorsFamilyTest.java new file mode 100644 index 0000000..5a16f12 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/primitives/vector/VectorGeneratorsFamilyTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.util.generators.primitives.vector; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link VectorGeneratorsFamily}. + */ +public class VectorGeneratorsFamilyTest { + /** */ + @Test + public void testSelection() { + VectorGeneratorsFamily family = new VectorGeneratorsFamily.Builder() + .add(() -> VectorUtils.of(1., 2.), 0.5) + .add(() -> VectorUtils.of(1., 2.), 0.25) + .add(() -> VectorUtils.of(1., 4.), 0.25) + .build(0L); + + Map<Integer, Vector> counters = new HashMap<>(); + for (int i = 0; i < 3; i++) + counters.put(i, VectorUtils.zeroes(2)); + + int N = 50000; + IntStream.range(0, N).forEach(i -> { + VectorGeneratorsFamily.VectorWithDistributionId vector = family.getWithId(); + int id = vector.distributionId(); + counters.put(id, counters.get(id).plus(vector.vector())); + }); + + for (int i = 0; i < 3; i++) + counters.put(i, counters.get(i).divide(N)); + + assertArrayEquals(new double[] {0.5, 1.0}, counters.get(0).asArray(), 1e-2); + assertArrayEquals(new double[] {0.25, .5}, counters.get(1).asArray(), 1e-2); + assertArrayEquals(new double[] {0.25, 1.}, counters.get(2).asArray(), 1e-2); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testInvalidParameters1() { + new VectorGeneratorsFamily.Builder().build(); + } + + /** */ + @Test(expected = IllegalArgumentException.class) + public void testInvalidParameters2() { + new VectorGeneratorsFamily.Builder().add(() -> VectorUtils.of(1.), -1.).build(); + } + + /** */ + @Test + public void testMap() { + VectorGeneratorsFamily family = new VectorGeneratorsFamily.Builder() + .add(() -> VectorUtils.of(1., 2.)) + .map(g -> g.move(VectorUtils.of(1, -1))) + .build(0L); + + assertArrayEquals(new double[] {2., 1.}, family.get().asArray(), 1e-7); + } + + /** */ + @Test + public void testGet() { + VectorGeneratorsFamily family = new VectorGeneratorsFamily.Builder() + .add(() -> VectorUtils.of(0.)) + .add(() -> VectorUtils.of(1.)) + .add(() -> VectorUtils.of(2.)) + .build(0L); + + Set<Double> validValues = DoubleStream.of(0., 1., 2.).boxed().collect(Collectors.toSet()); + for (int i = 0; i < 100; i++) { + Vector vector = family.get(); + assertTrue(validValues.contains(vector.get(0))); + } + } + + /** */ + @Test + public void testAsDataStream() { + VectorGeneratorsFamily family = new VectorGeneratorsFamily.Builder() + .add(() -> VectorUtils.of(0.)) + .add(() -> VectorUtils.of(1.)) + .add(() -> VectorUtils.of(2.)) + .build(0L); + + family.asDataStream().labeled().limit(100).forEach(v -> { + assertEquals(v.features().get(0), v.label(), 1e-7); + }); + } +}