This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 48a10cb  [SYSTEMDS-3204] Frame map operations w/ margin parameter
48a10cb is described below

commit 48a10cb07f42da54da08dfb931facbca4c13206d
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Sat Dec 18 22:14:16 2021 +0100

    [SYSTEMDS-3204] Frame map operations w/ margin parameter
    
    Closes #1440.
---
 docs/site/dml-language-reference.md                |   2 +-
 scripts/pipelines/scripts/utils.dml                |   6 +-
 src/main/java/org/apache/sysds/common/Types.java   |  10 +-
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |   3 -
 src/main/java/org/apache/sysds/hops/TernaryOp.java |  36 +++++--
 src/main/java/org/apache/sysds/lops/Binary.java    |   5 +-
 .../sysds/parser/BuiltinFunctionExpression.java    |   8 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |   6 +-
 .../runtime/instructions/CPInstructionParser.java  |   2 +-
 .../runtime/instructions/SPInstructionParser.java  |   2 +-
 .../instructions/cp/BinaryCPInstruction.java       |   2 -
 .../instructions/cp/TernaryCPInstruction.java      |   7 +-
 ...n.java => TernaryFrameScalarCPInstruction.java} |  13 +--
 .../instructions/fed/BinaryFEDInstruction.java     |   2 -
 .../instructions/fed/FEDInstructionUtils.java      |  31 +++---
 .../instructions/fed/TernaryFEDInstruction.java    |   8 +-
 ....java => TernaryFrameScalarFEDInstruction.java} |  10 +-
 .../instructions/spark/BinarySPInstruction.java    |   5 +-
 ...n.java => TernaryFrameScalarSPInstruction.java} |  23 +++--
 .../instructions/spark/TernarySPInstruction.java   |   7 +-
 .../sysds/runtime/matrix/data/FrameBlock.java      |  67 ++++++++----
 .../apache/sysds/runtime/util/UtilFunctions.java   |  56 ++++++++---
 .../functions/binary/frame/FrameMapMarginTest.java | 112 +++++++++++++++++++++
 src/test/scripts/functions/binary/frame/map.dml    |   2 +-
 .../binary/frame/{map.dml => mapMargin.dml}        |  18 ++--
 .../functions/federated/FederatedFrameMapTest.dml  |   2 +-
 .../federated/FederatedFrameMapTestReference.dml   |   2 +-
 27 files changed, 326 insertions(+), 121 deletions(-)

diff --git a/docs/site/dml-language-reference.md 
b/docs/site/dml-language-reference.md
index 31d3aba..3eeda3f 100644
--- a/docs/site/dml-language-reference.md
+++ b/docs/site/dml-language-reference.md
@@ -2053,7 +2053,7 @@ The following example uses <code>transformapply()</code> 
with the input matrix a
 
 Function | Description | Parameters | Example
 -------- | ----------- | ---------- | -------
-map() | It will execute the given lambda expression on a frame.| Input: (X 
&lt;frame&gt;, y &lt;String&gt;) <br/>Output: &lt;frame&gt;. <br/> X is a frame 
and <br/>y is a String containing the lambda expression to be executed on frame 
X. | [map](#map) 
+map() | It will execute the given lambda expression on a frame (cell, row or 
column wise). | Input: (X &lt;frame&gt;, y &lt;String&gt;, \[margin 
&lt;int&gt;\]) <br/>Output: &lt;frame&gt;. <br/> X is a frame and <br/>y is a 
String containing the lambda expression to be executed on frame X. <br/> margin 
- how to apply the lambda expression (0 indicates each cell, 1 - rows, 2 - 
columns). Output matrix dimensions are always equal to the input. | [map](#map) 
 tokenize() | Transforms a frame to tokenized frame using specification. 
Tokenization is valid only for string columns. | Input:<br/> target = 
&lt;frame&gt; <br/> spec = &lt;json specification&gt; <br/> Outputs: 
&lt;matrix&gt;, &lt;frame&gt; | [tokenize](#tokenize)
 
 #### map
diff --git a/scripts/pipelines/scripts/utils.dml 
b/scripts/pipelines/scripts/utils.dml
index 97de7e7..14ff36e 100644
--- a/scripts/pipelines/scripts/utils.dml
+++ b/scripts/pipelines/scripts/utils.dml
@@ -174,8 +174,8 @@ return(Frame[Unknown] processedData, Matrix[Double] M)
   for(i in 1:ncol(mask))
     if(as.scalar(schema[1,i]) == "STRING")
       data[, i] = map(data[, i], "x -> x.toLowerCase()")
-  
-  # step 5 typo correction
+
+  # step 5 typo correction  
   if(CorrectTypos)
   {
     # recode data to get null mask
@@ -197,7 +197,7 @@ return(Frame[Unknown] processedData, Matrix[Double] M)
   }
   # step 6 porter stemming on all features
   print(prefix+" porter-stemming on all features");
-  data = map(data, "x -> PorterStemmer.stem(x)")
+  data = map(data, "x -> PorterStemmer.stem(x)", 0)
   
   # TODO add deduplication
   print(prefix+" deduplication via entity resolution");
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 21a874e..85a11e4 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -310,7 +310,7 @@ public class Types
                BITWXOR(true), CBIND(false), CONCAT(false), COV(false), 
DIV(true),
                DROP_INVALID_TYPE(false), DROP_INVALID_LENGTH(false), 
EQUAL(true), GREATER(true),
                GREATEREQUAL(true), INTDIV(true), INTERQUANTILE(false), 
IQM(false), LESS(true),
-               LESSEQUAL(true), LOG(true), MAP(false), MAX(true), 
MEDIAN(false), MIN(true), 
+               LESSEQUAL(true), LOG(true), MAX(true), MEDIAN(false), MIN(true),
                MINUS(true), MODULUS(true), MOMENT(false), MULT(true), 
NOTEQUAL(true), OR(true),
                PLUS(true), POW(true), PRINT(false), QUANTILE(false), 
SOLVE(false),
                RBIND(false), VALUE_SWAP(false), XOR(true),
@@ -359,7 +359,6 @@ public class Types
                                case DROP_INVALID_TYPE: return 
"dropInvalidType";
                                case DROP_INVALID_LENGTH: return 
"dropInvalidLength";
                                case VALUE_SWAP: return "valueSwap";
-                               case MAP:          return "_map";
                                default:           return name().toLowerCase();
                        }
                }
@@ -393,8 +392,7 @@ public class Types
                                case "bitwShiftR":  return BITWSHIFTR;
                                case "dropInvalidType": return 
DROP_INVALID_TYPE;
                                case "dropInvalidLength": return 
DROP_INVALID_LENGTH;
-                               case "valueSwap": return VALUE_SWAP;
-                               case "map":         return MAP;
+                               case "valueSwap":   return VALUE_SWAP;
                                default:            return 
valueOf(opcode.toUpperCase());
                        }
                }
