Repository: ignite Updated Branches: refs/heads/master 523900a0c -> c0cc7d78e
http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java new file mode 100644 index 0000000..40a416f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.boosting; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; +import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer; +import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** */ +public class GDBTrainerTest { + /** */ + @Test public void testFitRegression() { + int size = 100; + double[] xs = new double[size]; + double[] ys = new double[size]; + double from = -5.0; + double to = 5.0; + double step = Math.abs(from - to) / size; + + Map<Integer, double[]> learningSample = new HashMap<>(); + for (int i = 0; i < size; i++) { + xs[i] = from + step * i; + ys[i] = 2 * xs[i]; + learningSample.put(i, new double[] {xs[i], ys[i]}); + } + + DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0); + Model<Vector, Double> model = trainer.fit( + learningSample, 1, + (k, v) -> new double[] {v[0]}, + (k, v) -> v[1] + ); + + double mse = 0.0; + for (int j = 0; j < size; j++) { + double x = xs[j]; + double y = ys[j]; + double p = model.apply(VectorUtils.of(x)); + mse += Math.pow(y - p, 2); + } + mse /= size; + + assertEquals(0.0, mse, 0.0001); + + assertTrue(model instanceof ModelsComposition); + ModelsComposition composition = (ModelsComposition) model; + composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode)); + + assertEquals(2000, composition.getModels().size()); + assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator); + } + + /** */ + @Test public void testFitClassifier() { + int sampleSize = 100; + double[] xs = new double[sampleSize]; + double[] ys = new double[sampleSize]; + + for (int i = 0; i < sampleSize; i++) { + xs[i] = i; + ys[i] = ((int)(xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0; + } + + Map<Integer, double[]> learningSample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) + learningSample.put(i, new double[] {xs[i], ys[i]}); + + DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0); + Model<Vector, Double> model = trainer.fit( + learningSample, 1, + (k, v) -> new double[] {v[0]}, + (k, v) -> v[1] + ); + + int errorsCount = 0; + for (int j = 0; j < sampleSize; j++) { + double x = xs[j]; + double y = ys[j]; + double p = model.apply(VectorUtils.of(x)); + if(p != y) + errorsCount++; + } + + assertEquals(0, errorsCount); + + assertTrue(model instanceof ModelsComposition); + ModelsComposition composition = (ModelsComposition) model; + composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode)); + + assertEquals(500, composition.getModels().size()); + assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java new file mode 100644 index 0000000..7fda6b6 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.predictionsaggregator; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** */ +public class WeightedPredictionsAggregatorTest { + /** */ + @Test public void testApply1() { + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(new double[] {}); + assertEquals(0.0, aggregator.apply(new double[] {}), 0.001); + } + + /** */ + @Test public void testApply2() { + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(new double[] {1.0, 0.5, 0.25}); + assertEquals(3.0, aggregator.apply(new double[] {1.0, 2.0, 4.0}), 0.001); + } + + /** Non-equal weight vector and predictions case */ + @Test(expected = IllegalArgumentException.class) + public void testIllegalArguments() { + WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(new double[] {1.0, 0.5, 0.25}); + aggregator.apply(new double[] { }); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java new file mode 100644 index 0000000..6479276 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class VectorUtilsTest { + /** */ + @Test + public void testOf1() { + double[] values = {1.0, 2.0, 3.0}; + Vector vector = VectorUtils.of(values); + + assertEquals(3, vector.size()); + assertEquals(3, vector.nonZeroElements()); + for (int i = 0; i < values.length; i++) + assertEquals(values[i], vector.get(i), 0.001); + } + + /** */ + @Test + public void testOf2() { + Double[] values = {1.0, null, 3.0}; + Vector vector = VectorUtils.of(values); + + assertEquals(3, vector.size()); + assertEquals(2, vector.nonZeroElements()); + for (int i = 0; i < values.length; i++) { + if (values[i] == null) + assertEquals(0.0, vector.get(i), 0.001); + else + assertEquals(values[i], vector.get(i), 0.001); + } + } + + /** */ + @Test(expected = NullPointerException.class) + public void testFails1() { + double[] values = null; + VectorUtils.of(values); + } + + /** */ + @Test(expected = NullPointerException.class) + public void testFails2() { + Double[] values = null; + VectorUtils.of(values); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index 0494249..2b95d10 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; @@ -68,13 +69,12 @@ public class RandomForestClassifierTrainerTest { RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(4, 3, 5, 0.3, 4, 0.1); ModelsComposition model = trainer.fit(sample, parts, (k, v) -> k, (k, v) -> v); + model.getModels().forEach(m -> { + assertTrue(m instanceof ModelOnFeaturesSubspace); + assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof DecisionTreeConditionalNode); + }); assertTrue(model.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator); assertEquals(5, model.getModels().size()); - - for (ModelsComposition.ModelOnFeaturesSubspace tree : model.getModels()) { - assertTrue(tree.getMdl() instanceof DecisionTreeConditionalNode); - assertEquals(3, tree.getFeaturesMapping().size()); - } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java index 418a98c..e837c65 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; import org.apache.ignite.ml.tree.DecisionTreeConditionalNode; @@ -68,13 +69,12 @@ public class RandomForestRegressionTrainerTest { RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1); ModelsComposition model = trainer.fit(sample, parts, (k, v) -> v, (k, v) -> k); + model.getModels().forEach(m -> { + assertTrue(m instanceof ModelOnFeaturesSubspace); + assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof DecisionTreeConditionalNode); + }); assertTrue(model.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator); assertEquals(5, model.getModels().size()); - - for (ModelsComposition.ModelOnFeaturesSubspace tree : model.getModels()) { - assertTrue(tree.getMdl() instanceof DecisionTreeConditionalNode); - assertEquals(3, tree.getFeaturesMapping().size()); - } } }