This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 422cc0421985d88491c1a58c11652693c77f2417
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Sep 3 22:13:42 2022 +0200

    [SYSTEMDS-3435] Fix robustness of multi-class svm for missing classes
    
    This patch fixes issues of multi-class SVM with scenarios where the
    training data does not contain all classes in [1,max class ID]. We
    now only train one-against-the-rest models if a class exists in at
    least one training example.
    
    Furthermore, this patch also adds support for ternary ifelse scalar
    operations with string arguments which so far always evaluated to "0.0",
    but was used in verbose debug output of MSVM calling L2SVM.
    
    Finally, the documentation of MSVM and L2SVM has been cleaned up to
    make the descriptions more understandable, add expected shapes, and
    remove non-existing parameters.
---
 scripts/builtin/l2svm.dml                          | 53 ++++++++++------------
 scripts/builtin/msvm.dml                           | 49 +++++++++++---------
 .../instructions/cp/TernaryCPInstruction.java      | 22 ++++++---
 3 files changed, 68 insertions(+), 56 deletions(-)

diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml
index cdcc1ba4df..8b25915625 100644
--- a/scripts/builtin/l2svm.dml
+++ b/scripts/builtin/l2svm.dml
@@ -19,27 +19,29 @@
 #
 #-------------------------------------------------------------
 
-# Builtin function Implements binary-class SVM with squared slack variables
+# This builting function implements binary-class Support Vector Machine (SVM)
+# with squared slack variables (l2 regularization).
 #
 # INPUT:
-# 
-----------------------------------------------------------------------------------------
-# X              matrix X of feature vectors
-# Y              matrix Y of class labels have to be a single column
-# intercept      No Intercept ( If set to TRUE then a constant bias column is 
added to X)
-# epsilon        Procedure terminates early if the reduction in objective 
function value is less
-#                than epsilon (tolerance) times the initial objective function 
value.
-# reg            Regularization parameter (reg) for L2 regularization
-# maxIterations  Maximum number of conjugate gradient iterations
-# maxii          max inner for loop iterations
-# verbose        Set to true if one wants print statements updating on loss.
-# columnId       The column Id used if one wants to add a ID to the print 
statement, 
+# 
------------------------------------------------------------------------------
+# X              Feature matrix X (shape: m x n)
+# Y              Label vector y of class labels (shape: m x 1), assumed binary
+#                in -1/+1 or 1/2 encoding.
+# intercept      Indicator if a bias column should be added to X and the model
+# epsilon        Tolerance for early termination if the reduction of objective
+#                function is less than epsilon times the initial objective
+# reg            Regularization parameter (lambda) for L2 regularization
+# maxIterations  Maximum number of conjugate gradient (outer) iterations
+# maxii          Maximum number of line search (inner) iterations
+# verbose        Indicator if training details should be printed
+# columnId       An optional class ID used in verbose print output,
 #                eg. used when L2SVM is used in MSVM.
-# 
-----------------------------------------------------------------------------------------
+# 
------------------------------------------------------------------------------
 #
 # OUTPUT:
-# 
------------------------------------------------------------------------------------------
-# model  the trained model
-# 
------------------------------------------------------------------------------------------
+# 
------------------------------------------------------------------------------
+# model          Trained model/weights (shape: n x 1, w/ intercept: n+1)
+# 
------------------------------------------------------------------------------
 
 m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = 
FALSE,
     Double epsilon = 0.001, Double reg = 1, Integer maxIterations = 100, 
@@ -73,29 +75,22 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, 
Boolean intercept = FALSE
   if(check_min != -1 | check_max != +1)
     Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - 
check_min)
     
-  # If column_id is -1 then we assume that the fundamental algorithm is MSVM, 
-  # Therefore don't print message.
+  # If column_id is -1 then we assume that it's called from within MSVM
   if(verbose & columnId == -1)
     print('Running L2-SVM ')
 
-  num_samples = nrow(X)
-  num_classes = ncol(Y)
-  
   # Add Bias 
-  num_rows_in_w = ncol(X)
   if (intercept) {
-    ones  = matrix(1, rows=num_samples, cols=1)
+    ones  = matrix(1, rows=nrow(X), cols=1)
     X = cbind(X, ones);
-    num_rows_in_w += 1
   }
   
-  w = matrix(0, rows=num_rows_in_w, cols=1)
+  w = matrix(0, rows=ncol(X), cols=1)
+  Xw = matrix(0, rows=nrow(X), cols=1)
 
   g_old = t(X) %*% Y
   s = g_old
 
