This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ec292a  [SYSTEMDS-2888] Fix incomplete cbind support in codegen row 
templates
1ec292a is described below

commit 1ec292a932c6e732bbac835a81cdb59371002114
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Apr 10 23:17:01 2021 +0200

    [SYSTEMDS-2888] Fix incomplete cbind support in codegen row templates
    
    This patch extends the existing cbind handling (compiler and runtime) in
    fused row templates, specifically for vector-vector cbinds where the
    right hand side input is also a matrix, not just a column vector.
    
    Furthermore, we also clean up the recently refactored row and cell
    templates and the related recursive template expansion.
---
 .../apache/sysds/hops/codegen/SpoofCompiler.java   |  5 +++-
 .../org/apache/sysds/hops/codegen/cplan/CNode.java |  2 +-
 .../sysds/hops/codegen/cplan/CNodeBinary.java      |  8 +++---
 .../apache/sysds/hops/codegen/cplan/CNodeCell.java | 30 +++++-----------------
 .../sysds/hops/codegen/cplan/CNodeMultiAgg.java    |  2 +-
 .../hops/codegen/cplan/CNodeOuterProduct.java      |  2 +-
 .../apache/sysds/hops/codegen/cplan/CNodeRow.java  | 22 ++++++++--------
 .../sysds/hops/codegen/cplan/CodeTemplate.java     |  2 +-
 .../sysds/hops/codegen/cplan/cuda/Binary.java      | 10 ++++----
 .../sysds/hops/codegen/cplan/java/Binary.java      | 19 ++++++++------
 .../hops/codegen/cplan/java/Cellwise.java.template | 14 +++++-----
 .../hops/codegen/cplan/java/Rowwise.java.template  | 28 ++++++++++----------
 .../sysds/hops/codegen/template/TemplateRow.java   |  2 +-
 .../sysds/hops/codegen/template/TemplateUtils.java |  6 ++++-
 .../sysds/runtime/codegen/LibSpoofPrimitives.java  | 15 +++++++++++
 .../codegen/FederatedCellwiseTmplTest.java         |  6 ++---
 .../codegen/FederatedCellwiseTmplTest.dml          |  4 +--
 .../codegen/FederatedCellwiseTmplTestReference.dml | 13 ++++------
 18 files changed, 101 insertions(+), 89 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index f728b46..d3d638a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -135,7 +135,10 @@ public class SpoofCompiler {
        public enum GeneratorAPI {
                AUTO,
                JAVA,
-               CUDA
+               CUDA;
+               public boolean isJava() {
+                       return this == JAVA;
+               }
        }
 
        public enum IntegrationType {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
index 841cd1c..fd5c2e6 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
@@ -90,7 +90,7 @@ public abstract class CNode
                                return "a.cols()";
                        if(getVarname().startsWith("b"))
                                return getVarname()+".cols()";
-                       else                            
+                       else
                                return getVarname()+".length";
                }
                else {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 925e055..6d53e1e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -159,8 +159,11 @@ public class CNodeBinary extends CNode {
                boolean scalarInput = _inputs.get(0).getDataType().isScalar();
                boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
                        && _inputs.get(1).getDataType().isMatrix());
+               boolean vectorVector = _inputs.get(0).getDataType().isMatrix()
+                       && _inputs.get(1).getDataType().isMatrix();
                String var = createVarname();
-               String tmp = getLanguageTemplateClass(this, 
api).getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput);
+               String tmp = getLanguageTemplateClass(this, api)
+                       .getTemplate(_type, lsparseLhs, lsparseRhs, 
scalarVector, scalarInput, vectorVector);
 
                tmp = tmp.replace("%TMP%", var);
                
@@ -174,7 +177,6 @@ public class CNodeBinary extends CNode {
                        tmp = tmp.replace("%IN"+(j+1)+"%",
                                        varj.startsWith("a") ? (api == 
GeneratorAPI.JAVA ? varj : 
                                                (_inputs.get(j).getDataType() 
== DataType.MATRIX ? varj + ".vals(0)" : varj)) :
-//                                     varj.startsWith("b") ? (api == 
GeneratorAPI.JAVA ? varj + ".values(rix)" : varj + ".vals(0)") : varj);
                                                varj.startsWith("b") ? (api == 
GeneratorAPI.JAVA ? varj + ".values(rix)" : 
                                                                (_type == 
BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
                                                        
_inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? 
varj : varj + ".vals(0)") : varj);
@@ -186,7 +188,7 @@ public class CNodeBinary extends CNode {
                                        varj + ".pos(rix)" : "0" : "0");
                }
                //replace length information (e.g., after matrix mult)
