This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit bf0b4ef680920bda0af4d85728de712f24cad6e7 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Wed Apr 17 20:26:27 2024 +0200 [SYSTEMDS-3695] Fix frame builtin parsing and missing nary append This patch fixes issues with parsing frame builtin functions (wrong scalar function name), and missing nary frame append instructions ( compiled not not correctly executed, resulting in scalar 0). --- .../java/org/apache/sysds/common/Builtins.java | 3 ++- .../controlprogram/context/ExecutionContext.java | 16 +++++++++++- .../cp/MatrixBuiltinNaryCPInstruction.java | 30 ++++++++++++++++------ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 8f113c092f..82b6441f58 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -477,6 +477,7 @@ public enum Builtins { public static String getInternalFName(String name, DataType dt) { return !contains(name, true, false) ? name : // private builtin - (dt.isMatrix() ? "m_" : "s_") + name; // public builtin + (dt.isMatrix() ? "m_" : // public builtin + dt.isFrame() ? "f_" : "s_") + name; } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index 0903b5abca..20d602fc38 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -310,7 +310,7 @@ public class ExecutionContext { if( dat == null ) throw new DMLRuntimeException(getNonExistingVarError(varname)); if( !(dat instanceof FrameObject) ) - throw new DMLRuntimeException("Variable '"+varname+"' is not a frame."); + throw new DMLRuntimeException("Variable '"+varname+"' is not a frame: "+dat.getDataType()); return (FrameObject) dat; } @@ -513,6 +513,10 @@ public class ExecutionContext { getMatrixObject(varName).getGPUObject(getGPUContext(0)).releaseInput(); } + public FrameBlock getFrameInput(CPOperand input) { + return getFrameInput(input.getName()); + } + /** * Pins a frame variable into memory and returns the internal frame block. * @@ -531,6 +535,11 @@ public class ExecutionContext { public void releaseFrameInput(String varName) { getFrameObject(varName).release(); } + + public void releaseFrameInputs(CPOperand[] inputs) { + Arrays.stream(inputs).filter(in -> in.isFrame()) + .forEach(in -> releaseFrameInput(in.getName())); + } public void releaseTensorInput(String varName) { getTensorObject(varName).release(); @@ -725,6 +734,11 @@ public class ExecutionContext { .map(in -> getScalarInput(in)).collect(Collectors.toList()); } + public List<FrameBlock> getFrameInputs(CPOperand[] inputs) { + return Arrays.stream(inputs).filter(in -> in.isFrame()) + .map(in -> getFrameInput(in)).collect(Collectors.toList()); + } + public void releaseMatrixInputs(CPOperand[] inputs) { releaseMatrixInputs(inputs, false); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java index 75dadc00f4..8611ade3eb 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java @@ -24,7 +24,9 @@ import java.util.List; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.lineage.LineageTraceable; @@ -42,16 +44,24 @@ public class MatrixBuiltinNaryCPInstruction extends BuiltinNaryCPInstruction imp //separate scalars and matrices and pin all input matrices List<MatrixBlock> matrices = ec.getMatrixInputs(inputs, true); List<ScalarObject> scalars = ec.getScalarInputs(inputs); + List<FrameBlock> frames = ec.getFrameInputs(inputs); - MatrixBlock outBlock = null; + CacheBlock<?> outBlock = null; if( "cbind".equals(getOpcode()) || "rbind".equals(getOpcode()) ) { boolean cbind = "cbind".equals(getOpcode()); - //robustness for empty lists: create 0-by-0 matrix block - outBlock = matrices.size() == 0 ? new MatrixBlock(0, 0, 0) : - matrices.get(0).append(matrices.subList(1, matrices.size()) - .toArray(new MatrixBlock[0]), new MatrixBlock(), cbind); + if(frames.size() == 0 ) { //matrix/scalar + //robustness for empty lists: create 0-by-0 matrix block + outBlock = matrices.size() == 0 ? new MatrixBlock(0, 0, 0) : + matrices.get(0).append(matrices.subList(1, matrices.size()) + .toArray(new MatrixBlock[0]), new MatrixBlock(), cbind); + } + else { + //TODO native nary frame append + outBlock = frames.get(0); + for(int i=1; i<frames.size(); i++) + outBlock = ((FrameBlock)outBlock).append(frames.get(i), cbind); + } } - else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+"}, getOpcode()) ) { outBlock = MatrixBlock.naryOperations(_optr, matrices.toArray(new MatrixBlock[0]), scalars.toArray(new ScalarObject[0]), new MatrixBlock()); @@ -62,12 +72,16 @@ public class MatrixBuiltinNaryCPInstruction extends BuiltinNaryCPInstruction imp //release inputs and set output matrix or scalar ec.releaseMatrixInputs(inputs, true); + ec.releaseFrameInputs(inputs); if( output.getDataType().isMatrix()) { - ec.setMatrixOutput(output.getName(), outBlock); + ec.setMatrixOutput(output.getName(), (MatrixBlock)outBlock); + } + else if( output.getDataType().isFrame()) { + ec.setFrameOutput(output.getName(), (FrameBlock)outBlock); } else { ec.setVariable(output.getName(), ScalarObjectFactory.createScalarObject( - output.getValueType(), outBlock.quickGetValue(0, 0))); + output.getValueType(), ((MatrixBlock)outBlock).quickGetValue(0, 0))); } }