This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cb932f26ce [SYSTEMDS-3365,3366] Extended transformencode w/ UDF 
transformations
cb932f26ce is described below

commit cb932f26ce281aba0cdd8f39b117cbcb2d09e3d0
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun May 1 21:47:14 2022 +0200

    [SYSTEMDS-3365,3366] Extended transformencode w/ UDF transformations
    
    This patch extends transformencode with support for applying arbitrary
    dml-based UDF functions on subsets of columns. Internally, these
    encoders reuse the existing eval-function mechanics, which further
    allows for reusing existing infrastructure for codegen, compression,
    and GPU operations. This extension is experimental and serves a
    systematic experimental evaluation of the integration into
    transformencode's fine-grained, column-oriented task graphs.
    
    Furthermore, this patch also includes some minor cleanups of builtin
    defaults, and javadoc/builtin documentation.
---
 scripts/builtin/scale.dml                          |  14 +--
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  10 +-
 .../apache/sysds/runtime/transform/TfUtils.java    |   2 +-
 .../runtime/transform/encode/ColumnEncoder.java    |   3 +-
 .../runtime/transform/encode/ColumnEncoderUDF.java | 135 +++++++++++++++++++++
 .../runtime/transform/encode/EncoderFactory.java   |  42 +++----
 .../sysds/runtime/transform/meta/TfMetaUtils.java  | 103 ++++++++++------
 .../transform/TransformEncodeUDFTest.java          |  89 ++++++++++++++
 .../functions/transform/TransformEncodeUDF1.dml    |  39 ++++++
 9 files changed, 363 insertions(+), 74 deletions(-)

diff --git a/scripts/builtin/scale.dml b/scripts/builtin/scale.dml
index a79f57a123..63a5f7fd87 100644
--- a/scripts/builtin/scale.dml
+++ b/scripts/builtin/scale.dml
@@ -26,20 +26,20 @@
 # NAME         TYPE              DEFAULT  MEANING
 # 
----------------------------------------------------------------------------------------------------------------------
 # X            Matrix[Double]    ---      Input feature matrix
-# Center       Boolean           TRUE     Indicates whether or not to center 
the feature matrix
-# Scale        Boolean           TRUE     Indicates whether or not to scale 
the feature matrix
+# center       Boolean           TRUE     Indicates whether or not to center 
the feature matrix
+# scale        Boolean           TRUE     Indicates whether or not to scale 
the feature matrix
 # 
----------------------------------------------------------------------------------------------------------------------
 #
 # OUTPUT:
 # 
----------------------------------------------------------------------------------------------------------------------
-# NAME         TYPE                      MEANING
+# NAME         TYPE                       MEANING
 # 
----------------------------------------------------------------------------------------------------------------------
-# Y            Matrix[Double]            Output feature matrix with K columns
-# Centering      Matrix[Double]            The column means of the input, 
subtracted if Center was TRUE
-# ScaleFactor  Matrix[Double]            The Scaling of the values, to make 
each dimension have similar value ranges
+# Y            Matrix[Double]             Output feature matrix with K columns
+# Centering    Matrix[Double]             The column means of the input, 
subtracted if Center was TRUE
+# ScaleFactor  Matrix[Double]             The Scaling of the values, to make 
each dimension have similar value ranges
 # 
----------------------------------------------------------------------------------------------------------------------
 
