This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new a11e0e2  [SYSTEMDS-3162] Fix transformapply binning w/ out-of-range 
values
a11e0e2 is described below

commit a11e0e27b5f926baa5741adbb6fff9416f59eb78
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 14 16:58:54 2021 +0200

    [SYSTEMDS-3162] Fix transformapply binning w/ out-of-range values
    
    So far the transformapply of binning encoders did not correctly handle
    values out of the min/max range the bins were created (<min -> bucket 1,
    >max out of bounds exceptions). Instead, we now properly handle these
    scenarios of returning NaNs for such values, which can be post-processed
    at script level accordingly.
---
 .../runtime/transform/encode/ColumnEncoderBin.java |  26 +++--
 .../TransformFederatedEncodeApplyTest.java         |   6 +-
 .../transform/TransformApplyUnknownsTest.java      | 105 +++++++++++++++++++++
 .../TransformFrameBuildMultithreadedTest.java      |   2 +-
 .../TransformFrameEncodeMultithreadedTest.java     |   2 +-
 .../sysds/test/util/DependencyThreadPoolTest.java  |   2 +-
 6 files changed, 122 insertions(+), 21 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 5736c1e..8439b4e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -47,7 +47,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
 
        // frame transform-apply attributes
        // a) column bin boundaries
-       // TODO binMins is redundant and could be removed - necessary for 
correct fed results
        private double[] _binMins = null;
        private double[] _binMaxs = null;
        // b) column min/max (for partial build)
@@ -125,8 +124,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
                return new BinMergePartialBuildTask(this, ret);
        }
 
