This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1dfffd5  [SYSTEMDS-3338] Multi-threaded local Qsort instruction
1dfffd5 is described below

commit 1dfffd5337802b5513104386f3ec49a1d8853880
Author: arnabp <[email protected]>
AuthorDate: Wed Mar 23 16:27:51 2022 +0100

    [SYSTEMDS-3338] Multi-threaded local Qsort instruction
    
    This patch updates the QuantileSort instruction to use a multithreaded
    sort for local and column-partitioned federated sites.
    This change improves quantile by 2.5x for 100M rows.
    
    Closes #1571
---
 src/main/java/org/apache/sysds/hops/BinaryOp.java  | 14 ++++--
 src/main/java/org/apache/sysds/hops/TernaryOp.java |  7 ++-
 src/main/java/org/apache/sysds/hops/UnaryOp.java   | 20 ++++----
 src/main/java/org/apache/sysds/lops/SortKeys.java  | 57 ++++++++++++++--------
 .../runtime/instructions/InstructionUtils.java     | 10 +++-
 .../instructions/cp/QuantileSortCPInstruction.java | 44 +++++++++++++----
 .../fed/QuantileSortFEDInstruction.java            | 55 ++++++++++++++++-----
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  6 ++-
 8 files changed, 152 insertions(+), 61 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 73deda4..1151135 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -181,7 +181,10 @@ public class BinaryOp extends MultiThreadedHop {
        public boolean isMultiThreadedOpType() {
                return !getDataType().isScalar()
                        || getOp() == OpOp2.COV
-                       || getOp() == OpOp2.MOMENT;
+                       || getOp() == OpOp2.MOMENT
+                       || getOp() == OpOp2.IQM
+                       || getOp() == OpOp2.MEDIAN
+                       || getOp() == OpOp2.QUANTILE;
        }
        
        @Override
@@ -233,11 +236,12 @@ public class BinaryOp extends MultiThreadedHop {
        }
        
        private void constructLopsIQM(ExecType et) {
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                SortKeys sort = SortKeys.constructSortByValueLop(
                                getInput().get(0).constructLops(), 
                                getInput().get(1).constructLops(), 
                                SortKeys.OperationTypes.WithWeights, 
-                               getInput().get(0).getDataType(), 
getInput().get(0).getValueType(), et);
+                               getInput().get(0).getDataType(), 
getInput().get(0).getValueType(), et, k);
                sort.getOutputParameters().setDimensions(
                                getInput().get(0).getDim1(),
                                getInput().get(0).getDim2(), 
@@ -256,11 +260,12 @@ public class BinaryOp extends MultiThreadedHop {
        }
        
        private void constructLopsMedian(ExecType et) {
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                SortKeys sort = SortKeys.constructSortByValueLop(
                                getInput().get(0).constructLops(), 
                                getInput().get(1).constructLops(), 
                                SortKeys.OperationTypes.WithWeights, 
-                               getInput().get(0).getDataType(), 
getInput().get(0).getValueType(), et);
+                               getInput().get(0).getDataType(), 
getInput().get(0).getValueType(), et, k);
                sort.getOutputParameters().setDimensions(
                                getInput().get(0).getDim1(),
                                getInput().get(0).getDim2(),
@@ -317,10 +322,11 @@ public class BinaryOp extends MultiThreadedHop {
                else
                        pick_op = PickByCount.OperationTypes.RANGEPICK;
 
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                SortKeys sort = SortKeys.constructSortByValueLop(
                        getInput().get(0).constructLops(), 
                        SortKeys.OperationTypes.WithoutWeights, 
-                       DataType.MATRIX, ValueType.FP64, et );
+                       DataType.MATRIX, ValueType.FP64, et, k );
                sort.getOutputParameters().setDimensions(
                        getInput().get(0).getDim1(),
                        getInput().get(0).getDim2(),
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java 
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index b7ad4fd..a754f1b 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -149,7 +149,9 @@ public class TernaryOp extends MultiThreadedHop
        public boolean isMultiThreadedOpType() {
                return _op == OpOp3.IFELSE
                        || _op == OpOp3.MINUS_MULT
-                       || _op == OpOp3.PLUS_MULT;
+                       || _op == OpOp3.PLUS_MULT
+                       || _op == OpOp3.QUANTILE
+                       || _op == OpOp3.INTERQUANTILE;
        }
        
        @Override
@@ -247,9 +249,10 @@ public class TernaryOp extends MultiThreadedHop
                        throw new HopsException("Unexpected operation: " + _op 
+ ", expecting " + OpOp3.QUANTILE + " or " + OpOp3.INTERQUANTILE );
                
                ExecType et = optFindExecType();
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                SortKeys sort = 
SortKeys.constructSortByValueLop(getInput().get(0).constructLops(),
                        getInput().get(1).constructLops(), 
SortKeys.OperationTypes.WithWeights, 
-                       getInput().get(0).getDataType(), 
getInput().get(0).getValueType(), et);
+                       getInput().get(0).getDataType(), 
getInput().get(0).getValueType(), et, k);
                PickByCount pick = new PickByCount(sort, 
getInput().get(2).constructLops(),
                        getDataType(), getValueType(), (_op == OpOp3.QUANTILE) ?
                        PickByCount.OperationTypes.VALUEPICK : 
PickByCount.OperationTypes.RANGEPICK, et, true);
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java 
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 009d6ff..25a1202 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -197,12 +197,11 @@ public class UnaryOp extends MultiThreadedHop
        private Lop constructLopsMedian() 
        {
                ExecType et = optFindExecType();
-
-               
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                SortKeys sort = SortKeys.constructSortByValueLop(
                                                        
getInput().get(0).constructLops(), 
                                                        
SortKeys.OperationTypes.WithoutWeights, 
-                                                       DataType.MATRIX, 
ValueType.FP64, et );
+                                                       DataType.MATRIX, 
ValueType.FP64, et, k );
                sort.getOutputParameters().setDimensions(
                                getInput().get(0).getDim1(),
                                getInput().get(0).getDim2(),
@@ -225,14 +224,13 @@ public class UnaryOp extends MultiThreadedHop
        
        private Lop constructLopsIQM() 
        {
-
                ExecType et = optFindExecType();
-
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                Hop input = getInput().get(0);
-                               SortKeys sort = 
SortKeys.constructSortByValueLop(
-                               input.constructLops(), 
-                               SortKeys.OperationTypes.WithoutWeights, 
-                               DataType.MATRIX, ValueType.FP64, et );
+               SortKeys sort = SortKeys.constructSortByValueLop(
+                       input.constructLops(),
+                       SortKeys.OperationTypes.WithoutWeights,
+                       DataType.MATRIX, ValueType.FP64, et, k );
                sort.getOutputParameters().setDimensions(
                                input.getDim1(),
                                input.getDim2(),
@@ -456,7 +454,9 @@ public class UnaryOp extends MultiThreadedHop
                        || _op == OpOp1.LOG
                        || _op == OpOp1.SIGMOID
                        || _op == OpOp1.COMPRESS
-                       || _op == OpOp1.DECOMPRESS);
+                       || _op == OpOp1.DECOMPRESS
+                       || _op == OpOp1.MEDIAN
+                       || _op == OpOp1.IQM);
        }
        
        public boolean isMetadataOperation() {
diff --git a/src/main/java/org/apache/sysds/lops/SortKeys.java 
b/src/main/java/org/apache/sysds/lops/SortKeys.java
index 01c1594..f7a0f82 100644
--- a/src/main/java/org/apache/sysds/lops/SortKeys.java
+++ b/src/main/java/org/apache/sysds/lops/SortKeys.java
@@ -39,31 +39,34 @@ public class SortKeys extends Lop
        }
        
        private OperationTypes operation;
-       
+
+       private int _numThreads;
+
        public OperationTypes getOpType() {
                return operation;
        }
 
-       public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType 
vt, ExecType et) {
+       public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType 
vt, ExecType et, int numThreads) {
                super(Lop.Type.SortKeys, dt, vt);
-               init(input, null, op, et);
+               init(input, null, op, et, numThreads);
        }
        
        public SortKeys(Lop input, boolean desc, OperationTypes op, DataType 
dt, ValueType vt, ExecType et) {
                super(Lop.Type.SortKeys, dt, vt);
-               init(input, null, op, et);
+               init(input, null, op, et, 1);
        }
 
-       public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, 
ValueType vt, ExecType et) {
-               super(Lop.Type.SortKeys, dt, vt);               
-               init(input1, input2, op, et);
+       public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, 
ValueType vt, ExecType et, int numThreads) {
+               super(Lop.Type.SortKeys, dt, vt);
+               init(input1, input2, op, et, numThreads);
        }
        
-       private void init(Lop input1, Lop input2, OperationTypes op, ExecType 
et) {
+       private void init(Lop input1, Lop input2, OperationTypes op, ExecType 
et, int numThreads) {
                addInput(input1);
                input1.addOutput(this);
                
                operation = op;
+               _numThreads = numThreads;
                
                // SortKeys can accept a optional second input only when 
executing in CP
                // Example: sorting with weights inside CP
@@ -82,43 +85,57 @@ public class SortKeys extends Lop
 
        @Override
        public String getInstructions(String input, String output) {
-               return InstructionUtils.concatOperands(
+               StringBuilder sb = new StringBuilder();
+               sb.append(InstructionUtils.concatOperands(
                        getExecType().name(),
                        OPCODE,
                        getInputs().get(0).prepInputOperand(input),
-                       prepOutputOperand(output));
+                       prepOutputOperand(output)));
+
+                       if( getExecType() == ExecType.CP ) {
+                               sb.append( OPERAND_DELIMITOR );
+                               sb.append(_numThreads);
+                       }
+               return sb.toString();
        }
        
        @Override
        public String getInstructions(String input1, String input2, String 
output) {
-               return InstructionUtils.concatOperands(
+               StringBuilder sb = new StringBuilder();
+               sb.append(InstructionUtils.concatOperands(
                        getExecType().name(),
                        OPCODE,
                        getInputs().get(0).prepInputOperand(input1),
                        getInputs().get(1).prepInputOperand(input2),
-                       prepOutputOperand(output));
+                       prepOutputOperand(output)));
+
+               if( getExecType() == ExecType.CP ) {
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append(_numThreads);
+               }
+               return sb.toString();
        }
-       
+
        // This method is invoked in two cases:
        // 1) SortKeys (both weighted and unweighted) executes in MR
        // 2) Unweighted SortKeys executes in CP
-       public static SortKeys constructSortByValueLop(Lop input1, 
OperationTypes op, 
-                       DataType dt, ValueType vt, ExecType et) {
-               
+       public static SortKeys constructSortByValueLop(Lop input1, 
OperationTypes op,
+               DataType dt, ValueType vt, ExecType et, int numThreads) {
+
                for (Lop lop  : input1.getOutputs()) {
                        if ( lop.type == Lop.Type.SortKeys ) {
                                return (SortKeys)lop;
                        }
                }
-               
-               SortKeys retVal = new SortKeys(input1, op, dt, vt, et);
+
+               SortKeys retVal = new SortKeys(input1, op, dt, vt, et, 
numThreads);
                retVal.setAllPositions(input1.getFilename(), 
input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), 
input1.getEndColumn());
                return retVal;
        }
 
        // This method is invoked ONLY for the case of Weighted SortKeys 
executing in CP
        public static SortKeys constructSortByValueLop(Lop input1, Lop input2, 
OperationTypes op, 
-                       DataType dt, ValueType vt, ExecType et) {
+                       DataType dt, ValueType vt, ExecType et, int numThreads) 
{
                
                HashSet<Lop> set1 = new HashSet<>();
                set1.addAll(input1.getOutputs());
@@ -131,7 +148,7 @@ public class SortKeys extends Lop
                        }
                }
                
-               SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et);
+               SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et, 
numThreads);
                retVal.setAllPositions(input1.getFilename(), 
input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), 
input1.getEndColumn());
                return retVal;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index f22fdfe..39fbef2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -225,7 +225,15 @@ public class InstructionUtils
                
                return ret;
        }