@@ -402,7 +400,7 @@ public class Types
        
        // Operations that require 3 operands
        public enum OpOp3 {
-               QUANTILE, INTERQUANTILE, CTABLE, MOMENT, COV, PLUS_MULT, 
MINUS_MULT, IFELSE;
+               QUANTILE, INTERQUANTILE, CTABLE, MOMENT, COV, PLUS_MULT, 
MINUS_MULT, IFELSE, MAP;
                
                @Override
                public String toString() {
@@ -410,6 +408,7 @@ public class Types
                                case MOMENT:     return "cm";
                                case PLUS_MULT:  return "+*";
                                case MINUS_MULT: return "-*";
+                               case MAP:          return "_map";
                                default:         return name().toLowerCase();
                        }
                }
@@ -419,6 +418,7 @@ public class Types
                                case "cm": return MOMENT;
                                case "+*": return PLUS_MULT;
                                case "-*": return MINUS_MULT;
+                               case "map": return MAP;
                                default:   return valueOf(opcode.toUpperCase());
                        }
                }
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index de6ccdc..73deda4 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1076,9 +1076,6 @@ public class BinaryOp extends MultiThreadedHop {
                if( !(that instanceof BinaryOp) )
                        return false;
 
-               if(op == OpOp2.MAP)
-                       return false; // custom UDFs
-
                BinaryOp that2 = (BinaryOp)that;
                return (   op == that2.op
                                && outer == that2.outer
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 706a319..00bf2e1 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -182,9 +182,9 @@ public class TernaryOp extends MultiThreadedHop
                                case PLUS_MULT:
                                case MINUS_MULT:
                                case IFELSE:
+                               case MAP:
                                        constructLopsTernaryDefault();
                                        break;
-                                       
                                default:
                                        throw new 
HopsException(this.printErrorLocation() + "Unknown TernaryOp (" + _op + ") 
while constructing Lops \n");
 
@@ -375,6 +375,7 @@ public class TernaryOp extends MultiThreadedHop
                                return 
OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
                        case PLUS_MULT:
                        case MINUS_MULT:
+                       case MAP:
                        case IFELSE: {
                                if (isGPUEnabled()) {
                                        // For the GPU, the input is converted 
to dense
@@ -419,6 +420,15 @@ public class TernaryOp extends MultiThreadedHop
                
                switch( _op ) 
                {
+                       case MAP:
+                               long ldim1 = (mc[0].rowsKnown()) ? 
mc[0].getRows() :
+                                       (mc[1].getRows()>=0) ? mc[1].getRows() 
: -1;
+                               long ldim2 = (mc[0].colsKnown()) ? 
mc[0].getCols() :
+                                       (mc[1].getCols()>=0) ? mc[1].getCols() 
: -1;
+                               if( ldim1>=0 && ldim2>=0 )
+                                       ret = new MatrixCharacteristics(ldim1, 
ldim2, -1, (long) (ldim1 * ldim2 * 1.0));
+                               return ret;
+
                        case CTABLE:
                                boolean dimsSpec = (getInput().size() > 3); 
                                
@@ -515,20 +525,31 @@ public class TernaryOp extends MultiThreadedHop
        @Override
        public void refreshSizeInformation()
        {
+               Hop input1 = getInput().get(0);
+               Hop input2 = getInput().get(1);
+               Hop input3 = getInput().get(2);
+
                if ( getDataType() == DataType.SCALAR ) 
                {
                        //do nothing always known
                }
-               else 
+               else
                {
                        switch( _op ) 
                        {
+                               case MAP:
+                                       long ldim1, ldim2, lnnz1 = -1;
+                                       ldim1 = (input1.rowsKnown()) ? 
input1.getDim1() : ((input2.getDim1()>=0)?input2.getDim1():-1);
+                                       ldim2 = (input1.colsKnown()) ? 
input1.getDim2() : ((input2.getDim2()>=0)?input2.getDim2():-1);
+                                       lnnz1 = input1.getNnz();
+
+                                       setDim1( ldim1 );
+                                       setDim2( ldim2 );
+                                       setNnz(lnnz1);
+                                       break;
                                case CTABLE:
                                        //in general, do nothing because the 
output size is data dependent
-                                       Hop input1 = getInput().get(0);
-                                       Hop input2 = getInput().get(1);
-                                       Hop input3 = getInput().get(2);
-                                       
+
                                        //TODO double check reset 
(dimsInputPresent?)
                                        if ( !dimsKnown() ) { 
                                                //for ctable_expand at least 
one dimension is known
@@ -600,6 +621,9 @@ public class TernaryOp extends MultiThreadedHop
                        return false;
                
                TernaryOp that2 = (TernaryOp)that;
+
+               if(_op == OpOp3.MAP)
+                       return false; // custom UDFs
                
                //compare basic inputs and weights (always existing)
                boolean ret = (_op == that2._op
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java 
b/src/main/java/org/apache/sysds/lops/Binary.java
index 202dce7..6949188 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -80,10 +80,7 @@ public class Binary extends Lop
                        return null;
 
                ArrayList<Lop> inputs = getInputs();
-               if (operation == OpOp2.MAP && inputs.get(0).getDataType() == 
DataType.MATRIX 
-                               && inputs.get(1).getDataType() == 
DataType.MATRIX)
-                       return inputs.get(1);
-               else if (inputs.get(0).getDataType() == DataType.FRAME && 
inputs.get(1).getDataType() == DataType.MATRIX)
+               if (inputs.get(0).getDataType() == DataType.FRAME && 
inputs.get(1).getDataType() == DataType.MATRIX)
                        return inputs.get(1);
                else
                        return null;
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index cc70396..c2f72a5 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1565,20 +1565,20 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        break;
 
                case MAP:
-                       checkNumParameters(2);
+                       checkNumParameters(getThirdExpr() != null ? 3 : 2);
                        checkMatrixFrameParam(getFirstExpr());
                        checkScalarParam(getSecondExpr());
+                       if(getThirdExpr() != null)
+                               checkScalarParam(getThirdExpr()); // margin
                        output.setDataType(DataType.FRAME);
                        if(_args[1].getText().contains("jaccardSim")) {
                                output.setDimensions(id.getDim1(), 
id.getDim1());
                                output.setValueType(ValueType.FP64);
                        }
                        else {
-                               output.setDimensions(id.getDim1(), 1);
+                               output.setDimensions(id.getDim1(), 
id.getDim2());
                                output.setValueType(ValueType.STRING);
                        }
-                       output.setBlocksize (id.getBlocksize());
-
                        break;
                case LOCAL:
                        if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_LOCAL_COMMAND){
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index dc51d46..08c7ebf 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2540,10 +2540,14 @@ public class DMLTranslator
                case DROP_INVALID_TYPE:
                case DROP_INVALID_LENGTH:
                case VALUE_SWAP:
-               case MAP:
                        currBuiltinOp = new BinaryOp(target.getName(), 
target.getDataType(),
                                target.getValueType(), 
OpOp2.valueOf(source.getOpCode().name()), expr, expr2);
                        break;
+               case MAP:
+                       currBuiltinOp = new TernaryOp(target.getName(), 
target.getDataType(),
+                               target.getValueType(), 
OpOp3.valueOf(source.getOpCode().name()),
+                               expr, expr2, (expr3==null) ? new LiteralOp(0L) 
: expr3);
+                       break;
                
                case LOG:
                                if (expr2 == null) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 22cceea..c08985b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -162,7 +162,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "dropInvalidType"  , 
CPType.Binary);
                String2CPInstructionType.put( "dropInvalidLength"  , 
CPType.Binary);
                String2CPInstructionType.put( "valueSwap"  , CPType.Binary);
-               String2CPInstructionType.put( "_map"  , CPType.Binary); // _map 
represents the operation map
+               String2CPInstructionType.put( "_map"  , CPType.Ternary); // 
_map represents the operation map
 
                String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
                String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 85abf32..56cd49a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -183,7 +183,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "dropInvalidType", SPType.Binary);
                String2SPInstructionType.put( "mapdropInvalidLength", 
SPType.Binary);
                String2SPInstructionType.put( "valueSwap", SPType.Binary);
-               String2SPInstructionType.put( "_map", SPType.Binary); // _map 
refers to the operation map
+               String2SPInstructionType.put( "_map", SPType.Ternary); // _map 
refers to the operation map
                // Relational Instruction Opcodes
                String2SPInstructionType.put( "=="   , SPType.Binary);
                String2SPInstructionType.put( "!="   , SPType.Binary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index 188b2ac..4d83d6d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -57,8 +57,6 @@ public abstract class BinaryCPInstruction extends 
ComputationCPInstruction {
                        return new BinaryFrameFrameCPInstruction(operator, in1, 
in2, out, opcode, str);
                else if (in1.getDataType() == DataType.FRAME && 
in2.getDataType() == DataType.MATRIX)
                        return new BinaryFrameMatrixCPInstruction(operator, 
in1, in2, out, opcode, str);
-               else if (in1.getDataType() == DataType.FRAME && 
in2.getDataType() == DataType.SCALAR)
-                       return new BinaryFrameScalarCPInstruction(operator, 
in1, in2, out, opcode, str);
                else
                        return new BinaryMatrixScalarCPInstruction(operator, 
in1, in2, out, opcode, str);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
index 14d0090..86733fb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java
@@ -26,7 +26,7 @@ import 
org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 
 public class TernaryCPInstruction extends ComputationCPInstruction {
        
-       private TernaryCPInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+       protected TernaryCPInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
                super(CPType.Ternary, op, in1, in2, in3, out, opcode, str);
        }
 
@@ -40,7 +40,10 @@ public class TernaryCPInstruction extends 
ComputationCPInstruction {
                CPOperand outOperand = new CPOperand(parts[4]);
                int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) : 
1;
                TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode, numThreads);
-               return new TernaryCPInstruction(op, operand1, operand2, 
operand3, outOperand, opcode,str);
+               if(operand1.isFrame() && operand2.isScalar() && 
opcode.contains("map"))
+                       return  new TernaryFrameScalarCPInstruction(op, 
operand1, operand2, operand3, outOperand, opcode, str);
+               else
+                       return new TernaryCPInstruction(op, operand1, operand2, 
operand3, outOperand, opcode,str);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryFrameScalarCPInstruction.java
similarity index 75%
rename from 
src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
rename to 
src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryFrameScalarCPInstruction.java
index bcf7cb5..6512343 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryFrameScalarCPInstruction.java
@@ -21,22 +21,23 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 
-public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction
+public class TernaryFrameScalarCPInstruction extends TernaryCPInstruction
 {
-       protected BinaryFrameScalarCPInstruction(Operator op, CPOperand in1,
-                       CPOperand in2, CPOperand out, String opcode, String 
istr) {
-               super(CPType.Binary, op, in1, in2, out, opcode, istr);
+       protected TernaryFrameScalarCPInstruction(TernaryOperator op, CPOperand 
in1,
+                       CPOperand in2, CPOperand in3, CPOperand out, String 
opcode, String istr) {
+               super(op, in1, in2, in3, out, opcode, istr);
        }
 
        @Override
        public void processInstruction(ExecutionContext ec)  {
                // get input frames
                FrameBlock inBlock = ec.getFrameInput(input1.getName());
+               ScalarObject margin = ec.getScalarInput(input3);
                String stringExpression = 
ec.getScalarInput(input2).getStringValue();
                //compute results
-               FrameBlock outBlock = inBlock.map(stringExpression);
+               FrameBlock outBlock = inBlock.map(stringExpression, 
margin.getLongValue());
                // Attach result frame with FrameBlock associated with 
output_name
                ec.setFrameOutput(output.getName(), outBlock);
                // Release the memory occupied by input frames
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index c2f07f7..5e1d35f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -72,8 +72,6 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
                        throw new DMLRuntimeException("Federated binary tensor 
tensor operations not yet supported");
                else if( in1.isMatrix() && in2.isScalar() || in2.isMatrix() && 
in1.isScalar() )
                        return new BinaryMatrixScalarFEDInstruction(operator, 
in1, in2, out, opcode, str, fedOut);
-               else if( in1.isFrame() && in2.isScalar() || in2.isFrame() && 
in1.isScalar() )
-                       return new BinaryFrameScalarFEDInstruction(operator, 
in1, in2, out, opcode, InstructionUtils.removeFEDOutputFlag(str));
                else
                        throw new DMLRuntimeException("Federated binary 
operations not yet supported:" + opcode);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 12965f6..d46d493 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -37,7 +37,6 @@ import 
org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
@@ -51,6 +50,7 @@ import 
org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import 
org.apache.sysds.runtime.instructions.cp.TernaryFrameScalarCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -62,7 +62,6 @@ import 
org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
-import 
org.apache.sysds.runtime.instructions.spark.BinaryFrameScalarSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
@@ -85,6 +84,7 @@ import 
org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.TernaryFrameScalarSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
@@ -187,9 +187,6 @@ public class FEDInstructionUtils {
                                else
                                        fedinst = 
BinaryFEDInstruction.parseInstruction(
                                                
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
-                       } else if(inst.getOpcode().equals("_map") && inst 
instanceof BinaryFrameScalarCPInstruction && 
!inst.getInstructionString().contains("UtilFunctions")
-                               && instruction.input1.isFrame() && 
ec.getFrameObject(instruction.input1).isFederated()) {
-                               fedinst = 
BinaryFrameScalarFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
                        }
                }
                else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
@@ -217,7 +214,15 @@ public class FEDInstructionUtils {
                }
                else if(inst instanceof TernaryCPInstruction) {
                        TernaryCPInstruction tinst = (TernaryCPInstruction) 
inst;
-                       if((tinst.input1.isMatrix() && 
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
+                       if(inst.getOpcode().equals("_map") && inst instanceof 
TernaryFrameScalarCPInstruction && 
!inst.getInstructionString().contains("UtilFunctions")
+                               && tinst.input1.isFrame() && 
ec.getFrameObject(tinst.input1).isFederated()) {
+                               long margin = 
ec.getScalarInput(tinst.input3).getLongValue();
+                               FrameObject fo = 
ec.getFrameObject(tinst.input1);
+                               if(margin == 0 || (fo.isFederated(FType.ROW) && 
margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
+                                       fedinst = 
TernaryFrameScalarFEDInstruction
+                                               
.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+                       }
+                       else if((tinst.input1.isMatrix() && 
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
                                || (tinst.input2.isMatrix() && 
ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
                                || (tinst.input3.isMatrix() && 
ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
                                fedinst = 
TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
@@ -435,11 +440,6 @@ public class FEDInstructionUtils {
                                                
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
                                }
                        }
-                       else if(inst.getOpcode().equals("_map") && inst 
instanceof BinaryFrameScalarSPInstruction && 
!inst.getInstructionString().contains("UtilFunctions")
-                               && instruction.input1.isFrame() && 
ec.getFrameObject(instruction.input1).isFederated()) {
-                               fedinst = 
BinaryFrameScalarFEDInstruction.parseInstruction(InstructionUtils
-                                       
.concatOperands(inst.getInstructionString(), FederatedOutput.NONE.name()));
-                       }
                        else if( (instruction.input1.isMatrix() && 
ec.getCacheableData(instruction.input1).isFederatedExcept(FType.BROADCAST))
                                || (instruction.input2.isMatrix() && 
ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
                                if("cov".equals(instruction.getOpcode()) && 
(ec.getMatrixObject(instruction.input1)
@@ -476,7 +476,14 @@ public class FEDInstructionUtils {
                }
                else if(inst instanceof TernarySPInstruction) {
                        TernarySPInstruction tinst = (TernarySPInstruction) 
inst;
-                       if((tinst.input1.isMatrix() && 
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
+                       if(inst.getOpcode().equals("_map") && inst instanceof 
TernaryFrameScalarSPInstruction && 
!inst.getInstructionString().contains("UtilFunctions")
+                               && tinst.input1.isFrame() && 
ec.getFrameObject(tinst.input1).isFederated()) {
+                               long margin = 
ec.getScalarInput(tinst.input3).getLongValue();
+                               FrameObject fo = 
ec.getFrameObject(tinst.input1);
+                               if(margin == 0 || (fo.isFederated(FType.ROW) && 
margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
+                                       fedinst = 
TernaryFrameScalarFEDInstruction
+                                               
.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
+                       } else if((tinst.input1.isMatrix() && 
ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
                                || (tinst.input2.isMatrix() && 
ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
                                || (tinst.input3.isMatrix() && 
ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
                                fedinst = 
TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 954c4c4..cb8d9fa 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -38,7 +38,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
 public class TernaryFEDInstruction extends ComputationFEDInstruction {
 
-       private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
+       protected TernaryFEDInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
                String opcode, String str, FederatedOutput fedOut) {
                super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, 
opcode, str, fedOut);
        }
@@ -50,9 +50,11 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
                CPOperand operand2 = new CPOperand(parts[2]);
                CPOperand operand3 = new CPOperand(parts[3]);
                CPOperand outOperand = new CPOperand(parts[4]);
-               int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) : 
1;
-               FederatedOutput fedOut = parts.length>7 ? 
FederatedOutput.valueOf(parts[6]) : FederatedOutput.NONE;
+               int numThreads = parts.length>5 & !opcode.contains("map") ? 
Integer.parseInt(parts[5]) : 1;
+               FederatedOutput fedOut = parts.length>7 && 
!opcode.contains("map") ? FederatedOutput.valueOf(parts[6]) : 
FederatedOutput.NONE;
                TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode, numThreads);
+               if( operand1.isFrame() && operand2.isScalar() || 
operand2.isFrame() && operand1.isScalar() )
+                       return new TernaryFrameScalarFEDInstruction(op, 
operand1, operand2, operand3, outOperand, opcode, 
InstructionUtils.removeFEDOutputFlag(str), fedOut);
                return new TernaryFEDInstruction(op, operand1, operand2, 
operand3, outOperand, opcode, str, fedOut);
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFrameScalarFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFrameScalarFEDInstruction.java
similarity index 83%
rename from 
src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFrameScalarFEDInstruction.java
rename to 
src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFrameScalarFEDInstruction.java
index ecb3a45..6205d20 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFrameScalarFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFrameScalarFEDInstruction.java
@@ -25,13 +25,13 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 
-public class BinaryFrameScalarFEDInstruction extends BinaryFEDInstruction
+public class TernaryFrameScalarFEDInstruction extends TernaryFEDInstruction
 {
-       protected BinaryFrameScalarFEDInstruction(Operator op, CPOperand in1,
-                       CPOperand in2, CPOperand out, String opcode, String 
istr) {
-               super(FEDInstruction.FEDType.Binary, op, in1, in2, out, opcode, 
istr);
+       protected TernaryFrameScalarFEDInstruction(TernaryOperator op, 
CPOperand in1,
+                       CPOperand in2, CPOperand in3, CPOperand out, String 
opcode, String istr, FederatedOutput fedOut) {
+               super(op, in1, in2, in3, out, opcode, istr, fedOut);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
index 82a95a8..16196d5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinarySPInstruction.java
@@ -111,9 +111,8 @@ public abstract class BinarySPInstruction extends 
ComputationSPInstruction {
                else if( dt1 == DataType.FRAME || dt2 == DataType.FRAME ) {
                        if(dt1 == DataType.FRAME && dt2 == DataType.FRAME)
                                return new 
BinaryFrameFrameSPInstruction(operator, in1, in2, out, opcode, str);
-                       if(dt1 == DataType.FRAME && dt2 == DataType.SCALAR)
-                               return  new 
BinaryFrameScalarSPInstruction(operator, in1, in2, out, opcode, str);
-
+                       else if(dt1 == DataType.FRAME && dt2 == DataType.SCALAR 
&& opcode.equalsIgnoreCase("+"))
+                               return new 
BinaryMatrixScalarSPInstruction(operator, in1, in2, out, opcode, str);
                }
 
                return null;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernaryFrameScalarSPInstruction.java
similarity index 74%
rename from 
src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
rename to 
src/main/java/org/apache/sysds/runtime/instructions/spark/TernaryFrameScalarSPInstruction.java
index b5cf078..d609515 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameScalarSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernaryFrameScalarSPInstruction.java
@@ -25,12 +25,12 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 
-public class BinaryFrameScalarSPInstruction extends BinarySPInstruction {
-       protected BinaryFrameScalarSPInstruction (Operator op, CPOperand in1, 
CPOperand in2, CPOperand out,
+public class TernaryFrameScalarSPInstruction extends TernarySPInstruction {
+       protected TernaryFrameScalarSPInstruction(TernaryOperator op, CPOperand 
in1, CPOperand in2, CPOperand in3, CPOperand out,
                        String opcode, String istr) {
-               super(SPType.Binary, op, in1, in2, out, opcode, istr);
+               super(op, in1, in2, in3, out, opcode, istr);
        }
 
        @Override
@@ -40,15 +40,19 @@ public class BinaryFrameScalarSPInstruction extends 
BinarySPInstruction {
                // Get input RDDs
                JavaPairRDD<Long, FrameBlock> in1 = 
sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
                String expression = sec.getScalarInput(input2).getStringValue();
+               long margin = ec.getScalarInput(input3).getLongValue();
 
                // Create local compiled functions (once) and execute on RDD
-               JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new 
RDDStringProcessing(expression));
+               JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new 
RDDStringProcessing(expression, margin));
 
                if(expression.contains("jaccardSim")) {
                        long rows = 
sec.getDataCharacteristics(output.getName()).getRows();
                        
sec.getDataCharacteristics(output.getName()).setDimension(rows, rows);
+               } else {
+                       long rows = 
sec.getDataCharacteristics(output.getName()).getRows();
+                       long cols = 
sec.getDataCharacteristics(output.getName()).getCols();
+                       
sec.getDataCharacteristics(output.getName()).setDimension(rows, cols);
                }
-
                sec.setRDDHandleForVariable(output.getName(), out);
                sec.addLineageRDD(output.getName(), input1.getName());
        }
@@ -57,14 +61,17 @@ public class BinaryFrameScalarSPInstruction extends 
BinarySPInstruction {
                private static final long serialVersionUID = 
5850400295183766400L;
 
                private String _expr = null;
+               private long _margin = -1;
 
-               public RDDStringProcessing(String expr) {
+               public RDDStringProcessing(String expr, long margin) {
                        _expr = expr;
+                       _margin = margin;
                }
 
                @Override
                public FrameBlock call(FrameBlock arg0) throws Exception {
-                       return arg0.map(_expr);
+                       FrameBlock fb =  arg0.map(_expr, _margin);
+                       return fb;
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
index 2f4c129..fa1dce2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TernarySPInstruction.java
@@ -34,7 +34,7 @@ import scala.Tuple2;
 import java.io.Serializable;
 
 public class TernarySPInstruction extends ComputationSPInstruction {
-       private TernarySPInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+       protected TernarySPInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
                super(SPType.Ternary, op, in1, in2, in3, out, opcode, str);
        }
 
@@ -46,7 +46,10 @@ public class TernarySPInstruction extends 
ComputationSPInstruction {
                CPOperand operand3 = new CPOperand(parts[3]);
                CPOperand outOperand = new CPOperand(parts[4]);
                TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode);
-               return new TernarySPInstruction(op,operand1, operand2, 
operand3, outOperand, opcode,str);
+               if(operand1.isFrame() && operand2.isScalar() && 
opcode.contains("map"))
+                       return  new TernaryFrameScalarSPInstruction(op, 
operand1, operand2, operand3, outOperand, opcode, str);
+               else
+                return new TernarySPInstruction(op, operand1, operand2, 
operand3, outOperand, opcode,str);
        }
        
        @Override
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index f4f9beb..751d24e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -2320,7 +2320,7 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        }
        }
 
-       public FrameBlock map (String lambdaExpr){
+       public FrameBlock map (String lambdaExpr, long margin){
                if(!lambdaExpr.contains("->")) {
                        String args = 
lambdaExpr.substring(lambdaExpr.indexOf('(') + 1, lambdaExpr.indexOf(')'));
                        if(args.contains(",")) {
@@ -2333,8 +2333,8 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        }
                }
                if(lambdaExpr.contains("jaccardSim"))
-                       return mapDist(getCompiledFunction(lambdaExpr));
-               return map(getCompiledFunction(lambdaExpr));
+                       return mapDist(getCompiledFunction(lambdaExpr, margin));
+               return map(getCompiledFunction(lambdaExpr, margin), margin);
        }
 
        public FrameBlock valueSwap(FrameBlock schema) {
@@ -2418,17 +2418,38 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                return this;
        }
 
-       public FrameBlock map (FrameMapFunction lambdaExpr) {
+       public FrameBlock map (FrameMapFunction lambdaExpr, long margin) {
                // Prepare temporary output array
                String[][] output = new String[getNumRows()][getNumColumns()];
-               // Execute map function on all cells
-               for(int j = 0; j < getNumColumns(); j++) {
-                       Array input = getColumn(j);
-                       for(int i = 0; i < input._size; i++)
-                               if(input.get(i) != null)
-                                       output[i][j] = 
lambdaExpr.apply(String.valueOf(input.get(i)));
-               }
 
+               if (margin == 1) {
+                       // Execute map function on rows
+                       for(int i = 0; i < getNumRows(); i++) {
+                               String[] row = new String[getNumColumns()];
+                               for(int j = 0; j < getNumColumns(); j++) {
+                                       Array input = getColumn(j);
+                                       row[j] = String.valueOf(input.get(i));
+                               }
+                               output[i] = lambdaExpr.apply(row);
+                       }
+               } else if (margin == 2) {
+                       // Execute map function on columns
+                       for(int j = 0; j < getNumColumns(); j++) {
+                               String[] actualColumn = 
Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more 
rows can be allocated, mutable array
+                               String[] outColumn = 
lambdaExpr.apply(actualColumn);
+
+                               for(int i = 0; i < getNumRows(); i++)
+                                       output[i][j] = outColumn[i];
+                       }
+               } else {
+                       // Execute map function on all cells
+                       for(int j = 0; j < getNumColumns(); j++) {
+                               Array input = getColumn(j);
+                               for(int i = 0; i < input._size; i++)
+                                       if(input.get(i) != null)
+                                               output[i][j] = 
lambdaExpr.apply(String.valueOf(input.get(i)));
+                       }
+               }
                return new FrameBlock(UtilFunctions.nCopies(getNumColumns(), 
ValueType.STRING), output);
        }
 
@@ -2441,13 +2462,12 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        for(int i = j + 1; i < input._size; i++)
                                if(input.get(i) != null && input.get(j) != 
null) {
                                        output[j][i] = 
lambdaExpr.apply(String.valueOf(input.get(j)), String.valueOf(input.get(i)));
-                                       //                                      
output[i][j] = output[j][i];
                                }
                }
                return new FrameBlock(UtilFunctions.nCopies(getNumRows(), 
ValueType.STRING), output);
        }
 
-       public static FrameMapFunction getCompiledFunction (String lambdaExpr) {
+       public static FrameMapFunction getCompiledFunction (String lambdaExpr, 
long margin) {
                String cname = "StringProcessing" + CLASS_ID.getNextID();
                StringBuilder sb = new StringBuilder();
                String[] parts = lambdaExpr.split("->");
@@ -2460,14 +2480,20 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                sb.append("import 
org.apache.sysds.runtime.util.UtilFunctions;\n");
                sb.append("import 
org.apache.sysds.runtime.util.PorterStemmer;\n");
                sb.append("import 
org.apache.sysds.runtime.matrix.data.FrameBlock.FrameMapFunction;\n");
+               sb.append("import java.util.Arrays;\n");
                sb.append("public class " + cname + " extends FrameMapFunction 
{\n");
-               if(varname.length == 1) {
-                       sb.append("public String apply(String " + 
varname[0].trim() + ") {\n");
-                       sb.append("  return String.valueOf(" + expr + "); 
}}\n");
-               }
-               else if(varname.length == 2) {
-                       sb.append("public String apply(String " + 
varname[0].trim() + ", String " + varname[1].trim() + ") {\n");
-                       sb.append("  return String.valueOf(" + expr + "); 
}}\n");
+               if(margin != 0) {
+                       sb.append("public String[] apply(String[] " + 
varname[0].trim() + ") {\n");
+                       sb.append("  return UtilFunctions.toStringArray(" + 
expr + "); }}\n");
+               } else {
+                       if(varname.length == 1) {
+                               sb.append("public String apply(String " + 
varname[0].trim() + ") {\n");
+                               sb.append("  return String.valueOf(" + expr + 
"); }}\n");
+                       }
+                       else if(varname.length == 2) {
+                               sb.append("public String apply(String " + 
varname[0].trim() + ", String " + varname[1].trim() + ") {\n");
+                               sb.append("  return String.valueOf(" + expr + 
"); }}\n");
+                       }
                }
                // compile class, and create FrameMapFunction object
                try {
@@ -2485,6 +2511,7 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                private static final long serialVersionUID = 
-8398572153616520873L;
                public String apply(String input) {return null;}
                public String apply(String input1, String input2) {     return 
null;}
+               public String[] apply(String[] input1) {        return null;}
        }
 
        public <T> FrameBlock replaceOperations(String pattern, String 
replacement) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java 
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ee64bc8..4376fa9 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -19,20 +19,6 @@
 
 package org.apache.sysds.runtime.util;
 
-import org.apache.commons.lang.ArrayUtils;
-import org.apache.commons.lang3.math.NumberUtils;
-import org.apache.commons.math3.random.RandomDataGenerator;
-import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.data.SparseBlock;
-import org.apache.sysds.runtime.data.TensorIndexes;
-import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysds.runtime.matrix.data.Pair;
-import org.apache.sysds.runtime.meta.TensorCharacteristics;
-import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
-
 import java.text.ParseException;
 import java.text.SimpleDateFormat;
 import java.util.ArrayList;
@@ -47,6 +33,20 @@ import java.util.Set;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.math.NumberUtils;
+import org.apache.commons.math3.random.RandomDataGenerator;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.TensorIndexes;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.meta.TensorCharacteristics;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
+
 public class UtilFunctions {
        // private static final Log LOG = 
LogFactory.getLog(UtilFunctions.class.getName());
 
@@ -863,6 +863,12 @@ public class UtilFunctions {
                return value ;
        }
 
+       public static String[] copyAsStringToArray(String[] input, Object 
value) {
+               String[] output = new String[input.length];
+               Arrays.fill(output, String.valueOf(value));
+               return output;
+       }
+
        private static String getDateFormat (String dateString) {
                return DATE_FORMATS.keySet().parallelStream().filter(e -> 
dateString.toLowerCase().matches(e)).findFirst()
                        .map(DATE_FORMATS::get).orElseThrow(() -> new 
NullPointerException("Unknown date format."));
@@ -1013,4 +1019,26 @@ public class UtilFunctions {
                //forward to column encoder, as UtilFunctions available in map 
context
                return ColumnEncoderRecode.splitRecodeMapEntry(s);
        }
+
+       public static String[] toStringArray(Object[] original) {
+               String[] result = new String[original.length];
+               for (int i = 0; i < result.length; i++)
+                       result[i] = String.valueOf(original[i]);
+               return result;
+       }
+
+       public static double[] convertStringToDoubleArray(String[] original) {
+//             double[] ret = new double[original.length];
+//             for (int i = 0; i < original.length; i++) {
+//                     try {
+//                             ret[i] = 
NumberFormat.getInstance().parse(original[i]).doubleValue();
+//                     }
+//                     catch(Exception e) {
+//                             e.printStackTrace();
+//                     }
+//             }
+//             return ret;
+
+               return 
Arrays.stream(original).mapToDouble(Double::parseDouble).toArray();
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapMarginTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapMarginTest.java
new file mode 100644
index 0000000..aca4484
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameMapMarginTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.binary.frame;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class FrameMapMarginTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "mapMargin";
+       private final static String TEST_DIR = "functions/binary/frame/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FrameMapMarginTest.class.getSimpleName() + "/";
+
+       private final static int rows = 100;
+       private final static Types.ValueType[] schemaStrings1 = 
{Types.ValueType.STRING, Types.ValueType.STRING};
+       private final static String expression = "x -> 
UtilFunctions.copyAsStringToArray(x, 
Arrays.stream(UtilFunctions.convertStringToDoubleArray(x)).sum())";
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @AfterClass
+       public static void cleanUp() {
+               if (TEST_CACHE_ENABLED) {
+                       TestUtils.clearDirectory(TEST_DATA_DIR + 
TEST_CLASS_DIR);
+               }
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"D"}));
+               if (TEST_CACHE_ENABLED) {
+                       setOutAndExpectedDeletionDisabled(true);
+               }
+       }
+
+       @Test
+       public void testMarginColCP() { runDmlMapTest(expression, 2, 
ExecType.CP); }
+
+       @Test
+       public void testMarginColSP() { runDmlMapTest(expression, 2, 
ExecType.CP); }
+
+       @Test
+       public void testMarginRowCP() {
+               runDmlMapTest(expression, 1, ExecType.SPARK);
+       }
+
+       @Test
+       public void testMarginRowSP() {
+               runDmlMapTest(expression, 1, ExecType.SPARK);
+       }
+
+       private void runDmlMapTest( String expression, int margin, ExecType et)
+       {
+               Types.ExecMode platformOld = setExecMode(et);
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] { "-stats","-args", 
input("A"), expression, String.valueOf(margin), output("O")};
+
+                       double[][] A = getRandomMatrix(rows, 2, 1, 1, 1, 2);
+                       writeInputFrameWithMTD("A", A, true, schemaStrings1, 
FileFormat.CSV);
+                       
+                       runTest(true, false, null, -1);
+
+                       FrameBlock outputFrame = readDMLFrameFromHDFS("O", 
FileFormat.CSV);
+
+                       for(int j = 0; j < schemaStrings1.length; j++)
+                               for(int i = 0; i < rows; i++) {
+                                       
Assert.assertEquals(Double.parseDouble(((String[]) 
outputFrame.getColumnData(j))[i]),
+                                               margin == 1 ? 
schemaStrings1.length : rows,
+                                               0.0);
+                               }
+               }
+               catch (Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/binary/frame/map.dml 
b/src/test/scripts/functions/binary/frame/map.dml
index 482c37e..a3c2b65 100644
--- a/src/test/scripts/functions/binary/frame/map.dml
+++ b/src/test/scripts/functions/binary/frame/map.dml
@@ -22,7 +22,7 @@
 # input: 1) frame, 2) lamba expression to execute for each non-null cell
 # output: frame of string columns
 
-# Examples: 
+# Examples:
 #   map(X, "x -> x.split(\"r\")[1]")
 #   map(X, "x -> x.charAt(2)")
 #   map(X, "y -> UtilFunctions.toMillis(y)")
diff --git a/src/test/scripts/functions/binary/frame/map.dml 
b/src/test/scripts/functions/binary/frame/mapMargin.dml
similarity index 75%
copy from src/test/scripts/functions/binary/frame/map.dml
copy to src/test/scripts/functions/binary/frame/mapMargin.dml
index 482c37e..f1ccbf1 100644
--- a/src/test/scripts/functions/binary/frame/map.dml
+++ b/src/test/scripts/functions/binary/frame/mapMargin.dml
@@ -19,16 +19,14 @@
 #
 #-------------------------------------------------------------
 
-# input: 1) frame, 2) lamba expression to execute for each non-null cell
-# output: frame of string columns
+# input: 1) frame, 2) lamba expression to execute for each non-null cell, row 
or col
+# output: frame
 
-# Examples: 
-#   map(X, "x -> x.split(\"r\")[1]")
-#   map(X, "x -> x.charAt(2)")
-#   map(X, "y -> UtilFunctions.toMillis(y)")
+# Example (convert strings to doubles, compute sum and opy to every cell):
+#   map(X, "x -> UtilFunctions.copyAsStringToArray(x, 
Arrays.stream(UtilFunctions.convertStringToDoubleArray(x)).sum())", 1)
 
 X = read($1, data_type = "frame", format = "csv", header = FALSE)
-# column vector and string operation
-Y = map(X, $2)
-write(Y, $3,  format="csv")
-write(X, $4,  format="csv")
\ No newline at end of file
+
+X1 = map(X, "x -> x.replace(\"Str\", \"\")", 0)
+Y = map(X1, $2, $3)
+write(Y, $4,  format="csv")
diff --git a/src/test/scripts/functions/federated/FederatedFrameMapTest.dml 
b/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
index b879b2f..abadf86 100644
--- a/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
+++ b/src/test/scripts/functions/federated/FederatedFrameMapTest.dml
@@ -31,7 +31,7 @@ if ($rP) {
 
 A = as.frame(A)
 
-S = map(A, "x -> x.replace(\"1\", \"2\")");
+S = map(A, "x -> x.replace(\"1\", \"2\")", 0);
 write(S, $out_S);
 print(toString(A[1,1]))
 print(toString(S[1,1]))
diff --git 
a/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml 
b/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
index 8b55291..b345984 100644
--- a/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedFrameMapTestReference.dml
@@ -28,7 +28,7 @@ else {
 
 A = as.frame(A)
 
-S = map(A, "x -> x.replace(\"1\", \"2\")");
+S = map(A, "x -> x.replace(\"1\", \"2\")", 0);
 write(S, $6);
 
 print(toString(A[1,1]))

Reply via email to