This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 47b1b904cb2db2b3975ed98d9c2d4318821fabbf
Author: Sebastian Baunsgaard <baunsga...@apache.org>
AuthorDate: Thu Aug 15 12:40:35 2024 +0200

    [SYSTEMDS-3548] Optimize IO Path, Extended Tests
    
    Add missing test cases for various data types,
    and negative tests.
    
    Closes #2065
---
 .../sysds/runtime/util/Py4jConverterUtils.java     |  12 +--
 .../frame/array/Py4jConverterUtilsTest.java        | 103 ++++++++++++++++++++-
 2 files changed, 108 insertions(+), 7 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java 
b/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java
index abc9abb4fd..be6a749fb7 100644
--- a/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/util/Py4jConverterUtils.java
@@ -130,11 +130,6 @@ public class Py4jConverterUtils {
 
                // Process the data based on the value type
                switch(valueType) {
-                       case UINT4:
-                               for(int i = 0; i < numElements; i++) {
-                                       array.set(i, (int) (buffer.get() & 
0xFF));
-                               }
-                               break;
                        case UINT8:
                                for(int i = 0; i < numElements; i++) {
                                        array.set(i, (int) (buffer.get() & 
0xFF));
@@ -177,7 +172,12 @@ public class Py4jConverterUtils {
                                break;
                        case CHARACTER:
                                for(int i = 0; i < numElements; i++) {
-                                       array.set(i, (char) buffer.get());
+                                       array.set(i, buffer.getChar());
+                               }
+                               break;
+                       case HASH32:
+                               for(int i = 0; i < numElements; i++) {
+                                       array.set(i, buffer.getInt());
                                }
                                break;
                        case HASH64:
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java
 
b/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java
index 965be8d71a..980165c3ab 100644
--- 
a/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/frame/array/Py4jConverterUtilsTest.java
@@ -27,7 +27,7 @@ import java.nio.ByteOrder;
 import java.nio.charset.StandardCharsets;
 
 import org.apache.sysds.common.Types;
-
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.util.Py4jConverterUtils;
 import org.apache.sysds.runtime.frame.data.columns.Array;
 import org.junit.Test;
@@ -64,6 +64,75 @@ public class Py4jConverterUtilsTest {
                assertEquals(4, result.get(3));
        }
 
+       @Test
+       public void testConvertINT64() {
+               int numElements = 4;
+               ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES * 
numElements);
+               buffer.order(ByteOrder.nativeOrder());
+               for(int i = 1; i <= numElements; i++) {
+                       buffer.putLong((long) i);
+               }
+               Array<?> result = Py4jConverterUtils.convert(buffer.array(), 
numElements, Types.ValueType.INT64);
+               assertNotNull(result);
+               assertEquals(4, result.size());
+               assertEquals(1L, result.get(0));
+               assertEquals(2L, result.get(1));
+               assertEquals(3L, result.get(2));
+               assertEquals(4L, result.get(3));
+       }
+
+
+       @Test
+       public void testConvertHASH32() {
+               int numElements = 4;
+               ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES * 
numElements);
+               buffer.order(ByteOrder.nativeOrder());
+               for(int i = 1; i <= numElements; i++) {
+                       buffer.putInt(i);
+               }
+               Array<?> result = Py4jConverterUtils.convert(buffer.array(), 
numElements, Types.ValueType.HASH32);
+               assertNotNull(result);
+               assertEquals(4, result.size());
+               assertEquals("1", result.get(0));
+               assertEquals("2", result.get(1));
+               assertEquals("3", result.get(2));
+               assertEquals("4", result.get(3));
+       }
+
+       @Test
+       public void testConvertHASH64() {
+               int numElements = 4;
+               ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES * 
numElements);
+               buffer.order(ByteOrder.nativeOrder());
+               for(int i = 1; i <= numElements; i++) {
+                       buffer.putLong((long) i);
+               }
+               Array<?> result = Py4jConverterUtils.convert(buffer.array(), 
numElements, Types.ValueType.HASH64);
+               assertNotNull(result);
+               assertEquals(4, result.size());
+               assertEquals("1", result.get(0));
+               assertEquals("2", result.get(1));
+               assertEquals("3", result.get(2));
+               assertEquals("4", result.get(3));
+       }
+
+       @Test
+       public void testConvertFP32() {
+               int numElements = 4;
+               ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * 
numElements);
+               buffer.order(ByteOrder.nativeOrder());
+               for(float i = 1.1f; i <= numElements + 1; i += 1.0) {
+                       buffer.putFloat(i);
+               }
+               Array<?> result = Py4jConverterUtils.convert(buffer.array(), 
numElements, Types.ValueType.FP32);
+               assertNotNull(result);
+               assertEquals(4, result.size());
+               assertEquals(1.1f, result.get(0));
+               assertEquals(2.1f, result.get(1));
+               assertEquals(3.1f, result.get(2));
+               assertEquals(4.1f, result.get(3));
+       }
+
        @Test
        public void testConvertFP64() {
                int numElements = 4;
@@ -112,4 +181,36 @@ public class Py4jConverterUtilsTest {
                assertEquals("hello", result.get(0));
                assertEquals("world", result.get(1));
        }
+
+       @Test
+       public void testConvertChar() {
+               char[] c = {'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 
'd'};
+               ByteBuffer buffer = ByteBuffer.allocate(Character.BYTES * 
c.length);
+               buffer.order(ByteOrder.LITTLE_ENDIAN);
+               for(char s : c) {
+                       buffer.putChar(s);
+               }
+               Array<?> result = Py4jConverterUtils.convert(buffer.array(), 
c.length, Types.ValueType.CHARACTER);
+               assertNotNull(result);
+               assertEquals(c.length, result.size());
+
+               for(int i = 0; i < c.length; i++) {
+                       assertEquals(c[i], result.get(i));
+               }
+       }
+
+       @Test(expected = Exception.class)
+       public void nullData() {
+               Py4jConverterUtils.convert(null, 14, ValueType.BOOLEAN);
+       }
+
+       @Test(expected = Exception.class)
+       public void nullValueType() {
+               Py4jConverterUtils.convert(new byte[] {1, 2, 3}, 14, null);
+       }
+
+       @Test(expected = Exception.class)
+       public void unknownValueType() {
+               Py4jConverterUtils.convert(new byte[] {1, 2, 3}, 14, 
ValueType.UNKNOWN);
+       }
 }

Reply via email to