-               if( _type == BinType.VECT_OUTERMULT_ADD ) {
+               if( _type == BinType.VECT_OUTERMULT_ADD || (_type == 
BinType.VECT_CBIND && vectorVector) ) {
                        for( int j=0; j<2; j++ )
                                tmp = tmp.replace("%LEN"+(j+1)+"%", 
_inputs.get(j).getVectorLength(api));
                }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index 1c67e3d..070fc9e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -124,38 +124,22 @@ public class CNodeCell extends CNodeTpl
 
                //generate dense/sparse bodies
                String tmpDense = _output.codegen(false, api);
-               // ToDo: workaround to fix name clash of cell and row template
+               // TODO: workaround to fix name clash of cell and row template
                if(api == GeneratorAPI.CUDA)
                        tmpDense = tmpDense.replace("a.vals(0)", "a");
                _output.resetGenerated();
                
-               String varName; 
-               if(getVarname() == null)
-//                     tmp = tmp.replace("%TMP%", createVarname());
-                       varName = createVarname();
-               else
-//                     tmp = tmp.replace("%TMP%", getVarname());
-                       varName = getVarname();
-               
-               if(api == GeneratorAPI.JAVA)
-                       tmp = tmp.replace("%TMP%", varName);
-               else
-                       tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
+               String varName = (getVarname() == null) ?
+                       createVarname() : getVarname();
+               tmp = tmp.replace(api.isJava() ? 
+                       "%TMP%" : "/*%TMP%*/SPOOF_OP_NAME", varName);
                
                if(tmpDense.contains("grix"))
                        tmp = tmp.replace("//%NEED_GRIX%", "\t\tuint32_t 
grix=_grix + rix;");
                else
                        tmp = tmp.replace("//%NEED_GRIX%", "");
-               
-//             if(tmpDense.contains("rix"))
-//                     tmp = tmp.replace("//%NEED_RIX%", "\t\tuint32_t rix = 
idx / A.cols();\n");
-//             else
-                       tmp = tmp.replace("//%NEED_RIX%", "");
-               
-//             if(tmpDense.contains("cix"))
-//                     tmp = tmp.replace("//%NEED_CIX%", "\t\tuint32_t cix = 
idx % A.cols();");
-//             else
-                       tmp = tmp.replace("//%NEED_CIX%", "");
+               tmp = tmp.replace("//%NEED_RIX%", "");
+               tmp = tmp.replace("//%NEED_CIX%", "");
                
                tmp = tmp.replace("%BODY_dense%", tmpDense);
                
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
index c14c2c7..0b4d625 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
@@ -55,7 +55,7 @@ public class CNodeMultiAgg extends CNodeTpl
        private static final String TEMPLATE_OUT_MIN   = "    c[%IX%] = 
Math.min(c[%IX%], %IN%);\n";
        private static final String TEMPLATE_OUT_MAX   = "    c[%IX%] = 
Math.max(c[%IX%], %IN%);\n";
        
-       private ArrayList<CNode> _outputs = null; 
+       private ArrayList<CNode> _outputs = null;
        private ArrayList<AggOp> _aggOps = null;
        private ArrayList<Hop> _roots = null;
        private boolean _sparseSafe = false;
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
index 955e9e2..d796045 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
@@ -48,7 +48,7 @@ public class CNodeOuterProduct extends CNodeTpl
                        + "  protected double genexecCellwise(double a, 
double[] a1, int a1i, double[] a2, int a2i, SideInput[] b, double[] scalars, 
int m, int n, int len, int rix, int cix) { \n"
                        + "%BODY_cellwise%"
                        + "    return %OUT_cellwise%;\n"
-                       + "  }\n"                       
+                       + "  }\n"
                        + "}\n";
        
        private OutProdType _type = null;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
index ecc1a4e..603fea9 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
@@ -96,14 +96,16 @@ public class CNodeRow extends CNodeTpl
                String tmpSparse = _output.codegen(true, api) + 
getOutputStatement(_output.getVarname());
                _output.resetGenerated();
                String varName = createVarname();
-               tmp = tmp.replace("//%TMP%", varName);
-               tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
-               tmp = tmp.replace("//%BODY_dense%", tmpDense);
-               tmp = tmp.replace("//%BODY_sparse%", tmpSparse);
+               tmp = tmp.replace(api.isJava()?"%TMP%":"//%TMP%", varName);
+               if( !api.isJava() )
+                       tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
+               String prefix = api.isJava()? "" : "//";
+               tmp = tmp.replace(prefix+"%BODY_dense%", tmpDense);
+               tmp = tmp.replace(prefix+"%BODY_sparse%", tmpSparse);
                
                //replace outputs 
-               tmp = api == GeneratorAPI.JAVA ? tmp.replace("%OUT%", "c") :
-                               tmp.replace("%OUT%", "c.vals(0)");
+               tmp = api.isJava() ? tmp.replace("%OUT%", "c") :
+                       tmp.replace("%OUT%", "c.vals(0)");
                tmp = tmp.replace("%POSOUT%", "0");
                
                //replace size information
@@ -132,14 +134,13 @@ public class CNodeRow extends CNodeTpl
                switch( _type ) {
                        case NO_AGG:
                                if(api == GeneratorAPI.CUDA)
-                                       return 
TEMPLATE_NOAGG_OUT_CUDA.replace("%IN%", varName + ".vals(0)") 
.replaceAll("%LEN%", _output.getVarname()+".length");
+                                       return 
TEMPLATE_NOAGG_OUT_CUDA.replace("%IN%", varName + 
".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
                        case NO_AGG_B1:
                        case NO_AGG_CONST:
                                if(api == GeneratorAPI.JAVA)
-                                       return 
TEMPLATE_NOAGG_OUT.replace("%IN%", varName) .replace("%LEN%", 
_output.getVarname()+".length");
+                                       return 
TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%", 
_output.getVarname()+".length");
                                else
-//                                     return "";
-                                       return 
TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName + ".vals(0)") 
.replaceAll("%LEN%", _output.getVarname()+".length");
+                                       return 
TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName + 
".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
                        case FULL_AGG:
                                if(api == GeneratorAPI.JAVA)
                                        return 
TEMPLATE_FULLAGG_OUT.replace("%IN%", varName);
@@ -237,5 +238,4 @@ public class CNodeRow extends CNodeTpl
        
        private native int compile_nvrtc(long context, String name, String src, 
int type, long constDim2, int numVectors, 
                        boolean TB1);
-
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
index 34f0b66..ce30fdb 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
@@ -33,7 +33,7 @@ public abstract class CodeTemplate {
        }
        
        public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector,
-               boolean scalarInput) {
+               boolean scalarInput, boolean vectorVector) {
                throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
        }
        
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
index cd02f5a..7d9655f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
@@ -24,12 +24,12 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
 
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
 
-public class Binary extends CodeTemplate {
-       
+public class Binary extends CodeTemplate
+{
        @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector,
-                                                         boolean scalarInput) {
-
+       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs,
+               boolean scalarVector, boolean scalarInput, boolean vectorVector)
+       {
                if(isSinglePrecision()) {
                        switch(type) {
                                case DOT_PRODUCT:
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index 1453b44..ecb7878 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -24,9 +24,9 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
 
 public class Binary extends CodeTemplate {
 
-       public String getTemplate(BinType type, boolean sparseLhs, boolean 
sparseRhs, boolean scalarVector,
-                                                         boolean scalarInput) {
-
+       public String getTemplate(BinType type, boolean sparseLhs, boolean 
sparseRhs,
+               boolean scalarVector, boolean scalarInput, boolean vectorVector)
+       {
                switch (type) {
                        case DOT_PRODUCT:
                                return sparseLhs ? "    double %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
@@ -58,7 +58,7 @@ public class Binary extends CodeTemplate {
                                String vectName = type.getVectorPrimitiveName();
                                if( scalarVector )
                                        return sparseLhs ? "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, 
%POSOUT%, alen, %LEN%);\n" :
-                                                       "       
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, 
%LEN%);\n";
+                                                       "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, 
%LEN%);\n";
                                else
                                        return sparseLhs ? "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, 
%POSOUT%, alen, %LEN%);\n" :
                                                        "    
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%, 
%LEN%);\n";
@@ -92,10 +92,14 @@ public class Binary extends CodeTemplate {
                        case VECT_CBIND:
                                if( scalarInput )
                                        return  "    double[] %TMP% = 
LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
-                               else
+                               else if( !vectorVector )
                                        return sparseLhs ?
                                                        "    double[] %TMP% = 
LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, 
%LEN%);\n" :
                                                        "    double[] %TMP% = 
LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+                               else //vect/vect
+                                       return sparseLhs ?
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, 
%LEN1%, %LEN2%);\n" :
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%, 
%LEN2%);\n";
 
                                //vector-vector operations
                        case VECT_MULT:
@@ -118,8 +122,8 @@ public class Binary extends CodeTemplate {
                                return sparseLhs ?
                                                "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, 
alen, %LEN%);\n" :
                                                sparseRhs ?
-                                                               "    double[] 
%TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, 
%POS2%, alen, %LEN%);\n" :
-                                                               "    double[] 
%TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, 
%LEN%);\n";
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%, 
alen, %LEN%);\n" :
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%, 
%LEN%);\n";
                        }
 
                        //scalar-scalar operations
