This is an automated email from the ASF dual-hosted git repository.

zaleslaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new 45024f07 IGNITE-9978: [ML] Implement Compound Naive Bayes classifier 
(#6567)
45024f07 is described below

commit 45024f07c0ce8c05a0abbba512da57e2cc5be435
Author: Ravil Galeyev <deh...@yandex.ru>
AuthorDate: Sun Sep 8 11:22:34 2019 +0200

    IGNITE-9978: [ML] Implement Compound Naive Bayes classifier (#6567)
---
 .../ml/naivebayes/CompoundNaiveBayesExample.java   |  95 ++++++++++++
 .../apache/ignite/ml/naivebayes/BayesModel.java    |  34 +++++
 .../compound/CompoundNaiveBayesModel.java          | 159 ++++++++++++++++++++
 .../compound/CompoundNaiveBayesTrainer.java        | 160 +++++++++++++++++++++
 .../ml/naivebayes/compound/package-info.java       |  22 +++
 .../discrete/DiscreteNaiveBayesModel.java          |  56 +++++---
 .../discrete/DiscreteNaiveBayesSumsHolder.java     |   2 +-
 .../discrete/DiscreteNaiveBayesTrainer.java        |   2 +-
 .../gaussian/GaussianNaiveBayesModel.java          |  55 ++++---
 .../gaussian/GaussianNaiveBayesSumsHolder.java     |   2 +-
 .../gaussian/GaussianNaiveBayesTrainer.java        |   2 +-
 .../apache/ignite/ml/util/MLSandboxDatasets.java   |   5 +-
 .../src/main/resources/datasets/mixed_dataset.csv  |   8 ++
 .../compound/CompoundNaiveBayesModelTest.java      |  95 ++++++++++++
 .../compound/CompoundNaiveBayesTest.java           |  65 +++++++++
 .../compound/CompoundNaiveBayesTrainerTest.java    | 108 ++++++++++++++
 .../apache/ignite/ml/naivebayes/compound/Data.java |  88 ++++++++++++
 .../discrete/DiscreteNaiveBayesModelTest.java      |   6 +-
 .../discrete/DiscreteNaiveBayesTrainerTest.java    |  12 +-
 .../gaussian/GaussianNaiveBayesModelTest.java      |   3 +-
 .../gaussian/GaussianNaiveBayesTrainerTest.java    |  22 +--
 21 files changed, 932 insertions(+), 69 deletions(-)

diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/CompoundNaiveBayesExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/CompoundNaiveBayesExample.java
new file mode 100644
index 0000000..5c0aac2
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/CompoundNaiveBayesExample.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.naivebayes;
+
+import java.io.FileNotFoundException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.util.MLSandboxDatasets;
+import org.apache.ignite.ml.util.SandboxMLCache;
+
+import static java.util.Arrays.asList;
+
+/**
+ * Run naive Compound Bayes classification model based on <a 
href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier";>
+ * Nnaive Bayes classifier</a> algorithm ({@link 
GaussianNaiveBayesTrainer})and <a
+ * 
href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes";>
 Discrete naive Bayes
+ * classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) over 
distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test 
data points.
+ * <p>
+ * After that it trains the naive Bayes classification model based on the 
specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the 
trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix";>confusion 
matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore 
this algorithm further.</p>
+ */
+public class CompoundNaiveBayesExample {
+    public static void main(String[] args) throws FileNotFoundException {
+        System.out.println();
+        System.out.println(">>> Compound Naive Bayes classification model over 
partitioned dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+                .fillCacheWith(MLSandboxDatasets.MIXED_DATASET);
+
+            double[] priorProbabilities = new double[] {.5, .5};
+            double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, 
{.5}};
+
+            System.out.println(">>> Create new naive Bayes classification 
trainer object.");
+            CompoundNaiveBayesTrainer trainer = new CompoundNaiveBayesTrainer()
+                .withPriorProbabilities(priorProbabilities)
+                .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer())
+                .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7))
+                .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer()
+                    .setBucketThresholds(thresholds))
+                .withDiscreteFeatureIdsToSkip(asList(0, 1, 2));
+            System.out.println(">>> Perform the training to get the model.");
+
+            Vectorizer<Integer, Vector, Integer, Double> vectorizer = new 
DummyVectorizer<Integer>()
+                .labeled(Vectorizer.LabelCoordinate.FIRST);
+
+            CompoundNaiveBayesModel mdl = trainer.fit(ignite, dataCache, 
vectorizer);
+
+            System.out.println(">>> Compound Naive Bayes model: " + mdl);
+
+            double accuracy = Evaluator.evaluate(
+                dataCache,
+                mdl,
+                vectorizer
+            ).accuracy();
+
+            System.out.println("\n>>> Accuracy " + accuracy);
+
+            System.out.println(">>> Compound Naive bayes model over 
partitioned dataset usage example completed.");
+        }
+    }
+}
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/BayesModel.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/BayesModel.java
new file mode 100644
index 0000000..d08b720
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/BayesModel.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes;
+
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.IgniteModel;
+
+/**
+ * Interface for Bayes Models.
+ */
+public interface BayesModel<MODEL extends BayesModel, FEATURES, OUTPUT>
+        extends IgniteModel<FEATURES, OUTPUT>, Exportable<MODEL> {
+
+    /**
+     * Returns an array where the index correapons a label, and value 
corresponds {@code log(probalility)} to be this label.
+     * The prior probabilities are not count.
+     */
+    double[] probabilityPowers(FEATURES vector);
+}
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java
new file mode 100644
index 0000000..5daebb7
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.compound;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+
+/**
+ * A compound Naive Bayes model which uses a composition of{@code 
GaussianNaiveBayesModel} and {@code
+ * DiscreteNaiveBayesModel}.
+ */
+public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, 
Exportable<CompoundNaiveBayesModel>, Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = -5045925321135798960L;
+
+    /** Prior probabilities of each class. */
+    private double[] priorProbabilities;
+
+    /** Labels. */
+    private double[] labels;
+
+    /** Gaussian Bayes model. */
+    private GaussianNaiveBayesModel gaussianModel;
+
+    /** Feature ids which should be skipped in Gaussian model. */
+    private Collection<Integer> gaussianFeatureIdsToSkip = 
Collections.emptyList();
+
+    /** Discrete Bayes model. */
+    private DiscreteNaiveBayesModel discreteModel;
+
+    /** Feature ids which should be skipped in Discrete model. */
+    private Collection<Integer> discreteFeatureIdsToSkip = 
Collections.emptyList();
+
+    /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<CompoundNaiveBayesModel, P> 
exporter, P path) {
+        exporter.save(this, path);
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double predict(Vector vector) {
+        double[] probapilityPowers = new double[priorProbabilities.length];
+        for (int i = 0; i < priorProbabilities.length; i++) {
+            probapilityPowers[i] = Math.log(priorProbabilities[i]);
+        }
+
+        if (discreteModel != null) {
+            probapilityPowers = sum(probapilityPowers, 
discreteModel.probabilityPowers(skipFeatures(vector, 
discreteFeatureIdsToSkip)));
+        }
+
+        if (gaussianModel != null) {
+            probapilityPowers = sum(probapilityPowers, 
gaussianModel.probabilityPowers(skipFeatures(vector, 
gaussianFeatureIdsToSkip)));
+        }
+
+        int maxLabelIndex = 0;
+        for (int i = 0; i < probapilityPowers.length; i++) {
+            if (probapilityPowers[i] > probapilityPowers[maxLabelIndex]) {
+                maxLabelIndex = i;
+            }
+        }
+        return labels[maxLabelIndex];
+    }
+
+    /** Returns a gaussian model. */
+    public GaussianNaiveBayesModel getGaussianModel() {
+        return gaussianModel;
+    }
+
+    /** Returns a discrete model. */
+    public DiscreteNaiveBayesModel getDiscreteModel() {
+        return discreteModel;
+    }
+
+    /** Sets prior probabilities. */
+    public CompoundNaiveBayesModel wirhPriorProbabilities(double[] 
priorProbabilities) {
+        this.priorProbabilities = priorProbabilities.clone();
+        return this;
+    }
+
+    /** Sets labels. */
+    public CompoundNaiveBayesModel withLabels(double[] labels) {
+        this.labels = labels.clone();
+        return this;
+    }
+
+    /** Sets a gaussian model. */
+    public CompoundNaiveBayesModel withGaussianModel(GaussianNaiveBayesModel 
gaussianModel) {
+        this.gaussianModel = gaussianModel;
+        return this;
+    }
+
+    /** Sets a discrete model. */
+    public CompoundNaiveBayesModel withDiscreteModel(DiscreteNaiveBayesModel 
discreteModel) {
+        this.discreteModel = discreteModel;
+        return this;
+    }
+
+    /** Sets feature ids to skip in Gaussian Bayes. */
+    public CompoundNaiveBayesModel 
withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) {
+        this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip;
+        return this;
+    }
+
+    /** Sets feature ids to skip in discrete Bayes. */
+    public CompoundNaiveBayesModel 
withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) {
+        this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip;
+        return this;
+    }
+
+    /** Returns index by index sum of two arrays. */
+    private static double[] sum(double[] arr1, double[] arr2) {
+        assert arr1.length == arr2.length;
+
+        double[] result = new double[arr1.length];
+
+        for (int i = 0; i < arr1.length; i++) {
+            result[i] = arr1[i] + arr2[i];
+        }
+        return result;
+    }
+
+    /** Returns a new (shorter) vector without features provided in {@param 
featureIdsToSkip}. */
+    private static Vector skipFeatures(Vector vector, Collection<Integer> 
featureIdsToSkip) {
+        int newSize = vector.size() - featureIdsToSkip.size();
+        double[] newFeaturesValues = new double[newSize];
+
+        int index = 0;
+        for (int j = 0; j < vector.size(); j++) {
+            if(featureIdsToSkip.contains(j)) continue;
+
+            newFeaturesValues[index] = vector.get(j);
+            ++index;
+        }
+        return VectorUtils.of(newFeaturesValues);
+    }
+}
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.java
new file mode 100644
index 0000000..9b8cd46
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainer.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.compound;
+
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.preprocessing.Preprocessor;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * Trainer for the compound Naive Bayes classifier model. It uses a model 
composition of {@code
+ * GaussianNaiveBayesTrainer} and {@code DiscreteNaiveBayesTrainer}. To 
distinguish which features with which trainer
+ * should be used, each trainer should have a collection of feature ids which 
should be skipped. It can be set by {@code
+ * #setFeatureIdsToSkip()} method.
+ */
+public class CompoundNaiveBayesTrainer extends 
SingleLabelDatasetTrainer<CompoundNaiveBayesModel> {
+
+    /** Prior probabilities of each class. */
+    private double[] priorProbabilities;
+
+    /** Gaussian Naive Bayes trainer. */
+    private GaussianNaiveBayesTrainer gaussianNaiveBayesTrainer;
+
+    /** Feature ids which should be skipped in Gaussian model. */
+    private Collection<Integer> gaussianFeatureIdsToSkip = 
Collections.emptyList();
+
+    /** Discrete Naive Bayes trainer. */
+    private DiscreteNaiveBayesTrainer discreteNaiveBayesTrainer;
+
+    /** Feature ids which should be skipped in Discrete model. */
+    private Collection<Integer> discreteFeatureIdsToSkip = 
Collections.emptyList();
+
+    /** {@inheritDoc} */
+    @Override public <K, V> CompoundNaiveBayesModel fit(DatasetBuilder<K, V> 
datasetBuilder,
+        Preprocessor<K, V> extractor) {
+        return updateModel(null, datasetBuilder, extractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean isUpdateable(CompoundNaiveBayesModel mdl) {
+        return gaussianNaiveBayesTrainer.isUpdateable(mdl.getGaussianModel())
+            && discreteNaiveBayesTrainer.isUpdateable(mdl.getDiscreteModel());
+    }
+
+    /** {@inheritDoc} */
+    @Override public CompoundNaiveBayesTrainer 
withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return 
(CompoundNaiveBayesTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> CompoundNaiveBayesModel 
updateModel(CompoundNaiveBayesModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
+
+        CompoundNaiveBayesModel compoundModel = new CompoundNaiveBayesModel()
+            .wirhPriorProbabilities(priorProbabilities);
+
+        if (gaussianNaiveBayesTrainer != null) {
+            if (priorProbabilities != null) {
+                
gaussianNaiveBayesTrainer.setPriorProbabilities(priorProbabilities);
+            }
+            GaussianNaiveBayesModel model = (mdl == null)
+                ? gaussianNaiveBayesTrainer.fit(datasetBuilder, 
extractor.map(skipFeatures(gaussianFeatureIdsToSkip)))
+                : gaussianNaiveBayesTrainer.update(mdl.getGaussianModel(), 
datasetBuilder, extractor.map(skipFeatures(gaussianFeatureIdsToSkip)));
+
+            compoundModel.withGaussianModel(model)
+                .withGaussianFeatureIdsToSkip(gaussianFeatureIdsToSkip)
+                .withLabels(model.getLabels())
+                .wirhPriorProbabilities(priorProbabilities);
+        }
+
+        if (discreteNaiveBayesTrainer != null) {
+            if (priorProbabilities != null) {
+                
discreteNaiveBayesTrainer.setPriorProbabilities(priorProbabilities);
+            }
+            DiscreteNaiveBayesModel model = (mdl == null)
+                ? discreteNaiveBayesTrainer.fit(datasetBuilder, 
extractor.map(skipFeatures(discreteFeatureIdsToSkip)))
+                : discreteNaiveBayesTrainer.update(mdl.getDiscreteModel(), 
datasetBuilder, extractor.map(skipFeatures(discreteFeatureIdsToSkip)));
+
+            compoundModel.withDiscreteModel(model)
+                .withDiscreteFeatureIdsToSkip(discreteFeatureIdsToSkip)
+                .withLabels(model.getLabels())
+                .wirhPriorProbabilities(priorProbabilities);
+        }
+
+        return compoundModel;
+    }
+
+    /** Sets prior probabilities. */
+    public CompoundNaiveBayesTrainer withPriorProbabilities(double[] 
priorProbabilities) {
+        this.priorProbabilities = priorProbabilities.clone();
+        return this;
+    }
+
+    /** Sets a gaussian trainer. */
+    public CompoundNaiveBayesTrainer 
withGaussianNaiveBayesTrainer(GaussianNaiveBayesTrainer 
gaussianNaiveBayesTrainer) {
+        this.gaussianNaiveBayesTrainer = gaussianNaiveBayesTrainer;
+        return this;
+    }
+
+    /** Sets a discrete trainer. */
+    public CompoundNaiveBayesTrainer 
withDiscreteNaiveBayesTrainer(DiscreteNaiveBayesTrainer 
discreteNaiveBayesTrainer) {
+        this.discreteNaiveBayesTrainer = discreteNaiveBayesTrainer;
+        return this;
+    }
+
+    /** Sets feature ids to skip in Gaussian Bayes. */
+    public CompoundNaiveBayesTrainer 
withGaussianFeatureIdsToSkip(Collection<Integer> gaussianFeatureIdsToSkip) {
+        this.gaussianFeatureIdsToSkip = gaussianFeatureIdsToSkip;
+        return this;
+    }
+
+    /** Sets feature ids to skip in discrete Bayes. */
+    public CompoundNaiveBayesTrainer 
withDiscreteFeatureIdsToSkip(Collection<Integer> discreteFeatureIdsToSkip) {
+        this.discreteFeatureIdsToSkip = discreteFeatureIdsToSkip;
+        return this;
+    }
+
+    /** Removes features provided in {@param featureIdsToSkip} from a vector. 
*/
+    private static IgniteFunction<LabeledVector<Object>, 
LabeledVector<Object>> skipFeatures(Collection<Integer> featureIdsToSkip) {
+        return featureValues -> {
+            final int size = featureValues.features().size();
+            int newSize = size - featureIdsToSkip.size();
+
+            double[] newFeaturesValues = new double[newSize];
+            int index = 0;
+            for (int j = 0; j < size; j++) {
+                if(featureIdsToSkip.contains(j)) continue;
+
+                newFeaturesValues[index] = featureValues.get(j);
+                ++index;
+            }
+            return new LabeledVector<>(VectorUtils.of(newFeaturesValues), 
featureValues.label());
+        };
+    }
+}
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/package-info.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/package-info.java
new file mode 100644
index 0000000..0806c48
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains Compound naive Bayes classifier.
+ */
+package org.apache.ignite.ml.naivebayes.compound;
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
index 3c35841..44d0767 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
@@ -17,20 +17,17 @@
 
 package org.apache.ignite.ml.naivebayes.discrete;
 
-import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
-import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-
-import java.io.Serializable;
+import org.apache.ignite.ml.naivebayes.BayesModel;
 
 /**
  * Discrete naive Bayes model which predicts result value {@code y} belongs to 
a class {@code C_k, k in [0..K]} as
  * {@code p(C_k,y) =x_1*p_k1^x *...*x_i*p_ki^x_i}. Where {@code x_i} is a 
discrete feature, {@code p_ki} is a prior
  * probability probability of class {@code p(x|C_k)}. Returns the number of 
the most possible class.
  */
-public class DiscreteNaiveBayesModel implements IgniteModel<Vector, Double>, 
Exportable<DiscreteNaiveBayesModel>, Serializable {
-    /** */
+public class DiscreteNaiveBayesModel implements 
BayesModel<DiscreteNaiveBayesModel, Vector, Double> {
+    /** Serial version uid. */
     private static final long serialVersionUID = -127386523291350345L;
 
     /**
@@ -58,9 +55,9 @@ public class DiscreteNaiveBayesModel implements 
IgniteModel<Vector, Double>, Exp
     /**
      * @param probabilities Probabilities of features for classes.
      * @param clsProbabilities Prior probabilities for classes.
+     * @param labels Labels.
      * @param bucketThresholds The threshold to convert a feature to a binary 
value.
      * @param sumsHolder Amount values which are abouve the threshold per 
label.
-     * @param labels Labels.
      */
     public DiscreteNaiveBayesModel(double[][][] probabilities, double[] 
clsProbabilities, double[] labels,
         double[][] bucketThresholds, DiscreteNaiveBayesSumsHolder sumsHolder) {
@@ -81,42 +78,57 @@ public class DiscreteNaiveBayesModel implements 
IgniteModel<Vector, Double>, Exp
      * @return A label with max probability.
      */
     @Override public Double predict(Vector vector) {
-        double maxProbapilityPower = -Double.MAX_VALUE;
-        int maxLabelIndex = -1;
+        double[] probapilityPowers = probabilityPowers(vector);
 
-        for (int i = 0; i < clsProbabilities.length; i++) {
-            double probabilityPower = Math.log(clsProbabilities[i]);
+        int maxLabelIndex = 0;
+        for (int i = 0; i < probapilityPowers.length; i++) {
+            probapilityPowers[i] += Math.log(clsProbabilities[i]);
+
+            if (probapilityPowers[i] > probapilityPowers[maxLabelIndex]) {
+                maxLabelIndex = i;
+            }
+        }
+
+        return labels[maxLabelIndex];
+    }
 
+    /** {@inheritDoc} */
+    @Override
+    public double[] probabilityPowers(Vector vector) {
+        double[] probapilityPowers = new double[clsProbabilities.length];
+
+        for (int i = 0; i < clsProbabilities.length; i++) {
             for (int j = 0; j < probabilities[0].length; j++) {
                 int x = toBucketNumber(vector.get(j), bucketThresholds[j]);
                 double p = probabilities[i][j][x];
-                probabilityPower += (p > 0 ? Math.log(p) : .0);
-            }
-
-            if (probabilityPower > maxProbapilityPower) {
-                maxLabelIndex = i;
-                maxProbapilityPower = probabilityPower;
+                probapilityPowers[i] += (p > 0 ? Math.log(p) : .0);
             }
         }
-        return labels[maxLabelIndex];
+
+        return probapilityPowers;
     }
 
-    /** */
+    /** A getter for probabilities.*/
     public double[][][] getProbabilities() {
         return probabilities;
     }
 
-    /** */
+    /** A getter for clsProbabilities.*/
     public double[] getClsProbabilities() {
         return clsProbabilities;
     }
 
-    /** */
+    /** A getter for bucketThresholds.*/
     public double[][] getBucketThresholds() {
         return bucketThresholds;
     }
 
-    /** */
+    /** A getter for labels.*/
+    public double[] getLabels() {
+        return labels;
+    }
+
+    /** A getter for sumsHolder.*/
     public DiscreteNaiveBayesSumsHolder getSumsHolder() {
         return sumsHolder;
     }
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
index 9ea18ca..50b335e 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
@@ -49,7 +49,7 @@ public class DiscreteNaiveBayesSumsHolder implements 
AutoCloseable, Serializable
         return arr1;
     }
 
-    /** */
+    /** {@inheritDoc} */
     @Override public void close() {
         // Do nothing, GC will clean up.
     }
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
index f5dd6e6..a80c402 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
@@ -115,7 +115,7 @@ public class DiscreteNaiveBayesTrainer extends 
SingleLabelDatasetTrainer<Discret
             })) {
             DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, 
(a, b) -> {
                 if (a == null)
-                    return b == null ? new DiscreteNaiveBayesSumsHolder() : b;
+                    return b;
                 if (b == null)
                     return a;
                 return a.merge(b);
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
index bef2e39..29b8cc2 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
@@ -17,18 +17,16 @@
 
 package org.apache.ignite.ml.naivebayes.gaussian;
 
-import java.io.Serializable;
-import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
-import org.apache.ignite.ml.IgniteModel;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.naivebayes.BayesModel;
 
 /**
  * Simple naive Bayes model which predicts result value {@code y} belongs to a 
class {@code C_k, k in [0..K]} as {@code
  * p(C_k,y) = p(C_k)*p(y_1,C_k) *...*p(y_n,C_k) / p(y)}. Return the number of 
the most possible class.
  */
-public class GaussianNaiveBayesModel implements IgniteModel<Vector, Double>, 
Exportable<GaussianNaiveBayesModel>, Serializable {
-    /** */
+public class GaussianNaiveBayesModel implements 
BayesModel<GaussianNaiveBayesModel, Vector, Double> {
+    /** Serial version uid. */
     private static final long serialVersionUID = -127386523291350345L;
 
     /** Means of features for all classes. kth row contains means for 
labels[k] class. */
@@ -69,48 +67,61 @@ public class GaussianNaiveBayesModel implements 
IgniteModel<Vector, Double>, Exp
 
     /** Returns a number of class to which the input belongs. */
     @Override public Double predict(Vector vector) {
-        int k = classProbabilities.length;
+        double[] probapilityPowers = probabilityPowers(vector);
 
-        double maxProbability = .0;
         int max = 0;
+        for (int i = 0; i < probapilityPowers.length; i++) {
+            probapilityPowers[i] += Math.log(classProbabilities[i]);
 
-        for (int i = 0; i < k; i++) {
-            double p = classProbabilities[i];
-            for (int j = 0; j < vector.size(); j++) {
-                double x = vector.get(j);
-                double g = gauss(x, means[i][j], variances[i][j]);
-                p *= g;
-            }
-            if (p > maxProbability) {
+            if (probapilityPowers[i] > probapilityPowers[max]) {
                 max = i;
-                maxProbability = p;
             }
         }
         return labels[max];
     }
 
-    /** */
+    /** {@inheritDoc} */
+    @Override
+    public double[] probabilityPowers(Vector vector) {
+        double[] probapilityPowers = new double[classProbabilities.length];
+
+        for (int i = 0; i < classProbabilities.length; i++) {
+            for (int j = 0; j < vector.size(); j++) {
+                double x = vector.get(j);
+                double parobability = gauss(x, means[i][j], variances[i][j]);
+                probapilityPowers[i] += (parobability > 0 ? 
Math.log(parobability) : .0);
+            }
+        }
+        return probapilityPowers;
+    }
+
+    /** A getter for means.*/
     public double[][] getMeans() {
         return means;
     }
 
-    /** */
+    /** A getter for variances.*/
     public double[][] getVariances() {
         return variances;
     }
 
-    /** */
+    /** A getter for classProbabilities.*/
     public double[] getClassProbabilities() {
         return classProbabilities;
     }
 
-    /** */
+    /** A getter for labels.*/
+    public double[] getLabels() {
+        return labels;
+    }
+
+    /** A getter for sumsHolder.*/
     public GaussianNaiveBayesSumsHolder getSumsHolder() {
         return sumsHolder;
     }
 
-    /** Gauss distribution */
-    private double gauss(double x, double mean, double variance) {
+    /** Gauss distribution. */
+    private static double gauss(double x, double mean, double variance) {
         return Math.exp(-1. * Math.pow(x - mean, 2) / (2. * variance)) / 
Math.sqrt(2. * Math.PI * variance);
     }
 }
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
index eecec74..7b95ff8 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
@@ -51,7 +51,7 @@ class GaussianNaiveBayesSumsHolder implements Serializable, 
AutoCloseable {
         return arr1;
     }
 
-    /** */
+    /** {@inheritDoc} */
     @Override public void close() {
         // Do nothing, GC will clean up.
     }
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
index 40ca840..84a16ad 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
@@ -106,7 +106,7 @@ public class GaussianNaiveBayesTrainer extends 
SingleLabelDatasetTrainer<Gaussia
         )) {
             GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, 
(a, b) -> {
                 if (a == null)
-                    return b == null ? new GaussianNaiveBayesSumsHolder() : b;
+                    return b;
                 if (b == null)
                     return a;
                 return a.merge(b);
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
index 28e4c9a..6574902 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MLSandboxDatasets.java
@@ -59,7 +59,10 @@ public enum MLSandboxDatasets {
     
WHOLESALE_CUSTOMERS("modules/ml/src/main/resources/datasets/wholesale_customers.csv",
 true, ","),
 
     /** Fraud detection problem [part of whole dataset]. Could be found <a 
href="https://www.kaggle.com/mlg-ulb/creditcardfraud/";>here</a>. */
-    
FRAUD_DETECTION("modules/ml/src/main/resources/datasets/fraud_detection.csv", 
false, ",");
+    
FRAUD_DETECTION("modules/ml/src/main/resources/datasets/fraud_detection.csv", 
false, ","),
+
+    /** A dataset with discrete and continious features. */
+    MIXED_DATASET("modules/ml/src/main/resources/datasets/mixed_dataset.csv", 
true, ",");
 
     /** Filename. */
     private final String filename;
diff --git a/modules/ml/src/main/resources/datasets/mixed_dataset.csv 
b/modules/ml/src/main/resources/datasets/mixed_dataset.csv
new file mode 100644
index 0000000..5a0df2b
--- /dev/null
+++ b/modules/ml/src/main/resources/datasets/mixed_dataset.csv
@@ -0,0 +1,8 @@
+1, 6,    180, 12, 0, 0, 1, 1, 1
+1, 5.92, 190, 11, 1, 0, 1, 1, 0
+1, 5.58, 170, 12, 1, 1, 0, 0, 1
+1, 5.92, 165, 10, 1, 1, 0, 0, 0
+0, 5,    100,  6, 1, 0, 0, 1, 1
+0, 5.5,  150,  8, 1, 1, 0, 0, 1
+0, 5.42, 130,  7, 1, 1, 1, 1, 0
+0, 5.75, 150,  9, 1, 1, 0, 1, 0
\ No newline at end of file
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModelTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModelTest.java
new file mode 100644
index 0000000..227a179
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModelTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.compound;
+
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+import org.junit.Test;
+
+import static java.util.Arrays.asList;
+import static org.apache.ignite.ml.naivebayes.compound.Data.LABEL_2;
+import static 
org.apache.ignite.ml.naivebayes.compound.Data.binarizedDataThresholds;
+import static org.apache.ignite.ml.naivebayes.compound.Data.classProbabilities;
+import static org.apache.ignite.ml.naivebayes.compound.Data.labels;
+import static org.apache.ignite.ml.naivebayes.compound.Data.means;
+import static org.apache.ignite.ml.naivebayes.compound.Data.probabilities;
+import static org.apache.ignite.ml.naivebayes.compound.Data.variances;
+import static org.junit.Assert.assertEquals;
+
+/** Tests for {@link CompoundNaiveBayesModel} */
+public class CompoundNaiveBayesModelTest {
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Test. */
+    @Test
+    public void testPredictOnlyGauss() {
+        GaussianNaiveBayesModel gaussianModel =
+            new GaussianNaiveBayesModel(means, variances, classProbabilities, 
labels, null);
+
+        Vector observation = VectorUtils.of(6, 130, 8);
+
+        CompoundNaiveBayesModel model = new CompoundNaiveBayesModel()
+            .wirhPriorProbabilities(classProbabilities)
+            .withLabels(labels)
+            .withGaussianModel(gaussianModel);
+
+        assertEquals(LABEL_2, model.predict(observation), PRECISION);
+    }
+
+    /** Test. */
+    @Test
+    public void testPredictOnlyDiscrete() {
+        DiscreteNaiveBayesModel discreteModel =
+            new DiscreteNaiveBayesModel(probabilities, classProbabilities, 
labels, binarizedDataThresholds, null);
+
+        Vector observation = VectorUtils.of(1, 0, 1, 1, 0);
+
+        CompoundNaiveBayesModel model = new CompoundNaiveBayesModel()
+            .wirhPriorProbabilities(classProbabilities)
+            .withLabels(labels)
+            .withDiscreteModel(discreteModel);
+
+        assertEquals(LABEL_2, model.predict(observation), PRECISION);
+    }
+
+    /** Test. */
+    @Test
+    public void testPredictGausAndDiscrete() {
+        DiscreteNaiveBayesModel discreteModel =
+            new DiscreteNaiveBayesModel(probabilities, classProbabilities, 
labels, binarizedDataThresholds, null);
+
+        GaussianNaiveBayesModel gaussianModel =
+            new GaussianNaiveBayesModel(means, variances, classProbabilities, 
labels, null);
+
+        CompoundNaiveBayesModel model = new CompoundNaiveBayesModel()
+            .wirhPriorProbabilities(classProbabilities)
+            .withLabels(labels)
+            .withGaussianModel(gaussianModel)
+            .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7))
+            .withDiscreteModel(discreteModel)
+            .withDiscreteFeatureIdsToSkip( asList(0, 1, 2));
+
+        Vector observation = VectorUtils.of(6, 130, 8, 1, 0, 1, 1, 0);
+
+        assertEquals(LABEL_2, model.predict(observation), PRECISION);
+    }
+}
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTest.java
new file mode 100644
index 0000000..c4cef0e
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTest.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.compound;
+
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import 
org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.junit.Test;
+
+import static java.util.Arrays.asList;
+import static org.apache.ignite.ml.naivebayes.compound.Data.LABEL_1;
+import static org.apache.ignite.ml.naivebayes.compound.Data.LABEL_2;
+import static 
org.apache.ignite.ml.naivebayes.compound.Data.binarizedDataThresholds;
+import static org.apache.ignite.ml.naivebayes.compound.Data.classProbabilities;
+import static org.apache.ignite.ml.naivebayes.compound.Data.data;
+import static org.junit.Assert.assertEquals;
+
+/** Integration tests for Compound naive Bayes algorithm with different 
datasets. */
+public class CompoundNaiveBayesTest {
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Test. */
+    @Test
+    public void testLearnsAndPredictCorrently() {
+        CompoundNaiveBayesTrainer trainer = new CompoundNaiveBayesTrainer()
+            .withPriorProbabilities(classProbabilities)
+            .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer())
+            .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7))
+            .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer()
+                .setBucketThresholds(binarizedDataThresholds))
+            .withDiscreteFeatureIdsToSkip(asList(0, 1, 2));
+
+        CompoundNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, 2),
+                new 
DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
+        );
+
+        Vector observation1 = VectorUtils.of(5.92, 165, 10, 1, 1, 0, 0, 0);
+        assertEquals(LABEL_1, model.predict(observation1), PRECISION);
+
+        Vector observation2 = VectorUtils.of(6, 130, 8, 1, 0, 1, 1, 0);
+        assertEquals(LABEL_2, model.predict(observation2), PRECISION);
+    }
+}
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainerTest.java
new file mode 100644
index 0000000..9b40fbb
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesTrainerTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.compound;
+
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import 
org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.junit.Before;
+import org.junit.Test;
+
+import static java.util.Arrays.asList;
+import static 
org.apache.ignite.ml.naivebayes.compound.Data.binarizedDataThresholds;
+import static org.apache.ignite.ml.naivebayes.compound.Data.classProbabilities;
+import static org.apache.ignite.ml.naivebayes.compound.Data.data;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Test for {@link CompoundNaiveBayesTrainer} */
+public class CompoundNaiveBayesTrainerTest extends TrainerTest {
+
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /** Trainer under test. */
+    private CompoundNaiveBayesTrainer trainer;
+
+    /** Initialization {@code CompoundNaiveBayesTrainer}. */
+    @Before
+    public void createTrainer() {
+        trainer = new CompoundNaiveBayesTrainer()
+            .withPriorProbabilities(classProbabilities)
+            .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer())
+                .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7))
+            .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer()
+                .setBucketThresholds(binarizedDataThresholds))
+                .withDiscreteFeatureIdsToSkip(asList(0, 1, 2));
+    }
+
+    /** Test. */
+    @Test
+    public void test() {
+        CompoundNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+                new 
DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
+        );
+
+        assertDiscreteModel(model.getDiscreteModel());
+        assertGaussianModel(model.getGaussianModel());
+    }
+
+    /** Discrete model assertions. */
+    private void assertDiscreteModel(DiscreteNaiveBayesModel model) {
+        double[][][] expectedProbabilites = new double[][][] {
+            {
+                {.25, .75},
+                {.5, .5},
+                {.5, .5},
+                {.5, .5},
+                {.5, .5}
+            },
+            {
+                {.0, 1},
+                {.25, .75},
+                {.75, .25},
+                {.25, .75},
+                {.5, .5}
+            }
+        };
+
+        for (int i = 0; i < expectedProbabilites.length; i++) {
+            for (int j = 0; j < expectedProbabilites[i].length; j++)
+                assertArrayEquals(expectedProbabilites[i][j], 
model.getProbabilities()[i][j], PRECISION);
+        }
+        assertArrayEquals(new double[] {.5, .5}, model.getClsProbabilities(), 
PRECISION);
+    }
+
+    /** Gaussian model assertions. */
+    private void assertGaussianModel(GaussianNaiveBayesModel model) {
+        double[] priorProbabilities = new double[] {.5, .5};
+
+        assertEquals(priorProbabilities[0], model.getClassProbabilities()[0], 
PRECISION);
+        assertEquals(priorProbabilities[1], model.getClassProbabilities()[1], 
PRECISION);
+        assertArrayEquals(new double[] {5.855, 176.25, 11.25}, 
model.getMeans()[0], PRECISION);
+        assertArrayEquals(new double[] {5.4175, 132.5, 7.5}, 
model.getMeans()[1], PRECISION);
+        double[] expectedVars = {0.026274999999999, 92.1875, 0.6875};
+        assertArrayEquals(expectedVars, model.getVariances()[0], PRECISION);
+    }
+}
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/Data.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/Data.java
new file mode 100644
index 0000000..ee34f85
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/compound/Data.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.compound;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Data class which contains test data with precalculated statistics. */
+final class Data {
+
+    /** Private constructor. */
+    private Data() {
+    }
+
+    /** The first label. */
+    static final double LABEL_1 = 1.;
+
+    /** The second label. */
+    static final double LABEL_2 = 2.;
+
+    /** Labels. */
+    static final double[] labels = {LABEL_1, LABEL_2};
+
+    /** */
+    static final Map<Integer, double[]> data = new HashMap<>();
+
+    /** Means for gaussian data part. */
+    static double[][] means;
+
+    /** Variances for gaussian data part. */
+    static double[][] variances;
+
+    /** */
+    static double[] classProbabilities;
+
+    /** Thresholds to binarize discret data. */
+    static double[][] binarizedDataThresholds;
+
+    /** Discrete probabilities. */
+    static double[][][] probabilities;
+
+    static {
+        data.put(0, new double[] {6, 180, 12, 0, 0, 1, 1, 1, LABEL_1});
+        data.put(1, new double[] {5.92, 190, 11, 1, 0, 1, 1, 0, LABEL_1});
+        data.put(2, new double[] {5.58, 170, 12, 1, 1, 0, 0, 1, LABEL_1});
+        data.put(3, new double[] {5.92, 165, 10, 1, 1, 0, 0, 0, LABEL_1});
+
+        data.put(4, new double[] {5, 100, 6, 1, 0, 0, 1, 1, LABEL_2});
+        data.put(5, new double[] {5.5, 150, 8, 1, 1, 0, 0, 1, LABEL_2});
+        data.put(6, new double[] {5.42, 130, 7, 1, 1, 1, 1, 0, LABEL_2});
+        data.put(7, new double[] {5.75, 150, 9, 1, 1, 0, 1, 0, LABEL_2});
+
+        classProbabilities = new double[] {.5, .5};
+
+        means = new double[][] {
+            {5.855, 176.25, 11.25},
+            {5.4175, 132.5, 7.5},
+        };
+
+        variances = new double[][] {
+            {3.5033E-2, 1.2292E2, 9.1667E-1},
+            {9.7225E-2, 5.5833E2, 1.6667},
+        };
+
+        binarizedDataThresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, 
{.5}};
+
+        probabilities = new double[][][] {
+            {{.25, .75}, {.25, .75}, {.5, .5}, {.5, .5}, {.5, .5}},
+            {{0, 1}, {.25, .75}, {.75, .25}, {.25, .75}, {.5, .5}}
+        };
+    }
+
+}
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
index 41d320d..60d0cdd 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModelTest.java
@@ -24,7 +24,8 @@ import org.junit.Test;
 
 /** Tests for {@code DiscreteNaiveBayesModel} */
 public class DiscreteNaiveBayesModelTest {
-    /** */
+
+    /** Test. */
     @Test
     public void testPredictWithTwoClasses() {
         double first = 1;
@@ -36,7 +37,8 @@ public class DiscreteNaiveBayesModelTest {
 
         double[] classProbabilities = new double[] {6. / 13, 7. / 13};
         double[][] thresholds = new double[][] {{.5}, {.2, .7}, {.5}, {.5, 
1.5}, {.5}};
-        DiscreteNaiveBayesModel mdl = new 
DiscreteNaiveBayesModel(probabilities, classProbabilities, new double[] {first, 
second}, thresholds, new DiscreteNaiveBayesSumsHolder());
+        DiscreteNaiveBayesModel mdl = new 
DiscreteNaiveBayesModel(probabilities, classProbabilities,
+                new double[] {first, second}, thresholds, new 
DiscreteNaiveBayesSumsHolder());
         Vector observation = VectorUtils.of(2, 0, 1, 2, 0);
 
         Assert.assertEquals(second, mdl.predict(observation), 0.0001);
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
index a864fd6..7426f56 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainerTest.java
@@ -79,7 +79,7 @@ public class DiscreteNaiveBayesTrainerTest extends 
TrainerTest {
         data.put(11, new double[] {2, 1, .45, 748, LABEL_2});
     }
 
-    /** */
+    /** Trainer under test. */
     private DiscreteNaiveBayesTrainer trainer;
 
     /** Initialization {@code DiscreteNaiveBayesTrainer}. */
@@ -88,7 +88,7 @@ public class DiscreteNaiveBayesTrainerTest extends 
TrainerTest {
         trainer = new 
DiscreteNaiveBayesTrainer().setBucketThresholds(binarizedDatathresholds);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsCorrectLabelProbalities() {
 
@@ -101,7 +101,7 @@ public class DiscreteNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertArrayEquals(expectedProbabilities, 
model.getClsProbabilities(), PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() {
         DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
@@ -116,7 +116,7 @@ public class DiscreteNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertArrayEquals(new double[] {.5, .5}, 
model.getClsProbabilities(), PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() {
         double[] priorProbabilities = new double[] {.35, .65};
@@ -132,7 +132,7 @@ public class DiscreteNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertArrayEquals(priorProbabilities, 
model.getClsProbabilities(), PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsCorrectPriorProbabilities() {
         double[][][] expectedPriorProbabilites = new double[][][] {
@@ -151,7 +151,7 @@ public class DiscreteNaiveBayesTrainerTest extends 
TrainerTest {
         }
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsCorrectPriorProbabilitiesWithDefferentThresholds() {
         double[][][] expectedPriorProbabilites = new double[][][] {
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
index d35ea3d..043ef38 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
@@ -26,7 +26,8 @@ import org.junit.Test;
  * Tests for {@link GaussianNaiveBayesModel}.
  */
 public class GaussianNaiveBayesModelTest {
-    /** */
+
+    /** Test. */
     @Test
     public void testPredictWithTwoClasses() {
         double first = 1;
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
index 80927e9..d00cfc6 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
@@ -36,19 +36,19 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
     /** Precision in test checks. */
     private static final double PRECISION = 1e-2;
 
-    /** */
+    /** Label. */
     private static final double LABEL_1 = 1.;
 
-    /** */
+    /** Label. */
     private static final double LABEL_2 = 2.;
 
     /** Data. */
     private static final Map<Integer, double[]> data = new HashMap<>();
 
-    /** */
+    /** {@code LABEL_1} data. */
     private static final Map<Integer, double[]> singleLabeldata1 = new 
HashMap<>();
 
-    /** */
+    /** {@code LABEL_2} data. */
     private static final Map<Integer, double[]> singleLabeldata2 = new 
HashMap<>();
 
     static {
@@ -75,7 +75,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         trainer = new GaussianNaiveBayesTrainer();
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testWithLinearlySeparableData() {
         Map<Integer, double[]> cacheMock = new HashMap<>();
@@ -92,7 +92,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(10, 100)), 
PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsCorrectLabelProbalities() {
 
@@ -105,7 +105,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertEquals(2. / data.size(), 
model.getClassProbabilities()[1], PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() {
         GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer()
@@ -120,7 +120,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertEquals(.5, model.getClassProbabilities()[1], PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() {
         double[] priorProbabilities = new double[] {.35, .65};
@@ -136,7 +136,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertEquals(priorProbabilities[1], 
model.getClassProbabilities()[1], PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsCorrectMeans() {
 
@@ -148,7 +148,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertArrayEquals(new double[] {2.0, 2. / 3.}, 
model.getMeans()[0], PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testReturnsCorrectVariances() {
 
@@ -161,7 +161,7 @@ public class GaussianNaiveBayesTrainerTest extends 
TrainerTest {
         Assert.assertArrayEquals(expectedVars, model.getVariances()[0], 
PRECISION);
     }
 
-    /** */
+    /** Test. */
     @Test
     public void testUpdatigModel() {
         Vectorizer<Integer, double[], Integer, Double> vectorizer = new 
DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST);

Reply via email to