This is an automated email from the ASF dual-hosted git repository.

janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 3de7cbe7e9 Add OOC WDivMM
3de7cbe7e9 is described below

commit 3de7cbe7e9e1af46510d059629e55624cd32626b
Author: Jessica Priebe <[email protected]>
AuthorDate: Wed May 13 10:00:04 2026 +0200

    Add OOC WDivMM
    
    Closes #2464.
---
 scripts/builtin/pnmf.dml                           |   6 +-
 .../java/org/apache/sysds/hops/QuaternaryOp.java   |  16 ++
 .../runtime/instructions/OOCInstructionParser.java |   3 +
 .../ooc/ComputationOOCInstruction.java             |  11 +-
 .../runtime/instructions/ooc/OOCInstruction.java   |   2 +-
 .../instructions/ooc/QuaternaryOOCInstruction.java |  54 +++++
 .../instructions/ooc/WDivMMOOCInstruction.java     | 218 +++++++++++++++++++++
 .../apache/sysds/test/functions/ooc/PNMFTest.java  |  18 +-
 .../sysds/test/functions/ooc/WDivMMTest.java       | 156 +++++++++++++++
 src/test/scripts/functions/ooc/PNMF.dml            |   6 +-
 .../{PNMF.dml => WeightedDivMM4MultMinusLeft.dml}  |  18 +-
 .../{PNMF.dml => WeightedDivMM4MultMinusRight.dml} |  18 +-
 .../ooc/{PNMF.dml => WeightedDivMMLeft.dml}        |  16 +-
 .../ooc/{PNMF.dml => WeightedDivMMLeftEps.dml}     |  18 +-
 .../ooc/{PNMF.dml => WeightedDivMMMultBasic.dml}   |  16 +-
 .../ooc/{PNMF.dml => WeightedDivMMMultLeft.dml}    |  16 +-
 .../{PNMF.dml => WeightedDivMMMultMinusLeft.dml}   |  16 +-
 .../{PNMF.dml => WeightedDivMMMultMinusRight.dml}  |  16 +-
 .../ooc/{PNMF.dml => WeightedDivMMMultRight.dml}   |  16 +-
 .../ooc/{PNMF.dml => WeightedDivMMRight.dml}       |  16 +-
 .../ooc/{PNMF.dml => WeightedDivMMRightEps.dml}    |  18 +-
 21 files changed, 595 insertions(+), 79 deletions(-)

diff --git a/scripts/builtin/pnmf.dml b/scripts/builtin/pnmf.dml
index 721ab7232b..bffc373592 100644
--- a/scripts/builtin/pnmf.dml
+++ b/scripts/builtin/pnmf.dml
@@ -42,12 +42,12 @@
 # H     List of amplitude matrices, one for each repetition.
 # 
------------------------------------------------------------------------------------
 
-m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer 
maxi = 10, Boolean verbose=TRUE) 
+m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer 
maxi = 10, Boolean verbose=TRUE, Integer seed=-1)
   return (Matrix[Double] W, Matrix[Double] H) 
 {
   #initialize W and H
-  W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025);
-  H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025);
+  W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed);
+  H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed);
 
   i = 0;
   while(i < maxi) {
diff --git a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java 
b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
index c2be949f37..8fede5f090 100644
--- a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java
@@ -211,6 +211,8 @@ public class QuaternaryOp extends MultiThreadedHop
                                                
constructCPLopsWeightedDivMM(wtype);
                                        else if( et == ExecType.SPARK )
                                                
constructSparkLopsWeightedDivMM(wtype);
+                                       else if( et == ExecType.OOC )
+                                               
constructOOCLopsWeightedDivMM(wtype);
                                        else
                                                throw new 
HopsException("Unsupported quaternaryop-wdivmm exec type: "+et);
                                        break;
@@ -462,6 +464,20 @@ public class QuaternaryOp extends MultiThreadedHop
                }
        }
 
