This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b3fa93abf [SYSTEMDS-3695] Fix spark frame cbind for misaligned inputs
4b3fa93abf is described below

commit 4b3fa93abfd43b9101c09a5bba5094992112c27a
Author: e-strauss <lathan...@gmx.de>
AuthorDate: Wed Jun 5 09:30:43 2024 +0200

    [SYSTEMDS-3695] Fix spark frame cbind for misaligned inputs
    
    Closes #2031.
---
 .../spark/FrameAppendRSPInstruction.java           | 102 +++++++++++++++++++--
 .../test/functions/frame/FrameAppendDistTest.java  |  74 +++++++++++++--
 .../functions/frame/FrameNAryAppendMisalignRSP.R   |  35 +++++++
 .../functions/frame/FrameNAryAppendMisalignRSP.dml |  31 +++++++
 .../functions/frame/FrameNAryAppendMisalignRSP2.R  |  31 +++++++
 .../frame/FrameNAryAppendMisalignRSP2.dml          |  28 ++++++
 6 files changed, 288 insertions(+), 13 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
index af1be2b0c4..8774c63ed7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.spark;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -31,6 +32,11 @@ import 
org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import scala.Tuple2;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
 public class FrameAppendRSPInstruction extends AppendRSPInstruction {
 
        protected FrameAppendRSPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, boolean cbind,
@@ -43,7 +49,7 @@ public class FrameAppendRSPInstruction extends 
AppendRSPInstruction {
                SparkExecutionContext sec = (SparkExecutionContext)ec;
                JavaPairRDD<Long,FrameBlock> in1 = 
sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() );
                JavaPairRDD<Long,FrameBlock> in2 = 
sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() );
-               JavaPairRDD<Long,FrameBlock> out = null;
+               JavaPairRDD<Long,FrameBlock> out;
                long leftRows = 
sec.getDataCharacteristics(input1.getName()).getRows();
 
                out = appendFrameRSP(in1, in2, leftRows, _cbind);
