This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3ae6a7764a [SYSTEMDS-3898] Fix correctness CP quantile pick instruction
3ae6a7764a is described below
commit 3ae6a7764ae3375247d1687a785e61e798399b60
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Aug 31 14:36:13 2025 +0200
[SYSTEMDS-3898] Fix correctness CP quantile pick instruction
This patch fixes an issue with quantiles for even-length arrays, where
the median for example is not a picked value but an average over two
values. As it turns out, the quantile kernel already supported averaging
but was called incorrectly for quantile() but correctly for median().
Now we have equivalent results to R and consistency in terms of
quantile(X, 0.5) == median(X).
The distributed Spark operations and weighted kernels need some
additional thought and a more involved implementation.
Thanks to Ramon Schoendorf for catching and reporting this issue.
---
.../instructions/cp/QuantilePickCPInstruction.java | 8 +++-
.../sysds/runtime/matrix/data/MatrixBlock.java | 3 ++
.../test/functions/binary/matrix/QuantileTest.java | 44 +++++++++++-----------
.../scripts/functions/binary/matrix/QuantileBug.R | 35 +++++++++++++++++
.../functions/binary/matrix/QuantileBug.dml | 25 ++++++++++++
5 files changed, 91 insertions(+), 24 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java
index 0a818afe78..a7bfbf5a16 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantilePickCPInstruction.java
@@ -90,12 +90,16 @@ public class QuantilePickCPInstruction extends
BinaryCPInstruction {
if ( input2.getDataType() ==
DataType.SCALAR ) {
ScalarObject quantile =
ec.getScalarInput(input2);
- double picked =
matBlock.pickValue(quantile.getDoubleValue());
+ //pick value w/ explicit
averaging for even-length arrays
+ double picked =
matBlock.pickValue(
+
quantile.getDoubleValue(), matBlock.getLength()%2==0);
ec.setScalarOutput(output.getName(), new DoubleObject(picked));
}
else {
MatrixBlock quantiles =
ec.getMatrixInput(input2.getName());
- MatrixBlock resultBlock =
matBlock.pickValues(quantiles, new MatrixBlock());
+ //pick value w/ explicit
averaging for even-length arrays
+ MatrixBlock resultBlock =
matBlock.pickValues(
+ quantiles, new
MatrixBlock(), matBlock.getLength()%2==0);
quantiles = null;
ec.releaseMatrixInput(input2.getName());
ec.setMatrixOutput(output.getName(), resultBlock);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index d15e9e5c2b..524598b119 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4755,7 +4755,10 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) {
+ return pickValues(quantiles, ret, false);
+ }
+ public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret,
boolean average) {
MatrixBlock qs=checkType(quantiles);
if ( qs.clen != 1 ) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/QuantileTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/QuantileTest.java
index 0b009481ec..cd10151007 100644
---
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/QuantileTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/QuantileTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.binary.matrix;
import java.util.HashMap;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -30,15 +29,12 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-/**
- *
- */
public class QuantileTest extends AutomatedTestBase
{
-
private final static String TEST_NAME1 = "Quantile";
private final static String TEST_NAME2 = "Median";
private final static String TEST_NAME3 = "IQM";
+ private final static String TEST_NAME4 = "QuantileBug";
private final static String TEST_DIR = "functions/binary/matrix/";
private final static String TEST_CLASS_DIR = TEST_DIR +
QuantileTest.class.getSimpleName() + "/";
@@ -59,6 +55,8 @@ public class QuantileTest extends AutomatedTestBase
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new
String[] { "R" }) );
addTestConfiguration(TEST_NAME3,
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new
String[] { "R" }) );
+ addTestConfiguration(TEST_NAME4,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new
String[] { "R" }) );
}
@Test
@@ -161,19 +159,21 @@ public class QuantileTest extends AutomatedTestBase
runQuantileTest(TEST_NAME3, -1, true, ExecType.SPARK);
}
+ @Test
+ public void testQuantileBugCP() {
+ runQuantileTest(TEST_NAME4, 0.5, false, ExecType.CP);
+ }
+
+// TODO reimplement distributed value pick logic
+// @Test
+// public void testQuantileBugSP() {
+// runQuantileTest(TEST_NAME4, 0.5, false, ExecType.SPARK);
+// }
+
private void runQuantileTest( String TEST_NAME, double p, boolean
sparse, ExecType et)
{
- //rtplatform for MR
- ExecMode platformOld = rtplatform;
- switch( et ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
+ ExecMode platformOld = setExecMode(et);
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
try
{
getAndLoadTestConfiguration(TEST_NAME);
@@ -185,9 +185,11 @@ public class QuantileTest extends AutomatedTestBase
rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + p + " "+ expectedDir();
//generate actual dataset (always dense because values
<=0 invalid)
- double sparsitya = sparse ? sparsity2 : sparsity1;
- double[][] A = getRandomMatrix(rows, 1, 1, maxVal,
sparsitya, 1236);
- writeInputMatrixWithMTD("A", A, true);
+ if( !TEST_NAME.equals(TEST_NAME4) ) {
+ double sparsitya = sparse ? sparsity2 :
sparsity1;
+ double[][] A = getRandomMatrix(rows, 1, 1,
maxVal, sparsitya, 1236);
+ writeInputMatrixWithMTD("A", A, true);
+ }
runTest(true, false, null, -1);
runRScript(true);
@@ -198,9 +200,7 @@ public class QuantileTest extends AutomatedTestBase
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
}
finally {
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
-
-}
\ No newline at end of file
+}
diff --git a/src/test/scripts/functions/binary/matrix/QuantileBug.R
b/src/test/scripts/functions/binary/matrix/QuantileBug.R
new file mode 100644
index 0000000000..d32c16fe3d
--- /dev/null
+++ b/src/test/scripts/functions/binary/matrix/QuantileBug.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(c(1,5,7,10))
+p = as.double(args[2]);
+
+s = quantile(A, p);
+m = as.matrix(s);
+
+writeMM(as(m, "CsparseMatrix"), paste(args[3], "R", sep=""));
+
+
diff --git a/src/test/scripts/functions/binary/matrix/QuantileBug.dml
b/src/test/scripts/functions/binary/matrix/QuantileBug.dml
new file mode 100644
index 0000000000..761f8c8c64
--- /dev/null
+++ b/src/test/scripts/functions/binary/matrix/QuantileBug.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = as.matrix(list(1,5,7,10));
+s = quantile(A, $2);
+m = as.matrix(s);
+write(m, $3, format="text");