This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a0987e536a [SYSTEMDS-3343,3366] Fix missing handling of positional 
defaults in eval
a0987e536a is described below

commit a0987e536a2be71d16d64ac64e9873206083e49b
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue May 10 20:46:29 2022 +0200

    [SYSTEMDS-3343,3366] Fix missing handling of positional defaults in eval
    
    This patch extends the recently added support for adding named defaults
    in eval function calls generic functions like gridSearch. We now
    extended this functionality for positional default as well, which
    broadens the set of functions that can be used in transformencode,
    UDF encoders.
---
 .../instructions/cp/EvalNaryCPInstruction.java     | 28 +++++++++++++++++++++-
 .../runtime/transform/encode/ColumnEncoderUDF.java |  6 +++--
 .../transform/TransformEncodeUDFTest.java          | 19 +++++++--------
 .../functions/transform/TransformEncodeUDF2.dml    |  2 +-
 4 files changed, 41 insertions(+), 14 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index b7d315c612..9f151e14cf 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -142,7 +142,9 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                        && !(fpb.getInputParams().size() == 1 && 
fpb.getInputParams().get(0).getDataType().isList()))
                {
                        ListObject lo = ec.getListObject(boundInputs[0]);
-                       lo = appendNamedDefaults(lo, fpb.getStatementBlock());
+                       lo = lo.isNamedList() ?
+                               appendNamedDefaults(lo, 
fpb.getStatementBlock()) :
+                               appendPositionalDefaults(lo, 
fpb.getStatementBlock());
                        checkValidArguments(lo.getData(), lo.getNames(), 
fpb.getInputParamNames());
                        if( lo.isNamedList() )
                                lo = reorderNamedListForFunctionCall(lo, 
fpb.getInputParamNames());
@@ -305,6 +307,30 @@ public class EvalNaryCPInstruction extends 
BuiltinNaryCPInstruction {
                return ret;
        }
        
+       private static ListObject appendPositionalDefaults(ListObject params, 
StatementBlock sb) {
+               if( sb == null )
+                       return params;
+               
+               //best effort replacement of scalar literal defaults
+               FunctionStatement fstmt = (FunctionStatement) 
sb.getStatement(0);
+               ListObject ret = new ListObject(params);
+               for( int i=ret.getLength(); i<fstmt.getInputParams().size(); 
i++ ) {
+                       String param = fstmt.getInputParamNames()[i];
+                       if( !(fstmt.getInputDefaults().get(i) != null
+                               && 
fstmt.getInputParams().get(i).getDataType().isScalar()
+                               && fstmt.getInputDefaults().get(i) instanceof 
ConstIdentifier) )
+                               throw new DMLRuntimeException("Unable to append 
positional scalar default for '"+param+"'");
+                       ValueType vt = 
fstmt.getInputParams().get(i).getValueType();
+                       Expression expr = fstmt.getInputDefaults().get(i);
+                       ScalarObject sobj = 
ScalarObjectFactory.createScalarObject(vt, expr.toString());
+                       LineageItem litem = !DMLScript.LINEAGE ? null :
+                               
LineageItemUtils.createScalarLineageItem(ScalarObjectFactory.createLiteralOp(sobj));
+                       ret.add(sobj, litem);
+               }
+               
+               return ret;
+       }
+       
        private static void checkValidArguments(List<Data> loData, List<String> 
loNames, List<String> fArgNames) {
                //check number of parameters
                int listSize = (loNames != null) ? loNames.size() : 
loData.size();
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
index 15fa568d65..a3f76623f2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
@@ -33,7 +33,9 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.EvalNaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.DependencyTask;
@@ -75,7 +77,7 @@ public class ColumnEncoderUDF extends ColumnEncoder {
                //create execution context and input
                ExecutionContext ec = ExecutionContextFactory.createContext(new 
Program(new DMLProgram()));
                MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, 
_colID-1, new MatrixBlock());
-               ec.setVariable("I", ParamservUtils.newMatrixObject(col, true));
+               ec.setVariable("I", new ListObject(new Data[] 
{ParamservUtils.newMatrixObject(col, true)}));
                ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));
                
                //call UDF function via eval machinery
@@ -83,7 +85,7 @@ public class ColumnEncoderUDF extends ColumnEncoder {
                        new CPOperand("O", ValueType.FP64, DataType.MATRIX),
                        new CPOperand[] {
                                new CPOperand(_fName, ValueType.STRING, 
DataType.SCALAR, true),
-                               new CPOperand("I", ValueType.FP64, 
DataType.MATRIX)});
+                               new CPOperand("I", ValueType.UNKNOWN, 
DataType.LIST)});
                fun.processInstruction(ec);
                
                //obtain result and in-place write back
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
index e7ccc8d582..1586a51b7d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
@@ -56,16 +56,15 @@ public class TransformEncodeUDFTest extends 
AutomatedTestBase
                runTransformTest(ExecMode.HYBRID, TEST_NAME1);
        }
 
-// TODO default handling without named lists
-//     @Test
-//     public void testUDF2Singlenode() {
-//             runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME2);
-//     }
-//     
-//     @Test
-//     public void testUDF2Hybrid() {
-//             runTransformTest(ExecMode.HYBRID, TEST_NAME2);
-//     }
+       @Test
+       public void testUDF2Singlenode() {
+               runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME2);
+       }
+
+       @Test
+       public void testUDF2Hybrid() {
+       runTransformTest(ExecMode.HYBRID, TEST_NAME2);
+       }
        
        private void runTransformTest(ExecMode rt, String testname)
        {
diff --git a/src/test/scripts/functions/transform/TransformEncodeUDF2.dml 
b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
index 233f25b73c..a62ca2d860 100644
--- a/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
+++ b/src/test/scripts/functions/transform/TransformEncodeUDF2.dml
@@ -34,6 +34,6 @@ jspec2 = "{ids: true, recode: [1, 2, 7], udf: {name: scale, 
ids: [1, 2, 3, 4, 5,
 
 while(FALSE){}
 
-R = sum(R1==R2);
+R = sum(abs(R1-R2)<1e-10);
 write(R, $R);
 

Reply via email to