This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 1a24d38 [SYSTEMDS-2922] Federated fused codegen operators (cellwise)
1a24d38 is described below
commit 1a24d3869e8c8770d8ed4445452aa465589b7e4f
Author: ywcb00 <[email protected]>
AuthorDate: Sun Apr 4 16:17:59 2021 +0200
[SYSTEMDS-2922] Federated fused codegen operators (cellwise)
Closes #1214.
---
.../apache/sysds/runtime/codegen/CodegenUtils.java | 116 +++++-----
.../controlprogram/federated/FederatedRequest.java | 46 ++--
.../controlprogram/federated/FederationMap.java | 34 ++-
.../instructions/cp/SpoofCPInstruction.java | 39 ++--
.../fed/BinaryMatrixMatrixFEDInstruction.java | 20 +-
.../runtime/instructions/fed/FEDInstruction.java | 11 +-
.../instructions/fed/FEDInstructionUtils.java | 14 ++
.../instructions/fed/SpoofFEDInstruction.java | 239 +++++++++++++++++++++
.../instructions/spark/SpoofSPInstruction.java | 233 ++++++++++----------
.../java/org/apache/sysds/utils/Statistics.java | 1 -
.../codegen/FederatedCellwiseTmplTest.java | 211 ++++++++++++++++++
.../codegen/FederatedCellwiseTmplTest.dml | 123 +++++++++++
.../codegen/FederatedCellwiseTmplTestReference.dml | 125 +++++++++++
.../federated/codegen/SystemDS-config-codegen.xml | 32 +++
14 files changed, 1013 insertions(+), 231 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
b/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
index fb801ab..a390aa6 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -54,13 +54,13 @@ import java.util.Map.Entry;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
-public class CodegenUtils
+public class CodegenUtils
{
private static final Log LOG =
LogFactory.getLog(CodegenUtils.class.getName());
-
- //cache to reuse compiled and loaded classes
+
+ //cache to reuse compiled and loaded classes
private static ConcurrentHashMap<String, Class<?>> _cache = new
ConcurrentHashMap<>();
-
+
//janino-specific map of source code transfer/recompile on-demand
private static ConcurrentHashMap<String, String> _src = new
ConcurrentHashMap<>();
@@ -69,36 +69,36 @@ public class CodegenUtils
//javac-specific working directory for src/class files
private static String _workingDir = null;
-
+
public static Class<?> compileClass(String name, String src) {
//reuse existing compiled class
Class<?> ret = _cache.get(name);
- if( ret != null )
+ if( ret != null )
return ret;
-
+
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
-
+
//compile java source w/ specific compiler
if( SpoofCompiler.JAVA_COMPILER == CompilerType.JANINO )
ret = compileClassJanino(name, src);
else
ret = compileClassJavac(name, src);
-
+
//keep compiled class for reuse
_cache.put(name, ret);
-
+
if( DMLScript.STATISTICS ) {
Statistics.incrementCodegenClassCompile();
Statistics.incrementCodegenClassCompileTime(System.nanoTime()-t0);
}
-
+
return ret;
}
-
+
public static Class<?> getClass(String name) {
return getClass(name, null);
}
-
+
public synchronized static Class<?> getClassSync(String name, byte[]
classBytes) {
//In order to avoid anomalies of concurrently compiling and
loading the same
//class with the same name multiple times in spark executors,
this indirection
@@ -108,13 +108,13 @@ public class CodegenUtils
//multiple times which causes unnecessary JIT compilation
overhead.
return getClass(name, classBytes);
}
-
+
public static Class<?> getClass(String name, byte[] classBytes) {
//reuse existing compiled class
Class<?> ret = _cache.get(name);
- if( ret != null )
+ if( ret != null )
return ret;
-
+
//get class in a compiler-specific manner
if( SpoofCompiler.JAVA_COMPILER == CompilerType.JANINO )
ret = compileClassJanino(name, new String(classBytes));
@@ -125,7 +125,7 @@ public class CodegenUtils
_cache.put(name, ret);
return ret;
}
-
+
public static byte[] getClassData(String name) {
//get class in a compiler-specific manner
if( SpoofCompiler.JAVA_COMPILER == CompilerType.JANINO )
@@ -133,12 +133,12 @@ public class CodegenUtils
else
return getClassAsByteArray(name);
}
-
+
public static void clearClassCache() {
_cache.clear();
_src.clear();
}
-
+
public static void clearClassCache(Class<?> cla) {
//one-pass, in-place filtering of class cache
Iterator<Entry<String,Class<?>>> iter =
_cache.entrySet().iterator();
@@ -146,17 +146,17 @@ public class CodegenUtils
if( iter.next().getValue()==cla )
iter.remove();
}
-
+
public static SpoofOperator createInstance(Class<?> cla) {
SpoofOperator ret = null;
-
+
try {
ret = (SpoofOperator) cla.newInstance();
}
catch( Exception ex ) {
throw new DMLRuntimeException(ex);
}
-
+
return ret;
}
@@ -167,18 +167,18 @@ public class CodegenUtils
public static void putCUDAOpID(String name, int id) {
_CUDA_op_IDs.put(name, id);
}
-
+
public static void putCUDASource(int id, String src) {
_CUDA_op_src.put(id, src);
}
-
+
public static SideInput createSideInput(MatrixBlock in) {
SideInput ret = (in.isInSparseFormat() || !in.isAllocated()) ?
new SideInput(null, in, in.getNumColumns()) :
new SideInput(in.getDenseBlock(), null,
in.getNumColumns());
return (ret.mdat != null) ? new SideInputSparseCell(ret) : ret;
}
-
+
////////////////////////////
//JANINO-specific methods (used for spark environments)
@@ -187,10 +187,10 @@ public class CodegenUtils
//compile source code
SimpleCompiler compiler = new SimpleCompiler();
compiler.cook(src);
-
+
//keep source code for later re-construction
_src.put(name, src);
-
+
//load compile class
return compiler.getClassLoader()
.loadClass(name);
@@ -199,8 +199,8 @@ public class CodegenUtils
LOG.error("Failed to compile class "+name+": \n"+src);
throw new DMLRuntimeException("Failed to compile class
"+name+".", ex);
}
- }
-
+ }
+
////////////////////////////
//JAVAC-specific methods (used for hadoop environments)
@@ -210,46 +210,46 @@ public class CodegenUtils
//create working dir on demand
if( _workingDir == null )
createWorkingDir();
-
+
//write input file (for debugging / classpath handling)
File ftmp = new File(_workingDir+"/"+name.replace(".",
"/")+".java");
if( !ftmp.getParentFile().exists() )
ftmp.getParentFile().mkdirs();
LocalFileUtils.writeTextFile(ftmp, src);
-
+
//get system java compiler
JavaCompiler compiler =
ToolProvider.getSystemJavaCompiler();
if( compiler == null )
throw new RuntimeException("Unable to obtain
system java compiler.");
-
+
//prepare file manager
- DiagnosticCollector<JavaFileObject> diagnostics = new
DiagnosticCollector<>();
+ DiagnosticCollector<JavaFileObject> diagnostics = new
DiagnosticCollector<>();
try(StandardJavaFileManager fileManager =
compiler.getStandardFileManager(diagnostics, null, null))
{
//prepare input source code
Iterable<? extends JavaFileObject> sources =
fileManager
.getJavaFileObjectsFromFiles(Arrays.asList(ftmp));
-
- //prepare class path
- URL runDir =
CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
- String classpath =
System.getProperty("java.class.path") +
+
+ //prepare class path
+ URL runDir =
CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ String classpath =
System.getProperty("java.class.path") +
File.pathSeparator +
runDir.getPath();
List<String> options =
Arrays.asList("-classpath",classpath);
-
+
//compile source code
CompilationTask task = compiler.getTask(null,
fileManager, diagnostics, options, null, sources);
Boolean success = task.call();
-
+
//output diagnostics and error handling
for(Diagnostic<? extends JavaFileObject> tmp :
diagnostics.getDiagnostics())
if( tmp.getKind()==Kind.ERROR )
System.err.println("ERROR:
"+tmp.toString());
if( success == null || !success )
throw new RuntimeException("Failed to
compile class "+name);
-
+
//dynamically load compiled class
try (URLClassLoader classLoader = new
URLClassLoader(
- new URL[]{new
File(_workingDir).toURI().toURL(), runDir},
+ new URL[]{new
File(_workingDir).toURI().toURL(), runDir},
CodegenUtils.class.getClassLoader()))
{
return classLoader.loadClass(name);
@@ -261,49 +261,49 @@ public class CodegenUtils
throw new DMLRuntimeException("Failed to compile class
"+name+".", ex);
}
}
-
+
private static Class<?> loadFromClassFile(String name, byte[]
classBytes) {
if(classBytes != null) {
//load from byte representation of class file
- try(ByteClassLoader byteLoader = new
ByteClassLoader(new URL[]{},
- CodegenUtils.class.getClassLoader(),
classBytes))
+ try(ByteClassLoader byteLoader = new
ByteClassLoader(new URL[]{},
+ CodegenUtils.class.getClassLoader(),
classBytes))
{
return byteLoader.findClass(name);
- }
+ }
catch (Exception e) {
throw new DMLRuntimeException(e);
}
}
else {
//load compiled class file
- URL runDir =
CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ URL runDir =
CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
try(URLClassLoader classLoader = new URLClassLoader(new
URL[]{new File(_workingDir)
- .toURI().toURL(), runDir},
CodegenUtils.class.getClassLoader()))
+ .toURI().toURL(), runDir},
CodegenUtils.class.getClassLoader()))
{
return classLoader.loadClass(name);
- }
+ }
catch (Exception e) {
throw new DMLRuntimeException(e);
}
- }
+ }
}
-
+
@SuppressWarnings("resource")
private static byte[] getClassAsByteArray(String name) {
String classAsPath = name.replace('.', '/') + ".class";
-
+
URLClassLoader classLoader = null;
InputStream stream = null;
-
+
try {
//dynamically load compiled class
- URL runDir =
CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
+ URL runDir =
CodegenUtils.class.getProtectionDomain().getCodeSource().getLocation();
classLoader = new URLClassLoader(
- new URL[]{new
File(_workingDir).toURI().toURL(), runDir},
+ new URL[]{new
File(_workingDir).toURI().toURL(), runDir},
CodegenUtils.class.getClassLoader());
stream = classLoader.getResourceAsStream(classAsPath);
return IOUtils.toByteArray(stream);
- }
+ }
catch (IOException e) {
throw new DMLRuntimeException(e);
}
@@ -312,7 +312,7 @@ public class CodegenUtils
IOUtilFunctions.closeSilently(stream);
}
}
-
+
private static void createWorkingDir() {
if( _workingDir != null )
return;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 00d0ac5..abc3437 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -38,7 +38,7 @@ import org.apache.sysds.utils.Statistics;
public class FederatedRequest implements Serializable {
private static final long serialVersionUID = 5946781306963870394L;
-
+
// commands sent to and excuted by federated workers
public enum RequestType {
READ_VAR, // create variable for local data, read on first
access
@@ -48,27 +48,26 @@ public class FederatedRequest implements Serializable {
EXEC_UDF, // execute arbitrary user-defined function
CLEAR, // clear all variables and execution contexts (i.e.,
rmvar ALL)
}
-
+
private RequestType _method;
private long _id;
private long _tid;
private List<Object> _data;
private boolean _checkPrivacy;
private List<Long> _checksums;
-
-
+
public FederatedRequest(RequestType method) {
this(method, FederationUtils.getNextFedDataID(), new
ArrayList<>());
}
-
+
public FederatedRequest(RequestType method, long id) {
this(method, id, new ArrayList<>());
}
-
+
public FederatedRequest(RequestType method, long id, Object ... data) {
this(method, id, Arrays.asList(data));
}
-
+
public FederatedRequest(RequestType method, long id, List<Object> data)
{
Statistics.incFederated(method);
_method = method;
@@ -78,41 +77,41 @@ public class FederatedRequest implements Serializable {
if (DMLScript.LINEAGE && method == RequestType.PUT_VAR)
setChecksum();
}
-
+
public RequestType getType() {
return _method;
}
-
+
public long getID() {
return _id;
}
-
+
public long getTID() {
return _tid;
}
-
+
public void setTID(long tid) {
_tid = tid;
}
-
+
public Object getParam(int i) {
return _data.get(i);
}
-
+
public FederatedRequest appendParam(Object obj) {
_data.add(obj);
return this;
}
-
+
public FederatedRequest appendParams(Object ... objs) {
_data.addAll(Arrays.asList(objs));
return this;
}
-
+
public int getNumParams() {
return _data.size();
}
-
+
public FederatedRequest deepClone() {
return new FederatedRequest(_method, _id, new
ArrayList<>(_data));
}
@@ -128,7 +127,7 @@ public class FederatedRequest implements Serializable {
public boolean checkPrivacy(){
return _checkPrivacy;
}
-
+
public void setChecksum() {
// Calculate Adler32 checksum. This is used as a leaf node of
Lineage DAGs
// in the workers, and helps to uniquely identify a node
(tracing PUT)
@@ -141,23 +140,23 @@ public class FederatedRequest implements Serializable {
throw new DMLException(e);
}
}
-
+
public long getChecksum(int i) {
return _checksums.get(i);
}
-
+
private void calcChecksum() throws IOException {
for (Object ob : _data) {
if (!(ob instanceof CacheBlock) && !(ob instanceof
ScalarObject))
continue;
-
+
Checksum checksum = new Adler32();
if (ob instanceof ScalarObject) {
byte bytes[] =
((ScalarObject)ob).getStringValue().getBytes();
checksum.update(bytes, 0, bytes.length);
_checksums.add(checksum.getValue());
}
-
+
if (ob instanceof CacheBlock) {
try {
CacheBlock cb = (CacheBlock)ob;
@@ -174,7 +173,7 @@ public class FederatedRequest implements Serializable {
}
}
}
-
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder("FederatedRequest[");
@@ -182,7 +181,8 @@ public class FederatedRequest implements Serializable {
sb.append(_id); sb.append(";");
sb.append("t"); sb.append(_tid); sb.append(";");
if( _method != RequestType.PUT_VAR )
- sb.append(Arrays.toString(_data.toArray()));
sb.append("]");
+ sb.append(Arrays.toString(_data.toArray()));
+ sb.append("]");
return sb.toString();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b066d1e..13e9fb7 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -30,13 +30,14 @@ import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.BiFunction;
+import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -134,7 +135,7 @@ public class FederationMap {
/**
* Creates separate slices of an input data object according to the
index ranges of federated data. Theses slices
* are then wrapped in separate federated requests for broadcasting.
- *
+ *
* @param data input data object (matrix, tensor, frame)
* @param transposed false: slice according to federated data, true:
slice according to transposed federated data
* @return array of federated requests corresponding to federated data
@@ -142,11 +143,11 @@ public class FederationMap {
public FederatedRequest[] broadcastSliced(CacheableData<?> data,
boolean transposed) {
if( _type == FType.FULL )
return new FederatedRequest[]{broadcast(data)};
-
+
// prepare broadcast id and pin input
long id = FederationUtils.getNextFedDataID();
CacheBlock cb = data.acquireReadAndRelease();
-
+
// prepare indexing ranges
int[][] ix = new int[_fedMap.size()][];
int pos = 0;
@@ -232,6 +233,29 @@ public class FederationMap {
return ret.toArray(new Future[0]);
}
+ @SuppressWarnings("unchecked")
+ public Future<FederatedResponse>[] executeMultipleSlices(long tid,
boolean wait,
+ FederatedRequest[][] frSlices, FederatedRequest[] fr) {
+ // executes step1[] - ... - stepM[] - stepM+1 - ... stepN (only
first step federated-data-specific)
+ FederatedRequest[] allSlices =
Arrays.stream(frSlices).flatMap(Stream::of).toArray(FederatedRequest[]::new);
+ setThreadID(tid, allSlices, fr);
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+ int pos = 0;
+ for(Entry<FederatedRange, FederatedData> e :
_fedMap.entrySet()) {
+ FederatedRequest[] fedReq = fr;
+ for(FederatedRequest[] slice : frSlices)
+ fedReq = addAll(slice[pos], fedReq);
+ ret.add(e.getValue().executeFederatedOperation(fedReq));
+ pos++;
+ }
+
+ // prepare results (future federated responses), with optional
wait to ensure the
+ // order of requests without data dependencies (e.g., cleanup
RPCs)
+ if(wait)
+ FederationUtils.waitFor(ret);
+ return ret.toArray(new Future[0]);
+ }
+
public List<Pair<FederatedRange, Future<FederatedResponse>>>
requestFederatedData() {
if(!isInitialized())
throw new DMLRuntimeException("Federated matrix read
only supported on initialized FederatedData");
@@ -360,7 +384,7 @@ public class FederationMap {
* Execute a function for each <code>FederatedRange</code> +
<code>FederatedData</code> pair. The function should
* not change any data of the pair and instead use
<code>mapParallel</code> if that is a necessity. Note that this
* operation is parallel and necessary synchronisation has to be
performed.
- *
+ *
* @param forEachFunction function to execute for each pair
*/
public void forEachParallel(BiFunction<FederatedRange, FederatedData,
Void> forEachFunction) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
index 4acba1e..0ba12a2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -28,6 +28,7 @@ import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageCodegenItem;
@@ -44,8 +45,9 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
private final int _numThreads;
private final CPOperand[] _in;
- private SpoofCPInstruction(SpoofOperator op, Class<?> cla, int k,
CPOperand[] in, CPOperand out, String opcode,
- String str) {
+ private SpoofCPInstruction(SpoofOperator op, Class<?> cla, int k,
+ CPOperand[] in, CPOperand out, String opcode, String str)
+ {
super(CPType.SpoofFused, null, null, null, out, opcode, str);
_class = cla;
_op = op;
@@ -59,17 +61,17 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
public static SpoofCPInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
-
+
ArrayList<CPOperand> inlist = new ArrayList<>();
Class<?> cla = CodegenUtils.getClass(parts[2]);
SpoofOperator op = CodegenUtils.createInstance(cla);
String opcode = parts[0] + op.getSpoofType();
-
+
for( int i=3; i<parts.length-2; i++ )
inlist.add(new CPOperand(parts[i]));
CPOperand out = new CPOperand(parts[parts.length-2]);
int k = Integer.parseInt(parts[parts.length-1]);
-
+
return new SpoofCPInstruction(op, cla, k, inlist.toArray(new
CPOperand[0]), out, opcode, str);
}
@@ -78,10 +80,12 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
//get input matrices and scalars, incl pinning of matrices
ArrayList<MatrixBlock> inputs = new ArrayList<>();
ArrayList<ScalarObject> scalars = new ArrayList<>();
- LOG.debug("executing spoof instruction " + _op);
+ if( LOG.isDebugEnabled() )
+ LOG.debug("executing spoof instruction " + _op);
for (CPOperand input : _in) {
if(input.getDataType()==DataType.MATRIX){
MatrixBlock mb =
ec.getMatrixInput(input.getName());
+ //FIXME fused codegen operators already support
compressed main inputs
if(mb instanceof CompressedMatrixBlock){
LOG.warn("Spoof instruction
decompressed matrix");
mb = ((CompressedMatrixBlock)
mb).decompress(_numThreads);
@@ -93,7 +97,7 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
scalars.add(ec.getScalarInput(input));
}
}
-
+
// set the output dimensions to the hop node matrix dimensions
if( output.getDataType() == DataType.MATRIX) {
MatrixBlock out = _op.execute(inputs, scalars, new
MatrixBlock(), _numThreads);
@@ -103,13 +107,13 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
ScalarObject out = _op.execute(inputs, scalars,
_numThreads);
ec.setScalarOutput(output.getName(), out);
}
-
+
// release input matrices
for (CPOperand input : _in)
if(input.getDataType()==DataType.MATRIX)
ec.releaseMatrixInput(input.getName());
}
-
+
@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
//return the lineage item if already traced once
@@ -119,10 +123,19 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
//read and deepcopy the corresponding lineage DAG (pre-codegen)
LineageItem LIroot =
LineageCodegenItem.getCodegenLTrace(getOperatorClass().getName()).deepCopy();
-
- //replace the placeholders with original instruction inputs.
+
+ //replace the placeholders with original instruction inputs.
LineageItemUtils.replaceDagLeaves(ec, LIroot, _in);
return Pair.of(output.getName(), LIroot);
}
+
+ public boolean isFederated(ExecutionContext ec) {
+ for(CPOperand input : _in) {
+ Data data = ec.getVariable(input);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
+ return true;
+ }
+ return false;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index b6d0227..cbe9bad 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -70,22 +70,7 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
}
}
else { // matrix-matrix binary operations -> lhs fed input ->
fed output
- if(mo1.isFederated(FType.FULL)) {
- // full federated (row and col)
- if(mo1.getFedMapping().getSize() == 1) {
- // only one partition (MM on a single
fed worker)
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
- fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
- FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
- //execute federated instruction and
cleanup intermediates
- mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
- }
- else {
- throw new
DMLRuntimeException("Matrix-matrix binary operations with a full partitioned
federated input with multiple partitions are not supported yet.");
- }
- }
- else if((mo1.isFederated(FType.ROW) && mo2.getNumRows()
== 1 && mo2.getNumColumns() > 1)
+ if((mo1.isFederated(FType.ROW) && mo2.getNumRows() == 1
&& mo2.getNumColumns() > 1)
|| (mo1.isFederated(FType.COL) &&
mo2.getNumRows() > 1 && mo2.getNumColumns() == 1)) {
// MV row partitioned row vector, MV col
partitioned col vector
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
@@ -95,7 +80,8 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
- else if(mo1.isFederated(FType.ROW) ^
mo1.isFederated(FType.COL)) {
+ else if((mo1.isFederated(FType.ROW) ^
mo1.isFederated(FType.COL))
+ || (mo1.isFederated(FType.FULL) &&
mo1.getFedMapping().getSize() == 1)) {
// row partitioned MM or col partitioned MM
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 8ed9aba..8f58a8b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -34,16 +34,17 @@ public abstract class FEDInstruction extends Instruction {
Binary,
Init,
MultiReturnParameterizedBuiltin,
- ParameterizedBuiltin,
- Tsmm,
MMChain,
- Reorg,
- Reshape,
MatrixIndexing,
- Ternary,
+ Ternary,
+ ParameterizedBuiltin,
Quaternary,
QSort,
QPick,
+ Reorg,
+ Reshape,
+ SpoofFused,
+ Tsmm,
Unary
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index a4c750c..214023f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -39,6 +40,7 @@ import
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstructio
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -57,6 +59,7 @@ import
org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
@@ -214,6 +217,11 @@ public class FEDInstructionUtils {
if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
fedinst =
QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
}
+ else if(inst instanceof SpoofCPInstruction) {
+ SpoofCPInstruction instruction = (SpoofCPInstruction)
inst;
+ if(instruction.getOperatorClass().getSuperclass() ==
SpoofCellwise.class && instruction.isFederated(ec))
+ fedinst =
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
//set thread id for federated context management
if( fedinst != null ) {
@@ -305,6 +313,12 @@ public class FEDInstructionUtils {
if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
fedinst =
QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
}
+ else if(inst instanceof SpoofSPInstruction) {
+ SpoofSPInstruction instruction = (SpoofSPInstruction)
inst;
+ if(instruction.getOperatorClass().getSuperclass() ==
SpoofCellwise.class && instruction.isFederated(ec)) {
+ fedinst =
SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ }
//set thread id for federated context management
if( fedinst != null ) {
fedinst.setTID(ec.getTID());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
new file mode 100644
index 0000000..6e59813
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
+import org.apache.sysds.runtime.codegen.SpoofCellwise;
+import org.apache.sysds.runtime.codegen.SpoofOperator;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+
+import java.util.ArrayList;
+import java.util.concurrent.Future;
+
+public class SpoofFEDInstruction extends FEDInstruction
+{
+ private final SpoofOperator _op;
+ private final CPOperand[] _inputs;
+ private final CPOperand _output;
+
+ private SpoofFEDInstruction(SpoofOperator op, CPOperand[] in,
+ CPOperand out, String opcode, String inst_str)
+ {
+ super(FEDInstruction.FEDType.SpoofFused, opcode, inst_str);
+ _op = op;
+ _inputs = in;
+ _output = out;
+ }
+
+ public static SpoofFEDInstruction parseInstruction(String str)
+ {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+
+ CPOperand[] inputCpo = new CPOperand[parts.length - 3 - 2];
+ Class<?> cla = CodegenUtils.getClass(parts[2]);
+ SpoofOperator op = CodegenUtils.createInstance(cla);
+ String opcode = parts[0] + op.getSpoofType();
+
+ for(int counter = 3; counter < parts.length - 2; counter++)
+ inputCpo[counter - 3] = new CPOperand(parts[counter]);
+ CPOperand out = new CPOperand(parts[parts.length - 2]);
+
+ return new SpoofFEDInstruction(op, inputCpo, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec)
+ {
+ ArrayList<CPOperand> inCpoMat = new ArrayList<>();
+ ArrayList<CPOperand> inCpoScal = new ArrayList<>();
+ ArrayList<MatrixObject> inMo = new ArrayList<>();
+ ArrayList<ScalarObject> inSo = new ArrayList<>();
+ MatrixObject fedMo = null;
+ FederationMap fedMap = null;
+ for(CPOperand cpo : _inputs) {
+ Data tmpData = ec.getVariable(cpo);
+ if(tmpData instanceof MatrixObject) {
+ MatrixObject tmp = (MatrixObject) tmpData;
+ if(fedMo == null & tmp.isFederated()) { //take
first
+ inCpoMat.add(0, cpo); // insert
federated CPO at the beginning
+ fedMo = tmp;
+ fedMap = tmp.getFedMapping();
+ }
+ else {
+ inCpoMat.add(cpo);
+ inMo.add(tmp);
+ }
+ }
+ else if(tmpData instanceof ScalarObject) {
+ ScalarObject tmp = (ScalarObject) tmpData;
+ inCpoScal.add(cpo);
+ inSo.add(tmp);
+ }
+ }
+
+ ArrayList<FederatedRequest> frBroadcast = new ArrayList<>();
+ ArrayList<FederatedRequest[]> frBroadcastSliced = new
ArrayList<>();
+ long[] frIds = new long[1 + inMo.size() + inSo.size()];
+ int index = 0;
+ frIds[index++] = fedMap.getID(); // insert federation map id at
the beginning
+ for(MatrixObject mo : inMo) {
+ if((fedMo.isFederated(FType.ROW) && mo.getNumRows() > 1
&& (mo.getNumColumns() == 1 || mo.getNumColumns() == fedMap.getSize()))
+ || (fedMo.isFederated(FType.ROW) &&
mo.getNumColumns() > 1 && mo.getNumRows() == fedMap.getSize())
+ || (fedMo.isFederated(FType.COL) &&
(mo.getNumRows() == 1 || mo.getNumRows() == fedMap.getSize()) &&
mo.getNumColumns() > 1)
+ || (fedMo.isFederated(FType.COL) &&
mo.getNumRows() > 1 && mo.getNumColumns() == fedMap.getSize())) {
+ FederatedRequest[] tmpFr =
fedMap.broadcastSliced(mo, false);
+ frIds[index++] = tmpFr[0].getID();
+ frBroadcastSliced.add(tmpFr);
+ }
+ else {
+ FederatedRequest tmpFr = fedMap.broadcast(mo);
+ frIds[index++] = tmpFr.getID();
+ frBroadcast.add(tmpFr);
+ }
+ }
+ for(ScalarObject so : inSo) {
+ FederatedRequest tmpFr = fedMap.broadcast(so);
+ frIds[index++] = tmpFr.getID();
+ frBroadcast.add(tmpFr);
+ }
+
+ // change the is_literal flag from true to false because when
broadcasted it is not a literal anymore
+ instString = instString.replace("true", "false");
+
+ CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new
CPOperand[0]), inCpoScal.toArray(new CPOperand[0]));
+ FederatedRequest frCompute =
FederationUtils.callInstruction(instString, _output, inCpo, frIds);
+
+ // get partial results from federated workers
+ FederatedRequest frGet = new
FederatedRequest(RequestType.GET_VAR, frCompute.getID());
+
+ ArrayList<FederatedRequest> frCleanup = new ArrayList<>();
+ frCleanup.add(fedMap.cleanup(getTID(), frCompute.getID()));
+ for(FederatedRequest fr : frBroadcast)
+ frCleanup.add(fedMap.cleanup(getTID(), fr.getID()));
+ for(FederatedRequest[] fr : frBroadcastSliced)
+ frCleanup.add(fedMap.cleanup(getTID(), fr[0].getID()));
+
+ FederatedRequest[] frAll = ArrayUtils.addAll(ArrayUtils.addAll(
+ frBroadcast.toArray(new FederatedRequest[0]),
frCompute, frGet),
+ frCleanup.toArray(new FederatedRequest[0]));
+ Future<FederatedResponse>[] response =
fedMap.executeMultipleSlices(
+ getTID(), true, frBroadcastSliced.toArray(new
FederatedRequest[0][]), frAll);
+
+ if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.FULL_AGG) { // full aggregation
+ if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM
+ || ((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM_SQ) {
+ //aggregate partial results from federated
responses as sum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ }
+ else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MIN) {
+ //aggregate partial results from federated
responses as min
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ }
+ else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MAX) {
+ //aggregate partial results from federated
responses as max
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ }
+ else {
+ throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
+ }
+ }
+ else if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.ROW_AGG) { // row aggregation
+ if(fedMo.isFederated(FType.ROW)) {
+ // bind partial results from federated responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else if(fedMo.isFederated(FType.COL)) {
+ if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM
+ || ((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM_SQ) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MIN) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MAX) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ }
+ else {
+ throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
+ }
+ }
+ else if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.COL_AGG) { // col aggregation
+ if(fedMo.isFederated(FType.ROW)) {
+ if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM
+ || ((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM_SQ) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MIN) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MAX) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ }
+ else if(fedMo.isFederated(FType.COL)) {
+ // bind partial results from federated responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ }
+ else {
+ throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
+ }
+ }
+ else if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.NO_AGG) { // no aggregation
+ if(fedMo.isFederated(FType.ROW)) {
+ // bind partial results from federated responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else if(fedMo.isFederated(FType.COL)) {
+ // bind partial results from federated responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ }
+ else {
+ throw new DMLRuntimeException("Only row
partitioned or column partitioned federated matrices supported yet.");
+ }
+ }
+ else {
+ throw new DMLRuntimeException("Aggregation type not
supported yet.");
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
index 6dc2832..dff78d8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
@@ -39,6 +39,7 @@ import
org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -46,6 +47,7 @@ import
org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -84,36 +86,40 @@ public class SpoofSPInstruction extends SPInstruction {
public static SpoofSPInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
-
+
//String opcode = parts[0];
ArrayList<CPOperand> inlist = new ArrayList<>();
Class<?> cls = CodegenUtils.getClass(parts[2]);
byte[] classBytes = CodegenUtils.getClassData(parts[2]);
String opcode = parts[0] +
CodegenUtils.createInstance(cls).getSpoofType();
-
+
for( int i=3; i<parts.length-2; i++ )
inlist.add(new CPOperand(parts[i]));
CPOperand out = new CPOperand(parts[parts.length-2]);
//note: number of threads parts[parts.length-1] always ignored
-
+
return new SpoofSPInstruction(cls, classBytes,
inlist.toArray(new CPOperand[0]), out, opcode, str);
}
+ public Class<?> getOperatorClass() {
+ return _class;
+ }
+
@Override
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
-
+
//decide upon broadcast side inputs
boolean[] bcVect = determineBroadcastInputs(sec, _in);
boolean[] bcVect2 = getMatrixBroadcastVector(sec, _in, bcVect);
int main = getMainInputIndex(_in, bcVect);
-
+
//create joined input rdd w/ replication if needed
DataCharacteristics mcIn =
sec.getDataCharacteristics(_in[main].getName());
JavaPairRDD<MatrixIndexes, MatrixBlock[]> in =
createJoinedInputRDD(
sec, _in, bcVect, (_class.getSuperclass() ==
SpoofOuterProduct.class));
JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
-
+
//create lists of input broadcasts and scalars
ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new
ArrayList<>();
ArrayList<ScalarObject> scalars = new ArrayList<>();
@@ -126,27 +132,27 @@ public class SpoofSPInstruction extends SPInstruction {
scalars.add(sec.getScalarInput(_in[i]));
}
}
-
+
//execute generated operator
if(_class.getSuperclass() == SpoofCellwise.class) //CELL
{
SpoofCellwise op = (SpoofCellwise)
CodegenUtils.createInstance(_class);
AggregateOperator aggop =
getAggregateOperator(op.getAggOp());
-
+
if( _out.getDataType()==DataType.MATRIX ) {
//execute codegen block operation
out = in.mapPartitionsToPair(new
CellwiseFunction(_class.getName(),
_classBytes, bcVect2, bcMatrices,
scalars, mcIn.getBlocksize()), true);
-
+
if( (op.getCellType()==CellType.ROW_AGG &&
mcIn.getCols() > mcIn.getBlocksize())
|| (op.getCellType()==CellType.COL_AGG
&& mcIn.getRows() > mcIn.getBlocksize())) {
- long numBlocks =
(op.getCellType()==CellType.ROW_AGG ) ?
+ long numBlocks =
(op.getCellType()==CellType.ROW_AGG ) ?
mcIn.getNumRowBlocks() :
mcIn.getNumColBlocks();
out =
RDDAggregateUtils.aggByKeyStable(out, aggop,
(int)Math.min(out.getNumPartitions(), numBlocks), false);
}
sec.setRDDHandleForVariable(_out.getName(),
out);
-
+
//maintain lineage info and output
characteristics
maintainLineageInfo(sec, _in, bcVect, _out);
updateOutputDataCharacteristics(sec, op);
@@ -176,7 +182,7 @@ public class SpoofSPInstruction extends SPInstruction {
//update matrix characteristics
updateOutputDataCharacteristics(sec, op);
DataCharacteristics mcOut =
sec.getDataCharacteristics(_out.getName());
-
+
out = in.mapPartitionsToPair(new
OuterProductFunction(
_class.getName(), _classBytes, bcVect2,
bcMatrices, scalars), true);
if(type == OutProdType.LEFT_OUTER_PRODUCT ||
type == OutProdType.RIGHT_OUTER_PRODUCT ) {
@@ -185,7 +191,7 @@ public class SpoofSPInstruction extends SPInstruction {
(int)Math.min(out.getNumPartitions(), numBlocks), false);
}
sec.setRDDHandleForVariable(_out.getName(),
out);
-
+
//maintain lineage info and output
characteristics
maintainLineageInfo(sec, _in, bcVect, _out);
}
@@ -198,7 +204,7 @@ public class SpoofSPInstruction extends SPInstruction {
}
else if( _class.getSuperclass() == SpoofRowwise.class ) { //ROW
if( mcIn.getCols() > mcIn.getBlocksize() ) {
- throw new DMLRuntimeException("Invalid spark
rowwise operator w/ ncol=" +
+ throw new DMLRuntimeException("Invalid spark
rowwise operator w/ ncol=" +
mcIn.getCols()+",
ncolpb="+mcIn.getBlocksize()+".");
}
SpoofRowwise op = (SpoofRowwise)
CodegenUtils.createInstance(_class);
@@ -208,23 +214,23 @@ public class SpoofSPInstruction extends SPInstruction {
bcMatrices, scalars, mcIn.getBlocksize(),
(int)mcIn.getCols(), (int)clen2);
out = in.mapPartitionsToPair(fmmc,
op.getRowType()==RowType.ROW_AGG
|| op.getRowType() == RowType.NO_AGG);
-
+
if( op.getRowType().isColumnAgg() ||
op.getRowType()==RowType.FULL_AGG ) {
MatrixBlock tmpMB =
RDDAggregateUtils.sumStable(out);
if( op.getRowType().isColumnAgg() )
sec.setMatrixOutput(_out.getName(),
tmpMB);
else
- sec.setScalarOutput(_out.getName(),
+ sec.setScalarOutput(_out.getName(),
new
DoubleObject(tmpMB.quickGetValue(0, 0)));
}
- else //row-agg or no-agg
+ else //row-agg or no-agg
{
if( op.getRowType()==RowType.ROW_AGG &&
mcIn.getCols() > mcIn.getBlocksize() ) {
out =
RDDAggregateUtils.sumByKeyStable(out,
(int)Math.min(out.getNumPartitions(), mcIn.getNumRowBlocks()), false);
}
sec.setRDDHandleForVariable(_out.getName(),
out);
-
+
//maintain lineage info and output
characteristics
maintainLineageInfo(sec, _in, bcVect, _out);
updateOutputDataCharacteristics(sec, op);
@@ -234,16 +240,16 @@ public class SpoofSPInstruction extends SPInstruction {
throw new DMLRuntimeException("Operator " +
_class.getSuperclass() + " is not supported on Spark");
}
}
-
+
private static boolean[] determineBroadcastInputs(SparkExecutionContext
sec, CPOperand[] inputs) {
boolean[] ret = new boolean[inputs.length];
double localBudget = OptimizerUtils.getLocalMemBudget()
- CacheableData.getBroadcastSize(); //account for other
broadcasts
double bcBudget =
SparkExecutionContext.getBroadcastMemoryBudget();
-
+
//decided for each matrix input if it fits into remaining memory
//budget; the major input, i.e., inputs[0] is always an RDD
- for( int i=0; i<inputs.length; i++ )
+ for( int i=0; i<inputs.length; i++ )
if( inputs[i].getDataType().isMatrix() ) {
DataCharacteristics mc =
sec.getDataCharacteristics(inputs[i].getName());
double sizeL =
OptimizerUtils.estimateSizeExactSparsity(mc);
@@ -253,14 +259,14 @@ public class SpoofSPInstruction extends SPInstruction {
localBudget -= ret[i] ? sizeP : 0; //in local
block manager
bcBudget -= ret[i] ? sizeP : 0; //in remote
block managers
}
-
+
//ensure there is at least one RDD input, with awareness for
scalars
if( !IntStream.range(0, ret.length).anyMatch(i ->
inputs[i].isMatrix() && !ret[i]) )
ret[0] = false;
-
+
return ret;
}
-
+
private static boolean[] getMatrixBroadcastVector(SparkExecutionContext
sec, CPOperand[] inputs, boolean[] bcVect) {
int numMtx = (int) Arrays.stream(inputs)
.filter(in -> in.getDataType().isMatrix()).count();
@@ -270,17 +276,17 @@ public class SpoofSPInstruction extends SPInstruction {
ret[pos++] = bcVect[i];
return ret;
}
-
+
private static JavaPairRDD<MatrixIndexes, MatrixBlock[]>
createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[]
bcVect, boolean outer) {
//get input rdd for main input
int main = getMainInputIndex(inputs, bcVect);
DataCharacteristics mcIn =
sec.getDataCharacteristics(inputs[main].getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in =
sec.getBinaryMatrixBlockRDDHandleForVariable(inputs[main].getName());
JavaPairRDD<MatrixIndexes, MatrixBlock[]> ret =
in.mapValues(new MapInputSignature());
-
+
for( int i=0; i<inputs.length; i++ )
if( i != main && inputs[i].getDataType().isMatrix() &&
!bcVect[i] ) {
- //create side input rdd
+ //create side input rdd
String varname = inputs[i].getName();
JavaPairRDD<MatrixIndexes, MatrixBlock> tmp =
sec
.getBinaryMatrixBlockRDDHandleForVariable(varname);
@@ -296,22 +302,22 @@ public class SpoofSPInstruction extends SPInstruction {
ret = ret.join(tmp)
.mapValues(new MapJoinSignature());
}
-
+
return ret;
}
-
+
private static void maintainLineageInfo(SparkExecutionContext sec,
CPOperand[] inputs, boolean[] bcVect, CPOperand output) {
- //add lineage info for all rdd/broadcast inputs
+ //add lineage info for all rdd/broadcast inputs
for( int i=0; i<inputs.length; i++ )
if( inputs[i].getDataType().isMatrix() )
sec.addLineage(output.getName(),
inputs[i].getName(), bcVect[i]);
}
-
+
private static int getMainInputIndex(CPOperand[] inputs, boolean[]
bcVect) {
return IntStream.range(0, bcVect.length)
.filter(i -> inputs[i].isMatrix() &&
!bcVect[i]).min().orElse(0);
}
-
+
private void updateOutputDataCharacteristics(SparkExecutionContext sec,
SpoofOperator op) {
if(op instanceof SpoofCellwise) {
DataCharacteristics mcIn =
sec.getDataCharacteristics(_in[0].getName());
@@ -327,7 +333,7 @@ public class SpoofSPInstruction extends SPInstruction {
DataCharacteristics mcIn3 =
sec.getDataCharacteristics(_in[2].getName()); //V
DataCharacteristics mcOut =
sec.getDataCharacteristics(_out.getName());
OutProdType type =
((SpoofOuterProduct)op).getOuterProdType();
-
+
if( type == OutProdType.CELLWISE_OUTER_PRODUCT)
mcOut.set(mcIn1.getRows(), mcIn1.getCols(),
mcIn1.getBlocksize(), mcIn1.getBlocksize());
else if( type == OutProdType.LEFT_OUTER_PRODUCT)
@@ -342,7 +348,7 @@ public class SpoofSPInstruction extends SPInstruction {
if( type == RowType.NO_AGG )
mcOut.set(mcIn);
else if( type == RowType.ROW_AGG )
- mcOut.set(mcIn.getRows(), 1,
+ mcOut.set(mcIn.getRows(), 1,
mcIn.getBlocksize(),
mcIn.getBlocksize());
else if( type == RowType.COL_AGG )
mcOut.set(1, mcIn.getCols(),
mcIn.getBlocksize(), mcIn.getBlocksize());
@@ -350,17 +356,17 @@ public class SpoofSPInstruction extends SPInstruction {
mcOut.set(mcIn.getCols(), 1,
mcIn.getBlocksize(), mcIn.getBlocksize());
}
}
-
- private static class SpoofFunction implements Serializable
- {
+
+ private static class SpoofFunction implements Serializable
+ {
private static final long serialVersionUID =
2953479427746463003L;
-
+
protected final boolean[] _bcInd;
protected final ArrayList<PartitionedBroadcast<MatrixBlock>>
_inputs;
protected final ArrayList<ScalarObject> _scalars;
protected final byte[] _classBytes;
protected final String _className;
-
+
protected SpoofFunction(String className, byte[] classBytes,
boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices,
ArrayList<ScalarObject> scalars) {
_bcInd = bcInd;
_inputs = bcMatrices;
@@ -368,20 +374,20 @@ public class SpoofSPInstruction extends SPInstruction {
_classBytes = classBytes;
_className = className;
}
-
+
protected ArrayList<MatrixBlock>
getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn) {
return getAllMatrixInputs(ixIn, blkIn, false);
}
-
+
protected ArrayList<MatrixBlock>
getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn, boolean outer) {
ArrayList<MatrixBlock> ret = new ArrayList<>();
//add all rdd/broadcast inputs (main and side inputs)
for( int i=0, posRdd=0, posBc=0; i<_bcInd.length; i++ )
{
if( _bcInd[i] ) {
PartitionedBroadcast<MatrixBlock> pb =
_inputs.get(posBc++);
- int rowIndex = (int) ((outer && i==2) ?
ixIn.getColumnIndex() :
+ int rowIndex = (int) ((outer && i==2) ?
ixIn.getColumnIndex() :
(pb.getNumRowBlocks()>=ixIn.getRowIndex())?ixIn.getRowIndex():1);
- int colIndex = (int) ((outer && i==2) ?
1 :
+ int colIndex = (int) ((outer && i==2) ?
1 :
(pb.getNumColumnBlocks()>=ixIn.getColumnIndex())?ixIn.getColumnIndex():1);
ret.add(pb.getBlock(rowIndex,
colIndex));
}
@@ -391,9 +397,9 @@ public class SpoofSPInstruction extends SPInstruction {
return ret;
}
}
-
+
private static class RowwiseFunction extends SpoofFunction
- implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,
MatrixBlock[]>>, MatrixIndexes, MatrixBlock>
+ implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,
MatrixBlock[]>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID =
-7926980450209760212L;
@@ -401,7 +407,7 @@ public class SpoofSPInstruction extends SPInstruction {
private final int _clen;
private final int _clen2;
private SpoofRowwise _op = null;
-
+
public RowwiseFunction(String className, byte[] classBytes,
boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices,
ArrayList<ScalarObject> scalars, int blen, int clen,
int clen2) {
super(className, classBytes, bcInd, bcMatrices,
scalars);
@@ -409,30 +415,30 @@ public class SpoofSPInstruction extends SPInstruction {
_clen = clen;
_clen2 = clen;
}
-
+
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(
Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg ) {
//lazy load of shipped class
if( _op == null ) {
Class<?> loadedClass =
CodegenUtils.getClassSync(_className, _classBytes);
- _op = (SpoofRowwise)
CodegenUtils.createInstance(loadedClass);
+ _op = (SpoofRowwise)
CodegenUtils.createInstance(loadedClass);
}
-
+
//setup local memory for reuse
LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen,
_clen2);
-
+
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new
ArrayList<>();
boolean aggIncr = (_op.getRowType().isColumnAgg()
//aggregate entire partition
- || _op.getRowType() == RowType.FULL_AGG);
+ || _op.getRowType() == RowType.FULL_AGG);
MatrixBlock blkOut = aggIncr ? new MatrixBlock() : null;
-
+
while( arg.hasNext() ) {
//get main input block and indexes
Tuple2<MatrixIndexes,MatrixBlock[]> e =
arg.next();
MatrixIndexes ixIn = e._1();
MatrixBlock[] blkIn = e._2();
long rix = (ixIn.getRowIndex()-1) * _blen;
//0-based
-
+
//prepare output and execute single-threaded
operator
ArrayList<MatrixBlock> inputs =
getAllMatrixInputs(ixIn, blkIn);
blkOut = aggIncr ? blkOut : new MatrixBlock();
@@ -443,7 +449,7 @@ public class SpoofSPInstruction extends SPInstruction {
ret.add(new Tuple2<>(ixOut, blkOut));
}
}
-
+
//cleanup and final result preparations
LibSpoofPrimitives.cleanupThreadLocalMemory();
if( aggIncr ) {
@@ -451,45 +457,45 @@ public class SpoofSPInstruction extends SPInstruction {
blkOut.examSparsity(); //deferred format change
ret.add(new Tuple2<>(new MatrixIndexes(1,1),
blkOut));
}
-
+
return ret.iterator();
}
}
-
+
private static class CellwiseFunction extends SpoofFunction
- implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,
MatrixBlock[]>>, MatrixIndexes, MatrixBlock>
+ implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,
MatrixBlock[]>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID =
-8209188316939435099L;
-
+
private SpoofCellwise _op = null;
private final int _blen;
-
+
public CellwiseFunction(String className, byte[] classBytes,
boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices,
ArrayList<ScalarObject> scalars, int blen) {
super(className, classBytes, bcInd, bcMatrices,
scalars);
_blen = blen;
}
-
+
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>>
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg)
- throws Exception
+ throws Exception
{
//lazy load of shipped class
if( _op == null ) {
Class<?> loadedClass =
CodegenUtils.getClassSync(_className, _classBytes);
- _op = (SpoofCellwise)
CodegenUtils.createInstance(loadedClass);
+ _op = (SpoofCellwise)
CodegenUtils.createInstance(loadedClass);
}
-
+
List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new
ArrayList<>();
- while(arg.hasNext())
+ while(arg.hasNext())
{
Tuple2<MatrixIndexes,MatrixBlock[]> tmp =
arg.next();
MatrixIndexes ixIn = tmp._1();
MatrixBlock[] blkIn = tmp._2();
- MatrixIndexes ixOut = ixIn;
+ MatrixIndexes ixOut = ixIn;
MatrixBlock blkOut = new MatrixBlock();
ArrayList<MatrixBlock> inputs =
getAllMatrixInputs(ixIn, blkIn);
long rix = (ixIn.getRowIndex()-1) * _blen;
//0-based
-
+
//execute core operation
if( _op.getCellType()==CellType.FULL_AGG ) {
ScalarObject obj = _op.execute(inputs,
_scalars, 1, rix);
@@ -508,53 +514,53 @@ public class SpoofSPInstruction extends SPInstruction {
return ret.iterator();
}
}
-
+
private static class MultiAggregateFunction extends SpoofFunction
- implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>,
MatrixIndexes, MatrixBlock>
+ implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>,
MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID =
-5224519291577332734L;
-
+
private SpoofMultiAggregate _op = null;
private final int _blen;
-
+
public MultiAggregateFunction(String className, byte[]
classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>>
bcMatrices, ArrayList<ScalarObject> scalars, int blen) {
super(className, classBytes, bcInd, bcMatrices,
scalars);
_blen = blen;
}
-
+
@Override
public Tuple2<MatrixIndexes, MatrixBlock>
call(Tuple2<MatrixIndexes, MatrixBlock[]> arg)
- throws Exception
+ throws Exception
{
//lazy load of shipped class
if( _op == null ) {
Class<?> loadedClass =
CodegenUtils.getClassSync(_className, _classBytes);
- _op = (SpoofMultiAggregate)
CodegenUtils.createInstance(loadedClass);
+ _op = (SpoofMultiAggregate)
CodegenUtils.createInstance(loadedClass);
}
-
+
//execute core operation
ArrayList<MatrixBlock> inputs =
getAllMatrixInputs(arg._1(), arg._2());
MatrixBlock blkOut = new MatrixBlock();
long rix = (arg._1().getRowIndex()-1) * _blen; //0-based
blkOut = _op.execute(inputs, _scalars, blkOut, 1, rix);
-
+
return new Tuple2<>(arg._1(), blkOut);
}
}
-
- private static class MultiAggAggregateFunction implements
Function2<MatrixBlock, MatrixBlock, MatrixBlock>
+
+ private static class MultiAggAggregateFunction implements
Function2<MatrixBlock, MatrixBlock, MatrixBlock>
{
private static final long serialVersionUID =
5978731867787952513L;
-
+
private AggOp[] _ops = null;
-
+
public MultiAggAggregateFunction( AggOp[] ops ) {
- _ops = ops;
+ _ops = ops;
}
-
+
@Override
public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1)
- throws Exception
+ throws Exception
{
//prepare combiner block
if( arg0.getNumRows() <= 0 || arg0.getNumColumns() <=
0) {
@@ -564,35 +570,35 @@ public class SpoofSPInstruction extends SPInstruction {
else if( arg1.getNumRows() <= 0 || arg1.getNumColumns()
<= 0 ) {
return arg0;
}
-
+
//aggregate second input (in-place)
SpoofMultiAggregate.aggregatePartialResults(_ops, arg0,
arg1);
-
+
return arg0;
}
}
-
+
private static class OuterProductFunction extends SpoofFunction
- implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,
MatrixBlock[]>>, MatrixIndexes, MatrixBlock>
+ implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes,
MatrixBlock[]>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID =
-8209188316939435099L;
-
+
private SpoofOperator _op = null;
-
+
public OuterProductFunction(String className, byte[]
classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>>
bcMatrices, ArrayList<ScalarObject> scalars) {
super(className, classBytes, bcInd, bcMatrices,
scalars);
}
-
+
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>>
call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg)
- throws Exception
+ throws Exception
{
//lazy load of shipped class
if( _op == null ) {
Class<?> loadedClass =
CodegenUtils.getClassSync(_className, _classBytes);
- _op = CodegenUtils.createInstance(loadedClass);
+ _op = CodegenUtils.createInstance(loadedClass);
}
-
+
List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new
ArrayList<>();
while(arg.hasNext())
{
@@ -611,45 +617,45 @@ public class SpoofSPInstruction extends SPInstruction {
else {
blkOut = _op.execute(inputs, _scalars,
blkOut);
}
-
+
ret.add(new
Tuple2<>(createOutputIndexes(ixIn,_op), blkOut));
}
-
+
return ret.iterator();
}
-
+
private static MatrixIndexes createOutputIndexes(MatrixIndexes
in, SpoofOperator spoofOp) {
- if( ((SpoofOuterProduct)spoofOp).getOuterProdType() ==
OutProdType.LEFT_OUTER_PRODUCT )
+ if( ((SpoofOuterProduct)spoofOp).getOuterProdType() ==
OutProdType.LEFT_OUTER_PRODUCT )
return new MatrixIndexes(in.getColumnIndex(),
1);
else if (
((SpoofOuterProduct)spoofOp).getOuterProdType() ==
OutProdType.RIGHT_OUTER_PRODUCT)
return new MatrixIndexes(in.getRowIndex(), 1);
- else
+ else
return in;
}
}
-
- public static class ReplicateRightFactorFunction implements
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes,
MatrixBlock>
+
+ public static class ReplicateRightFactorFunction implements
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes,
MatrixBlock>
{
private static final long serialVersionUID =
-7295989688796126442L;
-
+
private final long _len;
private final long _blen;
-
+
public ReplicateRightFactorFunction(long len, long blen) {
_len = len;
_blen = blen;
}
-
+
@Override
- public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(
Tuple2<MatrixIndexes, MatrixBlock> arg0 )
- throws Exception
+ public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(
Tuple2<MatrixIndexes, MatrixBlock> arg0 )
+ throws Exception
{
LinkedList<Tuple2<MatrixIndexes, MatrixBlock>> ret =
new LinkedList<>();
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();
-
- long numBlocks = (long) Math.ceil((double)_len/_blen);
-
+
+ long numBlocks = (long) Math.ceil((double)_len/_blen);
+
//replicate wrt # row blocks in LHS
long j = ixIn.getRowIndex();
for( long i=1; i<=numBlocks; i++ ) {
@@ -657,12 +663,12 @@ public class SpoofSPInstruction extends SPInstruction {
MatrixBlock tmpblk = blkIn;
ret.add( new Tuple2<>(tmpix, tmpblk) );
}
-
+
//output list of new tuples
return ret.iterator();
}
}
-
+
public static AggregateOperator getAggregateOperator(AggOp aggop) {
if( aggop == AggOp.SUM || aggop == AggOp.SUM_SQ )
return new AggregateOperator(0,
KahanPlus.getKahanPlusFnObject(), CorrectionLocationType.NONE);
@@ -672,4 +678,13 @@ public class SpoofSPInstruction extends SPInstruction {
return new AggregateOperator(Double.NEGATIVE_INFINITY,
Builtin.getBuiltinFnObject(BuiltinCode.MAX), CorrectionLocationType.NONE);
return null;
}
+
+ public boolean isFederated(ExecutionContext ec) {
+ for(CPOperand input : _in) {
+ Data data = ec.getVariable(input);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
+ return true;
+ }
+ return false;
+ }
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index bc22b4c..a76db81 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -45,7 +45,6 @@ import
org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
-import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
/**
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
new file mode 100644
index 0000000..653e622
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.codegen;
+
+import java.io.File;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedCellwiseTmplTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "FederatedCellwiseTmplTest";
+
+ private final static String TEST_DIR = "functions/federated/codegen/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedCellwiseTmplTest.class.getSimpleName() + "/";
+
+ private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+
+ private final static String OUTPUT_NAME = "Z";
+ private final static double TOLERANCE = 0;
+ private final static int BLOCKSIZE = 1024;
+
+ @Parameterized.Parameter()
+ public int test_num;
+ @Parameterized.Parameter(1)
+ public int rows;
+ @Parameterized.Parameter(2)
+ public int cols;
+ @Parameterized.Parameter(3)
+ public double sparsity;
+ @Parameterized.Parameter(4)
+ public boolean row_partitioned;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even for row partitioned X
+ // cols must be even for col partitioned X
+ return Arrays.asList(new Object[][] {
+ // {test_num, rows, cols, sparsity, row_partitioned}
+
+ // row partitioned
+ {1, 2000, 2000, 1, true},
+ {2, 10, 10, 1, true},
+ {3, 4, 4, 1, true},
+ {4, 4, 4, 1, true},
+ {5, 4, 4, 1, true},
+ {6, 4, 1, 1, true},
+ {9, 500, 2, 1, true},
+ {10, 500, 2, 1, true},
+ {11, 1100, 2000, 1, true},
+ {12, 2, 500, 1, true},
+ {13, 2, 4, 1, true},
+ {14, 1100, 200, 1, true},
+
+ // column partitioned
+ {1, 2000, 2000, 1, false},
+ {2, 10, 10, 1, false},
+ {3, 4, 4, 1, false},
+ {4, 4, 4, 1, false},
+ {5, 4, 4, 1, false},
+ {9, 500, 2, 1, false},
+ {10, 500, 2, 1, false},
+ {11, 1100, 2000, 1, false},
+ {12, 2, 500, 1, false},
+ {14, 1100, 200, 1, false},
+
+ // not working because of fused sequence operation
+ // (wrong grix inside genexec call of fed worker)
+ // {7, 1000, 1, 1, true},
+
+ // not creating a FedSpoof instruction
+ // {8, 1002, 24, 1, true},
+ // {8, 1002, 24, 1, false},
+ });
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+// @Test
+// public void federatedCodegenCellwiseSingleNode() {
+// testFederatedCodegen(ExecMode.SINGLE_NODE);
+// }
+//
+// @Test
+// public void federatedCodegenCellwiseSpark() {
+// testFederatedCodegen(ExecMode.SPARK);
+// }
+
+ @Test
+ public void federatedCodegenCellwiseHybrid() {
+ testFederatedCodegen(ExecMode.HYBRID);
+ }
+
+ private void testFederatedCodegen(ExecMode exec_mode) {
+ // store the previous platform config to restore it after the
test
+ ExecMode platform_old = setExecMode(exec_mode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows;
+ int fed_cols = cols;
+ if(row_partitioned)
+ fed_rows /= 2;
+ else
+ fed_cols /= 2;
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 7);
+
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+ "in_rp=" +
Boolean.toString(row_partitioned).toUpperCase(),
+ "in_test_num=" + Integer.toString(test_num),
+ "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_rp=" +
Boolean.toString(row_partitioned).toUpperCase(),
+ "in_test_num=" + Integer.toString(test_num),
+ "rows=" + rows, "cols=" + cols,
+ "out_Z=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ resetExecMode(platform_old);
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR,
TEST_CONF);
+ return TEST_CONF_FILE;
+ }
+}
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
new file mode 100644
index 0000000..68d48fb
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
@@ -0,0 +1,123 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)));
+}
+else {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2),
list($rows, $cols)));
+}
+
+if(test_num == 1) {
+ # X ... 2000x2000 matrix
+ Y= matrix(2, rows=2000, cols=1);
+
+ lamda = sum(Y);
+ Z = round(abs(X + lamda)) + 5;
+}
+else if(test_num == 2) {
+ # X ... 10x10 matrix
+ Z = 1/(1+exp(-X));
+}
+else if(test_num == 3) {
+ # X ... 4x4 matrix
+ Z = 10 + floor(round(abs(7 + (1 / exp(X)))));
+}
+else if(test_num == 4) {
+ # X ... 4x4 matrix
+ w = matrix(3, rows=4, cols=4);
+ v = matrix(5, rows=4, cols=4);
+ Z = 10 + floor(round(abs((X + w) * v)));
+}
+else if(test_num == 5) {
+ # X ... 4x4 matrix
+ w = matrix("1 2 3 4", rows=4, cols=1);
+ v = matrix("4 4 4 4", rows=4, cols=1);
+
+ G = abs(exp(X));
+ Y = 10 + floor(round(abs((X / w) + v)));
+ Z = G + Y;
+}
+else if(test_num == 6) {
+ # X ... 4x1 vector
+ y = matrix("1 1 1 1", rows=4, cols=1);
+ v = matrix("3 3 3 3", rows=4, cols=1);
+
+ Z = as.matrix(sum(X * y * v));
+}
+else if(test_num == 7) {
+ # X ... 1000x1 vector
+ Y = seq(6, 1006);
+
+ U = X + Y - 7 + abs(X);
+ Z = t(U) %*% U;
+}
+else if(test_num == 8) {
+ # X ... 1002x24 matrix
+ Y = seq(1, 1002);
+ X[100:900,] = matrix(0, rows=801, cols=24);
+
+ Z = X * ((X + 7.7) * Y);
+}
+else if(test_num == 9) {
+ # X ... 500x2 matrix
+ Y = matrix(seq(6, 1005), 500, 2);
+
+ U = X + 7 * Y;
+ Z = as.matrix(sum(U^2))
+}
+else if(test_num == 10) {
+ # X ... 500x2 matrix
+
+ Y = (0 / (X - 500))+1;
+ Z = replace(target=Y, pattern=0/0, replacement=7);
+}
+else if(test_num == 11) {
+ # X ... 1100x2000 matrix
+ Y = seq(1, 2000);
+
+ Z = -2 * X + t(Y);
+}
+else if(test_num == 12) {
+ # X ... 2x500 matrix
+ Y = matrix(seq(6, 1005), 2, 500);
+
+ U = X + 7 * Y;
+ Z = as.matrix(sum(U^2))
+}
+else if(test_num == 13) {
+ # X ... 2x4 matrix
+ w = matrix(seq(1,8), rows=2, cols=4);
+ v = matrix(5, rows=2, cols=4);
+ Z = 10 + floor(round(abs((X + w) * v)));
+}
+else if(test_num == 14) {
+ # X ... 1100x200 matrix
+
+ Z = colMins(2 * log(X));
+}
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
new file mode 100644
index 0000000..1826fb2
--- /dev/null
+++
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
@@ -0,0 +1,125 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = rbind(read($in_X1), read($in_X2));
+}
+else {
+ X = cbind(read($in_X1), read($in_X2));
+
+ # TODO: remove as soon as jira ticket SYSTEMDS-2888 has been solved
+ # needed to seperate the cbind from the code generation
+ while(FALSE) {}
+}
+
+if(test_num == 1) {
+ # X ... 2000x2000 matrix
+ Y= matrix(2, rows=2000, cols=1);
+
+ lamda = sum(Y);
+ Z = round(abs(X + lamda)) + 5;
+}
+else if(test_num == 2) {
+ # X ... 10x10 matrix
+ Z = 1/(1+exp(-X));
+}
+else if(test_num == 3) {
+ # X ... 4x4 matrix
+ Z = 10 + floor(round(abs(7 + (1 / exp(X)))));
+}
+else if(test_num == 4) {
+ # X ... 4x4 matrix
+ w = matrix(3, rows=4, cols=4);
+ v = matrix(5, rows=4, cols=4);
+ Z = 10 + floor(round(abs((X + w) * v)));
+}
+else if(test_num == 5) {
+ # X ... 4x4 matrix
+ w = matrix("1 2 3 4", rows=4, cols=1);
+ v = matrix("4 4 4 4", rows=4, cols=1);
+
+ G = abs(exp(X));
+ Y = 10 + floor(round(abs((X / w) + v)));
+ Z = G + Y;
+}
+else if(test_num == 6) {
+ # X ... 4x1 vector
+ y = matrix("1 1 1 1", rows=4, cols=1);
+ v = matrix("3 3 3 3", rows=4, cols=1);
+
+ Z = as.matrix(sum(X * y * v));
+}
+else if(test_num == 7) {
+ # X ... 1000x1 vector
+ Y = seq(6, 1006);
+
+ U = X + Y - 7 + abs(X);
+ Z = t(U) %*% U;
+}
+else if(test_num == 8) {
+ # X ... 1002x24 matrix
+ Y = seq(1, 1002);
+ X[100:900,] = matrix(0, rows=801, cols=24);
+
+ Z = X * ((X + 7.7) * Y);
+}
+else if(test_num == 9) {
+ # X ... 500x2 matrix
+ Y = matrix(seq(6, 1005), 500, 2);
+
+ U = X + 7 * Y;
+ Z = as.matrix(sum(U^2))
+}
+else if(test_num == 10) {
+ # X ... 500x2 matrix
+
+ Y = (0 / (X - 500))+1;
+ Z = replace(target=Y, pattern=0/0, replacement=7);
+}
+else if(test_num == 11) {
+ # X ... 1100x2000 matrix
+ Y = seq(1, 2000);
+
+ Z = -2 * X + t(Y);
+}
+else if(test_num == 12) {
+ # X ... 2x500 matrix
+ Y = matrix(seq(6, 1005), 2, 500);
+
+ U = X + 7 * Y;
+ Z = as.matrix(sum(U^2))
+}
+else if(test_num == 13) {
+ # X ... 2x4 matrix
+ w = matrix(seq(1,8), rows=2, cols=4);
+ v = matrix(5, rows=2, cols=4);
+ Z = 10 + floor(round(abs((X + w) * v)));
+}
+else if(test_num == 14) {
+ # X ... 1100x200 matrix
+
+ Z = colMins(2 * log(X));
+}
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/codegen/SystemDS-config-codegen.xml
b/src/test/scripts/functions/federated/codegen/SystemDS-config-codegen.xml
new file mode 100644
index 0000000..b3d4712
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/SystemDS-config-codegen.xml
@@ -0,0 +1,32 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.localtmpdir>/tmp/systemds</sysds.localtmpdir>
+ <sysds.scratch>scratch_space</sysds.scratch>
+ <sysds.optlevel>7</sysds.optlevel>
+ <sysds.codegen.enabled>true</sysds.codegen.enabled>
+ <sysds.codegen.plancache>true</sysds.codegen.plancache>
+ <sysds.codegen.literals>1</sysds.codegen.literals>
+
+ <!-- The number of theads for the spark instance artificially selected-->
+ <sysds.local.spark.number.threads>16</sysds.local.spark.number.threads>
+
+ <sysds.codegen.api>auto</sysds.codegen.api>
+</root>