This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fa8b122ed [SYSTEMDS-3153] Fix KNN
4fa8b122ed is described below

commit 4fa8b122eddefffd4d291bcea2011fdcc8c485f1
Author: Christina Dionysio <[email protected]>
AuthorDate: Wed Oct 18 13:44:07 2023 +0200

    [SYSTEMDS-3153] Fix KNN
    
    fixes the sampling method for missing value imputation using KNN
    
    Closes #1925
---
 scripts/builtin/imputeByKNN.dml                    | 107 ++++++++-------------
 .../builtin/part1/BuiltinImputeKNNTest.java        |  12 ++-
 src/test/scripts/functions/builtin/imputeByKNN.dml |   8 +-
 3 files changed, 52 insertions(+), 75 deletions(-)

diff --git a/scripts/builtin/imputeByKNN.dml b/scripts/builtin/imputeByKNN.dml
index 240631be47..13136ff2c9 100644
--- a/scripts/builtin/imputeByKNN.dml
+++ b/scripts/builtin/imputeByKNN.dml
@@ -19,7 +19,6 @@
 #
 #-------------------------------------------------------------
 
-
 # Imputes missing values, indicated by NaNs, using KNN-based methods
 # (k-nearest neighbors by euclidean distance). In order to avoid NaNs in
 # distance computation and meaningful nearest neighbor search, we initialize
@@ -50,13 +49,12 @@
 # result     Imputed dataset
 # 
------------------------------------------------------------------------------
 