@@ -65,11 +71,14 @@ public class FrameAppendRSPInstruction extends 
AppendRSPInstruction {
 
        public static JavaPairRDD<Long, FrameBlock> 
appendFrameRSP(JavaPairRDD<Long, FrameBlock> in1, JavaPairRDD<Long, FrameBlock> 
in2, long leftRows, boolean cbind) {
                if(cbind) {
-                       JavaPairRDD<Long,FrameBlock> in1Aligned = 
in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
-                       in1Aligned = 
FrameRDDAggregateUtils.mergeByKey(in1Aligned);
-                       JavaPairRDD<Long,FrameBlock> in2Aligned = 
in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
+                       //TODO preserve info if already aligned, and only align 
if necessary
+                       //get in1 keys
+                       long[] row_indices = 
in1.keys().collect().stream().mapToLong(Long::longValue).toArray();
+                       Arrays.sort(row_indices);
+                       //Align the blocks of in2 on the blocks of in1
+                       JavaPairRDD<Long,FrameBlock> in2Aligned = 
in2.flatMapToPair(new ReduceSideAppendAlignToLHSFunction(row_indices, 
leftRows));
                        in2Aligned = 
FrameRDDAggregateUtils.mergeByKey(in2Aligned);
-                       return in1Aligned.join(in2Aligned).mapValues(new 
ReduceSideColumnsFunction(cbind));
+                       return in1.join(in2Aligned).mapValues(new 
ReduceSideColumnsFunction(cbind));
                } else {        //rbind
                        JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new 
ReduceSideAppendRowsFunction(leftRows));
                        return in1.union(right);
@@ -94,6 +103,86 @@ public class FrameAppendRSPInstruction extends 
AppendRSPInstruction {
                }
        }
 
+       private static class ReduceSideAppendAlignToLHSFunction implements 
PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
+       {
+               private static final long serialVersionUID = 
5850400295183766409L;
+
+               private final long[] _indices;
+               private final long lastIndex; //max_rows + 1
+
+               public ReduceSideAppendAlignToLHSFunction(long[] indices, long 
max_rows) {
+                       _indices = indices;
+                       lastIndex = max_rows + 1;
+               }
+
+               @Override
+               public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, 
FrameBlock> arg0)
+               {
+                       List<Tuple2<Long, FrameBlock>> aligned_blocks = new 
ArrayList<>();
+                       long indexRHS = arg0._1();
+                       FrameBlock fb = arg0._2();
+
+                       //find the block index ix in the LHS with the smallest 
index s.t. following LHS indix ix' > indexRHS >= ix
+                       //doing binary search
+                       int L = 0;
+                       int R = _indices.length - 1;
+                       int m;
+                       while(L <= R){
+                               m = (L + R) / 2;
+                               if(_indices[m] == indexRHS){
+                                       R = m;
+                                       break;
+                               }
+                               if(_indices[m] < indexRHS)
+                                       L = m + 1;
+                       else
+                                       R = m - 1;
+                       }
+                       // search terminates if we have found the exact 
indexRHS or binary search reached the leaf nodes where
+                       // L == R (bucket size = 1) and m == L which implies 
that _indices[m+1] > indexRHS (otherwise we would
+                       // have considered this index in the search
+                       // if _indices[m] < indexRHS than m contains the index 
which fits our definition and R == m
+                       // else (m - 1) fits our definition which is stored and 
R = m - 1
+                       // therefore in all cases the correct position of the 
indexLHS  is stored in R
+                       long indexLHS = _indices[R];
+
+                       //assumes total num rows LHS == RHS
+                       long nextIndexLHS = R < _indices.length - 1? 
_indices[R+1] : this.lastIndex;
+                       int blkSizeLHS = (int) (nextIndexLHS -  indexLHS);
+                       int offsetLHS = (int) (indexRHS - indexLHS);
+                       int offsetRHS = 0;
+                       int sizeOfSlice = blkSizeLHS - offsetLHS;
+
+                       FrameBlock resultBlock = new FrameBlock(fb.getSchema());
+                       resultBlock.ensureAllocatedColumns(blkSizeLHS);
+
+                       int sizeOfRHS = fb.getNumRows();
+                       while(sizeOfSlice < sizeOfRHS){
+                               FrameBlock fb_sliced = fb.slice(offsetRHS, 
offsetRHS + sizeOfSlice - 1);
+                               resultBlock = 
resultBlock.leftIndexingOperations(fb_sliced,offsetLHS, offsetLHS + sizeOfSlice 
- 1, 0, fb.getNumColumns()-1, new FrameBlock());
+                               aligned_blocks.add(new Tuple2<>(indexLHS, 
resultBlock));
+                               resultBlock = new FrameBlock(fb.getSchema());
+                               if(R >= _indices.length - 1)
+                                       throw new RuntimeException("Alignment 
Error while CBIND: LHS has fewer rows than RHS");
+                               indexLHS = nextIndexLHS;
+                               offsetRHS += sizeOfSlice;
+                               offsetLHS = 0;
+                               sizeOfRHS -= sizeOfSlice;
+                               R++;
+                               nextIndexLHS =  R < _indices.length - 1? 
_indices[R+1] : this.lastIndex;
+                               sizeOfSlice = (int) (nextIndexLHS -  indexLHS); 
//sizeOfSlice = blkSizeLHS
+                               resultBlock.ensureAllocatedColumns(sizeOfSlice);
+                       }
+                       //RHS fits into aligned LHS block
+                       if(offsetRHS != 0)
+                               fb = fb.slice(offsetRHS, offsetRHS + sizeOfRHS 
- 1);
+                       resultBlock = resultBlock.leftIndexingOperations(fb, 
offsetLHS, offsetLHS + fb.getNumRows() - 1, 0, fb.getNumColumns()-1, new 
FrameBlock());
+                       aligned_blocks.add(new Tuple2<>(indexLHS, resultBlock));
+
+                       return aligned_blocks.iterator();
+               }
+       }
+
        private static class ReduceSideAppendRowsFunction implements 
PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> 
        {
                private static final long serialVersionUID = 
1723795153048336791L;
@@ -112,12 +201,13 @@ public class FrameAppendRSPInstruction extends 
AppendRSPInstruction {
                }
        }
 
+       @SuppressWarnings("unused")
        private static class ReduceSideAppendAlignFunction implements 
PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
        {
                private static final long serialVersionUID = 
5850400295183766409L;
 
                private long _rows;
-                               
+               
                public ReduceSideAppendAlignFunction(long rows) {
                        _rows = rows;
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java 
b/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
index c6cde96167..ebf6682742 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
@@ -43,6 +43,9 @@ public class FrameAppendDistTest extends AutomatedTestBase
        private final static String TEST_NAME = "FrameAppend";
        private final static String TEST_NAME2 = "FrameNAryAppend";
        private final static String TEST_NAME3 = "FrameNAryAppendMisalign";
+       private final static String TEST_NAME4 = "FrameNAryAppendMisalignRSP";
+       private final static String TEST_NAME5 = "FrameNAryAppendMisalignRSP2";
+
        private final static String TEST_DIR = "functions/frame/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FrameAppendDistTest.class.getSimpleName() + "/";
 
@@ -71,17 +74,19 @@ public class FrameAppendDistTest extends AutomatedTestBase
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[] {"C"}));
                addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3,new String[] {"C"}));
