This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 22acbf3f3adaeac7bcf1e5806fcf71472c150cde Author: baunsgaard <[email protected]> AuthorDate: Wed Dec 21 13:53:35 2022 +0100 [MINOR] Make BinaryCPInstruction MultithreadedOperators This commit change the class of operation, to correctly reflect being multithreaded to allow applySchema to include the thread count in the instruction. The change is mainly semantic since it already is maintained that the operation class is multithreaded. --- .../org/apache/sysds/runtime/frame/data/FrameBlock.java | 13 ++++++++++++- .../apache/sysds/runtime/instructions/InstructionUtils.java | 3 ++- .../sysds/runtime/instructions/cp/BinaryCPInstruction.java | 3 ++- .../instructions/cp/BinaryFrameFrameCPInstruction.java | 6 +++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 80c3508fea..bf7ce6bfb7 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -1747,7 +1747,18 @@ public class FrameBlock implements CacheBlock<FrameBlock>, Externalizable { * @return A new FrameBlock with the schema applied. */ public FrameBlock applySchema(ValueType[] schema) { - return FrameLibApplySchema.applySchema(this, schema, InfrastructureAnalyzer.getLocalParallelism()); + return FrameLibApplySchema.applySchema(this, schema, 1); + } + + /** + * Method to create a new FrameBlock where the given schema is applied. + * + * @param schema of value types. + * @param k parallelization degree + * @return A new FrameBlock with the schema applied. + */ + public FrameBlock applySchema(ValueType[] schema, int k){ + return FrameLibApplySchema.applySchema(this, schema, k); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 5026175b59..e4e641e8f1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -99,6 +99,7 @@ import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes; import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -580,7 +581,7 @@ public class InstructionUtils new UnaryOperator(Builtin.getBuiltinFnObject(opcode)); } - public static Operator parseBinaryOrBuiltinOperator(String opcode, CPOperand in1, CPOperand in2) { + public static MultiThreadedOperator parseBinaryOrBuiltinOperator(String opcode, CPOperand in1, CPOperand in2) { if( LibCommonsMath.isSupportedMatrixMatrixOperation(opcode) ) return null; boolean matrixScalar = (in1.getDataType() != in2.getDataType()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 4d83d6dfe3..4c2ae8a2d5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -23,6 +23,7 @@ import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; import org.apache.sysds.runtime.matrix.operators.Operator; public abstract class BinaryCPInstruction extends ComputationCPInstruction { @@ -45,7 +46,7 @@ public abstract class BinaryCPInstruction extends ComputationCPInstruction { if(!(in1.getDataType() == DataType.FRAME || in2.getDataType() == DataType.FRAME)) checkOutputDataType(in1, in2, out); - Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2); + MultiThreadedOperator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2); if (in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR) return new BinaryScalarScalarCPInstruction(operator, in1, in2, out, opcode, str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index 85fa9ddffb..bd0ad427c1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -23,12 +23,12 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction { // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); - protected BinaryFrameFrameCPInstruction(Operator op, CPOperand in1, + protected BinaryFrameFrameCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { super(CPType.Binary, op, in1, in2, out, opcode, istr); } @@ -62,7 +62,7 @@ public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction { ValueType[] schema = new ValueType[inBlock2.getNumColumns()]; for(int i=0; i<inBlock2.getNumColumns(); i++) schema[i] = ValueType.fromExternalString(inBlock2.get(0, i).toString()); - ec.setFrameOutput(output.getName(), inBlock1.applySchema(schema)); + ec.setFrameOutput(output.getName(), inBlock1.applySchema(schema, ((MultiThreadedOperator)getOperator()).getNumThreads())); } else { // Execute binary operations