-m_imputeByKNN = function(Matrix[Double] X, String method="dist", Int seed=-1, 
Double sample_frac = 0.1)
+m_imputeByKNN = function(Matrix[Double] X, String method="dist", Int seed=-1, 
Double sample_frac=0.1)
   return(Matrix[Double] result)
 {
   #KNN-Imputation Script
 
-  #Create a mask for placeholder and to check for missing values
-  masked = is.nan(X)
+  imputedValue = X
 
   #Impute NaN value with temporary mean value of the column
   filled_matrix = imputeByMean(X, matrix(0, cols = ncol(X), rows = 1))
@@ -66,103 +64,76 @@ m_imputeByKNN = function(Matrix[Double] X, String 
method="dist", Int seed=-1, Do
     distance_matrix = dist(filled_matrix)
 
     #Change 0 value so rowIndexMin will ignore that diagonal value
-    distance_matrix = replace(target = distance_matrix, pattern = 0, 
replacement = 999)
+    distance_matrix = replace(target=distance_matrix, pattern=0, 
replacement=999)
 
     #Get the minimum distance row-wise computation
     minimum_index = rowIndexMin(distance_matrix)
 
     #Create aligned matrix from minimum index
-    aligned = table(minimum_index, seq(1, nrow(X)), odim1 = nrow(X), odim2 = 
nrow(X))
+    aligned = table(minimum_index, seq(1, nrow(X)), odim1=nrow(X), 
odim2=nrow(X))
 
     #Get the X records that need to be imputed
     imputedValue = t(filled_matrix) %*% aligned
-
-    #Update the mask value
-    masked = t(imputedValue) * masked
+    imputedValue = t(imputedValue)
   }
   else if(method == "dist_missing") {
     #assuming small missing values
-    #Split the matrix into containing NaN values (missing records) and not 
containing NaN values (M2 records)
-    I = (rowSums(is.nan(X))!=0)
-    missing = removeEmpty(target=filled_matrix, margin="rows", select=I)
-
-    Y = (rowSums(is.nan(X))==0)
-    M2 = removeEmpty(target=filled_matrix, margin = "rows", select = Y)
-
-    #Calculate the euclidean distance between fully records and missing 
records, and then find the min value row wise
-    dotM2 = rowSums(M2 * M2) %*% matrix(1, rows = 1, cols = nrow(missing))
-    dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols = 
nrow(M2)))
-    D = sqrt(dotM2 + dotMissing - 2 * (M2 %*% t(missing)))
-    minD = rowIndexMin(t(D))
-
-    #Get the index location of the missing value
-    pos = rowMaxs(is.nan(X))
-    missing_indices = seq(1, nrow(pos)) * pos
-
-    #Put the replacement value in the missing indices
-    I2 = removeEmpty(target=missing_indices, margin="rows")
-    R = table(I2,1,minD,odim1 = nrow(X), odim2=1)
-
-    #Replace the 0 to avoid error in table()
-    R = replace(target = R, pattern = 0, replacement = nrow(X)+1)
-
-    #Create aligned matrix from minimum index
-    aligned = table(R, seq(1, nrow(X)), odim1 = nrow(X), odim2 = nrow(X))
-
-    #Reshape the subset
-    reshaped = rbind(M2, matrix(0, rows = nrow(X) - nrow(M2), cols = ncol(X)))
-
-    #Get the M2 records that need to be imputed
-    imputedValue = t(reshaped) %*% aligned
-
-    #Update the mask value
-    masked = t(imputedValue) * masked
+    imputedValue = compute_missing_values(X, filled_matrix, seed, 1.0)
   }
   else if(method == "dist_sample"){
     #assuming large missing values
+    imputedValue = compute_missing_values(X, filled_matrix, seed, sample_frac)
+  }
+  else {
+    stop("Method is unknown or not yet implemented")
+  }
+
+  #Impute the value
+  result = replace(target=X, pattern=NaN, replacement=0)
+  result = result + (imputedValue * is.nan(X))
+}
+
+compute_missing_values = function (Matrix[Double] X, Matrix[Double] 
filled_matrix, Int seed, Double sample_frac)
+    return (Matrix[Double] imputedValue)
+{
     #Split the matrix into containing NaN values (missing records) and not 
containing NaN values (M2 records)
-    I = rowSums(is.nan(X)) != 0
+    maskNAN = is.nan(X)
+    I = rowSums(maskNAN) != 0
     missing = removeEmpty(target=filled_matrix, margin="rows", select=I)
 
-    #Create permutation matrix for sampling sample_frac*nrow(X) rows
-    I = rand(rows=nrow(X), cols=1, seed=seed) <= sample_frac;
-    subset = removeEmpty(target=filled_matrix, margin="rows", select=I);
+    Y = (rowSums(maskNAN)==0)
+    M2 = removeEmpty(target=X, margin = "rows", select = Y)
+
+    if (sample_frac != 1.0) {
+        #Create permutation matrix for sampling sample_frac*nrow(X) rows
+        I = rand(rows=nrow(M2), cols=1, seed=seed) <= sample_frac;
+        M2 = removeEmpty(target=M2, margin="rows", select=I);
+    }
 
     #Calculate the euclidean distance between fully records and missing 
records, and then find the min value row wise
-    dotSubset = rowSums(subset * subset) %*% matrix(1, rows = 1, cols = 
nrow(missing))
-    dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols = 
nrow(subset)))
-    D = sqrt(dotSubset + dotMissing - 2 * (subset %*% t(missing)))
+    dotM2 = rowSums(M2 * M2) %*% matrix(1, rows = 1, cols = nrow(missing))
+    dotMissing = t(rowSums(missing * missing) %*% matrix(1, rows = 1, cols = 
nrow(M2)))
+    D = sqrt(dotM2 + dotMissing - 2 * (M2 %*% t(missing)))
     minD = rowIndexMin(t(D))
 
     #Get the index location of the missing value
-    pos = rowMaxs(is.nan(X))
+    pos = rowMaxs(maskNAN)
     missing_indices = seq(1, nrow(pos)) * pos
 
     #Put the replacement value in the missing indices
     I2 = removeEmpty(target=missing_indices, margin="rows")
-    R = table(I2,1,minD,odim1 = nrow(X), odim2=1)
+    R = table(I2, 1, minD, odim1=nrow(X), odim2=1)
 
     #Replace the 0 to avoid error in table()
-    R = replace(target = R, pattern = 0, replacement = nrow(X)+1)
+    R = replace(target=R, pattern=0, replacement=nrow(X)+1)
 
     #Create aligned matrix from minimum index
-    aligned = table(R, seq(1, nrow(X)), odim1 = nrow(X), odim2 = nrow(X))
+    aligned = table(R, seq(1, nrow(X)), odim1=nrow(X), odim2=nrow(X))
 
     #Reshape the subset
