Repository: ignite Updated Branches: refs/heads/master 03bb55138 -> da782958a
IGNITE-7079: Add examples for kNN classification and for kNN regression this closes #3218 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/da782958 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/da782958 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/da782958 Branch: refs/heads/master Commit: da782958adad85b750ffdc39266644c5750d0f2f Parents: 03bb551 Author: zaleslaw <[email protected]> Authored: Thu Dec 14 20:27:15 2017 +0300 Committer: Yury Babak <[email protected]> Committed: Thu Dec 14 20:27:15 2017 +0300 ---------------------------------------------------------------------- .../KNNClassificationExample.java | 151 ++++++++++++++ .../ml/knn/classification/package-info.java | 22 ++ .../ignite/examples/ml/knn/package-info.java | 22 ++ .../ml/knn/regression/KNNRegressionExample.java | 152 ++++++++++++++ .../ml/knn/regression/package-info.java | 22 ++ .../src/main/resources/datasets/knn/README.md | 2 + .../resources/datasets/knn/cleared_machines.txt | 209 +++++++++++++++++++ .../src/main/resources/datasets/knn/iris.txt | 150 +++++++++++++ .../ignite/ml/structures/LabeledDataset.java | 18 ++ .../structures/LabeledDatasetTestTrainPair.java | 116 ++++++++++ .../ignite/ml/knn/LabeledDatasetTest.java | 58 +++++ 11 files changed, 922 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java new file mode 100644 index 0000000..a92e9af --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.knn.classification; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.examples.ExampleNodeStartup; +import org.apache.ignite.ml.knn.models.KNNModel; +import org.apache.ignite.ml.knn.models.KNNStrategy; +import org.apache.ignite.ml.math.distances.EuclideanDistance; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; +import org.apache.ignite.thread.IgniteThread; + +/** + * <p> + * Example of using {@link KNNModel} with iris dataset.</p> + * <p> + * Note that in this example we cannot guarantee order in which nodes return results of intermediate + * computations and therefore algorithm can return different results.</p> + * <p> + * Remote nodes should always be started with special configuration file which + * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> + * <p> + * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node + * with {@code examples/config/example-ignite.xml} configuration.</p> + */ +public class KNNClassificationExample { + /** Separator. */ + private static final String SEPARATOR = "\t"; + + /** Path to the Iris dataset. */ + static final String KNN_IRIS_TXT = "datasets/knn/iris.txt"; + + /** + * Executes example. + * + * @param args Command line arguments, none required. + */ + public static void main(String[] args) throws InterruptedException { + System.out.println(">>> kNN classification example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + KNNClassificationExample.class.getSimpleName(), () -> { + + try { + // Prepare path to read + Path path = Paths.get(KNNClassificationExample.class.getClassLoader().getResource(KNN_IRIS_TXT).toURI()); + + // Read dataset from file + LabeledDataset dataset = LabeledDataset.loadTxt(path, SEPARATOR, true, false); + + // Random splitting of iris data as 70% train and 30% test datasets + LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.3); + + System.out.println("\n>>> Amount of observations in train dataset " + split.train().rowSize()); + System.out.println("\n>>> Amount of observations in test dataset " + split.test().rowSize()); + + LabeledDataset test = split.test(); + LabeledDataset train = split.train(); + + KNNModel knnMdl = new KNNModel(5, new EuclideanDistance(), KNNStrategy.SIMPLE, train); + + // Clone labels + final double[] labels = test.labels(); + + // Save predicted classes to test dataset + for (int i = 0; i < test.rowSize(); i++) { + double predictedCls = knnMdl.predict(test.getRow(i).features()); + test.setLabel(i, predictedCls); + } + + // Calculate amount of errors on test dataset + int amountOfErrors = 0; + for (int i = 0; i < test.rowSize(); i++) { + if (test.label(i) != labels[i]) + amountOfErrors++; + } + + System.out.println("\n>>> Absolute amount of errors " + amountOfErrors); + System.out.println("\n>>> Accuracy " + amountOfErrors / (double)test.rowSize()); + + // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix + int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}}; + for (int i = 0; i < test.rowSize(); i++) { + int idx1 = (int)test.label(i); + int idx2 = (int)labels[i]; + confusionMtx[idx1 - 1][idx2 - 1]++; + } + System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx)); + + // Calculate precision, recall and F-metric for each class + for (int i = 0; i < 3; i++) { + double precision = 0.0; + for (int j = 0; j < 3; j++) + precision += confusionMtx[i][j]; + precision = confusionMtx[i][i] / precision; + + double clsLb = (double)(i + 1); + + System.out.println("\n>>> Precision for class " + clsLb + " is " + precision); + + double recall = 0.0; + for (int j = 0; j < 3; j++) + recall += confusionMtx[j][i]; + recall = confusionMtx[i][i] / recall; + System.out.println("\n>>> Recall for class " + clsLb + " is " + recall); + + double fScore = 2 * precision * recall / (precision + recall); + System.out.println("\n>>> F-score for class " + clsLb + " is " + fScore); + } + + } + catch (URISyntaxException | IOException e) { + e.printStackTrace(); + System.out.println("\n>>> Check resources"); + } + finally { + System.out.println("\n>>> kNN classification example completed."); + } + + }); + + igniteThread.start(); + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/package-info.java new file mode 100644 index 0000000..d853f0d --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/classification/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * kNN classification examples. + */ +package org.apache.ignite.examples.ml.knn.classification; http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/ml/org/apache/ignite/examples/ml/knn/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/package-info.java new file mode 100644 index 0000000..8de4656 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * kNN examples. + */ +package org.apache.ignite.examples.ml.knn; http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java new file mode 100644 index 0000000..f4a9e1c --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.knn.regression; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.examples.ExampleNodeStartup; +import org.apache.ignite.examples.ml.knn.classification.KNNClassificationExample; +import org.apache.ignite.ml.knn.models.KNNStrategy; +import org.apache.ignite.ml.knn.models.Normalization; +import org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression; +import org.apache.ignite.ml.math.distances.ManhattanDistance; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; +import org.apache.ignite.thread.IgniteThread; + +/** + * <p> + * Example of using {@link KNNMultipleLinearRegression} with iris dataset.</p> + * <p> + * Note that in this example we cannot guarantee order in which nodes return results of intermediate + * computations and therefore algorithm can return different results.</p> + * <p> + * Remote nodes should always be started with special configuration file which + * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p> + * <p> + * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node + * with {@code examples/config/example-ignite.xml} configuration.</p> + */ +public class KNNRegressionExample { + /** Separator. */ + private static final String SEPARATOR = ","; + + /** */ + public static final String KNN_CLEARED_MACHINES_TXT = "datasets/knn/cleared_machines.txt"; + + /** + * Executes example. + * + * @param args Command line arguments, none required. + */ + public static void main(String[] args) throws InterruptedException { + System.out.println(">>> kNN regression example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + KNNRegressionExample.class.getSimpleName(), () -> { + + try { + // Prepare path to read + Path path = Paths.get(KNNClassificationExample.class.getClassLoader().getResource(KNN_CLEARED_MACHINES_TXT).toURI()); + + // Read dataset from file + LabeledDataset dataset = LabeledDataset.loadTxt(path, SEPARATOR, false, false); + + // Normalize dataset + dataset.normalizeWith(Normalization.MINIMAX); + + // Random splitting of iris data as 80% train and 20% test datasets + LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.2); + + System.out.println("\n>>> Amount of observations in train dataset " + split.train().rowSize()); + System.out.println("\n>>> Amount of observations in test dataset " + split.test().rowSize()); + + LabeledDataset test = split.test(); + LabeledDataset train = split.train(); + + // Builds weighted kNN-regression with Manhattan Distance + KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(7, new ManhattanDistance(), KNNStrategy.WEIGHTED, train); + + // Clone labels + final double[] labels = test.labels(); + + // Save predicted classes to test dataset + for (int i = 0; i < test.rowSize(); i++) { + double predictedCls = knnMdl.predict(test.getRow(i).features()); + test.setLabel(i, predictedCls); + } + + // Calculate mean squared error (MSE) + double mse = 0.0; + + for (int i = 0; i < test.rowSize(); i++) + mse += Math.pow(test.label(i) - labels[i], 2.0); + mse = mse / test.rowSize(); + + System.out.println("\n>>> Mean squared error (MSE) " + mse); + + // Calculate mean absolute error (MAE) + double mae = 0.0; + + for (int i = 0; i < test.rowSize(); i++) + mae += Math.abs(test.label(i) - labels[i]); + mae = mae / test.rowSize(); + + System.out.println("\n>>> Mean absolute error (MAE) " + mae); + + // Calculate R^2 as 1 - RSS/TSS + double avg = 0.0; + + for (int i = 0; i < test.rowSize(); i++) + avg += test.label(i); + + avg = avg / test.rowSize(); + + double detCf = 0.0; + double tss = 0.0; + + for (int i = 0; i < test.rowSize(); i++) { + detCf += Math.pow(test.label(i) - labels[i], 2.0); + tss += Math.pow(test.label(i) - avg, 2.0); + } + + detCf = 1 - detCf / tss; + + System.out.println("\n>>> R^2 " + detCf); + } + catch (URISyntaxException | IOException e) { + e.printStackTrace(); + System.out.println("\n>>> Check resources"); + } + finally { + System.out.println("\n>>> kNN regression example completed."); + } + }); + + igniteThread.start(); + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/package-info.java new file mode 100644 index 0000000..e7ac336 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/knn/regression/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * kNN regression examples. + */ +package org.apache.ignite.examples.ml.knn.regression; http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/resources/datasets/knn/README.md ---------------------------------------------------------------------- diff --git a/examples/src/main/resources/datasets/knn/README.md b/examples/src/main/resources/datasets/knn/README.md new file mode 100644 index 0000000..2f9c5ec --- /dev/null +++ b/examples/src/main/resources/datasets/knn/README.md @@ -0,0 +1,2 @@ +iris.txt and cleared_machines are from Lichman, M. (2013). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. +Read more about machine dataset http://archive.ics.uci.edu/ml/machine-learning-databases/cpu-performance/machine.names \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/resources/datasets/knn/cleared_machines.txt ---------------------------------------------------------------------- diff --git a/examples/src/main/resources/datasets/knn/cleared_machines.txt b/examples/src/main/resources/datasets/knn/cleared_machines.txt new file mode 100644 index 0000000..cf8b6b0 --- /dev/null +++ b/examples/src/main/resources/datasets/knn/cleared_machines.txt @@ -0,0 +1,209 @@ +199,125,256,6000,256,16,128 +253,29,8000,32000,32,8,32 +253,29,8000,32000,32,8,32 +253,29,8000,32000,32,8,32 +132,29,8000,16000,32,8,16 +290,26,8000,32000,64,8,32 +381,23,16000,32000,64,16,32 +381,23,16000,32000,64,16,32 +749,23,16000,64000,64,16,32 +1238,23,32000,64000,128,32,64 +23,400,1000,3000,0,1,2 +24,400,512,3500,4,1,6 +70,60,2000,8000,65,1,8 +117,50,4000,16000,65,1,8 +15,350,64,64,0,1,4 +64,200,512,16000,0,4,32 +23,167,524,2000,8,4,15 +29,143,512,5000,0,7,32 +22,143,1000,2000,0,5,16 +124,110,5000,5000,142,8,64 +35,143,1500,6300,0,5,32 +39,143,3100,6200,0,5,20 +40,143,2300,6200,0,6,64 +45,110,3100,6200,0,6,64 +28,320,128,6000,0,1,12 +21,320,512,2000,4,1,3 +28,320,256,6000,0,1,6 +22,320,256,3000,4,1,3 +28,320,512,5000,4,1,5 +27,320,256,5000,4,1,6 +102,25,1310,2620,131,12,24 +102,25,1310,2620,131,12,24 +74,50,2620,10480,30,12,24 +74,50,2620,10480,30,12,24 +138,56,5240,20970,30,12,24 +136,64,5240,20970,30,12,24 +23,50,500,2000,8,1,4 +29,50,1000,4000,8,1,5 +44,50,2000,8000,8,1,5 +30,50,1000,4000,8,3,5 +41,50,1000,8000,8,3,5 +74,50,2000,16000,8,3,5 +74,50,2000,16000,8,3,6 +74,50,2000,16000,8,3,6 +54,133,1000,12000,9,3,12 +41,133,1000,8000,9,3,12 +18,810,512,512,8,1,1 +28,810,1000,5000,0,1,1 +36,320,512,8000,4,1,5 +38,200,512,8000,8,1,8 +34,700,384,8000,0,1,1 +19,700,256,2000,0,1,1 +72,140,1000,16000,16,1,3 +36,200,1000,8000,0,1,2 +30,110,1000,4000,16,1,2 +56,110,1000,12000,16,1,2 +42,220,1000,8000,16,1,2 +34,800,256,8000,0,1,4 +34,800,256,8000,0,1,4 +34,800,256,8000,0,1,4 +34,800,256,8000,0,1,4 +34,800,256,8000,0,1,4 +19,125,512,1000,0,8,20 +75,75,2000,8000,64,1,38 +113,75,2000,16000,64,1,38 +157,75,2000,16000,128,1,38 +18,90,256,1000,0,3,10 +20,105,256,2000,0,3,10 +28,105,1000,4000,0,3,24 +33,105,2000,4000,8,3,19 +47,75,2000,8000,8,3,24 +54,75,3000,8000,8,3,48 +20,175,256,2000,0,3,24 +23,300,768,3000,0,6,24 +25,300,768,3000,6,6,24 +52,300,768,12000,6,6,24 +27,300,768,4500,0,1,24 +50,300,384,12000,6,1,24 +18,300,192,768,6,6,24 +53,180,768,12000,6,1,31 +23,330,1000,3000,0,2,4 +30,300,1000,4000,8,3,64 +73,300,1000,16000,8,2,112 +20,330,1000,2000,0,1,2 +25,330,1000,4000,0,3,6 +28,140,2000,4000,0,3,6 +29,140,2000,4000,0,4,8 +32,140,2000,4000,8,1,20 +175,140,2000,32000,32,1,20 +57,140,2000,8000,32,1,54 +181,140,2000,32000,32,1,54 +181,140,2000,32000,32,1,54 +32,140,2000,4000,8,1,20 +82,57,4000,16000,1,6,12 +171,57,4000,24000,64,12,16 +361,26,16000,32000,64,16,24 +350,26,16000,32000,64,8,24 +220,26,8000,32000,0,8,24 +113,26,8000,16000,0,8,16 +15,480,96,512,0,1,1 +21,203,1000,2000,0,1,5 +35,115,512,6000,16,1,6 +18,1100,512,1500,0,1,1 +20,1100,768,2000,0,1,1 +20,600,768,2000,0,1,1 +28,400,2000,4000,0,1,1 +45,400,4000,8000,0,1,1 +18,900,1000,1000,0,1,2 +17,900,512,1000,0,1,2 +26,900,1000,4000,4,1,2 +28,900,1000,4000,8,1,2 +28,900,2000,4000,0,3,6 +31,225,2000,4000,8,3,6 +31,225,2000,4000,8,3,6 +42,180,2000,8000,8,1,6 +76,185,2000,16000,16,1,6 +76,180,2000,16000,16,1,6 +26,225,1000,4000,2,3,6 +59,25,2000,12000,8,1,4 +65,25,2000,12000,16,3,5 +101,17,4000,16000,8,6,12 +116,17,4000,16000,32,6,12 +18,1500,768,1000,0,0,0 +20,1500,768,2000,0,0,0 +20,800,768,2000,0,0,0 +30,50,2000,4000,0,3,6 +44,50,2000,8000,8,3,6 +44,50,2000,8000,8,1,6 +82,50,2000,16000,24,1,6 +82,50,2000,16000,24,1,6 +128,50,8000,16000,48,1,10 +37,100,1000,8000,0,2,6 +46,100,1000,8000,24,2,6 +46,100,1000,8000,24,3,6 +80,50,2000,16000,12,3,16 +88,50,2000,16000,24,6,16 +88,50,2000,16000,24,6,16 +33,150,512,4000,0,8,128 +46,115,2000,8000,16,1,3 +29,115,2000,4000,2,1,5 +53,92,2000,8000,32,1,6 +53,92,2000,8000,32,1,6 +41,92,2000,8000,4,1,6 +86,75,4000,16000,16,1,6 +95,60,4000,16000,32,1,6 +107,60,2000,16000,64,5,8 +117,60,4000,16000,64,5,8 +119,50,4000,16000,64,5,10 +120,72,4000,16000,64,8,16 +48,72,2000,8000,16,6,8 +126,40,8000,16000,32,8,16 +266,40,8000,32000,64,8,24 +270,35,8000,32000,64,8,24 +426,38,16000,32000,128,16,32 +151,48,4000,24000,32,8,24 +267,38,8000,32000,64,8,24 +603,30,16000,32000,256,16,24 +19,112,1000,1000,0,1,4 +21,84,1000,2000,0,1,6 +26,56,1000,4000,0,1,6 +35,56,2000,6000,0,1,8 +41,56,2000,8000,0,1,8 +47,56,4000,8000,0,1,8 +62,56,4000,12000,0,1,8 +78,56,4000,16000,0,1,8 +80,38,4000,8000,32,16,32 +80,38,4000,8000,32,16,32 +142,38,8000,16000,64,4,8 +281,38,8000,24000,160,4,8 +190,38,4000,16000,128,16,32 +21,200,1000,2000,0,1,2 +25,200,1000,4000,0,1,4 +67,200,2000,8000,64,1,5 +24,250,512,4000,0,1,7 +24,250,512,4000,0,4,7 +64,250,1000,16000,1,1,8 +25,160,512,4000,2,1,5 +20,160,512,2000,2,3,8 +29,160,1000,4000,8,1,14 +43,160,1000,8000,16,1,14 +53,160,2000,8000,32,1,13 +19,240,512,1000,8,1,3 +22,240,512,2000,8,1,5 +31,105,2000,4000,8,3,8 +41,105,2000,6000,16,6,16 +47,105,2000,8000,16,4,14 +99,52,4000,16000,32,4,12 +67,70,4000,12000,8,6,8 +81,59,4000,12000,32,6,12 +149,59,8000,16000,64,12,24 +183,26,8000,24000,32,8,16 +275,26,8000,32000,64,12,16 +382,26,8000,32000,128,24,32 +56,116,2000,8000,32,5,28 +182,50,2000,32000,24,6,26 +227,50,2000,32000,48,26,52 +341,50,2000,32000,112,52,104 +360,50,4000,32000,112,52,104 +919,30,8000,64000,96,12,176 +978,30,8000,64000,128,12,176 +24,180,262,4000,0,1,3 +24,180,512,4000,0,1,3 +24,180,262,4000,0,1,3 +24,180,512,4000,0,1,3 +37,124,1000,8000,0,1,8 +50,98,1000,8000,32,2,8 +41,125,2000,8000,0,2,14 +47,480,512,8000,32,0,0 +25,480,1000,4000,0,0,0 http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/examples/src/main/resources/datasets/knn/iris.txt ---------------------------------------------------------------------- diff --git a/examples/src/main/resources/datasets/knn/iris.txt b/examples/src/main/resources/datasets/knn/iris.txt new file mode 100644 index 0000000..18f5f7c --- /dev/null +++ b/examples/src/main/resources/datasets/knn/iris.txt @@ -0,0 +1,150 @@ +1.0 5.1 3.5 1.4 0.2 +1.0 4.9 3.0 1.4 0.2 +1.0 4.7 3.2 1.3 0.2 +1.0 4.6 3.1 1.5 0.2 +1.0 5.0 3.6 1.4 0.2 +1.0 5.4 3.9 1.7 0.4 +1.0 4.6 3.4 1.4 0.3 +1.0 5.0 3.4 1.5 0.2 +1.0 4.4 2.9 1.4 0.2 +1.0 4.9 3.1 1.5 0.1 +1.0 5.4 3.7 1.5 0.2 +1.0 4.8 3.4 1.6 0.2 +1.0 4.8 3.0 1.4 0.1 +1.0 4.3 3.0 1.1 0.1 +1.0 5.8 4.0 1.2 0.2 +1.0 5.7 4.4 1.5 0.4 +1.0 5.4 3.9 1.3 0.4 +1.0 5.1 3.5 1.4 0.3 +1.0 5.7 3.8 1.7 0.3 +1.0 5.1 3.8 1.5 0.3 +1.0 5.4 3.4 1.7 0.2 +1.0 5.1 3.7 1.5 0.4 +1.0 4.6 3.6 1.0 0.2 +1.0 5.1 3.3 1.7 0.5 +1.0 4.8 3.4 1.9 0.2 +1.0 5.0 3.0 1.6 0.2 +1.0 5.0 3.4 1.6 0.4 +1.0 5.2 3.5 1.5 0.2 +1.0 5.2 3.4 1.4 0.2 +1.0 4.7 3.2 1.6 0.2 +1.0 4.8 3.1 1.6 0.2 +1.0 5.4 3.4 1.5 0.4 +1.0 5.2 4.1 1.5 0.1 +1.0 5.5 4.2 1.4 0.2 +1.0 4.9 3.1 1.5 0.1 +1.0 5.0 3.2 1.2 0.2 +1.0 5.5 3.5 1.3 0.2 +1.0 4.9 3.1 1.5 0.1 +1.0 4.4 3.0 1.3 0.2 +1.0 5.1 3.4 1.5 0.2 +1.0 5.0 3.5 1.3 0.3 +1.0 4.5 2.3 1.3 0.3 +1.0 4.4 3.2 1.3 0.2 +1.0 5.0 3.5 1.6 0.6 +1.0 5.1 3.8 1.9 0.4 +1.0 4.8 3.0 1.4 0.3 +1.0 5.1 3.8 1.6 0.2 +1.0 4.6 3.2 1.4 0.2 +1.0 5.3 3.7 1.5 0.2 +1.0 5.0 3.3 1.4 0.2 +2.0 7.0 3.2 4.7 1.4 +2.0 6.4 3.2 4.5 1.5 +2.0 6.9 3.1 4.9 1.5 +2.0 5.5 2.3 4.0 1.3 +2.0 6.5 2.8 4.6 1.5 +2.0 5.7 2.8 4.5 1.3 +2.0 6.3 3.3 4.7 1.6 +2.0 4.9 2.4 3.3 1.0 +2.0 6.6 2.9 4.6 1.3 +2.0 5.2 2.7 3.9 1.4 +2.0 5.0 2.0 3.5 1.0 +2.0 5.9 3.0 4.2 1.5 +2.0 6.0 2.2 4.0 1.0 +2.0 6.1 2.9 4.7 1.4 +2.0 5.6 2.9 3.6 1.3 +2.0 6.7 3.1 4.4 1.4 +2.0 5.6 3.0 4.5 1.5 +2.0 5.8 2.7 4.1 1.0 +2.0 6.2 2.2 4.5 1.5 +2.0 5.6 2.5 3.9 1.1 +2.0 5.9 3.2 4.8 1.8 +2.0 6.1 2.8 4.0 1.3 +2.0 6.3 2.5 4.9 1.5 +2.0 6.1 2.8 4.7 1.2 +2.0 6.4 2.9 4.3 1.3 +2.0 6.6 3.0 4.4 1.4 +2.0 6.8 2.8 4.8 1.4 +2.0 6.7 3.0 5.0 1.7 +2.0 6.0 2.9 4.5 1.5 +2.0 5.7 2.6 3.5 1.0 +2.0 5.5 2.4 3.8 1.1 +2.0 5.5 2.4 3.7 1.0 +2.0 5.8 2.7 3.9 1.2 +2.0 6.0 2.7 5.1 1.6 +2.0 5.4 3.0 4.5 1.5 +2.0 6.0 3.4 4.5 1.6 +2.0 6.7 3.1 4.7 1.5 +2.0 6.3 2.3 4.4 1.3 +2.0 5.6 3.0 4.1 1.3 +2.0 5.5 2.5 4.0 1.3 +2.0 5.5 2.6 4.4 1.2 +2.0 6.1 3.0 4.6 1.4 +2.0 5.8 2.6 4.0 1.2 +2.0 5.0 2.3 3.3 1.0 +2.0 5.6 2.7 4.2 1.3 +2.0 5.7 3.0 4.2 1.2 +2.0 5.7 2.9 4.2 1.3 +2.0 6.2 2.9 4.3 1.3 +2.0 5.1 2.5 3.0 1.1 +2.0 5.7 2.8 4.1 1.3 +3.0 6.3 3.3 6.0 2.5 +3.0 5.8 2.7 5.1 1.9 +3.0 7.1 3.0 5.9 2.1 +3.0 6.3 2.9 5.6 1.8 +3.0 6.5 3.0 5.8 2.2 +3.0 7.6 3.0 6.6 2.1 +3.0 4.9 2.5 4.5 1.7 +3.0 7.3 2.9 6.3 1.8 +3.0 6.7 2.5 5.8 1.8 +3.0 7.2 3.6 6.1 2.5 +3.0 6.5 3.2 5.1 2.0 +3.0 6.4 2.7 5.3 1.9 +3.0 6.8 3.0 5.5 2.1 +3.0 5.7 2.5 5.0 2.0 +3.0 5.8 2.8 5.1 2.4 +3.0 6.4 3.2 5.3 2.3 +3.0 6.5 3.0 5.5 1.8 +3.0 7.7 3.8 6.7 2.2 +3.0 7.7 2.6 6.9 2.3 +3.0 6.0 2.2 5.0 1.5 +3.0 6.9 3.2 5.7 2.3 +3.0 5.6 2.8 4.9 2.0 +3.0 7.7 2.8 6.7 2.0 +3.0 6.3 2.7 4.9 1.8 +3.0 6.7 3.3 5.7 2.1 +3.0 7.2 3.2 6.0 1.8 +3.0 6.2 2.8 4.8 1.8 +3.0 6.1 3.0 4.9 1.8 +3.0 6.4 2.8 5.6 2.1 +3.0 7.2 3.0 5.8 1.6 +3.0 7.4 2.8 6.1 1.9 +3.0 7.9 3.8 6.4 2.0 +3.0 6.4 2.8 5.6 2.2 +3.0 6.3 2.8 5.1 1.5 +3.0 6.1 2.6 5.6 1.4 +3.0 7.7 3.0 6.1 2.3 +3.0 6.3 3.4 5.6 2.4 +3.0 6.4 3.1 5.5 1.8 +3.0 6.0 3.0 4.8 1.8 +3.0 6.9 3.1 5.4 2.1 +3.0 6.7 3.1 5.6 2.4 +3.0 6.9 3.1 5.1 2.3 +3.0 5.8 2.7 5.1 1.9 +3.0 6.8 3.2 5.9 2.3 +3.0 6.7 3.3 5.7 2.5 +3.0 6.7 3.0 5.2 2.3 +3.0 6.3 2.5 5.0 1.9 +3.0 6.5 3.0 5.2 2.0 +3.0 6.2 3.4 5.4 2.3 +3.0 5.9 3.0 5.1 1.8 http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java index 81f7607..ee2f442 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java @@ -270,6 +270,24 @@ public class LabeledDataset implements Serializable { } /** + * Returns new copy of labels of all labeled vectors NOTE: This method is useful for copying labels from test + * dataset. + * + * @return Copy of labels. + */ + public double[] labels() { + assert data != null; + assert data.length > 0; + + double[] labels = new double[data.length]; + + for (int i = 0; i < data.length; i++) + labels[i] = (double)data[i].label(); + + return labels; + } + + /** * Fill the label with given value. * * @param idx Index of observation. http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java new file mode 100644 index 0000000..dd3d244 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import java.io.Serializable; +import java.util.Map; +import java.util.Random; +import java.util.TreeMap; +import java.util.TreeSet; +import org.jetbrains.annotations.NotNull; + +/** + * Class for splitting Labeled Dataset on train and test sets. + */ +public class LabeledDatasetTestTrainPair implements Serializable { + /** Data to keep train set. */ + private LabeledDataset train; + + /** Data to keep test set. */ + private LabeledDataset test; + + /** + * Creates two subsets of given dataset. + * <p> + * NOTE: This method uses next algorithm with O(n log n) by calculations and O(n) by memory. + * </p> + * @param dataset The dataset to split on train and test subsets. + * @param testPercentage The percentage of the test subset. + */ + public LabeledDatasetTestTrainPair(LabeledDataset dataset, double testPercentage) { + assert testPercentage > 0.0; + assert testPercentage < 1.0; + final int datasetSize = dataset.rowSize(); + assert datasetSize > 2; + + final int testSize = (int)Math.floor(datasetSize * testPercentage); + final int trainSize = datasetSize - testSize; + + final TreeSet<Integer> sortedTestIndices = getSortedIndices(datasetSize, testSize); + + + LabeledVector[] testVectors = new LabeledVector[testSize]; + LabeledVector[] trainVectors = new LabeledVector[trainSize]; + + + int datasetCntr = 0; + int trainCntr = 0; + int testCntr = 0; + + for (Integer idx: sortedTestIndices){ // guarantee order as iterator + testVectors[testCntr] = dataset.getRow(idx); + testCntr++; + + for (int i = datasetCntr; i < idx; i++) { + trainVectors[trainCntr] = dataset.getRow(i); + trainCntr++; + } + + datasetCntr = idx + 1; + } + if(datasetCntr < datasetSize){ + for (int i = datasetCntr; i < datasetSize; i++) { + trainVectors[trainCntr] = dataset.getRow(i); + trainCntr++; + } + } + + test = new LabeledDataset(testVectors, testSize); + train = new LabeledDataset(trainVectors, trainSize); + } + + /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */ + @NotNull private TreeSet<Integer> getSortedIndices(int datasetSize, int testSize) { + Random rnd = new Random(); + TreeMap<Double, Integer> randomIdxPairs = new TreeMap<>(); + for (int i = 0; i < datasetSize; i++) + randomIdxPairs.put(rnd.nextDouble(), i); + + final TreeMap<Double, Integer> testIdxPairs = randomIdxPairs.entrySet().stream() + .limit(testSize) + .collect(TreeMap::new, (m, e) -> m.put(e.getKey(), e.getValue()), Map::putAll); + + return new TreeSet<>(testIdxPairs.values()); + } + + /** + * Train subset of the whole dataset. + * @return Train subset. + */ + public LabeledDataset train() { + return train; + } + + /** + * Test subset of the whole dataset. + * @return Test subset. + */ + public LabeledDataset test() { + return test; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/da782958/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java index 32bd37b..c64a8d8 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java @@ -28,6 +28,7 @@ import org.apache.ignite.ml.math.exceptions.NoDataException; import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException; import org.apache.ignite.ml.math.exceptions.knn.FileParsingException; import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair; import org.apache.ignite.ml.structures.LabeledVector; /** Tests behaviour of KNNClassificationTest. */ @@ -205,4 +206,61 @@ public class LabeledDatasetTest extends BaseKNNTest { assertEquals(training.features(2).get(1), 0.0); } + + /** */ + public void testSplitting() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + double[][] mtx = + new double[][] { + {1.0, 1.0}, + {1.0, 2.0}, + {2.0, 1.0}, + {-1.0, -1.0}, + {-1.0, -2.0}, + {-2.0, -1.0}}; + double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0}; + + LabeledDataset training = new LabeledDataset(mtx, lbs); + + LabeledDatasetTestTrainPair split1 = new LabeledDatasetTestTrainPair(training, 0.67); + + assertEquals(4, split1.test().rowSize()); + assertEquals(2, split1.train().rowSize()); + + LabeledDatasetTestTrainPair split2 = new LabeledDatasetTestTrainPair(training, 0.65); + + assertEquals(3, split2.test().rowSize()); + assertEquals(3, split2.train().rowSize()); + + LabeledDatasetTestTrainPair split3 = new LabeledDatasetTestTrainPair(training, 0.4); + + assertEquals(2, split3.test().rowSize()); + assertEquals(4, split3.train().rowSize()); + + LabeledDatasetTestTrainPair split4 = new LabeledDatasetTestTrainPair(training, 0.3); + + assertEquals(1, split4.test().rowSize()); + assertEquals(5, split4.train().rowSize()); + } + + /** */ + public void testLabels() { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + double[][] mtx = + new double[][] { + {1.0, 1.0}, + {1.0, 2.0}, + {2.0, 1.0}, + {-1.0, -1.0}, + {-1.0, -2.0}, + {-2.0, -1.0}}; + double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0}; + + LabeledDataset dataset = new LabeledDataset(mtx, lbs); + final double[] labels = dataset.labels(); + for (int i = 0; i < lbs.length; i++) + assertEquals(lbs[i], labels[i]); + } }
