IGNITE-8511: [ML] Add support for Multi-Class Logistic Regression this closes #4008
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/cb8fb736 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/cb8fb736 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/cb8fb736 Branch: refs/heads/ignite-5789-1 Commit: cb8fb736597e9b3f25ef6d55a8dc4d8ad0d23b60 Parents: 436d123 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Mon May 21 15:31:16 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Mon May 21 15:31:16 2018 +0300 ---------------------------------------------------------------------- .../LogisticRegressionSGDTrainerSample.java | 239 --------------- .../LogisticRegressionSGDTrainerSample.java | 239 +++++++++++++++ .../logistic/binary/package-info.java | 22 ++ ...gressionMultiClassClassificationExample.java | 301 +++++++++++++++++++ .../logistic/multiclass/package-info.java | 22 ++ .../LogRegressionMultiClassModel.java | 96 ++++++ .../LogRegressionMultiClassTrainer.java | 222 ++++++++++++++ .../logistic/multiclass/package-info.java | 22 ++ .../ml/regressions/RegressionsTestSuite.java | 6 +- .../linear/LinearRegressionModelTest.java | 17 ++ .../logistic/LogRegMultiClassTrainerTest.java | 98 ++++++ 11 files changed, 1043 insertions(+), 241 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java deleted file mode 100644 index 0505ddd..0000000 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/LogisticRegressionSGDTrainerSample.java +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.regression.logistic; - -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; -import org.apache.ignite.cache.query.QueryCursor; -import org.apache.ignite.cache.query.ScanQuery; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.nn.UpdatesStrategy; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; -import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; -import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; -import org.apache.ignite.thread.IgniteThread; - -import javax.cache.Cache; -import java.util.Arrays; -import java.util.UUID; - -/** - * Run logistic regression model over distributed cache. - * - * @see LogisticRegressionSGDTrainer - */ -public class LogisticRegressionSGDTrainerSample { - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Logistic regression model over partitioned dataset usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), - LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> { - - IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); - - System.out.println(">>> Create new logistic regression trainer object."); - LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( - new SimpleGDUpdateCalculator(0.2), - SimpleGDParameterUpdate::sumLocal, - SimpleGDParameterUpdate::avg - ), 100000, 10, 100, 123L); - - System.out.println(">>> Perform the training to get the model."); - LogisticRegressionModel mdl = trainer.fit( - ignite, - dataCache, - (k, v) -> Arrays.copyOfRange(v, 1, v.length), - (k, v) -> v[0] - ).withRawLabels(true); - - System.out.println(">>> Logistic regression model: " + mdl); - - int amountOfErrors = 0; - int totalAmount = 0; - - // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix - int[][] confusionMtx = {{0, 0}, {0, 0}}; - - try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { - for (Cache.Entry<Integer, double[]> observation : observations) { - double[] val = observation.getValue(); - double[] inputs = Arrays.copyOfRange(val, 1, val.length); - double groundTruth = val[0]; - - double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); - - totalAmount++; - if(groundTruth != prediction) - amountOfErrors++; - - int idx1 = (int)prediction; - int idx2 = (int)groundTruth; - - confusionMtx[idx1][idx2]++; - - System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); - } - - System.out.println(">>> ---------------------------------"); - - System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); - System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); - } - - System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); - System.out.println(">>> ---------------------------------"); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } - /** - * Fills cache with data and returns it. - * - * @param ignite Ignite instance. - * @return Filled Ignite Cache. - */ - private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { - CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); - cacheConfiguration.setName("TEST_" + UUID.randomUUID()); - cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); - - IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); - - for (int i = 0; i < data.length; i++) - cache.put(i, data[i]); - - return cache; - } - - - /** The 1st and 2nd classes from the Iris dataset. */ - private static final double[][] data = { - {0, 5.1, 3.5, 1.4, 0.2}, - {0, 4.9, 3, 1.4, 0.2}, - {0, 4.7, 3.2, 1.3, 0.2}, - {0, 4.6, 3.1, 1.5, 0.2}, - {0, 5, 3.6, 1.4, 0.2}, - {0, 5.4, 3.9, 1.7, 0.4}, - {0, 4.6, 3.4, 1.4, 0.3}, - {0, 5, 3.4, 1.5, 0.2}, - {0, 4.4, 2.9, 1.4, 0.2}, - {0, 4.9, 3.1, 1.5, 0.1}, - {0, 5.4, 3.7, 1.5, 0.2}, - {0, 4.8, 3.4, 1.6, 0.2}, - {0, 4.8, 3, 1.4, 0.1}, - {0, 4.3, 3, 1.1, 0.1}, - {0, 5.8, 4, 1.2, 0.2}, - {0, 5.7, 4.4, 1.5, 0.4}, - {0, 5.4, 3.9, 1.3, 0.4}, - {0, 5.1, 3.5, 1.4, 0.3}, - {0, 5.7, 3.8, 1.7, 0.3}, - {0, 5.1, 3.8, 1.5, 0.3}, - {0, 5.4, 3.4, 1.7, 0.2}, - {0, 5.1, 3.7, 1.5, 0.4}, - {0, 4.6, 3.6, 1, 0.2}, - {0, 5.1, 3.3, 1.7, 0.5}, - {0, 4.8, 3.4, 1.9, 0.2}, - {0, 5, 3, 1.6, 0.2}, - {0, 5, 3.4, 1.6, 0.4}, - {0, 5.2, 3.5, 1.5, 0.2}, - {0, 5.2, 3.4, 1.4, 0.2}, - {0, 4.7, 3.2, 1.6, 0.2}, - {0, 4.8, 3.1, 1.6, 0.2}, - {0, 5.4, 3.4, 1.5, 0.4}, - {0, 5.2, 4.1, 1.5, 0.1}, - {0, 5.5, 4.2, 1.4, 0.2}, - {0, 4.9, 3.1, 1.5, 0.1}, - {0, 5, 3.2, 1.2, 0.2}, - {0, 5.5, 3.5, 1.3, 0.2}, - {0, 4.9, 3.1, 1.5, 0.1}, - {0, 4.4, 3, 1.3, 0.2}, - {0, 5.1, 3.4, 1.5, 0.2}, - {0, 5, 3.5, 1.3, 0.3}, - {0, 4.5, 2.3, 1.3, 0.3}, - {0, 4.4, 3.2, 1.3, 0.2}, - {0, 5, 3.5, 1.6, 0.6}, - {0, 5.1, 3.8, 1.9, 0.4}, - {0, 4.8, 3, 1.4, 0.3}, - {0, 5.1, 3.8, 1.6, 0.2}, - {0, 4.6, 3.2, 1.4, 0.2}, - {0, 5.3, 3.7, 1.5, 0.2}, - {0, 5, 3.3, 1.4, 0.2}, - {1, 7, 3.2, 4.7, 1.4}, - {1, 6.4, 3.2, 4.5, 1.5}, - {1, 6.9, 3.1, 4.9, 1.5}, - {1, 5.5, 2.3, 4, 1.3}, - {1, 6.5, 2.8, 4.6, 1.5}, - {1, 5.7, 2.8, 4.5, 1.3}, - {1, 6.3, 3.3, 4.7, 1.6}, - {1, 4.9, 2.4, 3.3, 1}, - {1, 6.6, 2.9, 4.6, 1.3}, - {1, 5.2, 2.7, 3.9, 1.4}, - {1, 5, 2, 3.5, 1}, - {1, 5.9, 3, 4.2, 1.5}, - {1, 6, 2.2, 4, 1}, - {1, 6.1, 2.9, 4.7, 1.4}, - {1, 5.6, 2.9, 3.6, 1.3}, - {1, 6.7, 3.1, 4.4, 1.4}, - {1, 5.6, 3, 4.5, 1.5}, - {1, 5.8, 2.7, 4.1, 1}, - {1, 6.2, 2.2, 4.5, 1.5}, - {1, 5.6, 2.5, 3.9, 1.1}, - {1, 5.9, 3.2, 4.8, 1.8}, - {1, 6.1, 2.8, 4, 1.3}, - {1, 6.3, 2.5, 4.9, 1.5}, - {1, 6.1, 2.8, 4.7, 1.2}, - {1, 6.4, 2.9, 4.3, 1.3}, - {1, 6.6, 3, 4.4, 1.4}, - {1, 6.8, 2.8, 4.8, 1.4}, - {1, 6.7, 3, 5, 1.7}, - {1, 6, 2.9, 4.5, 1.5}, - {1, 5.7, 2.6, 3.5, 1}, - {1, 5.5, 2.4, 3.8, 1.1}, - {1, 5.5, 2.4, 3.7, 1}, - {1, 5.8, 2.7, 3.9, 1.2}, - {1, 6, 2.7, 5.1, 1.6}, - {1, 5.4, 3, 4.5, 1.5}, - {1, 6, 3.4, 4.5, 1.6}, - {1, 6.7, 3.1, 4.7, 1.5}, - {1, 6.3, 2.3, 4.4, 1.3}, - {1, 5.6, 3, 4.1, 1.3}, - {1, 5.5, 2.5, 4, 1.3}, - {1, 5.5, 2.6, 4.4, 1.2}, - {1, 6.1, 3, 4.6, 1.4}, - {1, 5.8, 2.6, 4, 1.2}, - {1, 5, 2.3, 3.3, 1}, - {1, 5.6, 2.7, 4.2, 1.3}, - {1, 5.7, 3, 4.2, 1.2}, - {1, 5.7, 2.9, 4.2, 1.3}, - {1, 6.2, 2.9, 4.3, 1.3}, - {1, 5.1, 2.5, 3, 1.1}, - {1, 5.7, 2.8, 4.1, 1.3}, - }; - -} http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java new file mode 100644 index 0000000..215d7a4 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerSample.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.logistic.binary; + +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.thread.IgniteThread; + +import javax.cache.Cache; +import java.util.Arrays; +import java.util.UUID; + +/** + * Run logistic regression model over distributed cache. + * + * @see LogisticRegressionSGDTrainer + */ +public class LogisticRegressionSGDTrainerSample { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Logistic regression model over partitioned dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + LogisticRegressionSGDTrainerSample.class.getSimpleName(), () -> { + + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + System.out.println(">>> Create new logistic regression trainer object."); + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), 100000, 10, 100, 123L); + + System.out.println(">>> Perform the training to get the model."); + LogisticRegressionModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ).withRawLabels(true); + + System.out.println(">>> Logistic regression model: " + mdl); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + totalAmount++; + if(groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + } + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } + + + /** The 1st and 2nd classes from the Iris dataset. */ + private static final double[][] data = { + {0, 5.1, 3.5, 1.4, 0.2}, + {0, 4.9, 3, 1.4, 0.2}, + {0, 4.7, 3.2, 1.3, 0.2}, + {0, 4.6, 3.1, 1.5, 0.2}, + {0, 5, 3.6, 1.4, 0.2}, + {0, 5.4, 3.9, 1.7, 0.4}, + {0, 4.6, 3.4, 1.4, 0.3}, + {0, 5, 3.4, 1.5, 0.2}, + {0, 4.4, 2.9, 1.4, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 5.4, 3.7, 1.5, 0.2}, + {0, 4.8, 3.4, 1.6, 0.2}, + {0, 4.8, 3, 1.4, 0.1}, + {0, 4.3, 3, 1.1, 0.1}, + {0, 5.8, 4, 1.2, 0.2}, + {0, 5.7, 4.4, 1.5, 0.4}, + {0, 5.4, 3.9, 1.3, 0.4}, + {0, 5.1, 3.5, 1.4, 0.3}, + {0, 5.7, 3.8, 1.7, 0.3}, + {0, 5.1, 3.8, 1.5, 0.3}, + {0, 5.4, 3.4, 1.7, 0.2}, + {0, 5.1, 3.7, 1.5, 0.4}, + {0, 4.6, 3.6, 1, 0.2}, + {0, 5.1, 3.3, 1.7, 0.5}, + {0, 4.8, 3.4, 1.9, 0.2}, + {0, 5, 3, 1.6, 0.2}, + {0, 5, 3.4, 1.6, 0.4}, + {0, 5.2, 3.5, 1.5, 0.2}, + {0, 5.2, 3.4, 1.4, 0.2}, + {0, 4.7, 3.2, 1.6, 0.2}, + {0, 4.8, 3.1, 1.6, 0.2}, + {0, 5.4, 3.4, 1.5, 0.4}, + {0, 5.2, 4.1, 1.5, 0.1}, + {0, 5.5, 4.2, 1.4, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 5, 3.2, 1.2, 0.2}, + {0, 5.5, 3.5, 1.3, 0.2}, + {0, 4.9, 3.1, 1.5, 0.1}, + {0, 4.4, 3, 1.3, 0.2}, + {0, 5.1, 3.4, 1.5, 0.2}, + {0, 5, 3.5, 1.3, 0.3}, + {0, 4.5, 2.3, 1.3, 0.3}, + {0, 4.4, 3.2, 1.3, 0.2}, + {0, 5, 3.5, 1.6, 0.6}, + {0, 5.1, 3.8, 1.9, 0.4}, + {0, 4.8, 3, 1.4, 0.3}, + {0, 5.1, 3.8, 1.6, 0.2}, + {0, 4.6, 3.2, 1.4, 0.2}, + {0, 5.3, 3.7, 1.5, 0.2}, + {0, 5, 3.3, 1.4, 0.2}, + {1, 7, 3.2, 4.7, 1.4}, + {1, 6.4, 3.2, 4.5, 1.5}, + {1, 6.9, 3.1, 4.9, 1.5}, + {1, 5.5, 2.3, 4, 1.3}, + {1, 6.5, 2.8, 4.6, 1.5}, + {1, 5.7, 2.8, 4.5, 1.3}, + {1, 6.3, 3.3, 4.7, 1.6}, + {1, 4.9, 2.4, 3.3, 1}, + {1, 6.6, 2.9, 4.6, 1.3}, + {1, 5.2, 2.7, 3.9, 1.4}, + {1, 5, 2, 3.5, 1}, + {1, 5.9, 3, 4.2, 1.5}, + {1, 6, 2.2, 4, 1}, + {1, 6.1, 2.9, 4.7, 1.4}, + {1, 5.6, 2.9, 3.6, 1.3}, + {1, 6.7, 3.1, 4.4, 1.4}, + {1, 5.6, 3, 4.5, 1.5}, + {1, 5.8, 2.7, 4.1, 1}, + {1, 6.2, 2.2, 4.5, 1.5}, + {1, 5.6, 2.5, 3.9, 1.1}, + {1, 5.9, 3.2, 4.8, 1.8}, + {1, 6.1, 2.8, 4, 1.3}, + {1, 6.3, 2.5, 4.9, 1.5}, + {1, 6.1, 2.8, 4.7, 1.2}, + {1, 6.4, 2.9, 4.3, 1.3}, + {1, 6.6, 3, 4.4, 1.4}, + {1, 6.8, 2.8, 4.8, 1.4}, + {1, 6.7, 3, 5, 1.7}, + {1, 6, 2.9, 4.5, 1.5}, + {1, 5.7, 2.6, 3.5, 1}, + {1, 5.5, 2.4, 3.8, 1.1}, + {1, 5.5, 2.4, 3.7, 1}, + {1, 5.8, 2.7, 3.9, 1.2}, + {1, 6, 2.7, 5.1, 1.6}, + {1, 5.4, 3, 4.5, 1.5}, + {1, 6, 3.4, 4.5, 1.6}, + {1, 6.7, 3.1, 4.7, 1.5}, + {1, 6.3, 2.3, 4.4, 1.3}, + {1, 5.6, 3, 4.1, 1.3}, + {1, 5.5, 2.5, 4, 1.3}, + {1, 5.5, 2.6, 4.4, 1.2}, + {1, 6.1, 3, 4.6, 1.4}, + {1, 5.8, 2.6, 4, 1.2}, + {1, 5, 2.3, 3.3, 1}, + {1, 5.6, 2.7, 4.2, 1.3}, + {1, 5.7, 3, 4.2, 1.2}, + {1, 5.7, 2.9, 4.2, 1.3}, + {1, 6.2, 2.9, 4.3, 1.3}, + {1, 5.1, 2.5, 3, 1.1}, + {1, 5.7, 2.8, 4.1, 1.3}, + }; + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java new file mode 100644 index 0000000..6ea42e7 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * ML binary logistic regression examples. + */ +package org.apache.ignite.examples.ml.regression.logistic.binary; http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java new file mode 100644 index 0000000..f089923 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression.logistic.multiclass; + +import java.util.Arrays; +import java.util.UUID; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; +import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run Logistic Regression multi-class classification trainer over distributed dataset to build two models: + * one with normalization and one without normalization. + * + * @see SVMLinearMultiClassClassificationModel + */ +public class LogRegressionMultiClassClassificationExample { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + LogRegressionMultiClassClassificationExample.class.getSimpleName(), () -> { + IgniteCache<Integer, double[]> dataCache = getTestCache(ignite); + + LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + )) + .withAmountOfIterations(100000) + .withAmountOfLocIterations(10) + .withBatchSize(100) + .withSeed(123L); + + LogRegressionMultiClassModel mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ); + + System.out.println(">>> SVM Multi-class model"); + System.out.println(mdl.toString()); + + NormalizationTrainer<Integer, double[]> normalizationTrainer = new NormalizationTrainer<>(); + + IgniteBiFunction<Integer, double[], double[]> preprocessor = normalizationTrainer.fit( + ignite, + dataCache, + (k, v) -> Arrays.copyOfRange(v, 1, v.length) + ); + + LogRegressionMultiClassModel mdlWithNormalization = trainer.fit( + ignite, + dataCache, + preprocessor, + (k, v) -> v[0] + ); + + System.out.println(">>> Logistic Regression Multi-class model with normalization"); + System.out.println(mdlWithNormalization.toString()); + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int amountOfErrorsWithNormalization = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, double[]> observation : observations) { + double[] val = observation.getValue(); + double[] inputs = Arrays.copyOfRange(val, 1, val.length); + double groundTruth = val[0]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + double predictionWithNormalization = mdlWithNormalization.apply(new DenseLocalOnHeapVector(inputs)); + + totalAmount++; + + // Collect data for model + if(groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2); + int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); + + confusionMtx[idx1][idx2]++; + + // Collect data for model with normalization + if(groundTruth != predictionWithNormalization) + amountOfErrorsWithNormalization++; + + idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2); + idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2); + + confusionMtxWithNormalization[idx1][idx2]++; + + System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth); + } + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println("\n>>> -----------------Logistic Regression model-------------"); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount)); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + + System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------"); + System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization); + System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount)); + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization)); + } + }); + + igniteThread.start(); + igniteThread.join(); + } + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) { + CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TEST_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration); + + for (int i = 0; i < data.length; i++) + cache.put(i, data[i]); + + return cache; + } + + /** The preprocessed Glass dataset from the Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/Glass+Identification + * There are 3 classes with labels: 1 {building_windows_float_processed}, 3 {vehicle_windows_float_processed}, 7 {headlamps}. + * Feature names: 'Na-Sodium', 'Mg-Magnesium', 'Al-Aluminum', 'Ba-Barium', 'Fe-Iron'. + */ + private static final double[][] data = { + {1, 1.52101, 4.49, 1.10, 0.00, 0.00}, + {1, 1.51761, 3.60, 1.36, 0.00, 0.00}, + {1, 1.51618, 3.55, 1.54, 0.00, 0.00}, + {1, 1.51766, 3.69, 1.29, 0.00, 0.00}, + {1, 1.51742, 3.62, 1.24, 0.00, 0.00}, + {1, 1.51596, 3.61, 1.62, 0.00, 0.26}, + {1, 1.51743, 3.60, 1.14, 0.00, 0.00}, + {1, 1.51756, 3.61, 1.05, 0.00, 0.00}, + {1, 1.51918, 3.58, 1.37, 0.00, 0.00}, + {1, 1.51755, 3.60, 1.36, 0.00, 0.11}, + {1, 1.51571, 3.46, 1.56, 0.00, 0.24}, + {1, 1.51763, 3.66, 1.27, 0.00, 0.00}, + {1, 1.51589, 3.43, 1.40, 0.00, 0.24}, + {1, 1.51748, 3.56, 1.27, 0.00, 0.17}, + {1, 1.51763, 3.59, 1.31, 0.00, 0.00}, + {1, 1.51761, 3.54, 1.23, 0.00, 0.00}, + {1, 1.51784, 3.67, 1.16, 0.00, 0.00}, + {1, 1.52196, 3.85, 0.89, 0.00, 0.00}, + {1, 1.51911, 3.73, 1.18, 0.00, 0.00}, + {1, 1.51735, 3.54, 1.69, 0.00, 0.07}, + {1, 1.51750, 3.55, 1.49, 0.00, 0.19}, + {1, 1.51966, 3.75, 0.29, 0.00, 0.00}, + {1, 1.51736, 3.62, 1.29, 0.00, 0.00}, + {1, 1.51751, 3.57, 1.35, 0.00, 0.00}, + {1, 1.51720, 3.50, 1.15, 0.00, 0.00}, + {1, 1.51764, 3.54, 1.21, 0.00, 0.00}, + {1, 1.51793, 3.48, 1.41, 0.00, 0.00}, + {1, 1.51721, 3.48, 1.33, 0.00, 0.00}, + {1, 1.51768, 3.52, 1.43, 0.00, 0.00}, + {1, 1.51784, 3.49, 1.28, 0.00, 0.00}, + {1, 1.51768, 3.56, 1.30, 0.00, 0.14}, + {1, 1.51747, 3.50, 1.14, 0.00, 0.00}, + {1, 1.51775, 3.48, 1.23, 0.09, 0.22}, + {1, 1.51753, 3.47, 1.38, 0.00, 0.06}, + {1, 1.51783, 3.54, 1.34, 0.00, 0.00}, + {1, 1.51567, 3.45, 1.21, 0.00, 0.00}, + {1, 1.51909, 3.53, 1.32, 0.11, 0.00}, + {1, 1.51797, 3.48, 1.35, 0.00, 0.00}, + {1, 1.52213, 3.82, 0.47, 0.00, 0.00}, + {1, 1.52213, 3.82, 0.47, 0.00, 0.00}, + {1, 1.51793, 3.50, 1.12, 0.00, 0.00}, + {1, 1.51755, 3.42, 1.20, 0.00, 0.00}, + {1, 1.51779, 3.39, 1.33, 0.00, 0.00}, + {1, 1.52210, 3.84, 0.72, 0.00, 0.00}, + {1, 1.51786, 3.43, 1.19, 0.00, 0.30}, + {1, 1.51900, 3.48, 1.35, 0.00, 0.00}, + {1, 1.51869, 3.37, 1.18, 0.00, 0.16}, + {1, 1.52667, 3.70, 0.71, 0.00, 0.10}, + {1, 1.52223, 3.77, 0.79, 0.00, 0.00}, + {1, 1.51898, 3.35, 1.23, 0.00, 0.00}, + {1, 1.52320, 3.72, 0.51, 0.00, 0.16}, + {1, 1.51926, 3.33, 1.28, 0.00, 0.11}, + {1, 1.51808, 2.87, 1.19, 0.00, 0.00}, + {1, 1.51837, 2.84, 1.28, 0.00, 0.00}, + {1, 1.51778, 2.81, 1.29, 0.00, 0.09}, + {1, 1.51769, 2.71, 1.29, 0.00, 0.24}, + {1, 1.51215, 3.47, 1.12, 0.00, 0.31}, + {1, 1.51824, 3.48, 1.29, 0.00, 0.00}, + {1, 1.51754, 3.74, 1.17, 0.00, 0.00}, + {1, 1.51754, 3.66, 1.19, 0.00, 0.11}, + {1, 1.51905, 3.62, 1.11, 0.00, 0.00}, + {1, 1.51977, 3.58, 1.32, 0.69, 0.00}, + {1, 1.52172, 3.86, 0.88, 0.00, 0.11}, + {1, 1.52227, 3.81, 0.78, 0.00, 0.00}, + {1, 1.52172, 3.74, 0.90, 0.00, 0.07}, + {1, 1.52099, 3.59, 1.12, 0.00, 0.00}, + {1, 1.52152, 3.65, 0.87, 0.00, 0.17}, + {1, 1.52152, 3.65, 0.87, 0.00, 0.17}, + {1, 1.52152, 3.58, 0.90, 0.00, 0.16}, + {1, 1.52300, 3.58, 0.82, 0.00, 0.03}, + {3, 1.51769, 3.66, 1.11, 0.00, 0.00}, + {3, 1.51610, 3.53, 1.34, 0.00, 0.00}, + {3, 1.51670, 3.57, 1.38, 0.00, 0.10}, + {3, 1.51643, 3.52, 1.35, 0.00, 0.00}, + {3, 1.51665, 3.45, 1.76, 0.00, 0.17}, + {3, 1.52127, 3.90, 0.83, 0.00, 0.00}, + {3, 1.51779, 3.65, 0.65, 0.00, 0.00}, + {3, 1.51610, 3.40, 1.22, 0.00, 0.00}, + {3, 1.51694, 3.58, 1.31, 0.00, 0.00}, + {3, 1.51646, 3.40, 1.26, 0.00, 0.00}, + {3, 1.51655, 3.39, 1.28, 0.00, 0.00}, + {3, 1.52121, 3.76, 0.58, 0.00, 0.00}, + {3, 1.51776, 3.41, 1.52, 0.00, 0.00}, + {3, 1.51796, 3.36, 1.63, 0.00, 0.09}, + {3, 1.51832, 3.34, 1.54, 0.00, 0.00}, + {3, 1.51934, 3.54, 0.75, 0.15, 0.24}, + {3, 1.52211, 3.78, 0.91, 0.00, 0.37}, + {7, 1.51131, 3.20, 1.81, 1.19, 0.00}, + {7, 1.51838, 3.26, 2.22, 1.63, 0.00}, + {7, 1.52315, 3.34, 1.23, 0.00, 0.00}, + {7, 1.52247, 2.20, 2.06, 0.00, 0.00}, + {7, 1.52365, 1.83, 1.31, 1.68, 0.00}, + {7, 1.51613, 1.78, 1.79, 0.76, 0.00}, + {7, 1.51602, 0.00, 2.38, 0.64, 0.09}, + {7, 1.51623, 0.00, 2.79, 0.40, 0.09}, + {7, 1.51719, 0.00, 2.00, 1.59, 0.08}, + {7, 1.51683, 0.00, 1.98, 1.57, 0.07}, + {7, 1.51545, 0.00, 2.68, 0.61, 0.05}, + {7, 1.51556, 0.00, 2.54, 0.81, 0.01}, + {7, 1.51727, 0.00, 2.34, 0.66, 0.00}, + {7, 1.51531, 0.00, 2.66, 0.64, 0.00}, + {7, 1.51609, 0.00, 2.51, 0.53, 0.00}, + {7, 1.51508, 0.00, 2.25, 0.63, 0.00}, + {7, 1.51653, 0.00, 1.19, 0.00, 0.00}, + {7, 1.51514, 0.00, 2.42, 0.56, 0.00}, + {7, 1.51658, 0.00, 1.99, 1.71, 0.00}, + {7, 1.51617, 0.00, 2.27, 0.67, 0.00}, + {7, 1.51732, 0.00, 1.80, 1.55, 0.00}, + {7, 1.51645, 0.00, 1.87, 1.38, 0.00}, + {7, 1.51831, 0.00, 1.82, 2.88, 0.00}, + {7, 1.51640, 0.00, 2.74, 0.54, 0.00}, + {7, 1.51623, 0.00, 2.88, 1.06, 0.00}, + {7, 1.51685, 0.00, 1.99, 1.59, 0.00}, + {7, 1.52065, 0.00, 2.02, 1.64, 0.00}, + {7, 1.51651, 0.00, 1.94, 1.57, 0.00}, + {7, 1.51711, 0.00, 2.08, 1.67, 0.00}, + }; +} http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java new file mode 100644 index 0000000..c7b7fe8 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * ML multi-class logistic regression examples. + */ +package org.apache.ignite.examples.ml.regression.logistic.multiclass; http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java new file mode 100644 index 0000000..0817432 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic.multiclass; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; + +/** Base class for multi-classification model for set of Logistic Regression classifiers. */ +public class LogRegressionMultiClassModel implements Model<Vector, Double>, Exportable<LogRegressionMultiClassModel>, Serializable { + /** */ + private static final long serialVersionUID = -114986533350117L; + + /** List of models associated with each class. */ + private Map<Double, LogisticRegressionModel> models; + + /** */ + public LogRegressionMultiClassModel() { + this.models = new HashMap<>(); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector input) { + TreeMap<Double, Double> maxMargins = new TreeMap<>(); + + models.forEach((k, v) -> maxMargins.put(1.0 / (1.0 + Math.exp(-(input.dot(v.weights()) + v.intercept()))), k)); + + return maxMargins.lastEntry().getValue(); + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<LogRegressionMultiClassModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + + if (o == null || getClass() != o.getClass()) + return false; + + LogRegressionMultiClassModel mdl = (LogRegressionMultiClassModel)o; + + return Objects.equals(models, mdl.models); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return Objects.hash(models); + } + + /** {@inheritDoc} */ + @Override public String toString() { + StringBuilder wholeStr = new StringBuilder(); + + models.forEach((clsLb, mdl) -> { + wholeStr.append("The class with label ").append(clsLb).append(" has classifier: ").append(mdl.toString()).append(System.lineSeparator()); + }); + + return wholeStr.toString(); + } + + /** + * Adds a specific Log Regression binary classifier to the bunch of same classifiers. + * + * @param clsLb The class label for the added model. + * @param mdl The model. + */ + public void add(double clsLb, LogisticRegressionModel mdl) { + models.put(clsLb, mdl); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java new file mode 100644 index 0000000..e8ed67b --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic.multiclass; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; + +/** + * All common parameters are shared with bunch of binary classification trainers. + */ +public class LogRegressionMultiClassTrainer<P extends Serializable> + implements SingleLabelDatasetTrainer<LogRegressionMultiClassModel> { + /** Update strategy. */ + private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; + + /** Max number of iteration. */ + private int amountOfIterations; + + /** Batch size. */ + private int batchSize; + + /** Number of local iterations. */ + private int amountOfLocIterations; + + /** Seed for random generator. */ + private long seed; + + /** + * Trains model based on the specified data. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return Model. + */ + @Override public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, double[]> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); + + LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel(); + + classes.forEach(clsLb -> { + LogisticRegressionSGDTrainer<?> trainer = + new LogisticRegressionSGDTrainer<>(updatesStgy, amountOfIterations, batchSize, amountOfLocIterations, seed); + + IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> { + Double lb = lbExtractor.apply(k, v); + + if (lb.equals(clsLb)) + return 1.0; + else + return 0.0; + }; + multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); + }); + + return multiClsMdl; + } + + /** Iterates among dataset and collects class labels. */ + private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) { + assert datasetBuilder != null; + + PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); + + List<Double> res = new ArrayList<>(); + + try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), + partDataBuilder + )) { + final Set<Double> clsLabels = dataset.compute(data -> { + final Set<Double> locClsLabels = new HashSet<>(); + + final double[] lbs = data.getY(); + + for (double lb : lbs) locClsLabels.add(lb); + + return locClsLabels; + }, (a, b) -> a == null ? b : Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet())); + + res.addAll(clsLabels); + + } catch (Exception e) { + throw new RuntimeException(e); + } + return res; + } + + /** + * Set up the regularization parameter. + * + * @param batchSize The size of learning batch. + * @return Trainer with new batch size parameter value. + */ + public LogRegressionMultiClassTrainer withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Gets the batch size. + * + * @return The parameter value. + */ + public double batchSize() { + return batchSize; + } + + /** + * Gets the amount of outer iterations of SGD algorithm. + * + * @return The parameter value. + */ + public int amountOfIterations() { + return amountOfIterations; + } + + /** + * Set up the amount of outer iterations. + * + * @param amountOfIterations The parameter value. + * @return Trainer with new amountOfIterations parameter value. + */ + public LogRegressionMultiClassTrainer withAmountOfIterations(int amountOfIterations) { + this.amountOfIterations = amountOfIterations; + return this; + } + + /** + * Gets the amount of local iterations. + * + * @return The parameter value. + */ + public int amountOfLocIterations() { + return amountOfLocIterations; + } + + /** + * Set up the amount of local iterations of SGD algorithm. + * + * @param amountOfLocIterations The parameter value. + * @return Trainer with new amountOfLocIterations parameter value. + */ + public LogRegressionMultiClassTrainer withAmountOfLocIterations(int amountOfLocIterations) { + this.amountOfLocIterations = amountOfLocIterations; + return this; + } + + /** + * Set up the regularization parameter. + * + * @param seed Seed for random generator. + * @return Trainer with new seed parameter value. + */ + public LogRegressionMultiClassTrainer withSeed(long seed) { + this.seed = seed; + return this; + } + + /** + * Gets the seed for random generator. + * + * @return The parameter value. + */ + public long seed() { + return seed; + } + + /** + * Set up the regularization parameter. + * + * @param updatesStgy Update strategy. + * @return Trainer with new update strategy parameter value. + */ + public LogRegressionMultiClassTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) { + this.updatesStgy = updatesStgy; + return this; + } + + /** + * Gets the update strategy.. + * + * @return The parameter value. + */ + public UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy() { + return updatesStgy; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java new file mode 100644 index 0000000..2e7b947 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains multi-class logistic regression. + */ +package org.apache.ignite.ml.regressions.logistic.multiclass; http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java index 2d21d3b..021b567 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/RegressionsTestSuite.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml.regressions; import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainerTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionModelTest; import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainerTest; +import org.apache.ignite.ml.regressions.logistic.LogRegMultiClassTrainerTest; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModelTest; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainerTest; import org.junit.runner.RunWith; @@ -34,8 +35,9 @@ import org.junit.runners.Suite; LinearRegressionLSQRTrainerTest.class, LinearRegressionSGDTrainerTest.class, LogisticRegressionModelTest.class, - LogisticRegressionSGDTrainerTest.class + LogisticRegressionSGDTrainerTest.class, + LogRegMultiClassTrainerTest.class }) public class RegressionsTestSuite { // No-op. -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java index aac24f4..7ca9121 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModelTest.java @@ -21,6 +21,8 @@ import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.exceptions.CardinalityException; import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; import org.junit.Test; /** @@ -53,6 +55,21 @@ public class LinearRegressionModelTest { } /** */ + @Test + public void testPredictWithMultiClasses() { + Vector weights1 = new DenseLocalOnHeapVector(new double[]{10.0, 0.0}); + Vector weights2 = new DenseLocalOnHeapVector(new double[]{0.0, 10.0}); + Vector weights3 = new DenseLocalOnHeapVector(new double[]{-1.0, -1.0}); + LogRegressionMultiClassModel mdl = new LogRegressionMultiClassModel(); + mdl.add(1, new LogisticRegressionModel(weights1, 0.0).withRawLabels(true)); + mdl.add(2, new LogisticRegressionModel(weights2, 0.0).withRawLabels(true)); + mdl.add(2, new LogisticRegressionModel(weights3, 0.0).withRawLabels(true)); + + Vector observation = new DenseLocalOnHeapVector(new double[]{1.0, 1.0}); + TestUtils.assertEquals( 1.0, mdl.apply(observation), PRECISION); + } + + /** */ @Test(expected = CardinalityException.class) public void testPredictOnAnObservationWithWrongCardinality() { Vector weights = new DenseLocalOnHeapVector(new double[]{2.0, 3.0}); http://git-wip-us.apache.org/repos/asf/ignite/blob/cb8fb736/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java new file mode 100644 index 0000000..d26a4ca --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions.logistic; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.SmoothParametrized; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel; +import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer; +import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for {@link SVMLinearBinaryClassificationTrainer}. + */ +public class LogRegMultiClassTrainerTest { + /** Fixed size of Dataset. */ + private static final int AMOUNT_OF_OBSERVATIONS = 1000; + + /** Fixed size of columns in Dataset. */ + private static final int AMOUNT_OF_FEATURES = 2; + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** + * Test trainer on classification model y = x. + */ + @Test + public void testTrainWithTheLinearlySeparableCase() { + Map<Integer, double[]> data = new HashMap<>(); + + ThreadLocalRandom rndX = ThreadLocalRandom.current(); + ThreadLocalRandom rndY = ThreadLocalRandom.current(); + + for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) { + double x = rndX.nextDouble(-1000, 1000); + double y = rndY.nextDouble(-1000, 1000); + double[] vec = new double[AMOUNT_OF_FEATURES + 1]; + vec[0] = y - x > 0 ? 1 : -1; // assign label. + vec[1] = x; + vec[2] = y; + data.put(i, vec); + } + + final UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> stgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>() + .withUpdatesStgy(stgy) + .withAmountOfIterations(1000) + .withAmountOfLocIterations(10) + .withBatchSize(100) + .withSeed(123L); + + Assert.assertEquals(trainer.amountOfIterations(), 1000); + Assert.assertEquals(trainer.amountOfLocIterations(), 10); + Assert.assertEquals(trainer.batchSize(), 100, PRECISION); + Assert.assertEquals(trainer.seed(), 123L); + Assert.assertEquals(trainer.updatesStgy(), stgy); + + LogRegressionMultiClassModel mdl = trainer.fit( + data, + 10, + (k, v) -> Arrays.copyOfRange(v, 1, v.length), + (k, v) -> v[0] + ); + + TestUtils.assertEquals(-1, mdl.apply(new DenseLocalOnHeapVector(new double[]{100, 10})), PRECISION); + TestUtils.assertEquals(1, mdl.apply(new DenseLocalOnHeapVector(new double[]{10, 100})), PRECISION); + } +}