This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 1ec292a [SYSTEMDS-2888] Fix incomplete cbind support in codegen row
templates
1ec292a is described below
commit 1ec292a932c6e732bbac835a81cdb59371002114
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Apr 10 23:17:01 2021 +0200
[SYSTEMDS-2888] Fix incomplete cbind support in codegen row templates
This patch extends the existing cbind handling (compiler and runtime) in
fused row templates, specifically for vector-vector cbinds where the
right hand side input is also a matrix, not just a column vector.
Furthermore, we also clean up the recently refactored row and cell
templates and the related recursive template expansion.
---
.../apache/sysds/hops/codegen/SpoofCompiler.java | 5 +++-
.../org/apache/sysds/hops/codegen/cplan/CNode.java | 2 +-
.../sysds/hops/codegen/cplan/CNodeBinary.java | 8 +++---
.../apache/sysds/hops/codegen/cplan/CNodeCell.java | 30 +++++-----------------
.../sysds/hops/codegen/cplan/CNodeMultiAgg.java | 2 +-
.../hops/codegen/cplan/CNodeOuterProduct.java | 2 +-
.../apache/sysds/hops/codegen/cplan/CNodeRow.java | 22 ++++++++--------
.../sysds/hops/codegen/cplan/CodeTemplate.java | 2 +-
.../sysds/hops/codegen/cplan/cuda/Binary.java | 10 ++++----
.../sysds/hops/codegen/cplan/java/Binary.java | 19 ++++++++------
.../hops/codegen/cplan/java/Cellwise.java.template | 14 +++++-----
.../hops/codegen/cplan/java/Rowwise.java.template | 28 ++++++++++----------
.../sysds/hops/codegen/template/TemplateRow.java | 2 +-
.../sysds/hops/codegen/template/TemplateUtils.java | 6 ++++-
.../sysds/runtime/codegen/LibSpoofPrimitives.java | 15 +++++++++++
.../codegen/FederatedCellwiseTmplTest.java | 6 ++---
.../codegen/FederatedCellwiseTmplTest.dml | 4 +--
.../codegen/FederatedCellwiseTmplTestReference.dml | 13 ++++------
18 files changed, 101 insertions(+), 89 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index f728b46..d3d638a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -135,7 +135,10 @@ public class SpoofCompiler {
public enum GeneratorAPI {
AUTO,
JAVA,
- CUDA
+ CUDA;
+ public boolean isJava() {
+ return this == JAVA;
+ }
}
public enum IntegrationType {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
index 841cd1c..fd5c2e6 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
@@ -90,7 +90,7 @@ public abstract class CNode
return "a.cols()";
if(getVarname().startsWith("b"))
return getVarname()+".cols()";
- else
+ else
return getVarname()+".length";
}
else {
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 925e055..6d53e1e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -159,8 +159,11 @@ public class CNodeBinary extends CNode {
boolean scalarInput = _inputs.get(0).getDataType().isScalar();
boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
&& _inputs.get(1).getDataType().isMatrix());
+ boolean vectorVector = _inputs.get(0).getDataType().isMatrix()
+ && _inputs.get(1).getDataType().isMatrix();
String var = createVarname();
- String tmp = getLanguageTemplateClass(this,
api).getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput);
+ String tmp = getLanguageTemplateClass(this, api)
+ .getTemplate(_type, lsparseLhs, lsparseRhs,
scalarVector, scalarInput, vectorVector);
tmp = tmp.replace("%TMP%", var);
@@ -174,7 +177,6 @@ public class CNodeBinary extends CNode {
tmp = tmp.replace("%IN"+(j+1)+"%",
varj.startsWith("a") ? (api ==
GeneratorAPI.JAVA ? varj :
(_inputs.get(j).getDataType()
== DataType.MATRIX ? varj + ".vals(0)" : varj)) :
-// varj.startsWith("b") ? (api ==
GeneratorAPI.JAVA ? varj + ".values(rix)" : varj + ".vals(0)") : varj);
varj.startsWith("b") ? (api ==
GeneratorAPI.JAVA ? varj + ".values(rix)" :
(_type ==
BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
_inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ?
varj : varj + ".vals(0)") : varj);
@@ -186,7 +188,7 @@ public class CNodeBinary extends CNode {
varj + ".pos(rix)" : "0" : "0");
}
//replace length information (e.g., after matrix mult)
- if( _type == BinType.VECT_OUTERMULT_ADD ) {
+ if( _type == BinType.VECT_OUTERMULT_ADD || (_type ==
BinType.VECT_CBIND && vectorVector) ) {
for( int j=0; j<2; j++ )
tmp = tmp.replace("%LEN"+(j+1)+"%",
_inputs.get(j).getVectorLength(api));
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index 1c67e3d..070fc9e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -124,38 +124,22 @@ public class CNodeCell extends CNodeTpl
//generate dense/sparse bodies
String tmpDense = _output.codegen(false, api);
- // ToDo: workaround to fix name clash of cell and row template
+ // TODO: workaround to fix name clash of cell and row template
if(api == GeneratorAPI.CUDA)
tmpDense = tmpDense.replace("a.vals(0)", "a");
_output.resetGenerated();
- String varName;
- if(getVarname() == null)
-// tmp = tmp.replace("%TMP%", createVarname());
- varName = createVarname();
- else
-// tmp = tmp.replace("%TMP%", getVarname());
- varName = getVarname();
-
- if(api == GeneratorAPI.JAVA)
- tmp = tmp.replace("%TMP%", varName);
- else
- tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
+ String varName = (getVarname() == null) ?
+ createVarname() : getVarname();
+ tmp = tmp.replace(api.isJava() ?
+ "%TMP%" : "/*%TMP%*/SPOOF_OP_NAME", varName);
if(tmpDense.contains("grix"))
tmp = tmp.replace("//%NEED_GRIX%", "\t\tuint32_t
grix=_grix + rix;");
else
tmp = tmp.replace("//%NEED_GRIX%", "");
-
-// if(tmpDense.contains("rix"))
-// tmp = tmp.replace("//%NEED_RIX%", "\t\tuint32_t rix =
idx / A.cols();\n");
-// else
- tmp = tmp.replace("//%NEED_RIX%", "");
-
-// if(tmpDense.contains("cix"))
-// tmp = tmp.replace("//%NEED_CIX%", "\t\tuint32_t cix =
idx % A.cols();");
-// else
- tmp = tmp.replace("//%NEED_CIX%", "");
+ tmp = tmp.replace("//%NEED_RIX%", "");
+ tmp = tmp.replace("//%NEED_CIX%", "");
tmp = tmp.replace("%BODY_dense%", tmpDense);
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
index c14c2c7..0b4d625 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeMultiAgg.java
@@ -55,7 +55,7 @@ public class CNodeMultiAgg extends CNodeTpl
private static final String TEMPLATE_OUT_MIN = " c[%IX%] =
Math.min(c[%IX%], %IN%);\n";
private static final String TEMPLATE_OUT_MAX = " c[%IX%] =
Math.max(c[%IX%], %IN%);\n";
- private ArrayList<CNode> _outputs = null;
+ private ArrayList<CNode> _outputs = null;
private ArrayList<AggOp> _aggOps = null;
private ArrayList<Hop> _roots = null;
private boolean _sparseSafe = false;
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
index 955e9e2..d796045 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeOuterProduct.java
@@ -48,7 +48,7 @@ public class CNodeOuterProduct extends CNodeTpl
+ " protected double genexecCellwise(double a,
double[] a1, int a1i, double[] a2, int a2i, SideInput[] b, double[] scalars,
int m, int n, int len, int rix, int cix) { \n"
+ "%BODY_cellwise%"
+ " return %OUT_cellwise%;\n"
- + " }\n"
+ + " }\n"
+ "}\n";
private OutProdType _type = null;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
index ecc1a4e..603fea9 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeRow.java
@@ -96,14 +96,16 @@ public class CNodeRow extends CNodeTpl
String tmpSparse = _output.codegen(true, api) +
getOutputStatement(_output.getVarname());
_output.resetGenerated();
String varName = createVarname();
- tmp = tmp.replace("//%TMP%", varName);
- tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
- tmp = tmp.replace("//%BODY_dense%", tmpDense);
- tmp = tmp.replace("//%BODY_sparse%", tmpSparse);
+ tmp = tmp.replace(api.isJava()?"%TMP%":"//%TMP%", varName);
+ if( !api.isJava() )
+ tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
+ String prefix = api.isJava()? "" : "//";
+ tmp = tmp.replace(prefix+"%BODY_dense%", tmpDense);
+ tmp = tmp.replace(prefix+"%BODY_sparse%", tmpSparse);
//replace outputs
- tmp = api == GeneratorAPI.JAVA ? tmp.replace("%OUT%", "c") :
- tmp.replace("%OUT%", "c.vals(0)");
+ tmp = api.isJava() ? tmp.replace("%OUT%", "c") :
+ tmp.replace("%OUT%", "c.vals(0)");
tmp = tmp.replace("%POSOUT%", "0");
//replace size information
@@ -132,14 +134,13 @@ public class CNodeRow extends CNodeTpl
switch( _type ) {
case NO_AGG:
if(api == GeneratorAPI.CUDA)
- return
TEMPLATE_NOAGG_OUT_CUDA.replace("%IN%", varName + ".vals(0)")
.replaceAll("%LEN%", _output.getVarname()+".length");
+ return
TEMPLATE_NOAGG_OUT_CUDA.replace("%IN%", varName +
".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
case NO_AGG_B1:
case NO_AGG_CONST:
if(api == GeneratorAPI.JAVA)
- return
TEMPLATE_NOAGG_OUT.replace("%IN%", varName) .replace("%LEN%",
_output.getVarname()+".length");
+ return
TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%",
_output.getVarname()+".length");
else
-// return "";
- return
TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName + ".vals(0)")
.replaceAll("%LEN%", _output.getVarname()+".length");
+ return
TEMPLATE_NOAGG_CONST_OUT_CUDA.replace("%IN%", varName +
".vals(0)").replaceAll("%LEN%", _output.getVarname()+".length");
case FULL_AGG:
if(api == GeneratorAPI.JAVA)
return
TEMPLATE_FULLAGG_OUT.replace("%IN%", varName);
@@ -237,5 +238,4 @@ public class CNodeRow extends CNodeTpl
private native int compile_nvrtc(long context, String name, String src,
int type, long constDim2, int numVectors,
boolean TB1);
-
}
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
index 34f0b66..ce30fdb 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
@@ -33,7 +33,7 @@ public abstract class CodeTemplate {
}
public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs,
boolean sparseRhs, boolean scalarVector,
- boolean scalarInput) {
+ boolean scalarInput, boolean vectorVector) {
throw new RuntimeException("Calling wrong getTemplate method on
" + getClass().getCanonicalName());
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
index cd02f5a..7d9655f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
@@ -24,12 +24,12 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
import static
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
-public class Binary extends CodeTemplate {
-
+public class Binary extends CodeTemplate
+{
@Override
- public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs,
boolean sparseRhs, boolean scalarVector,
- boolean scalarInput) {
-
+ public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs,
boolean sparseRhs,
+ boolean scalarVector, boolean scalarInput, boolean vectorVector)
+ {
if(isSinglePrecision()) {
switch(type) {
case DOT_PRODUCT:
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index 1453b44..ecb7878 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -24,9 +24,9 @@ import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
public class Binary extends CodeTemplate {
- public String getTemplate(BinType type, boolean sparseLhs, boolean
sparseRhs, boolean scalarVector,
- boolean scalarInput) {
-
+ public String getTemplate(BinType type, boolean sparseLhs, boolean
sparseRhs,
+ boolean scalarVector, boolean scalarInput, boolean vectorVector)
+ {
switch (type) {
case DOT_PRODUCT:
return sparseLhs ? " double %TMP% =
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
@@ -58,7 +58,7 @@ public class Binary extends CodeTemplate {
String vectName = type.getVectorPrimitiveName();
if( scalarVector )
return sparseLhs ? "
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%,
%POSOUT%, alen, %LEN%);\n" :
- "
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%,
%LEN%);\n";
+ "
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%,
%LEN%);\n";
else
return sparseLhs ? "
LibSpoofPrimitives.vect"+vectName+"Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%,
%POSOUT%, alen, %LEN%);\n" :
"
LibSpoofPrimitives.vect"+vectName+"Add(%IN1%, %IN2%, %OUT%, %POS1%, %POSOUT%,
%LEN%);\n";
@@ -92,10 +92,14 @@ public class Binary extends CodeTemplate {
case VECT_CBIND:
if( scalarInput )
return " double[] %TMP% =
LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%);\n";
- else
+ else if( !vectorVector )
return sparseLhs ?
" double[] %TMP% =
LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen,
%LEN%);\n" :
" double[] %TMP% =
LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+ else //vect/vect
+ return sparseLhs ?
+ " double[] %TMP% =
LibSpoofPrimitives.vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen,
%LEN1%, %LEN2%);\n" :
+ " double[] %TMP% =
LibSpoofPrimitives.vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%,
%LEN2%);\n";
//vector-vector operations
case VECT_MULT:
@@ -118,8 +122,8 @@ public class Binary extends CodeTemplate {
return sparseLhs ?
" double[] %TMP% =
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%,
alen, %LEN%);\n" :
sparseRhs ?
- " double[]
%TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%,
%POS2%, alen, %LEN%);\n" :
- " double[]
%TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%,
%LEN%);\n";
+ " double[] %TMP% =
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2v%, %POS1%, %IN2i%, %POS2%,
alen, %LEN%);\n" :
+ " double[] %TMP% =
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %POS2%,
%LEN%);\n";
}
//scalar-scalar operations
@@ -174,5 +178,4 @@ public class Binary extends CodeTemplate {
throw new RuntimeException("Invalid binary
type: "+this.toString());
}
}
-
}
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
index 84183ae..3f7c6fe 100644
---
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
+++
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
@@ -25,11 +25,13 @@ import
org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;
import org.apache.commons.math3.util.FastMath;
-/* This is a SPOOF code generation template */
public final class %TMP% extends SpoofCellwise {
- public %TMP%() { super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%,
%AGG_OP_NAME%); }
+ public %TMP%() {
+ super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, %AGG_OP_NAME%);
+ }
- protected double genexec(double a, SideInput[] b, double[] scalars, int
m, int n, long grix, int rix, int cix) {
-%BODY_dense% return %OUT%;
- }
-};
+ protected double genexec(double a, SideInput[] b, double[] scalars, int m,
int n, long grix, int rix, int cix) {
+%BODY_dense%
+ return %OUT%;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
index a2c871e..89f5015 100644
---
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
+++
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template
@@ -24,18 +24,20 @@ import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.commons.math3.util.FastMath;
-/* This is a SPOOF code generation template */
-public final class /*%TMP%*/SPOOF_OP_NAME extends SpoofRowwise {
- public /*%TMP%*/SPOOF_OP_NAME() { super(RowType.%TYPE%, %CONST_DIM2%,
%TB1%, %VECT_MEM%); }
+public final class %TMP% extends SpoofRowwise {
+ public %TMP%() {
+ super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);
+ }
- protected void genexec(double[] a, int ai, SideInput[] b, double[]
scalars, double[] c,
- int ci, int len, long grix, int rix) {
-//%BODY_dense% //System.out.println("TMP3=" + TMP3);
- //for(int i=1; i<=TMP2.length; i++)
- // System.out.println(" " + TMP2[i]);
- }
+ protected void genexec(double[] a, int ai, SideInput[] b,
+ double[] scalars, double[] c, int ci, int len, long grix, int rix)
+ {
+%BODY_dense%
+ }
- protected void genexec(double[] avals, int[] aix, int ai, SideInput[]
b, double[] scalars, double[] c, int ci,
- int alen, int len, long grix, int rix) {
-//%BODY_sparse%}
-};
+ protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b,
+ double[] scalars, double[] c, int ci, int alen, int len, long grix, int
rix)
+ {
+%BODY_sparse%
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
index 866af1a..1962a65 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
@@ -415,7 +415,7 @@ public class TemplateRow extends TemplateBase
}
else {
cdata2 =
tmp.get(hop.getInput().get(1).getHopID());
- cdata2 =
TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
+ cdata2 =
TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1), true);
}
out = new CNodeBinary(cdata1, cdata2,
BinType.VECT_CBIND);
if( cdata1 instanceof CNodeData &&
!inHops2.containsKey("X") )
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index e3920d8..f61305f 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -99,13 +99,17 @@ public class TemplateUtils
}
public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
+ return wrapLookupIfNecessary(node, hop, false);
+ }
+
+ public static CNode wrapLookupIfNecessary(CNode node, Hop hop, boolean
rowTpl) {
CNode ret = node;
if( isColVector(node) )
ret = new CNodeUnary(node, UnaryType.LOOKUP_R);
else if( isRowVector(node) )
ret = new CNodeUnary(node, UnaryType.LOOKUP_C);
else if( node instanceof CNodeData &&
hop.getDataType().isMatrix() )
- ret = new CNodeUnary(node, UnaryType.LOOKUP_RC);
+ ret = rowTpl ? node : new CNodeUnary(node,
UnaryType.LOOKUP_RC);
return ret;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index c148718..905b392 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -265,6 +265,21 @@ public class LibSpoofPrimitives
c[len] = b;
return c;
}
+
+ public static double[] vectCbindWrite(double[] a, double[] b, int ai,
int bi, int alen, int blen) {
+ double[] c = allocVector(alen+blen, false);
+ System.arraycopy(a, ai, c, 0, alen);
+ System.arraycopy(b, bi, c, alen, blen);
+ return c;
+ }
+
+ public static double[] vectCbindWrite(double[] a, double[] b, int[]
aix, int ai, int bi, int alen, int alen2, int blen) {
+ double[] c = allocVector(alen2+blen, true);
+ for( int j = ai; j < ai+alen; j++ )
+ c[aix[j]] = a[j];
+ System.arraycopy(b, bi, c, alen2, blen);
+ return c;
+ }
// custom vector sums, mins, maxs
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
index 653e622..17f7426 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -49,7 +49,7 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
private final static String OUTPUT_NAME = "Z";
- private final static double TOLERANCE = 0;
+ private final static double TOLERANCE = 1e-8;
private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
@@ -102,7 +102,7 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
{14, 1100, 200, 1, false},
// not working because of fused sequence operation
- // (wrong grix inside genexec call of fed worker)
+ // (wrong grix inside genexec call of fed worker)
// {7, 1000, 1, 1, true},
// not creating a FedSpoof instruction
@@ -186,7 +186,7 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
-
+
TestUtils.shutdownThreads(thread1, thread2);
// check for federated operations
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
index 68d48fb..3f91385 100644
--- a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
+++ b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
@@ -87,7 +87,7 @@ else if(test_num == 9) {
Y = matrix(seq(6, 1005), 500, 2);
U = X + 7 * Y;
- Z = as.matrix(sum(U^2))
+ Z = as.matrix(sum(log(U)))
}
else if(test_num == 10) {
# X ... 500x2 matrix
@@ -106,7 +106,7 @@ else if(test_num == 12) {
Y = matrix(seq(6, 1005), 2, 500);
U = X + 7 * Y;
- Z = as.matrix(sum(U^2))
+ Z = as.matrix(sum(sqrt(U)))
}
else if(test_num == 13) {
# X ... 2x4 matrix
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
index 1826fb2..2c13e6a 100644
---
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
+++
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
@@ -27,10 +27,6 @@ if(row_part) {
}
else {
X = cbind(read($in_X1), read($in_X2));
-
- # TODO: remove as soon as jira ticket SYSTEMDS-2888 has been solved
- # needed to seperate the cbind from the code generation
- while(FALSE) {}
}
if(test_num == 1) {
@@ -89,11 +85,11 @@ else if(test_num == 9) {
Y = matrix(seq(6, 1005), 500, 2);
U = X + 7 * Y;
- Z = as.matrix(sum(U^2))
+ Z = as.matrix(sum(log(U)))
}
else if(test_num == 10) {
+ while(FALSE){} #TODO
# X ... 500x2 matrix
-
Y = (0 / (X - 500))+1;
Z = replace(target=Y, pattern=0/0, replacement=7);
}
@@ -108,7 +104,7 @@ else if(test_num == 12) {
Y = matrix(seq(6, 1005), 2, 500);
U = X + 7 * Y;
- Z = as.matrix(sum(U^2))
+ Z = as.matrix(sum(sqrt(U)))
}
else if(test_num == 13) {
# X ... 2x4 matrix
@@ -117,8 +113,9 @@ else if(test_num == 13) {
Z = 10 + floor(round(abs((X + w) * v)));
}
else if(test_num == 14) {
- # X ... 1100x200 matrix
+ while(FALSE){} #TODO
+ # X ... 1100x200 matrix
Z = colMins(2 * log(X));
}