This is an automated email from the ASF dual-hosted git repository.
estrauss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 615cd9ac48 [SYSTEMDS-3823] Compression test case for bultin kmeans
615cd9ac48 is described below
commit 615cd9ac4860d984cd7167f223ff4b8789ac612d
Author: e-strauss <[email protected]>
AuthorDate: Wed Jan 29 11:46:24 2025 +0100
[SYSTEMDS-3823] Compression test case for bultin kmeans
Closes #2194
---
.../sysds/runtime/matrix/data/MatrixBlock.java | 3 +++
.../compress/workload/WorkloadAlgorithmTest.java | 30 +++++++++++++++++----
.../compress/workload/WorkloadAnalysisKmeans.dml | 31 ++++++++++++++++++++++
3 files changed, 59 insertions(+), 5 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 22fa5e43e7..2a0b92bf4a 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -1730,6 +1730,9 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
* (the invoker is responsible to recompute nnz after
all copies are done)
*/
public void copy(int rl, int ru, int cl, int cu, MatrixBlock src,
boolean awareDestNZ ) {
+ if (src instanceof CompressedMatrixBlock){
+ src = ((CompressedMatrixBlock)
src).getUncompressed("In-place matrix copy into indexed matrix");
+ }
if(sparse && src.sparse)
copySparseToSparse(rl, ru, cl, cu, src, awareDestNZ);
else if(sparse && !src.sparse)
diff --git
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
index 9afede6eba..f09fcf456e 100644
---
a/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/compress/workload/WorkloadAlgorithmTest.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.functions.compress.workload;
import static org.junit.Assert.fail;
+import java.io.ByteArrayOutputStream;
import java.io.File;
import org.apache.commons.logging.Log;
@@ -46,6 +47,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
private final static String TEST_NAME5 = "WorkloadAnalysisSliceFinder";
private final static String TEST_NAME6 = "WorkloadAnalysisLmCG";
private final static String TEST_NAME7 = "WorkloadAnalysisL2SVM";
+ private final static String TEST_NAME8 = "WorkloadAnalysisKmeans";
private final static String TEST_DIR = "functions/compress/workload/";
private final static String TEST_CLASS_DIR = TEST_DIR +
WorkloadAnalysisTest.class.getSimpleName() + "/";
@@ -73,6 +75,7 @@ public class WorkloadAlgorithmTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME5, new TestConfiguration(dir,
TEST_NAME5, new String[] {"B"}));
addTestConfiguration(TEST_NAME6, new TestConfiguration(dir,
TEST_NAME6, new String[] {"B"}));
addTestConfiguration(TEST_NAME7, new TestConfiguration(dir,
TEST_NAME7, new String[] {"B"}));
+ addTestConfiguration(TEST_NAME8, new TestConfiguration(dir,
TEST_NAME8, new String[] {"B"}));
}
@Test
@@ -143,8 +146,23 @@ public class WorkloadAlgorithmTest extends
AutomatedTestBase {
runWorkloadAnalysisTest(TEST_NAME7, ExecMode.SINGLE_NODE, 2,
false);
}
+ @Test
+ public void testKmeansSuccessfulCP() {
+ runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1,
false, 30);
+ }
+
+ @Test
+ public void testKmeansUnsuccessfulCP() {
+ runWorkloadAnalysisTest(TEST_NAME8, ExecMode.SINGLE_NODE, 1,
false, 10);
+ }
+
+ private void runWorkloadAnalysisTest(String testname, ExecMode mode,
int compressionCount, boolean intermediates){
+ runWorkloadAnalysisTest(testname, mode, compressionCount,
intermediates, -1);
+ }
+
// private void runWorkloadAnalysisTest(String testname, ExecMode mode,
int compressionCount) {
- private void runWorkloadAnalysisTest(String testname, ExecMode mode,
int compressionCount, boolean intermediates) {
+ private void runWorkloadAnalysisTest(String testname, ExecMode mode,
int compressionCount, boolean intermediates,
+
int maxIter) {
ExecMode oldPlatform = setExecMode(mode);
boolean oldIntermediates =
WorkloadAnalyzer.ALLOW_INTERMEDIATE_CANDIDATES;
@@ -154,19 +172,20 @@ public class WorkloadAlgorithmTest extends
AutomatedTestBase {
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] {"-stats", "20", "-args",
input("X"), input("y"), output("B")};
+ programArgs = new String[] {"-stats", "20", "-args",
input("X"), input("y"), output("B"),
+ String.valueOf(maxIter)};
writeInputMatrixWithMTD("X", X, false);
writeInputMatrixWithMTD("y", y, false);
- String ret = runTest(null).toString();
+ ByteArrayOutputStream out = runTest(null);
+ String ret = out != null ? out.toString() : "";
LOG.debug(ret);
// check various additional expectations
long actualCompressionCount = (mode == ExecMode.HYBRID
|| mode == ExecMode.SINGLE_NODE) ? Statistics
.getCPHeavyHitterCount("compress") :
Statistics.getCPHeavyHitterCount("sp_compress");
-
- Assert.assertEquals("Assert that the compression counts
expeted matches actual: " + compressionCount + " vs "
+ Assert.assertEquals("Assert that the compression counts
expected matches actual: " + compressionCount + " vs "
+ actualCompressionCount, compressionCount,
actualCompressionCount);
if(compressionCount > 0)
Assert.assertTrue(mode == ExecMode.SINGLE_NODE
|| mode == ExecMode.HYBRID ? heavyHittersContainsString(
@@ -176,6 +195,7 @@ public class WorkloadAlgorithmTest extends
AutomatedTestBase {
}
catch(Exception e) {
+ e.printStackTrace();
resetExecMode(oldPlatform);
fail("Failed workload test");
}
diff --git
a/src/test/scripts/functions/compress/workload/WorkloadAnalysisKmeans.dml
b/src/test/scripts/functions/compress/workload/WorkloadAnalysisKmeans.dml
new file mode 100644
index 0000000000..7382436bf6
--- /dev/null
+++ b/src/test/scripts/functions/compress/workload/WorkloadAnalysisKmeans.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+
+
+print("")
+print("kmeans")
+
+[data, Centering, ScaleFactor] = scale(X, TRUE, TRUE)
+# terminates with result
+[Y_n, C_n] = kmeans(X=data, k=16, runs= 1, max_iter=as.integer($4), eps=
1e-17, seed= 13, is_verbose=TRUE)
+print(sum(Y_n))