This is an automated email from the ASF dual-hosted git repository.

estrauss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 615cd9ac48 [SYSTEMDS-3823] Compression test case for bultin kmeans
615cd9ac48 is described below

commit 615cd9ac4860d984cd7167f223ff4b8789ac612d
Author: e-strauss <[email protected]>
AuthorDate: Wed Jan 29 11:46:24 2025 +0100

    [SYSTEMDS-3823] Compression test case for bultin kmeans
    
    Closes #2194
---
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  3 +++
 .../compress/workload/WorkloadAlgorithmTest.java   | 30 +++++++++++++++++----
 .../compress/workload/WorkloadAnalysisKmeans.dml   | 31 ++++++++++++++++++++++
 3 files changed, 59 insertions(+), 5 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 22fa5e43e7..2a0b92bf4a 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -1730,6 +1730,9 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
         *                  (the invoker is responsible to recompute nnz after 
all copies are done) 
         */
        public void copy(int rl, int ru, int cl, int cu, MatrixBlock src, 
boolean awareDestNZ ) {
+               if (src instanceof CompressedMatrixBlock){
+                       src = ((CompressedMatrixBlock) 
src).getUncompressed("In-place matrix copy into indexed matrix");
+               }
                if(sparse && src.sparse)
                        copySparseToSparse(rl, ru, cl, cu, src, awareDestNZ);
                else if(sparse && !src.sparse)
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index 9afede6eba..f09fcf456e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.functions.compress.workload;
 
 import static org.junit.Assert.fail;
 
+import java.io.ByteArrayOutputStream;
 import java.io.File;
 
 import org.apache.commons.logging.Log;
@@ -46,6 +47,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
        private final static String TEST_NAME5 = "WorkloadAnalysisSliceFinder";
        private final static String TEST_NAME6 = "WorkloadAnalysisLmCG";
        private final static String TEST_NAME7 = "WorkloadAnalysisL2SVM";
+       private final static String TEST_NAME8 = "WorkloadAnalysisKmeans";
        private final static String TEST_DIR = "functions/compress/workload/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
WorkloadAnalysisTest.class.getSimpleName() + "/";
 
@@ -73,6 +75,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
                addTestConfiguration(TEST_NAME5, new TestConfiguration(dir, 
TEST_NAME5, new String[] {"B"}));
                addTestConfiguration(TEST_NAME6, new TestConfiguration(dir, 
TEST_NAME6, new String[] {"B"}));
                addTestConfiguration(TEST_NAME7, new TestConfiguration(dir, 
TEST_NAME7, new String[] {"B"}));
+               addTestConfiguration(TEST_NAME8, new TestConfiguration(dir, 
TEST_NAME8, new String[] {"B"}));
        }
 
        @Test
@@ -143,8 +146,23 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
                runWorkloadAnalysisTest(TEST_NAME7, ExecMode.SINGLE_NODE, 2, 
false);
        }
 
+       @Test
+       public void testKmeansSuccessfulCP() {
+               runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, 
false, 30);
+       }
+
+       @Test
+       public void testKmeansUnsuccessfulCP() {
+               runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1, 
false, 10);
+       }
+
+       private void runWorkloadAnalysisTest(String testname, ExecMode mode, 
int compressionCount, boolean intermediates){
+               runWorkloadAnalysisTest(testname, mode, compressionCount, 
intermediates, -1);
+       }
+
        // private void runWorkloadAnalysisTest(String testname, ExecMode mode, 
int compressionCount) {
-       private void runWorkloadAnalysisTest(String testname, ExecMode mode, 
int compressionCount, boolean intermediates) {
+       private void runWorkloadAnalysisTest(String testname, ExecMode mode, 
int compressionCount, boolean intermediates,
+                                                                               
 int maxIter) {
                ExecMode oldPlatform = setExecMode(mode);
                boolean oldIntermediates = 
WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
 
@@ -154,19 +172,20 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
 
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[] {"-stats", "20", "-args", 
input("X"), input("y"), output("B")};
+                       programArgs = new String[] {"-stats", "20", "-args", 
input("X"), input("y"), output("B"),
+                                       String.valueOf(maxIter)};
 
                        writeInputMatrixWithMTD("X", X, false);
                        writeInputMatrixWithMTD("y", y, false);
 
-                       String ret = runTest(null).toString();
+                       ByteArrayOutputStream out = runTest(null);
+                       String ret = out != null ? out.toString() : "";
                        LOG.debug(ret);
 
                        // check various additional expectations
                        long actualCompressionCount = (mode == ExecMode.HYBRID 
|| mode == ExecMode.SINGLE_NODE) ? Statistics
                                .getCPHeavyHitterCount("compress") : 
Statistics.getCPHeavyHitterCount("sp_compress");
-
-                       Assert.assertEquals("Assert that the compression counts 
expeted matches actual: " + compressionCount + " vs "
+                       Assert.assertEquals("Assert that the compression counts 
expected matches actual: " + compressionCount + " vs "
                                + actualCompressionCount, compressionCount, 
actualCompressionCount);
                        if(compressionCount > 0)
                                Assert.assertTrue(mode == ExecMode.SINGLE_NODE 
|| mode == ExecMode.HYBRID ? heavyHittersContainsString(
@@ -176,6 +195,7 @@ public class WorkloadAlgorithmTest extends 
AutomatedTestBase {
 
                }
                catch(Exception e) {
+                       e.printStackTrace();
                        resetExecMode(oldPlatform);
                        fail("Failed workload test");
                }
diff --git 
a/src/test/scripts/functions/compress/workload/WorkloadAnalysisKmeans.dml 
b/src/test/scripts/functions/compress/workload/WorkloadAnalysisKmeans.dml
new file mode 100644
index 0000000000..7382436bf6
--- /dev/null
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisKmeans.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+
+
+print("")
+print("kmeans")
+
+[data, Centering, ScaleFactor] = scale(X, TRUE, TRUE)
+# terminates with result
+[Y_n, C_n] = kmeans(X=data, k=16, runs= 1, max_iter=as.integer($4), eps= 
1e-17, seed= 13, is_verbose=TRUE)
+print(sum(Y_n))

Reply via email to