IGNITE-8741: [ML] Make a tutorial for data preprocessing this closes #4254
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/d66ccb4a Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/d66ccb4a Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/d66ccb4a Branch: refs/heads/ignite-8446 Commit: d66ccb4aac442bfcc580ad10b1d55fc9308b1530 Parents: a1dcb54 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Tue Jun 26 20:28:19 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Tue Jun 26 20:28:19 2018 +0300 ---------------------------------------------------------------------- .../ml/preprocessing/ImputingExample.java | 2 +- ...ecisionTreeClassificationTrainerExample.java | 6 +- .../DecisionTreeRegressionTrainerExample.java | 3 +- .../RandomForestClassificationExample.java | 3 +- .../RandomForestRegressionExample.java | 3 +- .../ml/tutorial/Step_1_Read_and_Learn.java | 108 ++ .../examples/ml/tutorial/Step_2_Imputing.java | 115 ++ .../examples/ml/tutorial/Step_3_Categorial.java | 131 ++ .../ml/tutorial/Step_4_Add_age_fare.java | 131 ++ .../examples/ml/tutorial/Step_5_Scaling.java | 150 ++ .../ignite/examples/ml/tutorial/Step_6_KNN.java | 150 ++ .../ml/tutorial/Step_7_Split_train_test.java | 160 +++ .../ignite/examples/ml/tutorial/Step_8_CV.java | 218 +++ .../ml/tutorial/Step_9_Change_algorithm.java | 251 ++++ .../examples/ml/tutorial/TitanicUtils.java | 91 ++ .../examples/ml/tutorial/package-info.java | 22 + .../src/main/resources/datasets/titanic.csv | 1310 ++++++++++++++++++ .../main/resources/datasets/titanic_10_rows.csv | 11 + .../ml/composition/BaggingModelTrainer.java | 5 +- .../ml/composition/ModelsComposition.java | 24 +- .../StringEncoderPreprocessor.java | 28 +- .../stringencoder/StringEncoderTrainer.java | 87 +- .../binomial/LogisticRegressionModel.java | 6 +- .../cv/CrossValidationScoreCalculator.java | 4 +- .../CacheBasedTruthWithPredictionCursor.java | 8 +- .../util/LocalTruthWithPredictionCursor.java | 10 +- .../ml/tree/DecisionTreeConditionalNode.java | 6 +- .../ignite/ml/tree/DecisionTreeLeafNode.java | 4 +- .../apache/ignite/ml/tree/DecisionTreeNode.java | 3 +- .../ml/tree/data/DecisionTreeDataBuilder.java | 1 + .../encoding/StringEncoderPreprocessorTest.java | 11 +- .../encoding/StringEncoderTrainerTest.java | 4 +- ...CacheBasedTruthWithPredictionCursorTest.java | 2 +- .../LocalTruthWithPredictionCursorTest.java | 2 +- .../DecisionTreeMNISTIntegrationTest.java | 6 +- .../tree/performance/DecisionTreeMNISTTest.java | 10 +- .../RandomForestClassifierTrainerTest.java | 2 +- .../RandomForestRegressionTrainerTest.java | 2 +- 38 files changed, 3011 insertions(+), 79 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java index f873736..68483ad 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java @@ -46,7 +46,7 @@ public class ImputingExample { v.getSalary() }; - // Defines second preprocessor that normalizes features. + // Defines second preprocessor that imputing features. IgniteBiFunction<Integer, Person, double[]> preprocessor = new ImputerTrainer<Integer, Person>() .fit(ignite, persons, featureExtractor); http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java index 1ecf460..ca70b29 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java @@ -17,17 +17,17 @@ package org.apache.ignite.examples.ml.tree; +import java.util.Random; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; import org.apache.ignite.thread.IgniteThread; -import java.util.Random; - /** * Example of using distributed {@link DecisionTreeClassificationTrainer}. */ @@ -76,7 +76,7 @@ public class DecisionTreeClassificationTrainerExample { for (int i = 0; i < 1000; i++) { LabeledPoint pnt = generatePoint(rnd); - double prediction = mdl.apply(new double[]{pnt.x, pnt.y}); + double prediction = mdl.apply(new DenseLocalOnHeapVector(new double[]{pnt.x, pnt.y})); if (prediction == pnt.lb) correctPredictions++; http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java index 19b15f3..cefeee3 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java @@ -22,6 +22,7 @@ import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.tree.DecisionTreeNode; import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer; import org.apache.ignite.thread.IgniteThread; @@ -74,7 +75,7 @@ public class DecisionTreeRegressionTrainerExample { // Calculate score. for (int x = 0; x < 10; x++) { - double predicted = mdl.apply(new double[] {x}); + double predicted = mdl.apply(new DenseLocalOnHeapVector(new double[] {x})); System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x)); } http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java index 0c012dc..e15b311 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java @@ -31,6 +31,7 @@ import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.tree.randomforest.RandomForestClassifierTrainer; import org.apache.ignite.ml.tree.randomforest.RandomForestTrainer; import org.apache.ignite.thread.IgniteThread; @@ -77,7 +78,7 @@ public class RandomForestClassificationExample { double[] inputs = Arrays.copyOfRange(val, 1, val.length); double groundTruth = val[0]; - double prediction = randomForest.apply(inputs); + double prediction = randomForest.apply(new DenseLocalOnHeapVector(inputs)); totalAmount++; if (groundTruth != prediction) http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java index 6019cdb..ca330b8 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java @@ -31,6 +31,7 @@ import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.ScanQuery; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer; import org.apache.ignite.ml.tree.randomforest.RandomForestTrainer; import org.apache.ignite.thread.IgniteThread; @@ -79,7 +80,7 @@ public class RandomForestRegressionExample { double[] inputs = Arrays.copyOfRange(val, 0, val.length - 1); double groundTruth = val[val.length - 1]; - double prediction = randomForest.apply(inputs); + double prediction = randomForest.apply(new DenseLocalOnHeapVector(inputs)); mse += Math.pow(prediction - groundTruth, 2.0); mae += Math.abs(prediction - groundTruth); http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java new file mode 100644 index 0000000..1b8d2a8 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Usage of DecisionTreeClassificationTrainer to predict death in the disaster. + * + * Extract 3 features "pclass", "sibsp", "parch" to use in prediction. + */ +public class Step_1_Read_and_Learn { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_1_Read_and_Learn.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + (k, v) -> new double[]{(double)v[0], (double)v[5], (double)v[6]}, // "pclass", "sibsp", "parch" + (k, v) -> (double)v[1] + ); + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + Object[] val = observation.getValue(); + double[] inputs = new double[]{(double)val[0], (double)val[5], (double)val[6]}; + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs)); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java new file mode 100644 index 0000000..495658d --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Usage of imputer to fill missed data (Double.NaN) values in the chosen columns. + */ +public class Step_2_Imputing { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_2_Imputing.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + (k, v) -> new double[]{(double)v[0], (double)v[5], (double)v[6]} // "pclass", "sibsp", "parch" + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + imputingPreprocessor, + (k, v) -> (double)v[1] + ); + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(imputingPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java new file mode 100644 index 0000000..3284e94 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Let's add two categorial features "sex", "embarked" to predict more precisely. + * + * To encode categorial features the StringEncoderTrainer will be used. + */ +public class Step_3_Categorial { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_3_Categorial.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[5], v[6], v[10]}; // "pclass", "sibsp", "parch", "sex", "embarked" + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(4) + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + imputingPreprocessor, + (k, v) -> (double)v[1] + ); + + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(imputingPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java new file mode 100644 index 0000000..22d316b --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Add yet two numerical features "age", "fare" to improve our model. + */ +public class Step_4_Add_age_fare { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_4_Add_age_fare.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + imputingPreprocessor, + (k, v) -> (double)v[1] + ); + + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(imputingPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java new file mode 100644 index 0000000..290b73a --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * MinMaxScalerTrainer and NormalizationTrainer are used in this example due to different values distribution in columns and rows. + */ +public class Step_5_Scaling { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_5_Scaling.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + + IgniteBiFunction<Integer, Object[], double[]> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(1) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + normalizationPreprocessor, + (k, v) -> (double)v[1] + ); + + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector( + normalizationPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java new file mode 100644 index 0000000..897f7c6 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.knn.classification.KNNClassificationModel; +import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; +import org.apache.ignite.ml.knn.classification.KNNStrategy; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.thread.IgniteThread; + +/** + * Sometimes is better to change algorithm, let's say on kNN. + */ +public class Step_6_KNN { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_6_KNN.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + + IgniteBiFunction<Integer, Object[], double[]> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(1) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + KNNClassificationTrainer trainer = new KNNClassificationTrainer(); + + // Train decision tree model. + KNNClassificationModel mdl = trainer.fit( + ignite, + dataCache, + normalizationPreprocessor, + (k, v) -> (double)v[1] + ).withK(1).withStrategy(KNNStrategy.WEIGHTED); + + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector(normalizationPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java new file mode 100644 index 0000000..7e99a46 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * The highest accuracy in the previous example is the result of overfitting. + * + * For real model estimation is better to use test-train split via TrainTestDatasetSplitter. + */ +public class Step_7_Split_train_test { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_7_Split_train_test.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() + .split(0.75); + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + + IgniteBiFunction<Integer, Object[], double[]> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(1) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + + // Train decision tree model. + DecisionTreeNode mdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + (k, v) -> (double)v[1] + ); + + + System.out.println(">>> ----------------------------------------------------------------"); + System.out.println(">>> | Prediction\t| Ground Truth\t| Name\t|"); + System.out.println(">>> ----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + ScanQuery<Integer, Object[]> qry = new ScanQuery<>(); + qry.setFilter(split.getTestFilter()); + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(qry)) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = mdl.apply(new DenseLocalOnHeapVector( + normalizationPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf(">>>| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println(">>> ---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java new file mode 100644 index 0000000..b04ca81 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.selection.cv.CrossValidationScoreCalculator; +import org.apache.ignite.ml.selection.score.AccuracyScoreCalculator; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * To choose the best hyperparameters the cross-validation will be used in this example. + * + * The purpose of cross-validation is model checking, not model building. + * + * We train k different models. + * + * They differ in that 1/(k-1)th of the training data is exchanged against other cases. + * + * These models are sometimes called surrogate models because the (average) performance measured for these models + * is taken as a surrogate of the performance of the model trained on all cases. + * + * All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html + */ +public class Step_8_CV { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_8_CV.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() + .split(0.75); + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + IgniteBiFunction<Integer, Object[], double[]> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + // Tune hyperparams with K-fold Cross-Validation on the splitted training set. + int[] pSet = new int[]{1, 2}; + int[] maxDeepSet = new int[]{1, 2, 3, 4, 5, 10, 20}; + int bestP = 1; + int bestMaxDeep = 1; + double avg = Double.MIN_VALUE; + + for(int p: pSet){ + for(int maxDeep: maxDeepSet){ + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(p) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(maxDeep, 0); + + CrossValidationScoreCalculator<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator + = new CrossValidationScoreCalculator<>(); + + double[] scores = scoreCalculator.score( + trainer, + new AccuracyScoreCalculator<>(), + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + (k, v) -> (double) v[1], + 3 + ); + + System.out.println("Scores are: " + Arrays.toString(scores)); + + final double currAvg = Arrays.stream(scores).average().orElse(Double.MIN_VALUE); + + if(currAvg > avg) { + avg = currAvg; + bestP = p; + bestMaxDeep = maxDeep; + } + + System.out.println("Avg is: " + currAvg + " with p: " + p + " with maxDeep: " + maxDeep); + } + } + + System.out.println("Train with p: " + bestP + " and maxDeep: " + bestMaxDeep); + + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(bestP) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(bestMaxDeep, 0); + + // Train decision tree model. + DecisionTreeNode bestMdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + (k, v) -> (double)v[1] + ); + + System.out.println("----------------------------------------------------------------"); + System.out.println("| Prediction\t| Ground Truth\t| Name\t|"); + System.out.println("----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + ScanQuery<Integer, Object[]> qry = new ScanQuery<>(); + qry.setFilter(split.getTestFilter()); + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(qry)) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = bestMdl.apply(new DenseLocalOnHeapVector( + normalizationPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf("| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println("---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Change_algorithm.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Change_algorithm.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Change_algorithm.java new file mode 100644 index 0000000..657f106 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Change_algorithm.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import javax.cache.Cache; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderTrainer; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.selection.cv.CrossValidationScoreCalculator; +import org.apache.ignite.ml.selection.score.AccuracyScoreCalculator; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.apache.ignite.thread.IgniteThread; + +/** + * Maybe the another algorithm can give us the higher accuracy? + * + * Let's win with the LogisticRegressionSGDTrainer! + */ +public class Step_9_Change_algorithm { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + Step_9_Change_algorithm.class.getSimpleName(), () -> { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[]{v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() + .split(0.75); + + IgniteBiFunction<Integer, Object[], double[]> strEncoderPreprocessor = new StringEncoderTrainer<Integer, Object[]>() + .encodeFeature(1) + .encodeFeature(6) // <--- Changed index here + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], double[]> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + IgniteBiFunction<Integer, Object[], double[]> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + // Tune hyperparams with K-fold Cross-Validation on the splitted training set. + int[] pSet = new int[]{1, 2}; + int[] maxIterationsSet = new int[]{ 100, 1000}; + int[] batchSizeSet = new int[]{100, 10}; + int[] locIterationsSet = new int[]{10, 100}; + double[] learningRateSet = new double[]{0.1, 0.2, 0.5}; + + + int bestP = 1; + int bestMaxIterations = 100; + int bestBatchSize = 10; + int bestLocIterations = 10; + double bestLearningRate = 0.0; + double avg = Double.MIN_VALUE; + + for(int p: pSet){ + for(int maxIterations: maxIterationsSet) { + for (int batchSize : batchSizeSet) { + for (int locIterations : locIterationsSet) { + for (double learningRate : learningRateSet) { + + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(p) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(learningRate), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), maxIterations, batchSize, locIterations, 123L); + + CrossValidationScoreCalculator<LogisticRegressionModel, Double, Integer, Object[]> scoreCalculator + = new CrossValidationScoreCalculator<>(); + + double[] scores = scoreCalculator.score( + trainer, + new AccuracyScoreCalculator<>(), + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + (k, v) -> (double)v[1], + 3 + ); + + System.out.println("Scores are: " + Arrays.toString(scores)); + + final double currAvg = Arrays.stream(scores).average().orElse(Double.MIN_VALUE); + + if (currAvg > avg) { + avg = currAvg; + bestP = p; + bestMaxIterations = maxIterations; + bestBatchSize = batchSize; + bestLearningRate = learningRate; + bestLocIterations = locIterations; + } + + System.out.println("Avg is: " + currAvg + + " with p: " + p + + " with maxIterations: " + maxIterations + + " with batchSize: " + batchSize + + " with learningRate: " + learningRate + + " with locIterations: " + locIterations + ); + } + } + } + } + } + + System.out.println("Train " + + " with p: " + bestP + + " with maxIterations: " + bestMaxIterations + + " with batchSize: " + bestBatchSize + + " with learningRate: " + bestLearningRate + + " with locIterations: " + bestLocIterations + ); + + IgniteBiFunction<Integer, Object[], double[]> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(bestP) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(bestLearningRate), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ), bestMaxIterations, bestBatchSize, bestLocIterations, 123L); + + System.out.println(">>> Perform the training to get the model."); + LogisticRegressionModel bestMdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + (k, v) -> (double)v[1] + ); + + + System.out.println("----------------------------------------------------------------"); + System.out.println("| Prediction\t| Ground Truth\t| Name\t|"); + System.out.println("----------------------------------------------------------------"); + + int amountOfErrors = 0; + int totalAmount = 0; + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0}, {0, 0}}; + + ScanQuery<Integer, Object[]> qry = new ScanQuery<>(); + qry.setFilter(split.getTestFilter()); + + try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(qry)) { + for (Cache.Entry<Integer, Object[]> observation : observations) { + + Object[] val = observation.getValue(); + double groundTruth = (double)val[1]; + String name = (String)val[2]; + + double prediction = bestMdl.apply(new DenseLocalOnHeapVector( + normalizationPreprocessor.apply(observation.getKey(), val))); + + totalAmount++; + if (groundTruth != prediction) + amountOfErrors++; + + int idx1 = (int)prediction; + int idx2 = (int)groundTruth; + + confusionMtx[idx1][idx2]++; + + System.out.printf("| %.4f\t\t| %.4f\t\t\t\t\t\t| %s\t\t\t\t\t\t\t\t\t\t|\n", prediction, groundTruth, name); + } + + System.out.println("---------------------------------"); + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + double accuracy = 1 - amountOfErrors / (double)totalAmount; + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + System.out.println(">>> ---------------------------------"); + } + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/d66ccb4a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java new file mode 100644 index 0000000..a339638 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.File; +import java.io.FileNotFoundException; +import java.text.NumberFormat; +import java.text.ParseException; +import java.util.Locale; +import java.util.Scanner; +import java.util.UUID; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; + +/** + * The utility class. + */ +public class TitanicUtils { + /** + * Read passengers data from csv file. + * + * @param ignite The ignite. + * @return The filled cache. + * @throws FileNotFoundException + */ + public static IgniteCache<Integer, Object[]> readPassengers(Ignite ignite) + throws FileNotFoundException { + IgniteCache<Integer, Object[]> cache = getCache(ignite); + Scanner scanner = new Scanner(new File("examples/src/main/resources/datasets/titanic.csv")); + + int cnt = 0; + while (scanner.hasNextLine()) { + String row = scanner.nextLine(); + if(cnt == 0) { + cnt++; + continue; + } + String[] cells = row.split(";"); + Object[] data = new Object[cells.length]; + NumberFormat format = NumberFormat.getInstance(Locale.FRANCE); + + for (int i = 0; i < cells.length; i++) + try{ + if(cells[i].equals("")) data[i] = Double.NaN; + else data[i] = Double.valueOf(cells[i]); + } catch (java.lang.NumberFormatException e) { + + try { + data[i] = format.parse(cells[i]).doubleValue(); + } + catch (ParseException e1) { + data[i] = cells[i]; + } + } + cache.put(cnt++, data); + } + return cache; + } + + /** + * Fills cache with data and returns it. + * + * @param ignite Ignite instance. + * @return Filled Ignite Cache. + */ + private static IgniteCache<Integer, Object[]> getCache(Ignite ignite) { + + CacheConfiguration<Integer, Object[]> cacheConfiguration = new CacheConfiguration<>(); + cacheConfiguration.setName("TUTORIAL_" + UUID.randomUUID()); + cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10)); + + return ignite.createCache(cacheConfiguration); + } +}