This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a11e0e2 [SYSTEMDS-3162] Fix transformapply binning w/ out-of-range
values
a11e0e2 is described below
commit a11e0e27b5f926baa5741adbb6fff9416f59eb78
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 14 16:58:54 2021 +0200
[SYSTEMDS-3162] Fix transformapply binning w/ out-of-range values
So far the transformapply of binning encoders did not correctly handle
values out of the min/max range the bins were created (<min -> bucket 1,
>max out of bounds exceptions). Instead, we now properly handle these
scenarios of returning NaNs for such values, which can be post-processed
at script level accordingly.
---
.../runtime/transform/encode/ColumnEncoderBin.java | 26 +++--
.../TransformFederatedEncodeApplyTest.java | 6 +-
.../transform/TransformApplyUnknownsTest.java | 105 +++++++++++++++++++++
.../TransformFrameBuildMultithreadedTest.java | 2 +-
.../TransformFrameEncodeMultithreadedTest.java | 2 +-
.../sysds/test/util/DependencyThreadPoolTest.java | 2 +-
6 files changed, 122 insertions(+), 21 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 5736c1e..8439b4e 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -47,7 +47,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
// frame transform-apply attributes
// a) column bin boundaries
- // TODO binMins is redundant and could be removed - necessary for
correct fed results
private double[] _binMins = null;
private double[] _binMaxs = null;
// b) column min/max (for partial build)
@@ -125,8 +124,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
return new BinMergePartialBuildTask(this, ret);
}
-
-
public void computeBins(double min, double max) {
// ensure allocated internal transformation metadata
if(_binMins == null || _binMaxs == null) {
@@ -159,8 +156,7 @@ public class ColumnEncoderBin extends ColumnEncoder {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
double inVal =
UtilFunctions.objectToDouble(in.getSchema()[_colID - 1], in.get(i, _colID - 1));
- int ix = Arrays.binarySearch(_binMaxs, inVal);
- int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
+ double binID = applyValue(inVal);
out.quickSetValue(i, outputCol, binID);
}
if (DMLScript.STATISTICS)
@@ -174,14 +170,21 @@ public class ColumnEncoderBin extends ColumnEncoder {
int end = getEndIndex(in.getNumRows(), rowStart, blk);
for(int i = rowStart; i < end; i++) {
double inVal = in.quickGetValueThreadSafe(i, _colID -
1);
- int ix = Arrays.binarySearch(_binMaxs, inVal);
- int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
+ double binID = applyValue(inVal);
out.quickSetValue(i, outputCol, binID);
}
if (DMLScript.STATISTICS)
Statistics.incTransformBinningApplyTime(System.nanoTime()-t0);
return out;
}
+
+ private double applyValue(double inVal) {
+ if( inVal < _binMins[0] | inVal > _binMaxs[_binMaxs.length-1] )
+ return Double.NaN; //value outside min/max range
+ int ix = Arrays.binarySearch(_binMaxs, inVal);
+ int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
+ return binID;
+ }
@Override
protected ColumnApplyTask<? extends ColumnEncoder>
@@ -301,8 +304,7 @@ public class ColumnEncoderBin extends ColumnEncoder {
for(int r = _startRow; r <
getEndIndex(_inputF.getNumRows(), _startRow, _blk); r++) {
SparseRowVector row = (SparseRowVector)
_out.getSparseBlock().get(r);
double inVal =
UtilFunctions.objectToDouble(_inputF.getSchema()[index], _inputF.get(r, index));
- int ix = Arrays.binarySearch(_encoder._binMaxs,
inVal);
- int binID = ((ix < 0) ? Math.abs(ix + 1) : ix)
+ 1;
+ double binID = _encoder.applyValue(inVal);
row.values()[index] = binID;
row.indexes()[index] = _outputCol;
}
@@ -315,7 +317,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
public String toString() {
return getClass().getSimpleName() + "<ColId: " +
_encoder._colID + ">";
}
-
}
private static class BinPartialBuildTask implements Callable<Object> {
@@ -352,7 +353,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
public String toString() {
return getClass().getSimpleName() + "<Start row: " +
_startRow + "; Block size: " + _blockSize + ">";
}
-
}
private static class BinMergePartialBuildTask implements
Callable<Object> {
@@ -384,11 +384,9 @@ public class ColumnEncoderBin extends ColumnEncoder {
public String toString() {
return getClass().getSimpleName() + "<ColId: " +
_encoder._colID + ">";
}
-
}
private static class ColumnBinBuildTask implements Callable<Object> {
-
private final ColumnEncoderBin _encoder;
private final FrameBlock _input;
@@ -407,7 +405,5 @@ public class ColumnEncoderBin extends ColumnEncoder {
public String toString() {
return getClass().getSimpleName() + "<ColId: " +
_encoder._colID + ">";
}
-
}
-
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index 1bbe444..f1fa7e6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -214,9 +214,9 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
String[] otherargs = lineage ? new String[]
{"-lineage", "reuse_full"} : null;
- t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2, otherargs,
FED_WORKER_WAIT_S);
- t3 = startLocalFedWorkerThread(port3, otherargs,
FED_WORKER_WAIT_S);
+ t1 = startLocalFedWorkerThread(port1, otherargs);
+ t2 = startLocalFedWorkerThread(port2, otherargs);
+ t3 = startLocalFedWorkerThread(port3, otherargs);
t4 = startLocalFedWorkerThread(port4, otherargs);
FileFormatPropertiesCSV ffpCSV = new
FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java
new file mode 100644
index 0000000..c0967af
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+
+public class TransformApplyUnknownsTest extends AutomatedTestBase
+{
+ private static final int rows = 70;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testTransformApplyRecode() {
+ try {
+ //generate input data
+ FrameBlock data =
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(1, rows, 1));
+ FrameBlock data2 =
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(1, rows+10, 1));
+
+ //encode and obtain meta data
+ String spec = "{ids:true, recode:[1]}";
+ MultiColumnEncoder encoder =
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, null);
+ encoder.build(data);
+ FrameBlock meta = encoder.getMetaData(new FrameBlock(1,
ValueType.STRING));
+
+ //apply
+ MultiColumnEncoder encoder2 =
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, meta);
+ MatrixBlock out = encoder2.apply(data2);
+
+ //check outputs
+ Assert.assertEquals(out.getNumRows(),
data2.getNumRows());
+ Assert.assertEquals(out.getNumColumns(),
data2.getNumColumns());
+ for(int i=1; i<=rows; i++)
+ Assert.assertEquals((double)i,
out.quickGetValue(i-1, 0), 1e-8);
+ for(int i=rows+1; i<=rows+10; i++)
+
Assert.assertTrue(Double.isNaN(out.quickGetValue(i-1, 0)));
+ }
+ catch (DMLRuntimeException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
+ public void testTransformApplyBinning() {
+ try {
+ //generate input data
+ FrameBlock data =
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(1, rows, 1));
+ FrameBlock data2 =
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(-5, rows+5, 1));
+
+ //encode and obtain meta data
+ String spec = "{ids:true, bin:[{id:1,
method:equi-width, numbins:7}] }";
+ MultiColumnEncoder encoder =
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, null);
+ encoder.build(data);
+ FrameBlock meta = encoder.getMetaData(new FrameBlock(1,
ValueType.STRING));
+
+ //apply
+ MultiColumnEncoder encoder2 =
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, meta);
+ MatrixBlock out = encoder2.apply(data2);
+
+ //check outputs
+ Assert.assertEquals(out.getNumRows(),
data2.getNumRows());
+ Assert.assertEquals(out.getNumColumns(),
data2.getNumColumns());
+ for(int i=-5; i<=rows+5; i++) {
+ if( i < 1 | i > rows )
+
Assert.assertTrue(Double.isNaN(out.quickGetValue(i+5, 0)));
+ else
+
Assert.assertEquals((double)((i-1)/10+1), out.quickGetValue(i+5, 0), 1e-8);
+ }
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
similarity index 99%
rename from
src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
rename to
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
index b70571b..0c235d6 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.sysds.test.functions.transform.mt;
+package org.apache.sysds.test.functions.transform;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
similarity index 99%
rename from
src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
rename to
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
index 6679f36..4c707c0 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.sysds.test.functions.transform.mt;
+package org.apache.sysds.test.functions.transform;
import java.nio.file.Files;
import java.nio.file.Paths;
diff --git
a/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
b/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
index eb5e244..61dc4f4 100644
--- a/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
+++ b/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
@@ -31,7 +31,7 @@ import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import
org.apache.sysds.test.functions.transform.mt.TransformFrameBuildMultithreadedTest;
+import
org.apache.sysds.test.functions.transform.TransformFrameBuildMultithreadedTest;
import org.junit.Assert;
import org.junit.Test;