This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7115b3707a [SYSTEMDS-3613] Fix missing size propagation on 
transformapply/decode
7115b3707a is described below

commit 7115b3707a802026ee287ce82d666b1b756941b5
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Aug 11 15:08:10 2023 +0200

    [SYSTEMDS-3613] Fix missing size propagation on transformapply/decode
    
    This patch fixes the missing size propagation for transformapply and
    transformdecode. By parsing the transformspec and/or using the meta
    data frame (of original number of columns) we now infer the rows/cols
    unless there are encoders that change the number of columns. For
    feature hashing we could also support it, but for the sake of simplicity
    currently don't do it.
---
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  | 32 ++++++++++++++++++----
 .../sysds/runtime/transform/meta/TfMetaUtils.java  | 15 ++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 4e70a3bf09..01883e2f5d 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -50,7 +50,10 @@ import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
+import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
 import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONObject;
 
 
 /**
@@ -840,16 +843,35 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                        }
                        case TRANSFORMDECODE: {
                                Hop target = getTargetHop();
+                               Hop meta = getParameterHop("meta");
                                //rows remain unchanged for recoding and dummy 
coding
-                               setDim1( target.getDim1() );
-                               //cols remain unchanged only if no dummy coding
-                               //TODO parse json spec
+                               setDim1(target.getDim1());
+                               //cols remain unchanged only if no dummy 
coding, but meta aligned with input columns
+                               setDim2(meta.getDim2());
                                break;
                        }
                        case TRANSFORMAPPLY: {
                                //rows remain unchanged only if no omitting
-                               //cols remain unchanged of no dummy coding 
-                               //TODO parse json spec
+                               //cols remain unchanged of no dummy coding, 
feature hashing, word embeddings
+                               Hop target = getTargetHop();
+                               Hop spec = getParameterHop("spec");
+                               if( dimsKnown() ) {
+                                       //safe to update according to new input 
as previously parsed 
+                                       setDim1(target.getDim1());
+                                       setDim2(target.getDim2());
+                               }
+                               else if( spec instanceof LiteralOp ) {
+                                       try {
+                                               JSONObject jspec = new 
JSONObject(((LiteralOp)spec).getStringValue());
+                                               if( 
TfMetaUtils.checkValidEncoders(jspec, TfMethod.RECODE, TfMethod.BIN, 
TfMethod.UDF) ) {
+                                                       
setDim1(target.getDim1());
+                                                       
setDim2(target.getDim2());
+                                               }
+                                       }
+                                       catch(Exception ex) {
+                                               throw new HopsException(ex);
+                                       }
+                               }
                                break;
                        }
                        case TRANSFORMCOLMAP: {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java 
b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 5ae26b1c3a..99fbe92bf2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -490,4 +490,19 @@ public class TfMetaUtils
                                throw new DMLRuntimeException("Transform 
specification includes an invalid encoder: "+key);
                }
        }
+       
+       @SuppressWarnings("unchecked")
+       public static boolean checkValidEncoders(JSONObject jSpec, TfMethod... 
encoders) {
+               Set<String> validEncoders = new HashSet<>();
+               validEncoders.addAll(Arrays.asList("ids","K"));
+               for( TfMethod tf : encoders )
+                       validEncoders.add(tf.toString());
+               Iterator<String> keys = jSpec.keys();
+               while( keys.hasNext() ) {
+                       String key = keys.next();
+                       if( !validEncoders.contains(key) )
+                               return false;
+               }
+               return true;
+       }
 }

Reply via email to