-    reshaped = rbind(subset, matrix(0, rows = nrow(X) - nrow(subset), cols = 
ncol(X)))
+    reshaped = rbind(M2, matrix(0, rows=nrow(X) - nrow(M2), cols=ncol(X)))
 
     #Get the subset records that need to be imputed
     imputedValue = t(reshaped) %*% aligned
-
-    #Update the mask value
-    masked = t(imputedValue) * masked
-  }
-  else {
-    print("Method is unknown or not yet implemented")
-  }
-
-  #Impute the value
-  result = replace(target = X, pattern = NaN, replacement = 0)
-  result = result + masked
+    imputedValue = t(imputedValue)
 }
-
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
index 627437fe76..2b7c422978 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImputeKNNTest.java
@@ -41,34 +41,36 @@ public class BuiltinImputeKNNTest extends AutomatedTestBase 
{
     @Override
     public void setUp() {
         TestUtils.clearAssertionInformation();
-        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME, new String[] {"B","B2"}));
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME, new String[] {"B","B2","B3"}));
     }
 
     @Test
     public void testDefaultCP()throws IOException{
-        runImputeKNN(true, Types.ExecType.CP);
+        runImputeKNN(Types.ExecType.CP);
     }
 
     @Test
     public void testDefaultSpark()throws IOException{
-        runImputeKNN(true, Types.ExecType.SPARK);
+        runImputeKNN(Types.ExecType.SPARK);
     }
 
-    private void runImputeKNN(boolean defaultProb, ExecType instType) throws 
IOException {
+    private void runImputeKNN(ExecType instType) throws IOException {
         ExecMode platform_old = setExecMode(instType);
         try {
             loadTestConfiguration(getTestConfiguration(TEST_NAME));
             String HOME = SCRIPT_DIR + TEST_DIR;
             fullDMLScriptName = HOME + TEST_NAME + ".dml";
             programArgs = new String[] {"-args", DATASET_DIR+"Salaries.csv", 
-               "dist", "dist_missing", output("B"), output("B2")};
+               "dist", "dist_missing", "dist_sample", "42", "0.9", 
output("B"), output("B2"), output("B3")};
 
             runTest(true, false, null, -1);
 
             //Compare matrices, check if the sum of the imputed value is 
roughly the same
             double sum1 = readDMLMatrixFromOutputDir("B").get(new 
CellIndex(1,1));
             double sum2 = readDMLMatrixFromOutputDir("B2").get(new 
CellIndex(1,1));
+            double sum3 = readDMLMatrixFromOutputDir("B3").get(new 
CellIndex(1,1));
             Assert.assertEquals(sum1, sum2, eps);
+            Assert.assertEquals(sum2, sum3, eps);
         }
         finally {
             rtplatform = platform_old;
diff --git a/src/test/scripts/functions/builtin/imputeByKNN.dml 
b/src/test/scripts/functions/builtin/imputeByKNN.dml
index 299c1dda30..0e87026e2b 100644
--- a/src/test/scripts/functions/builtin/imputeByKNN.dml
+++ b/src/test/scripts/functions/builtin/imputeByKNN.dml
@@ -28,15 +28,19 @@ mask = is.nan(X)
 #Perform the KNN imputation
 result = imputeByKNN(X = X, method = $2)
 result2 = imputeByKNN(X = X, method = $3)
+result3 = imputeByKNN(X = X, method = $4, seed = $5, sample_frac = $6)
 
 #Get the imputed value
 I = (mask[,2] == 1);
 value = removeEmpty(target = result, margin = "rows", select = I)
 value2 = removeEmpty(target = result2, margin = "rows", select = I)
+value3 = removeEmpty(target = result3, margin = "rows", select = I)
 
 #Get the sum of the imputed value
 value = colSums(value[,2])
 value2 = colSums(value2[,2])
+value3 = colSums(value3[,2])
 
-write(value, $4)
-write(value2, $5)
\ No newline at end of file
+write(value, $7)
+write(value2, $8)
+write(value3, $9)

Reply via email to