Repository: ignite Updated Branches: refs/heads/master 7ac2d4ddb -> 25f83819a
http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java index 8b10aaa..857d9bd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/binarization/BinarizationTrainerTest.java @@ -22,12 +22,15 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; /** * Tests for {@link BinarizationTrainer}. @@ -66,11 +69,38 @@ public class BinarizationTrainerTest { BinarizationTrainer<Integer, double[]> binarizationTrainer = new BinarizationTrainer<Integer, double[]>() .withThreshold(10); + assertEquals(10., binarizationTrainer.threshold(), 0); + BinarizationPreprocessor<Integer, double[]> preprocessor = binarizationTrainer.fit( datasetBuilder, (k, v) -> VectorUtils.of(v) ); + assertEquals(binarizationTrainer.threshold(), preprocessor.threshold(), 0); + + assertArrayEquals(new double[] {0, 0, 1}, preprocessor.apply(5, new double[] {1, 10, 100}).asArray(), 1e-8); + } + + /** Tests default implementation of {@code fit()} method. */ + @Test + public void testFitDefault() { + Map<Integer, double[]> data = new HashMap<>(); + data.put(1, new double[] {2, 4, 1}); + data.put(2, new double[] {1, 8, 22}); + data.put(3, new double[] {4, 10, 100}); + data.put(4, new double[] {0, 22, 300}); + + BinarizationTrainer<Integer, double[]> binarizationTrainer = new BinarizationTrainer<Integer, double[]>() + .withThreshold(10); + + assertEquals(10., binarizationTrainer.threshold(), 0); + + IgniteBiFunction<Integer, double[], Vector> preprocessor = binarizationTrainer.fit( + data, + parts, + (k, v) -> VectorUtils.of(v) + ); + assertArrayEquals(new double[] {0, 0, 1}, preprocessor.apply(5, new double[] {1, 10, 100}).asArray(), 1e-8); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java index b962701..7b02f20 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java @@ -29,6 +29,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; /** * Tests for {@link BinarizationTrainer}. @@ -67,11 +68,15 @@ public class NormalizationTrainerTest { NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<Integer, double[]>() .withP(3); + assertEquals(3., normalizationTrainer.p(), 0); + NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit( datasetBuilder, (k, v) -> VectorUtils.of(v) ); + assertEquals(normalizationTrainer.p(), preprocessor.p(), 0); + assertArrayEquals(new double[] {0.125, 0.99, 0.125}, preprocessor.apply(5, new double[]{1., 8., 1.}).asArray(), 1e-2); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java index 71d831d..66871b0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java @@ -25,6 +25,8 @@ import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionMode import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; import org.junit.Test; +import static org.junit.Assert.assertTrue; + /** * Tests for {@link LinearRegressionModel}. */ @@ -38,6 +40,10 @@ public class LinearRegressionModelTest { Vector weights = new DenseVector(new double[]{2.0, 3.0}); LinearRegressionModel mdl = new LinearRegressionModel(weights, 1.0); + assertTrue(mdl.toString().length() > 0); + assertTrue(mdl.toString(true).length() > 0); + assertTrue(mdl.toString(false).length() > 0); + Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.apply(observation), PRECISION); http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java index 1d25524..e0e6a71 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -93,6 +93,10 @@ public class LogRegMultiClassTrainerTest { (k, v) -> v[0] ); + Assert.assertTrue(mdl.toString().length() > 0); + Assert.assertTrue(mdl.toString(true).length() > 0); + Assert.assertTrue(mdl.toString(false).length() > 0); + TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 10})), PRECISION); TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 100})), PRECISION); } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java index bb6a77d..89c9cca 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java @@ -24,6 +24,10 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + /** * Tests for {@link LogisticRegressionModel}. */ @@ -35,22 +39,19 @@ public class LogisticRegressionModelTest { @Test public void testPredict() { Vector weights = new DenseVector(new double[]{2.0, 3.0}); - LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0).withRawLabels(true); - Vector observation = new DenseVector(new double[]{1.0, 1.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + assertFalse(new LogisticRegressionModel(weights, 1.0).isKeepingRawLabels()); - observation = new DenseVector(new double[]{2.0, 1.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + assertEquals(0.1, new LogisticRegressionModel(weights, 1.0).withThreshold(0.1).threshold(), 0); - observation = new DenseVector(new double[]{1.0, 2.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION); + assertTrue(new LogisticRegressionModel(weights, 1.0).toString().length() > 0); + assertTrue(new LogisticRegressionModel(weights, 1.0).toString(true).length() > 0); + assertTrue(new LogisticRegressionModel(weights, 1.0).toString(false).length() > 0); - observation = new DenseVector(new double[]{-2.0, 1.0}); - TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); - - observation = new DenseVector(new double[]{1.0, -2.0}); - TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION); + verifyPredict(new LogisticRegressionModel(weights, 1.0).withRawLabels(true)); + verifyPredict(new LogisticRegressionModel(null, 1.0).withRawLabels(true).withWeights(weights)); + verifyPredict(new LogisticRegressionModel(weights, 1.0).withRawLabels(true).withThreshold(0.5)); + verifyPredict(new LogisticRegressionModel(weights, 0.0).withRawLabels(true).withIntercept(1.0)); } /** */ @@ -65,6 +66,24 @@ public class LogisticRegressionModelTest { mdl.apply(observation); } + /** */ + private void verifyPredict(LogisticRegressionModel mdl) { + Vector observation = new DenseVector(new double[]{1.0, 1.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + + observation = new DenseVector(new double[]{2.0, 1.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + + observation = new DenseVector(new double[]{1.0, 2.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION); + + observation = new DenseVector(new double[]{-2.0, 1.0}); + TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION); + + observation = new DenseVector(new double[]{1.0, -2.0}); + TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION); + } + /** * Sigmoid function. * @param z The regression value. http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java index ad4aaf1..9dd35ef 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java @@ -34,7 +34,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; /** - * Tests for {@LogisticRegressionSGDTrainer}. + * Tests for {@link LogisticRegressionSGDTrainer}. */ @RunWith(Parameterized.class) public class LogisticRegressionSGDTrainerTest { @@ -85,7 +85,7 @@ public class LogisticRegressionSGDTrainerTest { } LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), + new SimpleGDUpdateCalculator().withLearningRate(0.2), SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg ), 100000, 10, 100, 123L); http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java index 3adae79..21c605b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java @@ -21,6 +21,7 @@ import org.apache.ignite.ml.selection.cv.CrossValidationTest; import org.apache.ignite.ml.selection.paramgrid.ParameterSetGeneratorTest; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTest; import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest; +import org.apache.ignite.ml.selection.scoring.evaluator.EvaluatorTest; import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest; import org.apache.ignite.ml.selection.scoring.metric.FmeasureTest; import org.apache.ignite.ml.selection.scoring.metric.PrecisionTest; @@ -36,6 +37,7 @@ import org.junit.runners.Suite; @RunWith(Suite.class) @Suite.SuiteClasses({ CrossValidationTest.class, + EvaluatorTest.class, ParameterSetGeneratorTest.class, CacheBasedLabelPairCursorTest.class, LocalLabelPairCursorTest.class, http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java new file mode 100644 index 0000000..c98231f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.evaluator; + +import java.text.NumberFormat; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.selection.cv.CrossValidation; +import org.apache.ignite.ml.selection.cv.CrossValidationResult; +import org.apache.ignite.ml.selection.paramgrid.ParamGrid; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import org.apache.ignite.thread.IgniteThread; + +import static org.junit.Assert.assertArrayEquals; + +/** + * Tests for {@link Evaluator} that require to start the whole Ignite infrastructure. IMPL NOTE based on + * Step_8_CV_with_Param_Grid example. + */ +public class EvaluatorTest extends GridCommonAbstractTest { + /** Number of nodes in grid */ + private static final int NODE_COUNT = 3; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** + * {@inheritDoc} + */ + @Override protected void beforeTest() { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + public void testBasic() throws InterruptedException { + AtomicReference<Double> actualAccuracy = new AtomicReference<>(null); + AtomicReference<Double> actualAccuracy2 = new AtomicReference<>(null); + AtomicReference<CrossValidationResult> res = new AtomicReference<>(null); + List<double[]> actualScores = new ArrayList<>(); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + EvaluatorTest.class.getSimpleName(), () -> { + CacheConfiguration<Integer, Object[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName(UUID.randomUUID().toString()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, Object[]> dataCache = ignite.createCache(cacheConfiguration); + + readPassengers(dataCache); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1]; + + TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() + .split(0.75); + + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(2) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + // Tune hyperparams with K-fold Cross-Validation on the split training set. + + DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer(); + + CrossValidation<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator + = new CrossValidation<>(); + + ParamGrid paramGrid = new ParamGrid() + .addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0}) + .addHyperParam("minImpurityDecrease", new Double[] {0.0, 0.25, 0.5}); + + CrossValidationResult crossValidationRes = scoreCalculator.score( + trainerCV, + new Accuracy<>(), + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + lbExtractor, + 3, + paramGrid + ); + + res.set(crossValidationRes); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer() + .withMaxDeep(crossValidationRes.getBest("maxDeep")) + .withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease")); + + crossValidationRes.getScoringBoard().forEach((hyperParams, score) -> actualScores.add(score)); + + // Train decision tree model. + DecisionTreeNode bestMdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + lbExtractor + ); + + double accuracy = Evaluator.evaluate( + dataCache, + split.getTestFilter(), + bestMdl, + normalizationPreprocessor, + lbExtractor, + new Accuracy<>() + ); + + actualAccuracy.set(accuracy); + actualAccuracy2.set(Evaluator.evaluate( + dataCache, + bestMdl, + normalizationPreprocessor, + lbExtractor, + new Accuracy<>() + )); + }); + + igniteThread.start(); + + igniteThread.join(); + + assertResults(res.get(), actualScores, actualAccuracy.get(), actualAccuracy2.get()); + } + + /** */ + private void assertResults(CrossValidationResult res, List<double[]> scores, double accuracy, double accuracy2) { + assertTrue(res.toString().length() > 0); + assertEquals("Best maxDeep", 1.0, res.getBest("maxDeep")); + assertEquals("Best minImpurityDecrease", 0.0, res.getBest("minImpurityDecrease")); + assertArrayEquals("Best score", new double[] {0.6666666666666666, 0.4, 0}, res.getBestScore(), 0); + assertEquals("Best hyper params size", 2, res.getBestHyperParams().size()); + assertEquals("Best average score", 0.35555555555555557, res.getBestAvgScore()); + + assertEquals("Scores amount", 18, scores.size()); + + int idx = 0; + for (double[] actualScore : scores) + assertEquals("Score size at index " + idx++, 3, actualScore.length); + + assertEquals("Accuracy", 1.0, accuracy); + assertTrue("Accuracy without filter", accuracy2 > 0.); + } + + /** + * Read passengers data. + * + * @param cache The ignite cache. + */ + private void readPassengers(IgniteCache<Integer, Object[]> cache) { + // IMPL NOTE: pclass;survived;name;sex;age;sibsp;parch;ticket;fare;cabin;embarked;boat;body;homedest + List<String[]> passengers = Arrays.asList( + new String[] { + "1", "1", "Allen, Miss. Elisabeth Walton", "", + "29", "", "", "24160", "211,3375", "B5", "", "2", "", "St Louis, MO"}, + new String[] { + "1", "1", "Allison, Master. Hudson Trevor", "male", + "0,9167", "1", "2", "113781", "151,55", "C22 C26", "S", "11", "", "Montreal, PQ / Chesterville, ON"}, + new String[] { + "1", "0", "Allison, Miss. Helen Loraine", "female", + "2", "1", "2", "113781", "151,55", "C22 C26", "S", "", "", "Montreal, PQ / Chesterville, ON"}, + new String[] { + "1", "0", "Allison, Mr. Hudson Joshua Creighton", + "male", "30", "1", "2", "113781", "151,55", "C22 C26", "S", "", "135", "Montreal, PQ / Chesterville, ON"}, + new String[] { + "1", "0", "Allison, Mrs. Hudson J C (Bessie Waldo Daniels)", "female", + "25", "1", "2", "113781", "151,55", "C22 C26", "S", "", "", "Montreal, PQ / Chesterville, ON"}, + new String[] { + "1", "1", "Anderson, Mr. Harry", "male", + "48", "0", "0", "19952", "26,55", "E12", "S", "3", "", "New York, NY"}, + new String[] { + "1", "1", "Andrews, Miss. Kornelia Theodosia", "female", + "63", "1", "0", "13502", "77,9583", "D7", "S", "10", "", "Hudson, NY"}, + new String[] { + "1", "0", "Andrews, Mr. Thomas Jr", "male", + "39", "0", "0", "112050", "0", "A36", "S", "", "", "Belfast, NI"}, + new String[] { + "1", "1", "Appleton, Mrs. Edward Dale (Charlotte Lamson)", "female", + "53", "2", "0", "11769", "51,4792", "C101", "S", "D", "", "Bayside, Queens, NY"}, + new String[] { + "1", "0", "Artagaveytia, Mr. Ramon", "male", + "71", "0", "0", "PC 17609", "49,5042", "", "C", "", "22", "Montevideo, Uruguay"}); + + int cnt = 1; + for (String[] details : passengers) { + Object[] data = new Object[details.length]; + + for (int i = 0; i < details.length; i++) + data[i] = doubleOrString(details[i]); + + cache.put(cnt++, data); + } + } + + /** */ + private Object doubleOrString(String data) { + NumberFormat format = NumberFormat.getInstance(Locale.FRANCE); + try { + return data.equals("") ? Double.NaN : Double.valueOf(data); + } + catch (java.lang.NumberFormatException e) { + + try { + return format.parse(data).doubleValue(); + } + catch (ParseException e1) { + return data; + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java index 4f13816..835d08d 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/FmeasureTest.java @@ -31,7 +31,7 @@ public class FmeasureTest { /** */ @Test public void testScore() { - Metric<Integer> scoreCalculator = new Fmeasure<>(1); + Fmeasure<Integer> scoreCalculator = new Fmeasure<>(1); LabelPairCursor<Integer> cursor = new TestLabelPairCursor<>( Arrays.asList(1, 0, 1, 0, 1, 0), http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java index 72f4cd7..d7821d5 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/PrecisionTest.java @@ -31,7 +31,7 @@ public class PrecisionTest { /** */ @Test public void testScore() { - Metric<Integer> scoreCalculator = new Precision<>(0); + Precision<Integer> scoreCalculator = new Precision<>(0); LabelPairCursor<Integer> cursor = new TestLabelPairCursor<>( Arrays.asList(1, 0, 1, 0, 1, 0), http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java index 5df465b..8c92acd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/RecallTest.java @@ -31,7 +31,7 @@ public class RecallTest { /** */ @Test public void testScore() { - Metric<Integer> scoreCalculator = new Recall<>(1); + Recall<Integer> scoreCalculator = new Recall<>(1); LabelPairCursor<Integer> cursor = new TestLabelPairCursor<>( Arrays.asList(1, 0, 1, 0, 1, 0), http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/structures/DatasetStructureTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/structures/DatasetStructureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/structures/DatasetStructureTest.java new file mode 100644 index 0000000..79e7a16 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/structures/DatasetStructureTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for {@link Dataset} basic features. + */ +public class DatasetStructureTest { + /** + * Basic test + */ + @Test + @SuppressWarnings("unchecked") + public void testBasic() { + Assert.assertNull("Feature names constructor", new Dataset<DatasetRow<Vector>>(1, 1, + new String[] {"tests"}, false).data()); + + Dataset<DatasetRow<Vector>> dataset = new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, + new FeatureMetadata[] {}); + + Assert.assertEquals("Expect empty data", 0, dataset.data().length); + Assert.assertEquals("Expect empty meta", 0, dataset.data().length); + Assert.assertFalse("Not distributed by default", dataset.isDistributed()); + + dataset.setData(new DatasetRow[] {new DatasetRow()}); + dataset.setMeta(new FeatureMetadata[] {new FeatureMetadata()}); + dataset.setDistributed(true); + + Assert.assertEquals("Expect non empty data", 1, dataset.data().length); + Assert.assertEquals("Expect non empty meta", 1, dataset.data().length); + Assert.assertTrue("Expect distributed", dataset.isDistributed()); + Assert.assertEquals(1, dataset.meta().length); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/structures/StructuresTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/structures/StructuresTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/structures/StructuresTestSuite.java new file mode 100644 index 0000000..01064a7 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/structures/StructuresTestSuite.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.trees package. + */ +@RunWith(Suite.class) [email protected]({ + DatasetStructureTest.class +}) +public class StructuresTestSuite { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java index 9244c35..9a222c3 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMModelTest.java @@ -53,7 +53,11 @@ public class SVMModelTest { observation = new DenseVector(new double[]{1.0, -2.0}); TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.apply(observation), PRECISION); - Assert.assertEquals(true, mdl.isKeepingRawLabels()); + Assert.assertTrue(mdl.isKeepingRawLabels()); + + Assert.assertTrue(mdl.toString().length() > 0); + Assert.assertTrue(mdl.toString(true).length() > 0); + Assert.assertTrue(mdl.toString(false).length() > 0); } @@ -68,6 +72,10 @@ public class SVMModelTest { mdl.add(2, new SVMLinearBinaryClassificationModel(weights2, 0.0).withRawLabels(true)); mdl.add(2, new SVMLinearBinaryClassificationModel(weights3, 0.0).withRawLabels(true)); + Assert.assertTrue(mdl.toString().length() > 0); + Assert.assertTrue(mdl.toString(true).length() > 0); + Assert.assertTrue(mdl.toString(false).length() > 0); + Vector observation = new DenseVector(new double[]{1.0, 1.0}); TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java index c84da12..b82885e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java @@ -29,6 +29,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import static junit.framework.TestCase.assertEquals; +import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertTrue; /** @@ -40,12 +41,12 @@ public class DecisionTreeClassificationTrainerTest { private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7}; /** Number of partitions. */ - @Parameterized.Parameter(0) + @Parameterized.Parameter() public int parts; /** Use index [= 1 if true]. */ @Parameterized.Parameter(1) - public int useIndex; + public int useIdx; /** Test parameters. */ @Parameterized.Parameters(name = "Data divided on {0} partitions. Use index = {1}.") @@ -73,7 +74,7 @@ public class DecisionTreeClassificationTrainerTest { } DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0) - .withUseIndex(useIndex == 1); + .withUseIndex(useIdx == 1); DecisionTreeNode tree = trainer.fit( data, @@ -87,6 +88,10 @@ public class DecisionTreeClassificationTrainerTest { DecisionTreeConditionalNode node = (DecisionTreeConditionalNode)tree; assertEquals(0, node.getThreshold(), 1e-3); + assertEquals(0, node.getCol()); + assertNotNull(node.toString()); + assertNotNull(node.toString(true)); + assertNotNull(node.toString(false)); assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode); assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode); @@ -96,5 +101,9 @@ public class DecisionTreeClassificationTrainerTest { assertEquals(1, thenNode.getVal(), 1e-10); assertEquals(0, elseNode.getVal(), 1e-10); + + assertNotNull(thenNode.toString()); + assertNotNull(thenNode.toString(true)); + assertNotNull(thenNode.toString(false)); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java index 001404f..579d592 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/util/SimpleStepFunctionCompressorTest.java @@ -27,6 +27,39 @@ import static org.junit.Assert.assertArrayEquals; public class SimpleStepFunctionCompressorTest { /** */ @Test + @SuppressWarnings("unchecked") + public void testDefaultCompress() { + StepFunction<TestImpurityMeasure> function = new StepFunction<>( + new double[]{1, 2, 3, 4}, + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4) + ); + + SimpleStepFunctionCompressor<TestImpurityMeasure> compressor = new SimpleStepFunctionCompressor<>(); + + StepFunction<TestImpurityMeasure> resFunction = compressor.compress(new StepFunction [] {function})[0]; + + assertArrayEquals(new double[]{1, 2, 3, 4}, resFunction.getX(), 1e-10); + assertArrayEquals(TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4), resFunction.getY()); + } + + /** */ + @Test + public void testDefaults() { + StepFunction<TestImpurityMeasure> function = new StepFunction<>( + new double[]{1, 2, 3, 4}, + TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4) + ); + + SimpleStepFunctionCompressor<TestImpurityMeasure> compressor = new SimpleStepFunctionCompressor<>(); + + StepFunction<TestImpurityMeasure> resFunction = compressor.compress(function); + + assertArrayEquals(new double[]{1, 2, 3, 4}, resFunction.getX(), 1e-10); + assertArrayEquals(TestImpurityMeasure.asTestImpurityMeasures(1, 2, 3, 4), resFunction.getY()); + } + + /** */ + @Test public void testCompressSmallFunction() { StepFunction<TestImpurityMeasure> function = new StepFunction<>( new double[]{1, 2, 3, 4}, http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index 055223b..d06fa50 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -33,6 +33,9 @@ import org.junit.runners.Parameterized; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +/** + * Tests for {@link RandomForestClassifierTrainer}. + */ @RunWith(Parameterized.class) public class RandomForestClassifierTrainerTest { /** @@ -46,6 +49,9 @@ public class RandomForestClassifierTrainerTest { @Parameterized.Parameter public int parts; + /** + * Data iterator. + */ @Parameterized.Parameters(name = "Data divided on {0} partitions") public static Iterable<Integer[]> data() { List<Integer[]> res = new ArrayList<>(); @@ -69,13 +75,24 @@ public class RandomForestClassifierTrainerTest { sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2)); } - RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(4, 3, 5, 0.3, 4, 0.1); + RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(4, 3, 5, 0.3, 4, 0.1) + .withUseIndex(false); ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); mdl.getModels().forEach(m -> { assertTrue(m instanceof ModelOnFeaturesSubspace); - assertTrue(((ModelOnFeaturesSubspace) m).getMdl() instanceof DecisionTreeConditionalNode); + + ModelOnFeaturesSubspace mdlOnFeaturesSubspace = (ModelOnFeaturesSubspace) m; + + assertTrue(mdlOnFeaturesSubspace.getMdl() instanceof DecisionTreeConditionalNode); + + assertTrue(mdlOnFeaturesSubspace.getFeaturesMapping().size() > 0); + + String expClsName = "ModelOnFeatureSubspace"; + assertTrue(mdlOnFeaturesSubspace.toString().contains(expClsName)); + assertTrue(mdlOnFeaturesSubspace.toString(true).contains(expClsName)); + assertTrue(mdlOnFeaturesSubspace.toString(false).contains(expClsName)); }); assertTrue(mdl.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator); http://git-wip-us.apache.org/repos/asf/ignite/blob/25f83819/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java index 1421e0a..987176e 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java @@ -33,6 +33,9 @@ import org.junit.runners.Parameterized; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +/** + * Tests for {@link RandomForestRegressionTrainer}. + */ @RunWith(Parameterized.class) public class RandomForestRegressionTrainerTest { /** @@ -69,7 +72,8 @@ public class RandomForestRegressionTrainerTest { sample.put(x1 * x2 + x3 * x4, new double[] {x1, x2, x3, x4}); } - RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1); + RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(4, 3, 5, 0.3, 4, 0.1) + .withUseIndex(false); ModelsComposition mdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(v), (k, v) -> k);
