This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c4166647bb [SYSTEMDS-2864] Fix opcode merge conflicts and lineage bug
c4166647bb is described below
commit c4166647bb952df54625510015a7fa32bd4d20fb
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 1 14:08:24 2025 +0100
[SYSTEMDS-2864] Fix opcode merge conflicts and lineage bug
* revert the bad merge of the previous ctable modification
* fix the handling of opcodes in the lineage program reconstruction
---
.../instructions/cp/CtableCPInstruction.java | 29 ++++++++++++++--------
.../runtime/lineage/LineageRecomputeUtils.java | 4 +--
2 files changed, 20 insertions(+), 13 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
index 55d98481b6..4f508cd5b8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
@@ -31,6 +31,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.CTableMap;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.LongLongDoubleHashMap.EntryType;
@@ -40,21 +41,23 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
private final CPOperand _outDim2;
private final boolean _isExpand;
private final boolean _ignoreZeros;
+ private final int _k;
private CtableCPInstruction(CPOperand in1, CPOperand in2, CPOperand
in3, CPOperand out,
String outputDim1, boolean dim1Literal, String
outputDim2, boolean dim2Literal, boolean isExpand,
- boolean ignoreZeros, String opcode, String istr) {
+ boolean ignoreZeros, String opcode, String istr, int k)
{
super(CPType.Ctable, null, in1, in2, in3, out, opcode, istr);
_outDim1 = new CPOperand(outputDim1, ValueType.FP64,
DataType.SCALAR, dim1Literal);
_outDim2 = new CPOperand(outputDim2, ValueType.FP64,
DataType.SCALAR, dim2Literal);
_isExpand = isExpand;
_ignoreZeros = ignoreZeros;
+ _k = k;
}
public static CtableCPInstruction parseInstruction(String inst)
{
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(inst);
- InstructionUtils.checkNumFields ( parts, 7 );
+ InstructionUtils.checkNumFields ( parts, 8 );
String opcode = parts[0];
@@ -76,8 +79,12 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
+ int k = Integer.parseInt(parts[8]);
+
// ctable does not require any operator, so we simply pass-in a
dummy operator with null functionobject
- return new CtableCPInstruction(in1, in2, in3, out,
dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0],
Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
+ return new CtableCPInstruction(in1, in2, in3, out,
+ dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]),
dim2Fields[0],
+ Boolean.parseBoolean(dim2Fields[1]), isExpand,
ignoreZeros, opcode, inst, k);
}
private Ctable.OperationTypes findCtableOperation() {
@@ -89,8 +96,8 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
- MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
- MatrixBlock matBlock2=null, wtBlock=null;
+ MatrixBlock matBlock1 = !_isExpand ? ec.getMatrixInput(input1):
null;
+ MatrixBlock matBlock2 = null, wtBlock=null;
double cst1, cst2;
CTableMap resultMap = new CTableMap(EntryType.INT);
@@ -111,9 +118,6 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
if( !sparse )
resultBlock = new MatrixBlock((int)outputDim1,
(int)outputDim2, false);
}
- if( _isExpand ){
- resultBlock = new MatrixBlock( matBlock1.getNumRows(),
Integer.MAX_VALUE, true );
- }
switch(ctableOp) {
case CTABLE_TRANSFORM: //(VECTOR)
@@ -130,10 +134,13 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
break;
case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR)
// F = ctable(seq,A) or F = ctable(seq,B,1)
+ // ignore first argument
+ if(input1.getDataType() == DataType.MATRIX){
+ LOG.warn("rewrite for table expand not
activated please fix");
+ }
matBlock2 = ec.getMatrixInput(input2.getName());
cst1 =
ec.getScalarInput(input3).getDoubleValue();
- // only resultBlock.rlen known,
resultBlock.clen set in operation
- matBlock1.ctableSeqOperations(matBlock2, cst1,
resultBlock);
+ resultBlock =
LibMatrixReorg.fusedSeqRexpand(matBlock2.getNumRows(), matBlock2, cst1,
resultBlock, true, _k);
break;
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
// F=ctable(A,1) or F = ctable(A,1,1)
@@ -152,7 +159,7 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
throw new DMLRuntimeException("Encountered an
invalid ctable operation ("+ctableOp+") while executing instruction: " +
this.toString());
}
- if(input1.getDataType() == DataType.MATRIX)
+ if(input1.getDataType() == DataType.MATRIX && ctableOp !=
Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT)
ec.releaseMatrixInput(input1.getName());
if(input2.getDataType() == DataType.MATRIX)
ec.releaseMatrixInput(input2.getName());
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index c3bd095aca..e64c742888 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -88,7 +88,7 @@ public class LineageRecomputeUtils {
public static Data parseNComputeLineageTrace(String mainTrace) {
if (DEBUG)
System.out.println(mainTrace);
-
+
// Separate the global trace and the dedup patches
String[] patches =
LineageParser.separateMainAndDedupPatches(mainTrace);
LineageItem root = LineageParser.parseLineageTrace(patches[0]);
//global trace
@@ -307,7 +307,7 @@ public class LineageRecomputeUtils {
break;
}
case Instruction: {
- CPType ctype =
InstructionUtils.getCPTypeByOpcode(item.getOpcode());
+ CPType ctype =
Opcodes.getCPTypeByOpcode(item.getOpcode());
SPType stype =
InstructionUtils.getSPTypeByOpcode(item.getOpcode());
if (ctype != null) {