-  Xw = matrix(0, rows=nrow(X), cols=1)
-
   iter = 0
   continue = TRUE
   while(continue & iter < maxIterations)  {
@@ -129,8 +124,8 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, 
Boolean intercept = FALSE
     g_new = t(X) %*% (out * Y) - reg * w
 
     if(verbose) {
-      colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,")
-      print("Iter: " + toString(iter) + " InnerIter: " + toString(iiter) +" 
--- "+ colstr + " Obj:" + obj)
+      colstr = ifelse(columnId!=-1, "-- MSVM class="+columnId+": ", "")
+      print(colstr + "Iter: " + iter + " InnerIter: " + iiter +" --- " + " 
Obj:" + obj)
     }
 
     tmp = sum(s * g_old)
diff --git a/scripts/builtin/msvm.dml b/scripts/builtin/msvm.dml
index 4bf904f822..076b9eb597 100644
--- a/scripts/builtin/msvm.dml
+++ b/scripts/builtin/msvm.dml
@@ -19,26 +19,27 @@
 #
 #-------------------------------------------------------------
 
-# Implements builtin multi-class SVM with squared slack variables, 
-# learns one-against-the-rest binary-class classifiers by making a function 
call to l2SVM
+# This builtin function implements a multi-class Support Vector Machine (SVM)
+# with squared slack variables. The trained model comprises #classes
+# one-against-the-rest binary-class l2svm classification models.
 #
 # INPUT:
-#------------------------------------------------------------------------------------------
-# X              matrix X of feature vectors
-# Y              matrix Y of class labels
-# intercept      No Intercept ( If set to TRUE then a constant bias column is 
added to X)
-# num_classes    Number of classes
-# epsilon        Procedure terminates early if the reduction in objective 
function
-#                value is less than epsilon (tolerance) times the initial 
objective function value.
+#-------------------------------------------------------------------------------
+# X              Feature matrix X (shape: m x n)
+# Y              Label vector y of class labels (shape: m x 1),
+#                where max(Y) is assumed to be the number of classes
+# intercept      Indicator if a bias column should be added to X and the model
+# epsilon        Tolerance for early termination if the reduction of objective
+#                function is less than epsilon times the initial objective
 # reg            Regularization parameter (lambda) for L2 regularization
-# maxIterations  Maximum number of conjugate gradient iterations
-# verbose        Set to true to print while training.
-# 
-----------------------------------------------------------------------------------------
+# maxIterations  Maximum number of conjugate gradient (outer l2svm) iterations
+# verbose        Indicator if training details should be printed
+# 
------------------------------------------------------------------------------
 #
 # OUTPUT:
-#-----------------------------------------------------------------------------------
-# model  model matrix
-#-----------------------------------------------------------------------------------
+#-------------------------------------------------------------------------------
+# model          Trained model/weights (shape: n x max(Y), w/ intercept: n+1)
+#-------------------------------------------------------------------------------
 
 m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = 
FALSE,
     Double epsilon = 0.001, Double reg = 1.0, Integer maxIterations = 100,
@@ -55,25 +56,31 @@ m_msvm = function(Matrix[Double] X, Matrix[Double] Y, 
Boolean intercept = FALSE,
     print("msvm: matrix X contains "+numNaNs+" missing values, replacing with 
0.")
     X = replace(target=X, pattern=NaN, replacement=0);
   }
-  num_rows_in_w = ncol(X)
+  # append once, and call l2svm always with intercept=FALSE
   if(intercept) {
-    # append once, and call l2svm always with intercept=FALSE 
     ones = matrix(1, rows=nrow(X), cols=1)
     X = cbind(X, ones);
-    num_rows_in_w += 1
   }
 
   if(ncol(Y) > 1)
-    Y = rowMaxs(Y * t(seq(1,ncol(Y))))
+    Y = rowIndexMax(Y)
 
   # Assuming number of classes to be max contained in Y
-  w = matrix(0, rows=num_rows_in_w, cols=max(Y))
+  w = matrix(0, rows=ncol(X), cols=max(Y))
 
   parfor(class in 1:max(Y)) {
+    # extract the class' binary labels and convert to -1/+1
     Y_local = 2 * (Y == class) - 1
-    w[,class] = l2svm(X=X, Y=Y_local, intercept=FALSE,
+    # train l2svm model with robustness for non-existing classes
+    nnzY = sum(Y == class);
+    if( nnzY > 0 ) {
+      w[,class] = l2svm(X=X, Y=Y_local, intercept=FALSE,
         epsilon=epsilon, reg=reg, maxIterations=maxIterations,
         verbose=verbose, columnId=class)
+    }
+    else {
+      w[,class] = matrix(-Inf, ncol(X), 1);
+    }
   }
 
   model = w
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
index 86733fb06f..76a8924ae1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
@@ -19,7 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.functionobjects.IfElse;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
@@ -71,12 +73,20 @@ public class TernaryCPInstruction extends 
ComputationCPInstruction {
                        ec.setMatrixOutput(output.getName(), out);
                }
                else { //SCALARS
-                       double value = ((TernaryOperator)_optr).fn.execute(
-                               ec.getScalarInput(input1).getDoubleValue(),
-                               ec.getScalarInput(input2).getDoubleValue(),
-                               ec.getScalarInput(input3).getDoubleValue());
-                       ec.setScalarOutput(output.getName(), ScalarObjectFactory
-                               .createScalarObject(output.getValueType(), 
value));
+                       if( ((TernaryOperator)_optr).fn instanceof IfElse
+                               && output.getValueType() == ValueType.STRING) {
+                               String value = 
(ec.getScalarInput(input1).getDoubleValue() != 0 ?
+                                       ec.getScalarInput(input2) : 
ec.getScalarInput(input3)).getStringValue();
+                               ec.setScalarOutput(output.getName(), new 
StringObject(value));
+                       }
+                       else {
+                               double value = 
((TernaryOperator)_optr).fn.execute(
+                                       
ec.getScalarInput(input1).getDoubleValue(),
+                                       
ec.getScalarInput(input2).getDoubleValue(),
+                                       
ec.getScalarInput(input3).getDoubleValue());
+                               ec.setScalarOutput(output.getName(), 
ScalarObjectFactory
+                                       
.createScalarObject(output.getValueType(), value));
+                       }
                }
        }
 }

Reply via email to