-       
+
+       public static String stripThreadCount(String str) {
+               String[] parts = str.split(Instruction.OPERAND_DELIM, -1);
+               String[] ret = new String[parts.length-1];
+               for (int i=0; i<parts.length-1; i++) //strip-off the thread 
count
+                       ret[i] = parts[i];
+               return concatOperands(ret);
+       }
+
        public static ExecType getExecType( String str ) {
                try{
                        int ix = str.indexOf(Instruction.OPERAND_DELIM);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java
index 3e953d2..3123af5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java
@@ -36,14 +36,35 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
  *  
  */
 public class QuantileSortCPInstruction extends UnaryCPInstruction {
+       int _numThreads;
 
-       private QuantileSortCPInstruction(CPOperand in, CPOperand out, String 
opcode, String istr) {
-               this(in, null, out, opcode, istr);
+       private QuantileSortCPInstruction(CPOperand in, CPOperand out, String 
opcode, String istr, int k) {
+               this(in, null, out, opcode, istr, k);
        }
 
        private QuantileSortCPInstruction(CPOperand in1, CPOperand in2, 
CPOperand out, String opcode,
-                       String istr) {
+                       String istr, int k) {
                super(CPType.QSort, null, in1, in2, out, opcode, istr);
+               _numThreads = k;
+       }
+
+       private static void parseInstruction(String instr, CPOperand in1, 
CPOperand in2, CPOperand out) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
+
+               out.split(parts[parts.length-2]);
+
+               switch(parts.length) {
+                       case 4:
+                               in1.split(parts[1]);
+                               in2 = null;
+                               break;
+                       case 5:
+                               in1.split(parts[1]);
+                               in2.split(parts[2]);
+                               break;
+                       default:
+                               throw new DMLRuntimeException("Unexpected 
number of operands in the instruction: " + instr);
+               }
        }
 
        public static QuantileSortCPInstruction parseInstruction ( String str ) 
{
@@ -55,16 +76,19 @@ public class QuantileSortCPInstruction extends 
UnaryCPInstruction {
                String opcode = parts[0];
                
                if ( opcode.equalsIgnoreCase(SortKeys.OPCODE) ) {
-                       if ( parts.length == 3 ) {
+                       int k = Integer.parseInt(parts[parts.length-1]); 
//#threads
+                       if ( parts.length == 4 ) {
                                // Example: sort:mVar1:mVar2 (input=mVar1, 
output=mVar2)
-                               parseUnaryInstruction(str, in1, out);
-                               return new QuantileSortCPInstruction(in1, out, 
opcode, str);
+                               InstructionUtils.checkNumFields(str, 3);
+                               parseInstruction(str, in1, null, out);
+                               return new QuantileSortCPInstruction(in1, out, 
opcode, str, k);
                        }
-                       else if ( parts.length == 4 ) {
+                       else if ( parts.length == 5 ) {
                                // Example: sort:mVar1:mVar2:mVar3 
(input=mVar1, weights=mVar2, output=mVar3)
+                               InstructionUtils.checkNumFields(str, 4);
                                in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                               parseUnaryInstruction(str, in1, in2, out);
-                               return new QuantileSortCPInstruction(in1, in2, 
out, opcode, str);
+                               parseInstruction(str, in1, in2, out);
+                               return new QuantileSortCPInstruction(in1, in2, 
out, opcode, str, k);
                        }
                        else {
                                throw new DMLRuntimeException("Invalid number 
of operands in instruction: " + str);
@@ -85,7 +109,7 @@ public class QuantileSortCPInstruction extends 
UnaryCPInstruction {
                }
                
                //process core instruction
-               MatrixBlock resultBlock = matBlock.sortOperations(wtBlock, new 
MatrixBlock());
+               MatrixBlock resultBlock = matBlock.sortOperations(wtBlock, new 
MatrixBlock(), _numThreads);
                
                //release inputs
                ec.releaseMatrixInput(input1.getName());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index f84be32..91aaf81 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -37,15 +37,36 @@ import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-public class QuantileSortFEDInstruction extends UnaryFEDInstruction{
+public class QuantileSortFEDInstruction extends UnaryFEDInstruction {
+       int _numThreads;
 
-       private QuantileSortFEDInstruction(CPOperand in, CPOperand out, String 
opcode, String istr) {
-               this(in, null, out, opcode, istr);
+       private QuantileSortFEDInstruction(CPOperand in, CPOperand out, String 
opcode, String istr, int k) {
+               this(in, null, out, opcode, istr, k);
        }
 
        private QuantileSortFEDInstruction(CPOperand in1, CPOperand in2, 
CPOperand out, String opcode,
-               String istr) {
+               String istr, int k) {
                super(FEDInstruction.FEDType.QSort, null, in1, in2, out, 
opcode, istr);
+               _numThreads = k;
+       }
+
+       private static void parseInstruction(String instr, CPOperand in1, 
CPOperand in2, CPOperand out) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
+
+               out.split(parts[parts.length-2]);
+
+               switch(parts.length) {
+                       case 4:
+                               in1.split(parts[1]);
+                               in2 = null;
+                               break;
+                       case 5:
+                               in1.split(parts[1]);
+                               in2.split(parts[2]);
+                               break;
+                       default:
+                               throw new DMLRuntimeException("Unexpected 
number of operands in the instruction: " + instr);
+               }
        }
 
        public static QuantileSortFEDInstruction parseInstruction ( String str 
) {
@@ -55,18 +76,23 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
 
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
+               boolean isSpark = str.startsWith("SPARK");
+               int k = isSpark ? 1 : Integer.parseInt(parts[parts.length-1]);
 
                if ( opcode.equalsIgnoreCase(SortKeys.OPCODE) ) {
-                       if ( parts.length == 3 ) {
+                       int oneInputLength = isSpark ? 3 : 4;
+                       int twoInputLength = isSpark ? 4 : 5;
+                       if ( parts.length == oneInputLength ) {
                                // Example: sort:mVar1:mVar2 (input=mVar1, 
output=mVar2)
                                parseUnaryInstruction(str, in1, out);
-                               return new QuantileSortFEDInstruction(in1, out, 
opcode, str);
+                               return new QuantileSortFEDInstruction(in1, out, 
opcode, str, k);
                        }
-                       else if ( parts.length == 4 ) {
+                       else if ( parts.length == twoInputLength ) {
                                // Example: sort:mVar1:mVar2:mVar3 
(input=mVar1, weights=mVar2, output=mVar3)
                                in2 = new CPOperand("", 
Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
-                               parseUnaryInstruction(str, in1, in2, out);
-                               return new QuantileSortFEDInstruction(in1, in2, 
out, opcode, str);
+                               InstructionUtils.checkNumFields(str, 
twoInputLength-1);
+                               parseInstruction(str, in1, in2, out);
+                               return new QuantileSortFEDInstruction(in1, in2, 
out, opcode, str, k);
                        }
                        else {
                                throw new DMLRuntimeException("Invalid number 
of operands in instruction: " + str);
@@ -91,7 +117,8 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
                // TODO make sure that qsort result is used by qpick only where 
the main operation happens
                if(input2 != null) {
                        MatrixObject weights = ec.getMatrixObject(input2);
-                       String newInst = 
InstructionUtils.replaceOperand(instString, 1, "append");
+                       String newInst = _numThreads > 1 ? 
InstructionUtils.stripThreadCount(instString) : instString;
+                       newInst = InstructionUtils.replaceOperand(newInst, 1, 
"append");
                        newInst = InstructionUtils.concatOperands(newInst, 
"true");
                        FederatedRequest[] fr1 = 
in.getFedMapping().broadcastSliced(weights, false);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(newInst, output,
@@ -123,7 +150,7 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
 
                                FederatedResponse response = data
                                        .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                               new GetSorted(data.getVarID(), 
varID, wtBlock))).get();
+                                               new GetSorted(data.getVarID(), 
varID, wtBlock, _numThreads))).get();
                                if(!response.isSuccessful())
                                        response.throwExceptionFromResponse();
                        }
@@ -145,17 +172,19 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
                private static final long serialVersionUID = 
-1969015577260167645L;
                private final long _outputID;
                private final MatrixBlock _weights;
+               private final int _numThreads;
 
-               protected GetSorted(long input, long outputID, MatrixBlock 
weights) {
+               protected GetSorted(long input, long outputID, MatrixBlock 
weights, int k) {
                        super(new long[] {input});
                        _outputID = outputID;
                        _weights = weights;
+                       _numThreads = k;
                }
                @Override
                public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
                        MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
 
-                       MatrixBlock res = mb.sortOperations(_weights, new 
MatrixBlock());
+                       MatrixBlock res = mb.sortOperations(_weights, new 
MatrixBlock(), _numThreads);
 
                        MatrixObject mout = 
ExecutionContext.createMatrixObject(res);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index eb152ab..5df5d62 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4820,6 +4820,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        }
 
        public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock 
result) {
+               return sortOperations(weights, result, 1);
+       }
+
+       public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock 
result, int k) {
                boolean wtflag = (weights!=null);
                
                MatrixBlock wts= (weights == null ? null : checkType(weights));
@@ -4877,7 +4881,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                
                // Sort td and tw based on values inside td (ascending sort), 
incl copy into result
                SortIndex sfn = new SortIndex(1, false, false);
-               ReorgOperator rop = new ReorgOperator(sfn);
+               ReorgOperator rop = new ReorgOperator(sfn, k);
                LibMatrixReorg.reorg(tdw, result, rop);
                
                return result;

Reply via email to