+       private void constructOOCLopsWeightedDivMM(WDivMMType wtype)
+       {
+               WeightedDivMM wdiv = new WeightedDivMM(
+                       getInput().get(0).constructLops(),
+                       getInput().get(1).constructLops(),
+                       getInput().get(2).constructLops(),
+                       getInput().get(3).constructLops(),
+                       getDataType(), getValueType(), wtype, ExecType.OOC);
+
+               setOutputDimensions(wdiv);
+               setLineNumbers(wdiv);
+               setLops(wdiv);
+       }
+
        private void constructCPLopsWeightedCeMM(WCeMMType wtype) 
        {
                WeightedCrossEntropy wcemm = new WeightedCrossEntropy(
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
index affda5910d..ae41639687 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java
@@ -43,6 +43,7 @@ import 
org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
 import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;
+import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction;
 
 public class OOCInstructionParser extends InstructionParser {
        protected static final Log LOG = 
LogFactory.getLog(OOCInstructionParser.class.getName());
@@ -111,6 +112,8 @@ public class OOCInstructionParser extends InstructionParser 
{
                                return 
DataGenOOCInstruction.parseInstruction(str);
                        case Append:
                                return 
AppendOOCInstruction.parseInstruction(str);
+                       case Quaternary:
+                               return 
QuaternaryOOCInstruction.parseInstruction(str);
 
                        default:
                                throw new DMLRuntimeException("Invalid OOC 
Instruction Type: " + ooctype);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
index 4dcdffcb0d..d6686c1156 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java
@@ -24,7 +24,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public abstract class ComputationOOCInstruction extends OOCInstruction {
        public CPOperand output;
-       public CPOperand input1, input2, input3;
+       public CPOperand input1, input2, input3, input4;
 
        protected ComputationOOCInstruction(OOCType type, Operator op, 
CPOperand in1, CPOperand out, String opcode, String istr) {
                super(type, op, opcode, istr);
@@ -50,6 +50,15 @@ public abstract class ComputationOOCInstruction extends 
OOCInstruction {
                output = out;
        }
 
+       protected ComputationOOCInstruction(OOCType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, 
String opcode, String istr) {
+               super(type, op, opcode, istr);
+               input1 = in1;
+               input2 = in2;
+               input3 = in3;
+               input4 = in4;
+               output = out;
+       }
+
        public String getOutputVariableName() {
                return output.getName();
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index be9728d87b..679e7187e5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction {
 
        public enum OOCType {
                Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, 
AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
-               MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, 
ParameterizedBuiltin, Rand, Append
+               MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, 
ParameterizedBuiltin, Rand, Append, Quaternary
        }
 
        protected final OOCInstruction.OOCType _ooctype;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java
new file mode 100644
index 0000000000..8df1e33c59
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+public abstract class QuaternaryOOCInstruction extends 
ComputationOOCInstruction {
+
+       protected QuaternaryOOCInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand in4,
+                       CPOperand out, String opcode, String istr) {
+               super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode, 
istr);
+       }
+
+       public static QuaternaryOOCInstruction parseInstruction(String str) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+
+               if(opcode.contains(Opcodes.WEIGHTEDDIVMM.toString())) {
+                       InstructionUtils.checkNumFields(parts, 6);
+                       CPOperand in1 = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       CPOperand out = new CPOperand(parts[5]);
+                       QuaternaryOperator qop = new 
QuaternaryOperator(WDivMMType.valueOf(parts[6]));
+                       return new WDivMMOOCInstruction(qop, in1, in2, in3, 
in4, out, opcode, str);
+               }
+               throw new DMLRuntimeException("Not implemented yet opcode " + 
opcode);
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java
new file mode 100644
index 0000000000..ec9a7bcd4f
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+import java.util.function.Function;
+
+public class WDivMMOOCInstruction extends QuaternaryOOCInstruction {
+
+       protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand in4,
+               CPOperand out, String opcode, String istr) {
+               super(op, in1, in2, in3, in4, out, opcode, istr);
+       }
+
+       public static WDivMMOOCInstruction 
parseInstruction(QuaternaryOOCInstruction instr) {
+               String instrStr = instr.getInstructionString();
+               String opcode = 
InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0];
+               return new WDivMMOOCInstruction((QuaternaryOperator) 
instr.getOperator(), instr.input1, instr.input2,
+                       instr.input3, instr.input4, instr.output, opcode, 
instrStr);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               QuaternaryOperator qop = ((QuaternaryOperator) _optr);
+               final WDivMMType wt = qop.wtype3;
+
+               CachingStream X = new 
CachingStream(ec.getMatrixObject(input1).getStreamHandle());
+               CachingStream U = new 
CachingStream(ec.getMatrixObject(input2).getStreamHandle());
+               CachingStream V = new 
CachingStream(ec.getMatrixObject(input3).getStreamHandle());
+
+               boolean basic = wt.isBasic();
+               boolean left = wt.isLeft();
+               boolean mult = wt.isMult();
+               boolean minus = wt.isMinus();
+               boolean four = wt.hasFourInputs();
+               boolean scalar = wt.hasScalar();
+
+               OOCStream<IndexedMatrixValue> mmt = 
matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(),
+                       V.getDataCharacteristics(), false, true);
+               OOCStream<IndexedMatrixValue> inter;
+               OOCStream<IndexedMatrixValue> out;
+
+               if(basic) {
+                       out = elemMultOOC(X.getReadStream(), mmt);
+                       ec.getMatrixObject(output).setStreamHandle(out);
+                       return;
+               }
+               else if(four) {
+                       if(scalar) {
+                               double eps = 
ec.getScalarInput(input4).getDoubleValue();
+                               inter = elemDivOOC(X.getReadStream(), 
elemPlusOOC(mmt, eps));
+                       }
+                       else {
+                               CachingStream W = new 
CachingStream(ec.getMatrixObject(input4).getStreamHandle());
+                               inter = elemMultOOC(X.getReadStream(), 
elemMinusOOC(mmt, W.getReadStream()));
+                       }
+               }
+               else {
+                       if(minus)
+                               inter = maskOOC(X.getReadStream(), 
elemMinusOOC(mmt, X.getReadStream()));
+                       else {
+                               if(mult)
+                                       inter = elemMultOOC(X.getReadStream(), 
mmt);
+                               else
+                                       inter = elemDivOOC(X.getReadStream(), 
mmt);
+                       }
+               }
+
+               if(left)
+                       out = matMultOOC(inter, U.getReadStream(), 
X.getDataCharacteristics(), U.getDataCharacteristics(),
+                               true, false);
+               else
+                       out = matMultOOC(inter, V.getReadStream(), 
X.getDataCharacteristics(), V.getDataCharacteristics(),
+                               false, false);
+
+               ec.getMatrixObject(output).setStreamHandle(out);
+       }
+
+       private OOCStream<IndexedMatrixValue> 
matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2,
+               DataCharacteristics dc1, DataCharacteristics dc2, boolean 
leftTranspose, boolean rightTranspose) {
+
+               int emitLeftThreshold = rightTranspose ? (int) 
dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
+               int emitRightThreshold = leftTranspose ? (int) 
dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();
+
+               OOCStream<IndexedMatrixValue> intermediateStream = 
createWritableStream();
+               OOCStream<IndexedMatrixValue> out = createWritableStream();
+
+               AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               AggregateBinaryOperator op = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+
+               joinManyOOC(m1, m2, intermediateStream, (left, right) -> {
+                       MatrixBlock leftBlock = (MatrixBlock) left.getValue();
+                       MatrixBlock rightBlock = (MatrixBlock) right.getValue();
+                       if(leftTranspose)
+                               leftBlock = leftBlock.transpose();
+                       if(rightTranspose)
+                               rightBlock = rightBlock.transpose();
+
+                       MatrixBlock partialResult = 
leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(), 
op);
+                       int lidx = (int) (leftTranspose ? 
left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
+                       int ridx = (int) (rightTranspose ? 
right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
+                       return new IndexedMatrixValue(new MatrixIndexes(lidx, 
ridx), partialResult);
+               }, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() : 
tmp.getIndexes().getColumnIndex(),
+                       tmp -> rightTranspose ? 
tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
+                       emitLeftThreshold, emitRightThreshold);
+
+               BinaryOperator plus = 
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
+               int emitAggThreshold = leftTranspose ? (int) 
dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();
+
+               groupedReduceOOC(intermediateStream, out, (left, right) -> {
+                       MatrixBlock mb = ((MatrixBlock) 
left.getValue()).binaryOperationsInPlace(plus, right.getValue());
+                       left.setValue(mb);
+                       return left;
+               }, emitAggThreshold);
+
+               return out;
+       }
+
+       private OOCStream<IndexedMatrixValue> 
elemOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2, 
BinaryOperator bop) {
+               SubscribableTaskQueue<IndexedMatrixValue> out = new 
SubscribableTaskQueue<>();
+               Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
+                       new MatrixIndexes(imv.getIndexes().getRowIndex(), 
imv.getIndexes().getColumnIndex());
+
+               joinOOC(m1, m2, out, (left, right) -> {
+                       MatrixBlock lb = (MatrixBlock) left.getValue();
+                       MatrixBlock rb = (MatrixBlock) right.getValue();
+                       MatrixBlock combined = lb.binaryOperations(bop, rb);
+                       return new IndexedMatrixValue(
+                               new 
MatrixIndexes(left.getIndexes().getRowIndex(), 
left.getIndexes().getColumnIndex()), combined);
+               }, key);
+
+               return out;
+       }
+
+       private OOCStream<IndexedMatrixValue> 
elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
+               BinaryOperator div = 
InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
+               return elemOOC(m1, m2, div);
+       }
+
+       private OOCStream<IndexedMatrixValue> 
elemMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) 
{
+               BinaryOperator div = 
InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString());
+               return elemOOC(m1, m2, div);
+       }
+
+       private OOCStream<IndexedMatrixValue> 
elemMinusOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> 
m2) {
+               BinaryOperator div = 
InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString());
+               return elemOOC(m1, m2, div);
+       }
+
+       private OOCStream<IndexedMatrixValue> 
elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps) {
+               SubscribableTaskQueue<IndexedMatrixValue> out = new 
SubscribableTaskQueue<>();
+               mapOOC(m1, out, blk -> {
+                       MatrixBlock res = ((MatrixBlock) blk.getValue())
+                               .scalarOperations(new 
RightScalarOperator(Plus.getPlusFnObject(), eps), null);
+                       return new IndexedMatrixValue(
+                               new 
MatrixIndexes(blk.getIndexes().getRowIndex(), 
blk.getIndexes().getColumnIndex()), res);
+               });
+               return out;
+       }
+
+       private OOCStream<IndexedMatrixValue> 
maskOOC(OOCStream<IndexedMatrixValue> mask, OOCStream<IndexedMatrixValue> m1) {
+               SubscribableTaskQueue<IndexedMatrixValue> out = new 
SubscribableTaskQueue<>();
+               Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
+                       new MatrixIndexes(imv.getIndexes().getRowIndex(), 
imv.getIndexes().getColumnIndex());
+
+               joinOOC(mask, m1, out, (left, right) -> {
+                       MatrixBlock lb = (MatrixBlock) left.getValue();
+                       MatrixBlock rb = (MatrixBlock) right.getValue();
+                       MatrixBlock combined = mask(lb, rb);
+                       return new IndexedMatrixValue(
+                               new 
MatrixIndexes(left.getIndexes().getRowIndex(), 
left.getIndexes().getColumnIndex()), combined);
+               }, key);
+
+               return out;
+       }
+
+       private MatrixBlock mask(MatrixBlock mask, MatrixBlock blk) {
+               for(int i = 0; i < blk.getNumRows(); i++) {
+                       for(int j = 0; j < blk.getNumColumns(); j++) {
+                               if(mask.get(i,j) ==0) blk.set(i, j, 0);
+                       }
+               }
+               return blk;
+       }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
index a25249985d..d7186f2bbe 100644
--- a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java
@@ -21,12 +21,16 @@ package org.apache.sysds.test.functions.ooc;
 
 import java.io.IOException;
 
