This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new b8d373a889 [SYSTEMDS-3853] Fix ampute outer broadcasting and error
handling
b8d373a889 is described below
commit b8d373a889963ca2845ced0f8d717a3d26295186
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Apr 16 13:42:04 2025 +0200
[SYSTEMDS-3853] Fix ampute outer broadcasting and error handling
This patch fixes an invalid left-hand-side and left- and right-hand-side
broadcasting in the new ampute builtin function. We now have a proper
error handling in the hop to guide script developers that broadcasts
can only be used from the right-hand-side.
---
scripts/builtin/ampute.dml | 7 +++----
src/main/java/org/apache/sysds/hops/BinaryOp.java | 9 +++++++++
2 files changed, 12 insertions(+), 4 deletions(-)
diff --git a/scripts/builtin/ampute.dml b/scripts/builtin/ampute.dml
index 7d96136b7c..691e5b48e2 100644
--- a/scripts/builtin/ampute.dml
+++ b/scripts/builtin/ampute.dml
@@ -72,8 +72,8 @@ m_ampute = function(Matrix[Double] X,
# 4. Use probabilities to ampute pattern candidates:
random = rand(rows=groupSize, cols=1, min=0, max=1, pdf="uniform",
seed=seed)
- amputeds = (random <= probs) * (1 - patterns[patternNum]) # Obtains
matrix with 1's at indices to ampute.
- while (FALSE) {} # FIX ME
+ # Obtains matrix with 1's at indices to ampute.
+ amputeds = outer((random <= probs), (1 - patterns[patternNum]), "*")
groupSamples = groupSamples + replace(target=amputeds, pattern=1,
replacement=NaN)
# 5. Update output matrix:
@@ -241,7 +241,6 @@ return (Matrix[Double] groupAssignments, Matrix[Double]
groupCounts) {
for (i in 1:numGroups) {
assigned = (random >= cumSum[i]) & (random < cumSum[i + 1])
- while (FALSE) {} # FIX ME
groupCounts[i] = sum(assigned)
groupAssignments = groupAssignments + i * assigned
}
@@ -308,4 +307,4 @@ return(Integer start, Integer end) {
start = sum(numPerGroup[1:(patternNum - 1), ]) + 1
}
end = start + groupSize - 1
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 8d2b00c1aa..bbcb8b121b 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -1092,6 +1092,15 @@ public class BinaryOp extends MultiThreadedHop {
}
else //GENERAL CASE
{
+ //check correct broadcasting
dimensions
+ if( (input1.getDim1()==1 &&
input2.getDim1() > 1)
+ || (input1.getDim2()==1
&& input2.getDim2() > 1) )
+ {
+ throw new
HopsException("Invalid binary broadcasting from left: "
+ +
input1.getDataCharacteristics()+" "+getOp().name()+" "
+
+input2.getDataCharacteristics());
+ }
+
ldim1 = (input1.rowsKnown()) ?
input1.getDim1()
:
((input2.getDim1()>1)?input2.getDim1():-1);
ldim2 = (input1.colsKnown()) ?
input1.getDim2()