@@ -174,5 +178,4 @@ public class Binary extends CodeTemplate {
                                throw new RuntimeException("Invalid binary 
type: "+this.toString());
                }
        }
-
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
index 84183ae..3f7c6fe 100644
--- 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
+++ 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
@@ -25,11 +25,13 @@ import 
org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
 import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;
 import org.apache.commons.math3.util.FastMath;
 
-/* This is a SPOOF code generation template */
 public final class %TMP% extends SpoofCellwise {
-       public %TMP%() { super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, 
%AGG_OP_NAME%); }
+  public %TMP%() {
+    super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);
+  }
 
-       protected double genexec(double a, SideInput[] b, double[] scalars, int 
m, int n, long grix, int rix, int cix) {
-%BODY_dense%   return %OUT%;
-       }
-};
+  protected double genexec(double a, SideInput[] b, double[] scalars, int m, 
int n, long grix, int rix, int cix) {
+%BODY_dense%
+    return %OUT%;
+  }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
index a2c871e..89f5015 100644
--- 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
+++ 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
@@ -24,18 +24,20 @@ import org.apache.sysds.runtime.codegen.SpoofRowwise;
 import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
 import org.apache.commons.math3.util.FastMath;
 
-/* This is a SPOOF code generation template */
-public final class /*%TMP%*/SPOOF_OP_NAME extends SpoofRowwise {
-       public /*%TMP%*/SPOOF_OP_NAME() {  super(RowType.%TYPE%, %CONST_DIM2%, 
%TB1%, %VECT_MEM%); }
+public final class %TMP% extends SpoofRowwise {
+  public %TMP%() {
+    super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);
+  }
 
-       protected void genexec(double[] a, int ai, SideInput[] b, double[] 
scalars, double[] c, 
-                       int ci, int len, long grix, int rix) {                  
 
-//%BODY_dense%                 //System.out.println("TMP3=" + TMP3);
-               //for(int i=1; i<=TMP2.length; i++)
-          //  System.out.println(" " + TMP2[i]);
-       }
+  protected void genexec(double[] a, int ai, SideInput[] b,
+    double[] scalars, double[] c, int ci, int len, long grix, int rix)
+  {
+%BODY_dense%
+  }
 
-       protected void genexec(double[] avals, int[] aix, int ai, SideInput[] 
b, double[] scalars, double[] c, int ci, 
-                       int alen, int len, long grix, int rix) {                
         
-//%BODY_sparse%}
-};
+  protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b,
+    double[] scalars, double[] c, int ci, int alen, int len, long grix, int 
rix)
+  {
+%BODY_sparse%
+  }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
index 866af1a..1962a65 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
@@ -415,7 +415,7 @@ public class TemplateRow extends TemplateBase
                        }
                        else {
                                cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());