+import org.apache.sysds.common.Opcodes;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
 
 public class PNMFTest extends AutomatedTestBase {
        private static final String TEST_NAME = "PNMF";
@@ -44,6 +48,7 @@ public class PNMFTest extends AutomatedTestBase {
        private static final int RANK = 20;
        private static final int MAX_ITER = 10;
        private static final int BLOCK_SIZE = 1000;
+       private static final int SEED = 7;
 
        private static final double SPARSITY = 0.7;
        private static final double EPS = 1e-6;
@@ -54,7 +59,7 @@ public class PNMFTest extends AutomatedTestBase {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
        }
 
-       //@Test
+       @Test
        public void testPNMFOOCVsCP() {
                runPNMFTest();
        }
@@ -71,13 +76,16 @@ public class PNMFTest extends AutomatedTestBase {
                        double[][] xData = getRandomMatrix(ROWS, COLS, 1, 10, 
SPARSITY, 7);
                        writeBinaryWithMTD(INPUT_X, 
DataConverter.convertToMatrixBlock(xData));
 
-                       programArgs = new String[] {"-explain", "-stats", 
"-seed", "7", "-ooc", "-args",
-                               input(INPUT_X), String.valueOf(RANK), 
String.valueOf(MAX_ITER),
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc", "-args",
+                               input(INPUT_X), String.valueOf(RANK), 
String.valueOf(MAX_ITER), String.valueOf(SEED),
                                output(OUTPUT_W_OOC), output(OUTPUT_H_OOC)};
                        runTest(true, false, null, -1);
 
-                       programArgs = new String[] {"-explain", "-stats", 
"-seed", "7", "-args",
-                               input(INPUT_X), String.valueOf(RANK), 
String.valueOf(MAX_ITER),
+                       Assert.assertTrue("OOC wasn't used for pnmf",
+                               
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + 
Opcodes.WEIGHTEDDIVMM));
+
+                       programArgs = new String[] {"-explain", "-stats", 
"-args",
+                               input(INPUT_X), String.valueOf(RANK), 
String.valueOf(MAX_ITER), String.valueOf(SEED),
                                output(OUTPUT_W_CP), output(OUTPUT_H_CP)};
                        runTest(true, false, null, -1);
 
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java
new file mode 100644
index 0000000000..549fdc764d
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
[email protected]
+public class WDivMMTest extends AutomatedTestBase {
+       private final static String INPUT_NAME_1 = "W";
+       private final static String INPUT_NAME_2 = "U";
+       private final static String INPUT_NAME_3 = "V";
+       private final static String OUTPUT_NAME = "R";
+       private static final String TEST_DIR = "functions/ooc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
WDivMMTest.class.getSimpleName() + "/";
+
+       private static final int rows = 2201;
+       private static final int cols = 1103;
+       private static final int rank = 20;
+       private static final int blen = 1000;
+       private static final double eps = 1e-6;
+       private static final double div_eps = 0.1;
+
+       private final static String TEST_NAME_1 = "WeightedDivMMLeft";
+       private final static String TEST_NAME_2 = "WeightedDivMMRight";
+       private final static String TEST_NAME_3 = "WeightedDivMMMultBasic";
+       private final static String TEST_NAME_4 = "WeightedDivMMMultLeft";
+       private final static String TEST_NAME_5 = "WeightedDivMMMultRight";
+       private final static String TEST_NAME_6 = "WeightedDivMMMultMinusLeft";
+       private final static String TEST_NAME_7 = "WeightedDivMMMultMinusRight";
+       private final static String TEST_NAME_8 = "WeightedDivMM4MultMinusLeft";
+       private final static String TEST_NAME_9 = 
"WeightedDivMM4MultMinusRight";
+       private final static String TEST_NAME_10 = "WeightedDivMMLeftEps";
+       private final static String TEST_NAME_11 = "WeightedDivMMRightEps";
+       private String TEST_NAME;
+
+       public WDivMMTest(String testName) {
+               this.TEST_NAME = testName;
+       }
+
+       @Parameterized.Parameters(name = "{0}")
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {{TEST_NAME_1}, 
{TEST_NAME_2}, {TEST_NAME_3}, {TEST_NAME_4}, {TEST_NAME_5},
+                       {TEST_NAME_6}, {TEST_NAME_7}, {TEST_NAME_8}, 
{TEST_NAME_9}, {TEST_NAME_10}, {TEST_NAME_11}});
+       }
+
+       @Before
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME}));
+       }
+
+       @Test
+       public void testWeightedDivMM() {
+               runWeightedDivMMTest(TEST_NAME);
+       }
+
+       private void runWeightedDivMMTest(String TEST_NAME) {
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+               try {
+                       boolean basic = TEST_NAME.equals(TEST_NAME_3);
+                       boolean left = TEST_NAME.equals(TEST_NAME_1) || 
TEST_NAME.equals(TEST_NAME_4) ||
+                               TEST_NAME.equals(TEST_NAME_6) || 
TEST_NAME.equals(TEST_NAME_8) || TEST_NAME.equals(TEST_NAME_10);
+
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+                       double[][] W = getRandomMatrix(rows, cols, 0, 1, 0.7, 
7);
+                       double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 
713);
+                       double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 
812);
+
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+                       
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(W), 
input(INPUT_NAME_1), rows,
+                               cols, blen, rows * cols);
+                       
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(U), 
input(INPUT_NAME_2), rows,
+                               rank, blen, rows * rank);
+                       
writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(V), 
input(INPUT_NAME_3), cols,
+                               rank, blen, cols * rank);
+
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, cols, blen, 
rows * cols), Types.FileFormat.BINARY);
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, rank, blen, 
rows * rank), Types.FileFormat.BINARY);
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_3 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(cols, rank, blen, 
cols * rank), Types.FileFormat.BINARY);
+
+                       programArgs = new String[] {"-ooc", "-stats", 
"-explain", "runtime", "-args", input(INPUT_NAME_1),
+                               input(INPUT_NAME_2), input(INPUT_NAME_3), 
output(OUTPUT_NAME), Double.toString(div_eps)};
+
+                       runTest(true, false, null, -1);
+
+                       Assert.assertTrue("OOC wasn't used for wdivmm",
+                               
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + 
Opcodes.WEIGHTEDDIVMM));
+
+                       programArgs = new String[] {"-stats", "-explain", 
"runtime", "-args", input(INPUT_NAME_1),
+                               input(INPUT_NAME_2), input(INPUT_NAME_3), 
output(OUTPUT_NAME + "_target"), Double.toString(div_eps)};
+
+                       runTest(true, false, null, -1);
+
+                       int rows2 = left ? cols : rows;
+                       int cols2 = basic ? cols : rank;
+                       checkDMLMetaDataFile("R", new 
MatrixCharacteristics(rows2, cols2));
+
+                       MatrixBlock actual = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+                               Types.FileFormat.BINARY, rows2, cols2, blen);
+                       MatrixBlock expected = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+                               Types.FileFormat.BINARY, rows2, cols2, blen);
+                       TestUtils.compareMatrices(expected, actual, eps);
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/PNMF.dml
index 60aecb8963..bc0fd5b100 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/PNMF.dml
@@ -20,7 +20,7 @@
 #-------------------------------------------------------------
 
 X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
