This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 024255708d [MINOR] Improved code coverage control program and symbol 
table
024255708d is described below

commit 024255708de1c9e8e534413d1121952af356084d
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Dec 10 08:48:44 2024 +0100

    [MINOR] Improved code coverage control program and symbol table
    
    This patch adds a couple of tests to systematically fix uncovered code.
    Furthermore, it removes incorrect and renames misleading methods on
    "pinned objects" that actually did not deal with our notion of pinned
    (i.e., disabled cleanup) data objects.
---
 .../org/apache/sysds/api/jmlc/PreparedScript.java  |  2 +-
 .../runtime/controlprogram/LocalVariableMap.java   | 13 ++--
 .../sysds/runtime/controlprogram/Program.java      | 11 +---
 .../sysds/runtime/controlprogram/ProgramBlock.java | 42 +++----------
 .../federated/FederatedWorkerHandler.java          |  4 +-
 .../org/apache/sysds/runtime/meta/MetaData.java    |  4 ++
 .../sysds/test/component/cp/VariableMapTest.java   | 71 ++++++++++++++++++++++
 .../UltraSparseMRMatrixMultiplicationTest.java     |  5 +-
 .../primitives/part2/FederatedRdiagTest.java       |  4 +-
 9 files changed, 102 insertions(+), 54 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java 
b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java
index 08b7425240..31bb745722 100644
--- a/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java
+++ b/src/main/java/org/apache/sysds/api/jmlc/PreparedScript.java
@@ -83,7 +83,7 @@ public class PreparedScript implements ConfigurableAPI
        private PreparedScript(PreparedScript that) {
                //shallow copy, except for a separate symbol table
                //and related meta data of reused inputs
-               _prog = that._prog.clone(false);
+               _prog = (Program)that._prog.clone();
                _vars = new LocalVariableMap();
                for(Entry<String, Data> e : that._vars.entrySet())
                        _vars.put(e.getKey(), e.getValue());
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
index aa6184a5f4..8d27a6e8f2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/LocalVariableMap.java
@@ -47,7 +47,7 @@ public class LocalVariableMap implements Cloneable
        
        //variable map data and id
        private final ConcurrentHashMap<String, Data> localMap;
-       private final long localID;
+       private long localID;
        
        //optional set of registered outputs
        private HashSet<String> outputs = null;
@@ -61,6 +61,10 @@ public class LocalVariableMap implements Cloneable
                localMap = new ConcurrentHashMap<>(vars.localMap);
                localID = _seq.getNextID();
        }
+       
+       public void setID(long ID) {
+               localID = ID;
+       }
 
        public Set<String> keySet() {
                return localMap.keySet();
@@ -154,12 +158,7 @@ public class LocalVariableMap implements Cloneable
                return total;
        }
        
-       public long countPinnedData() {
-               return localMap.values().stream()
-                       .filter(d -> (d instanceof CacheableData)).count();
-       }
-       
-       public void releasePinnedData() {
+       public void releaseAcquiredData() {
                localMap.values().stream()
                        .filter(d -> (d instanceof CacheableData))
                        .map(d -> (CacheableData<?>) d)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
index 73ed572114..e79e48433f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/Program.java
@@ -23,7 +23,6 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Map.Entry;
 
-import org.apache.commons.lang3.NotImplementedException;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.FunctionDictionary;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -165,9 +164,8 @@ public class Program
                }
        }
 