-
-
        public void computeBins(double min, double max) {
                // ensure allocated internal transformation metadata
                if(_binMins == null || _binMaxs == null) {
@@ -159,8 +156,7 @@ public class ColumnEncoderBin extends ColumnEncoder {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                for(int i = rowStart; i < getEndIndex(in.getNumRows(), 
rowStart, blk); i++) {
                        double inVal = 
UtilFunctions.objectToDouble(in.getSchema()[_colID - 1], in.get(i, _colID - 1));
-                       int ix = Arrays.binarySearch(_binMaxs, inVal);
-                       int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
+                       double binID = applyValue(inVal);
                        out.quickSetValue(i, outputCol, binID);
                }
                if (DMLScript.STATISTICS)
@@ -174,14 +170,21 @@ public class ColumnEncoderBin extends ColumnEncoder {
                int end = getEndIndex(in.getNumRows(), rowStart, blk);
                for(int i = rowStart; i < end; i++) {
                        double inVal = in.quickGetValueThreadSafe(i, _colID - 
1);
-                       int ix = Arrays.binarySearch(_binMaxs, inVal);
-                       int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
+                       double binID = applyValue(inVal);
                        out.quickSetValue(i, outputCol, binID);
                }
                if (DMLScript.STATISTICS)
                        
Statistics.incTransformBinningApplyTime(System.nanoTime()-t0);
                return out;
        }
+       
+       private double applyValue(double inVal) {
+               if( inVal < _binMins[0] | inVal > _binMaxs[_binMaxs.length-1] )
+                       return Double.NaN; //value outside min/max range
+               int ix = Arrays.binarySearch(_binMaxs, inVal);
+               int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
+               return binID;
+       }
 
        @Override
        protected ColumnApplyTask<? extends ColumnEncoder> 
@@ -301,8 +304,7 @@ public class ColumnEncoderBin extends ColumnEncoder {
                        for(int r = _startRow; r < 
getEndIndex(_inputF.getNumRows(), _startRow, _blk); r++) {
                                SparseRowVector row = (SparseRowVector) 
_out.getSparseBlock().get(r);
                                double inVal = 
UtilFunctions.objectToDouble(_inputF.getSchema()[index], _inputF.get(r, index));
-                               int ix = Arrays.binarySearch(_encoder._binMaxs, 
inVal);
-                               int binID = ((ix < 0) ? Math.abs(ix + 1) : ix) 
+ 1;
+                               double binID = _encoder.applyValue(inVal);
                                row.values()[index] = binID;
                                row.indexes()[index] = _outputCol;
                        }
@@ -315,7 +317,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
                public String toString() {
                        return getClass().getSimpleName() + "<ColId: " + 
_encoder._colID + ">";
                }
-
        }
 
        private static class BinPartialBuildTask implements Callable<Object> {
@@ -352,7 +353,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
                public String toString() {
                        return getClass().getSimpleName() + "<Start row: " + 
_startRow + "; Block size: " + _blockSize + ">";
                }
-
        }
 
        private static class BinMergePartialBuildTask implements 
Callable<Object> {
@@ -384,11 +384,9 @@ public class ColumnEncoderBin extends ColumnEncoder {
                public String toString() {
                        return getClass().getSimpleName() + "<ColId: " + 
_encoder._colID + ">";
                }
-
        }
 
        private static class ColumnBinBuildTask implements Callable<Object> {
-
                private final ColumnEncoderBin _encoder;
                private final FrameBlock _input;
 
@@ -407,7 +405,5 @@ public class ColumnEncoderBin extends ColumnEncoder {
                public String toString() {
                        return getClass().getSimpleName() + "<ColId: " + 
_encoder._colID + ">";
                }
-
        }
-
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index 1bbe444..f1fa7e6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -214,9 +214,9 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                        int port3 = getRandomAvailablePort();
                        int port4 = getRandomAvailablePort();
                        String[] otherargs = lineage ? new String[] 
{"-lineage", "reuse_full"} : null;
-                       t1 = startLocalFedWorkerThread(port1, otherargs, 
FED_WORKER_WAIT_S);
-                       t2 = startLocalFedWorkerThread(port2, otherargs, 
FED_WORKER_WAIT_S);
-                       t3 = startLocalFedWorkerThread(port3, otherargs, 
FED_WORKER_WAIT_S);
+                       t1 = startLocalFedWorkerThread(port1, otherargs);
+                       t2 = startLocalFedWorkerThread(port2, otherargs);
+                       t3 = startLocalFedWorkerThread(port3, otherargs);
                        t4 = startLocalFedWorkerThread(port4, otherargs);
 
                        FileFormatPropertiesCSV ffpCSV = new 
FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java
new file mode 100644
index 0000000..c0967af
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+
+public class TransformApplyUnknownsTest extends AutomatedTestBase 
+{
+       private static final int rows = 70;
+       
+       @Override
+       public void setUp()  {
+               TestUtils.clearAssertionInformation();
+       }
+       
+       @Test
+       public void testTransformApplyRecode() {
+               try {
+                       //generate input data
+                       FrameBlock data = 
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(1, rows, 1));
+                       FrameBlock data2 = 
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(1, rows+10, 1));
+                       
+                       //encode and obtain meta data
+                       String spec = "{ids:true, recode:[1]}";
+                       MultiColumnEncoder encoder = 
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, null);
+                       encoder.build(data);
+                       FrameBlock meta = encoder.getMetaData(new FrameBlock(1, 
ValueType.STRING));
+                       
+                       //apply
+                       MultiColumnEncoder encoder2 = 
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, meta);
+                       MatrixBlock out = encoder2.apply(data2);
+                       
+                       //check outputs
+                       Assert.assertEquals(out.getNumRows(), 
data2.getNumRows());
+                       Assert.assertEquals(out.getNumColumns(), 
data2.getNumColumns());
+                       for(int i=1; i<=rows; i++)
+                               Assert.assertEquals((double)i, 
out.quickGetValue(i-1, 0), 1e-8);
+                       for(int i=rows+1; i<=rows+10; i++)
+                               
Assert.assertTrue(Double.isNaN(out.quickGetValue(i-1, 0)));
+               } 
+               catch (DMLRuntimeException e) {
+                       throw new RuntimeException(e);
+               }
+       }
+       
+       @Test
+       public void testTransformApplyBinning() {
+               try {
+                       //generate input data
+                       FrameBlock data = 
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(1, rows, 1));
+                       FrameBlock data2 = 
DataConverter.convertToFrameBlock(MatrixBlock.seqOperations(-5, rows+5, 1));
+                       
+                       //encode and obtain meta data
+                       String spec = "{ids:true, bin:[{id:1, 
method:equi-width, numbins:7}] }";
+                       MultiColumnEncoder encoder = 
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, null);
+                       encoder.build(data);
+                       FrameBlock meta = encoder.getMetaData(new FrameBlock(1, 
ValueType.STRING));
+                       
+                       //apply
+                       MultiColumnEncoder encoder2 = 
EncoderFactory.createEncoder(spec, data.getColumnNames(), 1, meta);
+                       MatrixBlock out = encoder2.apply(data2);
+                       
+                       //check outputs
+                       Assert.assertEquals(out.getNumRows(), 
data2.getNumRows());
+                       Assert.assertEquals(out.getNumColumns(), 
data2.getNumColumns());
+                       for(int i=-5; i<=rows+5; i++) {
+                               if( i < 1 | i > rows )
+                                       
Assert.assertTrue(Double.isNaN(out.quickGetValue(i+5, 0)));
+                               else
+                                       
Assert.assertEquals((double)((i-1)/10+1), out.quickGetValue(i+5, 0), 1e-8);
+                       }
+               }
+               catch (Exception e) {
+                       e.printStackTrace();
+                       throw new RuntimeException(e);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
similarity index 99%
rename from 
src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
rename to 
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
index b70571b..0c235d6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameBuildMultithreadedTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.transform.mt;
+package org.apache.sysds.test.functions.transform;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
similarity index 99%
rename from 
src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
rename to 
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
index 6679f36..4c707c0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/mt/TransformFrameEncodeMultithreadedTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.transform.mt;
+package org.apache.sysds.test.functions.transform;
 
 import java.nio.file.Files;
 import java.nio.file.Paths;
diff --git 
a/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java 
b/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
index eb5e244..61dc4f4 100644
--- a/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
+++ b/src/test/java/org/apache/sysds/test/util/DependencyThreadPoolTest.java
@@ -31,7 +31,7 @@ import org.apache.sysds.runtime.util.DependencyThreadPool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import 
org.apache.sysds.test.functions.transform.mt.TransformFrameBuildMultithreadedTest;
+import 
org.apache.sysds.test.functions.transform.TransformFrameBuildMultithreadedTest;
 import org.junit.Assert;
 import org.junit.Test;
 

Reply via email to