+[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE, seed=$4);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+write(W, $5, format="binary");
+write(H, $6, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml
similarity index 86%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml
index 60aecb8963..42bd4c96a0 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+X = W/0.7;
+while(FALSE){}
+R = t(t(U) %*% (W*(U%*%t(V)-X)));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml
index 60aecb8963..7b393f1231 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+X = W/0.3
+while(FALSE){}
+R = (W*(U%*%t(V)-X)) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMLeft.dml
index 60aecb8963..48639a176a 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = t(t(U) %*% (W/(U%*%t(V))));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml
index 60aecb8963..dc07670fea 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+x = $5;
+
+R = t(t(U) %*% (W/(U%*%t(V) + x)));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml
index 60aecb8963..144e59a773 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = W*(U%*%t(V));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml
index 60aecb8963..93bc765617 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = t(t(U) %*% (W*(U%*%t(V))));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml
index 60aecb8963..84ac35ad89 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = t(t(U) %*% ((W != 0)*(U%*%t(V)-W)));
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml
index 60aecb8963..59caa4d17b 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = ((W != 0)*(U%*%t(V)-W)) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml
index 60aecb8963..fbb1224d17 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = (W*(U%*%t(V))) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMRight.dml
index 60aecb8963..e878a81d14 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+R = (W/(U%*%t(V))) %*% V;
+
+write(R, $4, format="binary");
diff --git a/src/test/scripts/functions/ooc/PNMF.dml 
b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml
similarity index 87%
copy from src/test/scripts/functions/ooc/PNMF.dml
copy to src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml
index 60aecb8963..9ecbaf5663 100644
--- a/src/test/scripts/functions/ooc/PNMF.dml
+++ b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-#
+# 
 #   http://www.apache.org/licenses/LICENSE-2.0
-#
+# 
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,8 +19,14 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE);
 
-write(W, $4, format="binary");
-write(H, $5, format="binary");
+
+W = read($1);
+U = read($2);
+V = read($3);
+
+x = $5;
+
+R = (W/(U%*%t(V) + x)) %*% V;
+
+write(R, $4, format="binary");


Reply via email to