This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 48a10cb [SYSTEMDS-3204] Frame map operations w/ margin parameter
48a10cb is described below
commit 48a10cb07f42da54da08dfb931facbca4c13206d
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Sat Dec 18 22:14:16 2021 +0100
[SYSTEMDS-3204] Frame map operations w/ margin parameter
Closes #1440.
---
docs/site/dml-language-reference.md | 2 +-
scripts/pipelines/scripts/utils.dml | 6 +-
src/main/java/org/apache/sysds/common/Types.java | 10 +-
src/main/java/org/apache/sysds/hops/BinaryOp.java | 3 -
src/main/java/org/apache/sysds/hops/TernaryOp.java | 36 +++++--
src/main/java/org/apache/sysds/lops/Binary.java | 5 +-
.../sysds/parser/BuiltinFunctionExpression.java | 8 +-
.../org/apache/sysds/parser/DMLTranslator.java | 6 +-
.../runtime/instructions/CPInstructionParser.java | 2 +-
.../runtime/instructions/SPInstructionParser.java | 2 +-
.../instructions/cp/BinaryCPInstruction.java | 2 -
.../instructions/cp/TernaryCPInstruction.java | 7 +-
...n.java => TernaryFrameScalarCPInstruction.java} | 13 +--
.../instructions/fed/BinaryFEDInstruction.java | 2 -
.../instructions/fed/FEDInstructionUtils.java | 31 +++---
.../instructions/fed/TernaryFEDInstruction.java | 8 +-
....java => TernaryFrameScalarFEDInstruction.java} | 10 +-
.../instructions/spark/BinarySPInstruction.java | 5 +-
...n.java => TernaryFrameScalarSPInstruction.java} | 23 +++--
.../instructions/spark/TernarySPInstruction.java | 7 +-
.../sysds/runtime/matrix/data/FrameBlock.java | 67 ++++++++----
.../apache/sysds/runtime/util/UtilFunctions.java | 56 ++++++++---
.../functions/binary/frame/FrameMapMarginTest.java | 112 +++++++++++++++++++++
src/test/scripts/functions/binary/frame/map.dml | 2 +-
.../binary/frame/{map.dml => mapMargin.dml} | 18 ++--
.../functions/federated/FederatedFrameMapTest.dml | 2 +-
.../federated/FederatedFrameMapTestReference.dml | 2 +-
27 files changed, 326 insertions(+), 121 deletions(-)
diff --git a/docs/site/dml-language-reference.md
b/docs/site/dml-language-reference.md
index 31d3aba..3eeda3f 100644
--- a/docs/site/dml-language-reference.md
+++ b/docs/site/dml-language-reference.md
@@ -2053,7 +2053,7 @@ The following example uses <code>transformapply()</code>
with the input matrix a
Function | Description | Parameters | Example
-------- | ----------- | ---------- | -------
-map() | It will execute the given lambda expression on a frame.| Input: (X
<frame>, y <String>) <br/>Output: <frame>. <br/> X is a frame
and <br/>y is a String containing the lambda expression to be executed on frame
X. | [map](#map)
+map() | It will execute the given lambda expression on a frame (cell, row or
column wise). | Input: (X <frame>, y <String>, \[margin
<int>\]) <br/>Output: <frame>. <br/> X is a frame and <br/>y is a
String containing the lambda expression to be executed on frame X. <br/> margin
- how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 -
columns). Output matrix dimensions are always equal to the input. | [map](#map)
tokenize() | Transforms a frame to tokenized frame using specification.
Tokenization is valid only for string columns. | Input:<br/> target =
<frame> <br/> spec = <json specification> <br/> Outputs:
<matrix>, <frame> | [tokenize](#tokenize)
#### map
diff --git a/scripts/pipelines/scripts/utils.dml
b/scripts/pipelines/scripts/utils.dml
index 97de7e7..14ff36e 100644
--- a/scripts/pipelines/scripts/utils.dml
+++ b/scripts/pipelines/scripts/utils.dml
@@ -174,8 +174,8 @@ return(Frame[Unknown] processedData, Matrix[Double] M)
for(i in 1:ncol(mask))
if(as.scalar(schema[1,i]) == "STRING")
data[, i] = map(data[, i], "x -> x.toLowerCase()")
-
- # step 5 typo correction
+
+ # step 5 typo correction
if(CorrectTypos)
{
# recode data to get null mask
@@ -197,7 +197,7 @@ return(Frame[Unknown] processedData, Matrix[Double] M)
}
# step 6 porter stemming on all features
print(prefix+" porter-stemming on all features");
- data = map(data, "x -> PorterStemmer.stem(x)")
+ data = map(data, "x -> PorterStemmer.stem(x)", 0)
# TODO add deduplication
print(prefix+" deduplication via entity resolution");
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 21a874e..85a11e4 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -310,7 +310,7 @@ public class Types
BITWXOR(true), CBIND(false), CONCAT(false), COV(false),
DIV(true),
DROP_INVALID_TYPE(false), DROP_INVALID_LENGTH(false),
EQUAL(true), GREATER(true),
GREATEREQUAL(true), INTDIV(true), INTERQUANTILE(false),
IQM(false), LESS(true),
- LESSEQUAL(true), LOG(true), MAP(false), MAX(true),
MEDIAN(false), MIN(true),
+ LESSEQUAL(true), LOG(true), MAX(true), MEDIAN(false), MIN(true),
MINUS(true), MODULUS(true), MOMENT(false), MULT(true),
NOTEQUAL(true), OR(true),
PLUS(true), POW(true), PRINT(false), QUANTILE(false),
SOLVE(false),
RBIND(false), VALUE_SWAP(false), XOR(true),
@@ -359,7 +359,6 @@ public class Types
case DROP_INVALID_TYPE: return
"dropInvalidType";
case DROP_INVALID_LENGTH: return
"dropInvalidLength";
case VALUE_SWAP: return "valueSwap";
- case MAP: return "_map";
default: return name().toLowerCase();
}
}
@@ -393,8 +392,7 @@ public class Types
case "bitwShiftR": return BITWSHIFTR;
case "dropInvalidType": return
DROP_INVALID_TYPE;
case "dropInvalidLength": return
DROP_INVALID_LENGTH;
- case "valueSwap": return VALUE_SWAP;
- case "map": return MAP;
+ case "valueSwap": return VALUE_SWAP;
default: return
valueOf(opcode.toUpperCase());
}
}
@@ -402,7 +400,7 @@ public class Types
// Operations that require 3 operands
public enum OpOp3 {
- QUANTILE, INTERQUANTILE, CTABLE, MOMENT, COV, PLUS_MULT,
MINUS_MULT, IFELSE;
+ QUANTILE, INTERQUANTILE, CTABLE, MOMENT, COV, PLUS_MULT,
MINUS_MULT, IFELSE, MAP;
@Override
public String toString() {
@@ -410,6 +408,7 @@ public class Types
case MOMENT: return "cm";
case PLUS_MULT: return "+*";
case MINUS_MULT: return "-*";
+ case MAP: return "_map";
default: return name().toLowerCase();
}
}
@@ -419,6 +418,7 @@ public class Types
case "cm": return MOMENT;
case "+*": return PLUS_MULT;
case "-*": return MINUS_MULT;
+ case "map": return MAP;
default: return valueOf(opcode.toUpperCase());
}
}
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index de6ccdc..73deda4 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1076,9 +1076,6 @@ public class BinaryOp extends MultiThreadedHop {
if( !(that instanceof BinaryOp) )
return false;
- if(op == OpOp2.MAP)
- return false; // custom UDFs
-
BinaryOp that2 = (BinaryOp)that;
return ( op == that2.op
&& outer == that2.outer
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 706a319..00bf2e1 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -182,9 +182,9 @@ public class TernaryOp extends MultiThreadedHop
case PLUS_MULT:
case MINUS_MULT:
case IFELSE:
+ case MAP:
constructLopsTernaryDefault();
break;
-
default:
throw new
HopsException(this.printErrorLocation() + "Unknown TernaryOp (" + _op + ")
while constructing Lops \n");
@@ -375,6 +375,7 @@ public class TernaryOp extends MultiThreadedHop
return
OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
case PLUS_MULT:
case MINUS_MULT:
+ case MAP:
case IFELSE: {
if (isGPUEnabled()) {
// For the GPU, the input is converted
to dense
@@ -419,6 +420,15 @@ public class TernaryOp extends MultiThreadedHop
switch( _op )
{
+ case MAP:
+ long ldim1 = (mc[0].rowsKnown()) ?
mc[0].getRows() :
+ (mc[1].getRows()>=0) ? mc[1].getRows()
: -1;
+ long ldim2 = (mc[0].colsKnown()) ?
mc[0].getCols() :
+ (mc[1].getCols()>=0) ? mc[1].getCols()
: -1;
+ if( ldim1>=0 && ldim2>=0 )
+ ret = new MatrixCharacteristics(ldim1,
ldim2, -1, (long) (ldim1 * ldim2 * 1.0));
+ return ret;
+
case CTABLE:
boolean dimsSpec = (getInput().size() > 3);
@@ -515,20 +525,31 @@ public class TernaryOp extends MultiThreadedHop
@Override
public void refreshSizeInformation()
{
+ Hop input1 = getInput().get(0);
+ Hop input2 = getInput().get(1);
+ Hop input3 = getInput().get(2);
+
if ( getDataType() == DataType.SCALAR )
{
//do nothing always known
}
- else
+ else
{
switch( _op )
{
+ case MAP:
+ long ldim1, ldim2, lnnz1 = -1;
+ ldim1 = (input1.rowsKnown()) ?
input1.getDim1() : ((input2.getDim1()>=0)?input2.getDim1():-1);
+ ldim2 = (input1.colsKnown()) ?
input1.getDim2() : ((input2.getDim2()>=0)?input2.getDim2():-1);
+ lnnz1 = input1.getNnz();
+
+ setDim1( ldim1 );
+ setDim2( ldim2 );
+ setNnz(lnnz1);
+ break;
case CTABLE:
//in general, do nothing because the
output size is data dependent
- Hop input1 = getInput().get(0);
- Hop input2 = getInput().get(1);
- Hop input3 = getInput().get(2);
-
+
//TODO double check reset
(dimsInputPresent?)
if ( !dimsKnown() ) {
//for ctable_expand at least
one dimension is known
@@ -600,6 +621,9 @@ public class TernaryOp extends MultiThreadedHop
return false;
TernaryOp that2 = (TernaryOp)that;
+
+ if(_op == OpOp3.MAP)
+ return false; // custom UDFs
//compare basic inputs and weights (always existing)
boolean ret = (_op == that2._op
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java
b/src/main/java/org/apache/sysds/lops/Binary.java
index 202dce7..6949188 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -80,10 +80,7 @@ public class Binary extends Lop
return null;
ArrayList<Lop> inputs = getInputs();
- if (operation == OpOp2.MAP && inputs.get(0).getDataType() ==
DataType.MATRIX
- && inputs.get(1).getDataType() ==
DataType.MATRIX)
- return inputs.get(1);
- else if (inputs.get(0).getDataType() == DataType.FRAME &&
inputs.get(1).getDataType() == DataType.MATRIX)
+ if (inputs.get(0).getDataType() == DataType.FRAME &&
inputs.get(1).getDataType() == DataType.MATRIX)
return inputs.get(1);
else
return null;
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index cc70396..c2f72a5 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1565,20 +1565,20 @@ public class BuiltinFunctionExpression extends
DataIdentifier
break;
case MAP:
- checkNumParameters(2);
+ checkNumParameters(getThirdExpr() != null ? 3 : 2);
checkMatrixFrameParam(getFirstExpr());
checkScalarParam(getSecondExpr());
+ if(getThirdExpr() != null)
+ checkScalarParam(getThirdExpr()); // margin
output.setDataType(DataType.FRAME);
if(_args[1].getText().contains("jaccardSim")) {
output.setDimensions(id.getDim1(),
id.getDim1());
output.setValueType(ValueType.FP64);
}
else {
- output.setDimensions(id.getDim1(), 1);
+ output.setDimensions(id.getDim1(),
id.getDim2());
output.setValueType(ValueType.STRING);
}
- output.setBlocksize (id.getBlocksize());
-
break;
case LOCAL:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_LOCAL_COMMAND){
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index dc51d46..08c7ebf 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2540,10 +2540,14 @@ public class DMLTranslator
case DROP_INVALID_TYPE:
case DROP_INVALID_LENGTH:
case VALUE_SWAP:
- case MAP:
currBuiltinOp = new BinaryOp(target.getName(),
target.getDataType(),
target.getValueType(),
OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
break;
+ case MAP:
+ currBuiltinOp = new TernaryOp(target.getName(),
target.getDataType(),
+ target.getValueType(),
OpOp3.valueOf(source.getOpCode().name()),
+ expr, expr2, (expr3==null) ? new LiteralOp(0L)
: expr3);
+ break;
case LOG:
if (expr2 == null) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 22cceea..c08985b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -162,7 +162,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "dropInvalidType" ,
CPType.Binary);
String2CPInstructionType.put( "dropInvalidLength" ,
CPType.Binary);
String2CPInstructionType.put( "valueSwap" , CPType.Binary);
- String2CPInstructionType.put( "_map" , CPType.Binary); // _map
represents the operation map
+ String2CPInstructionType.put( "_map" , CPType.Ternary); //
_map represents the operation map
String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 85abf32..56cd49a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -183,7 +183,7 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "dropInvalidType", SPType.Binary);
String2SPInstructionType.put( "mapdropInvalidLength",
SPType.Binary);
String2SPInstructionType.put( "valueSwap", SPType.Binary);
- String2SPInstructionType.put( "_map", SPType.Binary); // _map
refers to the operation map
+ String2SPInstructionType.put( "_map", SPType.Ternary); // _map
refers to the operation map
// Relational Instruction Opcodes
String2SPInstructionType.put( "==" , SPType.Binary);
String2SPInstructionType.put( "!=" , SPType.Binary);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index 188b2ac..4d83d6d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -57,8 +57,6 @@ public abstract class BinaryCPInstruction extends
ComputationCPInstruction {
return new BinaryFrameFrameCPInstruction(operator, in1,
in2, out, opcode, str);
else if (in1.getDataType() == DataType.FRAME &&
in2.getDataType() == DataType.MATRIX)
return new BinaryFrameMatrixCPInstruction(operator,
in1, in2, out, opcode, str);
- else if (in1.getDataType() == DataType.FRAME &&
in2.getDataType() == DataType.SCALAR)
- return new BinaryFrameScalarCPInstruction(operator,
in1, in2, out, opcode, str);
else
return new BinaryMatrixScalarCPInstruction(operator,
in1, in2, out, opcode, str);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
index 14d0090..86733fb 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
@@ -26,7 +26,7 @@ import
org.apache.sysds.runtime.matrix.operators.TernaryOperator;
public class TernaryCPInstruction extends ComputationCPInstruction {
- private TernaryCPInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+ protected TernaryCPInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
super(CPType.Ternary, op, in1, in2, in3, out, opcode, str);
}
@@ -40,7 +40,10 @@ public class TernaryCPInstruction extends
ComputationCPInstruction {
CPOperand outOperand = new CPOperand(parts[4]);
int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) :
1;
TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode, numThreads);
- return new TernaryCPInstruction(op, operand1, operand2,
operand3, outOperand, opcode,str);
+ if(operand1.isFrame() && operand2.isScalar() &&
opcode.contains("map"))
+ return new TernaryFrameScalarCPInstruction(op,
operand1, operand2, operand3, outOperand, opcode, str);
+ else
+ return new TernaryCPInstruction(op, operand1, operand2,
operand3, outOperand, opcode,str);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryFrameScalarCPInstruction.java
similarity index 75%
rename from
src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
rename to
src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryFrameScalarCPInstruction.java
index bcf7cb5..6512343 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryFrameScalarCPInstruction.java
@@ -21,22 +21,23 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
-public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction
+public class TernaryFrameScalarCPInstruction extends TernaryCPInstruction
{
- protected BinaryFrameScalarCPInstruction(Operator op, CPOperand in1,
- CPOperand in2, CPOperand out, String opcode, String
istr) {
- super(CPType.Binary, op, in1, in2, out, opcode, istr);
+ protected TernaryFrameScalarCPInstruction(TernaryOperator op, CPOperand
in1,
+ CPOperand in2, CPOperand in3, CPOperand out, String
opcode, String istr) {
+ super(op, in1, in2, in3, out, opcode, istr);
}
@Override
public void processInstruction(ExecutionContext ec) {
// get input frames
FrameBlock inBlock = ec.getFrameInput(input1.getName());
+ ScalarObject margin = ec.getScalarInput(input3);
String stringExpression =
ec.getScalarInput(input2).getStringValue();
//compute results
- FrameBlock outBlock = inBlock.map(stringExpression);
+ FrameBlock outBlock = inBlock.map(stringExpression,
margin.getLongValue());
// Attach result frame with FrameBlock associated with
output_name
ec.setFrameOutput(output.getName(), outBlock);
// Release the memory occupied by input frames
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index c2f07f7..5e1d35f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -72,8 +72,6 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
throw new DMLRuntimeException("Federated binary tensor
tensor operations not yet supported");
else if( in1.isMatrix() && in2.isScalar() || in2.isMatrix() &&
in1.isScalar() )
return new BinaryMatrixScalarFEDInstruction(operator,
in1, in2, out, opcode, str, fedOut);
- else if( in1.isFrame() && in2.isScalar() || in2.isFrame() &&
in1.isScalar() )
- return new BinaryFrameScalarFEDInstruction(operator,
in1, in2, out, opcode, InstructionUtils.removeFEDOutputFlag(str));
else
throw new DMLRuntimeException("Federated binary
operations not yet supported:" + opcode);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 12965f6..d46d493 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -37,7 +37,6 @@ import
org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
@@ -51,6 +50,7 @@ import
org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import
org.apache.sysds.runtime.instructions.cp.TernaryFrameScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -62,7 +62,6 @@ import
org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
-import
org.apache.sysds.runtime.instructions.spark.BinaryFrameScalarSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
@@ -85,6 +84,7 @@ import
org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.TernaryFrameScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
@@ -187,9 +187,6 @@ public class FEDInstructionUtils {
else
fedinst =
BinaryFEDInstruction.parseInstruction(
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
- } else if(inst.getOpcode().equals("_map") && inst
instanceof BinaryFrameScalarCPInstruction &&
!inst.getInstructionString().contains("UtilFunctions")
- && instruction.input1.isFrame() &&
ec.getFrameObject(instruction.input1).isFederated()) {
- fedinst =
BinaryFrameScalarFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
}
}
else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
@@ -217,7 +214,15 @@ public class FEDInstructionUtils {
}
else if(inst instanceof TernaryCPInstruction) {
TernaryCPInstruction tinst = (TernaryCPInstruction)
inst;
- if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
+ if(inst.getOpcode().equals("_map") && inst instanceof
TernaryFrameScalarCPInstruction &&
!inst.getInstructionString().contains("UtilFunctions")
+ && tinst.input1.isFrame() &&
ec.getFrameObject(tinst.input1).isFederated()) {
+ long margin =
ec.getScalarInput(tinst.input3).getLongValue();
+ FrameObject fo =
ec.getFrameObject(tinst.input1);
+ if(margin == 0 || (fo.isFederated(FType.ROW) &&
margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
+ fedinst =
TernaryFrameScalarFEDInstruction
+
.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+ }
+ else if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
|| (tinst.input2.isMatrix() &&
ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
|| (tinst.input3.isMatrix() &&
ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
fedinst =
TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
@@ -435,11 +440,6 @@ public class FEDInstructionUtils {
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
}
}
- else if(inst.getOpcode().equals("_map") && inst
instanceof BinaryFrameScalarSPInstruction &&
!inst.getInstructionString().contains("UtilFunctions")
- && instruction.input1.isFrame() &&
ec.getFrameObject(instruction.input1).isFederated()) {
- fedinst =
BinaryFrameScalarFEDInstruction.parseInstruction(InstructionUtils
-
.concatOperands(inst.getInstructionString(), FederatedOutput.NONE.name()));
- }
else if( (instruction.input1.isMatrix() &&
ec.getCacheableData(instruction.input1).isFederatedExcept(FType.BROADCAST))
|| (instruction.input2.isMatrix() &&
ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
if("cov".equals(instruction.getOpcode()) &&
(ec.getMatrixObject(instruction.input1)
@@ -476,7 +476,14 @@ public class FEDInstructionUtils {
}
else if(inst instanceof TernarySPInstruction) {
TernarySPInstruction tinst = (TernarySPInstruction)
inst;
- if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
+ if(inst.getOpcode().equals("_map") && inst instanceof
TernaryFrameScalarSPInstruction &&
!inst.getInstructionString().contains("UtilFunctions")
+ && tinst.input1.isFrame() &&
ec.getFrameObject(tinst.input1).isFederated()) {
+ long margin =
ec.getScalarInput(tinst.input3).getLongValue();
+ FrameObject fo =
ec.getFrameObject(tinst.input1);
+ if(margin == 0 || (fo.isFederated(FType.ROW) &&
margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
+ fedinst =
TernaryFrameScalarFEDInstruction
+
.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+ } else if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
|| (tinst.input2.isMatrix() &&
ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
|| (tinst.input3.isMatrix() &&
ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
fedinst =
TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 954c4c4..cb8d9fa 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -38,7 +38,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class TernaryFEDInstruction extends ComputationFEDInstruction {
- private TernaryFEDInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
+ protected TernaryFEDInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String str, FederatedOutput fedOut) {
super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out,
opcode, str, fedOut);
}
@@ -50,9 +50,11 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
CPOperand operand2 = new CPOperand(parts[2]);
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
- int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) :
1;
- FederatedOutput fedOut = parts.length>7 ?
FederatedOutput.valueOf(parts[6]) : FederatedOutput.NONE;
+ int numThreads = parts.length>5 & !opcode.contains("map") ?
Integer.parseInt(parts[5]) : 1;
+ FederatedOutput fedOut = parts.length>7 &&
!opcode.contains("map") ? FederatedOutput.valueOf(parts[6]) :
FederatedOutput.NONE;
TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode, numThreads);
+ if( operand1.isFrame() && operand2.isScalar() ||
operand2.isFrame() && operand1.isScalar() )
+ return new TernaryFrameScalarFEDInstruction(op,
operand1, operand2, operand3, outOperand, opcode,
InstructionUtils.removeFEDOutputFlag(str), fedOut);
return new TernaryFEDInstruction(op, operand1, operand2,
operand3, outOperand, opcode, str, fedOut);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFrameScalarFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFrameScalarFEDInstruction.java
similarity index 83%
rename from
src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFrameScalarFEDInstruction.java
rename to
src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFrameScalarFEDInstruction.java
index ecb3a45..6205d20 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFrameScalarFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFrameScalarFEDInstruction.java
@@ -25,13 +25,13 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
-public class BinaryFrameScalarFEDInstruction extends BinaryFEDInstruction
+public class TernaryFrameScalarFEDInstruction extends TernaryFEDInstruction
{
- protected BinaryFrameScalarFEDInstruction(Operator op, CPOperand in1,
- CPOperand in2, CPOperand out, String opcode, String
istr) {
- super(FEDInstruction.FEDType.Binary, op, in1, in2, out, opcode,
istr);
+ protected TernaryFrameScalarFEDInstruction(TernaryOperator op,
CPOperand in1,
+ CPOperand in2, CPOperand in3, CPOperand out, String
opcode, String istr, FederatedOutput fedOut) {
+ super(op, in1, in2, in3, out, opcode, istr, fedOut);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index 82a95a8..16196d5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -111,9 +111,8 @@ public abstract class BinarySPInstruction extends
ComputationSPInstruction {
else if( dt1 == DataType.FRAME || dt2 == DataType.FRAME ) {
if(dt1 == DataType.FRAME && dt2 == DataType.FRAME)
return new
BinaryFrameFrameSPInstruction(operator, in1, in2, out, opcode, str);
- if(dt1 == DataType.FRAME && dt2 == DataType.SCALAR)
- return new
BinaryFrameScalarSPInstruction(operator, in1, in2, out, opcode, str);
-
+ else if(dt1 == DataType.FRAME && dt2 == DataType.SCALAR
&& opcode.equalsIgnoreCase("+"))
+ return new
BinaryMatrixScalarSPInstruction(operator, in1, in2, out, opcode, str);
}
return null;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernaryFrameScalarSPInstruction.java
similarity index 74%
rename from
src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
rename to
src/main/java/org/apache/sysds/runtime/instructions/spark/TernaryFrameScalarSPInstruction.java
index b5cf078..d609515 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernaryFrameScalarSPInstruction.java
@@ -25,12 +25,12 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
-public class BinaryFrameScalarSPInstruction extends BinarySPInstruction {
- protected BinaryFrameScalarSPInstruction (Operator op, CPOperand in1,
CPOperand in2, CPOperand out,
+public class TernaryFrameScalarSPInstruction extends TernarySPInstruction {
+ protected TernaryFrameScalarSPInstruction(TernaryOperator op, CPOperand
in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String istr) {
- super(SPType.Binary, op, in1, in2, out, opcode, istr);
+ super(op, in1, in2, in3, out, opcode, istr);
}
@Override
@@ -40,15 +40,19 @@ public class BinaryFrameScalarSPInstruction extends
BinarySPInstruction {
// Get input RDDs
JavaPairRDD<Long, FrameBlock> in1 =
sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
String expression = sec.getScalarInput(input2).getStringValue();
+ long margin = ec.getScalarInput(input3).getLongValue();
// Create local compiled functions (once) and execute on RDD
- JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new
RDDStringProcessing(expression));
+ JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new
RDDStringProcessing(expression, margin));
if(expression.contains("jaccardSim")) {
long rows =
sec.getDataCharacteristics(output.getName()).getRows();
sec.getDataCharacteristics(output.getName()).setDimension(rows, rows);
+ } else {
+ long rows =
sec.getDataCharacteristics(output.getName()).getRows();
+ long cols =
sec.getDataCharacteristics(output.getName()).getCols();
+
sec.getDataCharacteristics(output.getName()).setDimension(rows, cols);
}
-
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), input1.getName());
}
@@ -57,14 +61,17 @@ public class BinaryFrameScalarSPInstruction extends
BinarySPInstruction {
private static final long serialVersionUID =
5850400295183766400L;
private String _expr = null;
+ private long _margin = -1;
- public RDDStringProcessing(String expr) {
+ public RDDStringProcessing(String expr, long margin) {
_expr = expr;
+ _margin = margin;
}
@Override
public FrameBlock call(FrameBlock arg0) throws Exception {
- return arg0.map(_expr);
+ FrameBlock fb = arg0.map(_expr, _margin);
+ return fb;
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
index 2f4c129..fa1dce2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
@@ -34,7 +34,7 @@ import scala.Tuple2;
import java.io.Serializable;
public class TernarySPInstruction extends ComputationSPInstruction {
- private TernarySPInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+ protected TernarySPInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
super(SPType.Ternary, op, in1, in2, in3, out, opcode, str);
}
@@ -46,7 +46,10 @@ public class TernarySPInstruction extends
ComputationSPInstruction {
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode);
- return new TernarySPInstruction(op,operand1, operand2,
operand3, outOperand, opcode,str);
+ if(operand1.isFrame() && operand2.isScalar() &&
opcode.contains("map"))
+ return new TernaryFrameScalarSPInstruction(op,
operand1, operand2, operand3, outOperand, opcode, str);
+ else
+ return new TernarySPInstruction(op, operand1, operand2,
operand3, outOperand, opcode,str);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index f4f9beb..751d24e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -2320,7 +2320,7 @@ public class FrameBlock implements CacheBlock,
Externalizable {
}
}
- public FrameBlock map (String lambdaExpr){
+ public FrameBlock map (String lambdaExpr, long margin){
if(!lambdaExpr.contains("->")) {
String args =
lambdaExpr.substring(lambdaExpr.indexOf('(') + 1, lambdaExpr.indexOf(')'));
if(args.contains(",")) {
@@ -2333,8 +2333,8 @@ public class FrameBlock implements CacheBlock,
Externalizable {
}
}
if(lambdaExpr.contains("jaccardSim"))
- return mapDist(getCompiledFunction(lambdaExpr));
- return map(getCompiledFunction(lambdaExpr));
+ return mapDist(getCompiledFunction(lambdaExpr, margin));
+ return map(getCompiledFunction(lambdaExpr, margin), margin);
}
public FrameBlock valueSwap(FrameBlock schema) {
@@ -2418,17 +2418,38 @@ public class FrameBlock implements CacheBlock,
Externalizable {
return this;
}
- public FrameBlock map (FrameMapFunction lambdaExpr) {
+ public FrameBlock map (FrameMapFunction lambdaExpr, long margin) {
// Prepare temporary output array
String[][] output = new String[getNumRows()][getNumColumns()];
- // Execute map function on all cells
- for(int j = 0; j < getNumColumns(); j++) {
- Array input = getColumn(j);
- for(int i = 0; i < input._size; i++)
- if(input.get(i) != null)
- output[i][j] =
lambdaExpr.apply(String.valueOf(input.get(i)));
- }
+ if (margin == 1) {
+ // Execute map function on rows
+ for(int i = 0; i < getNumRows(); i++) {
+ String[] row = new String[getNumColumns()];
+ for(int j = 0; j < getNumColumns(); j++) {
+ Array input = getColumn(j);
+ row[j] = String.valueOf(input.get(i));
+ }
+ output[i] = lambdaExpr.apply(row);
+ }
+ } else if (margin == 2) {
+ // Execute map function on columns
+ for(int j = 0; j < getNumColumns(); j++) {
+ String[] actualColumn =
Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more
rows can be allocated, mutable array
+ String[] outColumn =
lambdaExpr.apply(actualColumn);
+
+ for(int i = 0; i < getNumRows(); i++)
+ output[i][j] = outColumn[i];
+ }
+ } else {
+ // Execute map function on all cells
+ for(int j = 0; j < getNumColumns(); j++) {
+ Array input = getColumn(j);
+ for(int i = 0; i < input._size; i++)
+ if(input.get(i) != null)
+ output[i][j] =
lambdaExpr.apply(String.valueOf(input.get(i)));
+ }
+ }
return new FrameBlock(UtilFunctions.nCopies(getNumColumns(),
ValueType.STRING), output);
}
@@ -2441,13 +2462,12 @@ public class FrameBlock implements CacheBlock,
Externalizable {
for(int i = j + 1; i < input._size; i++)
if(input.get(i) != null && input.get(j) !=
null) {
output[j][i] =
lambdaExpr.apply(String.valueOf(input.get(j)), String.valueOf(input.get(i)));
- //
output[i][j] = output[j][i];
}
}
return new FrameBlock(UtilFunctions.nCopies(getNumRows(),
ValueType.STRING), output);
}
- public static FrameMapFunction getCompiledFunction (String lambdaExpr) {
+ public static FrameMapFunction getCompiledFunction (String lambdaExpr,
long margin) {
String cname = "StringProcessing" + CLASS_ID.getNextID();
StringBuilder sb = new StringBuilder();
String[] parts = lambdaExpr.split("->");
@@ -2460,14 +2480,20 @@ public class FrameBlock implements CacheBlock,
Externalizable {
sb.append("import
org.apache.sysds.runtime.util.UtilFunctions;\n");
sb.append("import
org.apache.sysds.runtime.util.PorterStemmer;\n");
sb.append("import
org.apache.sysds.runtime.matrix.data.FrameBlock.FrameMapFunction;\n");
+ sb.append("import java.util.Arrays;\n");
sb.append("public class " + cname + " extends FrameMapFunction
{\n");
- if(varname.length == 1) {
- sb.append("public String apply(String " +
varname[0].trim() + ") {\n");
- sb.append(" return String.valueOf(" + expr + ");
}}\n");
- }
- else if(varname.length == 2) {
- sb.append("public String apply(String " +
varname[0].trim() + ", String " + varname[1].trim() + ") {\n");
- sb.append(" return String.valueOf(" + expr + ");
}}\n");
+ if(margin != 0) {
+ sb.append("public String[] apply(String[] " +
varname[0].trim() + ") {\n");
+ sb.append(" return UtilFunctions.toStringArray(" +
expr + "); }}\n");
+ } else {
+ if(varname.length == 1) {
+ sb.append("public String apply(String " +
varname[0].trim() + ") {\n");
+ sb.append(" return String.valueOf(" + expr +
"); }}\n");
+ }
+ else if(varname.length == 2) {
+ sb.append("public String apply(String " +
varname[0].trim() + ", String " + varname[1].trim() + ") {\n");
+ sb.append(" return String.valueOf(" + expr +
"); }}\n");
+ }
}
// compile class, and create FrameMapFunction object
try {
@@ -2485,6 +2511,7 @@ public class FrameBlock implements CacheBlock,
Externalizable {
private static final long serialVersionUID =
-8398572153616520873L;
public String apply(String input) {return null;}
public String apply(String input1, String input2) { return
null;}
+ public String[] apply(String[] input1) { return null;}
}
public <T> FrameBlock replaceOperations(String pattern, String
replacement) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ee64bc8..4376fa9 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -19,20 +19,6 @@
package org.apache.sysds.runtime.util;
-import org.apache.commons.lang.ArrayUtils;
-import org.apache.commons.lang3.math.NumberUtils;
-import org.apache.commons.math3.random.RandomDataGenerator;
-import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.data.SparseBlock;
-import org.apache.sysds.runtime.data.TensorIndexes;
-import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysds.runtime.matrix.data.Pair;
-import org.apache.sysds.runtime.meta.TensorCharacteristics;
-import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
-
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
@@ -47,6 +33,20 @@ import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.math.NumberUtils;
+import org.apache.commons.math3.random.RandomDataGenerator;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.TensorIndexes;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.meta.TensorCharacteristics;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
+
public class UtilFunctions {
// private static final Log LOG =
LogFactory.getLog(UtilFunctions.class.getName());
@@ -863,6 +863,12 @@ public class UtilFunctions {
return value ;
}
+ public static String[] copyAsStringToArray(String[] input, Object
value) {
+ String[] output = new String[input.length];
+ Arrays.fill(output, String.valueOf(value));
+ return output;
+ }
+
private static String getDateFormat (String dateString) {
return DATE_FORMATS.keySet().parallelStream().filter(e ->
dateString.toLowerCase().matches(e)).findFirst()
.map(DATE_FORMATS::get).orElseThrow(() -> new
NullPointerException("Unknown date format."));
@@ -1013,4 +1019,26 @@ public class UtilFunctions {
//forward to column encoder, as UtilFunctions available in map
context
return ColumnEncoderRecode.splitRecodeMapEntry(s);
}
+
+ public static String[] toStringArray(Object[] original) {
+ String[] result = new String[original.length];
+ for (int i = 0; i < result.length; i++)
+ result[i] = String.valueOf(original[i]);
+ return result;
+ }
+
+ public static double[] convertStringToDoubleArray(String[] original) {
+// double[] ret = new double[original.length];
+// for (int i = 0; i < original.length; i++) {
+// try {
+// ret[i] =
NumberFormat.getInstance().parse(original[i]).doubleValue();
+// }
+// catch(Exception e) {
+// e.printStackTrace();
+// }
+// }
+// return ret;
+
+ return
Arrays.stream(original).mapToDouble(Double::parseDouble).toArray();
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapMarginTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapMarginTest.java
new file mode 100644
index 0000000..aca4484
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapMarginTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.binary.frame;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class FrameMapMarginTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "mapMargin";
+ private final static String TEST_DIR = "functions/binary/frame/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FrameMapMarginTest.class.getSimpleName() + "/";
+
+ private final static int rows = 100;
+ private final static Types.ValueType[] schemaStrings1 =
{Types.ValueType.STRING, Types.ValueType.STRING};
+ private final static String expression = "x ->
UtilFunctions.copyAsStringToArray(x,
Arrays.stream(UtilFunctions.convertStringToDoubleArray(x)).sum())";
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp() {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR +
TEST_CLASS_DIR);
+ }
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"D"}));
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @Test
+ public void testMarginColCP() { runDmlMapTest(expression, 2,
ExecType.CP); }
+
+ @Test
+ public void testMarginColSP() { runDmlMapTest(expression, 2,
ExecType.CP); }
+
+ @Test
+ public void testMarginRowCP() {
+ runDmlMapTest(expression, 1, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMarginRowSP() {
+ runDmlMapTest(expression, 1, ExecType.SPARK);
+ }
+
+ private void runDmlMapTest( String expression, int margin, ExecType et)
+ {
+ Types.ExecMode platformOld = setExecMode(et);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] { "-stats","-args",
input("A"), expression, String.valueOf(margin), output("O")};
+
+ double[][] A = getRandomMatrix(rows, 2, 1, 1, 1, 2);
+ writeInputFrameWithMTD("A", A, true, schemaStrings1,
FileFormat.CSV);
+
+ runTest(true, false, null, -1);
+
+ FrameBlock outputFrame = readDMLFrameFromHDFS("O",
FileFormat.CSV);
+
+ for(int j = 0; j < schemaStrings1.length; j++)
+ for(int i = 0; i < rows; i++) {
+
Assert.assertEquals(Double.parseDouble(((String[])
outputFrame.getColumnData(j))[i]),
+ margin == 1 ?
schemaStrings1.length : rows,
+ 0.0);
+ }
+ }
+ catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/binary/frame/map.dml
b/src/test/scripts/functions/binary/frame/map.dml
index 482c37e..a3c2b65 100644
--- a/src/test/scripts/functions/binary/frame/map.dml
+++ b/src/test/scripts/functions/binary/frame/map.dml
@@ -22,7 +22,7 @@
# input: 1) frame, 2) lamba expression to execute for each non-null cell
# output: frame of string columns
-# Examples:
+# Examples:
# map(X, "x -> x.split(\"r\")[1]")
# map(X, "x -> x.charAt(2)")
# map(X, "y -> UtilFunctions.toMillis(y)")
diff --git a/src/test/scripts/functions/binary/frame/map.dml
b/src/test/scripts/functions/binary/frame/mapMargin.dml
similarity index 75%
copy from src/test/scripts/functions/binary/frame/map.dml
copy to src/test/scripts/functions/binary/frame/mapMargin.dml
index 482c37e..f1ccbf1 100644
--- a/src/test/scripts/functions/binary/frame/map.dml
+++ b/src/test/scripts/functions/binary/frame/mapMargin.dml
@@ -19,16 +19,14 @@
#
#-------------------------------------------------------------
-# input: 1) frame, 2) lamba expression to execute for each non-null cell
-# output: frame of string columns
+# input: 1) frame, 2) lamba expression to execute for each non-null cell, row
or col
+# output: frame
-# Examples:
-# map(X, "x -> x.split(\"r\")[1]")
-# map(X, "x -> x.charAt(2)")
-# map(X, "y -> UtilFunctions.toMillis(y)")
+# Example (convert strings to doubles, compute sum and opy to every cell):
+# map(X, "x -> UtilFunctions.copyAsStringToArray(x,
Arrays.stream(UtilFunctions.convertStringToDoubleArray(x)).sum())", 1)
X = read($1, data_type = "frame", format = "csv", header = FALSE)
-# column vector and string operation
-Y = map(X, $2)
-write(Y, $3, format="csv")
-write(X, $4, format="csv")
\ No newline at end of file
+
+X1 = map(X, "x -> x.replace(\"Str\", \"\")", 0)
+Y = map(X1, $2, $3)
+write(Y, $4, format="csv")
diff --git a/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
b/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
index b879b2f..abadf86 100644
--- a/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
+++ b/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
@@ -31,7 +31,7 @@ if ($rP) {
A = as.frame(A)
-S = map(A, "x -> x.replace(\"1\", \"2\")");
+S = map(A, "x -> x.replace(\"1\", \"2\")", 0);
write(S, $out_S);
print(toString(A[1,1]))
print(toString(S[1,1]))
diff --git
a/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
b/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
index 8b55291..b345984 100644
--- a/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
@@ -28,7 +28,7 @@ else {
A = as.frame(A)
-S = map(A, "x -> x.replace(\"1\", \"2\")");
+S = map(A, "x -> x.replace(\"1\", \"2\")", 0);
write(S, $6);
print(toString(A[1,1]))