-m_scale = function(Matrix[Double] X, Boolean center, Boolean scale) 
+m_scale = function(Matrix[Double] X, Boolean center=TRUE, Boolean scale=TRUE) 
   return (Matrix[Double] out, Matrix[Double] Centering, Matrix[Double] 
ScaleFactor) 
 {
   if(center){
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 182b771138..4d5b97ff3d 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4147,11 +4147,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
         * 
         * This means that if you call with rl == ru then you get 1 row output.
         * 
-        * @param rl row lower if this value is bellow 0 or above the number of 
rows contained in the matrix an execption is thrown
-        * @param ru row upper if this value is bellow rl or above the number 
of rows contained in the matrix an exception is thrown
-        * @param cl column lower if this value us bellow 0 or above the number 
of columns contained in the matrix an exception is thrown
-        * @param cu column upper if this value us bellow cl or above the 
number of columns contained in the matrix an exception is thrown
-        * @param deep should perform deep copy, this is relelvant in cases 
where the matrix is in sparse format,
+        * @param rl row lower if this value is below 0 or above the number of 
rows contained in the matrix an exception is thrown
+        * @param ru row upper if this value is below rl or above the number of 
rows contained in the matrix an exception is thrown
+        * @param cl column lower if this value us below 0 or above the number 
of columns contained in the matrix an exception is thrown
+        * @param cu column upper if this value us below cl or above the number 
of columns contained in the matrix an exception is thrown
+        * @param deep should perform deep copy, this is relevant in cases 
where the matrix is in sparse format,
         *            or the entire matrix is sliced out
         * @param ret output sliced out matrix block
         * @return matrix block output matrix block
diff --git a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java 
b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
index d895fd15e7..ec4758a819 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
@@ -47,7 +47,7 @@ public class TfUtils implements Serializable
        
        //transform methods
        public enum TfMethod {
-               IMPUTE, RECODE, HASH, BIN, DUMMYCODE, SCALE, OMIT;
+               IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT;
                @Override
                public String toString() {
                        return name().toLowerCase();
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 4e969896c3..89423521b1 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -37,7 +37,6 @@ import java.util.concurrent.Callable;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.data.SparseRowVector;
@@ -65,7 +64,7 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
        protected int _nApplyPartitions = 0;
 
        protected enum TransformType{
-               BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, N_A
+               BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, N_A
        }
 
        protected ColumnEncoder(int colID) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
new file mode 100644
index 0000000000..15fa568d65
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.encode;
+
+import java.util.List;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.EvalNaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DependencyTask;
+
+public class ColumnEncoderUDF extends ColumnEncoder {
+
+       //TODO pass execution context through encoder factory for arbitrary 
functions not just builtin
+       //TODO handling udf after dummy coding
+       //TODO integration into IPA to ensure existence of unoptimized functions
+       
+       private final String _fName;
+       
+       protected ColumnEncoderUDF(int ptCols, String name) {
+               super(ptCols); // 1-based
+               _fName = name;
+       }
+
+       public ColumnEncoderUDF() {
+               this(-1, null);
+       }
+
+       @Override
+       protected TransformType getTransformType() {
+               return TransformType.UDF;
+       }
+
+       @Override
+       public void build(CacheBlock in) {
+               // do nothing
+       }
+
+       @Override
+       public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
+               return null;
+       }
+       
+       @Override
+       public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, 
int rowStart, int blk) {
+               //create execution context and input
+               ExecutionContext ec = ExecutionContextFactory.createContext(new 
Program(new DMLProgram()));
+               MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, 
_colID-1, new MatrixBlock());
+               ec.setVariable("I", ParamservUtils.newMatrixObject(col, true));
+               ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));
+               
+               //call UDF function via eval machinery
+               var fun = new EvalNaryCPInstruction(null, "eval", "",
+                       new CPOperand("O", ValueType.FP64, DataType.MATRIX),
+                       new CPOperand[] {
+                               new CPOperand(_fName, ValueType.STRING, 
DataType.SCALAR, true),
+                               new CPOperand("I", ValueType.FP64, 
DataType.MATRIX)});
+               fun.processInstruction(ec);
+               
+               //obtain result and in-place write back
+               MatrixBlock ret = 
((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease();
+               out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, 
_colID-1, ret, UpdateType.INPLACE);
+               return out;
+       }
+       
+       @Override
+       protected ColumnApplyTask<ColumnEncoderUDF> getSparseTask(CacheBlock in,
+               MatrixBlock out, int outputCol, int startRow, int blk)
+       {
+               throw new DMLRuntimeException("UDF encoders do not support 
sparse tasks.");
+       }
+       
+       @Override
+       public void mergeAt(ColumnEncoder other) {
+               if(other instanceof ColumnEncoderUDF)
+                       return;
+               super.mergeAt(other);
+       }
+
+       @Override
+       public void allocateMetaData(FrameBlock meta) {
+               // do nothing
+               return;
+       }
+
+       @Override
+       public FrameBlock getMetaData(FrameBlock meta) {
+               // do nothing
+               return meta;
+       }
+
+       @Override
+       public void initMetaData(FrameBlock meta) {
+               // do nothing
+       }
+
+       @Override
+       protected double getCode(CacheBlock in, int row) {
+               throw new DMLRuntimeException("UDF encoders only support full 
column access.");
+       }
+
+       @Override
+       protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) 
{
+               throw new DMLRuntimeException("UDF encoders only support full 
column access.");
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 33b7682076..f7f7a7f990 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -91,25 +91,19 @@ public class EncoderFactory {
                                .toObject(TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.OMIT.toString(), minCol, maxCol)));
                        List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
                                TfMetaUtils.parseJsonObjectIDList(jSpec, 
colnames, TfMethod.IMPUTE.toString(), minCol, maxCol)));
-
+                       List<Integer> udfIDs = 
TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol);
+                       
                        // create individual encoders
