This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 159cc8e2e4 [SYSTEMDS-3777] Improved adasyn builtin (tests, vectorized 
impl)
159cc8e2e4 is described below

commit 159cc8e2e40a30715de6bbbf957fb7be8a49498d
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Nov 17 16:37:48 2024 +0100

    [SYSTEMDS-3777] Improved adasyn builtin (tests, vectorized impl)
    
    This patch adds real-data tests for the new adasyn builtin function,
    and changes the implementation to a vectorized implementation that
    extracts over-sampled rows via a randomized permutation matrix multiply.
    On the Diabetes dataset (with moderate class imbalance of 500 vs 268)
    ADASYN slightly improves the test accuracy from 78.3 to 78.7%. It is
    also noteworthy that the original ADASYN paper from 2008 only achieved
    0.6831 and 0.6833 (with ADASYN) on this dataset.
---
 scripts/builtin/adasyn.dml                         | 93 +++++++---------------
 .../builtin/part1/BuiltinAdasynRealDataTest.java   | 79 ++++++++++++++++++
 .../functions/builtin/part1/BuiltinAdasynTest.java | 23 ------
 src/test/resources/datasets/diabetes/diabetes.json | 17 ----
 .../builtin/{adasyn.dml => adasynRealData.dml}     | 23 +++++-
 5 files changed, 130 insertions(+), 105 deletions(-)

diff --git a/scripts/builtin/adasyn.dml b/scripts/builtin/adasyn.dml
index ca73fef9a4..458c993c7e 100644
--- a/scripts/builtin/adasyn.dml
+++ b/scripts/builtin/adasyn.dml
@@ -24,19 +24,22 @@
 #
 # INPUT:
 # 
--------------------------------------------------------------------------------------
-# minority        Matrix of minority class samples
-# majority        Matrix of majority class samples
-# k               Number of nearest neighbors
-# beta            Desired balance level after generation of synthetic data [0, 
1]
+# X        Feature matrix [shape: n-by-m]
+# Y        Class labels [shape: n-by-1]
+# k        Number of nearest neighbors
+# beta     Desired balance level after generation of synthetic data [0, 1]
+# dth      Distribution threshold
 # 
--------------------------------------------------------------------------------------
 #
 # OUTPUT:
 # 
-------------------------------------------------------------------------------------
-# Z     Matrix of G synthetic minority class samples, with G = (ml-ms)*beta
+# Xp       Feature matrix of n original rows followed by G = (ml-ms)*beta 
synthetic rows
+# Yp       Class labels aligned with output X
 # 
-------------------------------------------------------------------------------------
 
