This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0eba4dc  [SYSTEMDS-2735] Builtin function gmmPredict for clustering 
instances
0eba4dc is described below

commit 0eba4dcdd3d92c91b5192e1e7d2d84cff5326068
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Sat Jan 23 23:58:03 2021 +0100

    [SYSTEMDS-2735] Builtin function gmmPredict for clustering instances
    
    Closes #1108.
---
 scripts/builtin/gmmPredict.dml                     | 108 +++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../functions/builtin/BuiltinGMMPredictTest.java   | 150 +++++++++++++++++++++
 src/test/scripts/functions/builtin/GMM_Predict.dml |  54 ++++++++
 4 files changed, 313 insertions(+)

diff --git a/scripts/builtin/gmmPredict.dml b/scripts/builtin/gmmPredict.dml
new file mode 100644
index 0000000..e054902
--- /dev/null
+++ b/scripts/builtin/gmmPredict.dml
@@ -0,0 +1,108 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# ------------------------------------------
+# Gaussian Mixture Model Predict
+# ------------------------------------------
+
+# INPUT PARAMETERS:
+# 
---------------------------------------------------------------------------------------------
+# NAME                   TYPE     DEFAULT   MEANING
+# 
---------------------------------------------------------------------------------------------
+# X                      Double   ---       Matrix X (instances to be 
clustered)
+# weight                 Double   ---       Weight of learned model
+# mu                     Double   ---       fitted clusters mean
+# precisions_cholesky    Double   ---       fitted precision matrix for each 
mixture
+# model                  String   ---       fitted model
+# 
---------------------------------------------------------------------------------------------
+
+# OUTPUT:
+# 
---------------------------------------------------------------------------------------------
+# NAME                   TYPE     DEFAULT   MEANING
+# 
---------------------------------------------------------------------------------------------
+# predict                Double   ---       predicted cluster labels
+# posterior_prob         Double   ---       probabilities of belongingness
+# 
---------------------------------------------------------------------------------------------
+
+# compute posterior probabilities for new instances given the variance and 
mean of fitted data
+
+m_gmmPredict = function(Matrix[Double] X, Matrix[Double] weight,
+  Matrix[Double] mu, Matrix[Double] precisions_cholesky, String model)
+  return(Matrix[Double] predict, Matrix[Double] posterior_prob)
+{
+  # compute the posterior probabilities for new instances
+  weighted_log_prob =  compute_log_gaussian_prob(X, mu, precisions_cholesky, 
model) + log(weight)
+  log_prob_norm = logSumExp(weighted_log_prob, "rows")
+  log_resp = weighted_log_prob - log_prob_norm
+  posterior_prob = exp(log_resp)
+  predict =  rowIndexMax(weighted_log_prob)
+}
+
+compute_log_gaussian_prob = function(Matrix[Double] X, Matrix[Double] mu,
+  Matrix[Double] prec_chol, String model)
+  return(Matrix[Double] es_log_prob ) # nrow(X) * n_components
+{
+  n_components = nrow(mu)
+  d = ncol(X)
+
+  if(model == "VVV") { 
+    log_prob = matrix(0, nrow(X), n_components) # log probabilities
+    log_det_chol = matrix(0, 1, n_components)   # log determinant 
+    i = 1
+    for(k in 1:n_components) {
+      prec = prec_chol[i:(k*ncol(X)),]
+      y = X %*% prec - mu[k,] %*% prec
+      log_prob[, k] = rowSums(y*y)
+      # compute log_det_cholesky
+      log_det = sum(log(diag(t(prec))))
+      log_det_chol[1,k] = log_det
+      i = i + ncol(X)
+    }
+  }
+  else if(model == "EEE") { 
+    log_prob = matrix(0, nrow(X), n_components) 
+    log_det_chol = as.matrix(sum(log(diag(prec_chol))))
+    prec = prec_chol
+    for(k in 1:n_components) {
+      y = X %*% prec - mu[k,] %*% prec
+      log_prob[, k] = rowSums(y*y) 
+    }
+  }
+  else if(model ==  "VVI") {
+    log_det_chol = t(rowSums(log(prec_chol)))
+    prec = prec_chol
+    precisions = prec^2
+    bc_matrix = matrix(1,nrow(X), nrow(mu))
+    log_prob = (bc_matrix*t(rowSums(mu^2 * precisions)) 
+      - 2 * (X %*% t(mu * precisions)) + X^2 %*% t(precisions))
+  }
+  else if (model == "VII") {
+    log_det_chol = t(d * log(prec_chol))
+    prec = prec_chol
+    precisions = prec^ 2
+    bc_matrix = matrix(1,nrow(X), nrow(mu))
+    log_prob = (bc_matrix * t(rowSums(mu^2) * precisions) 
+      - 2 * X %*% t(mu * precisions) + rowSums(X*X) %*% t(precisions) )
+  }
+  if(ncol(log_det_chol) == 1)
+    log_det_chol = matrix(1, 1, ncol(log_prob)) * log_det_chol
+
+  es_log_prob = -.5 * (ncol(X) * log(2 * pi) + log_prob) + log_det_chol
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 5010c21..a1f372c 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -116,6 +116,7 @@ public enum Builtins {
        GET_PERMUTATIONS("getPermutations", true),
        GLM("glm", true),
        GMM("gmm", true),
+       GMM_PREDICT("gmmPredict", true),
        GNMF("gnmf", true),
        GRID_SEARCH("gridSearch", true),
        HYPERBAND("hyperband", true),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
new file mode 100644
index 0000000..f0d2cc7
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGMMPredictTest.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinGMMPredictTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "GMM_Predict";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinGMMPredictTest.class.getSimpleName() + "/";
+
+       private final static double eps = 2;
+       private final static double tol = 1e-3;
+       private final static double tol2 = 1e-5;
+
+       private final static String DATASET = SCRIPT_DIR + 
"functions/transform/input/iris/iris.csv";
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
+       }
+
+       @Test
+       public void testGMMMPredictCP1() {
+               runGMMPredictTest(3, "VVI", "random", 10,
+                       0.000000001, tol,42,true, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void testGMMMPredictCP2() {
+               runGMMPredictTest(3, "VII", "random", 50,
+                       0.000001, tol2,42,true, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void testGMMMPredictCPKmean1() {
+               runGMMPredictTest(3, "VVV", "kmeans", 10,
+                       0.0000001, tol,42,true, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void testGMMMPredictCPKmean2() {
+               runGMMPredictTest(3, "EEE", "kmeans", 150,
+                       0.000001, tol,42,true, LopProperties.ExecType.CP);
+       }
+
+       @Test
+       public void testGMMMPredictCPKmean3() {
+               runGMMPredictTest(3, "VII", "kmeans", 50,
+                       0.000001, tol2,42,true, LopProperties.ExecType.CP);
+       }
+
+//     @Test
+//     public void testGMMM1Spark() {
+//             runGMMPredictTest(3, "VVV", "random", 10,
+//             0.0000001, tol,42,true, LopProperties.ExecType.SPARK); }
+//
+//     @Test
+//     public void testGMMM2Spark() {
+//             runGMMPredictTest(3, "EEE", "random", 50,
+//                     0.0000001, tol,42,true, LopProperties.ExecType.CP);
+//     }
+//
+//     @Test
+//     public void testGMMMS3Spark() {
+//             runGMMPredictTest(3, "VVI", "random", 100,
+//                     0.000001, tol,42,true, LopProperties.ExecType.CP);
+//     }
+//
+//     @Test
+//     public void testGMMM4Spark() {
+//             runGMMPredictTest(3, "VII", "random", 100,
+//                     0.000001, tol1,42,true, LopProperties.ExecType.CP);
+//     }
+//
+//     @Test
+//     public void testGMMM1KmeanSpark() {
+//             runGMMPredictTest(3, "VVV", "kmeans", 100,
+//                     0.000001, tol2,42,false, LopProperties.ExecType.SPARK);
+//     }
+//
+//     @Test
+//     public void testGMMM2KmeanSpark() {
+//             runGMMPredictTest(3, "EEE", "kmeans", 50,
+//                     0.00000001, tol1,42,false, 
LopProperties.ExecType.SPARK);
+//     }
+//
+//     @Test
+//     public void testGMMM3KmeanSpark() {
+//             runGMMPredictTest(3, "VVI", "kmeans", 100,
+//                     0.000001, tol,42,false, LopProperties.ExecType.SPARK);
+//     }
+//
+//     @Test
+//     public void testGMMM4KmeanSpark() {
+//             runGMMPredictTest(3, "VII", "kmeans", 100,
+//                     0.000001, tol,42,false, LopProperties.ExecType.SPARK);
+//     }
+
+       private void runGMMPredictTest(int G_mixtures, String model, String 
init_param, int iter,
+               double reg, double tol, int seed, boolean rewrite, 
LopProperties.ExecType instType) {
+
+               Types.ExecMode platformOld = setExecMode(instType);
+               boolean rewriteOld = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       String outFile = output("O");
+                       System.out.println(outFile);
+                       programArgs = new String[] {"-args", DATASET,
+                               String.valueOf(G_mixtures), model, init_param, 
String.valueOf(iter), String.valueOf(reg),
+                               String.valueOf(tol), String.valueOf(seed), 
outFile};
+
+                       runTest(true, false, null, -1);
+                       // compare results
+                       double accuracy = TestUtils.readDMLScalar(outFile);
+                       Assert.assertEquals(1, accuracy, eps);
+               }
+               finally {
+                       resetExecMode(platformOld);
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewriteOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/GMM_Predict.dml 
b/src/test/scripts/functions/builtin/GMM_Predict.dml
new file mode 100644
index 0000000..283db7c
--- /dev/null
+++ b/src/test/scripts/functions/builtin/GMM_Predict.dml
@@ -0,0 +1,54 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1, data_type = "frame", format = "csv", header=TRUE)
+X = X[ , 2:ncol(X) - 1]
+X = as.matrix(X)
+
+# divide in train and test set
+train = X[1:45,]
+train = rbind(train, X[52:95,])
+train = rbind(train, X[102:145,])
+
+test = X[46:51,]
+test = rbind(test, X[96:101,])
+test = rbind(test, X[146:150,])
+
+# train GMM
+[labels, prob, df, bic, mu, prec_chol, w] = gmm(X=train, n_components = $2,
+  model = $3, init_params = $4, iter = $5, reg_covar = $6, tol = $7, seed=$8, 
verbose=TRUE)
+ 
+# predict labels
+[pred, pp] = gmmPredict(test, w, mu, prec_chol, $3)  
+
+# expected clusters/predictions
+expected = matrix("6 6 5", 3, 1)
+
+resp = matrix(1, 17, 3) * t(seq(1,3))
+resp = resp == pred
+cluster = t(colSums(resp))
+
+cluster = order(target = cluster, by = 1, decreasing = FALSE, 
index.return=FALSE)
+correct_Predictions = order(target = expected, by = 1, decreasing = FALSE, 
index.return=FALSE)
+
+error = mean(abs(correct_Predictions - cluster))
+write(error, $9, format = "text")

Reply via email to