This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 159cc8e2e4 [SYSTEMDS-3777] Improved adasyn builtin (tests, vectorized
impl)
159cc8e2e4 is described below
commit 159cc8e2e40a30715de6bbbf957fb7be8a49498d
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Nov 17 16:37:48 2024 +0100
[SYSTEMDS-3777] Improved adasyn builtin (tests, vectorized impl)
This patch adds real-data tests for the new adasyn builtin function,
and changes the implementation to a vectorized implementation that
extracts over-sampled rows via a randomized permutation matrix multiply.
On the Diabetes dataset (with moderate class imbalance of 500 vs 268)
ADASYN slightly improves the test accuracy from 78.3 to 78.7%. It is
also noteworthy that the original ADASYN paper from 2008 only achieved
0.6831 and 0.6833 (with ADASYN) on this dataset.
---
scripts/builtin/adasyn.dml | 93 +++++++---------------
.../builtin/part1/BuiltinAdasynRealDataTest.java | 79 ++++++++++++++++++
.../functions/builtin/part1/BuiltinAdasynTest.java | 23 ------
src/test/resources/datasets/diabetes/diabetes.json | 17 ----
.../builtin/{adasyn.dml => adasynRealData.dml} | 23 +++++-
5 files changed, 130 insertions(+), 105 deletions(-)
diff --git a/scripts/builtin/adasyn.dml b/scripts/builtin/adasyn.dml
index ca73fef9a4..458c993c7e 100644
--- a/scripts/builtin/adasyn.dml
+++ b/scripts/builtin/adasyn.dml
@@ -24,19 +24,22 @@
#
# INPUT:
#
--------------------------------------------------------------------------------------
-# minority Matrix of minority class samples
-# majority Matrix of majority class samples
-# k Number of nearest neighbors
-# beta Desired balance level after generation of synthetic data [0,
1]
+# X Feature matrix [shape: n-by-m]
+# Y Class labels [shape: n-by-1]
+# k Number of nearest neighbors
+# beta Desired balance level after generation of synthetic data [0, 1]
+# dth Distribution threshold
#
--------------------------------------------------------------------------------------
#
# OUTPUT:
#
-------------------------------------------------------------------------------------
-# Z Matrix of G synthetic minority class samples, with G = (ml-ms)*beta
+# Xp Feature matrix of n original rows followed by G = (ml-ms)*beta
synthetic rows
+# Yp Class labels aligned with output X
#
-------------------------------------------------------------------------------------
-m_adasyn = function(Matrix[Double] minority, Matrix[Double] majority, Integer
k = 1, Double beta = 0.8)
- return (Matrix[Double] Z)
+m_adasyn = function(Matrix[Double] X, Matrix[Double] Y, Integer k = 2,
+ Double beta = 1.0, Double dth = 0.9)
+ return (Matrix[Double] Xp, Matrix[Double] Yp)
{
if(k < 1) {
print("ADASYN: k should not be less than 1. Setting k value to default k =
1.")
@@ -44,75 +47,37 @@ m_adasyn = function(Matrix[Double] minority, Matrix[Double]
majority, Integer k
}
# Preprocessing
- dth = 0.9
- ms = nrow(minority)
- ml = nrow(majority)
- combined = rbind(minority, majority)
+ freq = t(table(Y, 1));
+ minorIdx = as.scalar(rowIndexMin(freq))
+ majorIdx = as.scalar(rowIndexMax(freq))
# (Step 1)
# Calculate the degree of class imbalance, where d in (0, 1]
- d = ms/ml
+ d = as.scalar(freq[1,minorIdx])/sum(freq)
# (Step 2)
# Check if imbalance is lower than predefined threshold
- if(d >= dth){
+ print("ADASYN: class imbalance: " + d)
+
+ if(d >= dth) {
stop("ADASYN: Class imbalance not large enough.")
}
# (Step 2a)
# Calculate number of synthetic data examples
- G = (ml-ms)*beta
+ G = as.scalar(freq[1,majorIdx]-freq[1,minorIdx])*beta
# (Step 2b)
- # For each x_i in minority class, find k nearest neighbors.
- # Then, compute ratio r of neighbors belonging to majority class to total
number of neighbors k
- NNR = knnbf(combined, minority, k+1)
- NNR = NNR[,2:ncol(NNR)]
- delta = rowSums(NNR>ms)
- r = delta/k
- r = r + 0 #only to force materialization, caught by compiler rewrites
-
- # (Step 2c)
- # Normalize ratio vector r
- rSum = sum(r)
- r = r/rSum
-
- # (Step 2d)
- # Calculate the number of synthetic data examples that need to be
- # generated for each minority example x_i
- # Then, pre-allocate the result matrix Z
- g = round(r * G)
- gSum = sum(g)
- Z = matrix(0, rows=gSum, cols=ncol(minority)) # output matrix, slightly
overallocated
-
- # (Step 2e)
- # For each minority class data example x_i, generate g_i synthetic data
examples by
- # looping from 1 to g_i and randomly choosing one minority data example x_j
from
- # the k-nearest neighbors. Then, compute the synthetic sample s_i as
- # s_i = x_i + (x_j - x_i) * lambda, with lambda being a random number in [0,
1].
- minNNR = NNR * (NNR <= ms) # set every index from majority class to zero
- zeroCount = 0
- for(i in 1:nrow(minority)){
- row = minNNR[i, ] # slice a row
- minRow = removeEmpty(target=row, margin="cols") # remove all zero
values from that row
- hasSynthetic = as.scalar(g[i])>0
- hasMinorityNN = (as.scalar(minRow[1, 1]) > 0) & (hasSynthetic)
- if(hasMinorityNN){
- for(j in 1:as.scalar(g[i])){
- randomIndex = as.scalar(sample(ncol(minRow), 1))
- lambda = as.scalar(rand(rows=1, cols=1, min=0, max=1))
- randomMinIndex = as.scalar(minRow[ , randomIndex])
- randomMinNN = minority[randomMinIndex, ]
- insIdx = i+j-1-zeroCount
- Z[insIdx, ] = minority[i, ] + (randomMinNN - minority[i, ]) *
lambda
- }
- } else {
- zeroCount = zeroCount + 1
- }
- }
-
- diff = nrow(minority) - gSum
- numTrailZeros = zeroCount - diff
- Z = Z[1:gSum-numTrailZeros, ]
+ # For each x_i in non-majority class, find k nearest neighbors.
+ # Get G random points from the KNN set via a permutation matrix multiply
+ Xnonmajor = removeEmpty(target=X, margin="rows", select=(Y!=majorIdx))
+ Ynonmajor = removeEmpty(target=Y, margin="rows", select=(Y!=majorIdx))
+ NNR = knnbf(Xnonmajor, Xnonmajor, k+1)
+ NNR = matrix(NNR, rows=length(NNR), cols=1)
+ I = rand(rows=nrow(NNR), cols=1) < (G/nrow(NNR))
+ NNRg = removeEmpty(target=NNR, margin="rows", select=I);
+ P = table(seq(1, nrow(NNRg)), NNRg, nrow(NNRg), nrow(Xnonmajor));
+ Xp = rbind(X, P %*% Xnonmajor);
+ Yp = rbind(Y, P %*% Ynonmajor); # multi-class
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
new file mode 100644
index 0000000000..e310fd877a
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinAdasynRealDataTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "adasynRealData";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinAdasynRealDataTest.class.getSimpleName() + "/";
+
+ private final static String DIABETES_DATA = DATASET_DIR +
"diabetes/diabetes.csv";
+ private final static String DIABETES_TFSPEC = DATASET_DIR +
"diabetes/tfspec.json";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+ }
+
+ @Test
+ public void testDiabetesNoAdasyn() {
+ runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, false, 0.783, -1,
ExecType.CP);
+ }
+
+ @Test
+ public void testDiabetesAdasynK4() {
+ runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 4,
ExecType.CP);
+ }
+
+ @Test
+ public void testDiabetesAdasynK6() {
+ runAdasynTest(DIABETES_DATA, DIABETES_TFSPEC, true, 0.787, 6,
ExecType.CP);
+ }
+
+ private void runAdasynTest(String data, String tfspec, boolean adasyn,
double minAcc, int k, ExecType instType) {
+ Types.ExecMode platformOld = setExecMode(instType);
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats",
+ "-args", data, String.valueOf(adasyn),
String.valueOf(k), output("R")};
+
+ runTest(true, false, null, -1);
+
+ double acc = readDMLMatrixFromOutputDir("R").get(new
CellIndex(1,1));
+ Assert.assertTrue(acc >= minAcc);
+ Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java
deleted file mode 100644
index 43ce08eb56..0000000000
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynTest.java
+++ /dev/null
@@ -1,23 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.functions.builtin.part1;
-
-public class BuiltinAdasynTest {
-}
diff --git a/src/test/resources/datasets/diabetes/diabetes.json
b/src/test/resources/datasets/diabetes/diabetes.json
deleted file mode 100644
index e988e42651..0000000000
--- a/src/test/resources/datasets/diabetes/diabetes.json
+++ /dev/null
@@ -1,17 +0,0 @@
-{
- "ids": true,
- "recode": [9],
- "bin": [
- {"id": 1, "method": "equi-width", "numbins": 10},
- {"id": 2, "method": "equi-width", "numbins": 10},
- {"id": 3, "method": "equi-width", "numbins": 10},
- {"id": 4, "method": "equi-width", "numbins": 10},
- {"id": 5, "method": "equi-width", "numbins": 10},
- {"id": 6, "method": "equi-width", "numbins": 10},
- {"id": 7, "method": "equi-width", "numbins": 10},
- {"id": 8, "method": "equi-width", "numbins": 10}
- ]
-}
-
-
-
diff --git a/src/test/scripts/functions/builtin/adasyn.dml
b/src/test/scripts/functions/builtin/adasynRealData.dml
similarity index 61%
rename from src/test/scripts/functions/builtin/adasyn.dml
rename to src/test/scripts/functions/builtin/adasynRealData.dml
index 1fd7838a7e..cc3e7e5170 100644
--- a/src/test/scripts/functions/builtin/adasyn.dml
+++ b/src/test/scripts/functions/builtin/adasynRealData.dml
@@ -17,4 +17,25 @@
# specific language governing permissions and limitations
# under the License.
#
-#--------------------------------------
+#-------------------------------------------------------------
+
+
+M = read($1, data_type="matrix", format="csv", header=TRUE);
+Y = M[, ncol(M)] + 1
+X = M[, 1:ncol(M)-1]
+upsample = as.logical($2)
+
+[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7);
+
+if( upsample ) {
+ # oversampling all classes other than majority
+ [Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3);
+}
+
+B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);
+[P,yhat,acc] = multiLogRegPredict(X=Xtest, Y=Ytest, B=B);
+print("accuracy: "+acc)
+
+R = as.matrix(acc/100);
+write(R, $4);
+