Repository: ignite
Updated Branches:
  refs/heads/master 523900a0c -> c0cc7d78e


http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
new file mode 100644
index 0000000..40a416f
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import 
org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
+import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
+import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/** */
+public class GDBTrainerTest {
+    /** */
+    @Test public void testFitRegression() {
+        int size = 100;
+        double[] xs = new double[size];
+        double[] ys = new double[size];
+        double from = -5.0;
+        double to = 5.0;
+        double step = Math.abs(from - to) / size;
+
+        Map<Integer, double[]> learningSample = new HashMap<>();
+        for (int i = 0; i < size; i++) {
+            xs[i] = from + step * i;
+            ys[i] = 2 * xs[i];
+            learningSample.put(i, new double[] {xs[i], ys[i]});
+        }
+
+        DatasetTrainer<Model<Vector, Double>, Double> trainer = new 
GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0);
+        Model<Vector, Double> model = trainer.fit(
+            learningSample, 1,
+            (k, v) -> new double[] {v[0]},
+            (k, v) -> v[1]
+        );
+
+        double mse = 0.0;
+        for (int j = 0; j < size; j++) {
+            double x = xs[j];
+            double y = ys[j];
+            double p = model.apply(VectorUtils.of(x));
+            mse += Math.pow(y - p, 2);
+        }
+        mse /= size;
+
+        assertEquals(0.0, mse, 0.0001);
+
+        assertTrue(model instanceof ModelsComposition);
+        ModelsComposition composition = (ModelsComposition) model;
+        composition.getModels().forEach(m -> assertTrue(m instanceof 
DecisionTreeConditionalNode));
+
+        assertEquals(2000, composition.getModels().size());
+        assertTrue(composition.getPredictionsAggregator() instanceof 
WeightedPredictionsAggregator);
+    }
+
+    /** */
+    @Test public void testFitClassifier() {
+        int sampleSize = 100;
+        double[] xs = new double[sampleSize];
+        double[] ys = new double[sampleSize];
+
+        for (int i = 0; i < sampleSize; i++) {
+            xs[i] = i;
+            ys[i] = ((int)(xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
+        }
+
+        Map<Integer, double[]> learningSample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++)
+            learningSample.put(i, new double[] {xs[i], ys[i]});
+
+        DatasetTrainer<Model<Vector, Double>, Double> trainer = new 
GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0);
+        Model<Vector, Double> model = trainer.fit(
+            learningSample, 1,
+            (k, v) -> new double[] {v[0]},
+            (k, v) -> v[1]
+        );
+
+        int errorsCount = 0;
+        for (int j = 0; j < sampleSize; j++) {
+            double x = xs[j];
+            double y = ys[j];
+            double p = model.apply(VectorUtils.of(x));
+            if(p != y)
+                errorsCount++;
+        }
+
+        assertEquals(0, errorsCount);
+
+        assertTrue(model instanceof ModelsComposition);
+        ModelsComposition composition = (ModelsComposition) model;
+        composition.getModels().forEach(m -> assertTrue(m instanceof 
DecisionTreeConditionalNode));
+
+        assertEquals(500, composition.getModels().size());
+        assertTrue(composition.getPredictionsAggregator() instanceof 
WeightedPredictionsAggregator);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java
new file mode 100644
index 0000000..7fda6b6
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregatorTest.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.predictionsaggregator;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/** */
+public class WeightedPredictionsAggregatorTest {
+    /** */
+    @Test public void testApply1() {
+        WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(new double[] {});
+        assertEquals(0.0, aggregator.apply(new double[] {}), 0.001);
+    }
+
+    /** */
+    @Test public void testApply2() {
+        WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(new double[] {1.0, 0.5, 0.25});
+        assertEquals(3.0, aggregator.apply(new double[] {1.0, 2.0, 4.0}), 
0.001);
+    }
+
+    /** Non-equal weight vector and predictions case */
+    @Test(expected = IllegalArgumentException.class)
+    public void testIllegalArguments() {
+        WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(new double[] {1.0, 0.5, 0.25});
+        aggregator.apply(new double[] { });
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java
new file mode 100644
index 0000000..6479276
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/VectorUtilsTest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.math;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class VectorUtilsTest {
+    /** */
+    @Test
+    public void testOf1() {
+        double[] values = {1.0, 2.0, 3.0};
+        Vector vector = VectorUtils.of(values);
+
+        assertEquals(3, vector.size());
+        assertEquals(3, vector.nonZeroElements());
+        for (int i = 0; i < values.length; i++)
+            assertEquals(values[i], vector.get(i), 0.001);
+    }
+
+    /** */
+    @Test
+    public void testOf2() {
+        Double[] values = {1.0, null, 3.0};
+        Vector vector = VectorUtils.of(values);
+
+        assertEquals(3, vector.size());
+        assertEquals(2, vector.nonZeroElements());
+        for (int i = 0; i < values.length; i++) {
+            if (values[i] == null)
+                assertEquals(0.0, vector.get(i), 0.001);
+            else
+                assertEquals(values[i], vector.get(i), 0.001);
+        }
+    }
+
+    /** */
+    @Test(expected = NullPointerException.class)
+    public void testFails1() {
+        double[] values = null;
+        VectorUtils.of(values);
+    }
+
+    /** */
+    @Test(expected = NullPointerException.class)
+    public void testFails2() {
+        Double[] values = null;
+        VectorUtils.of(values);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
index 0494249..2b95d10 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
@@ -21,6 +21,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
@@ -68,13 +69,12 @@ public class RandomForestClassifierTrainerTest {
 
         RandomForestClassifierTrainer trainer = new 
RandomForestClassifierTrainer(4, 3, 5, 0.3, 4, 0.1);
         ModelsComposition model = trainer.fit(sample, parts, (k, v) -> k, (k, 
v) -> v);
+        model.getModels().forEach(m -> {
+            assertTrue(m instanceof ModelOnFeaturesSubspace);
+            assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof 
DecisionTreeConditionalNode);
+        });
 
         assertTrue(model.getPredictionsAggregator() instanceof 
OnMajorityPredictionsAggregator);
         assertEquals(5, model.getModels().size());
-
-        for (ModelsComposition.ModelOnFeaturesSubspace tree : 
model.getModels()) {
-            assertTrue(tree.getMdl() instanceof DecisionTreeConditionalNode);
-            assertEquals(3, tree.getFeaturesMapping().size());
-        }
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/c0cc7d78/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
index 418a98c..e837c65 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
@@ -21,6 +21,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import org.apache.ignite.ml.composition.ModelOnFeaturesSubspace;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
 import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
@@ -68,13 +69,12 @@ public class RandomForestRegressionTrainerTest {
 
         RandomForestRegressionTrainer trainer = new 
RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1);
         ModelsComposition model = trainer.fit(sample, parts, (k, v) -> v, (k, 
v) -> k);
+        model.getModels().forEach(m -> {
+            assertTrue(m instanceof ModelOnFeaturesSubspace);
+            assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof 
DecisionTreeConditionalNode);
+        });
 
         assertTrue(model.getPredictionsAggregator() instanceof 
MeanValuePredictionsAggregator);
         assertEquals(5, model.getModels().size());
-
-        for (ModelsComposition.ModelOnFeaturesSubspace tree : 
model.getModels()) {
-            assertTrue(tree.getMdl() instanceof DecisionTreeConditionalNode);
-            assertEquals(3, tree.getFeaturesMapping().size());
-        }
     }
 }

Reply via email to