This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cb932f26ce [SYSTEMDS-3365,3366] Extended transformencode w/ UDF
transformations
cb932f26ce is described below
commit cb932f26ce281aba0cdd8f39b117cbcb2d09e3d0
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun May 1 21:47:14 2022 +0200
[SYSTEMDS-3365,3366] Extended transformencode w/ UDF transformations
This patch extends transformencode with support for applying arbitrary
dml-based UDF functions on subsets of columns. Internally, these
encoders reuse the existing eval-function mechanics, which further
allows for reusing existing infrastructure for codegen, compression,
and GPU operations. This extension is experimental and serves a
systematic experimental evaluation of the integration into
transformencode's fine-grained, column-oriented task graphs.
Furthermore, this patch also includes some minor cleanups of builtin
defaults, and javadoc/builtin documentation.
---
scripts/builtin/scale.dml | 14 +--
.../sysds/runtime/matrix/data/MatrixBlock.java | 10 +-
.../apache/sysds/runtime/transform/TfUtils.java | 2 +-
.../runtime/transform/encode/ColumnEncoder.java | 3 +-
.../runtime/transform/encode/ColumnEncoderUDF.java | 135 +++++++++++++++++++++
.../runtime/transform/encode/EncoderFactory.java | 42 +++----
.../sysds/runtime/transform/meta/TfMetaUtils.java | 103 ++++++++++------
.../transform/TransformEncodeUDFTest.java | 89 ++++++++++++++
.../functions/transform/TransformEncodeUDF1.dml | 39 ++++++
9 files changed, 363 insertions(+), 74 deletions(-)
diff --git a/scripts/builtin/scale.dml b/scripts/builtin/scale.dml
index a79f57a123..63a5f7fd87 100644
--- a/scripts/builtin/scale.dml
+++ b/scripts/builtin/scale.dml
@@ -26,20 +26,20 @@
# NAME TYPE DEFAULT MEANING
#
----------------------------------------------------------------------------------------------------------------------
# X Matrix[Double] --- Input feature matrix
-# Center Boolean TRUE Indicates whether or not to center
the feature matrix
-# Scale Boolean TRUE Indicates whether or not to scale
the feature matrix
+# center Boolean TRUE Indicates whether or not to center
the feature matrix
+# scale Boolean TRUE Indicates whether or not to scale
the feature matrix
#
----------------------------------------------------------------------------------------------------------------------
#
# OUTPUT:
#
----------------------------------------------------------------------------------------------------------------------
-# NAME TYPE MEANING
+# NAME TYPE MEANING
#
----------------------------------------------------------------------------------------------------------------------
-# Y Matrix[Double] Output feature matrix with K columns
-# Centering Matrix[Double] The column means of the input,
subtracted if Center was TRUE
-# ScaleFactor Matrix[Double] The Scaling of the values, to make
each dimension have similar value ranges
+# Y Matrix[Double] Output feature matrix with K columns
+# Centering Matrix[Double] The column means of the input,
subtracted if Center was TRUE
+# ScaleFactor Matrix[Double] The Scaling of the values, to make
each dimension have similar value ranges
#
----------------------------------------------------------------------------------------------------------------------
-m_scale = function(Matrix[Double] X, Boolean center, Boolean scale)
+m_scale = function(Matrix[Double] X, Boolean center=TRUE, Boolean scale=TRUE)
return (Matrix[Double] out, Matrix[Double] Centering, Matrix[Double]
ScaleFactor)
{
if(center){
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 182b771138..4d5b97ff3d 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4147,11 +4147,11 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
*
* This means that if you call with rl == ru then you get 1 row output.
*
- * @param rl row lower if this value is bellow 0 or above the number of
rows contained in the matrix an execption is thrown
- * @param ru row upper if this value is bellow rl or above the number
of rows contained in the matrix an exception is thrown
- * @param cl column lower if this value us bellow 0 or above the number
of columns contained in the matrix an exception is thrown
- * @param cu column upper if this value us bellow cl or above the
number of columns contained in the matrix an exception is thrown
- * @param deep should perform deep copy, this is relelvant in cases
where the matrix is in sparse format,
+ * @param rl row lower if this value is below 0 or above the number of
rows contained in the matrix an exception is thrown
+ * @param ru row upper if this value is below rl or above the number of
rows contained in the matrix an exception is thrown
+ * @param cl column lower if this value us below 0 or above the number
of columns contained in the matrix an exception is thrown
+ * @param cu column upper if this value us below cl or above the number
of columns contained in the matrix an exception is thrown
+ * @param deep should perform deep copy, this is relevant in cases
where the matrix is in sparse format,
* or the entire matrix is sliced out
* @param ret output sliced out matrix block
* @return matrix block output matrix block
diff --git a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
index d895fd15e7..ec4758a819 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
@@ -47,7 +47,7 @@ public class TfUtils implements Serializable
//transform methods
public enum TfMethod {
- IMPUTE, RECODE, HASH, BIN, DUMMYCODE, SCALE, OMIT;
+ IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT;
@Override
public String toString() {
return name().toLowerCase();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 4e969896c3..89423521b1 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -37,7 +37,6 @@ import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseRowVector;
@@ -65,7 +64,7 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
protected int _nApplyPartitions = 0;
protected enum TransformType{
- BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, N_A
+ BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, N_A
}
protected ColumnEncoder(int colID) {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
new file mode 100644
index 0000000000..15fa568d65
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.encode;
+
+import java.util.List;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.EvalNaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DependencyTask;
+
+public class ColumnEncoderUDF extends ColumnEncoder {
+
+ //TODO pass execution context through encoder factory for arbitrary
functions not just builtin
+ //TODO handling udf after dummy coding
+ //TODO integration into IPA to ensure existence of unoptimized functions
+
+ private final String _fName;
+
+ protected ColumnEncoderUDF(int ptCols, String name) {
+ super(ptCols); // 1-based
+ _fName = name;
+ }
+
+ public ColumnEncoderUDF() {
+ this(-1, null);
+ }
+
+ @Override
+ protected TransformType getTransformType() {
+ return TransformType.UDF;
+ }
+
+ @Override
+ public void build(CacheBlock in) {
+ // do nothing
+ }
+
+ @Override
+ public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
+ return null;
+ }
+
+ @Override
+ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol,
int rowStart, int blk) {
+ //create execution context and input
+ ExecutionContext ec = ExecutionContextFactory.createContext(new
Program(new DMLProgram()));
+ MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1,
_colID-1, new MatrixBlock());
+ ec.setVariable("I", ParamservUtils.newMatrixObject(col, true));
+ ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));
+
+ //call UDF function via eval machinery
+ var fun = new EvalNaryCPInstruction(null, "eval", "",
+ new CPOperand("O", ValueType.FP64, DataType.MATRIX),
+ new CPOperand[] {
+ new CPOperand(_fName, ValueType.STRING,
DataType.SCALAR, true),
+ new CPOperand("I", ValueType.FP64,
DataType.MATRIX)});
+ fun.processInstruction(ec);
+
+ //obtain result and in-place write back
+ MatrixBlock ret =
((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease();
+ out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1,
_colID-1, ret, UpdateType.INPLACE);
+ return out;
+ }
+
+ @Override
+ protected ColumnApplyTask<ColumnEncoderUDF> getSparseTask(CacheBlock in,
+ MatrixBlock out, int outputCol, int startRow, int blk)
+ {
+ throw new DMLRuntimeException("UDF encoders do not support
sparse tasks.");
+ }
+
+ @Override
+ public void mergeAt(ColumnEncoder other) {
+ if(other instanceof ColumnEncoderUDF)
+ return;
+ super.mergeAt(other);
+ }
+
+ @Override
+ public void allocateMetaData(FrameBlock meta) {
+ // do nothing
+ return;
+ }
+
+ @Override
+ public FrameBlock getMetaData(FrameBlock meta) {
+ // do nothing
+ return meta;
+ }
+
+ @Override
+ public void initMetaData(FrameBlock meta) {
+ // do nothing
+ }
+
+ @Override
+ protected double getCode(CacheBlock in, int row) {
+ throw new DMLRuntimeException("UDF encoders only support full
column access.");
+ }
+
+ @Override
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ throw new DMLRuntimeException("UDF encoders only support full
column access.");
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 33b7682076..f7f7a7f990 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -91,25 +91,19 @@ public class EncoderFactory {
.toObject(TfMetaUtils.parseJsonIDList(jSpec,
colnames, TfMethod.OMIT.toString(), minCol, maxCol)));
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonObjectIDList(jSpec,
colnames, TfMethod.IMPUTE.toString(), minCol, maxCol)));
-
+ List<Integer> udfIDs =
TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol);
+
// create individual encoders
- if(!rcIDs.isEmpty()) {
- for(Integer id : rcIDs) {
- ColumnEncoderRecode ra = new
ColumnEncoderRecode(id);
- addEncoderToMap(ra, colEncoders);
- }
- }
- if(!haIDs.isEmpty()) {
- for(Integer id : haIDs) {
- ColumnEncoderFeatureHash ha = new
ColumnEncoderFeatureHash(id, TfMetaUtils.getK(jSpec));
- addEncoderToMap(ha, colEncoders);
- }
- }
+ if(!rcIDs.isEmpty())
+ for(Integer id : rcIDs)
+ addEncoderToMap(new
ColumnEncoderRecode(id), colEncoders);
+ if(!haIDs.isEmpty())
+ for(Integer id : haIDs)
+ addEncoderToMap(new
ColumnEncoderFeatureHash(id, TfMetaUtils.getK(jSpec)), colEncoders);
if(!ptIDs.isEmpty())
- for(Integer id : ptIDs) {
- ColumnEncoderPassThrough pt = new
ColumnEncoderPassThrough(id);
- addEncoderToMap(pt, colEncoders);
- }
+ for(Integer id : ptIDs)
+ addEncoderToMap(new
ColumnEncoderPassThrough(id), colEncoders);
+
if(!binIDs.isEmpty())
for(Object o : (JSONArray)
jSpec.get(TfMethod.BIN.toString())) {
JSONObject colspec = (JSONObject) o;
@@ -129,10 +123,14 @@ public class EncoderFactory {
addEncoderToMap(bin, colEncoders);
}
if(!dcIDs.isEmpty())
- for(Integer id : dcIDs) {
- ColumnEncoderDummycode dc = new
ColumnEncoderDummycode(id);
- addEncoderToMap(dc, colEncoders);
- }
+ for(Integer id : dcIDs)
+ addEncoderToMap(new
ColumnEncoderDummycode(id), colEncoders);
+ if(!udfIDs.isEmpty()) {
+ String name =
jSpec.getJSONObject("udf").getString("name");
+ for(Integer id : udfIDs)
+ addEncoderToMap(new
ColumnEncoderUDF(id, name), colEncoders);
+ }
+
// create composite decoder of all created encoders
for(Entry<Integer, List<ColumnEncoder>> listEntry :
colEncoders.entrySet()) {
if(DMLScript.STATISTICS)
@@ -190,6 +188,8 @@ public class EncoderFactory {
}
public static int getEncoderType(ColumnEncoder columnEncoder) {
+ //TODO replace with columnEncoder.getType().ordinal
+ //(which requires a cleanup of all type handling)
if(columnEncoder instanceof ColumnEncoderBin)
return EncoderType.Bin.ordinal();
else if(columnEncoder instanceof ColumnEncoderDummycode)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index ca83ff877f..801e3ef0a7 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -106,7 +106,6 @@ public class TfMetaUtils
public static int[] parseJsonIDList(JSONObject spec, String[] colnames,
String group, int minCol, int maxCol)
throws JSONException
{
- List<Integer> colList = new ArrayList<>();
int[] arr = new int[0];
boolean ids = spec.containsKey("ids") && spec.getBoolean("ids");
@@ -119,30 +118,7 @@ public class TfMetaUtils
}
else
attrs = (JSONArray)spec.get(group);
-
- //construct ID list array
- for(int i=0; i < attrs.length(); i++) {
- int ix;
- if (ids) {
- ix = UtilFunctions.toInt(attrs.get(i));
- if(maxCol != -1 && ix >= maxCol)
- ix = -1;
- if(minCol != -1 && ix >= 0)
- ix -= minCol - 1;
- }
- else {
- ix = ArrayUtils.indexOf(colnames,
attrs.get(i)) + 1;
- }
- if(ix > 0)
- colList.add(ix);
- else if(minCol == -1 && maxCol == -1)
- // only if we remove some columns, ix
-1 is expected
- throw new RuntimeException("Specified
column '" + attrs.get(i) + "' does not exist.");
- }
-
- //ensure ascending order of column IDs
- arr = colList.stream().mapToInt((i) -> i)
- .sorted().toArray();
+ arr = parseJsonPlainArrayIDList(attrs, colnames,
minCol, maxCol, ids);
}
return arr;
}
@@ -168,25 +144,59 @@ public class TfMetaUtils
public static int[] parseJsonObjectIDList(JSONObject spec, String[]
colnames, String group, int minCol, int maxCol)
throws JSONException {
- List<Integer> colList = new ArrayList<>();
int[] arr = new int[0];
boolean ids = spec.containsKey("ids") && spec.getBoolean("ids");
if(spec.containsKey(group) && spec.get(group) instanceof
JSONArray) {
JSONArray colspecs = (JSONArray) spec.get(group);
- for(Object o : colspecs) {
- JSONObject colspec = (JSONObject) o;
- int id = parseJsonObjectID(colspec, colnames,
minCol, maxCol, ids);
- if(id > 0)
- colList.add(id);
- }
-
- // ensure ascending order of column IDs
- arr = colList.stream().mapToInt((i) ->
i).sorted().toArray();
+ arr = parseJsonArrayIDList(colspecs, colnames, minCol,
maxCol, ids);
}
return arr;
}
+
+ public static int[] parseJsonArrayIDList(JSONArray arr, String[]
colnames, int minCol, int maxCol, boolean ids)
+ throws JSONException
+ {
+ List<Integer> colList = new ArrayList<>();
+ for(Object o : arr) {
+ JSONObject colspec = (JSONObject) o;
+ int id = parseJsonObjectID(colspec, colnames, minCol,
maxCol, ids);
+ if(id > 0)
+ colList.add(id);
+ }
+
+ // ensure ascending order of column IDs
+ return colList.stream().mapToInt((i) -> i).sorted().toArray();
+ }
+
+ public static int[] parseJsonPlainArrayIDList(JSONArray arr, String[]
colnames, int minCol, int maxCol, boolean ids) {
+ List<Integer> colList = new ArrayList<>();
+
+ //construct ID list array
+ for(int i=0; i < arr.length(); i++) {
+ int ix;
+ if (ids) {
+ ix = UtilFunctions.toInt(arr.get(i));
+ if(maxCol != -1 && ix >= maxCol)
+ ix = -1;
+ if(minCol != -1 && ix >= 0)
+ ix -= minCol - 1;
+ }
+ else {
+ ix = ArrayUtils.indexOf(colnames, arr.get(i)) +
1;
+ }
+ if(ix > 0)
+ colList.add(ix);
+ else if(minCol == -1 && maxCol == -1)
+ // only if we remove some columns, ix -1 is
expected
+ throw new RuntimeException("Specified column '"
+ arr.get(i) + "' does not exist.");
+ }
+
+ //ensure ascending order of column IDs
+ return colList.stream().mapToInt((i) -> i)
+ .sorted().toArray();
+ }
/**
* Get K value used for calculation during feature hashing from parsed
specifications.
@@ -429,16 +439,33 @@ public class TfMetaUtils
String binKey = TfMethod.BIN.toString();
if( jSpec.containsKey(binKey) && jSpec.get(binKey)
instanceof JSONArray ) {
return Arrays.asList(ArrayUtils.toObject(
-
TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, binKey, minCol, maxCol)));
+ parseJsonObjectIDList(jSpec, colnames,
binKey, minCol, maxCol)));
+ }
+ else { //internally generated
+ return Arrays.asList(ArrayUtils.toObject(
+ parseJsonIDList(jSpec, colnames,
binKey)));
}
- else { //internally generates
+ }
+ catch(JSONException ex) {
+ throw new IOException(ex);
+ }
+ }
+
+ public static List<Integer> parseUDFColIDs(JSONObject jSpec, String[]
colnames, int minCol, int maxCol)
+ throws IOException
+ {
+ try {
+ String binKey = TfMethod.UDF.toString();
+ if( jSpec.containsKey(binKey) ) {
+ JSONArray bin =
jSpec.getJSONObject(binKey).getJSONArray("ids");
return Arrays.asList(ArrayUtils.toObject(
-
TfMetaUtils.parseJsonIDList(jSpec, colnames, binKey)));
+ parseJsonPlainArrayIDList(bin,
colnames, minCol, maxCol, true)));
}
}
catch(JSONException ex) {
throw new IOException(ex);
}
+ return new ArrayList<>();
}
private static String getStringFromResource(String path) throws
IOException {
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
new file mode 100644
index 0000000000..aba517efbb
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeUDFTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class TransformEncodeUDFTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "TransformEncodeUDF1";
+ private final static String TEST_DIR = "functions/transform/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
TransformEncodeUDFTest.class.getSimpleName() + "/";
+
+ //dataset and transform tasks without missing values
+ private final static String DATASET = "homes3/homes.csv";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+ }
+
+ @Test
+ public void testUDF1Singlenode() {
+ runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME1);
+ }
+
+ @Test
+ public void testUDF1Hybrid() {
+ runTransformTest(ExecMode.HYBRID, TEST_NAME1);
+ }
+
+ private void runTransformTest(ExecMode rt, String testname)
+ {
+ //set runtime platform
+ ExecMode rtold = setExecMode(rt);
+
+ try
+ {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[]{"-explain",
+ "-nvargs", "DATA=" + DATASET_DIR + DATASET,
"R="+output("R")};
+
+ //compare transformencode+scale vs transformencode w/
UDF
+ runTest(true, false, null, -1);
+
+ double ret =
HDFSTool.readDoubleFromHDFSFile(output("R"));
+ Assert.assertEquals(Double.valueOf(148*9),
Double.valueOf(ret));
+
+ if( rt == ExecMode.HYBRID ) {
+ Long num =
Long.valueOf(Statistics.getNoOfExecutedSPInst());
+ Assert.assertEquals("Wrong number of executed
Spark instructions: " + num, Long.valueOf(0), num);
+ }
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ resetExecMode(rtold);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/transform/TransformEncodeUDF1.dml
b/src/test/scripts/functions/transform/TransformEncodeUDF1.dml
new file mode 100644
index 0000000000..64cc470a6c
--- /dev/null
+++ b/src/test/scripts/functions/transform/TransformEncodeUDF1.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F1 = read($DATA, data_type="frame", format="csv");
+
+# reference solution with scale outside transformencode
+jspec = "{ids: true, recode: [1, 2, 7]}";
+[X, M] = transformencode(target=F1, spec=jspec);
+R1 = scaleMinMax(X);
+
+while(FALSE){}
+
+# reference solution with scale outside transformencode
+jspec2 = "{ids: true, recode: [1, 2, 7], udf: {name: scaleMinMax, ids: [1, 2,
3, 4, 5, 6, 7, 8, 9]}}";
+[R2, M2] = transformencode(target=F1, spec=jspec2);
+
+while(FALSE){}
+
+R = sum(R1==R2);
+write(R, $R);
+