+               addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4,new String[] {"C"}));
+               addTestConfiguration(TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5,new String[] {"C"}));
        }
 
        @Test
        public void testAppendInBlock1DenseSP() {
                commonAppendTest(ExecMode.SPARK, rows1, rows1, cols1a, cols2a, 
false, AppendMethod.MR_RAPPEND, false, TEST_NAME);
-       }   
+       }
        
        @Test
        public void testAppendInBlock1SparseSP() {
                commonAppendTest(ExecMode.SPARK, rows1, rows1, cols1a, cols2a, 
true, AppendMethod.MR_RAPPEND, false, TEST_NAME);
-       }   
+       }
        
        @Test
        public void testAppendInBlock1DenseRBindSP() {
@@ -116,22 +121,77 @@ public class FrameAppendDistTest extends AutomatedTestBase
 
        @Test
        public void testNAryCAppendMSP(){
-               commonAppendTest(ExecMode.SPARK ,100, 100, 5, 10, false, null, 
false, TEST_NAME2);;
+               commonAppendTest(ExecMode.SPARK ,100, 100, 5, 10, false, null, 
false, TEST_NAME2);
        }
 
        @Test
        public void testNAryCAppendRSP(){
-               commonAppendTest(ExecMode.SPARK ,30, 30, 5, 1001, false, null, 
false, TEST_NAME2);;
+               commonAppendTest(ExecMode.SPARK ,30, 30, 5, 1001, false, null, 
false, TEST_NAME2);
        }
 
        @Test
        public void testNAryRAppendSP(){
-               commonAppendTest(ExecMode.SPARK ,100, 100, 5, 5, false, null, 
true, TEST_NAME2);;
+               commonAppendTest(ExecMode.SPARK ,100, 100, 5, 5, false, null, 
true, TEST_NAME2);
        }
 
        @Test
        public void testNAryAppendWithMisalignmentMSP(){
-               commonAppendTest(ExecMode.SPARK ,5, 10, 5, 5, false, null, 
false, TEST_NAME3);;
+               commonAppendTest(ExecMode.SPARK ,5, 10, 5, 5, false, null, 
false, TEST_NAME3);
+       }
+
+       @Test
+       public void testNAryAppendWithMisalignmentRSP() {
+               commonAppendTest(ExecMode.SPARK, 5, 10, 1001, 1001, false, 
null, false, TEST_NAME3);
+       }
+
+// NAryAppendWithMisalignmentRSP2:
+// LHS:                RHS:
+// +---------+         +-----+
+// |         |         +-----+
+// |         |         +-----+
+// |         |         +-----+
+// +---------+         +-----+
+       @Test
+       public void testNAryAppendWithMisalignmentRSP2(){
+               commonAppendTest(ExecMode.SPARK ,20, 5, 1001, 1005, false, 
null, false, TEST_NAME4);
+       }
+
+// NAryAppendWithMisalignmentRSP3:
+//      LHS:            RHS:
+//      +-----+         +---------+
+//      +-----+         |         |
+//      +-----+         |         |
+//      +-----+         |         |
+//      +-----+         +---------+
+       @Test
+       public void testNAryAppendWithMisalignmentRSP3(){
+               commonAppendTest(ExecMode.SPARK ,5, 20, 1001, 1005, false, 
null, false, TEST_NAME4);
+       }
+// NAryAppendWithMisalignmentRSP4:
+//      LHS:            RHS:
+//      +-----+         +---------+
+//      |     |         +---------+
+//      +-----+         |         |
+//      |     |         +---------+
+//      +-----+         |         |
+//      |     |         +---------+
+//      +-----+         |         |
+//      +-----+         +---------+
+       @Test
+       public void testNAryAppendWithMisalignmentRSP4(){
+               commonAppendTest(ExecMode.SPARK ,20, 5, 1001, 1001, false, 
null, false, TEST_NAME5);
+       }
+// NAryAppendWithMisalignmentRSP5:
+//      LHS:            RHS:
+//      +-----+         +---------+
+//      +-----+         |         |
+//      +-----+         +---------+
+//      +-----+         +---------+
+//      |     |         +---------+
+//      +-----+         +---------+
+       @Test
+       public void testNAryAppendWithMisalignmentRSP5(){
+               commonAppendTest(ExecMode.SPARK ,8, 20, 1001, 1001, false, 
null, false, TEST_NAME5);
        }
 
        
@@ -142,7 +202,7 @@ public class FrameAppendDistTest extends AutomatedTestBase
                
                ExecMode prevPlfm=rtplatform;
                
-               double sparsity = (sparse) ? sparsity2 : sparsity1; 
+               double sparsity = (sparse) ? sparsity2 : sparsity1;
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                //setOutputBuffering(true);
                try
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R 
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R
new file mode 100644
index 0000000000..1ec13f690a
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, 
stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, 
stringsAsFactors=FALSE)
+if(nrow(A) > nrow(B)){
+    t=rbind(B, B, B, B)
+    C=cbind(A, t)
+} else {
+    t= rbind(A, A, A, A)
+    C=cbind(t, B)
+}
+write.csv(C, paste(args[2], "C.csv", sep=""), row.names = FALSE, quote = FALSE)
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml 
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml
new file mode 100644
index 0000000000..3633733da6
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A=read($1, data_type="frame", rows=$2, cols=$3, format="binary")
+B=read($4, data_type="frame", rows=$5, cols=$6, format="binary")
+if($2 > $5){
+    t=rbind(B, B, B, B)
+    C=cbind(A, t)
+} else {
+    t= rbind(A, A, A, A)
+    C=cbind(t, B)
+}
+write(C, $7, format="binary")
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R 
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R
new file mode 100644
index 0000000000..9265c9b3a4
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, 
stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, 
stringsAsFactors=FALSE)
+t=rbind(B, A, A, A)
+t2=rbind(A, A, A, B)
+C=cbind(t2, t)
+write.csv(C, paste(args[2], "C.csv", sep=""), row.names = FALSE, quote = FALSE)
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml 
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml
new file mode 100644
index 0000000000..c761c428fd
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A=read($1, data_type="frame", rows=$2, cols=$3, format="binary")
+B=read($4, data_type="frame", rows=$5, cols=$6, format="binary")
+
+t=rbind(B, A, A, A)
+t2=rbind(A, A, A, B)
+C=cbind(t2, t)
+write(C, $7, format="binary")

Reply via email to