This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 47b1b904cb2db2b3975ed98d9c2d4318821fabbf Author: Sebastian Baunsgaard <baunsga...@apache.org> AuthorDate: Thu Aug 15 12:40:35 2024 +0200 [SYSTEMDS-3548] Optimize IO Path, Extended Tests Add missing test cases for various data types, and negative tests. Closes #2065 --- .../sysds/runtime/util/Py4jConverterUtils.java | 12 +-- .../frame/array/Py4jConverterUtilsTest.java | 103 ++++++++++++++++++++- 2 files changed, 108 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java b/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java index abc9abb4fd..be6a749fb7 100644 --- a/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java +++ b/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java @@ -130,11 +130,6 @@ public class Py4jConverterUtils { // Process the data based on the value type switch(valueType) { - case UINT4: - for(int i = 0; i < numElements; i++) { - array.set(i, (int) (buffer.get() & 0xFF)); - } - break; case UINT8: for(int i = 0; i < numElements; i++) { array.set(i, (int) (buffer.get() & 0xFF)); @@ -177,7 +172,12 @@ public class Py4jConverterUtils { break; case CHARACTER: for(int i = 0; i < numElements; i++) { - array.set(i, (char) buffer.get()); + array.set(i, buffer.getChar()); + } + break; + case HASH32: + for(int i = 0; i < numElements; i++) { + array.set(i, buffer.getInt()); } break; case HASH64: diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java b/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java index 965be8d71a..980165c3ab 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java @@ -27,7 +27,7 @@ import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import org.apache.sysds.common.Types; - +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.util.Py4jConverterUtils; import org.apache.sysds.runtime.frame.data.columns.Array; import org.junit.Test; @@ -64,6 +64,75 @@ public class Py4jConverterUtilsTest { assertEquals(4, result.get(3)); } + @Test + public void testConvertINT64() { + int numElements = 4; + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES * numElements); + buffer.order(ByteOrder.nativeOrder()); + for(int i = 1; i <= numElements; i++) { + buffer.putLong((long) i); + } + Array<?> result = Py4jConverterUtils.convert(buffer.array(), numElements, Types.ValueType.INT64); + assertNotNull(result); + assertEquals(4, result.size()); + assertEquals(1L, result.get(0)); + assertEquals(2L, result.get(1)); + assertEquals(3L, result.get(2)); + assertEquals(4L, result.get(3)); + } + + + @Test + public void testConvertHASH32() { + int numElements = 4; + ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES * numElements); + buffer.order(ByteOrder.nativeOrder()); + for(int i = 1; i <= numElements; i++) { + buffer.putInt(i); + } + Array<?> result = Py4jConverterUtils.convert(buffer.array(), numElements, Types.ValueType.HASH32); + assertNotNull(result); + assertEquals(4, result.size()); + assertEquals("1", result.get(0)); + assertEquals("2", result.get(1)); + assertEquals("3", result.get(2)); + assertEquals("4", result.get(3)); + } + + @Test + public void testConvertHASH64() { + int numElements = 4; + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES * numElements); + buffer.order(ByteOrder.nativeOrder()); + for(int i = 1; i <= numElements; i++) { + buffer.putLong((long) i); + } + Array<?> result = Py4jConverterUtils.convert(buffer.array(), numElements, Types.ValueType.HASH64); + assertNotNull(result); + assertEquals(4, result.size()); + assertEquals("1", result.get(0)); + assertEquals("2", result.get(1)); + assertEquals("3", result.get(2)); + assertEquals("4", result.get(3)); + } + + @Test + public void testConvertFP32() { + int numElements = 4; + ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * numElements); + buffer.order(ByteOrder.nativeOrder()); + for(float i = 1.1f; i <= numElements + 1; i += 1.0) { + buffer.putFloat(i); + } + Array<?> result = Py4jConverterUtils.convert(buffer.array(), numElements, Types.ValueType.FP32); + assertNotNull(result); + assertEquals(4, result.size()); + assertEquals(1.1f, result.get(0)); + assertEquals(2.1f, result.get(1)); + assertEquals(3.1f, result.get(2)); + assertEquals(4.1f, result.get(3)); + } + @Test public void testConvertFP64() { int numElements = 4; @@ -112,4 +181,36 @@ public class Py4jConverterUtilsTest { assertEquals("hello", result.get(0)); assertEquals("world", result.get(1)); } + + @Test + public void testConvertChar() { + char[] c = {'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'}; + ByteBuffer buffer = ByteBuffer.allocate(Character.BYTES * c.length); + buffer.order(ByteOrder.LITTLE_ENDIAN); + for(char s : c) { + buffer.putChar(s); + } + Array<?> result = Py4jConverterUtils.convert(buffer.array(), c.length, Types.ValueType.CHARACTER); + assertNotNull(result); + assertEquals(c.length, result.size()); + + for(int i = 0; i < c.length; i++) { + assertEquals(c[i], result.get(i)); + } + } + + @Test(expected = Exception.class) + public void nullData() { + Py4jConverterUtils.convert(null, 14, ValueType.BOOLEAN); + } + + @Test(expected = Exception.class) + public void nullValueType() { + Py4jConverterUtils.convert(new byte[] {1, 2, 3}, 14, null); + } + + @Test(expected = Exception.class) + public void unknownValueType() { + Py4jConverterUtils.convert(new byte[] {1, 2, 3}, 14, ValueType.UNKNOWN); + } }