-       public Program clone(boolean deep) {
-               if( deep )
-                       throw new NotImplementedException();
+       @Override
+       public Object clone() {
                Program ret = new Program(_prog);
                //shallow copy of all program blocks
                ret._programBlocks.addAll(_programBlocks);
@@ -179,11 +177,6 @@ public class Program
                return ret;
        }
        
-       @Override
-       public Object clone() {
-               return clone(true);
-       }
-       
        private static String getSafeNamespace(String namespace) {
                return (namespace == null) ? DMLProgram.DEFAULT_NAMESPACE : 
namespace;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 0739334680..aee08516db 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -41,12 +41,9 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.Instruction;
-import org.apache.sysds.runtime.instructions.cp.BooleanObject;
 import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
 import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
@@ -60,7 +57,7 @@ public abstract class ProgramBlock implements ParseInfo {
        public static final String PRED_VAR = "__pred";
 
        protected static final Log LOG = 
LogFactory.getLog(ProgramBlock.class.getName());
-       private static final boolean CHECK_MATRIX_PROPERTIES = false;
+       public static boolean CHECK_MATRIX_PROPERTIES = false;
 
        protected Program _prog; // pointer to Program this ProgramBlock is 
part of
 
@@ -84,10 +81,6 @@ public abstract class ProgramBlock implements ParseInfo {
                return _prog;
        }
 
-       public void setProgram(Program prog) {
-               _prog = prog;
-       }
-
        public StatementBlock getStatementBlock() {
                return _sb;
        }
@@ -216,22 +209,7 @@ public abstract class ProgramBlock implements ParseInfo {
 
                // check and correct scalar ret type (incl save double to int)
                if(retType != null && retType != ret.getValueType())
-                       switch(retType) {
-                               case BOOLEAN:
-                                       ret = new 
BooleanObject(ret.getBooleanValue());
-                                       break;
-                               case INT64:
-                                       ret = new IntObject(ret.getLongValue());
-                                       break;
-                               case FP64:
-                                       ret = new 
DoubleObject(ret.getDoubleValue());
-                                       break;
-                               case STRING:
-                                       ret = new 
StringObject(ret.getStringValue());
-                                       break;
-                               default:
-                                       // do nothing
-                       }
+                       ret = ScalarObjectFactory.createScalarObject(retType, 
ret);
 
                // remove predicate variable
                ec.removeVariable(PRED_VAR);
@@ -350,12 +328,10 @@ public abstract class ProgramBlock implements ParseInfo {
                                        synchronized(mb) { // potential state 
change
                                                mb.recomputeNonZeros();
                                                mb.examSparsity();
-
                                        }
                                        if(mb.isInSparseFormat() && 
mb.isAllocated()) {
                                                
mb.getSparseBlock().checkValidity(mb.getNumRows(), mb.getNumColumns(), 
mb.getNonZeros(), true);
                                        }
-
                                        boolean sparse2 = mb.isInSparseFormat();
                                        long nnz2 = mb.getNonZeros();
                                        mo.release();
@@ -473,11 +449,11 @@ public abstract class ProgramBlock implements ParseInfo {
         *                  position, ending column position, text, and filename
         */
        public void setParseInfo(ParseInfo parseInfo) {
-               _beginLine = parseInfo.getBeginLine();
-               _beginColumn = parseInfo.getBeginColumn();
-               _endLine = parseInfo.getEndLine();
-               _endColumn = parseInfo.getEndColumn();
-               _text = parseInfo.getText();
-               _filename = parseInfo.getFilename();
+               setBeginLine(parseInfo.getBeginLine());
+               setBeginColumn(parseInfo.getBeginColumn());
+               setEndLine(parseInfo.getEndLine());
+               setEndColumn(parseInfo.getEndColumn());
+               setText(parseInfo.getText());
+               setFilename(parseInfo.getFilename());
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index d8d3c262ef..ceaf61c225 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -619,9 +619,9 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        pb.execute(ec);
                }
                catch(Exception ex) {
-                       // ensure all variables are properly unpinned, even in 
case
+                       // ensure all variables are properly released, even in 
case
                        // of failures because federated workers are stateful 
servers
-                       ec.getVariables().releasePinnedData();
+                       ec.getVariables().releaseAcquiredData();
                        throw ex;
                }
        }
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaData.java 
b/src/main/java/org/apache/sysds/runtime/meta/MetaData.java
index aa820b5cb4..925fe7f186 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MetaData.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MetaData.java
@@ -27,6 +27,10 @@ public class MetaData
 {
        protected final DataCharacteristics _dc;
        
+       public MetaData() {
+               this(new MatrixCharacteristics());
+       }
+       
        public MetaData(DataCharacteristics dc) {
                _dc = dc;
        }
diff --git 
a/src/test/java/org/apache/sysds/test/component/cp/VariableMapTest.java 
b/src/test/java/org/apache/sysds/test/component/cp/VariableMapTest.java
new file mode 100644
index 0000000000..74263fff80
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/cp/VariableMapTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.cp;
+
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class VariableMapTest {
+
+       @Test
+       public void testPinnedMethods() {
+               LocalVariableMap vars = createSymbolTable();
+               Assert.assertTrue(vars.getPinnedDataSize() > 2e5);
+               vars.releaseAcquiredData(); //no impact on pinned status
+               Assert.assertTrue(vars.getPinnedDataSize() > 2e5);
+               vars.removeAll();
+       }
+       
+       @Test
+       public void testSerializeDeserialize() {
+               LocalVariableMap vars = createSymbolTable();
+               LocalVariableMap vars2 = 
LocalVariableMap.deserialize(vars.serialize());
+               vars2.setID(1);
+               Assert.assertEquals(vars.toString(), vars2.toString());
+               LocalVariableMap vars3 = (LocalVariableMap) vars2.clone();
+               vars3.setID(1);
+               Assert.assertEquals(vars.toString(), vars3.toString());
+       }
+       
+       private LocalVariableMap createSymbolTable() {
+               LocalVariableMap vars = new LocalVariableMap();
+               vars.put("a", createPinnedMatrixObject(1));
+               vars.put("b", createPinnedMatrixObject(2));
+               return vars;
+       }
+       
+       private MatrixObject createPinnedMatrixObject(int seed) {
+               MatrixBlock mb1 = MatrixBlock.randOperations(150, 167, 0.3, 1, 
1, "uniform", seed);
+               MatrixObject mo = new MatrixObject(ValueType.FP64, "./tmp", 
+                       new MetaDataFormat(new MatrixCharacteristics(), 
FileFormat.BINARY));
+               mo.acquireModify(mb1);
+               mo.release();
+               mo.enableCleanup(false);
+               mo.setDirty(false);
+               return mo;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java
index 6c712a8ad2..0fbe20a320 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UltraSparseMRMatrixMultiplicationTest.java
@@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggBinaryOp.MMultMethod;
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -116,12 +117,13 @@ public class UltraSparseMRMatrixMultiplicationTest 
extends AutomatedTestBase
                if( rtplatform == ExecMode.SPARK )
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 
+               ProgramBlock.CHECK_MATRIX_PROPERTIES = true;
+               
                if(forcePMMJ)
                        AggBinaryOp.FORCED_MMULT_METHOD = MMultMethod.PMM;
                        
                try
                {
-                       setOutputBuffering(true);
                        String TEST_NAME = (rowwise) ? TEST_NAME1 : TEST_NAME2;
                        getAndLoadTestConfiguration(TEST_NAME);
                        
@@ -154,6 +156,7 @@ public class UltraSparseMRMatrixMultiplicationTest extends 
AutomatedTestBase
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                        AggBinaryOp.FORCED_MMULT_METHOD = null;
+                       ProgramBlock.CHECK_MATRIX_PROPERTIES = false;
                }
        }
 }
\ No newline at end of file
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
index 8e45c7347d..46bf7a4565 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
@@ -25,6 +25,7 @@ import java.util.Collection;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -116,7 +117,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
                Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
                Process t4 = startLocalFedWorker(port4);
 
-               
+               ProgramBlock.CHECK_MATRIX_PROPERTIES = true;
                try {
                        if(!isAlive(t1, t2, t3, t4))
                                throw new RuntimeException("Failed starting 
federated worker");
@@ -162,6 +163,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                        OptimizerUtils.FEDERATED_COMPILATION = false;
+                       ProgramBlock.CHECK_MATRIX_PROPERTIES = false;
                }
        }
 }

Reply via email to