-                       if(!rcIDs.isEmpty()) {
-                               for(Integer id : rcIDs) {
-                                       ColumnEncoderRecode ra = new 
ColumnEncoderRecode(id);
-                                       addEncoderToMap(ra, colEncoders);
-                               }
-                       }
-                       if(!haIDs.isEmpty()) {
-                               for(Integer id : haIDs) {
-                                       ColumnEncoderFeatureHash ha = new 
ColumnEncoderFeatureHash(id, TfMetaUtils.getK(jSpec));
-                                       addEncoderToMap(ha, colEncoders);
-                               }
-                       }
+                       if(!rcIDs.isEmpty())
+                               for(Integer id : rcIDs)
+                                       addEncoderToMap(new 
ColumnEncoderRecode(id), colEncoders);
+                       if(!haIDs.isEmpty())
+                               for(Integer id : haIDs)
+                                       addEncoderToMap(new 
ColumnEncoderFeatureHash(id, TfMetaUtils.getK(jSpec)), colEncoders);
                        if(!ptIDs.isEmpty())
-                               for(Integer id : ptIDs) {
-                                       ColumnEncoderPassThrough pt = new 
ColumnEncoderPassThrough(id);
-                                       addEncoderToMap(pt, colEncoders);
-                               }
+                               for(Integer id : ptIDs)
+                                       addEncoderToMap(new 
ColumnEncoderPassThrough(id), colEncoders);
+                       
                        if(!binIDs.isEmpty())
                                for(Object o : (JSONArray) 
jSpec.get(TfMethod.BIN.toString())) {
                                        JSONObject colspec = (JSONObject) o;
@@ -129,10 +123,14 @@ public class EncoderFactory {
                                        addEncoderToMap(bin, colEncoders);
                                }
                        if(!dcIDs.isEmpty())
-                               for(Integer id : dcIDs) {
-                                       ColumnEncoderDummycode dc = new 
ColumnEncoderDummycode(id);
-                                       addEncoderToMap(dc, colEncoders);
-                               }
+                               for(Integer id : dcIDs)
+                                       addEncoderToMap(new 
ColumnEncoderDummycode(id), colEncoders);
+                       if(!udfIDs.isEmpty()) {
+                               String name = 
jSpec.getJSONObject("udf").getString("name");
+                               for(Integer id : udfIDs)
+                                       addEncoderToMap(new 
ColumnEncoderUDF(id, name), colEncoders);
+                       }
+                       
                        // create composite decoder of all created encoders
                        for(Entry<Integer, List<ColumnEncoder>> listEntry : 
colEncoders.entrySet()) {
                                if(DMLScript.STATISTICS)
@@ -190,6 +188,8 @@ public class EncoderFactory {
        }
 
        public static int getEncoderType(ColumnEncoder columnEncoder) {
+               //TODO replace with columnEncoder.getType().ordinal
+               //(which requires a cleanup of all type handling)
                if(columnEncoder instanceof ColumnEncoderBin)
                        return EncoderType.Bin.ordinal();
                else if(columnEncoder instanceof ColumnEncoderDummycode)
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java 
b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index ca83ff877f..801e3ef0a7 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -106,7 +106,6 @@ public class TfMetaUtils
        public static int[] parseJsonIDList(JSONObject spec, String[] colnames, 
String group, int minCol, int maxCol)
                throws JSONException
        {
-               List<Integer> colList = new ArrayList<>();
                int[] arr = new int[0];
                boolean ids = spec.containsKey("ids") && spec.getBoolean("ids");
                
@@ -119,30 +118,7 @@ public class TfMetaUtils
                        }
                        else
                                attrs = (JSONArray)spec.get(group);
-                       
-                       //construct ID list array
-                       for(int i=0; i < attrs.length(); i++) {
-                               int ix;
-                               if (ids) {
-                                       ix = UtilFunctions.toInt(attrs.get(i));
-                                       if(maxCol != -1 && ix >= maxCol)
-                                               ix = -1;
-                                       if(minCol != -1 && ix >= 0)
-                                               ix -= minCol - 1;
-                               }
-                               else {
-                                       ix = ArrayUtils.indexOf(colnames, 
attrs.get(i)) + 1;
-                               }
-                               if(ix > 0)
-                                       colList.add(ix);
-                               else if(minCol == -1 && maxCol == -1)
-                                       // only if we remove some columns, ix 
-1 is expected
-                                       throw new RuntimeException("Specified 
column '" + attrs.get(i) + "' does not exist.");
-                       }
-                       
-                       //ensure ascending order of column IDs
-                       arr = colList.stream().mapToInt((i) -> i)
-                               .sorted().toArray();
+                       arr = parseJsonPlainArrayIDList(attrs, colnames, 
minCol, maxCol, ids);
                }
                return arr;
        }
@@ -168,25 +144,59 @@ public class TfMetaUtils
 
        public static int[] parseJsonObjectIDList(JSONObject spec, String[] 
colnames, String group, int minCol, int maxCol)
                throws JSONException {
-               List<Integer> colList = new ArrayList<>();
                int[] arr = new int[0];
                boolean ids = spec.containsKey("ids") && spec.getBoolean("ids");
 
                if(spec.containsKey(group) && spec.get(group) instanceof 
JSONArray) {
                        JSONArray colspecs = (JSONArray) spec.get(group);
-                       for(Object o : colspecs) {
-                               JSONObject colspec = (JSONObject) o;
-                               int id = parseJsonObjectID(colspec, colnames, 
minCol, maxCol, ids);
-                               if(id > 0)
-                                       colList.add(id);
-                       }
-
-                       // ensure ascending order of column IDs
-                       arr = colList.stream().mapToInt((i) -> 
i).sorted().toArray();
+                       arr = parseJsonArrayIDList(colspecs, colnames, minCol, 
maxCol, ids);
                }
 
                return arr;
        }
+       
+       public static int[] parseJsonArrayIDList(JSONArray arr, String[] 
colnames, int minCol, int maxCol, boolean ids)
+               throws JSONException
+       {
+               List<Integer> colList = new ArrayList<>();
+               for(Object o : arr) {
+                       JSONObject colspec = (JSONObject) o;
+                       int id = parseJsonObjectID(colspec, colnames, minCol, 
maxCol, ids);
+                       if(id > 0)
+                               colList.add(id);
+               }
+
+               // ensure ascending order of column IDs
+               return colList.stream().mapToInt((i) -> i).sorted().toArray();
+       }
+       
+       public static int[] parseJsonPlainArrayIDList(JSONArray arr, String[] 
colnames, int minCol, int maxCol, boolean ids) {
+               List<Integer> colList = new ArrayList<>();
+               
+               //construct ID list array
+               for(int i=0; i < arr.length(); i++) {
+                       int ix;
+                       if (ids) {
+                               ix = UtilFunctions.toInt(arr.get(i));
+                               if(maxCol != -1 && ix >= maxCol)
+                                       ix = -1;
+                               if(minCol != -1 && ix >= 0)
+                                       ix -= minCol - 1;
+                       }
+                       else {
+                               ix = ArrayUtils.indexOf(colnames, arr.get(i)) + 
1;
+                       }
+                       if(ix > 0)
+                               colList.add(ix);
+                       else if(minCol == -1 && maxCol == -1)
+                               // only if we remove some columns, ix -1 is 
expected
+                               throw new RuntimeException("Specified column '" 
+ arr.get(i) + "' does not exist.");
+               }
+               
+               //ensure ascending order of column IDs
+               return colList.stream().mapToInt((i) -> i)
+                       .sorted().toArray();
+       }
 
        /**
         * Get K value used for calculation during feature hashing from parsed 
specifications.
@@ -429,16 +439,33 @@ public class TfMetaUtils
                        String binKey = TfMethod.BIN.toString();
                        if( jSpec.containsKey(binKey) && jSpec.get(binKey) 
instanceof JSONArray ) {
                                return Arrays.asList(ArrayUtils.toObject(
-                                               
TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, binKey, minCol, maxCol)));
+                                       parseJsonObjectIDList(jSpec, colnames, 
binKey, minCol, maxCol)));
+                       }
+                       else { //internally generated
+                               return Arrays.asList(ArrayUtils.toObject(
+                                       parseJsonIDList(jSpec, colnames, 
binKey)));
                        }
-                       else { //internally generates
+               }
+               catch(JSONException ex) {
+                       throw new IOException(ex);
+               }
+       }
+       
+       public static List<Integer> parseUDFColIDs(JSONObject jSpec, String[] 
colnames, int minCol, int maxCol)
+               throws IOException 
+       {
+               try {
+                       String binKey = TfMethod.UDF.toString();
+                       if( jSpec.containsKey(binKey) ) {
+                               JSONArray bin = 
jSpec.getJSONObject(binKey).getJSONArray("ids");
                                return Arrays.asList(ArrayUtils.toObject(
-                                               
TfMetaUtils.parseJsonIDList(jSpec, colnames, binKey)));
+                                       parseJsonPlainArrayIDList(bin, 
colnames, minCol, maxCol, true)));
                        }
                }
                catch(JSONException ex) {
                        throw new IOException(ex);
                }
+               return new ArrayList<>();
        }
        
        private static String getStringFromResource(String path) throws 
IOException {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
new file mode 100644
index 0000000000..aba517efbb
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class TransformEncodeUDFTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "TransformEncodeUDF1";
+       private final static String TEST_DIR = "functions/transform/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformEncodeUDFTest.class.getSimpleName() + "/";
+       
+       //dataset and transform tasks without missing values
+       private final static String DATASET = "homes3/homes.csv";
+       
+       @Override
+       public void setUp()  {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+       }
+       
+       @Test
+       public void testUDF1Singlenode() {
+               runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME1);
+       }
+       
+       @Test
+       public void testUDF1Hybrid() {
+               runTransformTest(ExecMode.HYBRID, TEST_NAME1);
+       }
+       
+       private void runTransformTest(ExecMode rt, String testname)
+       {
+               //set runtime platform
+               ExecMode rtold = setExecMode(rt);
+               
+               try
+               {
+                       getAndLoadTestConfiguration(TEST_NAME1);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       programArgs = new String[]{"-explain",
+                               "-nvargs", "DATA=" + DATASET_DIR + DATASET, 
"R="+output("R")};
+
+                       //compare transformencode+scale vs transformencode w/ 
UDF
+                       runTest(true, false, null, -1); 
+                       
+                       double ret = 
HDFSTool.readDoubleFromHDFSFile(output("R"));
+                       Assert.assertEquals(Double.valueOf(148*9), 
Double.valueOf(ret));
+                       
+                       if( rt == ExecMode.HYBRID ) {
+                               Long num = 
Long.valueOf(Statistics.getNoOfExecutedSPInst());
+                               Assert.assertEquals("Wrong number of executed 
Spark instructions: " + num, Long.valueOf(0), num);
+                       }
+               }
+               catch(Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       resetExecMode(rtold);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/transform/TransformEncodeUDF1.dml 
b/src/test/scripts/functions/transform/TransformEncodeUDF1.dml
new file mode 100644
index 0000000000..64cc470a6c
--- /dev/null
+++ b/src/test/scripts/functions/transform/TransformEncodeUDF1.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F1 = read($DATA, data_type="frame", format="csv");
+
+# reference solution with scale outside transformencode
+jspec = "{ids: true, recode: [1, 2, 7]}";
+[X, M] = transformencode(target=F1, spec=jspec);
+R1 = scaleMinMax(X);
+
+while(FALSE){}
+
+# reference solution with scale outside transformencode
+jspec2 = "{ids: true, recode: [1, 2, 7], udf: {name: scaleMinMax, ids: [1, 2, 
3, 4, 5, 6, 7, 8, 9]}}";
+[R2, M2] = transformencode(target=F1, spec=jspec2);
+
+while(FALSE){}
+
+R = sum(R1==R2);
+write(R, $R);
+

Reply via email to