-                               cdata2 = 
TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
+                               cdata2 = 
TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1), true);
                        }
                        out = new CNodeBinary(cdata1, cdata2, 
BinType.VECT_CBIND);
                        if( cdata1 instanceof CNodeData && 
!inHops2.containsKey("X") )
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index e3920d8..f61305f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -99,13 +99,17 @@ public class TemplateUtils
        }
        
        public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
+               return wrapLookupIfNecessary(node, hop, false);
+       }
+       
+       public static CNode wrapLookupIfNecessary(CNode node, Hop hop, boolean 
rowTpl) {
                CNode ret = node;
                if( isColVector(node) )
                        ret = new CNodeUnary(node, UnaryType.LOOKUP_R);
                else if( isRowVector(node) )
                        ret = new CNodeUnary(node, UnaryType.LOOKUP_C);
                else if( node instanceof CNodeData && 
hop.getDataType().isMatrix() )
-                       ret = new CNodeUnary(node, UnaryType.LOOKUP_RC);
+                       ret = rowTpl ? node : new CNodeUnary(node, 
UnaryType.LOOKUP_RC);
                return ret;
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index c148718..905b392 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -265,6 +265,21 @@ public class LibSpoofPrimitives
                c[len] = b;
                return c;
        }
+       
+       public static double[] vectCbindWrite(double[] a, double[] b, int ai, 
int bi, int alen, int blen) {
+               double[] c = allocVector(alen+blen, false);
+               System.arraycopy(a, ai, c, 0, alen);
+               System.arraycopy(b, bi, c, alen, blen);
+               return c;
+       }
+       
+       public static double[] vectCbindWrite(double[] a, double[] b, int[] 
aix, int ai, int bi, int alen, int alen2, int blen) {
+               double[] c = allocVector(alen2+blen, true);
+               for( int j = ai; j < ai+alen; j++ )
+                       c[aix[j]] = a[j];
+               System.arraycopy(b, bi, c, alen2, blen);
+               return c;
+       }
 
        // custom vector sums, mins, maxs
        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
