This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ddbf556f2 [SYSTEMDS-3088] No prefetch for List type consumers
8ddbf556f2 is described below

commit 8ddbf556f2d740900438f5fa4734fffc840644a0
Author: Arnab Phani <[email protected]>
AuthorDate: Sun Jun 11 14:31:36 2023 +0200

    [SYSTEMDS-3088] No prefetch for List type consumers
    
    This patch fixes a bug in prefetch placement and prevent
    prefetch if the consumer is of List type. List is not an
    operation. A prefetch can wrongly pull a Spark intermediate
    if the output goes into a List.
    e.g. rightindex -> List
    
    Closes #1840
---
 .../sysds/lops/rewrite/RewriteAddPrefetchLop.java  |  6 +++-
 .../test/functions/async/PrefetchRDDTest.java      | 13 ++++++--
 src/test/scripts/functions/async/PrefetchRDD5.dml  | 35 ++++++++++++++++++++++
 3 files changed, 51 insertions(+), 3 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
index 6eb52e0d9f..91b7f81e71 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
@@ -105,13 +105,17 @@ public class RewriteAddPrefetchLop extends LopRewriteRule
                        && !(lop instanceof MMTSJ) && !(lop instanceof 
UAggOuterChain)
                        && !(lop instanceof ParameterizedBuiltin) && !(lop 
instanceof SpoofFused);
 
+               // Exclude List consumers. List is just a metadata handle.
+               boolean anyOutputList = lop.getOutputs().stream()
+                       .anyMatch(out -> out.getDataType() == 
Types.DataType.LIST);
+
                //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
                boolean hasParameterizedOut = lop.getOutputs().stream()
                        .anyMatch(out -> ((out instanceof ParameterizedBuiltin)
                                || (out instanceof GroupedAggregate)
                                || (out instanceof GroupedAggregateM)));
                //TODO: support non-matrix outputs
-               return transformOP && !hasParameterizedOut
+               return transformOP && !hasParameterizedOut && !anyOutputList
                        && (lop.isAllOutputsCP() || 
OperatorOrderingUtils.isCollectForBroadcast(lop))
                        && lop.getDataType() == Types.DataType.MATRIX;
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index f821af5eb0..886a850d22 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -39,7 +39,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
        
        protected static final String TEST_DIR = "functions/async/";
        protected static final String TEST_NAME = "PrefetchRDD";
-       protected static final int TEST_VARIANTS = 4;
+       protected static final int TEST_VARIANTS = 5;
        protected static String TEST_CLASS_DIR = TEST_DIR + 
PrefetchRDDTest.class.getSimpleName() + "/";
        
        @Override
@@ -73,6 +73,12 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                runTest(TEST_NAME+"4");
        }
 
+       @Test
+       public void testAsyncSparkOPs5() {
+               //List type consumer. No Prefetch.
+               runTest(TEST_NAME+"5");
+       }
+
        public void runTest(String testname) {
                boolean old_trans_exec_type = 
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
                ExecMode oldPlatform = setExecMode(ExecMode.HYBRID);
@@ -108,7 +114,10 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        if (!matchVal)
                                System.out.println("Value w/o Prefetch "+R+" w/ 
Prefetch "+R_pf);
                        //assert Prefetch instructions and number of success.
-                       long expected_numPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
+                       long expected_numPF = 1;
+                       if (testname.equalsIgnoreCase(TEST_NAME+"3")
+                               || testname.equalsIgnoreCase(TEST_NAME+"5"))
+                               expected_numPF = 0;
                        //long expected_successPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
                        long numPF = 
Statistics.getCPHeavyHitterCount("prefetch");
                        Assert.assertTrue("Violated Prefetch instruction count: 
"+numPF, numPF == expected_numPF);
diff --git a/src/test/scripts/functions/async/PrefetchRDD5.dml 
b/src/test/scripts/functions/async/PrefetchRDD5.dml
new file mode 100644
index 0000000000..13d272c5a8
--- /dev/null
+++ b/src/test/scripts/functions/async/PrefetchRDD5.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+k = 2;
+#create empty lists
+dataset_X = list(); #empty list
+fs = ceil(nrow(X)/k);
+off = fs - 1;
+#devide X into lists of k matrices
+for (i in seq(1, k)) {
+  #List type consumer. No prefetch after rightindex.
+  dataset_X = append(dataset_X, X[i*fs-off : min(i*fs, nrow(X)),]);
+}
+[tmpX, testX] = remove(dataset_X, 1);
+R = sum(rbind(testX));
+write(R, $1, format="text");
+

Reply via email to