This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new adb8af1  [SYSTEMDS-3134] Fix robustness transformapply for unknown 
categories
adb8af1 is described below

commit adb8af1d5f490d58635c6e27b55cc0dd00b80a43
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Sep 15 14:39:23 2021 +0200

    [SYSTEMDS-3134] Fix robustness transformapply for unknown categories
    
    This patch fixes issues of the cleaning pipeline enumeration where
    transformapply corrupted the output sparse matrix with negative column
    indexes which then produce index out-of-bounds exceptions during sparse
    operations. We now handle these unknowns gracefully, but additional work
    is needed to set the outputs by position.
---
 .../java/org/apache/sysds/runtime/data/SparseBlockMCSR.java   | 10 ++++++----
 .../java/org/apache/sysds/runtime/data/SparseRowVector.java   |  5 +++++
 .../runtime/transform/encode/ColumnEncoderDummycode.java      | 11 ++++++++---
 .../pipelines/BuiltinTopkCleaningClassificationTest.java      |  2 --
 4 files changed, 19 insertions(+), 9 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
index 159e581..a733ea9 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockMCSR.java
@@ -195,7 +195,7 @@ public class SparseBlockMCSR extends SparseBlock
                        int[] aix = indexes(i);
                        double[] avals = values(i);
                        for (int k = apos + 1; k < apos + alen; k++) {
-                               if (aix[k-1] >= aix[k])
+                               if (aix[k-1] >= aix[k] | aix[k-1] < 0 )
                                        throw new RuntimeException("Wrong 
sparse row ordering, at row="+i+", pos="+k
                                                + " with column indexes " + 
aix[k-1] + ">=" + aix[k]);
                                if (avals[k] == 0)
@@ -205,10 +205,12 @@ public class SparseBlockMCSR extends SparseBlock
                }
 
                //3. A capacity that is no larger than nnz times resize factor
-               for( int i=0; i<rlen; i++ )
-                       if( !isEmpty(i) && values(i).length > 
nnz*RESIZE_FACTOR1 )
+               for( int i=0; i<rlen; i++ ) {
+                       long max_size = (long)Math.max(nnz*RESIZE_FACTOR1, 
INIT_CAPACITY);
+                       if( !isEmpty(i) && values(i).length > max_size )
                                throw new RuntimeException("The capacity is 
larger than nnz times a resize factor(=2). "
-                                       + "Actual length = " + 
values(i).length+", should not exceed "+nnz*RESIZE_FACTOR1);
+                                       + "Actual length = " + 
values(i).length+", should not exceed "+max_size);
+               }
 
                return true;
        }
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
index 38a9aba..6d67707 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseRowVector.java
@@ -195,6 +195,11 @@ public final class SparseRowVector extends SparseRow{
                return true; // nnz++
        }
        
+       public void setAtPos(int pos, int col, double v) {
+               indexes[pos] = col;
+               values[pos] = v;
+       }
+       
        @Override
        public boolean add(int col, double v) {
                //early abort on zero (if no overwrite)
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 1047f54..3643d00 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -75,12 +75,17 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
                for(int i = rowStart; i < getEndIndex(in.getNumRows(), 
rowStart, blk); i++) {
                        // Using outputCol here as index since we have a 
MatrixBlock as input where dummycoding could have been
                        // applied in a previous encoder
+                       // FIXME: we need a clear way of separating 
input/output (org input, pre-allocated output)
+                       // need input index to avoid inconsistencies; also need 
to set by position not binarysearch
                        double val = in.quickGetValueThreadSafe(i, outputCol);
                        int nCol = outputCol + (int) val - 1;
-                       // Setting value to 0 first in case of sparse so the 
row vector does not need to be resized
-                       if(nCol != outputCol)
+                       // Set value, w/ robustness for val=NaN (unknown 
categories)
+                       if( nCol >= 0 && !Double.isNaN(val) ) { // filter 
unknowns
+                               out.quickSetValue(i, outputCol, 0); //FIXME 
remove this workaround (see above)
+                               out.quickSetValue(i, nCol, 1);
+                       }
+                       else
                                out.quickSetValue(i, outputCol, 0);
-                       out.quickSetValue(i, nCol, 1);
                }
                if (DMLScript.STATISTICS)
                        
Statistics.incTransformDummyCodeApplyTime(System.nanoTime()-t0);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningClassificationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningClassificationTest.java
index 0c91513..47e1347 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningClassificationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkCleaningClassificationTest.java
@@ -45,8 +45,6 @@ public class BuiltinTopkCleaningClassificationTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
        }
 
-       // TODO fixing ArrayIndexOutOfBounds exception
-       @Ignore
        public void testFindBestPipelineCompany() {
                runtopkCleaning(DATA_DIR+ "company.csv", RESOURCE+ 
"meta/meta_company.csv", 1.0, 3,5,
                        "FALSE", 0,0.8, Types.ExecMode.SINGLE_NODE);

Reply via email to