This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c4bffd9cc [SYSTEMDS-3819] New sliceLineExtract builtin function
6c4bffd9cc is described below

commit 6c4bffd9cc51aaef8a6254cb21becc6e60aeeb89
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Jan 25 11:38:10 2025 +0100

    [SYSTEMDS-3819] New sliceLineExtract builtin function
    
    This new sliceLineExtract builtin functions allows to take the output
    of sliceLine and extract the rows from X and e which belong to the top
    k2 <= k slices.
---
 scripts/builtin/sliceLineDebug.dml                 |  2 +-
 .../{sliceLineDebug.dml => sliceLineExtract.dml}   | 54 ++++++++++------------
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../part2/BuiltinSliceLineRealDataTest.java        |  6 ++-
 .../functions/builtin/sliceLineRealData.dml        |  7 ++-
 5 files changed, 36 insertions(+), 34 deletions(-)

diff --git a/scripts/builtin/sliceLineDebug.dml 
b/scripts/builtin/sliceLineDebug.dml
index c3eec20d91..687800dfab 100644
--- a/scripts/builtin/sliceLineDebug.dml
+++ b/scripts/builtin/sliceLineDebug.dml
@@ -26,7 +26,7 @@
 # INPUT:
 # 
------------------------------------------------------------------------------
 # TK      top-k slices (k x ncol(X) if successful)
-# TKC     score, size, error of slices (k x 3)
+# TKC     score, total/max error, size of slices (k x 4)
 # tfmeta  transformencode meta data
 # tfspec  transform specification
 # 
------------------------------------------------------------------------------
diff --git a/scripts/builtin/sliceLineDebug.dml 
b/scripts/builtin/sliceLineExtract.dml
similarity index 50%
copy from scripts/builtin/sliceLineDebug.dml
copy to scripts/builtin/sliceLineExtract.dml
index c3eec20d91..a4ab980bb8 100644
--- a/scripts/builtin/sliceLineDebug.dml
+++ b/scripts/builtin/sliceLineExtract.dml
@@ -19,48 +19,42 @@
 #
 #-------------------------------------------------------------
 
-# This builtin function takes the outputs of SliceLine and the
-# original transformencode meta data in order to print a human-
-# readable debug output of the resulting top-k slices.
+# This builtin function takes the outputs of SliceLine and allows
+#
 #
 # INPUT:
 # 
------------------------------------------------------------------------------
+# X       Feature matrix in recoded/binned representation
+# e       Error vector of trained model
 # TK      top-k slices (k x ncol(X) if successful)
-# TKC     score, size, error of slices (k x 3)
-# tfmeta  transformencode meta data
-# tfspec  transform specification
+# TKC     score, total/max error, size of slices (k x 4)
+# k2      fist k2 slices to extract with k2 <= k
 # 
------------------------------------------------------------------------------
 #
 # OUTPUT:
 # 
------------------------------------------------------------------------------
-# S     debug output collected as a string
+# Xtk     Selected rows from X which belong to k2 top slices
+# etk     Selected rows from e which belong to k2 top slices
 # 
------------------------------------------------------------------------------
 