-m_adasyn = function(Matrix[Double] minority, Matrix[Double] majority, Integer 
k = 1, Double beta = 0.8)
-  return (Matrix[Double] Z)
+m_adasyn = function(Matrix[Double] X, Matrix[Double] Y, Integer k = 2,
+  Double beta = 1.0, Double dth = 0.9)
+  return (Matrix[Double] Xp, Matrix[Double] Yp)
 {
   if(k < 1) {
     print("ADASYN: k should not be less than 1. Setting k value to default k = 
1.")
@@ -44,75 +47,37 @@ m_adasyn = function(Matrix[Double] minority, Matrix[Double] 
majority, Integer k
   }
 
   # Preprocessing
-  dth = 0.9
-  ms = nrow(minority)
-  ml = nrow(majority)
-  combined = rbind(minority, majority)
+  freq = t(table(Y, 1));
+  minorIdx = as.scalar(rowIndexMin(freq))
+  majorIdx = as.scalar(rowIndexMax(freq))
 
   # (Step 1)
   # Calculate the degree of class imbalance, where d in (0, 1]
-  d = ms/ml
+  d = as.scalar(freq[1,minorIdx])/sum(freq)
 
   # (Step 2)
   # Check if imbalance is lower than predefined threshold
-  if(d >= dth){
+  print("ADASYN: class imbalance: " + d)
+
+  if(d >= dth) {
       stop("ADASYN: Class imbalance not large enough.")
   }
 
   # (Step 2a)
   # Calculate number of synthetic data examples
-  G = (ml-ms)*beta
+  G = as.scalar(freq[1,majorIdx]-freq[1,minorIdx])*beta
 
   # (Step 2b)
-  # For each x_i in minority class, find k nearest neighbors.
-  # Then, compute ratio r of neighbors belonging to majority class to total 
number of neighbors k
-  NNR = knnbf(combined, minority, k+1)
-  NNR = NNR[,2:ncol(NNR)]
-  delta = rowSums(NNR>ms)
-  r = delta/k
-  r = r + 0   #only to force materialization, caught by compiler rewrites
-
-  # (Step 2c)
-  # Normalize ratio vector r
-  rSum = sum(r)
-  r = r/rSum
-
-  # (Step 2d)
-  # Calculate the number of synthetic data examples that need to be
-  # generated for each minority example x_i
-  # Then, pre-allocate the result matrix Z
-  g = round(r * G)
-  gSum = sum(g)
-  Z = matrix(0, rows=gSum, cols=ncol(minority)) # output matrix, slightly 
overallocated
-
-  # (Step 2e)
-  # For each minority class data example x_i, generate g_i synthetic data 
examples by
-  # looping from 1 to g_i and randomly choosing one minority data example x_j 
from
-  # the k-nearest neighbors. Then, compute the synthetic sample s_i as
-  # s_i = x_i + (x_j - x_i) * lambda, with lambda being a random number in [0, 
1].
-  minNNR = NNR * (NNR <= ms)  # set every index from majority class to zero
-  zeroCount = 0
-  for(i in 1:nrow(minority)){
-      row = minNNR[i, ]       # slice a row
-      minRow = removeEmpty(target=row, margin="cols")     # remove all zero 
values from that row
-      hasSynthetic = as.scalar(g[i])>0
-      hasMinorityNN = (as.scalar(minRow[1, 1]) > 0) & (hasSynthetic)
-      if(hasMinorityNN){
-          for(j in 1:as.scalar(g[i])){
-              randomIndex = as.scalar(sample(ncol(minRow), 1))
-              lambda = as.scalar(rand(rows=1, cols=1, min=0, max=1))
-              randomMinIndex = as.scalar(minRow[ , randomIndex])
-              randomMinNN = minority[randomMinIndex, ]
-              insIdx = i+j-1-zeroCount
-              Z[insIdx, ] = minority[i, ] + (randomMinNN - minority[i, ]) * 
lambda
-          }
-      } else {
-          zeroCount = zeroCount + 1
-      }
-  }
-
-  diff = nrow(minority) - gSum
-  numTrailZeros = zeroCount - diff
-  Z = Z[1:gSum-numTrailZeros, ]
+  # For each x_i in non-majority class, find k nearest neighbors.
+  # Get G random points from the KNN set via a permutation matrix multiply
+  Xnonmajor = removeEmpty(target=X, margin="rows", select=(Y!=majorIdx))
+  Ynonmajor = removeEmpty(target=Y, margin="rows", select=(Y!=majorIdx))
+  NNR = knnbf(Xnonmajor, Xnonmajor, k+1)
+  NNR = matrix(NNR, rows=length(NNR), cols=1)
+  I = rand(rows=nrow(NNR), cols=1) < (G/nrow(NNR))
+  NNRg = removeEmpty(target=NNR, margin="rows", select=I);
+  P = table(seq(1, nrow(NNRg)), NNRg, nrow(NNRg), nrow(Xnonmajor));
+  Xp = rbind(X, P %*% Xnonmajor);
+  Yp = rbind(Y, P %*% Ynonmajor); # multi-class
 }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
new file mode 100644
index 0000000000..e310fd877a
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinAdasynRealDataTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "adasynRealData";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinAdasynRealDataTest.class.getSimpleName() + "/";
+
+       private final static String DIABETES_DATA = DATASET_DIR + 
"diabetes/diabetes.csv";
+       private final static String DIABETES_TFSPEC = DATASET_DIR + 
"diabetes/tfspec.json";
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+       }
+
+       @Test
+       public void testDiabetesNoAdasyn() {
+               runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, false, 0.783, -1, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testDiabetesAdasynK4() {
+               runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 4, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testDiabetesAdasynK6() {
+               runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 6, 
ExecType.CP);
+       }
+       
+       private void runAdasynTest(String data, String tfspec, boolean adasyn, 
double minAcc, int k, ExecType instType) {
+               Types.ExecMode platformOld = setExecMode(instType);
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats",
+                               "-args", data, String.valueOf(adasyn), 
String.valueOf(k), output("R")};
+
+                       runTest(true, false, null, -1);
+
+                       double acc = readDMLMatrixFromOutputDir("R").get(new 
CellIndex(1,1));
+                       Assert.assertTrue(acc >= minAcc);
+                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java
deleted file mode 100644
index 43ce08eb56..0000000000
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java
+++ /dev/null
@@ -1,23 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.functions.builtin.part1;
-
-public class BuiltinAdasynTest {
-}
diff --git a/src/test/resources/datasets/diabetes/diabetes.json 
b/src/test/resources/datasets/diabetes/diabetes.json
deleted file mode 100644
index e988e42651..0000000000
--- a/src/test/resources/datasets/diabetes/diabetes.json
+++ /dev/null
@@ -1,17 +0,0 @@
-{
-  "ids": true,
-  "recode": [9],
-  "bin": [
-    {"id": 1, "method": "equi-width", "numbins": 10},
-    {"id": 2, "method": "equi-width", "numbins": 10},
-    {"id": 3, "method": "equi-width", "numbins": 10},
-    {"id": 4, "method": "equi-width", "numbins": 10},
-    {"id": 5, "method": "equi-width", "numbins": 10},
-    {"id": 6, "method": "equi-width", "numbins": 10},
-    {"id": 7, "method": "equi-width", "numbins": 10},
-    {"id": 8, "method": "equi-width", "numbins": 10}
-  ]
-}
-
-
-  
diff --git a/src/test/scripts/functions/builtin/adasyn.dml 
b/src/test/scripts/functions/builtin/adasynRealData.dml
similarity index 61%
rename from src/test/scripts/functions/builtin/adasyn.dml
rename to src/test/scripts/functions/builtin/adasynRealData.dml
index 1fd7838a7e..cc3e7e5170 100644
--- a/src/test/scripts/functions/builtin/adasyn.dml
+++ b/src/test/scripts/functions/builtin/adasynRealData.dml
@@ -17,4 +17,25 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-#--------------------------------------
+#-------------------------------------------------------------
+
+
+M = read($1, data_type="matrix", format="csv", header=TRUE);
+Y = M[, ncol(M)] + 1
+X = M[, 1:ncol(M)-1]
+upsample = as.logical($2)
+
+[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7);
+
+if( upsample ) {
+  # oversampling all classes other than majority
+  [Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3);
+}
+
+B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);
+[P,yhat,acc] = multiLogRegPredict(X=Xtest, Y=Ytest, B=B);
+print("accuracy: "+acc)
+
+R = as.matrix(acc/100);
+write(R, $4);
+

Reply via email to