This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 22acbf3f3adaeac7bcf1e5806fcf71472c150cde
Author: baunsgaard <[email protected]>
AuthorDate: Wed Dec 21 13:53:35 2022 +0100

    [MINOR] Make BinaryCPInstruction MultithreadedOperators
    
    This commit change the class of operation, to correctly reflect being
    multithreaded to allow applySchema to include the thread
    count in the instruction.
    
    The change is mainly semantic since it already is maintained that the
    operation class is multithreaded.
---
 .../org/apache/sysds/runtime/frame/data/FrameBlock.java     | 13 ++++++++++++-
 .../apache/sysds/runtime/instructions/InstructionUtils.java |  3 ++-
 .../sysds/runtime/instructions/cp/BinaryCPInstruction.java  |  3 ++-
 .../instructions/cp/BinaryFrameFrameCPInstruction.java      |  6 +++---
 4 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index 80c3508fea..bf7ce6bfb7 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -1747,7 +1747,18 @@ public class FrameBlock implements 
CacheBlock<FrameBlock>, Externalizable {
         * @return A new FrameBlock with the schema applied.
         */
        public FrameBlock applySchema(ValueType[] schema) {
-               return FrameLibApplySchema.applySchema(this, schema, 
InfrastructureAnalyzer.getLocalParallelism());
+               return FrameLibApplySchema.applySchema(this, schema, 1);
+       }
+
+       /**
+        * Method to create a new FrameBlock where the given schema is applied.
+        * 
+        * @param schema of value types.
+        * @param k parallelization degree
+        * @return A new FrameBlock with the schema applied.
+        */
+       public FrameBlock applySchema(ValueType[] schema, int k){
+               return FrameLibApplySchema.applySchema(this, schema, k);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 5026175b59..e4e641e8f1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -99,6 +99,7 @@ import org.apache.sysds.runtime.matrix.operators.CMOperator;
 import 
org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -580,7 +581,7 @@ public class InstructionUtils
                        new UnaryOperator(Builtin.getBuiltinFnObject(opcode));
        }
 
-       public static Operator parseBinaryOrBuiltinOperator(String opcode, 
CPOperand in1, CPOperand in2) {
+       public static MultiThreadedOperator parseBinaryOrBuiltinOperator(String 
opcode, CPOperand in1, CPOperand in2) {
                if( LibCommonsMath.isSupportedMatrixMatrixOperation(opcode) )
                        return null;
                boolean matrixScalar = (in1.getDataType() != in2.getDataType());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index 4d83d6dfe3..4c2ae8a2d5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -23,6 +23,7 @@ import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public abstract class BinaryCPInstruction extends ComputationCPInstruction {
@@ -45,7 +46,7 @@ public abstract class BinaryCPInstruction extends 
ComputationCPInstruction {
                if(!(in1.getDataType() == DataType.FRAME || in2.getDataType() 
== DataType.FRAME))
                        checkOutputDataType(in1, in2, out);
                
-               Operator operator = 
InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
+               MultiThreadedOperator operator = 
InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
 
                if (in1.getDataType() == DataType.SCALAR && in2.getDataType() 
== DataType.SCALAR)
                        return new BinaryScalarScalarCPInstruction(operator, 
in1, in2, out, opcode, str);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
index 85fa9ddffb..bd0ad427c1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
@@ -23,12 +23,12 @@ import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
 
 public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction {
        // private static final Log LOG = 
LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName());
 
-       protected BinaryFrameFrameCPInstruction(Operator op, CPOperand in1,
+       protected BinaryFrameFrameCPInstruction(MultiThreadedOperator op, 
CPOperand in1,
                        CPOperand in2, CPOperand out, String opcode, String 
istr) {
                super(CPType.Binary, op, in1, in2, out, opcode, istr);
        }
@@ -62,7 +62,7 @@ public class BinaryFrameFrameCPInstruction extends 
BinaryCPInstruction {
                        ValueType[] schema = new 
ValueType[inBlock2.getNumColumns()];
                        for(int i=0; i<inBlock2.getNumColumns(); i++)
                                schema[i] = 
ValueType.fromExternalString(inBlock2.get(0, i).toString());
-                       ec.setFrameOutput(output.getName(), 
inBlock1.applySchema(schema));
+                       ec.setFrameOutput(output.getName(), 
inBlock1.applySchema(schema, 
((MultiThreadedOperator)getOperator()).getNumThreads()));
                }
                else {
                        // Execute binary operations

Reply via email to