This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new d4ba502  [MINOR] Fix null check in EncoderMVImpute for Global_Mean; 
Also:
d4ba502 is described below

commit d4ba5028635ff2011b22828ad1e372d1811fab95
Author: Lukas Erlbacher <[email protected]>
AuthorDate: Mon Mar 1 18:46:28 2021 +0100

    [MINOR] Fix null check in EncoderMVImpute for Global_Mean; Also:
    
    * Adapted test so NaN is checked
    * removed unnecessary import
    * adapted count for mean calculation
    
    Closes #1190
---
 .../runtime/transform/encode/EncoderMVImpute.java  | 24 +++++++++++++---------
 src/test/java/org/apache/sysds/test/TestUtils.java |  7 +++++++
 .../transform/TransformFrameEncodeApplyTest.java   | 10 +++++----
 3 files changed, 27 insertions(+), 14 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index 16377ba..2effddb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -48,8 +48,7 @@ import org.apache.wink.json4j.JSONArray;
 import org.apache.wink.json4j.JSONException;
 import org.apache.wink.json4j.JSONObject;
 
-public class EncoderMVImpute extends Encoder
-{
+public class EncoderMVImpute extends Encoder {
        private static final long serialVersionUID = 9057868620144662194L;
 
        public enum MVMethod { INVALID, GLOBAL_MEAN, GLOBAL_MODE, CONSTANT }
@@ -69,8 +68,7 @@ public class EncoderMVImpute extends Encoder
        public KahanObject[] getMeans()   { return _meanList; }
        
        public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int 
clen, int minCol, int maxCol)
-               throws JSONException
-       {
+                       throws JSONException {
                super(null, clen);
                
                //handle column list
@@ -166,9 +164,15 @@ public class EncoderMVImpute extends Encoder
                                if( _mvMethodList[j] == MVMethod.GLOBAL_MEAN ) {
                                        //compute global column mean (scale)
                                        long off = _countList[j];
-                                       for( int i=0; i<in.getNumRows(); i++ )
+                                       for( int i=0; i<in.getNumRows(); i++ ){
+                                               Object key = in.get(i, colID-1);
+                                               if(key == null){
+                                                       off--;
+                                                       continue;
+                                               }
                                                _meanFn.execute2(_meanList[j], 
UtilFunctions.objectToDouble(
-                                                       
in.getSchema()[colID-1], in.get(i, colID-1)), off+i+1);
+                                                               
in.getSchema()[colID-1], key), off+i+1);
+                                       }
                                        _replacementList[j] = 
String.valueOf(_meanList[j]._sum);
                                        _countList[j] += in.getNumRows();
                                }
@@ -241,8 +245,8 @@ public class EncoderMVImpute extends Encoder
        }
 
        private static void fillListsFromMap(Map<Integer, ColInfo> map, int[] 
colList, MVMethod[] mvMethodList,
-               String[] replacementList, KahanObject[] meanList, long[] 
countList,
-               HashMap<Integer, HashMap<String, Long>> hist) {
+                       String[] replacementList, KahanObject[] meanList, 
long[] countList,
+                       HashMap<Integer, HashMap<String, Long>> hist) {
                int i = 0;
                for(Entry<Integer, ColInfo> entry : map.entrySet()) {
                        colList[i] = entry.getKey();
@@ -267,7 +271,7 @@ public class EncoderMVImpute extends Encoder
                        for(int i = 0; i < other._colList.length; i++) {
                                int column = other._colList[i] + (col - 1);
                                ColInfo otherColInfo = new 
ColInfo(otherImpute._mvMethodList[i], otherImpute._replacementList[i],
-                                       otherImpute._meanList[i], 
otherImpute._countList[i], otherImpute._hist.get(i + 1));
+                                               otherImpute._meanList[i], 
otherImpute._countList[i], otherImpute._hist.get(i + 1));
                                ColInfo colInfo = map.get(column);
                                if(colInfo == null)
                                        map.put(column, otherColInfo);
@@ -432,7 +436,7 @@ public class EncoderMVImpute extends Encoder
                public void merge(ColInfo otherColInfo) {
                        if(_method != otherColInfo._method)
                                throw new DMLRuntimeException("Tried to merge 
two different impute methods: " + _method.name() + " vs. "
-                                       + otherColInfo._method.name());
+                                               + otherColInfo._method.name());
                        switch(_method) {
                                case CONSTANT:
                                        assert 
_replacement.equals(otherColInfo._replacement);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 0d541e7..b8c6304 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -3056,4 +3056,11 @@ public class TestUtils
 
                return y;
        }
+
+       public static boolean containsNan(double[][] data, int col) {
+               for (double[] datum : data)
+                       if (Double.isNaN(datum[col]))
+                               return true;
+               return false;
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
index 18f5fd7..65b9fa8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
@@ -31,8 +31,7 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Statistics;
 
-public class TransformFrameEncodeApplyTest extends AutomatedTestBase 
-{
+public class TransformFrameEncodeApplyTest extends AutomatedTestBase {
        private final static String TEST_NAME1 = "TransformFrameEncodeApply";
        private final static String TEST_DIR = "functions/transform/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeApplyTest.class.getSimpleName() + "/";
@@ -357,8 +356,7 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                runTransformTest(ExecMode.HYBRID, "csv", 
TransformType.HASH_RECODE, false);
        }
        
-       private void runTransformTest( ExecMode rt, String ofmt, TransformType 
type, boolean colnames )
-       {
+       private void runTransformTest( ExecMode rt, String ofmt, TransformType 
type, boolean colnames ) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                if( rtplatform == ExecMode.SPARK || rtplatform == 
ExecMode.HYBRID)
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
@@ -431,6 +429,10 @@ public class TransformFrameEncodeApplyTest extends 
AutomatedTestBase
                                                        1:0, R1[i][10+j], 1e-8);
                                        }
                                }
+                       } else if (type == TransformType.IMPUTE){
+                               // Column 8 had GLOBAL_MEAN applied
+                               Assert.assertFalse(TestUtils.containsNan(R1, 
8));
+                               Assert.assertFalse(TestUtils.containsNan(R2, 
8));
                        }
                }
                catch(Exception ex) {

Reply via email to