This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 422cc0421985d88491c1a58c11652693c77f2417 Author: Matthias Boehm <[email protected]> AuthorDate: Sat Sep 3 22:13:42 2022 +0200 [SYSTEMDS-3435] Fix robustness of multi-class svm for missing classes This patch fixes issues of multi-class SVM with scenarios where the training data does not contain all classes in [1,max class ID]. We now only train one-against-the-rest models if a class exists in at least one training example. Furthermore, this patch also adds support for ternary ifelse scalar operations with string arguments which so far always evaluated to "0.0", but was used in verbose debug output of MSVM calling L2SVM. Finally, the documentation of MSVM and L2SVM has been cleaned up to make the descriptions more understandable, add expected shapes, and remove non-existing parameters. --- scripts/builtin/l2svm.dml | 53 ++++++++++------------ scripts/builtin/msvm.dml | 49 +++++++++++--------- .../instructions/cp/TernaryCPInstruction.java | 22 ++++++--- 3 files changed, 68 insertions(+), 56 deletions(-) diff --git a/scripts/builtin/l2svm.dml b/scripts/builtin/l2svm.dml index cdcc1ba4df..8b25915625 100644 --- a/scripts/builtin/l2svm.dml +++ b/scripts/builtin/l2svm.dml @@ -19,27 +19,29 @@ # #------------------------------------------------------------- -# Builtin function Implements binary-class SVM with squared slack variables +# This builting function implements binary-class Support Vector Machine (SVM) +# with squared slack variables (l2 regularization). # # INPUT: -# ----------------------------------------------------------------------------------------- -# X matrix X of feature vectors -# Y matrix Y of class labels have to be a single column -# intercept No Intercept ( If set to TRUE then a constant bias column is added to X) -# epsilon Procedure terminates early if the reduction in objective function value is less -# than epsilon (tolerance) times the initial objective function value. -# reg Regularization parameter (reg) for L2 regularization -# maxIterations Maximum number of conjugate gradient iterations -# maxii max inner for loop iterations -# verbose Set to true if one wants print statements updating on loss. -# columnId The column Id used if one wants to add a ID to the print statement, +# ------------------------------------------------------------------------------ +# X Feature matrix X (shape: m x n) +# Y Label vector y of class labels (shape: m x 1), assumed binary +# in -1/+1 or 1/2 encoding. +# intercept Indicator if a bias column should be added to X and the model +# epsilon Tolerance for early termination if the reduction of objective +# function is less than epsilon times the initial objective +# reg Regularization parameter (lambda) for L2 regularization +# maxIterations Maximum number of conjugate gradient (outer) iterations +# maxii Maximum number of line search (inner) iterations +# verbose Indicator if training details should be printed +# columnId An optional class ID used in verbose print output, # eg. used when L2SVM is used in MSVM. -# ----------------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------ # # OUTPUT: -# ------------------------------------------------------------------------------------------ -# model the trained model -# ------------------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ +# model Trained model/weights (shape: n x 1, w/ intercept: n+1) +# ------------------------------------------------------------------------------ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE, Double epsilon = 0.001, Double reg = 1, Integer maxIterations = 100, @@ -73,29 +75,22 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE if(check_min != -1 | check_max != +1) Y = 2/(check_max - check_min)*Y - (check_min + check_max)/(check_max - check_min) - # If column_id is -1 then we assume that the fundamental algorithm is MSVM, - # Therefore don't print message. + # If column_id is -1 then we assume that it's called from within MSVM if(verbose & columnId == -1) print('Running L2-SVM ') - num_samples = nrow(X) - num_classes = ncol(Y) - # Add Bias - num_rows_in_w = ncol(X) if (intercept) { - ones = matrix(1, rows=num_samples, cols=1) + ones = matrix(1, rows=nrow(X), cols=1) X = cbind(X, ones); - num_rows_in_w += 1 } - w = matrix(0, rows=num_rows_in_w, cols=1) + w = matrix(0, rows=ncol(X), cols=1) + Xw = matrix(0, rows=nrow(X), cols=1) g_old = t(X) %*% Y s = g_old - Xw = matrix(0, rows=nrow(X), cols=1) - iter = 0 continue = TRUE while(continue & iter < maxIterations) { @@ -129,8 +124,8 @@ m_l2svm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE g_new = t(X) %*% (out * Y) - reg * w if(verbose) { - colstr = ifelse(columnId!=-1, ", Col:"+columnId + " ,", " ,") - print("Iter: " + toString(iter) + " InnerIter: " + toString(iiter) +" --- "+ colstr + " Obj:" + obj) + colstr = ifelse(columnId!=-1, "-- MSVM class="+columnId+": ", "") + print(colstr + "Iter: " + iter + " InnerIter: " + iiter +" --- " + " Obj:" + obj) } tmp = sum(s * g_old) diff --git a/scripts/builtin/msvm.dml b/scripts/builtin/msvm.dml index 4bf904f822..076b9eb597 100644 --- a/scripts/builtin/msvm.dml +++ b/scripts/builtin/msvm.dml @@ -19,26 +19,27 @@ # #------------------------------------------------------------- -# Implements builtin multi-class SVM with squared slack variables, -# learns one-against-the-rest binary-class classifiers by making a function call to l2SVM +# This builtin function implements a multi-class Support Vector Machine (SVM) +# with squared slack variables. The trained model comprises #classes +# one-against-the-rest binary-class l2svm classification models. # # INPUT: -#------------------------------------------------------------------------------------------ -# X matrix X of feature vectors -# Y matrix Y of class labels -# intercept No Intercept ( If set to TRUE then a constant bias column is added to X) -# num_classes Number of classes -# epsilon Procedure terminates early if the reduction in objective function -# value is less than epsilon (tolerance) times the initial objective function value. +#------------------------------------------------------------------------------- +# X Feature matrix X (shape: m x n) +# Y Label vector y of class labels (shape: m x 1), +# where max(Y) is assumed to be the number of classes +# intercept Indicator if a bias column should be added to X and the model +# epsilon Tolerance for early termination if the reduction of objective +# function is less than epsilon times the initial objective # reg Regularization parameter (lambda) for L2 regularization -# maxIterations Maximum number of conjugate gradient iterations -# verbose Set to true to print while training. -# ----------------------------------------------------------------------------------------- +# maxIterations Maximum number of conjugate gradient (outer l2svm) iterations +# verbose Indicator if training details should be printed +# ------------------------------------------------------------------------------ # # OUTPUT: -#----------------------------------------------------------------------------------- -# model model matrix -#----------------------------------------------------------------------------------- +#------------------------------------------------------------------------------- +# model Trained model/weights (shape: n x max(Y), w/ intercept: n+1) +#------------------------------------------------------------------------------- m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE, Double epsilon = 0.001, Double reg = 1.0, Integer maxIterations = 100, @@ -55,25 +56,31 @@ m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE, print("msvm: matrix X contains "+numNaNs+" missing values, replacing with 0.") X = replace(target=X, pattern=NaN, replacement=0); } - num_rows_in_w = ncol(X) + # append once, and call l2svm always with intercept=FALSE if(intercept) { - # append once, and call l2svm always with intercept=FALSE ones = matrix(1, rows=nrow(X), cols=1) X = cbind(X, ones); - num_rows_in_w += 1 } if(ncol(Y) > 1) - Y = rowMaxs(Y * t(seq(1,ncol(Y)))) + Y = rowIndexMax(Y) # Assuming number of classes to be max contained in Y - w = matrix(0, rows=num_rows_in_w, cols=max(Y)) + w = matrix(0, rows=ncol(X), cols=max(Y)) parfor(class in 1:max(Y)) { + # extract the class' binary labels and convert to -1/+1 Y_local = 2 * (Y == class) - 1 - w[,class] = l2svm(X=X, Y=Y_local, intercept=FALSE, + # train l2svm model with robustness for non-existing classes + nnzY = sum(Y == class); + if( nnzY > 0 ) { + w[,class] = l2svm(X=X, Y=Y_local, intercept=FALSE, epsilon=epsilon, reg=reg, maxIterations=maxIterations, verbose=verbose, columnId=class) + } + else { + w[,class] = matrix(-Inf, ncol(X), 1); + } } model = w diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java index 86733fb06f..76a8924ae1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java @@ -19,7 +19,9 @@ package org.apache.sysds.runtime.instructions.cp; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.IfElse; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.TernaryOperator; @@ -71,12 +73,20 @@ public class TernaryCPInstruction extends ComputationCPInstruction { ec.setMatrixOutput(output.getName(), out); } else { //SCALARS - double value = ((TernaryOperator)_optr).fn.execute( - ec.getScalarInput(input1).getDoubleValue(), - ec.getScalarInput(input2).getDoubleValue(), - ec.getScalarInput(input3).getDoubleValue()); - ec.setScalarOutput(output.getName(), ScalarObjectFactory - .createScalarObject(output.getValueType(), value)); + if( ((TernaryOperator)_optr).fn instanceof IfElse + && output.getValueType() == ValueType.STRING) { + String value = (ec.getScalarInput(input1).getDoubleValue() != 0 ? + ec.getScalarInput(input2) : ec.getScalarInput(input3)).getStringValue(); + ec.setScalarOutput(output.getName(), new StringObject(value)); + } + else { + double value = ((TernaryOperator)_optr).fn.execute( + ec.getScalarInput(input1).getDoubleValue(), + ec.getScalarInput(input2).getDoubleValue(), + ec.getScalarInput(input3).getDoubleValue()); + ec.setScalarOutput(output.getName(), ScalarObjectFactory + .createScalarObject(output.getValueType(), value)); + } } } }
