IGNITE-6880: KNN(k nearest neighbor) algorithm this closes #3117
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/8ba773bf Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/8ba773bf Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/8ba773bf Branch: refs/heads/ignite-zk-ce Commit: 8ba773bfe580ab3e0822283e7581115242303962 Parents: e3d70a8 Author: zaleslaw <[email protected]> Authored: Mon Dec 11 19:02:08 2017 +0300 Committer: Yury Babak <[email protected]> Committed: Mon Dec 11 19:02:08 2017 +0300 ---------------------------------------------------------------------- .../ml/clustering/FuzzyCMeansExample.java | 4 +- .../KMeansDistributedClustererExample.java | 2 +- .../clustering/KMeansLocalClustererExample.java | 4 +- .../ignite/ml/FuzzyCMeansModelFormat.java | 2 +- .../org/apache/ignite/ml/KMeansModelFormat.java | 2 +- .../ml/clustering/BaseFuzzyCMeansClusterer.java | 2 +- .../ml/clustering/BaseKMeansClusterer.java | 2 +- .../FuzzyCMeansDistributedClusterer.java | 2 +- .../clustering/FuzzyCMeansLocalClusterer.java | 2 +- .../ignite/ml/clustering/FuzzyCMeansModel.java | 2 +- .../clustering/KMeansDistributedClusterer.java | 2 +- .../ml/clustering/KMeansLocalClusterer.java | 2 +- .../ignite/ml/clustering/KMeansModel.java | 2 +- .../apache/ignite/ml/knn/models/KNNModel.java | 233 ++++++++++ .../ignite/ml/knn/models/KNNModelFormat.java | 92 ++++ .../ignite/ml/knn/models/KNNStrategy.java | 27 ++ .../ignite/ml/knn/models/Normalization.java | 32 ++ .../ignite/ml/knn/models/package-info.java | 22 + .../org/apache/ignite/ml/knn/package-info.java | 22 + .../regression/KNNMultipleLinearRegression.java | 83 ++++ .../ignite/ml/knn/regression/package-info.java | 22 + .../apache/ignite/ml/math/DistanceMeasure.java | 38 -- .../ignite/ml/math/EuclideanDistance.java | 58 --- .../ml/math/distances/DistanceMeasure.java | 39 ++ .../ml/math/distances/EuclideanDistance.java | 59 +++ .../ml/math/distances/HammingDistance.java | 65 +++ .../ml/math/distances/ManhattanDistance.java | 59 +++ .../ignite/ml/math/distances/package-info.java | 22 + .../ignite/ml/math/distributed/CacheUtils.java | 2 +- .../distributed/keys/impl/SparseMatrixKey.java | 1 + .../math/exceptions/knn/EmptyFileException.java | 37 ++ .../exceptions/knn/FileParsingException.java | 39 ++ .../exceptions/knn/NoLabelVectorException.java | 37 ++ .../knn/SmallTrainingDatasetSizeException.java | 38 ++ .../ml/math/exceptions/knn/package-info.java | 22 + .../ignite/ml/structures/LabeledDataset.java | 457 +++++++++++++++++++ .../ignite/ml/structures/LabeledVector.java | 37 +- .../org/apache/ignite/ml/IgniteMLTestSuite.java | 2 + .../org/apache/ignite/ml/LocalModelsTest.java | 39 +- .../FuzzyCMeansDistributedClustererTest.java | 4 +- .../FuzzyCMeansLocalClustererTest.java | 4 +- ...KMeansDistributedClustererTestMultiNode.java | 2 +- ...MeansDistributedClustererTestSingleNode.java | 4 +- .../ml/clustering/KMeansLocalClustererTest.java | 2 +- .../org/apache/ignite/ml/knn/BaseKNNTest.java | 91 ++++ .../ignite/ml/knn/KNNClassificationTest.java | 153 +++++++ .../ml/knn/KNNMultipleLinearRegressionTest.java | 157 +++++++ .../org/apache/ignite/ml/knn/KNNTestSuite.java | 33 ++ .../ignite/ml/knn/LabeledDatasetTest.java | 208 +++++++++ .../ignite/ml/math/MathImplLocalTestSuite.java | 4 +- .../ignite/ml/math/distances/DistanceTest.java | 75 +++ .../OLSMultipleLinearRegressionTest.java | 1 + .../ml/trees/ColumnDecisionTreeTrainerTest.java | 6 +- .../ColumnDecisionTreeTrainerBenchmark.java | 6 +- .../ml/src/test/resources/datasets/README.md | 2 + .../resources/datasets/knn/cleared_machines.txt | 209 +++++++++ .../src/test/resources/datasets/knn/empty.txt | 0 .../ml/src/test/resources/datasets/knn/iris.txt | 150 ++++++ .../resources/datasets/knn/iris_incorrect.txt | 150 ++++++ .../resources/datasets/knn/machine.data.txt | 209 +++++++++ .../test/resources/datasets/knn/missed_data.txt | 3 + .../src/test/resources/datasets/knn/no_data.txt | 6 + parent/pom.xml | 1 + 63 files changed, 2963 insertions(+), 131 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/FuzzyCMeansExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/FuzzyCMeansExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/FuzzyCMeansExample.java index 9c47186..3fce624 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/FuzzyCMeansExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/FuzzyCMeansExample.java @@ -22,10 +22,10 @@ import org.apache.ignite.Ignition; import org.apache.ignite.ml.clustering.BaseFuzzyCMeansClusterer; import org.apache.ignite.ml.clustering.FuzzyCMeansDistributedClusterer; import org.apache.ignite.ml.clustering.FuzzyCMeansModel; -import org.apache.ignite.ml.math.DistanceMeasure; -import org.apache.ignite.ml.math.EuclideanDistance; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; +import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; import org.apache.ignite.thread.IgniteThread; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansDistributedClustererExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansDistributedClustererExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansDistributedClustererExample.java index 456e915..09f35d2 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansDistributedClustererExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansDistributedClustererExample.java @@ -24,10 +24,10 @@ import org.apache.ignite.Ignition; import org.apache.ignite.examples.ExampleNodeStartup; import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; import org.apache.ignite.ml.clustering.KMeansDistributedClusterer; -import org.apache.ignite.ml.math.EuclideanDistance; import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.Tracer; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; import org.apache.ignite.thread.IgniteThread; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansLocalClustererExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansLocalClustererExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansLocalClustererExample.java index 970931e..28ca9d9 100644 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansLocalClustererExample.java +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/clustering/KMeansLocalClustererExample.java @@ -22,10 +22,10 @@ import java.util.Comparator; import java.util.List; import org.apache.ignite.ml.clustering.KMeansLocalClusterer; import org.apache.ignite.ml.clustering.KMeansModel; -import org.apache.ignite.ml.math.DistanceMeasure; -import org.apache.ignite.ml.math.EuclideanDistance; import org.apache.ignite.ml.math.Tracer; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; +import org.apache.ignite.ml.math.distances.EuclideanDistance; import org.apache.ignite.ml.math.functions.Functions; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/FuzzyCMeansModelFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/FuzzyCMeansModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/FuzzyCMeansModelFormat.java index 2b27e86..cc3d9b3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/FuzzyCMeansModelFormat.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/FuzzyCMeansModelFormat.java @@ -19,8 +19,8 @@ package org.apache.ignite.ml; import java.io.Serializable; import java.util.Arrays; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; /** Fuzzy C-Means model representation. */ public class FuzzyCMeansModelFormat implements Serializable { http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/KMeansModelFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/KMeansModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/KMeansModelFormat.java index 4f5b143..c013198 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/KMeansModelFormat.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/KMeansModelFormat.java @@ -19,8 +19,8 @@ package org.apache.ignite.ml; import java.io.Serializable; import java.util.Arrays; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; /** * K-means model representation. http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseFuzzyCMeansClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseFuzzyCMeansClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseFuzzyCMeansClusterer.java index 65aaeee..2b2febf 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseFuzzyCMeansClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseFuzzyCMeansClusterer.java @@ -17,9 +17,9 @@ package org.apache.ignite.ml.clustering; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.exceptions.ConvergenceException; import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseKMeansClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseKMeansClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseKMeansClusterer.java index 570ea7e..521437c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseKMeansClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/BaseKMeansClusterer.java @@ -19,9 +19,9 @@ package org.apache.ignite.ml.clustering; import java.util.List; import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.exceptions.ConvergenceException; import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansDistributedClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansDistributedClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansDistributedClusterer.java index a5cd871..8823c10 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansDistributedClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansDistributedClusterer.java @@ -26,9 +26,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import javax.cache.Cache; import org.apache.ignite.internal.util.GridArgumentCheck; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.distributed.CacheUtils; import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey; import org.apache.ignite.ml.math.exceptions.ConvergenceException; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansLocalClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansLocalClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansLocalClusterer.java index 1724da3..a1b6d3f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansLocalClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansLocalClusterer.java @@ -22,9 +22,9 @@ import java.util.Collections; import java.util.List; import java.util.Random; import org.apache.ignite.internal.util.GridArgumentCheck; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.exceptions.ConvergenceException; import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansModel.java index 41267b9..83fbf1f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/FuzzyCMeansModel.java @@ -21,8 +21,8 @@ import java.util.Arrays; import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; import org.apache.ignite.ml.FuzzyCMeansModelFormat; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; /** This class incapsulates result of clusterization. */ public class FuzzyCMeansModel implements ClusterizationModel<Vector, Integer>, Exportable<FuzzyCMeansModelFormat> { http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java index 24938bc..5595b4c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansDistributedClusterer.java @@ -26,9 +26,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import javax.cache.Cache; import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.distributed.CacheUtils; import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey; import org.apache.ignite.ml.math.exceptions.ConvergenceException; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansLocalClusterer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansLocalClusterer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansLocalClusterer.java index 3d005b4..8a50e65 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansLocalClusterer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansLocalClusterer.java @@ -23,10 +23,10 @@ import java.util.Collections; import java.util.List; import java.util.Random; import org.apache.ignite.internal.util.GridArgumentCheck; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; import org.apache.ignite.ml.math.VectorUtils; +import org.apache.ignite.ml.math.distances.DistanceMeasure; import org.apache.ignite.ml.math.exceptions.ConvergenceException; import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansModel.java index c449b8b..381f976 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/KMeansModel.java @@ -21,8 +21,8 @@ import java.util.Arrays; import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; import org.apache.ignite.ml.KMeansModelFormat; -import org.apache.ignite.ml.math.DistanceMeasure; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; /** * This class encapsulates result of clusterization. http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java new file mode 100644 index 0000000..44955c8 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn.models; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; +import org.apache.ignite.ml.math.exceptions.knn.SmallTrainingDatasetSizeException; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledVector; +import org.jetbrains.annotations.NotNull; + +/** + * kNN algorithm is a classification algorithm. + */ +public class KNNModel implements Model<Vector, Double>, Exportable<KNNModelFormat> { + /** Amount of nearest neighbors. */ + protected final int k; + + /** Distance measure. */ + protected final DistanceMeasure distanceMeasure; + + /** Training dataset. */ + protected final LabeledDataset training; + + /** kNN strategy. */ + protected final KNNStrategy stgy; + + /** Cached distances for k-nearest neighbors. */ + protected double[] cachedDistances; + + /** + * Creates the kNN model with the given parameters. + * + * @param k Amount of nearest neighbors. + * @param distanceMeasure Distance measure. + * @param stgy Strategy of calculations. + * @param training Training dataset. + */ + public KNNModel(int k, DistanceMeasure distanceMeasure, KNNStrategy stgy, LabeledDataset training) { + assert training != null; + + if (training.rowSize() < k) + throw new SmallTrainingDatasetSizeException(k, training.rowSize()); + + this.k = k; + this.distanceMeasure = distanceMeasure; + this.training = training; + this.stgy = stgy; + } + + /** {@inheritDoc} */ + @Override public Double predict(Vector v) { + LabeledVector[] neighbors = findKNearestNeighbors(v, true); + + return classify(neighbors, v, stgy); + } + + /** */ + @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) { + KNNModelFormat mdlData = new KNNModelFormat(k, distanceMeasure, training, stgy); + + exporter.save(mdlData, path); + } + + /** + * The main idea is calculation all distance pairs between given vector and all vectors in training set, sorting + * them and finding k vectors with min distance with the given vector. + * + * @param v The given vector. + * @return K-nearest neighbors. + */ + protected LabeledVector[] findKNearestNeighbors(Vector v, boolean isCashedDistance) { + LabeledVector[] trainingData = training.data(); + + TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, trainingData); + + return getKClosestVectors(trainingData, distanceIdxPairs, isCashedDistance); + } + + /** + * Iterates along entries in distance map and fill the resulting k-element array. + * + * @param trainingData The training data. + * @param distanceIdxPairs The distance map. + * @param isCashedDistances Cache distances if true. + * @return K-nearest neighbors. + */ + @NotNull private LabeledVector[] getKClosestVectors(LabeledVector[] trainingData, + TreeMap<Double, Set<Integer>> distanceIdxPairs, boolean isCashedDistances) { + LabeledVector[] res = new LabeledVector[k]; + int i = 0; + final Iterator<Double> iter = distanceIdxPairs.keySet().iterator(); + while (i < k) { + double key = iter.next(); + Set<Integer> idxs = distanceIdxPairs.get(key); + for (Integer idx : idxs) { + res[i] = trainingData[idx]; + if (isCashedDistances) { + if (cachedDistances == null) + cachedDistances = new double[k]; + cachedDistances[i] = key; + } + i++; + if (i >= k) + break; // go to next while-loop iteration + } + } + return res; + } + + /** + * Computes distances between given vector and each vector in training dataset. + * + * @param v The given vector. + * @param trainingData The training dataset. + * @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented + * with Set because there can be a few vectors with the same distance. + */ + @NotNull private TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledVector[] trainingData) { + TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>(); + + for (int i = 0; i < trainingData.length; i++) { + + LabeledVector labeledVector = trainingData[i]; + if (labeledVector != null) { + double distance = distanceMeasure.compute(v, labeledVector.features()); + putDistanceIdxPair(distanceIdxPairs, i, distance); + } + } + return distanceIdxPairs; + } + + /** */ + private void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) { + if (distanceIdxPairs.containsKey(distance)) { + Set<Integer> idxs = distanceIdxPairs.get(distance); + idxs.add(i); + } + else { + Set<Integer> idxs = new HashSet<>(); + idxs.add(i); + distanceIdxPairs.put(distance, idxs); + } + } + + /** */ + private double classify(LabeledVector[] neighbors, Vector v, KNNStrategy stgy) { + Map<Double, Double> clsVotes = new HashMap<>(); + + for (int i = 0; i < neighbors.length; i++) { + LabeledVector neighbor = neighbors[i]; + double clsLb = (double)neighbor.label(); + + double distance = cachedDistances != null ? cachedDistances[i] : distanceMeasure.compute(v, neighbor.features()); + + if (clsVotes.containsKey(clsLb)) { + double clsVote = clsVotes.get(clsLb); + clsVote += getClassVoteForVector(stgy, distance); + clsVotes.put(clsLb, clsVote); + } + else { + final double val = getClassVoteForVector(stgy, distance); + clsVotes.put(clsLb, val); + } + } + return getClassWithMaxVotes(clsVotes); + } + + /** */ + private double getClassWithMaxVotes(Map<Double, Double> clsVotes) { + return Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey(); + } + + /** */ + private double getClassVoteForVector(KNNStrategy stgy, double distance) { + if (stgy.equals(KNNStrategy.WEIGHTED)) + return 1 / distance; // strategy.WEIGHTED + else + return 1.0; // strategy.SIMPLE + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = 1; + + res = res * 37 + k; + res = res * 37 + distanceMeasure.hashCode(); + res = res * 37 + stgy.hashCode(); + res = res * 37 + Arrays.hashCode(training.data()); + + return res; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + if (obj == null || getClass() != obj.getClass()) + return false; + + KNNModel that = (KNNModel)obj; + + return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy) + && Arrays.deepEquals(training.data(), that.training.data()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java new file mode 100644 index 0000000..17d9842 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn.models; + +import java.io.Serializable; +import java.util.Arrays; +import org.apache.ignite.ml.math.distances.DistanceMeasure; +import org.apache.ignite.ml.structures.LabeledDataset; + +/** */ +public class KNNModelFormat implements Serializable { + /** Amount of nearest neighbors. */ + private int k; + + /** Distance measure. */ + private DistanceMeasure distanceMeasure; + + /** Training dataset */ + private LabeledDataset training; + + /** kNN strategy. */ + private KNNStrategy stgy; + + /** */ + public int getK() { + return k; + } + + /** */ + public DistanceMeasure getDistanceMeasure() { + return distanceMeasure; + } + + /** */ + public LabeledDataset getTraining() { + return training; + } + + /** */ + public KNNStrategy getStgy() { + return stgy; + } + + /** */ + public KNNModelFormat(int k, DistanceMeasure measure, LabeledDataset training, KNNStrategy stgy) { + this.k = k; + this.distanceMeasure = measure; + this.training = training; + this.stgy = stgy; + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = 1; + + res = res * 37 + k; + res = res * 37 + distanceMeasure.hashCode(); + res = res * 37 + stgy.hashCode(); + res = res * 37 + Arrays.hashCode(training.data()); + + return res; + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + if (obj == null || getClass() != obj.getClass()) + return false; + + KNNModelFormat that = (KNNModelFormat)obj; + + return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy) + && Arrays.deepEquals(training.data(), that.training.data()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java new file mode 100644 index 0000000..d524773 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn.models; + +/** This enum contains settings for kNN algorithm. */ +public enum KNNStrategy { + /** */ + SIMPLE, + + /** */ + WEIGHTED +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java new file mode 100644 index 0000000..aa4b291 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/Normalization.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.knn.models; + +/** This enum contains names of different normalization approaches. */ +public enum Normalization { + /** Minimax. + * + * x'=(x-MIN[X])/(MAX[X]-MIN[X]) + */ + MINIMAX, + /** Z normalization. + * + * x'=(x-M[X])/\sigma [X] + */ + Z_NORMALIZATION +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java new file mode 100644 index 0000000..7b6e678 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains main APIs for kNN classification algorithms. + */ +package org.apache.ignite.ml.knn.models; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/package-info.java new file mode 100644 index 0000000..0854015 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains main APIs for kNN algorithms. + */ +package org.apache.ignite.ml.knn; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java new file mode 100644 index 0000000..2db8a9f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.knn.regression; + +import org.apache.ignite.ml.knn.models.KNNModel; +import org.apache.ignite.ml.knn.models.KNNStrategy; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.distances.DistanceMeasure; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.structures.LabeledDataset; +import org.apache.ignite.ml.structures.LabeledVector; + +/** + * This class provides kNN Multiple Linear Regression or Locally [weighted] regression (Simple and Weighted versions). + * + * <p> This is an instance-based learning method. </p> + * + * <ul> + * <li>Local means using nearby points (i.e. a nearest neighbors approach).</li> + * <li>Weighted means we value points based upon how far away they are.</li> + * <li>Regression means approximating a function.</li> + * </ul> + */ +public class KNNMultipleLinearRegression extends KNNModel { + /** {@inheritDoc} */ + public KNNMultipleLinearRegression(int k, DistanceMeasure distanceMeasure, KNNStrategy stgy, + LabeledDataset training) { + super(k, distanceMeasure, stgy, training); + } + + /** {@inheritDoc} */ + @Override public Double predict(Vector v) { + LabeledVector[] neighbors = findKNearestNeighbors(v, true); + + return predictYBasedOn(neighbors, v); + } + + /** */ + private double predictYBasedOn(LabeledVector[] neighbors, Vector v) { + switch (stgy) { + case SIMPLE: + return simpleRegression(neighbors); + case WEIGHTED: + return weightedRegression(neighbors, v); + default: + throw new UnsupportedOperationException("Strategy " + stgy.name() + " is not supported"); + } + } + + /** */ + private double weightedRegression(LabeledVector<Vector, Double>[] neighbors, Vector v) { + double sum = 0.0; + double div = 0.0; + for (int i = 0; i < neighbors.length; i++) { + double distance = cachedDistances != null ? cachedDistances[i] : distanceMeasure.compute(v, neighbors[i].features()); + sum += neighbors[i].label() * distance; + div += distance; + } + return sum / div; + } + + /** */ + private double simpleRegression(LabeledVector<Vector, Double>[] neighbors) { + double sum = 0.0; + for (LabeledVector<Vector, Double> neighbor : neighbors) + sum += neighbor.label(); + return sum / (double)k; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java new file mode 100644 index 0000000..30023a1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains main APIs for kNN regression algorithms. + */ +package org.apache.ignite.ml.knn.regression; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/DistanceMeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/DistanceMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/DistanceMeasure.java deleted file mode 100644 index df235a7..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/DistanceMeasure.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.ignite.ml.math; - -import java.io.Externalizable; -import org.apache.ignite.ml.math.exceptions.CardinalityException; - -/** - * This class is based on the corresponding class from Apache Common Math lib. - * Interface for distance measures of n-dimensional vectors. - */ -public interface DistanceMeasure extends Externalizable { - /** - * Compute the distance between two n-dimensional vectors. - * <p> - * The two vectors are required to have the same dimension. - * - * @param a the first vector - * @param b the second vector - * @return the distance between the two vectors - * @throws CardinalityException if the array lengths differ. - */ - public double compute(Vector a, Vector b) throws CardinalityException; -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/EuclideanDistance.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/EuclideanDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/EuclideanDistance.java deleted file mode 100644 index 5d5a64e..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/EuclideanDistance.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.ignite.ml.math; - -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import org.apache.ignite.ml.math.exceptions.CardinalityException; -import org.apache.ignite.ml.math.util.MatrixUtil; - -/** - * Calculates the L<sub>2</sub> (Euclidean) distance between two points. - */ -public class EuclideanDistance implements DistanceMeasure { - /** Serializable version identifier. */ - private static final long serialVersionUID = 1717556319784040040L; - - /** {@inheritDoc} */ - @Override public double compute(Vector a, Vector b) - throws CardinalityException { - return MatrixUtil.localCopyOf(a).minus(b).kNorm(2.0); - } - - /** {@inheritDoc} */ - @Override public void writeExternal(ObjectOutput out) throws IOException { - // No-op - } - - /** {@inheritDoc} */ - @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - // No-op - } - - /** {@inheritDoc} */ - @Override public boolean equals(Object obj) { - if (this == obj) - return true; - - if (obj == null || getClass() != obj.getClass()) - return false; - - return true; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java new file mode 100644 index 0000000..3fa2ec7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.io.Externalizable; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; + +/** + * This class is based on the corresponding class from Apache Common Math lib. + * Interface for distance measures of n-dimensional vectors. + */ +public interface DistanceMeasure extends Externalizable { + /** + * Compute the distance between two n-dimensional vectors. + * <p> + * The two vectors are required to have the same dimension. + * + * @param a The first vector. + * @param b The second vector. + * @return The distance between the two vectors. + * @throws CardinalityException if the array lengths differ. + */ + public double compute(Vector a, Vector b) throws CardinalityException; +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java new file mode 100644 index 0000000..a0c95d2 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the L<sub>2</sub> (Euclidean) distance between two points. + */ +public class EuclideanDistance implements DistanceMeasure { + /** Serializable version identifier. */ + private static final long serialVersionUID = 1717556319784040040L; + + /** {@inheritDoc} */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + return MatrixUtil.localCopyOf(a).minus(b).kNorm(2.0); + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + // No-op + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // No-op + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + if (obj == null || getClass() != obj.getClass()) + return false; + + return true; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java new file mode 100644 index 0000000..dec2d73 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.functions.Functions; +import org.apache.ignite.ml.math.functions.IgniteDoubleFunction; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the Hamming distance between two points. + */ +public class HammingDistance implements DistanceMeasure { + /** Serializable version identifier. */ + private static final long serialVersionUID = 1771556549784040098L; + + /** {@inheritDoc} */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + IgniteDoubleFunction<Double> fun = (value -> { + if (value == 0) return 0.0; + else return 1.0; + }); + return MatrixUtil.localCopyOf(a).minus(b).foldMap(Functions.PLUS, fun, 0d); + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + // No-op + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // No-op + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + if (obj == null || getClass() != obj.getClass()) + return false; + + return true; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java new file mode 100644 index 0000000..66394f1 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.ml.math.distances; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.util.MatrixUtil; + +/** + * Calculates the L<sub>1</sub> (sum of abs) distance between two points. + */ +public class ManhattanDistance implements DistanceMeasure { + /** Serializable version identifier. */ + private static final long serialVersionUID = 8989556319784040040L; + + /** {@inheritDoc} */ + @Override public double compute(Vector a, Vector b) + throws CardinalityException { + return MatrixUtil.localCopyOf(a).minus(b).kNorm(1.0); + } + + /** {@inheritDoc} */ + @Override public void writeExternal(ObjectOutput out) throws IOException { + // No-op + } + + /** {@inheritDoc} */ + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // No-op + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + + if (obj == null || getClass() != obj.getClass()) + return false; + + return true; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/package-info.java new file mode 100644 index 0000000..9d799b7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains main APIs for distances. + */ +package org.apache.ignite.ml.math.distances; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java index 9ca167c..3256f8a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java @@ -374,7 +374,7 @@ public class CacheUtils { else if (key instanceof VectorBlockKey) return ((VectorBlockKey)key).dataStructureId().equals(matrixUuid); else - throw new UnsupportedOperationException(); // TODO: handle my poor doubles + throw new UnsupportedOperationException(); }; } http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java index cbd5208..3669d19 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/keys/impl/SparseMatrixKey.java @@ -28,6 +28,7 @@ import org.apache.ignite.internal.util.typedef.internal.S; import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey; import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; + /** * Key implementation for {@link SparseDistributedMatrix}. */ http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/EmptyFileException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/EmptyFileException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/EmptyFileException.java new file mode 100644 index 0000000..065776a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/EmptyFileException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions.knn; + +import org.apache.ignite.IgniteException; + +/** + * Shows empty filename. + */ +public class EmptyFileException extends IgniteException { + /** */ + private static final long serialVersionUID = 0L; + + /** + * Creates new exception. + * + * @param filename Name of the file without content. + */ + public EmptyFileException(String filename) { + super("Empty file with filename " + filename); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/FileParsingException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/FileParsingException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/FileParsingException.java new file mode 100644 index 0000000..12c8fe3 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/FileParsingException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions.knn; + +import java.nio.file.Path; +import org.apache.ignite.IgniteException; + +/** + * Shows non-parsed data in specific row by given file path. + */ +public class FileParsingException extends IgniteException { + /** */ + private static final long serialVersionUID = 0L; + + /** + * Creates new exception. + * @param parsedData Data to parse. + * @param rowIdx Index of row in file. + * @param file File path + */ + public FileParsingException(String parsedData, int rowIdx, Path file) { + super("Data " + parsedData + " in row # " + rowIdx + " in file " + file + " can not be parsed to appropriate format"); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/NoLabelVectorException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/NoLabelVectorException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/NoLabelVectorException.java new file mode 100644 index 0000000..7815e0f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/NoLabelVectorException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions.knn; + +import org.apache.ignite.IgniteException; + +/** + * Shows Labeled Dataset index with non-existing Labeled Vector. + */ +public class NoLabelVectorException extends IgniteException { + /** */ + private static final long serialVersionUID = 0L; + + /** + * Creates new exception. + * + * @param idx index of missed Labeled vector. + */ + public NoLabelVectorException(int idx) { + super("No vector in position" + idx); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/SmallTrainingDatasetSizeException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/SmallTrainingDatasetSizeException.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/SmallTrainingDatasetSizeException.java new file mode 100644 index 0000000..5eb3f7a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/SmallTrainingDatasetSizeException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.exceptions.knn; + +import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; + +/** + * Indicates a small training dataset size in ML algorithms. + */ +public class SmallTrainingDatasetSizeException extends MathIllegalArgumentException { + /** */ + private static final long serialVersionUID = 0L; + + /** + * Creates new small training dataset size exception. + * + * @param exp Expected dataset size. + * @param act Actual dataset size. + */ + public SmallTrainingDatasetSizeException(int exp, int act) { + super("Small training dataset size [expected=%d, actual=%d]", exp, act); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/package-info.java new file mode 100644 index 0000000..e55b7b9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/exceptions/knn/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains exceptions for kNN algorithms. + */ +package org.apache.ignite.ml.math.exceptions.knn; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java new file mode 100644 index 0000000..81f7607 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.structures; + +import java.io.IOException; +import java.io.Serializable; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; +import org.apache.ignite.ml.knn.models.Normalization; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.CardinalityException; +import org.apache.ignite.ml.math.exceptions.NoDataException; +import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException; +import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException; +import org.apache.ignite.ml.math.exceptions.knn.FileParsingException; +import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector; +import org.jetbrains.annotations.NotNull; + +/** + * Class for set of labeled vectors. + */ +public class LabeledDataset implements Serializable { + /** Data to keep. */ + private final LabeledVector[] data; + + /** Feature names (one name for each attribute in vector). */ + private String[] featureNames; + + /** Amount of instances. */ + private int rowSize; + + /** Amount of attributes in each vector. */ + private int colSize; + + /** + * Creates new Labeled Dataset by given data. + * + * @param data Should be initialized with one vector at least. + * @param colSize Amount of observed attributes in each vector. + */ + public LabeledDataset(LabeledVector[] data, int colSize) { + this(data, null, colSize); + } + + /** + * Creates new Labeled Dataset by given data. + * + * @param data Given data. Should be initialized with one vector at least. + * @param featureNames Column names. + * @param colSize Amount of observed attributes in each vector. + */ + public LabeledDataset(LabeledVector[] data, String[] featureNames, int colSize) { + assert data != null; + assert data.length > 0; + + this.data = data; + this.rowSize = data.length; + this.colSize = colSize; + + if(featureNames == null) generateFeatureNames(); + else { + assert colSize == featureNames.length; + this.featureNames = featureNames; + } + + } + + /** + * Creates new Labeled Dataset and initialized with empty data structure. + * + * @param rowSize Amount of instances. Should be > 0. + * @param colSize Amount of attributes. Should be > 0. + * @param isDistributed Use distributed data structures to keep data. + */ + public LabeledDataset(int rowSize, int colSize, boolean isDistributed){ + this(rowSize, colSize, null, isDistributed); + } + + /** + * Creates new local Labeled Dataset and initialized with empty data structure. + * + * @param rowSize Amount of instances. Should be > 0. + * @param colSize Amount of attributes. Should be > 0. + */ + public LabeledDataset(int rowSize, int colSize){ + this(rowSize, colSize, null, false); + } + + /** + * Creates new Labeled Dataset and initialized with empty data structure. + * + * @param rowSize Amount of instances. Should be > 0. + * @param colSize Amount of attributes. Should be > 0 + * @param featureNames Column names. + * @param isDistributed Use distributed data structures to keep data. + */ + public LabeledDataset(int rowSize, int colSize, String[] featureNames, boolean isDistributed){ + assert rowSize > 0; + assert colSize > 0; + + if(featureNames == null) generateFeatureNames(); + else { + assert colSize == featureNames.length; + this.featureNames = featureNames; + } + + this.rowSize = rowSize; + this.colSize = colSize; + + data = new LabeledVector[rowSize]; + for (int i = 0; i < rowSize; i++) + data[i] = new LabeledVector(getVector(colSize, isDistributed), null); + + } + + + /** + * Creates new local Labeled Dataset by matrix and vector of labels. + * + * @param mtx Given matrix with rows as observations. + * @param lbs Labels of observations. + */ + public LabeledDataset(double[][] mtx, double[] lbs) { + this(mtx, lbs, null, false); + } + + /** + * Creates new Labeled Dataset by matrix and vector of labels. + * + * @param mtx Given matrix with rows as observations. + * @param lbs Labels of observations. + * @param featureNames Column names. + * @param isDistributed Use distributed data structures to keep data. + */ + public LabeledDataset(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) { + assert mtx != null; + assert lbs != null; + + if(mtx.length != lbs.length) + throw new CardinalityException(lbs.length, mtx.length); + + if(mtx[0] == null) + throw new NoDataException("Pass filled array, the first vector is empty"); + + this.rowSize = lbs.length; + this.colSize = mtx[0].length; + + if(featureNames == null) generateFeatureNames(); + else this.featureNames = featureNames; + + + data = new LabeledVector[rowSize]; + for (int i = 0; i < rowSize; i++){ + + data[i] = new LabeledVector(getVector(colSize, isDistributed), lbs[i]); + for (int j = 0; j < colSize; j++) { + try { + data[i].features().set(j, mtx[i][j]); + } catch (ArrayIndexOutOfBoundsException e) { + throw new NoDataException("No data in given matrix by coordinates (" + i + "," + j + ")"); + } + } + } + } + + /** */ + private void generateFeatureNames() { + featureNames = new String[colSize]; + + for (int i = 0; i < colSize; i++) + featureNames[i] = "f_" + i; + } + + + /** + * Get vectors and their labels. + * + * @return Array of Label Vector instances. + */ + public LabeledVector[] data() { + return data; + } + + /** + * Gets amount of observation. + * + * @return Amount of rows in dataset. + */ + public int rowSize(){ + return rowSize; + } + + /** + * Returns feature name for column with given index. + * + * @param i The given index. + * @return Feature name. + */ + public String getFeatureName(int i){ + return featureNames[i]; + } + + /** + * Gets amount of attributes. + * + * @return Amount of attributes in each Labeled Vector. + */ + public int colSize(){ + return colSize; + } + + /** + * Retrieves Labeled Vector by given index. + * + * @param idx Index of observation. + * @return Labeled features. + */ + public LabeledVector getRow(int idx){ + return data[idx]; + } + + /** + * Get the features. + * + * @param idx Index of observation. + * @return Vector with features. + */ + public Vector features(int idx){ + assert idx < rowSize; + assert data != null; + assert data[idx] != null; + + return data[idx].features(); + } + + /** + * Returns label if label is attached or null if label is missed. + * + * @param idx Index of observation. + * @return Label. + */ + public double label(int idx) { + LabeledVector labeledVector = data[idx]; + + if(labeledVector!=null) + return (double)labeledVector.label(); + else + return Double.NaN; + } + + /** + * Fill the label with given value. + * + * @param idx Index of observation. + * @param lb The given label. + */ + public void setLabel(int idx, double lb) { + LabeledVector labeledVector = data[idx]; + + if(labeledVector != null) + labeledVector.setLabel(lb); + else + throw new NoLabelVectorException(idx); + } + + /** + * Datafile should keep class labels in the first column. + * + * @param pathToFile Path to file. + * @param separator Element to tokenize row on separate tokens. + * @param isDistributed Generates distributed dataset if true. + * @param isFallOnBadData Fall on incorrect data if true. + * @return Labeled Dataset parsed from file. + */ + public static LabeledDataset loadTxt(Path pathToFile, String separator, boolean isDistributed, boolean isFallOnBadData) throws IOException { + Stream<String> stream = Files.lines(pathToFile); + List<String> list = new ArrayList<>(); + stream.forEach(list::add); + + final int rowSize = list.size(); + + List<Double> labels = new ArrayList<>(); + List<Vector> vectors = new ArrayList<>(); + + if (rowSize > 0) { + + final int colSize = getColumnSize(separator, list) - 1; + + if (colSize > 0) { + + for (int i = 0; i < rowSize; i++) { + Double clsLb; + + String[] rowData = list.get(i).split(separator); + + try { + clsLb = Double.parseDouble(rowData[0]); + Vector vec = parseFeatures(pathToFile, isDistributed, isFallOnBadData, colSize, i, rowData); + labels.add(clsLb); + vectors.add(vec); + } + catch (NumberFormatException e) { + if(isFallOnBadData) + throw new FileParsingException(rowData[0], i, pathToFile); + } + } + + LabeledVector[] data = new LabeledVector[vectors.size()]; + for (int i = 0; i < vectors.size(); i++) + data[i] = new LabeledVector(vectors.get(i), labels.get(i)); + + return new LabeledDataset(data, colSize); + } + else + throw new NoDataException("File should contain first row with data"); + } + else + throw new EmptyFileException(pathToFile.toString()); + } + + /** */ + @NotNull private static Vector parseFeatures(Path pathToFile, boolean isDistributed, boolean isFallOnBadData, + int colSize, int rowIdx, String[] rowData) { + final Vector vec = getVector(colSize, isDistributed); + + for (int j = 0; j < colSize; j++) { + + if (rowData.length == colSize + 1) { + double val = fillMissedData(); + + try { + val = Double.parseDouble(rowData[j + 1]); + vec.set(j, val); + } + catch (NumberFormatException e) { + if(isFallOnBadData) + throw new FileParsingException(rowData[j + 1], rowIdx, pathToFile); + else + vec.set(j,val); + } + } + else throw new CardinalityException(colSize + 1, rowData.length); + } + return vec; + } + + // TODO: IGNITE-7025 add filling with mean, mode, ignoring and so on + /** */ + private static double fillMissedData() { + return 0.0; + } + + /** */ + @NotNull private static Vector getVector(int size, boolean isDistributed) { + + if(isDistributed) return new SparseBlockDistributedVector(size); + else return new DenseLocalOnHeapVector(size); + } + + /** */ + private static int getColumnSize(String separator, List<String> list) { + String[] rowData = list.get(0).split(separator, -1); // assume that all observation has the same length as a first row + + return rowData.length; + } + + /** + * Scales features in dataset. + * + * @param normalization normalization approach + * @return Labeled dataset + */ + public LabeledDataset normalizeWith(Normalization normalization) { + switch (normalization){ + case MINIMAX: minMaxFeatures(); + break; + case Z_NORMALIZATION: throw new UnsupportedOperationException("Z-normalization is not supported yet"); + } + + return this; + } + + /** + * Complexity 2*N^2. Try to optimize. + */ + private void minMaxFeatures() { + double[] mins = new double[colSize]; + double[] maxs = new double[colSize]; + + for (int j = 0; j < colSize; j++) { + double maxInCurrCol = Double.MIN_VALUE; + double minInCurrCol = Double.MAX_VALUE; + + for (int i = 0; i < rowSize; i++) { + double e = data[i].features().get(j); + maxInCurrCol = Math.max(e, maxInCurrCol); + minInCurrCol = Math.min(e, minInCurrCol); + } + + mins[j] = minInCurrCol; + maxs[j] = maxInCurrCol; + } + + for (int j = 0; j < colSize; j++) { + double div = maxs[j] - mins[j]; + + for (int i = 0; i < rowSize; i++) { + double oldVal = data[i].features().get(j); + double newVal = (oldVal - mins[j])/div; + // x'=(x-MIN[X])/(MAX[X]-MIN[X]) + data[i].features().set(j, newVal); + } + } + } + + /** */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + LabeledDataset that = (LabeledDataset)o; + + return rowSize == that.rowSize && colSize == that.colSize && Arrays.equals(data, that.data) && Arrays.equals(featureNames, that.featureNames); + } + + /** */ + @Override public int hashCode() { + int res = Arrays.hashCode(data); + res = 31 * res + Arrays.hashCode(featureNames); + res = 31 * res + rowSize; + res = 31 * res + colSize; + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/8ba773bf/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java index 51b973a..a4e218b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.structures; +import java.io.Serializable; import org.apache.ignite.ml.math.Vector; /** @@ -25,12 +26,12 @@ import org.apache.ignite.ml.math.Vector; * @param <V> Some class extending {@link Vector}. * @param <T> Type of label. */ -public class LabeledVector<V extends Vector, T> { +public class LabeledVector<V extends Vector, T> implements Serializable { /** Vector. */ private final V vector; /** Label. */ - private final T lb; + private T lb; /** * Construct labeled vector. @@ -48,7 +49,7 @@ public class LabeledVector<V extends Vector, T> { * * @return Vector. */ - public V vector() { + public V features() { return vector; } @@ -60,4 +61,34 @@ public class LabeledVector<V extends Vector, T> { public T label() { return lb; } + + /** + * Set the label + * + * @param lb Label. + */ + public void setLabel(T lb) { + this.lb = lb; + } + + /** */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + LabeledVector vector1 = (LabeledVector)o; + + if (vector != null ? !vector.equals(vector1.vector) : vector1.vector != null) + return false; + return lb != null ? lb.equals(vector1.lb) : vector1.lb == null; + } + + /** */ + @Override public int hashCode() { + int res = vector != null ? vector.hashCode() : 0; + res = 31 * res + (lb != null ? lb.hashCode() : 0); + return res; + } }