-m_sliceLineDebug = function(Matrix[Double] TK,
-  Matrix[Double] TKC, Frame[Unknown] tfmeta, String tfspec)
-  return(Matrix[Double] S)
+m_sliceLineExtract = function(Matrix[Double] X, Matrix[Double] e,
+  Matrix[Double] TK, Matrix[Double] TKC, Integer k2 = -1)
+  return(Matrix[Double] Xtk, Matrix[Double] etk)
 {
-  print("\nsliceLineDebug: 
input\n"+toString(TK)+"\n"+toString(TKC)+"\n"+toString(tfmeta));
-
-  # prepare essential decoding info
-  N = colnames(tfmeta);
-  TKsafe = TK + (TK==0); # for vectorized decoding
-  FTK = transformdecode(target=TKsafe, meta=tfmeta, spec=tfspec);
+  # check valid parameters
+  if( k2 > nrow(TK) )
+    stop("sliceLineExtract: invalid number of slices to extract: "+k2+" > 
"+nrow(TK)).
+  if( k2 <= 0 )
+    k2 = nrow(TK);
 
-  # actual debug output
-  for(i in 1:nrow(TK)) {
-    TKi = TK[i,]; FTKi = FTK[i,];
-    print("-- Slice #"+i+": score="+as.scalar(TKC[i,1])+", 
size="+as.scalar(TKC[i,4]));
-    print("---- avg error="+as.scalar(TKC[i,2]/TKC[i,4])+", max 
error="+as.scalar(TKC[i,3]));
-    pred = "";
-    for(j in 1:ncol(TKi)) {
-        if( as.scalar(TKi[1,j]) != 0 ) {
-           tmp = as.scalar(N[1,j]) + " = " + as.scalar(FTK[i,j]);
-           pred = ifelse(pred=="", tmp, pred+" AND "+tmp);
-        }
-    }
-    print("---- predicate: "+pred);
+  # extract first k2 slices from X and e
+  I = matrix(0, k2, nrow(X));
+  parfor(i in 1:k2) {
+    I[i,] = t(rowSums(X == TK[i,]) == sum(TK[i,]))
   }
-  S = TK;
+  I = t(colSums(I)); #union
+
+  Xtk = removeEmpty(target=X, margin="rows", select=I);
+  etk = removeEmpty(target=e, margin="rows", select=I);
 }
 
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 5429cb287c..ab7400df44 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -312,6 +312,7 @@ public enum Builtins {
        SLICEFINDER("slicefinder", true), //TODO remove
        SLICELINE("sliceLine", true),
        SLICELINE_DEBUG("sliceLineDebug", true),
+       SLICELINE_EXTRACT("sliceLineExtract", true),
        SKEWNESS("skewness", true),
        SMAPE("smape", true),
        SMOTE("smote", true),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSliceLineRealDataTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSliceLineRealDataTest.java
index 4de711b37c..218051a2f6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSliceLineRealDataTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSliceLineRealDataTest.java
@@ -39,7 +39,7 @@ public class BuiltinSliceLineRealDataTest extends 
AutomatedTestBase {
        @Override
        public void setUp() {
                for(int i=1; i<=1; i++)
-                       addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+                       addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R","V"}));
        }
 
        @Test
@@ -55,12 +55,14 @@ public class BuiltinSliceLineRealDataTest extends 
AutomatedTestBase {
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
                        programArgs = new String[] {"-stats",
-                               "-args", data, tfspec, output("R")};
+                               "-args", data, tfspec, output("R"), 
output("V")};
 
                        runTest(true, false, null, -1);
 
                        double acc = readDMLMatrixFromOutputDir("R").get(new 
CellIndex(1,1));
+                       double val = readDMLMatrixFromOutputDir("V").get(new 
CellIndex(1,1));
                        Assert.assertTrue(acc >= minAcc);
+                       Assert.assertTrue(val >= 0.99);
                        Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                }
                finally {
diff --git a/src/test/scripts/functions/builtin/sliceLineRealData.dml 
b/src/test/scripts/functions/builtin/sliceLineRealData.dml
index 3145f2fb9d..91477c4f3c 100644
--- a/src/test/scripts/functions/builtin/sliceLineRealData.dml
+++ b/src/test/scripts/functions/builtin/sliceLineRealData.dml
@@ -45,9 +45,14 @@ acc = lmPredictStats(yhat, y, TRUE);
 e = (y-yhat)^2;
 
 # model debugging via sliceline
-[TK,TKC,D] = slicefinder(X=X, e=e, k=4, alpha=0.95, minSup=32, tpBlksz=16, 
verbose=TRUE)
+[TK,TKC,D] = sliceLine(X=X, e=e, k=4, alpha=0.95, minSup=32, tpBlksz=16, 
verbose=TRUE)
 tfspec2 = "{ ids:true, recode:[1,2,5], bin:[{id:3, method:equi-width, 
numbins:10},{id:4, method:equi-width, numbins:10}]}"
 XYZ = sliceLineDebug(TK=TK, TKC=TKC, tfmeta=meta, tfspec=tfspec2)
+[Xtk,etk] = sliceLineExtract(X=X, e=e, TK=TK, TKC=TKC, k2=3);
 
 acc = acc[3,1];
+val = as.matrix((sum(TKC[1,4]) >= nrow(Xtk)) & (nrow(Xtk) == nrow(etk)))
+
 write(acc, $3);
+write(val, $4);
+

Reply via email to