IGNITE-8176: Integrate gradient descent linear regression with partition based dataset
this closes #3787 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/df6356d5 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/df6356d5 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/df6356d5 Branch: refs/heads/ignite-7708 Commit: df6356d5d1470337a6ea705a332cf07f1dce2222 Parents: 67023a8 Author: dmitrievanthony <dmitrievanth...@gmail.com> Authored: Thu Apr 12 11:16:22 2018 +0300 Committer: YuriBabak <y.ch...@gmail.com> Committed: Thu Apr 12 11:16:22 2018 +0300 ---------------------------------------------------------------------- .../ml/knn/KNNClassificationExample.java | 11 +- .../examples/ml/nn/MLPTrainerExample.java | 4 +- .../ml/preprocessing/NormalizationExample.java | 17 +-- ...nWithLSQRTrainerAndNormalizationExample.java | 23 ++-- ...dLinearRegressionWithLSQRTrainerExample.java | 14 +-- ...tedLinearRegressionWithQRTrainerExample.java | 9 +- ...edLinearRegressionWithSGDTrainerExample.java | 78 +++++++++--- .../binary/SVMBinaryClassificationExample.java | 11 +- .../SVMMultiClassClassificationExample.java | 24 ++-- ...ecisionTreeClassificationTrainerExample.java | 7 +- .../DecisionTreeRegressionTrainerExample.java | 4 +- .../org/apache/ignite/ml/nn/Activators.java | 20 ++++ .../org/apache/ignite/ml/nn/MLPTrainer.java | 46 ++++++-- .../ml/preprocessing/PreprocessingTrainer.java | 41 ++++++- .../normalization/NormalizationTrainer.java | 35 ++++-- .../linear/FeatureExtractorWrapper.java | 55 +++++++++ .../linear/LinearRegressionLSQRTrainer.java | 38 +----- .../linear/LinearRegressionSGDTrainer.java | 118 +++++++++++++------ .../ignite/ml/trainers/DatasetTrainer.java | 46 ++++++++ .../ignite/ml/knn/KNNClassificationTest.java | 20 ++-- .../ignite/ml/nn/MLPTrainerIntegrationTest.java | 14 +-- .../org/apache/ignite/ml/nn/MLPTrainerTest.java | 22 ++-- .../MLPTrainerMnistIntegrationTest.java | 7 +- .../ml/nn/performance/MLPTrainerMnistTest.java | 11 +- .../normalization/NormalizationTrainerTest.java | 10 +- .../ml/regressions/RegressionsTestSuite.java | 15 +-- ...stributedLinearRegressionSGDTrainerTest.java | 35 ------ ...stributedLinearRegressionSGDTrainerTest.java | 35 ------ ...wareAbstractLinearRegressionTrainerTest.java | 3 + .../linear/LinearRegressionLSQRTrainerTest.java | 14 ++- .../linear/LinearRegressionSGDTrainerTest.java | 94 +++++++++++++++ .../LocalLinearRegressionSGDTrainerTest.java | 35 ------ .../ignite/ml/svm/SVMBinaryTrainerTest.java | 11 +- .../ignite/ml/svm/SVMMultiClassTrainerTest.java | 11 +- ...reeClassificationTrainerIntegrationTest.java | 9 +- .../DecisionTreeClassificationTrainerTest.java | 12 +- ...ionTreeRegressionTrainerIntegrationTest.java | 9 +- .../tree/DecisionTreeRegressionTrainerTest.java | 12 +- .../DecisionTreeMNISTIntegrationTest.java | 7 +- .../tree/performance/DecisionTreeMNISTTest.java | 11 +- 40 files changed, 612 insertions(+), 386 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java index f3cdbbe..39a8431 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java @@ -17,9 +17,6 @@ package org.apache.ignite.examples.ml.knn; -import java.util.Arrays; -import java.util.UUID; -import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; @@ -27,7 +24,6 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; import org.apache.ignite.ml.knn.classification.KNNStrategy; @@ -35,6 +31,10 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.thread.IgniteThread; +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + /** * Run kNN multi-class classification trainer over distributed dataset. * @@ -56,7 +56,8 @@ public class KNNClassificationExample { KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), + ignite, + dataCache, (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ).withK(3) http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java index efa1ba7..ce44cc6 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java @@ -23,7 +23,6 @@ import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.examples.ExampleNodeStartup; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.Activators; @@ -99,7 +98,8 @@ public class MLPTrainerExample { // Train neural network and get multilayer perceptron model. MultilayerPerceptron mlp = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, trainingSet), + ignite, + trainingSet, (k, v) -> new double[] {v.x, v.y}, (k, v) -> new double[] {v.lb} ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java index e0bcd08..b2c4e12 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java @@ -17,21 +17,19 @@ package org.apache.ignite.examples.ml.preprocessing; -import java.util.Arrays; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.examples.ml.dataset.model.Person; -import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.DatasetFactory; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.primitive.SimpleDataset; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import java.util.Arrays; + /** * Example that shows how to use normalization preprocessor to normalize data. * @@ -47,8 +45,6 @@ public class NormalizationExample { IgniteCache<Integer, Person> persons = createCache(ignite); - DatasetBuilder<Integer, Person> builder = new CacheBasedDatasetBuilder<>(ignite, persons); - // Defines first preprocessor that extracts features from an upstream data. IgniteBiFunction<Integer, Person, double[]> featureExtractor = (k, v) -> new double[] { v.getAge(), @@ -56,14 +52,11 @@ public class NormalizationExample { }; // Defines second preprocessor that normalizes features. - NormalizationPreprocessor<Integer, Person> preprocessor = new NormalizationTrainer<Integer, Person>() - .fit(builder, featureExtractor, 2); + IgniteBiFunction<Integer, Person, double[]> preprocessor = new NormalizationTrainer<Integer, Person>() + .fit(ignite, persons, featureExtractor); // Creates a cache based simple dataset containing features and providing standard dataset API. - try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset( - builder, - preprocessor - )) { + try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) { // Calculation of the mean value. This calculation will be performed in map-reduce manner. double[] mean = dataset.mean(); System.out.println("Mean \n\t" + Arrays.toString(mean)); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java index 567a599..99e6577 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample.java @@ -17,9 +17,6 @@ package org.apache.ignite.examples.ml.regression.linear; -import java.util.Arrays; -import java.util.UUID; -import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; @@ -28,7 +25,7 @@ import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; @@ -36,6 +33,10 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.thread.IgniteThread; +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + /** * Run linear regression model over distributed matrix. * @@ -119,21 +120,17 @@ public class DistributedLinearRegressionWithLSQRTrainerAndNormalizationExample { NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>(); System.out.println(">>> Perform the training to get the normalization preprocessor."); - NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), - (k, v) -> Arrays.copyOfRange(v, 1, v.length), - 4 + IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length) ); System.out.println(">>> Create new linear regression trainer object."); LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), - preprocessor, - (k, v) -> v[0] - ); + LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v[0]); System.out.println(">>> Linear regression model: " + mdl); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java index a853092..25aec0c 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithLSQRTrainerExample.java @@ -17,9 +17,6 @@ package org.apache.ignite.examples.ml.regression.linear; -import java.util.Arrays; -import java.util.UUID; -import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; @@ -27,13 +24,15 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.thread.IgniteThread; +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + /** * Run linear regression model over distributed matrix. * @@ -108,7 +107,7 @@ public class DistributedLinearRegressionWithLSQRTrainerExample { // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread // because we create ignite cache internally. IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { + DistributedLinearRegressionWithLSQRTrainerExample.class.getSimpleName(), () -> { IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); System.out.println(">>> Create new linear regression trainer object."); @@ -116,7 +115,8 @@ public class DistributedLinearRegressionWithLSQRTrainerExample { System.out.println(">>> Perform the training to get the model."); LinearRegressionModel mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), + ignite, + dataCache, (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java index 2b45aa2..98d5e4e 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithQRTrainerExample.java @@ -17,7 +17,6 @@ package org.apache.ignite.examples.ml.regression.linear; -import java.util.Arrays; import org.apache.ignite.Ignite; import org.apache.ignite.Ignition; import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; @@ -30,6 +29,8 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; import org.apache.ignite.thread.IgniteThread; +import java.util.Arrays; + /** * Run linear regression model over distributed matrix. * @@ -113,15 +114,15 @@ public class DistributedLinearRegressionWithQRTrainerExample { Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionQRTrainer(); System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel model = trainer.train(distributedMatrix); - System.out.println(">>> Linear regression model: " + model); + LinearRegressionModel mdl = trainer.train(distributedMatrix); + System.out.println(">>> Linear regression model: " + mdl); System.out.println(">>> ---------------------------------"); System.out.println(">>> | Prediction\t| Ground Truth\t|"); System.out.println(">>> ---------------------------------"); for (double[] observation : data) { Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); - double prediction = model.apply(inputs); + double prediction = mdl.apply(inputs); double groundTruth = observation[0]; System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java index f3b2655..44366e1 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/DistributedLinearRegressionWithSGDTrainerExample.java @@ -17,20 +17,26 @@ package org.apache.ignite.examples.ml.regression.linear; -import java.util.Arrays; import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; import org.apache.ignite.ml.regressions.linear.LinearRegressionQRTrainer; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer; +import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.thread.IgniteThread; +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + /** * Run linear regression model over distributed matrix. * @@ -104,28 +110,43 @@ public class DistributedLinearRegressionWithSGDTrainerExample { // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread // because we create ignite cache internally. IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - SparseDistributedMatrixExample.class.getSimpleName(), () -> { + DistributedLinearRegressionWithSGDTrainerExample.class.getSimpleName(), () -> { - // Create SparseDistributedMatrix, new cache will be created automagically. - System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(data); + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); System.out.println(">>> Create new linear regression trainer object."); - Trainer<LinearRegressionModel, Matrix> trainer = new LinearRegressionSGDTrainer(100_000, 1e-12); + LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>( + new RPropUpdateCalculator(), + RPropParameterUpdate::sumLocal, + RPropParameterUpdate::avg + ), 100000, 10, 100, 123L); System.out.println(">>> Perform the training to get the model."); - LinearRegressionModel model = trainer.train(distributedMatrix); - System.out.println(">>> Linear regression model: " + model); + LinearRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ); + + System.out.println(">>> Linear regression model: " + mdl); System.out.println(">>> ---------------------------------"); System.out.println(">>> | Prediction\t| Ground Truth\t|"); System.out.println(">>> ---------------------------------"); - for (double[] observation : data) { - Vector inputs = new SparseDistributedVector(Arrays.copyOfRange(observation, 1, observation.length)); - double prediction = model.apply(inputs); - double groundTruth = observation[0]; - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } } + System.out.println(">>> ---------------------------------"); }); @@ -134,4 +155,23 @@ public class DistributedLinearRegressionWithSGDTrainerExample { igniteThread.join(); } } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java index f8bf521..ce37112 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java @@ -17,9 +17,6 @@ package org.apache.ignite.examples.ml.svm.binary; -import java.util.Arrays; -import java.util.UUID; -import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; @@ -27,12 +24,15 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel; import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer; import org.apache.ignite.thread.IgniteThread; +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + /** * Run SVM binary-class classification model over distributed dataset. * @@ -54,7 +54,8 @@ public class SVMBinaryClassificationExample { SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer(); SVMLinearBinaryClassificationModel mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), + ignite, + dataCache, (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java index f8281e4..4054201 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/multiclass/SVMMultiClassClassificationExample.java @@ -17,9 +17,6 @@ package org.apache.ignite.examples.ml.svm.multiclass; -import java.util.Arrays; -import java.util.UUID; -import javax.cache.Cache; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; @@ -27,14 +24,17 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.preprocessing.normalization.NormalizationPreprocessor; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationTrainer; import org.apache.ignite.thread.IgniteThread; +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + /** * Run SVM multi-class classification trainer over distributed dataset to build two models: * one with normalization and one without normalization. @@ -57,7 +57,8 @@ public class SVMMultiClassClassificationExample { SVMLinearMultiClassClassificationTrainer trainer = new SVMLinearMultiClassClassificationTrainer(); SVMLinearMultiClassClassificationModel mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), + ignite, + dataCache, (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0] ); @@ -67,14 +68,15 @@ public class SVMMultiClassClassificationExample { NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>(); - NormalizationPreprocessor<Integer, double[]> preprocessor = normalizationTrainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), - (k, v) -> Arrays.copyOfRange(v, 1, v.length), - 5 + IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length) ); SVMLinearMultiClassClassificationModel mdlWithNormalization = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, dataCache), + ignite, + dataCache, preprocessor, (k, v) -> v[0] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java index cef6368..1ecf460 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java @@ -17,17 +17,17 @@ package org.apache.ignite.examples.ml.tree; -import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; import org.apache.ignite.thread.IgniteThread; +import java.util.Random; + /** * Example of using distributed {@link DecisionTreeClassificationTrainer}. */ @@ -65,7 +65,8 @@ public class DecisionTreeClassificationTrainerExample { // Train decision tree model. DecisionTreeNode mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, trainingSet), + ignite, + trainingSet, (k, v) -> new double[]{v.x, v.y}, (k, v) -> v.lb ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java index 61ba5f9..19b15f3 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java @@ -22,7 +22,6 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.tree.DecisionTreeNode; import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; import org.apache.ignite.thread.IgniteThread; @@ -61,7 +60,8 @@ public class DecisionTreeRegressionTrainerExample { // Train decision tree model. DecisionTreeNode mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, trainingSet), + ignite, + trainingSet, (k, v) -> new double[] {v.x}, (k, v) -> v.y ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java index f05bde8..4c34cd2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/Activators.java @@ -58,4 +58,24 @@ public class Activators { return Math.max(val, 0); } }; + + /** + * Linear unit activation function. + */ + public static IgniteDifferentiableDoubleToDoubleFunction LINEAR = new IgniteDifferentiableDoubleToDoubleFunction() { + /** {@inheritDoc} */ + @Override public double differential(double pnt) { + return 1.0; + } + + /** + * Differential of linear at pnt. + * + * @param pnt Point to differentiate at. + * @return Differential at pnt. + */ + @Override public Double apply(double pnt) { + return pnt; + } + }; } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index 47d2022..fe955cb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -17,11 +17,6 @@ package org.apache.ignite.ml.nn; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder; @@ -37,17 +32,23 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.nn.initializers.RandomInitializer; import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator; +import org.apache.ignite.ml.trainers.MultiLabelDatasetTrainer; import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.ml.util.Utils; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + /** * Multilayer perceptron trainer based on partition based {@link Dataset}. * * @param <P> Type of model update used in this trainer. */ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrainer<MultilayerPerceptron> { - /** Multilayer perceptron architecture that defines layers and activators. */ - private final MLPArchitecture arch; + /** Multilayer perceptron architecture supplier that defines layers and activators. */ + private final IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier; /** Loss function to be minimized during the training. */ private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss; @@ -81,7 +82,25 @@ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrai public MLPTrainer(MLPArchitecture arch, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, int locIterations, long seed) { - this.arch = arch; + this(dataset -> arch, loss, updatesStgy, maxIterations, batchSize, locIterations, seed); + } + + /** + * Constructs a new instance of multilayer perceptron trainer. + * + * @param archSupplier Multilayer perceptron architecture supplier that defines layers and activators. + * @param loss Loss function to be minimized during the training. + * @param updatesStgy Update strategy that defines how to update model parameters during the training. + * @param maxIterations Maximal number of iterations before the training will be stopped. + * @param batchSize Batch size (per every partition). + * @param locIterations Maximal number of local iterations before synchronization. + * @param seed Random initializer seed. + */ + public MLPTrainer(IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier, + IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss, + UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, int batchSize, + int locIterations, long seed) { + this.archSupplier = archSupplier; this.loss = loss; this.updatesStgy = updatesStgy; this.maxIterations = maxIterations; @@ -94,13 +113,14 @@ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrai public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { - MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); - ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator(); - try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) )) { + MLPArchitecture arch = archSupplier.apply(dataset); + MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); + ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator(); + for (int i = 0; i < maxIterations; i += locIterations) { MultilayerPerceptron finalMdl = mdl; @@ -163,12 +183,12 @@ public class MLPTrainer<P extends Serializable> implements MultiLabelDatasetTrai P update = updatesStgy.allUpdatesReducer().apply(totUp); mdl = updater.update(mdl, update); } + + return mdl; } catch (Exception e) { throw new RuntimeException(e); } - - return mdl; } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java index f5a6bb0..1886ee5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java @@ -17,9 +17,15 @@ package org.apache.ignite.ml.preprocessing; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import java.util.Map; + /** * Trainer for preprocessor. * @@ -34,9 +40,40 @@ public interface PreprocessingTrainer<K, V, T, R> { * * @param datasetBuilder Dataset builder. * @param basePreprocessor Base preprocessor. - * @param cols Number of columns. * @return Preprocessor. */ public IgniteBiFunction<K, V, R> fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, T> basePreprocessor, int cols); + IgniteBiFunction<K, V, T> basePreprocessor); + + /** + * Fits preprocessor. + * + * @param ignite Ignite instance. + * @param cache Ignite cache. + * @param basePreprocessor Base preprocessor. + * @return Preprocessor. + */ + public default IgniteBiFunction<K, V, R> fit(Ignite ignite, IgniteCache<K, V> cache, + IgniteBiFunction<K, V, T> basePreprocessor) { + return fit( + new CacheBasedDatasetBuilder<>(ignite, cache), + basePreprocessor + ); + } + + /** + * Fits preprocessor. + * + * @param data Data. + * @param parts Number of partitions. + * @param basePreprocessor Base preprocessor. + * @return Preprocessor. + */ + public default IgniteBiFunction<K, V, R> fit(Map<K, V> data, int parts, + IgniteBiFunction<K, V, T> basePreprocessor) { + return fit( + new LocalDatasetBuilder<>(data, parts), + basePreprocessor + ); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java index 16623ba..57acbad 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainer.java @@ -33,33 +33,48 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTrainer; public class NormalizationTrainer<K, V> implements PreprocessingTrainer<K, V, double[], double[]> { /** {@inheritDoc} */ @Override public NormalizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, double[]> basePreprocessor, int cols) { + IgniteBiFunction<K, V, double[]> basePreprocessor) { try (Dataset<EmptyContext, NormalizationPartitionData> dataset = datasetBuilder.build( (upstream, upstreamSize) -> new EmptyContext(), (upstream, upstreamSize, ctx) -> { - double[] min = new double[cols]; - double[] max = new double[cols]; - - for (int i = 0; i < cols; i++) { - min[i] = Double.MAX_VALUE; - max[i] = -Double.MAX_VALUE; - } + double[] min = null; + double[] max = null; while (upstream.hasNext()) { UpstreamEntry<K, V> entity = upstream.next(); double[] row = basePreprocessor.apply(entity.getKey(), entity.getValue()); - for (int i = 0; i < cols; i++) { + + if (min == null) { + min = new double[row.length]; + for (int i = 0; i < min.length; i++) + min[i] = Double.MAX_VALUE; + } + else + assert min.length == row.length : "Base preprocessor must return exactly " + min.length + + " features"; + + if (max == null) { + max = new double[row.length]; + for (int i = 0; i < max.length; i++) + max[i] = -Double.MAX_VALUE; + } + else + assert max.length == row.length : "Base preprocessor must return exactly " + min.length + + " features"; + + for (int i = 0; i < row.length; i++) { if (row[i] < min[i]) min[i] = row[i]; if (row[i] > max[i]) max[i] = row[i]; } } + return new NormalizationPartitionData(min, max); } )) { double[][] minMax = dataset.compute( - data -> new double[][]{ data.getMin(), data.getMax() }, + data -> data.getMin() != null ? new double[][]{ data.getMin(), data.getMax() } : null, (a, b) -> { if (a == null) return b; http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java new file mode 100644 index 0000000..8e8f467 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/FeatureExtractorWrapper.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.linear; + +import org.apache.ignite.ml.math.functions.IgniteBiFunction; + +import java.util.Arrays; + +/** + * Feature extractor wrapper that adds additional column filled by 1. + * + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + */ +public class FeatureExtractorWrapper<K, V> implements IgniteBiFunction<K, V, double[]> { + /** */ + private static final long serialVersionUID = -2686524650955735635L; + + /** Underlying feature extractor. */ + private final IgniteBiFunction<K, V, double[]> featureExtractor; + + /** + * Constructs a new instance of feature extractor wrapper. + * + * @param featureExtractor Underlying feature extractor. + */ + FeatureExtractorWrapper(IgniteBiFunction<K, V, double[]> featureExtractor) { + this.featureExtractor = featureExtractor; + } + + /** {@inheritDoc} */ + @Override public double[] apply(K k, V v) { + double[] featureRow = featureExtractor.apply(k, v); + double[] row = Arrays.copyOf(featureRow, featureRow.length + 1); + + row[featureRow.length] = 1.0; + + return row; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java index ae15f2f..9526db1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java @@ -17,8 +17,6 @@ package org.apache.ignite.ml.regressions.linear; -import java.util.Arrays; -import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.functions.IgniteBiFunction; @@ -27,6 +25,9 @@ import org.apache.ignite.ml.math.isolve.LinSysPartitionDataBuilderOnHeap; import org.apache.ignite.ml.math.isolve.lsqr.AbstractLSQR; import org.apache.ignite.ml.math.isolve.lsqr.LSQROnHeap; import org.apache.ignite.ml.math.isolve.lsqr.LSQRResult; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; + +import java.util.Arrays; /** * Trainer of the linear regression model based on LSQR algorithm. @@ -55,37 +56,4 @@ public class LinearRegressionLSQRTrainer implements SingleLabelDatasetTrainer<Li return new LinearRegressionModel(weights, x[x.length - 1]); } - - /** - * Feature extractor wrapper that adds additional column filled by 1. - * - * @param <K> Type of a key in {@code upstream} data. - * @param <V> Type of a value in {@code upstream} data. - */ - private static class FeatureExtractorWrapper<K, V> implements IgniteBiFunction<K, V, double[]> { - /** */ - private static final long serialVersionUID = -2686524650955735635L; - - /** Underlying feature extractor. */ - private final IgniteBiFunction<K, V, double[]> featureExtractor; - - /** - * Constructs a new instance of feature extractor wrapper. - * - * @param featureExtractor Underlying feature extractor. - */ - FeatureExtractorWrapper(IgniteBiFunction<K, V, double[]> featureExtractor) { - this.featureExtractor = featureExtractor; - } - - /** {@inheritDoc} */ - @Override public double[] apply(K k, V v) { - double[] featureRow = featureExtractor.apply(k, v); - double[] row = Arrays.copyOf(featureRow, featureRow.length + 1); - - row[featureRow.length] = 1.0; - - return row; - } - } } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java index aad4c7a..9be3fdd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java @@ -17,51 +17,99 @@ package org.apache.ignite.ml.regressions.linear; -import org.apache.ignite.ml.Trainer; -import org.apache.ignite.ml.math.Matrix; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.optimization.BarzilaiBorweinUpdater; -import org.apache.ignite.ml.optimization.GradientDescent; -import org.apache.ignite.ml.optimization.LeastSquaresGradientFunction; -import org.apache.ignite.ml.optimization.SimpleUpdater; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.Activators; +import org.apache.ignite.ml.nn.MLPTrainer; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; +import org.apache.ignite.ml.trainers.group.UpdatesStrategy; + +import java.io.Serializable; +import java.util.Arrays; /** - * Linear regression trainer based on least squares loss function and gradient descent optimization algorithm. + * Trainer of the linear regression model based on stochastic gradient descent algorithm. */ -public class LinearRegressionSGDTrainer implements Trainer<LinearRegressionModel, Matrix> { - /** - * Gradient descent optimizer. - */ - private final GradientDescent gradientDescent; +public class LinearRegressionSGDTrainer<P extends Serializable> implements SingleLabelDatasetTrainer<LinearRegressionModel> { + /** Update strategy. */ + private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; - /** */ - public LinearRegressionSGDTrainer(GradientDescent gradientDescent) { - this.gradientDescent = gradientDescent; - } + /** Max number of iteration. */ + private final int maxIterations; - /** */ - public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol) { - this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new BarzilaiBorweinUpdater()) - .withMaxIterations(maxIterations) - .withConvergenceTol(convergenceTol); - } + /** Batch size. */ + private final int batchSize; - /** */ - public LinearRegressionSGDTrainer(int maxIterations, double convergenceTol, double learningRate) { - this.gradientDescent = new GradientDescent(new LeastSquaresGradientFunction(), new SimpleUpdater(learningRate)) - .withMaxIterations(maxIterations) - .withConvergenceTol(convergenceTol); - } + /** Number of local iterations. */ + private final int locIterations; + + /** Seed for random generator. */ + private final long seed; /** - * {@inheritDoc} + * Constructs a new instance of linear regression SGD trainer. + * + * @param updatesStgy Update strategy. + * @param maxIterations Max number of iteration. + * @param batchSize Batch size. + * @param locIterations Number of local iterations. + * @param seed Seed for random generator. */ - @Override public LinearRegressionModel train(Matrix data) { - Vector variables = gradientDescent.optimize(data, data.likeVector(data.columnSize())); - Vector weights = variables.viewPart(1, variables.size() - 1); + public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations, + int batchSize, int locIterations, long seed) { + this.updatesStgy = updatesStgy; + this.maxIterations = maxIterations; + this.batchSize = batchSize; + this.locIterations = locIterations; + this.seed = seed; + } + + /** {@inheritDoc} */ + @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { + + int cols = dataset.compute(data -> { + if (data.getFeatures() == null) + return null; + return data.getFeatures().length / data.getRows(); + }, (a, b) -> a == null ? b : a); + + MLPArchitecture architecture = new MLPArchitecture(cols); + architecture = architecture.withAddedLayer(1, true, Activators.LINEAR); + + return architecture; + }; + + MLPTrainer<?> trainer = new MLPTrainer<>( + archSupplier, + LossFunctions.MSE, + updatesStgy, + maxIterations, + batchSize, + locIterations, + seed + ); + + IgniteBiFunction<K, V, double[]> lbE = new IgniteBiFunction<K, V, double[]>() { + @Override public double[] apply(K k, V v) { + return new double[]{lbExtractor.apply(k, v)}; + } + }; + + MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE); - double intercept = variables.get(0); + double[] p = mlp.parameters().getStorage().data(); - return new LinearRegressionModel(weights, intercept); + return new LinearRegressionModel(new DenseLocalOnHeapVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]); } } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 8119a29..fcde3f5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -17,10 +17,16 @@ package org.apache.ignite.ml.trainers; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import java.util.Map; + /** * Interface for trainers. Trainer is just a function which produces model from the data. * @@ -40,4 +46,44 @@ public interface DatasetTrainer<M extends Model, L> { */ public <K, V> M fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor); + + /** + * Trains model based on the specified data. + * + * @param ignite Ignite instance. + * @param cache Ignite cache. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Model. + */ + public default <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache, IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + return fit( + new CacheBasedDatasetBuilder<>(ignite, cache), + featureExtractor, + lbExtractor + ); + } + + /** + * Trains model based on the specified data. + * + * @param data Data. + * @param parts Number of partitions. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return Model. + */ + public default <K, V> M fit(Map<K, V> data, int parts, IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + return fit( + new LocalDatasetBuilder<>(data, parts), + featureExtractor, + lbExtractor + ); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index b5a4b54..b27fcba 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -17,11 +17,7 @@ package org.apache.ignite.ml.knn; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.knn.classification.KNNClassificationModel; import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; import org.apache.ignite.ml.knn.classification.KNNStrategy; @@ -29,6 +25,10 @@ import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + /** Tests behaviour of KNNClassificationTest. */ public class KNNClassificationTest extends BaseKNNTest { /** */ @@ -46,7 +46,8 @@ public class KNNClassificationTest extends BaseKNNTest { KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( - new LocalDatasetBuilder<>(data, 2), + data, + 2, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(3) @@ -74,7 +75,8 @@ public class KNNClassificationTest extends BaseKNNTest { KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( - new LocalDatasetBuilder<>(data, 2), + data, + 2, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(1) @@ -102,7 +104,8 @@ public class KNNClassificationTest extends BaseKNNTest { KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( - new LocalDatasetBuilder<>(data, 2), + data, + 2, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(3) @@ -128,7 +131,8 @@ public class KNNClassificationTest extends BaseKNNTest { KNNClassificationTrainer trainer = new KNNClassificationTrainer(); KNNClassificationModel knnMdl = trainer.fit( - new LocalDatasetBuilder<>(data, 2), + data, + 2, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[2] ).withK(3) http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java index 5ca661f..038b880 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerIntegrationTest.java @@ -17,7 +17,6 @@ package org.apache.ignite.ml.nn; -import java.io.Serializable; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; @@ -25,22 +24,18 @@ import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.internal.util.typedef.X; import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Tracer; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; -import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator; -import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.*; import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import java.io.Serializable; + /** * Tests for {@link MLPTrainer} that require to start the whole Ignite infrastructure. */ @@ -137,7 +132,8 @@ public class MLPTrainerIntegrationTest extends GridCommonAbstractTest { ); MultilayerPerceptron mlp = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, xorCache), + ignite, + xorCache, (k, v) -> new double[]{ v.x, v.y }, (k, v) -> new double[]{ v.lb} ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java index 6906424..c53f6f1 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java @@ -17,24 +17,13 @@ package org.apache.ignite.ml.nn; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import org.apache.ignite.ml.TestUtils; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; -import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator; -import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.optimization.updatecalculators.*; import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.junit.Before; import org.junit.Test; @@ -42,6 +31,12 @@ import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + /** * Tests for {@link MLPTrainer} that don't require to start the whole Ignite infrastructure. */ @@ -140,7 +135,8 @@ public class MLPTrainerTest { ); MultilayerPerceptron mlp = trainer.fit( - new LocalDatasetBuilder<>(xorData, parts), + xorData, + parts, (k, v) -> v[0], (k, v) -> v[1] ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java index c787a47..a64af9b 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistIntegrationTest.java @@ -17,13 +17,11 @@ package org.apache.ignite.ml.nn.performance; -import java.io.IOException; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; @@ -38,6 +36,8 @@ import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.ml.util.MnistUtils; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; +import java.io.IOException; + /** * Tests {@link MLPTrainer} on the MNIST dataset that require to start the whole Ignite infrastructure. */ @@ -104,7 +104,8 @@ public class MLPTrainerMnistIntegrationTest extends GridCommonAbstractTest { System.out.println("Start training..."); long start = System.currentTimeMillis(); MultilayerPerceptron mdl = trainer.fit( - new CacheBasedDatasetBuilder<>(ignite, trainingSet), + ignite, + trainingSet, (k, v) -> v.getPixels(), (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data() ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java index 354af2c..d966484 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MLPTrainerMnistTest.java @@ -17,10 +17,6 @@ package org.apache.ignite.ml.nn.performance; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.VectorUtils; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; @@ -35,6 +31,10 @@ import org.apache.ignite.ml.trainers.group.UpdatesStrategy; import org.apache.ignite.ml.util.MnistUtils; import org.junit.Test; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + import static org.junit.Assert.assertTrue; /** @@ -74,7 +74,8 @@ public class MLPTrainerMnistTest { System.out.println("Start training..."); long start = System.currentTimeMillis(); MultilayerPerceptron mdl = trainer.fit( - new LocalDatasetBuilder<>(trainingSet, 1), + trainingSet, + 1, (k, v) -> v.getPixels(), (k, v) -> VectorUtils.num2Vec(v.getLabel(), 10).getStorage().data() ); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java index 1548253..e7a0d47 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/normalization/NormalizationTrainerTest.java @@ -17,15 +17,16 @@ package org.apache.ignite.ml.preprocessing.normalization; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + import static org.junit.Assert.assertArrayEquals; /** @@ -66,8 +67,7 @@ public class NormalizationTrainerTest { NormalizationPreprocessor<Integer, double[]> preprocessor = standardizationTrainer.fit( datasetBuilder, - (k, v) -> v, - 3 + (k, v) -> v ); assertArrayEquals(new double[] {0, 4, 1}, preprocessor.getMin(), 1e-8); http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index 82b3a1b..b3c9368 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -17,14 +17,7 @@ package org.apache.ignite.ml.regressions; -import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionQRTrainerTest; -import org.apache.ignite.ml.regressions.linear.BlockDistributedLinearRegressionSGDTrainerTest; -import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionQRTrainerTest; -import org.apache.ignite.ml.regressions.linear.DistributedLinearRegressionSGDTrainerTest; -import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest; -import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; -import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionQRTrainerTest; -import org.apache.ignite.ml.regressions.linear.LocalLinearRegressionSGDTrainerTest; +import org.apache.ignite.ml.regressions.linear.*; import org.junit.runner.RunWith; import org.junit.runners.Suite; @@ -35,12 +28,10 @@ import org.junit.runners.Suite; @Suite.SuiteClasses({ LinearRegressionModelTest.class, LocalLinearRegressionQRTrainerTest.class, - LocalLinearRegressionSGDTrainerTest.class, DistributedLinearRegressionQRTrainerTest.class, - DistributedLinearRegressionSGDTrainerTest.class, BlockDistributedLinearRegressionQRTrainerTest.class, - BlockDistributedLinearRegressionSGDTrainerTest.class, - LinearRegressionLSQRTrainerTest.class + LinearRegressionLSQRTrainerTest.class, + LinearRegressionSGDTrainerTest.class }) public class RegressionsTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java deleted file mode 100644 index 58037e2..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/BlockDistributedLinearRegressionSGDTrainerTest.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.ml.math.impls.matrix.SparseBlockDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; - -/** - * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseBlockDistributedMatrix}. - */ -public class BlockDistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { - /** */ - public BlockDistributedLinearRegressionSGDTrainerTest() { - super( - new LinearRegressionSGDTrainer(100_000, 1e-12), - SparseBlockDistributedMatrix::new, - SparseBlockDistributedVector::new, - 1e-2); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java deleted file mode 100644 index 71d3b3b..0000000 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/DistributedLinearRegressionSGDTrainerTest.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.regressions.linear; - -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; - -/** - * Tests for {@link LinearRegressionSGDTrainer} on {@link SparseDistributedMatrix}. - */ -public class DistributedLinearRegressionSGDTrainerTest extends GridAwareAbstractLinearRegressionTrainerTest { - /** */ - public DistributedLinearRegressionSGDTrainerTest() { - super( - new LinearRegressionSGDTrainer(100_000, 1e-12), - SparseDistributedMatrix::new, - SparseDistributedVector::new, - 1e-2); - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java index 1a60b80..9b75bd4 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/GridAwareAbstractLinearRegressionTrainerTest.java @@ -26,6 +26,9 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; import org.junit.Test; +/** + * Grid aware abstract linear regression trainer test. + */ public abstract class GridAwareAbstractLinearRegressionTrainerTest extends GridCommonAbstractTest { /** Number of nodes in grid */ private static final int NODE_COUNT = 3; http://git-wip-us.apache.org/repos/asf/ignite/blob/df6356d5/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java index e3f60ec..2414236 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java @@ -17,14 +17,14 @@ package org.apache.ignite.ml.regressions.linear; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Random; -import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -72,7 +72,8 @@ public class LinearRegressionLSQRTrainerTest { LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); LinearRegressionModel mdl = trainer.fit( - new LocalDatasetBuilder<>(data, parts), + data, + parts, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[4] ); @@ -110,7 +111,8 @@ public class LinearRegressionLSQRTrainerTest { LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer(); LinearRegressionModel mdl = trainer.fit( - new LocalDatasetBuilder<>(data, parts), + data, + parts, (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1), (k, v) -> v[coef.length] );