This is an automated email from the ASF dual-hosted git repository. zaleslaw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push: new 45024f07 IGNITE-9978: [ML] Implement Compound Naive Bayes classifier (#6567) 45024f07 is described below commit 45024f07c0ce8c05a0abbba512da57e2cc5be435 Author: Ravil Galeyev <deh...@yandex.ru> AuthorDate: Sun Sep 8 11:22:34 2019 +0200 IGNITE-9978: [ML] Implement Compound Naive Bayes classifier (#6567) --- .../ml/naivebayes/CompoundNaiveBayesExample.java | 95 ++++++++++++ .../apache/ignite/ml/naivebayes/BayesModel.java | 34 +++++ .../compound/CompoundNaiveBayesModel.java | 159 ++++++++++++++++++++ .../compound/CompoundNaiveBayesTrainer.java | 160 +++++++++++++++++++++ .../ml/naivebayes/compound/package-info.java | 22 +++ .../discrete/DiscreteNaiveBayesModel.java | 56 +++++--- .../discrete/DiscreteNaiveBayesSumsHolder.java | 2 +- .../discrete/DiscreteNaiveBayesTrainer.java | 2 +- .../gaussian/GaussianNaiveBayesModel.java | 55 ++++--- .../gaussian/GaussianNaiveBayesSumsHolder.java | 2 +- .../gaussian/GaussianNaiveBayesTrainer.java | 2 +- .../apache/ignite/ml/util/MLSandboxDatasets.java | 5 +- .../src/main/resources/datasets/mixed_dataset.csv | 8 ++ .../compound/CompoundNaiveBayesModelTest.java | 95 ++++++++++++ .../compound/CompoundNaiveBayesTest.java | 65 +++++++++ .../compound/CompoundNaiveBayesTrainerTest.java | 108 ++++++++++++++ .../apache/ignite/ml/naivebayes/compound/Data.java | 88 ++++++++++++ .../discrete/DiscreteNaiveBayesModelTest.java | 6 +- .../discrete/DiscreteNaiveBayesTrainerTest.java | 12 +- .../gaussian/GaussianNaiveBayesModelTest.java | 3 +- .../gaussian/GaussianNaiveBayesTrainerTest.java | 22 +-- 21 files changed, 932 insertions(+), 69 deletions(-) diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/CompoundNaiveBayesExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/CompoundNaiveBayesExample.java new file mode 100644 index 0000000..5c0aac2 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/CompoundNaiveBayesExample.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.naivebayes; + +import java.io.FileNotFoundException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer; +import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesTrainer; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer; +import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.util.MLSandboxDatasets; +import org.apache.ignite.ml.util.SandboxMLCache; + +import static java.util.Arrays.asList; + +/** + * Run naive Compound Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier"> + * Nnaive Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer})and <a + * href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes"> Discrete naive Bayes + * classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) over distributed cache. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points. + * <p> + * After that it trains the naive Bayes classification model based on the specified data.</p> + * <p> + * Finally, this example loops over the test set of data points, applies the trained model to predict the target value, + * compares prediction to expected outcome (ground truth), and builds + * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + */ +public class CompoundNaiveBayesExample { + public static void main(String[] args) throws FileNotFoundException { + System.out.println(); + System.out.println(">>> Compound Naive Bayes classification model over partitioned dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.MIXED_DATASET); + + double[] priorProbabilities = new double[] {.5, .5}; + double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}}; + + System.out.println(">>> Create new naive Bayes classification trainer object."); + CompoundNaiveBayesTrainer trainer = new CompoundNaiveBayesTrainer() + .withPriorProbabilities(priorProbabilities) + .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer()) + .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7)) + .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer() + .setBucketThresholds(thresholds)) + .withDiscreteFeatureIdsToSkip(asList(0, 1, 2)); + System.out.println(">>> Perform the training to get the model."); + + Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>() + .labeled(Vectorizer.LabelCoordinate.FIRST); + + CompoundNaiveBayesModel mdl = trainer.fit(ignite, dataCache, vectorizer); + + System.out.println(">>> Compound Naive Bayes model: " + mdl); + + double accuracy = Evaluator.evaluate( + dataCache, + mdl, + vectorizer + ).accuracy(); + + System.out.println("\n>>> Accuracy " + accuracy); + + System.out.println(">>> Compound Naive bayes model over partitioned dataset usage example completed."); + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/BayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/BayesModel.java new file mode 100644 index 0000000..d08b720 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/BayesModel.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes; + +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.IgniteModel; + +/** + * Interface for Bayes Models. + */ +public interface BayesModel<MODEL extends BayesModel, FEATURES, OUTPUT> + extends IgniteModel<FEATURES, OUTPUT>, Exportable<MODEL> { + + /** + * Returns an array where the index correapons a label, and value corresponds {@code log(probalility)} to be this label. + * The prior probabilities are not count. + */ + double[] probabilityPowers(FEATURES vector); +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java new file mode 100644 index 0000000..5daebb7 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.compound; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; + +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel; + +/** + * A compound Naive Bayes model which uses a composition of{@code GaussianNaiveBayesModel} and {@code + * DiscreteNaiveBayesModel}. + */ +public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<CompoundNaiveBayesModel>, Serializable { + /** Serial version uid. */ + private static final long serialVersionUID = -5045925321135798960L; + + /** Prior probabilities of each class. */ + private double[] priorProbabilities; + + /** Labels. */ + private double[] labels; + + /** Gaussian Bayes model. */ + private GaussianNaiveBayesModel gaussianModel; + + /** Feature ids which should be skipped in Gaussian model. */ + private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList(); + + /** Discrete Bayes model. */ + private DiscreteNaiveBayesModel discreteModel; + + /** Feature ids which should be skipped in Discrete model. */ + private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList(); + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<CompoundNaiveBayesModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public Double predict(Vector vector) { + double[] probapilityPowers = new double[priorProbabilities.length]; + for (int i = 0; i < priorProbabilities.length; i++) { + probapilityPowers[i] = Math.log(priorProbabilities[i]); + } + + if (discreteModel != null) { + probapilityPowers = sum(probapilityPowers, discreteModel.probabilityPowers(skipFeatures(vector, discreteFeatureIdsToSkip))); + } + + if (gaussianModel != null) { + probapilityPowers = sum(probapilityPowers, gaussianModel.probabilityPowers(skipFeatures(vector, gaussianFeatureIdsToSkip))); + } + + int maxLabelIndex = 0; + for (int i = 0; i < probapilityPowers.length; i++) { + if (probapilityPowers[i] > probapilityPowers[maxLabelIndex]) { + maxLabelIndex = i; + } + } + return labels[maxLabelIndex]; + } + + /** Returns a gaussian model. */ + public GaussianNaiveBayesModel getGaussianModel() { + return gaussianModel; + } + + /** Returns a discrete model. */ + public DiscreteNaiveBayesModel getDiscreteModel() { + return discreteModel; + } + + /** Sets prior probabilities. */ + public CompoundNaiveBayesModel wirhPriorProbabilities(double[] priorProbabilities) { + this.priorProbabilities = priorProbabilities.clone(); + return this; + } + + /** Sets labels. */ + public CompoundNaiveBayesModel withLabels(double[] labels) { + this.labels = labels.clone(); + return this; + } + + /** Sets a gaussian model. */ + public CompoundNaiveBayesModel withGaussianModel(GaussianNaiveBayesModel gaussianModel) { + this.gaussianModel = gaussianModel; + return this; + } + + /** Sets a discrete model. */ + public CompoundNaiveBayesModel withDiscreteModel(DiscreteNaiveBayesModel discreteModel) { + this.discreteModel = discreteModel; + return this; + } + + /** Sets feature ids to skip in Gaussian Bayes. */ + public CompoundNaiveBayesModel withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) { + this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip; + return this; + } + + /** Sets feature ids to skip in discrete Bayes. */ + public CompoundNaiveBayesModel withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) { + this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip; + return this; + } + + /** Returns index by index sum of two arrays. */ + private static double[] sum(double[] arr1, double[] arr2) { + assert arr1.length == arr2.length; + + double[] result = new double[arr1.length]; + + for (int i = 0; i < arr1.length; i++) { + result[i] = arr1[i] + arr2[i]; + } + return result; + } + + /** Returns a new (shorter) vector without features provided in {@param featureIdsToSkip}. */ + private static Vector skipFeatures(Vector vector, Collection<Integer> featureIdsToSkip) { + int newSize = vector.size() - featureIdsToSkip.size(); + double[] newFeaturesValues = new double[newSize]; + + int index = 0; + for (int j = 0; j < vector.size(); j++) { + if(featureIdsToSkip.contains(j)) continue; + + newFeaturesValues[index] = vector.get(j); + ++index; + } + return VectorUtils.of(newFeaturesValues); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.java new file mode 100644 index 0000000..9b8cd46 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.compound; + +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer; +import org.apache.ignite.ml.preprocessing.Preprocessor; +import org.apache.ignite.ml.structures.LabeledVector; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; + +import java.util.Collection; +import java.util.Collections; + +/** + * Trainer for the compound Naive Bayes classifier model. It uses a model composition of {@code + * GaussianNaiveBayesTrainer} and {@code DiscreteNaiveBayesTrainer}. To distinguish which features with which trainer + * should be used, each trainer should have a collection of feature ids which should be skipped. It can be set by {@code + * #setFeatureIdsToSkip()} method. + */ +public class CompoundNaiveBayesTrainer extends SingleLabelDatasetTrainer<CompoundNaiveBayesModel> { + + /** Prior probabilities of each class. */ + private double[] priorProbabilities; + + /** Gaussian Naive Bayes trainer. */ + private GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer; + + /** Feature ids which should be skipped in Gaussian model. */ + private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList(); + + /** Discrete Naive Bayes trainer. */ + private DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer; + + /** Feature ids which should be skipped in Discrete model. */ + private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList(); + + /** {@inheritDoc} */ + @Override public <K, V> CompoundNaiveBayesModel fit(DatasetBuilder<K, V> datasetBuilder, + Preprocessor<K, V> extractor) { + return updateModel(null, datasetBuilder, extractor); + } + + /** {@inheritDoc} */ + @Override public boolean isUpdateable(CompoundNaiveBayesModel mdl) { + return gaussianNaiveBayesTrainer.isUpdateable(mdl.getGaussianModel()) + && discreteNaiveBayesTrainer.isUpdateable(mdl.getDiscreteModel()); + } + + /** {@inheritDoc} */ + @Override public CompoundNaiveBayesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (CompoundNaiveBayesTrainer)super.withEnvironmentBuilder(envBuilder); + } + + /** {@inheritDoc} */ + @Override protected <K, V> CompoundNaiveBayesModel updateModel(CompoundNaiveBayesModel mdl, + DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) { + + CompoundNaiveBayesModel compoundModel = new CompoundNaiveBayesModel() + .wirhPriorProbabilities(priorProbabilities); + + if (gaussianNaiveBayesTrainer != null) { + if (priorProbabilities != null) { + gaussianNaiveBayesTrainer.setPriorProbabilities(priorProbabilities); + } + GaussianNaiveBayesModel model = (mdl == null) + ? gaussianNaiveBayesTrainer.fit(datasetBuilder, extractor.map(skipFeatures(gaussianFeatureIdsToSkip))) + : gaussianNaiveBayesTrainer.update(mdl.getGaussianModel(), datasetBuilder, extractor.map(skipFeatures(gaussianFeatureIdsToSkip))); + + compoundModel.withGaussianModel(model) + .withGaussianFeatureIdsToSkip(gaussianFeatureIdsToSkip) + .withLabels(model.getLabels()) + .wirhPriorProbabilities(priorProbabilities); + } + + if (discreteNaiveBayesTrainer != null) { + if (priorProbabilities != null) { + discreteNaiveBayesTrainer.setPriorProbabilities(priorProbabilities); + } + DiscreteNaiveBayesModel model = (mdl == null) + ? discreteNaiveBayesTrainer.fit(datasetBuilder, extractor.map(skipFeatures(discreteFeatureIdsToSkip))) + : discreteNaiveBayesTrainer.update(mdl.getDiscreteModel(), datasetBuilder, extractor.map(skipFeatures(discreteFeatureIdsToSkip))); + + compoundModel.withDiscreteModel(model) + .withDiscreteFeatureIdsToSkip(discreteFeatureIdsToSkip) + .withLabels(model.getLabels()) + .wirhPriorProbabilities(priorProbabilities); + } + + return compoundModel; + } + + /** Sets prior probabilities. */ + public CompoundNaiveBayesTrainer withPriorProbabilities(double[] priorProbabilities) { + this.priorProbabilities = priorProbabilities.clone(); + return this; + } + + /** Sets a gaussian trainer. */ + public CompoundNaiveBayesTrainer withGaussianNaiveBayesTrainer(GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer) { + this.gaussianNaiveBayesTrainer = gaussianNaiveBayesTrainer; + return this; + } + + /** Sets a discrete trainer. */ + public CompoundNaiveBayesTrainer withDiscreteNaiveBayesTrainer(DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer) { + this.discreteNaiveBayesTrainer = discreteNaiveBayesTrainer; + return this; + } + + /** Sets feature ids to skip in Gaussian Bayes. */ + public CompoundNaiveBayesTrainer withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) { + this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip; + return this; + } + + /** Sets feature ids to skip in discrete Bayes. */ + public CompoundNaiveBayesTrainer withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) { + this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip; + return this; + } + + /** Removes features provided in {@param featureIdsToSkip} from a vector. */ + private static IgniteFunction<LabeledVector<Object>, LabeledVector<Object>> skipFeatures(Collection<Integer> featureIdsToSkip) { + return featureValues -> { + final int size = featureValues.features().size(); + int newSize = size - featureIdsToSkip.size(); + + double[] newFeaturesValues = new double[newSize]; + int index = 0; + for (int j = 0; j < size; j++) { + if(featureIdsToSkip.contains(j)) continue; + + newFeaturesValues[index] = featureValues.get(j); + ++index; + } + return new LabeledVector<>(VectorUtils.of(newFeaturesValues), featureValues.label()); + }; + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/package-info.java new file mode 100644 index 0000000..0806c48 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * <!-- Package description. --> + * Contains Compound naive Bayes classifier. + */ +package org.apache.ignite.ml.naivebayes.compound; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java index 3c35841..44d0767 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java @@ -17,20 +17,17 @@ package org.apache.ignite.ml.naivebayes.discrete; -import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.math.primitives.vector.Vector; - -import java.io.Serializable; +import org.apache.ignite.ml.naivebayes.BayesModel; /** * Discrete naive Bayes model which predicts result value {@code y} belongs to a class {@code C_k, k in [0..K]} as * {@code p(C_k,y) =x_1*p_k1^x *...*x_i*p_ki^x_i}. Where {@code x_i} is a discrete feature, {@code p_ki} is a prior * probability probability of class {@code p(x|C_k)}. Returns the number of the most possible class. */ -public class DiscreteNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<DiscreteNaiveBayesModel>, Serializable { - /** */ +public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesModel, Vector, Double> { + /** Serial version uid. */ private static final long serialVersionUID = -127386523291350345L; /** @@ -58,9 +55,9 @@ public class DiscreteNaiveBayesModel implements IgniteModel<Vector, Double>, Exp /** * @param probabilities Probabilities of features for classes. * @param clsProbabilities Prior probabilities for classes. + * @param labels Labels. * @param bucketThresholds The threshold to convert a feature to a binary value. * @param sumsHolder Amount values which are abouve the threshold per label. - * @param labels Labels. */ public DiscreteNaiveBayesModel(double[][][] probabilities, double[] clsProbabilities, double[] labels, double[][] bucketThresholds, DiscreteNaiveBayesSumsHolder sumsHolder) { @@ -81,42 +78,57 @@ public class DiscreteNaiveBayesModel implements IgniteModel<Vector, Double>, Exp * @return A label with max probability. */ @Override public Double predict(Vector vector) { - double maxProbapilityPower = -Double.MAX_VALUE; - int maxLabelIndex = -1; + double[] probapilityPowers = probabilityPowers(vector); - for (int i = 0; i < clsProbabilities.length; i++) { - double probabilityPower = Math.log(clsProbabilities[i]); + int maxLabelIndex = 0; + for (int i = 0; i < probapilityPowers.length; i++) { + probapilityPowers[i] += Math.log(clsProbabilities[i]); + + if (probapilityPowers[i] > probapilityPowers[maxLabelIndex]) { + maxLabelIndex = i; + } + } + + return labels[maxLabelIndex]; + } + /** {@inheritDoc} */ + @Override + public double[] probabilityPowers(Vector vector) { + double[] probapilityPowers = new double[clsProbabilities.length]; + + for (int i = 0; i < clsProbabilities.length; i++) { for (int j = 0; j < probabilities[0].length; j++) { int x = toBucketNumber(vector.get(j), bucketThresholds[j]); double p = probabilities[i][j][x]; - probabilityPower += (p > 0 ? Math.log(p) : .0); - } - - if (probabilityPower > maxProbapilityPower) { - maxLabelIndex = i; - maxProbapilityPower = probabilityPower; + probapilityPowers[i] += (p > 0 ? Math.log(p) : .0); } } - return labels[maxLabelIndex]; + + return probapilityPowers; } - /** */ + /** A getter for probabilities.*/ public double[][][] getProbabilities() { return probabilities; } - /** */ + /** A getter for clsProbabilities.*/ public double[] getClsProbabilities() { return clsProbabilities; } - /** */ + /** A getter for bucketThresholds.*/ public double[][] getBucketThresholds() { return bucketThresholds; } - /** */ + /** A getter for labels.*/ + public double[] getLabels() { + return labels; + } + + /** A getter for sumsHolder.*/ public DiscreteNaiveBayesSumsHolder getSumsHolder() { return sumsHolder; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java index 9ea18ca..50b335e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java @@ -49,7 +49,7 @@ public class DiscreteNaiveBayesSumsHolder implements AutoCloseable, Serializable return arr1; } - /** */ + /** {@inheritDoc} */ @Override public void close() { // Do nothing, GC will clean up. } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java index f5dd6e6..a80c402 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java @@ -115,7 +115,7 @@ public class DiscreteNaiveBayesTrainer extends SingleLabelDatasetTrainer<Discret })) { DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> { if (a == null) - return b == null ? new DiscreteNaiveBayesSumsHolder() : b; + return b; if (b == null) return a; return a.merge(b); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java index bef2e39..29b8cc2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java @@ -17,18 +17,16 @@ package org.apache.ignite.ml.naivebayes.gaussian; -import java.io.Serializable; -import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; -import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.naivebayes.BayesModel; /** * Simple naive Bayes model which predicts result value {@code y} belongs to a class {@code C_k, k in [0..K]} as {@code * p(C_k,y) = p(C_k)*p(y_1,C_k) *...*p(y_n,C_k) / p(y)}. Return the number of the most possible class. */ -public class GaussianNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<GaussianNaiveBayesModel>, Serializable { - /** */ +public class GaussianNaiveBayesModel implements BayesModel<GaussianNaiveBayesModel, Vector, Double> { + /** Serial version uid. */ private static final long serialVersionUID = -127386523291350345L; /** Means of features for all classes. kth row contains means for labels[k] class. */ @@ -69,48 +67,61 @@ public class GaussianNaiveBayesModel implements IgniteModel<Vector, Double>, Exp /** Returns a number of class to which the input belongs. */ @Override public Double predict(Vector vector) { - int k = classProbabilities.length; + double[] probapilityPowers = probabilityPowers(vector); - double maxProbability = .0; int max = 0; + for (int i = 0; i < probapilityPowers.length; i++) { + probapilityPowers[i] += Math.log(classProbabilities[i]); - for (int i = 0; i < k; i++) { - double p = classProbabilities[i]; - for (int j = 0; j < vector.size(); j++) { - double x = vector.get(j); - double g = gauss(x, means[i][j], variances[i][j]); - p *= g; - } - if (p > maxProbability) { + if (probapilityPowers[i] > probapilityPowers[max]) { max = i; - maxProbability = p; } } return labels[max]; } - /** */ + /** {@inheritDoc} */ + @Override + public double[] probabilityPowers(Vector vector) { + double[] probapilityPowers = new double[classProbabilities.length]; + + for (int i = 0; i < classProbabilities.length; i++) { + for (int j = 0; j < vector.size(); j++) { + double x = vector.get(j); + double parobability = gauss(x, means[i][j], variances[i][j]); + probapilityPowers[i] += (parobability > 0 ? Math.log(parobability) : .0); + } + } + return probapilityPowers; + } + + /** A getter for means.*/ public double[][] getMeans() { return means; } - /** */ + /** A getter for variances.*/ public double[][] getVariances() { return variances; } - /** */ + /** A getter for classProbabilities.*/ public double[] getClassProbabilities() { return classProbabilities; } - /** */ + /** A getter for labels.*/ + public double[] getLabels() { + return labels; + } + + /** A getter for sumsHolder.*/ public GaussianNaiveBayesSumsHolder getSumsHolder() { return sumsHolder; } - /** Gauss distribution */ - private double gauss(double x, double mean, double variance) { + /** Gauss distribution. */ + private static double gauss(double x, double mean, double variance) { return Math.exp(-1. * Math.pow(x - mean, 2) / (2. * variance)) / Math.sqrt(2. * Math.PI * variance); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java index eecec74..7b95ff8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java @@ -51,7 +51,7 @@ class GaussianNaiveBayesSumsHolder implements Serializable, AutoCloseable { return arr1; } - /** */ + /** {@inheritDoc} */ @Override public void close() { // Do nothing, GC will clean up. } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java index 40ca840..84a16ad 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java @@ -106,7 +106,7 @@ public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<Gaussia )) { GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> { if (a == null) - return b == null ? new GaussianNaiveBayesSumsHolder() : b; + return b; if (b == null) return a; return a.merge(b); diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java index 28e4c9a..6574902 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java @@ -59,7 +59,10 @@ public enum MLSandboxDatasets { WHOLESALE_CUSTOMERS("modules/ml/src/main/resources/datasets/wholesale_customers.csv", true, ","), /** Fraud detection problem [part of whole dataset]. Could be found <a href="https://www.kaggle.com/mlg-ulb/creditcardfraud/">here</a>. */ - FRAUD_DETECTION("modules/ml/src/main/resources/datasets/fraud_detection.csv", false, ","); + FRAUD_DETECTION("modules/ml/src/main/resources/datasets/fraud_detection.csv", false, ","), + + /** A dataset with discrete and continious features. */ + MIXED_DATASET("modules/ml/src/main/resources/datasets/mixed_dataset.csv", true, ","); /** Filename. */ private final String filename; diff --git a/modules/ml/src/main/resources/datasets/mixed_dataset.csv b/modules/ml/src/main/resources/datasets/mixed_dataset.csv new file mode 100644 index 0000000..5a0df2b --- /dev/null +++ b/modules/ml/src/main/resources/datasets/mixed_dataset.csv @@ -0,0 +1,8 @@ +1, 6, 180, 12, 0, 0, 1, 1, 1 +1, 5.92, 190, 11, 1, 0, 1, 1, 0 +1, 5.58, 170, 12, 1, 1, 0, 0, 1 +1, 5.92, 165, 10, 1, 1, 0, 0, 0 +0, 5, 100, 6, 1, 0, 0, 1, 1 +0, 5.5, 150, 8, 1, 1, 0, 0, 1 +0, 5.42, 130, 7, 1, 1, 1, 1, 0 +0, 5.75, 150, 9, 1, 1, 0, 1, 0 \ No newline at end of file diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModelTest.java new file mode 100644 index 0000000..227a179 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModelTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.compound; + +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel; +import org.junit.Test; + +import static java.util.Arrays.asList; +import static org.apache.ignite.ml.naivebayes.compound.Data.LABEL_2; +import static org.apache.ignite.ml.naivebayes.compound.Data.binarizedDataThresholds; +import static org.apache.ignite.ml.naivebayes.compound.Data.classProbabilities; +import static org.apache.ignite.ml.naivebayes.compound.Data.labels; +import static org.apache.ignite.ml.naivebayes.compound.Data.means; +import static org.apache.ignite.ml.naivebayes.compound.Data.probabilities; +import static org.apache.ignite.ml.naivebayes.compound.Data.variances; +import static org.junit.Assert.assertEquals; + +/** Tests for {@link CompoundNaiveBayesModel} */ +public class CompoundNaiveBayesModelTest { + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** Test. */ + @Test + public void testPredictOnlyGauss() { + GaussianNaiveBayesModel gaussianModel = + new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, null); + + Vector observation = VectorUtils.of(6, 130, 8); + + CompoundNaiveBayesModel model = new CompoundNaiveBayesModel() + .wirhPriorProbabilities(classProbabilities) + .withLabels(labels) + .withGaussianModel(gaussianModel); + + assertEquals(LABEL_2, model.predict(observation), PRECISION); + } + + /** Test. */ + @Test + public void testPredictOnlyDiscrete() { + DiscreteNaiveBayesModel discreteModel = + new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, binarizedDataThresholds, null); + + Vector observation = VectorUtils.of(1, 0, 1, 1, 0); + + CompoundNaiveBayesModel model = new CompoundNaiveBayesModel() + .wirhPriorProbabilities(classProbabilities) + .withLabels(labels) + .withDiscreteModel(discreteModel); + + assertEquals(LABEL_2, model.predict(observation), PRECISION); + } + + /** Test. */ + @Test + public void testPredictGausAndDiscrete() { + DiscreteNaiveBayesModel discreteModel = + new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, binarizedDataThresholds, null); + + GaussianNaiveBayesModel gaussianModel = + new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, null); + + CompoundNaiveBayesModel model = new CompoundNaiveBayesModel() + .wirhPriorProbabilities(classProbabilities) + .withLabels(labels) + .withGaussianModel(gaussianModel) + .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7)) + .withDiscreteModel(discreteModel) + .withDiscreteFeatureIdsToSkip( asList(0, 1, 2)); + + Vector observation = VectorUtils.of(6, 130, 8, 1, 0, 1, 1, 0); + + assertEquals(LABEL_2, model.predict(observation), PRECISION); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTest.java new file mode 100644 index 0000000..c4cef0e --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.compound; + +import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer; +import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer; +import org.junit.Test; + +import static java.util.Arrays.asList; +import static org.apache.ignite.ml.naivebayes.compound.Data.LABEL_1; +import static org.apache.ignite.ml.naivebayes.compound.Data.LABEL_2; +import static org.apache.ignite.ml.naivebayes.compound.Data.binarizedDataThresholds; +import static org.apache.ignite.ml.naivebayes.compound.Data.classProbabilities; +import static org.apache.ignite.ml.naivebayes.compound.Data.data; +import static org.junit.Assert.assertEquals; + +/** Integration tests for Compound naive Bayes algorithm with different datasets. */ +public class CompoundNaiveBayesTest { + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** Test. */ + @Test + public void testLearnsAndPredictCorrently() { + CompoundNaiveBayesTrainer trainer = new CompoundNaiveBayesTrainer() + .withPriorProbabilities(classProbabilities) + .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer()) + .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7)) + .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer() + .setBucketThresholds(binarizedDataThresholds)) + .withDiscreteFeatureIdsToSkip(asList(0, 1, 2)); + + CompoundNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(data, 2), + new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST) + ); + + Vector observation1 = VectorUtils.of(5.92, 165, 10, 1, 1, 0, 0, 0); + assertEquals(LABEL_1, model.predict(observation1), PRECISION); + + Vector observation2 = VectorUtils.of(6, 130, 8, 1, 0, 1, 1, 0); + assertEquals(LABEL_2, model.predict(observation2), PRECISION); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainerTest.java new file mode 100644 index 0000000..9b40fbb --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainerTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.compound; + +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer; +import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel; +import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer; +import org.junit.Before; +import org.junit.Test; + +import static java.util.Arrays.asList; +import static org.apache.ignite.ml.naivebayes.compound.Data.binarizedDataThresholds; +import static org.apache.ignite.ml.naivebayes.compound.Data.classProbabilities; +import static org.apache.ignite.ml.naivebayes.compound.Data.data; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Test for {@link CompoundNaiveBayesTrainer} */ +public class CompoundNaiveBayesTrainerTest extends TrainerTest { + + /** Precision in test checks. */ + private static final double PRECISION = 1e-2; + + /** Trainer under test. */ + private CompoundNaiveBayesTrainer trainer; + + /** Initialization {@code CompoundNaiveBayesTrainer}. */ + @Before + public void createTrainer() { + trainer = new CompoundNaiveBayesTrainer() + .withPriorProbabilities(classProbabilities) + .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer()) + .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7)) + .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer() + .setBucketThresholds(binarizedDataThresholds)) + .withDiscreteFeatureIdsToSkip(asList(0, 1, 2)); + } + + /** Test. */ + @Test + public void test() { + CompoundNaiveBayesModel model = trainer.fit( + new LocalDatasetBuilder<>(data, parts), + new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST) + ); + + assertDiscreteModel(model.getDiscreteModel()); + assertGaussianModel(model.getGaussianModel()); + } + + /** Discrete model assertions. */ + private void assertDiscreteModel(DiscreteNaiveBayesModel model) { + double[][][] expectedProbabilites = new double[][][] { + { + {.25, .75}, + {.5, .5}, + {.5, .5}, + {.5, .5}, + {.5, .5} + }, + { + {.0, 1}, + {.25, .75}, + {.75, .25}, + {.25, .75}, + {.5, .5} + } + }; + + for (int i = 0; i < expectedProbabilites.length; i++) { + for (int j = 0; j < expectedProbabilites[i].length; j++) + assertArrayEquals(expectedProbabilites[i][j], model.getProbabilities()[i][j], PRECISION); + } + assertArrayEquals(new double[] {.5, .5}, model.getClsProbabilities(), PRECISION); + } + + /** Gaussian model assertions. */ + private void assertGaussianModel(GaussianNaiveBayesModel model) { + double[] priorProbabilities = new double[] {.5, .5}; + + assertEquals(priorProbabilities[0], model.getClassProbabilities()[0], PRECISION); + assertEquals(priorProbabilities[1], model.getClassProbabilities()[1], PRECISION); + assertArrayEquals(new double[] {5.855, 176.25, 11.25}, model.getMeans()[0], PRECISION); + assertArrayEquals(new double[] {5.4175, 132.5, 7.5}, model.getMeans()[1], PRECISION); + double[] expectedVars = {0.026274999999999, 92.1875, 0.6875}; + assertArrayEquals(expectedVars, model.getVariances()[0], PRECISION); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/Data.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/Data.java new file mode 100644 index 0000000..ee34f85 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/Data.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.naivebayes.compound; + +import java.util.HashMap; +import java.util.Map; + +/** Data class which contains test data with precalculated statistics. */ +final class Data { + + /** Private constructor. */ + private Data() { + } + + /** The first label. */ + static final double LABEL_1 = 1.; + + /** The second label. */ + static final double LABEL_2 = 2.; + + /** Labels. */ + static final double[] labels = {LABEL_1, LABEL_2}; + + /** */ + static final Map<Integer, double[]> data = new HashMap<>(); + + /** Means for gaussian data part. */ + static double[][] means; + + /** Variances for gaussian data part. */ + static double[][] variances; + + /** */ + static double[] classProbabilities; + + /** Thresholds to binarize discret data. */ + static double[][] binarizedDataThresholds; + + /** Discrete probabilities. */ + static double[][][] probabilities; + + static { + data.put(0, new double[] {6, 180, 12, 0, 0, 1, 1, 1, LABEL_1}); + data.put(1, new double[] {5.92, 190, 11, 1, 0, 1, 1, 0, LABEL_1}); + data.put(2, new double[] {5.58, 170, 12, 1, 1, 0, 0, 1, LABEL_1}); + data.put(3, new double[] {5.92, 165, 10, 1, 1, 0, 0, 0, LABEL_1}); + + data.put(4, new double[] {5, 100, 6, 1, 0, 0, 1, 1, LABEL_2}); + data.put(5, new double[] {5.5, 150, 8, 1, 1, 0, 0, 1, LABEL_2}); + data.put(6, new double[] {5.42, 130, 7, 1, 1, 1, 1, 0, LABEL_2}); + data.put(7, new double[] {5.75, 150, 9, 1, 1, 0, 1, 0, LABEL_2}); + + classProbabilities = new double[] {.5, .5}; + + means = new double[][] { + {5.855, 176.25, 11.25}, + {5.4175, 132.5, 7.5}, + }; + + variances = new double[][] { + {3.5033E-2, 1.2292E2, 9.1667E-1}, + {9.7225E-2, 5.5833E2, 1.6667}, + }; + + binarizedDataThresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}}; + + probabilities = new double[][][] { + {{.25, .75}, {.25, .75}, {.5, .5}, {.5, .5}, {.5, .5}}, + {{0, 1}, {.25, .75}, {.75, .25}, {.25, .75}, {.5, .5}} + }; + } + +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java index 41d320d..60d0cdd 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java @@ -24,7 +24,8 @@ import org.junit.Test; /** Tests for {@code DiscreteNaiveBayesModel} */ public class DiscreteNaiveBayesModelTest { - /** */ + + /** Test. */ @Test public void testPredictWithTwoClasses() { double first = 1; @@ -36,7 +37,8 @@ public class DiscreteNaiveBayesModelTest { double[] classProbabilities = new double[] {6. / 13, 7. / 13}; double[][] thresholds = new double[][] {{.5}, {.2, .7}, {.5}, {.5, 1.5}, {.5}}; - DiscreteNaiveBayesModel mdl = new DiscreteNaiveBayesModel(probabilities, classProbabilities, new double[] {first, second}, thresholds, new DiscreteNaiveBayesSumsHolder()); + DiscreteNaiveBayesModel mdl = new DiscreteNaiveBayesModel(probabilities, classProbabilities, + new double[] {first, second}, thresholds, new DiscreteNaiveBayesSumsHolder()); Vector observation = VectorUtils.of(2, 0, 1, 2, 0); Assert.assertEquals(second, mdl.predict(observation), 0.0001); diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java index a864fd6..7426f56 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java @@ -79,7 +79,7 @@ public class DiscreteNaiveBayesTrainerTest extends TrainerTest { data.put(11, new double[] {2, 1, .45, 748, LABEL_2}); } - /** */ + /** Trainer under test. */ private DiscreteNaiveBayesTrainer trainer; /** Initialization {@code DiscreteNaiveBayesTrainer}. */ @@ -88,7 +88,7 @@ public class DiscreteNaiveBayesTrainerTest extends TrainerTest { trainer = new DiscreteNaiveBayesTrainer().setBucketThresholds(binarizedDatathresholds); } - /** */ + /** Test. */ @Test public void testReturnsCorrectLabelProbalities() { @@ -101,7 +101,7 @@ public class DiscreteNaiveBayesTrainerTest extends TrainerTest { Assert.assertArrayEquals(expectedProbabilities, model.getClsProbabilities(), PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() { DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer() @@ -116,7 +116,7 @@ public class DiscreteNaiveBayesTrainerTest extends TrainerTest { Assert.assertArrayEquals(new double[] {.5, .5}, model.getClsProbabilities(), PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() { double[] priorProbabilities = new double[] {.35, .65}; @@ -132,7 +132,7 @@ public class DiscreteNaiveBayesTrainerTest extends TrainerTest { Assert.assertArrayEquals(priorProbabilities, model.getClsProbabilities(), PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsCorrectPriorProbabilities() { double[][][] expectedPriorProbabilites = new double[][][] { @@ -151,7 +151,7 @@ public class DiscreteNaiveBayesTrainerTest extends TrainerTest { } } - /** */ + /** Test. */ @Test public void testReturnsCorrectPriorProbabilitiesWithDefferentThresholds() { double[][][] expectedPriorProbabilites = new double[][][] { diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java index d35ea3d..043ef38 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java @@ -26,7 +26,8 @@ import org.junit.Test; * Tests for {@link GaussianNaiveBayesModel}. */ public class GaussianNaiveBayesModelTest { - /** */ + + /** Test. */ @Test public void testPredictWithTwoClasses() { double first = 1; diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java index 80927e9..d00cfc6 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java @@ -36,19 +36,19 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { /** Precision in test checks. */ private static final double PRECISION = 1e-2; - /** */ + /** Label. */ private static final double LABEL_1 = 1.; - /** */ + /** Label. */ private static final double LABEL_2 = 2.; /** Data. */ private static final Map<Integer, double[]> data = new HashMap<>(); - /** */ + /** {@code LABEL_1} data. */ private static final Map<Integer, double[]> singleLabeldata1 = new HashMap<>(); - /** */ + /** {@code LABEL_2} data. */ private static final Map<Integer, double[]> singleLabeldata2 = new HashMap<>(); static { @@ -75,7 +75,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { trainer = new GaussianNaiveBayesTrainer(); } - /** */ + /** Test. */ @Test public void testWithLinearlySeparableData() { Map<Integer, double[]> cacheMock = new HashMap<>(); @@ -92,7 +92,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsCorrectLabelProbalities() { @@ -105,7 +105,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { Assert.assertEquals(2. / data.size(), model.getClassProbabilities()[1], PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() { GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer() @@ -120,7 +120,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { Assert.assertEquals(.5, model.getClassProbabilities()[1], PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() { double[] priorProbabilities = new double[] {.35, .65}; @@ -136,7 +136,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { Assert.assertEquals(priorProbabilities[1], model.getClassProbabilities()[1], PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsCorrectMeans() { @@ -148,7 +148,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { Assert.assertArrayEquals(new double[] {2.0, 2. / 3.}, model.getMeans()[0], PRECISION); } - /** */ + /** Test. */ @Test public void testReturnsCorrectVariances() { @@ -161,7 +161,7 @@ public class GaussianNaiveBayesTrainerTest extends TrainerTest { Assert.assertArrayEquals(expectedVars, model.getVariances()[0], PRECISION); } - /** */ + /** Test. */ @Test public void testUpdatigModel() { Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST);