This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8edabbc730 [SYSTEMDS-3777] Fix adasyn test flakiness via fixed seeds
8edabbc730 is described below
commit 8edabbc730dca1a104c2a0ed1f11d5ca8910bf79
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Nov 17 17:17:10 2024 +0100
[SYSTEMDS-3777] Fix adasyn test flakiness via fixed seeds
---
scripts/builtin/adasyn.dml | 5 +++--
src/test/scripts/functions/builtin/adasynRealData.dml | 4 ++--
2 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/scripts/builtin/adasyn.dml b/scripts/builtin/adasyn.dml
index 458c993c7e..6424e5b193 100644
--- a/scripts/builtin/adasyn.dml
+++ b/scripts/builtin/adasyn.dml
@@ -29,6 +29,7 @@
# k Number of nearest neighbors
# beta Desired balance level after generation of synthetic data [0, 1]
# dth Distribution threshold
+# seed Seed for randomized data point selection
#
--------------------------------------------------------------------------------------
#
# OUTPUT:
@@ -38,7 +39,7 @@
#
-------------------------------------------------------------------------------------
m_adasyn = function(Matrix[Double] X, Matrix[Double] Y, Integer k = 2,
- Double beta = 1.0, Double dth = 0.9)
+ Double beta = 1.0, Double dth = 0.9, Integer seed = -1)
return (Matrix[Double] Xp, Matrix[Double] Yp)
{
if(k < 1) {
@@ -74,7 +75,7 @@ m_adasyn = function(Matrix[Double] X, Matrix[Double] Y,
Integer k = 2,
Ynonmajor = removeEmpty(target=Y, margin="rows", select=(Y!=majorIdx))
NNR = knnbf(Xnonmajor, Xnonmajor, k+1)
NNR = matrix(NNR, rows=length(NNR), cols=1)
- I = rand(rows=nrow(NNR), cols=1) < (G/nrow(NNR))
+ I = rand(rows=nrow(NNR), cols=1, seed=seed) < (G/nrow(NNR))
NNRg = removeEmpty(target=NNR, margin="rows", select=I);
P = table(seq(1, nrow(NNRg)), NNRg, nrow(NNRg), nrow(Xnonmajor));
Xp = rbind(X, P %*% Xnonmajor);
diff --git a/src/test/scripts/functions/builtin/adasynRealData.dml
b/src/test/scripts/functions/builtin/adasynRealData.dml
index cc3e7e5170..6e401ec336 100644
--- a/src/test/scripts/functions/builtin/adasynRealData.dml
+++ b/src/test/scripts/functions/builtin/adasynRealData.dml
@@ -25,11 +25,11 @@ Y = M[, ncol(M)] + 1
X = M[, 1:ncol(M)-1]
upsample = as.logical($2)
-[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7);
+[Xtrain, Xtest, Ytrain, Ytest] = split(X=X, Y=Y, f=0.7, seed=3);
if( upsample ) {
# oversampling all classes other than majority
- [Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3);
+ [Xtrain,Ytrain] = adasyn(X=Xtrain, Y=Ytrain, k=$3, seed=7);
}
B = multiLogReg(X=Xtrain, Y=Ytrain, icpt=2);