This is an automated email from the ASF dual-hosted git repository.
markd pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new d4ba502 [MINOR] Fix null check in EncoderMVImpute for Global_Mean;
Also:
d4ba502 is described below
commit d4ba5028635ff2011b22828ad1e372d1811fab95
Author: Lukas Erlbacher <[email protected]>
AuthorDate: Mon Mar 1 18:46:28 2021 +0100
[MINOR] Fix null check in EncoderMVImpute for Global_Mean; Also:
* Adapted test so NaN is checked
* removed unnecessary import
* adapted count for mean calculation
Closes #1190
---
.../runtime/transform/encode/EncoderMVImpute.java | 24 +++++++++++++---------
src/test/java/org/apache/sysds/test/TestUtils.java | 7 +++++++
.../transform/TransformFrameEncodeApplyTest.java | 10 +++++----
3 files changed, 27 insertions(+), 14 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index 16377ba..2effddb 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -48,8 +48,7 @@ import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
-public class EncoderMVImpute extends Encoder
-{
+public class EncoderMVImpute extends Encoder {
private static final long serialVersionUID = 9057868620144662194L;
public enum MVMethod { INVALID, GLOBAL_MEAN, GLOBAL_MODE, CONSTANT }
@@ -69,8 +68,7 @@ public class EncoderMVImpute extends Encoder
public KahanObject[] getMeans() { return _meanList; }
public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int
clen, int minCol, int maxCol)
- throws JSONException
- {
+ throws JSONException {
super(null, clen);
//handle column list
@@ -166,9 +164,15 @@ public class EncoderMVImpute extends Encoder
if( _mvMethodList[j] == MVMethod.GLOBAL_MEAN ) {
//compute global column mean (scale)
long off = _countList[j];
- for( int i=0; i<in.getNumRows(); i++ )
+ for( int i=0; i<in.getNumRows(); i++ ){
+ Object key = in.get(i, colID-1);
+ if(key == null){
+ off--;
+ continue;
+ }
_meanFn.execute2(_meanList[j],
UtilFunctions.objectToDouble(
-
in.getSchema()[colID-1], in.get(i, colID-1)), off+i+1);
+
in.getSchema()[colID-1], key), off+i+1);
+ }
_replacementList[j] =
String.valueOf(_meanList[j]._sum);
_countList[j] += in.getNumRows();
}
@@ -241,8 +245,8 @@ public class EncoderMVImpute extends Encoder
}
private static void fillListsFromMap(Map<Integer, ColInfo> map, int[]
colList, MVMethod[] mvMethodList,
- String[] replacementList, KahanObject[] meanList, long[]
countList,
- HashMap<Integer, HashMap<String, Long>> hist) {
+ String[] replacementList, KahanObject[] meanList,
long[] countList,
+ HashMap<Integer, HashMap<String, Long>> hist) {
int i = 0;
for(Entry<Integer, ColInfo> entry : map.entrySet()) {
colList[i] = entry.getKey();
@@ -267,7 +271,7 @@ public class EncoderMVImpute extends Encoder
for(int i = 0; i < other._colList.length; i++) {
int column = other._colList[i] + (col - 1);
ColInfo otherColInfo = new
ColInfo(otherImpute._mvMethodList[i], otherImpute._replacementList[i],
- otherImpute._meanList[i],
otherImpute._countList[i], otherImpute._hist.get(i + 1));
+ otherImpute._meanList[i],
otherImpute._countList[i], otherImpute._hist.get(i + 1));
ColInfo colInfo = map.get(column);
if(colInfo == null)
map.put(column, otherColInfo);
@@ -432,7 +436,7 @@ public class EncoderMVImpute extends Encoder
public void merge(ColInfo otherColInfo) {
if(_method != otherColInfo._method)
throw new DMLRuntimeException("Tried to merge
two different impute methods: " + _method.name() + " vs. "
- + otherColInfo._method.name());
+ + otherColInfo._method.name());
switch(_method) {
case CONSTANT:
assert
_replacement.equals(otherColInfo._replacement);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 0d541e7..b8c6304 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -3056,4 +3056,11 @@ public class TestUtils
return y;
}
+
+ public static boolean containsNan(double[][] data, int col) {
+ for (double[] datum : data)
+ if (Double.isNaN(datum[col]))
+ return true;
+ return false;
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
index 18f5fd7..65b9fa8 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeApplyTest.java
@@ -31,8 +31,7 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
-public class TransformFrameEncodeApplyTest extends AutomatedTestBase
-{
+public class TransformFrameEncodeApplyTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "TransformFrameEncodeApply";
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeApplyTest.class.getSimpleName() + "/";
@@ -357,8 +356,7 @@ public class TransformFrameEncodeApplyTest extends
AutomatedTestBase
runTransformTest(ExecMode.HYBRID, "csv",
TransformType.HASH_RECODE, false);
}
- private void runTransformTest( ExecMode rt, String ofmt, TransformType
type, boolean colnames )
- {
+ private void runTransformTest( ExecMode rt, String ofmt, TransformType
type, boolean colnames ) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
if( rtplatform == ExecMode.SPARK || rtplatform ==
ExecMode.HYBRID)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
@@ -431,6 +429,10 @@ public class TransformFrameEncodeApplyTest extends
AutomatedTestBase
1:0, R1[i][10+j], 1e-8);
}
}
+ } else if (type == TransformType.IMPUTE){
+ // Column 8 had GLOBAL_MEAN applied
+ Assert.assertFalse(TestUtils.containsNan(R1,
8));
+ Assert.assertFalse(TestUtils.containsNan(R2,
8));
}
}
catch(Exception ex) {