index 653e622..17f7426 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -49,7 +49,7 @@ public class FederatedCellwiseTmplTest extends 
AutomatedTestBase
        private final static String TEST_CONF = "SystemDS-config-codegen.xml";
 
        private final static String OUTPUT_NAME = "Z";
-       private final static double TOLERANCE = 0;
+       private final static double TOLERANCE = 1e-8;
        private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
@@ -102,7 +102,7 @@ public class FederatedCellwiseTmplTest extends 
AutomatedTestBase
                        {14, 1100, 200, 1, false},
 
                        // not working because of fused sequence operation
-                       //      (wrong grix inside genexec call of fed worker)
+                       // (wrong grix inside genexec call of fed worker)
                        // {7, 1000, 1, 1, true},
 
                        // not creating a FedSpoof instruction
@@ -186,7 +186,7 @@ public class FederatedCellwiseTmplTest extends 
AutomatedTestBase
                HashMap<CellIndex, Double> refResults  = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
                HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
                TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
-
+               
                TestUtils.shutdownThreads(thread1, thread2);
 
                // check for federated operations
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml 
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
index 68d48fb..3f91385 100644
--- a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
+++ b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
@@ -87,7 +87,7 @@ else if(test_num == 9) {
   Y = matrix(seq(6, 1005), 500, 2);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(U^2))
+  Z = as.matrix(sum(log(U)))
 }
 else if(test_num == 10) {
   # X ... 500x2 matrix
@@ -106,7 +106,7 @@ else if(test_num == 12) {
   Y = matrix(seq(6, 1005), 2, 500);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(U^2))
+  Z = as.matrix(sum(sqrt(U)))
 }
 else if(test_num == 13) {
   # X ... 2x4 matrix
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
index 1826fb2..2c13e6a 100644
--- 
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
@@ -27,10 +27,6 @@ if(row_part) {
 }
 else {
   X = cbind(read($in_X1), read($in_X2));
-
-  # TODO: remove as soon as jira ticket SYSTEMDS-2888 has been solved
-  # needed to seperate the cbind from the code generation
-  while(FALSE) {}
 }
 
 if(test_num == 1) {
@@ -89,11 +85,11 @@ else if(test_num == 9) {
   Y = matrix(seq(6, 1005), 500, 2);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(U^2))
+  Z = as.matrix(sum(log(U)))
 }
 else if(test_num == 10) {
+  while(FALSE){} #TODO
   # X ... 500x2 matrix
-
   Y = (0 / (X - 500))+1;
   Z = replace(target=Y, pattern=0/0, replacement=7);
 }
@@ -108,7 +104,7 @@ else if(test_num == 12) {
   Y = matrix(seq(6, 1005), 2, 500);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(U^2))
+  Z = as.matrix(sum(sqrt(U)))
 }
 else if(test_num == 13) {
   # X ... 2x4 matrix
@@ -117,8 +113,9 @@ else if(test_num == 13) {
   Z = 10 + floor(round(abs((X + w) * v)));
 }
 else if(test_num == 14) {
-  # X ... 1100x200 matrix
+  while(FALSE){} #TODO
 
+  # X ... 1100x200 matrix  
   Z = colMins(2 * log(X));
 }
 

Reply via email to