http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java index d465e82..4b7fa33 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java @@ -17,17 +17,15 @@ package org.apache.ignite.ml.preprocessing.binarization; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -35,26 +33,7 @@ import static org.junit.Assert.assertEquals; /** * Tests for {@link BinarizationTrainer}. */ -@RunWith(Parameterized.class) -public class BinarizationTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class BinarizationTrainerTest extends TrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() {
http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java index 6d01901..23afd30 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java @@ -17,15 +17,13 @@ package org.apache.ignite.ml.preprocessing.encoding; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialFeatureValue; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static junit.framework.TestCase.fail; import static org.junit.Assert.assertArrayEquals; @@ -33,26 +31,7 @@ import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link EncoderTrainer}. */ -@RunWith(Parameterized.class) -public class EncoderTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[]{1}, - new Integer[]{2}, - new Integer[]{3}, - new Integer[]{5}, - new Integer[]{7}, - new Integer[]{100}, - new Integer[]{1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class EncoderTrainerTest extends TrainerTest { /** Tests {@code fit()} method. */ @Test public void testFitOnStringCategorialFeatures() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java index 006ac29..9c11d13 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/imputing/ImputerTrainerTest.java @@ -17,42 +17,21 @@ package org.apache.ignite.ml.preprocessing.imputing; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link ImputerTrainer}. */ -@RunWith(Parameterized.class) -public class ImputerTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class ImputerTrainerTest extends TrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerPreprocessorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerPreprocessorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerPreprocessorTest.java index 3c30f3e..91562da 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerPreprocessorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerPreprocessorTest.java @@ -42,7 +42,7 @@ public class MaxAbsScalerPreprocessorTest { (k, v) -> v ); - double[][] expectedData = new double[][] { + double[][] expData = new double[][] { {.5, 4. / 22, 1. / 300}, {.25, 8. / 22, 22. / 300}, {-1., 10. / 22, 100. / 300}, @@ -50,6 +50,6 @@ public class MaxAbsScalerPreprocessorTest { }; for (int i = 0; i < data.length; i++) - assertArrayEquals(expectedData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8); + assertArrayEquals(expData[i], preprocessor.apply(i, VectorUtils.of(data[i])).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java index 5711660..844468e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/maxabsscaling/MaxAbsScalerTrainerTest.java @@ -17,42 +17,21 @@ package org.apache.ignite.ml.preprocessing.maxabsscaling; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link MaxAbsScalerTrainer}. */ -@RunWith(Parameterized.class) -public class MaxAbsScalerTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class MaxAbsScalerTrainerTest extends TrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java index 451f5e9..4c0a99f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/minmaxscaling/MinMaxScalerTrainerTest.java @@ -17,42 +17,21 @@ package org.apache.ignite.ml.preprocessing.minmaxscaling; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; /** * Tests for {@link MinMaxScalerTrainer}. */ -@RunWith(Parameterized.class) -public class MinMaxScalerTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class MinMaxScalerTrainerTest extends TrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java index 7b02f20..9d39354 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java @@ -17,16 +17,14 @@ package org.apache.ignite.ml.preprocessing.normalization; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.preprocessing.binarization.BinarizationTrainer; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -34,26 +32,7 @@ import static org.junit.Assert.assertEquals; /** * Tests for {@link BinarizationTrainer}. */ -@RunWith(Parameterized.class) -public class NormalizationTrainerTest { - /** Parameters. */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - return Arrays.asList( - new Integer[] {1}, - new Integer[] {2}, - new Integer[] {3}, - new Integer[] {5}, - new Integer[] {7}, - new Integer[] {100}, - new Integer[] {1000} - ); - } - - /** Number of partitions. */ - @Parameterized.Parameter - public int parts; - +public class NormalizationTrainerTest extends TrainerTest { /** Tests {@code fit()} method. */ @Test public void testFit() { http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java index 9c35ac7..3ca1a07 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java @@ -123,7 +123,7 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest { LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); - LinearRegressionModel originalModel = trainer.fit( + LinearRegressionModel originalMdl = trainer.fit( data, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), @@ -131,7 +131,7 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest { ); LinearRegressionModel updatedOnSameDS = trainer.update( - originalModel, + originalMdl, data, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), @@ -139,17 +139,17 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest { ); LinearRegressionModel updatedOnEmpyDS = trainer.update( - originalModel, + originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), (k, v) -> v[coef.length] ); - assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6); - assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6); + assertArrayEquals(originalMdl.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6); + assertEquals(originalMdl.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6); - assertArrayEquals(originalModel.getWeights().getStorage().data(), updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6); - assertEquals(originalModel.getIntercept(), updatedOnEmpyDS.getIntercept(), 1e-6); + assertArrayEquals(originalMdl.getWeights().getStorage().data(), updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6); + assertEquals(originalMdl.getIntercept(), updatedOnEmpyDS.getIntercept(), 1e-6); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java index 86b0f27..1af9109 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java @@ -94,7 +94,7 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest { RPropParameterUpdate::avg ), 100000, 10, 100, 0L); - LinearRegressionModel originalModel = trainer.withSeed(0).fit( + LinearRegressionModel originalMdl = trainer.withSeed(0).fit( data, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), @@ -103,7 +103,7 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest { LinearRegressionModel updatedOnSameDS = trainer.withSeed(0).update( - originalModel, + originalMdl, data, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), @@ -111,7 +111,7 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest { ); LinearRegressionModel updatedOnEmptyDS = trainer.withSeed(0).update( - originalModel, + originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), @@ -119,19 +119,19 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest { ); assertArrayEquals( - originalModel.getWeights().getStorage().data(), + originalMdl.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1.0 ); - assertEquals(originalModel.getIntercept(), updatedOnSameDS.getIntercept(), 1.0); + assertEquals(originalMdl.getIntercept(), updatedOnSameDS.getIntercept(), 1.0); assertArrayEquals( - originalModel.getWeights().getStorage().data(), + originalMdl.getWeights().getStorage().data(), updatedOnEmptyDS.getWeights().getStorage().data(), 1e-1 ); - assertEquals(originalModel.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1); + assertEquals(originalMdl.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java index 73c8842..78cd08d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -103,7 +103,7 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { .withBatchSize(100) .withSeed(123L); - LogRegressionMultiClassModel originalModel = trainer.fit( + LogRegressionMultiClassModel originalMdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -111,7 +111,7 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { ); LogRegressionMultiClassModel updatedOnSameDS = trainer.update( - originalModel, + originalMdl, cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -119,7 +119,7 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { ); LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update( - originalModel, + originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -135,8 +135,8 @@ public class LogRegMultiClassTrainerTest extends TrainerTest { for (Vector vec : vectors) { - TestUtils.assertEquals(originalModel.apply(vec), updatedOnSameDS.apply(vec), PRECISION); - TestUtils.assertEquals(originalModel.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); + TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION); + TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); } } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index 1da0d1a..723677c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -76,7 +76,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { SimpleGDParameterUpdate::avg ), 100000, 10, 100, 123L); - LogisticRegressionModel originalModel = trainer.fit( + LogisticRegressionModel originalMdl = trainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -84,7 +84,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { ); LogisticRegressionModel updatedOnSameDS = trainer.update( - originalModel, + originalMdl, cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -92,7 +92,7 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { ); LogisticRegressionModel updatedOnEmptyDS = trainer.update( - originalModel, + originalMdl, new HashMap<Integer, double[]>(), parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -101,9 +101,9 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest { Vector v1 = VectorUtils.of(100, 10); Vector v2 = VectorUtils.of(10, 100); - TestUtils.assertEquals(originalModel.apply(v1), updatedOnSameDS.apply(v1), PRECISION); - TestUtils.assertEquals(originalModel.apply(v2), updatedOnSameDS.apply(v2), PRECISION); - TestUtils.assertEquals(originalModel.apply(v2), updatedOnEmptyDS.apply(v2), PRECISION); - TestUtils.assertEquals(originalModel.apply(v1), updatedOnEmptyDS.apply(v1), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v1), updatedOnSameDS.apply(v1), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v2), updatedOnSameDS.apply(v2), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v2), updatedOnEmptyDS.apply(v2), PRECISION); + TestUtils.assertEquals(originalMdl.apply(v1), updatedOnEmptyDS.apply(v1), PRECISION); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java index 84975a8..d89b9bf 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java @@ -45,7 +45,7 @@ public class DecisionTreeRegressionTrainerTest { /** Use index [= 1 if true]. */ @Parameterized.Parameter(1) - public int useIndex; + public int useIdx; /** Test parameters. */ @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.") @@ -73,7 +73,7 @@ public class DecisionTreeRegressionTrainerTest { } DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0) - .withUsingIdx(useIndex == 1); + .withUsingIdx(useIdx == 1); DecisionTreeNode tree = trainer.fit( data, http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java index 4ee717a..7405c16 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java @@ -40,7 +40,7 @@ public class DecisionTreeDataTest { /** Use index. */ @Parameterized.Parameter - public boolean useIndex; + public boolean useIdx; /** */ @Test @@ -48,7 +48,7 @@ public class DecisionTreeDataTest { double[][] features = new double[][]{{0}, {1}, {2}, {3}, {4}, {5}}; double[] labels = new double[]{0, 1, 2, 3, 4, 5}; - DecisionTreeData data = new DecisionTreeData(features, labels, useIndex); + DecisionTreeData data = new DecisionTreeData(features, labels, useIdx); DecisionTreeData filteredData = data.filter(obj -> obj[0] > 2); assertArrayEquals(new double[][]{{3}, {4}, {5}}, filteredData.getFeatures()); @@ -61,7 +61,7 @@ public class DecisionTreeDataTest { double[][] features = new double[][]{{4, 1}, {3, 3}, {2, 0}, {1, 4}, {0, 2}}; double[] labels = new double[]{0, 1, 2, 3, 4}; - DecisionTreeData data = new DecisionTreeData(features, labels, useIndex); + DecisionTreeData data = new DecisionTreeData(features, labels, useIdx); data.sort(0); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java index 78bdfdf..b8ad49a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/TreeDataIndexTest.java @@ -75,41 +75,41 @@ public class TreeDataIndexTest { }; /** */ - private TreeDataIndex index = new TreeDataIndex(features, labels); + private TreeDataIndex idx = new TreeDataIndex(features, labels); /** */ @Test public void labelInSortedOrderTest() { - assertEquals(features.length, index.rowsCount()); - assertEquals(features[0].length, index.columnsCount()); + assertEquals(features.length, idx.rowsCount()); + assertEquals(features[0].length, idx.columnsCount()); - for (int k = 0; k < index.rowsCount(); k++) { - for (int featureId = 0; featureId < index.columnsCount(); featureId++) - assertEquals(labelsInSortedOrder[k][featureId], index.labelInSortedOrder(k, featureId), 0.01); + for (int k = 0; k < idx.rowsCount(); k++) { + for (int featureId = 0; featureId < idx.columnsCount(); featureId++) + assertEquals(labelsInSortedOrder[k][featureId], idx.labelInSortedOrder(k, featureId), 0.01); } } /** */ @Test public void featuresInSortedOrderTest() { - assertEquals(features.length, index.rowsCount()); - assertEquals(features[0].length, index.columnsCount()); + assertEquals(features.length, idx.rowsCount()); + assertEquals(features[0].length, idx.columnsCount()); - for (int k = 0; k < index.rowsCount(); k++) { - for (int featureId = 0; featureId < index.columnsCount(); featureId++) - assertArrayEquals(featuresInSortedOrder[k][featureId], index.featuresInSortedOrder(k, featureId), 0.01); + for (int k = 0; k < idx.rowsCount(); k++) { + for (int featureId = 0; featureId < idx.columnsCount(); featureId++) + assertArrayEquals(featuresInSortedOrder[k][featureId], idx.featuresInSortedOrder(k, featureId), 0.01); } } /** */ @Test public void featureInSortedOrderTest() { - assertEquals(features.length, index.rowsCount()); - assertEquals(features[0].length, index.columnsCount()); + assertEquals(features.length, idx.rowsCount()); + assertEquals(features[0].length, idx.columnsCount()); - for (int k = 0; k < index.rowsCount(); k++) { - for (int featureId = 0; featureId < index.columnsCount(); featureId++) - assertEquals((double)k + 1, index.featureInSortedOrder(k, featureId), 0.01); + for (int k = 0; k < idx.rowsCount(); k++) { + for (int featureId = 0; featureId < idx.columnsCount(); featureId++) + assertEquals((double)k + 1, idx.featureInSortedOrder(k, featureId), 0.01); } } @@ -120,9 +120,9 @@ public class TreeDataIndexTest { TreeFilter filter2 = features -> features[1] > 2; TreeFilter filterAnd = filter1.and(features -> features[1] > 2); - TreeDataIndex filtered1 = index.filter(filter1); + TreeDataIndex filtered1 = idx.filter(filter1); TreeDataIndex filtered2 = filtered1.filter(filter2); - TreeDataIndex filtered3 = index.filter(filterAnd); + TreeDataIndex filtered3 = idx.filter(filterAnd); assertEquals(2, filtered1.rowsCount()); assertEquals(4, filtered1.columnsCount()); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java index a328bd7..0c77a2c 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java @@ -45,7 +45,7 @@ public class GiniImpurityMeasureCalculatorTest { /** Use index. */ @Parameterized.Parameter - public boolean useIndex; + public boolean useIdx; /** */ @Test @@ -56,9 +56,9 @@ public class GiniImpurityMeasureCalculatorTest { Map<Double, Integer> encoder = new HashMap<>(); encoder.put(0.0, 0); encoder.put(1.0, 1); - GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx); - StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0); + StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0); assertEquals(2, impurity.length); @@ -88,9 +88,9 @@ public class GiniImpurityMeasureCalculatorTest { Map<Double, Integer> encoder = new HashMap<>(); encoder.put(0.0, 0); encoder.put(1.0, 1); - GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx); - StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0); + StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0); assertEquals(1, impurity.length); @@ -111,7 +111,7 @@ public class GiniImpurityMeasureCalculatorTest { encoder.put(1.0, 1); encoder.put(2.0, 2); - GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIndex); + GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx); assertEquals(0, calculator.getLabelCode(0.0)); assertEquals(1, calculator.getLabelCode(1.0)); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java index 82b3805..ed1fce0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java @@ -43,7 +43,7 @@ public class MSEImpurityMeasureCalculatorTest { /** Use index. */ @Parameterized.Parameter - public boolean useIndex; + public boolean useIdx; /** */ @Test @@ -51,9 +51,9 @@ public class MSEImpurityMeasureCalculatorTest { double[][] data = new double[][]{{0, 2}, {1, 1}, {2, 0}, {3, 3}}; double[] labels = new double[]{1, 2, 2, 1}; - MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIndex); + MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIdx); - StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIndex), fs -> true, 0); + StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0); assertEquals(2, impurity.length); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index 087f4e8..3a038ff 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -19,16 +19,14 @@ package org.apache.ignite.ml.tree.randomforest; import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.dataset.feature.FeatureMeta; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -36,31 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Tests for {@link RandomForestClassifierTrainer}. */ -@RunWith(Parameterized.class) -public class RandomForestClassifierTrainerTest { - /** - * Number of parts to be tested. - */ - private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; - - /** - * Number of partitions. - */ - @Parameterized.Parameter - public int parts; - - /** - * Data iterator. - */ - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - List<Integer[]> res = new ArrayList<>(); - for (int part : partsToBeTested) - res.add(new Integer[] {part}); - - return res; - } - +public class RandomForestClassifierTrainerTest extends TrainerTest { /** */ @Test public void testFit() { @@ -79,7 +53,7 @@ public class RandomForestClassifierTrainerTest { for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false)); RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta) - .withCountOfTrees(5) + .withAmountOfTrees(5) .withFeaturesCountSelectionStrgy(x -> 2); ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); @@ -106,15 +80,15 @@ public class RandomForestClassifierTrainerTest { for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false)); RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta) - .withCountOfTrees(100) + .withAmountOfTrees(100) .withFeaturesCountSelectionStrgy(x -> 2); - ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); - ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); - ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition originalMdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005); - assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.01); - assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.01); + assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), 0.01); + assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), 0.01); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java index fcc20bd..08ff95d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java @@ -19,16 +19,14 @@ package org.apache.ignite.ml.tree.randomforest; import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; import org.apache.ignite.ml.dataset.feature.FeatureMeta; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -36,28 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Tests for {@link RandomForestRegressionTrainer}. */ -@RunWith(Parameterized.class) -public class RandomForestRegressionTrainerTest { - /** - * Number of parts to be tested. - */ - private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; - - /** - * Number of partitions. - */ - @Parameterized.Parameter - public int parts; - - @Parameterized.Parameters(name = "Data divided on {0} partitions") - public static Iterable<Integer[]> data() { - List<Integer[]> res = new ArrayList<>(); - for (int part : partsToBeTested) - res.add(new Integer[] {part}); - - return res; - } - +public class RandomForestRegressionTrainerTest extends TrainerTest { /** */ @Test public void testFit() { @@ -76,7 +53,7 @@ public class RandomForestRegressionTrainerTest { for(int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false)); RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta) - .withCountOfTrees(5) + .withAmountOfTrees(5) .withFeaturesCountSelectionStrgy(x -> 2); ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(v), (k, v) -> k); @@ -102,15 +79,15 @@ public class RandomForestRegressionTrainerTest { for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false)); RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(meta) - .withCountOfTrees(100) + .withAmountOfTrees(100) .withFeaturesCountSelectionStrgy(x -> 2); - ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); - ModelsComposition updatedOnSameDS = trainer.update(originalModel, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); - ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition originalMdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); + ModelsComposition updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005); - assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.1); - assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.1); + assertEquals(originalMdl.apply(v), updatedOnSameDS.apply(v), 0.1); + assertEquals(originalMdl.apply(v), updatedOnEmptyDS.apply(v), 0.1); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java index 9fa7f0e..eb81b36 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java @@ -34,7 +34,7 @@ public class RandomForestTest { private final long seed = 0; /** Count of trees. */ - private final int countOfTrees = 10; + private final int cntOfTrees = 10; /** Min imp delta. */ private final double minImpDelta = 1.0; @@ -55,7 +55,7 @@ public class RandomForestTest { /** Rf. */ private RandomForestClassifierTrainer rf = new RandomForestClassifierTrainer(meta) - .withCountOfTrees(countOfTrees) + .withAmountOfTrees(cntOfTrees) .withSeed(seed) .withFeaturesCountSelectionStrgy(x -> 4) .withMaxDepth(maxDepth) http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniFeatureHistogramTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniFeatureHistogramTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniFeatureHistogramTest.java index 7ca6411..a82bb95 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniFeatureHistogramTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniFeatureHistogramTest.java @@ -44,7 +44,7 @@ public class GiniFeatureHistogramTest extends ImpurityHistogramTest { /** */ @Before - public void setUp() throws Exception { + public void setUp() { feature2Meta.setMinVal(-5); feature2Meta.setBucketSize(1); } @@ -129,12 +129,13 @@ public class GiniFeatureHistogramTest extends ImpurityHistogramTest { NodeSplit catSplit = catFeatureSmpl1.findBestSplit().get(); NodeSplit contSplit = contFeatureSmpl1.findBestSplit().get(); - assertEquals(1.0, catSplit.getValue(), 0.01); - assertEquals(-0.5, contSplit.getValue(), 0.01); + assertEquals(1.0, catSplit.getVal(), 0.01); + assertEquals(-0.5, contSplit.getVal(), 0.01); assertFalse(emptyHist.findBestSplit().isPresent()); assertFalse(catFeatureSmpl2.findBestSplit().isPresent()); } + /** */ @Test public void testOfSums() { int sampleId = 0; @@ -148,22 +149,22 @@ public class GiniFeatureHistogramTest extends ImpurityHistogramTest { List<GiniHistogram> partitions1 = new ArrayList<>(); List<GiniHistogram> partitions2 = new ArrayList<>(); - int countOfPartitions = rnd.nextInt(1000); - for(int i = 0; i < countOfPartitions; i++) { + int cntOfPartitions = rnd.nextInt(1000); + for (int i = 0; i < cntOfPartitions; i++) { partitions1.add(new GiniHistogram(sampleId,lblMapping, bucketMeta1)); partitions2.add(new GiniHistogram(sampleId,lblMapping, bucketMeta2)); } int datasetSize = rnd.nextInt(10000); for(int i = 0; i < datasetSize; i++) { - BootstrappedVector vec = randomVector(2, 1, true); + BootstrappedVector vec = randomVector(true); vec.features().set(1, (vec.features().get(1) * 100) % 100); forAllHist1.addElement(vec); forAllHist2.addElement(vec); - int partitionId = rnd.nextInt(countOfPartitions); - partitions1.get(partitionId).addElement(vec); - partitions2.get(partitionId).addElement(vec); + int partId = rnd.nextInt(cntOfPartitions); + partitions1.get(partId).addElement(vec); + partitions2.get(partId).addElement(vec); } checkSums(forAllHist1, partitions1); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramTest.java index df4c154..54bd0df 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramTest.java @@ -32,9 +32,17 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertTrue; +/** + * Tests for {@link ImpurityHistogram}. + */ public class ImpurityHistogramTest { - protected static final int COUNT_OF_CLASSES = 3; - protected static final Map<Double, Integer> lblMapping = new HashMap<>(); + /** Count of classes. */ + private static final int COUNT_OF_CLASSES = 3; + + /** Lbl mapping. */ + static final Map<Double, Integer> lblMapping = new HashMap<>(); + + /** Random generator. */ protected Random rnd = new Random(); static { @@ -42,28 +50,41 @@ public class ImpurityHistogramTest { lblMapping.put((double)i, i); } - protected void checkBucketIds(Set<Integer> bucketIdsSet, Integer[] expected) { + /** */ + void checkBucketIds(Set<Integer> bucketIdsSet, Integer[] exp) { Integer[] bucketIds = new Integer[bucketIdsSet.size()]; bucketIdsSet.toArray(bucketIds); - assertArrayEquals(expected, bucketIds); + assertArrayEquals(exp, bucketIds); } - protected void checkCounters(ObjectHistogram<BootstrappedVector> hist, double[] expected) { + /** */ + void checkCounters(ObjectHistogram<BootstrappedVector> hist, double[] exp) { double[] counters = hist.buckets().stream().mapToDouble(x -> hist.getValue(x).get()).toArray(); - assertArrayEquals(expected, counters, 0.01); + assertArrayEquals(exp, counters, 0.01); } - protected BootstrappedVector randomVector(int countOfFeatures, int countOfSampes, boolean isClassification) { - double[] features = DoubleStream.generate(() -> rnd.nextDouble()).limit(countOfFeatures).toArray(); - int[] counters = IntStream.generate(() -> rnd.nextInt(10)).limit(countOfSampes).toArray(); + /** + * Generates random vector. + * + * @param isClassification Is classification. + */ + BootstrappedVector randomVector(boolean isClassification) { + double[] features = DoubleStream.generate(() -> rnd.nextDouble()).limit(2).toArray(); + int[] counters = IntStream.generate(() -> rnd.nextInt(10)).limit(1).toArray(); double lbl = isClassification ? Math.abs(rnd.nextInt() % COUNT_OF_CLASSES) : rnd.nextDouble(); return new BootstrappedVector(VectorUtils.of(features), lbl, counters); } - protected <T extends Histogram<BootstrappedVector, T>> void checkSums(T expected, List<T> partitions) { + /** + * Check sums. + * + * @param exp Expected value. + * @param partitions Partitions. + */ + <T extends Histogram<BootstrappedVector, T>> void checkSums(T exp, List<T> partitions) { T leftSum = partitions.stream().reduce((x,y) -> x.plus(y)).get(); T rightSum = partitions.stream().reduce((x,y) -> y.plus(x)).get(); - assertTrue(expected.isEqualTo(leftSum)); - assertTrue(expected.isEqualTo(rightSum)); + assertTrue(exp.isEqualTo(leftSum)); + assertTrue(exp.isEqualTo(rightSum)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogramTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogramTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogramTest.java index 41bd5ff..872ecec 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogramTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogramTest.java @@ -82,6 +82,7 @@ public class MSEHistogramTest extends ImpurityHistogramTest { checkCounters(contHist2.getSumOfSquaredLabels(), new double[]{ 2 * 5 * 5, 2 * 1 * 1, 1 * 4 * 4, 1 * 2 * 2, 0 * 3 * 3 }); } + /** */ @Test public void testOfSums() { int sampleId = 0; @@ -95,22 +96,24 @@ public class MSEHistogramTest extends ImpurityHistogramTest { List<MSEHistogram> partitions1 = new ArrayList<>(); List<MSEHistogram> partitions2 = new ArrayList<>(); - int countOfPartitions = rnd.nextInt(100); - for(int i = 0; i < countOfPartitions; i++) { + + int cntOfPartitions = rnd.nextInt(100); + + for (int i = 0; i < cntOfPartitions; i++) { partitions1.add(new MSEHistogram(sampleId, bucketMeta1)); partitions2.add(new MSEHistogram(sampleId, bucketMeta2)); } int datasetSize = rnd.nextInt(1000); for(int i = 0; i < datasetSize; i++) { - BootstrappedVector vec = randomVector(2, 1, false); + BootstrappedVector vec = randomVector(false); vec.features().set(1, (vec.features().get(1) * 100) % 100); forAllHist1.addElement(vec); forAllHist2.addElement(vec); - int partitionId = rnd.nextInt(countOfPartitions); - partitions1.get(partitionId).addElement(vec); - partitions2.get(partitionId).addElement(vec); + int partId = rnd.nextInt(cntOfPartitions); + partitions1.get(partId).addElement(vec); + partitions2.get(partId).addElement(vec); } checkSums(forAllHist1, partitions1); http://git-wip-us.apache.org/repos/asf/ignite/blob/eff5751e/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/statistics/NormalDistributionStatisticsComputerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/statistics/NormalDistributionStatisticsComputerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/statistics/NormalDistributionStatisticsComputerTest.java index 79ee3b6..c65a9ac 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/statistics/NormalDistributionStatisticsComputerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/statistics/NormalDistributionStatisticsComputerTest.java @@ -54,13 +54,14 @@ public class NormalDistributionStatisticsComputerTest { new BootstrappedVector(VectorUtils.of(9, 0, 11, 2, 13, 3, 15), 0., null), }); + /** Normal Distribution Statistics Computer. */ private NormalDistributionStatisticsComputer computer = new NormalDistributionStatisticsComputer(); /** */ @Test public void computeStatsOnPartitionTest() { - List<NormalDistributionStatistics> result = computer.computeStatsOnPartition(partition, meta); - NormalDistributionStatistics[] expected = new NormalDistributionStatistics[] { + List<NormalDistributionStatistics> res = computer.computeStatsOnPartition(partition, meta); + NormalDistributionStatistics[] exp = new NormalDistributionStatistics[] { new NormalDistributionStatistics(0, 9, 285, 45, 10), new NormalDistributionStatistics(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0, 0, 10), new NormalDistributionStatistics(2, 11, 505, 65, 10), @@ -70,15 +71,15 @@ public class NormalDistributionStatisticsComputerTest { new NormalDistributionStatistics(6, 15, 1185, 105, 10), }; - assertEquals(expected.length, result.size()); - for (int i = 0; i < expected.length; i++) { - NormalDistributionStatistics expectedStat = expected[i]; - NormalDistributionStatistics resultStat = result.get(i); - assertEquals(expectedStat.mean(), resultStat.mean(), 0.01); - assertEquals(expectedStat.variance(), resultStat.variance(), 0.01); - assertEquals(expectedStat.std(), resultStat.std(), 0.01); - assertEquals(expectedStat.min(), resultStat.min(), 0.01); - assertEquals(expectedStat.max(), resultStat.max(), 0.01); + assertEquals(exp.length, res.size()); + for (int i = 0; i < exp.length; i++) { + NormalDistributionStatistics expStat = exp[i]; + NormalDistributionStatistics resStat = res.get(i); + assertEquals(expStat.mean(), resStat.mean(), 0.01); + assertEquals(expStat.variance(), resStat.variance(), 0.01); + assertEquals(expStat.std(), resStat.std(), 0.01); + assertEquals(expStat.min(), resStat.min(), 0.01); + assertEquals(expStat.max(), resStat.max(), 0.01); } } @@ -105,8 +106,8 @@ public class NormalDistributionStatisticsComputerTest { new NormalDistributionStatistics(0, 9, 285, 45, 10) ); - List<NormalDistributionStatistics> result = computer.reduceStats(left, right, meta); - NormalDistributionStatistics[] expected = new NormalDistributionStatistics[] { + List<NormalDistributionStatistics> res = computer.reduceStats(left, right, meta); + NormalDistributionStatistics[] exp = new NormalDistributionStatistics[] { new NormalDistributionStatistics(0, 15, 1470, 150, 20), new NormalDistributionStatistics(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, 0, 0, 10), new NormalDistributionStatistics(2, 13, 1310, 150, 20), @@ -116,15 +117,15 @@ public class NormalDistributionStatisticsComputerTest { new NormalDistributionStatistics(0, 15, 1470, 150, 20) }; - assertEquals(expected.length, result.size()); - for (int i = 0; i < expected.length; i++) { - NormalDistributionStatistics expectedStat = expected[i]; - NormalDistributionStatistics resultStat = result.get(i); - assertEquals(expectedStat.mean(), resultStat.mean(), 0.01); - assertEquals(expectedStat.variance(), resultStat.variance(), 0.01); - assertEquals(expectedStat.std(), resultStat.std(), 0.01); - assertEquals(expectedStat.min(), resultStat.min(), 0.01); - assertEquals(expectedStat.max(), resultStat.max(), 0.01); + assertEquals(exp.length, res.size()); + for (int i = 0; i < exp.length; i++) { + NormalDistributionStatistics expStat = exp[i]; + NormalDistributionStatistics resStat = res.get(i); + assertEquals(expStat.mean(), resStat.mean(), 0.01); + assertEquals(expStat.variance(), resStat.variance(), 0.01); + assertEquals(expStat.std(), resStat.std(), 0.01); + assertEquals(expStat.min(), resStat.min(), 0.01); + assertEquals(expStat.max(), resStat.max(), 0.01); } } }