This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-math.git


The following commit(s) were added to refs/heads/master by this push:
     new 770882409 Allow fitting single column data
770882409 is described below

commit 7708824094a0213db07f09897686645437f580f8
Author: aherbert <aherb...@apache.org>
AuthorDate: Tue Mar 5 17:29:18 2024 +0000

    Allow fitting single column data
---
 ...ariateNormalMixtureExpectationMaximization.java |  6 +--
 ...teNormalMixtureExpectationMaximizationTest.java | 48 ++++++++++++++++++++++
 2 files changed, 51 insertions(+), 3 deletions(-)

diff --git 
a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
 
b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
index ee5993039..8b51195ab 100644
--- 
a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
+++ 
b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
@@ -81,7 +81,7 @@ public class MultivariateNormalMixtureExpectationMaximization 
{
      * @throws DimensionMismatchException if rows of data have different 
numbers
      *             of columns
      * @throws NumberIsTooSmallException if the number of columns in the data 
is
-     *             less than 2
+     *             less than 1
      */
     public MultivariateNormalMixtureExpectationMaximization(double[][] data)
         throws NotStrictlyPositiveException,
@@ -99,9 +99,9 @@ public class MultivariateNormalMixtureExpectationMaximization 
{
                 throw new DimensionMismatchException(data[i].length,
                                                      data[0].length);
             }
-            if (data[i].length < 2) {
+            if (data[i].length < 1) {
                 throw new 
NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
-                                                    data[i].length, 2, true);
+                                                    data[i].length, 1, true);
             }
             this.data[i] = Arrays.copyOf(data[i], data[i].length);
         }
diff --git 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
index a2fe49684..7a1e84282 100644
--- 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
+++ 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
@@ -17,6 +17,7 @@
 package org.apache.commons.math4.legacy.distribution.fitting;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 import 
org.apache.commons.math4.legacy.distribution.MixtureMultivariateNormalDistribution;
@@ -241,6 +242,53 @@ public class 
MultivariateNormalMixtureExpectationMaximizationTest {
         }
     }
 
+    @Test
+    public void testFit1() {
+        // Test that the fit can be performed on data with a single dimension
+        // Use only the first column of the test data
+        final double[][] data = Arrays.stream(getTestSamples())
+            .map(x -> new double[] {x[0]}).toArray(double[][]::new);
+
+        // Fit the first column of test samples using Matlab R2023b (Update 6):
+        // GMModel = fitgmdist(X,2);
+
+        // NegativeLogLikelihood (CM code use the positive log-likehood 
divided by the number of observations)
+        final double correctLogLikelihood = -2.512197016873482e+02 / 
data.length;
+        // ComponentProportion
+        final double[] correctWeights = new double[] {0.240510201974078, 
0.759489798025922};
+        // Since data has 1 dimension the means and covariances are single 
values
+        // mu
+        final double[] correctMeans = new double[] {-1.736139126623031, 
3.899886984922886};
+        // Sigma
+        final double[] correctCov = new double[] {1.371327786710623, 
5.254286022455004};
+
+        MultivariateNormalMixtureExpectationMaximization fitter
+            = new MultivariateNormalMixtureExpectationMaximization(data);
+
+        MixtureMultivariateNormalDistribution initialMix
+            = MultivariateNormalMixtureExpectationMaximization.estimate(data, 
2);
+        fitter.fit(initialMix);
+        MixtureMultivariateNormalDistribution fittedMix = 
fitter.getFittedModel();
+        List<Pair<Double, MultivariateNormalDistribution>> components = 
fittedMix.getComponents();
+
+        final double relError = 0.05;
+        Assert.assertEquals(correctLogLikelihood,
+                            fitter.getLogLikelihood(),
+                            Math.abs(correctLogLikelihood) * relError);
+
+        int i = 0;
+        for (Pair<Double, MultivariateNormalDistribution> component : 
components) {
+            final double weight = component.getFirst();
+            final MultivariateNormalDistribution mvn = component.getSecond();
+            final double[] mean = mvn.getMeans();
+            final RealMatrix covMat = mvn.getCovariances();
+            Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * 
relError);
+            Assert.assertEquals(correctMeans[i], mean[0], 
Math.abs(correctMeans[i]) * relError);
+            Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), 
correctCov[i] * relError);
+            i++;
+        }
+    }
+
     private double[][] getTestSamples() {
         // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
         // [4, 8.2]

Reply via email to