This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-math.git
The following commit(s) were added to refs/heads/master by this push: new 770882409 Allow fitting single column data 770882409 is described below commit 7708824094a0213db07f09897686645437f580f8 Author: aherbert <aherb...@apache.org> AuthorDate: Tue Mar 5 17:29:18 2024 +0000 Allow fitting single column data --- ...ariateNormalMixtureExpectationMaximization.java | 6 +-- ...teNormalMixtureExpectationMaximizationTest.java | 48 ++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java index ee5993039..8b51195ab 100644 --- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java +++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java @@ -81,7 +81,7 @@ public class MultivariateNormalMixtureExpectationMaximization { * @throws DimensionMismatchException if rows of data have different numbers * of columns * @throws NumberIsTooSmallException if the number of columns in the data is - * less than 2 + * less than 1 */ public MultivariateNormalMixtureExpectationMaximization(double[][] data) throws NotStrictlyPositiveException, @@ -99,9 +99,9 @@ public class MultivariateNormalMixtureExpectationMaximization { throw new DimensionMismatchException(data[i].length, data[0].length); } - if (data[i].length < 2) { + if (data[i].length < 1) { throw new NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL, - data[i].length, 2, true); + data[i].length, 1, true); } this.data[i] = Arrays.copyOf(data[i], data[i].length); } diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java index a2fe49684..7a1e84282 100644 --- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java +++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java @@ -17,6 +17,7 @@ package org.apache.commons.math4.legacy.distribution.fitting; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution; @@ -241,6 +242,53 @@ public class MultivariateNormalMixtureExpectationMaximizationTest { } } + @Test + public void testFit1() { + // Test that the fit can be performed on data with a single dimension + // Use only the first column of the test data + final double[][] data = Arrays.stream(getTestSamples()) + .map(x -> new double[] {x[0]}).toArray(double[][]::new); + + // Fit the first column of test samples using Matlab R2023b (Update 6): + // GMModel = fitgmdist(X,2); + + // NegativeLogLikelihood (CM code use the positive log-likehood divided by the number of observations) + final double correctLogLikelihood = -2.512197016873482e+02 / data.length; + // ComponentProportion + final double[] correctWeights = new double[] {0.240510201974078, 0.759489798025922}; + // Since data has 1 dimension the means and covariances are single values + // mu + final double[] correctMeans = new double[] {-1.736139126623031, 3.899886984922886}; + // Sigma + final double[] correctCov = new double[] {1.371327786710623, 5.254286022455004}; + + MultivariateNormalMixtureExpectationMaximization fitter + = new MultivariateNormalMixtureExpectationMaximization(data); + + MixtureMultivariateNormalDistribution initialMix + = MultivariateNormalMixtureExpectationMaximization.estimate(data, 2); + fitter.fit(initialMix); + MixtureMultivariateNormalDistribution fittedMix = fitter.getFittedModel(); + List<Pair<Double, MultivariateNormalDistribution>> components = fittedMix.getComponents(); + + final double relError = 0.05; + Assert.assertEquals(correctLogLikelihood, + fitter.getLogLikelihood(), + Math.abs(correctLogLikelihood) * relError); + + int i = 0; + for (Pair<Double, MultivariateNormalDistribution> component : components) { + final double weight = component.getFirst(); + final MultivariateNormalDistribution mvn = component.getSecond(); + final double[] mean = mvn.getMeans(); + final RealMatrix covMat = mvn.getCovariances(); + Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * relError); + Assert.assertEquals(correctMeans[i], mean[0], Math.abs(correctMeans[i]) * relError); + Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), correctCov[i] * relError); + i++; + } + } + private double[][] getTestSamples() { // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and // [4, 8.2]