This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9da2fb8 [SYSTEMDS-3119] Multi-return eval function calls (evalList)
9da2fb8 is described below
commit 9da2fb82dd2ab8857cd42947ad96278980b16eb8
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Dec 29 23:51:20 2021 +0100
[SYSTEMDS-3119] Multi-return eval function calls (evalList)
This patch extends the existing eval() second-order function calls
(which return a single matrix) by evalList that bundles multiple
returns into a named list. This approach allows reusing all existing
primitives as they are, yet support better state management in data
cleaning pipelines. In detail, we provide a new language-level
builtin function evalList, but both eval and evalList are parsed
to eval operations, simply with different output type, and at
runtime, we handle the functions accordingly.
Additional changes that showed up during the tests include:
* New rewrite for list indexes (avoid unnecessary instructions)
* Extended rewrite for DAG splits after data-dependent operators
(include persistent writes into consideration to avoid Spark ops)
* Cleanup right indexing lop construction (old MR code)
* Fix for invalid dimensions checks for list indexing
* Fix selected tests for reduced # expected spark jobs
---
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../java/org/apache/sysds/hops/IndexingOp.java | 15 ++---
.../hops/ipa/IPAPassReplaceEvalFunctionCalls.java | 5 ++
.../RewriteAlgebraicSimplificationStatic.java | 16 ++++-
.../RewriteSplitDagDataDependentOperators.java | 2 +-
.../java/org/apache/sysds/lops/RightIndex.java | 58 ++++------------
.../sysds/parser/BuiltinFunctionExpression.java | 6 +-
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../org/apache/sysds/parser/StatementBlock.java | 24 +++----
.../controlprogram/FunctionProgramBlock.java | 4 ++
.../instructions/cp/EvalNaryCPInstruction.java | 77 ++++++++++++++--------
.../instructions/cp/VariableCPInstruction.java | 2 +-
.../builtin/part2/BuiltinNormalizeTest.java | 48 +++++++++++++-
.../functions/recompile/RandJobRecompileTest.java | 14 ++--
.../functions/builtin/normalizeListEval.dml | 25 +++++++
.../functions/builtin/normalizeListEvalAll.dml | 30 +++++++++
16 files changed, 219 insertions(+), 109 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 58fced1..c59a76f 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -123,6 +123,7 @@ public enum Builtins {
EXECUTE_PIPELINE("executePipeline", true),
EXP("exp", false),
EVAL("eval", false),
+ EVALLIST("evalList", false),
FIX_INVALID_LENGTHS("fixInvalidLengths", true),
FF_TRAIN("ffTrain", true),
FF_PREDICT("ffPredict", true),
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 17c097e..f10400d 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -26,7 +26,6 @@ import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
-import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.RightIndex;
@@ -140,10 +139,8 @@ public class IndexingOp extends Hop
SparkAggType aggtype =
(method==IndexingMethod.MR_VRIX || isBlockAligned()) ?
SparkAggType.NONE :
SparkAggType.MULTI_BLOCK;
- Lop dummy =
Data.createLiteralLop(ValueType.INT64, Integer.toString(-1));
- RightIndex reindex = new RightIndex(
- input.constructLops(),
getInput().get(1).constructLops(), getInput().get(2).constructLops(),
-
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy,
dummy,
+ RightIndex reindex = new
RightIndex(input.constructLops(), getInput(1).constructLops(),
+ getInput(2).constructLops(),
getInput(3).constructLops(), getInput(4).constructLops(),
getDataType(), getValueType(),
aggtype, et);
setOutputDimensions(reindex);
@@ -152,11 +149,9 @@ public class IndexingOp extends Hop
}
else //CP or GPU
{
- Lop dummy =
Data.createLiteralLop(ValueType.INT64, Integer.toString(-1));
- RightIndex reindex = new RightIndex(
- input.constructLops(),
getInput().get(1).constructLops(), getInput().get(2).constructLops(),
-
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy,
dummy,
- getDataType(),
getValueType(), et);
+ RightIndex reindex = new
RightIndex(input.constructLops(), getInput(1).constructLops(),
+ getInput(2).constructLops(),
getInput(3).constructLops(), getInput(4).constructLops(),
+ getDataType(), getValueType(),
et);
setOutputDimensions(reindex);
setLineNumbers(reindex);
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java
index f3d5b93..07b6fca 100644
---
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java
+++
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.java
@@ -150,6 +150,11 @@ public class IPAPassReplaceEvalFunctionCalls extends
IPAPass
+ "applicable for replacement,
but list inputs not yet supported.");
continue;
}
+ if( eval.getDataType().isList() ) {
+ LOG.warn("IPA:
eval("+fnamespace+"::"+fname+") "
+ + "applicable for replacement,
but list output not yet supported.");
+ continue;
+ }
if( fstmt.getOutputParams().size() != 1 ||
!fstmt.getOutputParams().get(0).getDataType().isMatrix() ) {
LOG.warn("IPA:
eval("+fnamespace+"::"+fname+") "
+ "applicable for replacement,
but function output is not a matrix.");
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 56854ff..7f658ac 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -165,6 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = fuseBinarySubDAGToUnaryOperation(hop, hi,
i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) ->
selp(X)
hi = simplifyTraceMatrixMult(hop, hi, i);
//e.g., trace(X%*%Y)->sum(X*t(Y));
hi = simplifySlicedMatrixMult(hop, hi, i);
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
+ hi = simplifyListIndexing(hi);
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
hi = simplifyConstantSort(hop, hi, i);
//e.g., order(matrix())->matrix/seq;
hi = simplifyOrderedSort(hop, hi, i);
//e.g., order(matrix())->seq;
hi = fuseOrderOperationChain(hi);
//e.g., order(order(X,2),1) -> order(X,(12))
@@ -1390,12 +1391,23 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
mm.refreshSizeInformation();
hi = mm;
-
- LOG.debug("Applied simplifySlicedMatrixMult");
+
+ LOG.debug("Applied simplifySlicedMatrixMult");
}
return hi;
}
+
+ private static Hop simplifyListIndexing(Hop hi) {
+ //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
+ if( hi instanceof IndexingOp && hi.getDataType().isList()
+ && !(hi.getInput(4) instanceof LiteralOp) )
+ {
+ HopRewriteUtils.replaceChildReference(hi,
hi.getInput(4), new LiteralOp(1));
+ LOG.debug("Applied simplifyListIndexing (line
"+hi.getBeginLine()+").");
+ }
+ return hi;
+ }
private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
{
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index 25dbd78..239ca23 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -230,7 +230,7 @@ public class RewriteSplitDagDataDependentOperators extends
StatementBlockRewrite
return;
//prevent unnecessary dag split (dims known or no consumer
operations)
- boolean noSplitRequired =
(HopRewriteUtils.hasOnlyWriteParents(hop, true, true)
+ boolean noSplitRequired =
(HopRewriteUtils.hasOnlyWriteParents(hop, true, false)
|| hop.dimsKnown() || DMLScript.getGlobalExecMode() ==
ExecMode.SINGLE_NODE);
boolean investigateChilds = true;
diff --git a/src/main/java/org/apache/sysds/lops/RightIndex.java
b/src/main/java/org/apache/sysds/lops/RightIndex.java
index bc2e414..7858638 100644
--- a/src/main/java/org/apache/sysds/lops/RightIndex.java
+++ b/src/main/java/org/apache/sysds/lops/RightIndex.java
@@ -36,46 +36,42 @@ public class RightIndex extends Lop
//optional attribute for spark exec type
private SparkAggType _aggtype = SparkAggType.MULTI_BLOCK;
- public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
Lop rowDim, Lop colDim,
- DataType dt, ValueType vt, ExecType et, boolean forleft)
+ public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
+ DataType dt, ValueType vt, ExecType et, boolean forleft)
{
super(Lop.Type.RightIndex, dt, vt);
- init(input, rowL, rowU, colL, colU, rowDim, colDim, dt, vt, et,
forleft);
+ init(input, rowL, rowU, colL, colU, dt, vt, et, forleft);
}
- public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
Lop rowDim, Lop colDim,
- DataType dt, ValueType vt, ExecType et)
+ public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
+ DataType dt, ValueType vt, ExecType et)
{
super(Lop.Type.RightIndex, dt, vt);
- init(input, rowL, rowU, colL, colU, rowDim, colDim, dt, vt, et,
false);
+ init(input, rowL, rowU, colL, colU, dt, vt, et, false);
}
- public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
Lop rowDim, Lop colDim,
- DataType dt, ValueType vt, SparkAggType aggtype,
ExecType et)
+ public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
+ DataType dt, ValueType vt, SparkAggType aggtype, ExecType et)
{
super(Lop.Type.RightIndex, dt, vt);
_aggtype = aggtype;
- init(input, rowL, rowU, colL, colU, rowDim, colDim, dt, vt, et,
false);
+ init(input, rowL, rowU, colL, colU, dt, vt, et, false);
}
- private void init(Lop inputMatrix, Lop rowL, Lop rowU, Lop colL, Lop
colU, Lop leftMatrixRowDim,
- Lop leftMatrixColDim, DataType dt, ValueType vt,
ExecType et, boolean forleft)
- {
+ private void init(Lop inputMatrix, Lop rowL, Lop rowU, Lop colL, Lop
colU,
+ DataType dt, ValueType vt, ExecType et, boolean forleft)
+ {
addInput(inputMatrix);
addInput(rowL);
addInput(rowU);
addInput(colL);
addInput(colU);
- addInput(leftMatrixRowDim);
- addInput(leftMatrixColDim);
inputMatrix.addOutput(this);
rowL.addOutput(this);
rowU.addOutput(this);
colL.addOutput(this);
colU.addOutput(this);
- leftMatrixRowDim.addOutput(this);
- leftMatrixColDim.addOutput(this);
lps.setProperties(inputs, et);
forLeftIndexing=forleft;
}
@@ -93,7 +89,7 @@ public class RightIndex extends Lop
}
@Override
- public String getInstructions(String input, String rowl, String rowu,
String coll, String colu, String leftRowDim, String leftColDim, String output) {
+ public String getInstructions(String input, String rowl, String rowu,
String coll, String colu, String output) {
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
sb.append( OPERAND_DELIMITOR );
@@ -124,40 +120,14 @@ public class RightIndex extends Lop
//in case of spark, we also compile the optional aggregate flag
into the instruction.
if( getExecType() == ExecType.SPARK ) {
sb.append( OPERAND_DELIMITOR );
- sb.append( _aggtype );
+ sb.append( _aggtype );
}
return sb.toString();
}
@Override
- public String getInstructions(int input_index1, int input_index2, int
input_index3, int input_index4, int input_index5, int input_index6, int
input_index7, int output_index) {
- /*
- * Example: B = A[row_l:row_u, col_l:col_u]
- * A - input matrix (input_index1)
- * row_l - lower bound in row dimension
- * row_u - upper bound in row dimension
- * col_l - lower bound in column dimension
- * col_u - upper bound in column dimension
- *
- * Since row_l,row_u,col_l,col_u are scalars, values for
input_index(2,3,4,5,6,7)
- * will be equal to -1. They should be ignored and the scalar
value labels must
- * be derived from input lops.
- */
- String rowl = getInputs().get(1).prepScalarLabel();
- String rowu = getInputs().get(2).prepScalarLabel();
- String coll = getInputs().get(3).prepScalarLabel();
- String colu = getInputs().get(4).prepScalarLabel();
-
- String left_nrow = getInputs().get(5).prepScalarLabel();
- String left_ncol = getInputs().get(6).prepScalarLabel();
-
- return getInstructions(Integer.toString(input_index1), rowl,
rowu, coll, colu, left_nrow, left_ncol, Integer.toString(output_index));
- }
-
- @Override
public String toString() {
return getOpcode();
}
-
}
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index c2f72a5..e3cb0ee 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -550,11 +550,13 @@ public class BuiltinFunctionExpression extends
DataIdentifier
switch (getOpCode()) {
case EVAL:
+ case EVALLIST:
if (_args.length == 0)
raiseValidateError("Function eval should
provide at least one argument, i.e., the function name.", false);
checkValueTypeParam(_args[0], ValueType.STRING);
- output.setDataType(DataType.MATRIX);
- output.setValueType(ValueType.FP64);
+ boolean listReturn = (getOpCode()==Builtins.EVALLIST);
+ output.setDataType(listReturn ? DataType.LIST :
DataType.MATRIX);
+ output.setValueType(listReturn ? ValueType.UNKNOWN :
ValueType.FP64);
output.setDimensions(-1, -1);
output.setBlocksize(ConfigurationManager.getBlocksize());
break;
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 08c7ebf..50e036d 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2244,6 +2244,7 @@ public class DMLTranslator
switch (source.getOpCode()) {
case EVAL:
+ case EVALLIST:
currBuiltinOp = new NaryOp(target.getName(),
target.getDataType(), target.getValueType(),
OpOpN.EVAL,
processAllExpressions(source.getAllExpr(), hops));
break;
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 570af39..4f8cd1b 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -994,19 +994,21 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
// validate that size of LHS index ranges is being
assigned:
// (a) a matrix value of same size as LHS
- // (b) singleton value (semantics: initialize
enitre submatrix with this value)
+ // (b) singleton value (semantics: initialize
entire submatrix with this value)
IndexPair targetSize =
((IndexedIdentifier)target).calculateIndexedDimensions(ids.getVariables(),
currConstVars, conditional);
- if (targetSize._row >= 1 &&
source.getOutput().getDim1() > 1 && targetSize._row !=
source.getOutput().getDim1()){
- target.raiseValidateError("Dimension mismatch.
Indexed expression " + target.toString() + " can only be assigned matrix with
dimensions "
- + targetSize._row + " rows and
" + targetSize._col + " cols. Attempted to assign matrix with dimensions "
- + source.getOutput().getDim1()
+ " rows and " + source.getOutput().getDim2() + " cols ", conditional);
- }
-
- if (targetSize._col >= 1 &&
source.getOutput().getDim2() > 1 && targetSize._col !=
source.getOutput().getDim2()){
- target.raiseValidateError("Dimension mismatch.
Indexed expression " + target.toString() + " can only be assigned matrix with
dimensions "
- + targetSize._row + " rows and
" + targetSize._col + " cols. Attempted to assign matrix with dimensions "
- + source.getOutput().getDim1()
+ " rows and " + source.getOutput().getDim2() + " cols ", conditional);
+ if( target.getDataType().isMatrixOrFrame() ) {
+ if (targetSize._row >= 1 &&
source.getOutput().getDim1() > 1 && targetSize._row !=
source.getOutput().getDim1()){
+ target.raiseValidateError("Dimension
mismatch. Indexed expression " + target.toString() + " can only be assigned
matrix with dimensions "
+ + targetSize._row + "
rows and " + targetSize._col + " cols. Attempted to assign matrix with
dimensions "
+ +
source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + "
cols ", conditional);
+ }
+
+ if (targetSize._col >= 1 &&
source.getOutput().getDim2() > 1 && targetSize._col !=
source.getOutput().getDim2()){
+ target.raiseValidateError("Dimension
mismatch. Indexed expression " + target.toString() + " can only be assigned
matrix with dimensions "
+ + targetSize._row + "
rows and " + targetSize._col + " cols. Attempted to assign matrix with
dimensions "
+ +
source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + "
cols ", conditional);
+ }
}
((IndexedIdentifier)target).setDimensions(targetSize._row, targetSize._col);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index b23c784..f8cded4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -72,6 +72,10 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
return _inputParams.stream().map(d ->
d.getName()).collect(Collectors.toList());
}
+ public List<String> getOutputParamNames() {
+ return _outputParams.stream().map(d ->
d.getName()).collect(Collectors.toList());
+ }
+
public ArrayList<DataIdentifier> getInputParams(){
return _inputParams;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
index 1f22e98..0bbb242 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.java
@@ -67,7 +67,14 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
- //1. get the namespace and func
+ // There are two main types of eval function calls, which share
most of the
+ // code for lazy function loading and execution:
+ // a) a single-return eval fcall returns a matrix which is
bound to the output
+ // (if the function returns multiple objects, the first one
is used as output)
+ // b) a multi-return eval fcall gets all returns of the
function call and
+ // creates a named list used the names of the function
signature
+
+ //1. get the namespace and function names
String funcName = ec.getScalarInput(inputs[0]).getStringValue();
String nsName = null; //default namespace
if( funcName.contains(Program.KEY_DELIM) ) {
@@ -76,14 +83,13 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
nsName = parts[0];
}
- // bound the inputs to avoiding being deleted after the
function call
+ // bind the inputs to avoiding being deleted after the function
call
CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1,
inputs.length);
- List<String> boundOutputNames = new ArrayList<>();
- boundOutputNames.add(output.getName());
-
+
//2. copy the created output matrix
- MatrixObject outputMO = new
MatrixObject(ec.getMatrixObject(output.getName()));
-
+ MatrixObject outputMO = !output.isMatrix() ? null :
+ new MatrixObject(ec.getMatrixObject(output.getName()));
+
//3. lazy loading of dml-bodied builtin functions (incl. rename
// of function name to dml-bodied builtin scheme
(data-type-specific)
DataType dt1 = boundInputs[0].getDataType().isList() ?
@@ -138,34 +144,51 @@ public class EvalNaryCPInstruction extends
BuiltinNaryCPInstruction {
boundInputs2[i] = new CPOperand(varName, in);
}
boundInputs = boundInputs2;
- lineageInputs = DMLScript.LINEAGE
- ? lo.getLineageItems().toArray(new
LineageItem[lo.getLength()]) : null;
+ lineageInputs = !DMLScript.LINEAGE ? null :
+ lo.getLineageItems().toArray(new
LineageItem[lo.getLength()]);
}
+ // bind the outputs
+ List<String> boundOutputNames = new ArrayList<>();
+ if( output.getDataType().isMatrix() )
+ boundOutputNames.add(output.getName());
+ else //list
+ boundOutputNames.addAll(fpb.getOutputParamNames());
+
//5. call the function (to unoptimized function)
FunctionCallCPInstruction fcpi = new
FunctionCallCPInstruction(nsName, funcName,
false, boundInputs, lineageInputs,
fpb.getInputParamNames(), boundOutputNames, "eval func");
fcpi.processInstruction(ec);
- //6. convert the result to matrix
- Data newOutput = ec.getVariable(output);
- if (!(newOutput instanceof MatrixObject)) {
- MatrixBlock mb = null;
- if (newOutput instanceof ScalarObject) {
- //convert scalar to matrix
- mb = new MatrixBlock(((ScalarObject)
newOutput).getDoubleValue());
- } else if (newOutput instanceof FrameObject) {
- //convert frame to matrix
- mb =
DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
- ec.cleanupCacheableData((FrameObject)
newOutput);
+ //6a. convert the result to matrix
+ if( output.getDataType().isMatrix() ) {
+ Data newOutput = ec.getVariable(output);
+ if (!(newOutput instanceof MatrixObject)) {
+ MatrixBlock mb = null;
+ if (newOutput instanceof ScalarObject) {
+ //convert scalar to matrix
+ mb = new MatrixBlock(((ScalarObject)
newOutput).getDoubleValue());
+ } else if (newOutput instanceof FrameObject) {
+ //convert frame to matrix
+ mb =
DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
+ ec.cleanupCacheableData((FrameObject)
newOutput);
+ }
+ else {
+ throw new DMLRuntimeException("Invalid
eval return type: "+newOutput.getDataType().name()
+ + " (valid:
matrix/frame/scalar; where frames or scalars are converted to output
matrices)");
+ }
+ outputMO.acquireModify(mb);
+ outputMO.release();
+ ec.setVariable(output.getName(), outputMO);
}
- else {
- throw new DMLRuntimeException("Invalid eval
return type: "+newOutput.getDataType().name()
- + " (valid: matrix/frame/scalar; where
frames or scalars are converted to output matrices)");
- }
- outputMO.acquireModify(mb);
- outputMO.release();
- ec.setVariable(output.getName(), outputMO);
+ }
+ //6a. wrap outputs in named list (evalList)
+ else {
+ Data[] ldata = boundOutputNames.stream()
+ .map(n ->
ec.getVariable(n)).toArray(Data[]::new);
+ String[] lnames = boundOutputNames.toArray(new
String[0]);
+ ListObject listOutput = new ListObject(ldata, lnames);
+ ec.setVariable(output.getName(), listOutput);
}
//7. cleanup of variable expanded from list
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index c579c96..c4cf00a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -736,7 +736,7 @@ public class VariableCPInstruction extends CPInstruction
implements LineageTrace
if ( srcData == null ) {
throw new DMLRuntimeException("Unexpected
error: could not find a data object "
- + "for variable name:" +
getInput1().getName() + ", while processing instruction ");
+ + "for variable name: " +
getInput1().getName() + ", while processing instruction ");
}
// remove existing variable bound to target name and
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinNormalizeTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinNormalizeTest.java
index 6bc7028..e1f1aa1 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinNormalizeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinNormalizeTest.java
@@ -36,6 +36,8 @@ public class BuiltinNormalizeTest extends AutomatedTestBase
{
private final static String TEST_NAME = "normalize";
private final static String TEST_NAME2 = "normalizeAll";
+ private final static String TEST_NAME3 = "normalizeListEval";
+ private final static String TEST_NAME4 = "normalizeListEvalAll";
private final static String TEST_DIR = "functions/builtin/";
private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinNormalizeTest.class.getSimpleName() + "/";
@@ -48,6 +50,7 @@ public class BuiltinNormalizeTest extends AutomatedTestBase
@Override
public void setUp() {
+ //only needed for directory here
addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
}
@@ -91,6 +94,46 @@ public class BuiltinNormalizeTest extends AutomatedTestBase
runNormalizeTest(TEST_NAME2, true, ExecType.SPARK);
}
+ @Test
+ public void testNormalizeListEvalMatrixDenseCP() {
+ runNormalizeTest(TEST_NAME3, false, ExecType.CP);
+ }
+
+ @Test
+ public void testNormalizeListEvalMatrixSparseCP() {
+ runNormalizeTest(TEST_NAME3, true, ExecType.CP);
+ }
+
+ @Test
+ public void testNormalizeListEvalMatrixDenseSP() {
+ runNormalizeTest(TEST_NAME3, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testNormalizeListEvalMatrixSparseSP() {
+ runNormalizeTest(TEST_NAME3, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testNormalizeListEval2MatrixDenseCP() {
+ runNormalizeTest(TEST_NAME4, false, ExecType.CP);
+ }
+
+ @Test
+ public void testNormalizeListEval2MatrixSparseCP() {
+ runNormalizeTest(TEST_NAME4, true, ExecType.CP);
+ }
+
+ @Test
+ public void testNormalizeListEval2MatrixDenseSP() {
+ runNormalizeTest(TEST_NAME4, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testNormalizeListEval2MatrixSparseSP() {
+ runNormalizeTest(TEST_NAME4, true, ExecType.SPARK);
+ }
+
private void runNormalizeTest(String testname, boolean sparse, ExecType
instType)
{
ExecMode platformOld = setExecMode(instType);
@@ -102,7 +145,7 @@ public class BuiltinNormalizeTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{"-args", input("A"),
output("B") };
+ programArgs = new String[]{"-explain","-args",
input("A"), output("B") };
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
@@ -120,7 +163,8 @@ public class BuiltinNormalizeTest extends AutomatedTestBase
//check number of compiler Spark instructions
if( instType == ExecType.CP ) {
- Assert.assertEquals(1,
Statistics.getNoOfCompiledSPInst()); //reblock
+ int expected = testname.equals(TEST_NAME4) ? 2
: 1;
+ Assert.assertEquals(expected,
Statistics.getNoOfCompiledSPInst()); //reblock, [write]
Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/recompile/RandJobRecompileTest.java
b/src/test/java/org/apache/sysds/test/functions/recompile/RandJobRecompileTest.java
index 4231481..ed9174f 100644
---
a/src/test/java/org/apache/sysds/test/functions/recompile/RandJobRecompileTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/recompile/RandJobRecompileTest.java
@@ -42,23 +42,19 @@ public class RandJobRecompileTest extends AutomatedTestBase
@Override
- public void setUp()
- {
+ public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME,
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new
String[] { "Z" }) );
}
-
@Test
- public void testRandRecompileNoEstSizeEval()
- {
+ public void testRandRecompileNoEstSizeEval() {
runRandJobRecompileTest(false);
}
@Test
- public void testRandRecompilEstSizeEval()
- {
+ public void testRandRecompileEstSizeEval() {
runRandJobRecompileTest(true);
}
@@ -74,7 +70,7 @@ public class RandJobRecompileTest extends AutomatedTestBase
/* This is for running the junit test the new way,
i.e., construct the arguments directly */
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-args", input("X"),
Integer.toString(rows), output("Z") };
+ programArgs = new String[]{"-explain","-args",
input("X"), Integer.toString(rows), output("Z") };
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
@@ -94,7 +90,7 @@ public class RandJobRecompileTest extends AutomatedTestBase
//check expected number of compiled and executed Spark
jobs
int expectedNumCompiled = (estSizeEval?1:3); //rbl,
rand, write
- int expectedNumExecuted = (estSizeEval?0:1); //write
+ int expectedNumExecuted = 0;
checkNumCompiledSparkInst(expectedNumCompiled);
checkNumExecutedSparkInst(expectedNumExecuted);
diff --git a/src/test/scripts/functions/builtin/normalizeListEval.dml
b/src/test/scripts/functions/builtin/normalizeListEval.dml
new file mode 100644
index 0000000..3a7e56f
--- /dev/null
+++ b/src/test/scripts/functions/builtin/normalizeListEval.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+L = evalList("normalize", X);
+Y = as.matrix(L[1]);
+write(Y, $2);
diff --git a/src/test/scripts/functions/builtin/normalizeListEvalAll.dml
b/src/test/scripts/functions/builtin/normalizeListEvalAll.dml
new file mode 100644
index 0000000..461338a
--- /dev/null
+++ b/src/test/scripts/functions/builtin/normalizeListEvalAll.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+L = evalList("normalize", X);
+# if the output signature of normalize would match
+# the input signature of normalizeApply, direct hand-over possible
+[L, Y] = remove(L, 1);
+L = append(L, list(X=X));
+Y = eval("normalizeApply", L);
+
+write(Y, $2);