This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1d1886958d [SYSTEMDS-3631] Fix Integer casting bug in as.frame
1d1886958d is described below
commit 1d1886958d6ec7214ae99fd3776add59fd5eea7c
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Tue Oct 17 01:06:48 2023 +0200
[SYSTEMDS-3631] Fix Integer casting bug in as.frame
This commit fixes the bug by solidifying the isType detection to not
allow casting to int on integers above Integer.MAX_VALUE and not cast
to long on above Long.MAX_VALUE. Also added is a few tests to verify
the behavior in some edge cases.
Closes #1923
---
.../frame/data/lib/FrameFromMatrixBlock.java | 21 ++++---
.../sysds/runtime/frame/data/lib/FrameUtil.java | 22 +++-----
.../test/component/frame/FrameCustomTest.java | 65 ++++++++++++++++++++++
3 files changed, 88 insertions(+), 20 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java
index 3164724f4f..7e75c621a5 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.frame.data.lib;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
@@ -60,7 +61,6 @@ public class FrameFromMatrixBlock {
this.mb = mb;
m = mb.getNumRows();
n = mb.getNumColumns();
-
this.schema = schema == null ? getSchema(mb) : schema;
this.frame = new FrameBlock(this.schema);
this.k = k;
@@ -82,14 +82,21 @@ public class FrameFromMatrixBlock {
return new FrameFromMatrixBlock(mb, schema, k).apply();
}
- private ValueType[] getSchema(MatrixBlock mb) {
+ private static ValueType[] getSchema(MatrixBlock mb) {
final int nCol = mb.getNumColumns();
final int nRow = mb.getNumRows();
- ValueType[] schema = UtilFunctions.nCopies(nCol,
ValueType.BOOLEAN);
- for(int r = 0; r < nRow; r++)
- for(int c = 0; c < nCol; c++)
- schema[c] =
FrameUtil.isType(mb.quickGetValue(r, c), schema[c]);
-
+ // default boolean if possible.
+ final ValueType[] schema = UtilFunctions.nCopies(nCol,
ValueType.BOOLEAN);
+ for(int c = 0; c < nCol; c++){
+ for(int r = 0; r < nRow; r++){
+ switch(schema[c]){
+ case FP64:
+ break;
+ default:
+ schema[c] =
FrameUtil.isType(mb.quickGetValue(r, c), schema[c]);
+ }
+ }
+ }
return schema;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
index d9b1f739ba..c2d53a650b 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
@@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.BooleanArray;
@@ -193,32 +194,27 @@ public interface FrameUtil {
public static ValueType isType(double val) {
if(val == 1.0d || val == 0.0d)
return ValueType.BOOLEAN;
- else if((long) (val) == val) {
- if((int) val == val)
- return ValueType.INT32;
- else
+ else if(val < Integer.MAX_VALUE && Util.eq((int) val,val))
+ return ValueType.INT32;
+ else if(val < Long.MAX_VALUE && Util.eq((long) val, val))
return ValueType.INT64;
- }
else if(same(val, (float) val))
return ValueType.FP32;
else
return ValueType.FP64;
-
}
public static ValueType isType(double val, ValueType min) {
switch(min) {
case BOOLEAN:
return isType(val);
- case INT32:
case UINT8:
+ case INT32:
+ if(val < Integer.MAX_VALUE && Util.eq((int)
val,val))
+ return ValueType.INT32;
case INT64:
- if((long) (val) == val) {
- if((int) val == val)
- return ValueType.INT32;
- else
- return ValueType.INT64;
- }
+ if(val < Long.MAX_VALUE && Util.eq((long) val,
val))
+ return ValueType.INT64;
case FP32:
if(same(val, (float) val))
return ValueType.FP32;
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
new file mode 100644
index 0000000000..f5fbfd8588
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class FrameCustomTest {
+
+ @Test
+ public void castToFrame() {
+ double maxp1 = ((double) Integer.MAX_VALUE) + 1.0;
+ MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100,
maxp1, maxp1, 1.0, 23);
+ FrameBlock f = DataConverter.convertToFrameBlock(mb);
+ assertTrue(f.getSchema()[0] == ValueType.INT64);
+ }
+
+ @Test
+ public void castToFrame3() {
+ double maxp1 = ((double) Integer.MAX_VALUE) - 1.0;
+ MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100,
maxp1, maxp1, 1.0, 23);
+ FrameBlock f = DataConverter.convertToFrameBlock(mb);
+ assertTrue(f.getSchema()[0] == ValueType.INT32);
+ }
+
+ @Test
+ public void castErrorValue() {
+ MatrixBlock mb = new MatrixBlock(10, 10,
Double.parseDouble("2.572306572E9"));
+ FrameBlock f = DataConverter.convertToFrameBlock(mb);
+ assertTrue(f.getSchema()[0] == ValueType.INT64);
+
+ }
+
+ @Test
+ public void castToFrame2() {
+ double maxp1 = ((double) Integer.MAX_VALUE) + 1.1111;
+ MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100,
maxp1, maxp1, 1.0, 23);
+ FrameBlock f = DataConverter.convertToFrameBlock(mb);
+ assertTrue(f.getSchema()[0] == ValueType.FP64);
+ }
+
+}