This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new ede3635a6c [SYSTEMDS-3443] Asynchronously execute and persist Spark
transformations
ede3635a6c is described below
commit ede3635a6c0ba9d7089044eeec5317f685e9e03b
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Oct 10 21:02:52 2022 +0200
[SYSTEMDS-3443] Asynchronously execute and persist Spark transformations
This patch adds an operator to asynchronously trigger a chain of Spark
transformations and persist the result. This is a generalization of
Prefetch instruction and works if the consumer is a Spark transformation
or action. TODO: operator placement (in parallel with a CP instruction
chain).
Closes #1704
---
src/main/java/org/apache/sysds/common/Types.java | 2 +-
.../runtime/instructions/CPInstructionParser.java | 7 ++++-
.../runtime/instructions/cp/CPInstruction.java | 2 +-
.../instructions/cp/PrefetchCPInstruction.java | 2 +-
...perationsTask.java => TriggerPrefetchTask.java} | 8 +++---
.../cp/TriggerRemoteOperationsTask.java | 33 +++++++++-------------
...ion.java => TriggerRemoteOpsCPInstruction.java} | 21 ++++++--------
.../apache/sysds/utils/stats/SparkStatistics.java | 14 +++++++--
.../test/functions/async/PrefetchRDDTest.java | 2 +-
src/test/scripts/functions/async/PrefetchRDD3.dml | 6 ++--
10 files changed, 52 insertions(+), 45 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index a7cfa823aa..a9b8108600 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -241,7 +241,7 @@ public class Types
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR,
INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT,
STOP,
- SVD, TAN, TANH, TYPEOF,
+ SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
//fused ML-specific operators for performance
SPROP, //sample proportion: P * (1 - P)
SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b53c954fd8..f2d3080ddc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -67,6 +67,7 @@ import
org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.TriggerRemoteOpsCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -328,6 +329,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "spoof", CPType.SpoofFused);
String2CPInstructionType.put( "prefetch", CPType.Prefetch);
String2CPInstructionType.put( "broadcast", CPType.Broadcast);
+ String2CPInstructionType.put( "trigremote", CPType.TrigRemote);
String2CPInstructionType.put( Local.OPCODE, CPType.Local);
String2CPInstructionType.put( "sql", CPType.Sql);
@@ -477,7 +479,10 @@ public class CPInstructionParser extends InstructionParser
case Broadcast:
return
BroadcastCPInstruction.parseInstruction(str);
-
+
+ case TrigRemote:
+ return
TriggerRemoteOpsCPInstruction.parseInstruction(str);
+
default:
throw new DMLRuntimeException("Invalid CP
Instruction Type: " + cptype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index acb1fd1dae..144760b3d9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -47,7 +47,7 @@ public abstract class CPInstruction extends Instruction
MultiReturnParameterizedBuiltin, ParameterizedBuiltin,
MultiReturnBuiltin,
Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick,
Local,
MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition,
Compression, DeCompression, SpoofFused,
- StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn,
Sql, Prefetch, Broadcast }
+ StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn,
Sql, Prefetch, Broadcast, TrigRemote }
protected final CPType _cptype;
protected final boolean _requiresLabelUpdate;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 9d95a58dc3..db9bbb1b84 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -51,6 +51,6 @@ public class PrefetchCPInstruction extends UnaryCPInstruction
{
// In that case this Prefetch instruction will act like a NOOP.
if (CommonThreadPool.triggerRemoteOPsPool == null)
CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
- CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
+ CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerPrefetchTask(ec.getMatrixObject(output)));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
similarity index 87%
copy from
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
copy to
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index 255a2d61f6..26a1c8e8bb 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -24,10 +24,10 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.utils.stats.SparkStatistics;
-public class TriggerRemoteOperationsTask implements Runnable {
+public class TriggerPrefetchTask implements Runnable {
MatrixObject _prefetchMO;
- public TriggerRemoteOperationsTask(MatrixObject mo) {
+ public TriggerPrefetchTask(MatrixObject mo) {
_prefetchMO = mo;
}
@@ -35,8 +35,8 @@ public class TriggerRemoteOperationsTask implements Runnable {
public void run() {
boolean prefetched = false;
synchronized (_prefetchMO) {
- // Having this check if operations are pending inside
the
- // critical section safeguards against concurrent rmVar.
+ // Having this check inside the critical section
+ // safeguards against concurrent rmVar.
if (_prefetchMO.isPendingRDDOps() ||
_prefetchMO.isFederated()) {
// TODO: Add robust runtime constraints for
federated prefetch
// Execute and bring the result to local
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
index 255a2d61f6..63e6f56fd5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
@@ -19,37 +19,32 @@
package org.apache.sysds.runtime.instructions.cp;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.utils.stats.SparkStatistics;
public class TriggerRemoteOperationsTask implements Runnable {
- MatrixObject _prefetchMO;
+ MatrixObject _remoteOperationsRoot;
public TriggerRemoteOperationsTask(MatrixObject mo) {
- _prefetchMO = mo;
+ _remoteOperationsRoot = mo;
}
@Override
public void run() {
- boolean prefetched = false;
- synchronized (_prefetchMO) {
- // Having this check if operations are pending inside
the
- // critical section safeguards against concurrent rmVar.
- if (_prefetchMO.isPendingRDDOps() ||
_prefetchMO.isFederated()) {
- // TODO: Add robust runtime constraints for
federated prefetch
- // Execute and bring the result to local
- _prefetchMO.acquireReadAndRelease();
- prefetched = true;
+ boolean triggered = false;
+ synchronized (_remoteOperationsRoot) {
+ if (_remoteOperationsRoot.isPendingRDDOps()) {
+ JavaPairRDD<?, ?> rdd =
_remoteOperationsRoot.getRDDHandle().getRDD();
+
rdd.persist(Checkpoint.DEFAULT_STORAGE_LEVEL).count();
+
_remoteOperationsRoot.getRDDHandle().setCheckpointRDD(true);
+ triggered = true;
}
}
- if (DMLScript.STATISTICS && prefetched) {
- if (_prefetchMO.isFederated())
- FederatedStatistics.incAsyncPrefetchCount(1);
- else
- SparkStatistics.incAsyncPrefetchCount(1);
- }
- }
+ if (DMLScript.STATISTICS && triggered)
+ SparkStatistics.incAsyncTriggerRemoteCount(1);
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
similarity index 71%
copy from
src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
copy to
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
index 9d95a58dc3..98d7f440c4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
package org.apache.sysds.runtime.instructions.cp;
import java.util.concurrent.Executors;
@@ -26,29 +25,27 @@ import
org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;
-public class PrefetchCPInstruction extends UnaryCPInstruction {
- private PrefetchCPInstruction(Operator op, CPOperand in, CPOperand out,
String opcode, String istr) {
- super(CPType.Prefetch, op, in, out, opcode, istr);
+public class TriggerRemoteOpsCPInstruction extends UnaryCPInstruction {
+ private TriggerRemoteOpsCPInstruction(Operator op, CPOperand in,
CPOperand out, String opcode, String istr) {
+ super(CPType.TrigRemote, op, in, out, opcode, istr);
}
-
- public static PrefetchCPInstruction parseInstruction (String str) {
+
+ public static TriggerRemoteOpsCPInstruction parseInstruction (String
str) {
InstructionUtils.checkNumFields(str, 2);
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
- return new PrefetchCPInstruction(null, in, out, opcode, str);
+ return new TriggerRemoteOpsCPInstruction(null, in, out, opcode,
str);
}
@Override
public void processInstruction(ExecutionContext ec) {
- //TODO: handle non-matrix objects
+ // TODO: Operator placement.
+ // Note for testing: write a method in the Dag class to place
this operator
+ // after Spark MMRJ. Then execute
PrefetchRDDTest.testAsyncSparkOPs3.
ec.setVariable(output.getName(), ec.getMatrixObject(input1));
- // Note, a Prefetch instruction doesn't guarantee an
asynchronous execution.
- // If the next instruction which takes this output as an input
comes before
- // the prefetch thread triggers, that instruction will start
the operations.
- // In that case this Prefetch instruction will act like a NOOP.
if (CommonThreadPool.triggerRemoteOPsPool == null)
CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index 5263dbd119..3965feafdc 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -34,6 +34,7 @@ public class SparkStatistics {
private static final LongAdder broadcastCount = new LongAdder();
private static final LongAdder asyncPrefetchCount = new LongAdder();
private static final LongAdder asyncBroadcastCount = new LongAdder();
+ private static final LongAdder asyncTriggerRemoteCount = new
LongAdder();
public static boolean createdSparkContext() {
return ctxCreateTime > 0;
@@ -76,6 +77,10 @@ public class SparkStatistics {
asyncBroadcastCount.add(c);
}
+ public static void incAsyncTriggerRemoteCount(long c) {
+ asyncTriggerRemoteCount.add(c);
+ }
+
public static long getSparkCollectCount() {
return collectCount.longValue();
}
@@ -88,6 +93,10 @@ public class SparkStatistics {
return asyncBroadcastCount.longValue();
}
+ public static long getAsyncTriggerRemoteCount() {
+ return asyncTriggerRemoteCount.longValue();
+ }
+
public static void reset() {
ctxCreateTime = 0;
parallelizeTime.reset();
@@ -98,6 +107,7 @@ public class SparkStatistics {
collectCount.reset();
asyncPrefetchCount.reset();
asyncBroadcastCount.reset();
+ asyncTriggerRemoteCount.reset();
}
public static String displayStatistics() {
@@ -114,8 +124,8 @@ public class SparkStatistics {
broadcastTime.longValue()*1e-9,
collectTime.longValue()*1e-9));
if (OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS)
- sb.append("Spark async. count (pf,bc): \t" +
- String.format("%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount()));
+ sb.append("Spark async. count (pf,bc,tr): \t" +
+ String.format("%d/%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount(),
getAsyncTriggerRemoteCount()));
return sb.toString();
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index 5a884724f6..61279bd036 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -64,7 +64,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
@Test
public void testAsyncSparkOPs3() {
- //SP action type consumer. No Prefetch.
+ //SP binary consumer, followed by an action. No Prefetch.
runTest(TEST_NAME+"3");
}
diff --git a/src/test/scripts/functions/async/PrefetchRDD3.dml
b/src/test/scripts/functions/async/PrefetchRDD3.dml
index 15286e3034..340115b46e 100644
--- a/src/test/scripts/functions/async/PrefetchRDD3.dml
+++ b/src/test/scripts/functions/async/PrefetchRDD3.dml
@@ -30,7 +30,7 @@ sp2 = sp1 %*% t(sp1);
v = ((v + v) * 1 - v) / (1+1);
v = ((v + v) * 2 - v) / (2+1);
-# CP binary triggers the DAG of SP operations
-cp = sp2 + sum(v);
-R = sum(cp);
+# SP sum triggers the DAG of SP operations
+SP = sp2 + sum(v); #spark transformation
+R = sum(SP); #action
write(R, $1, format="text");