This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 021467d59266e4f7a0e97c7cc0040a52d550e86d Author: baunsgaard <[email protected]> AuthorDate: Sat Jan 21 16:20:16 2023 +0100 [MINOR] 100% FrameIterator Tests This commit adds full iterator tests, and found 1 null pointer exception. --- .../frame/data/iterators/IteratorFactory.java | 48 +++- .../frame/data/iterators/ObjectRowIterator.java | 8 +- .../runtime/frame/data/iterators/RowIterator.java | 6 +- .../frame/data/iterators/StringRowIterator.java | 4 +- .../component/frame/iterators/IteratorTest.java | 268 +++++++++++++++++++++ 5 files changed, 314 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java index 4560b1736e..a9bf87e681 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.frame.data.iterators; -import java.util.Iterator; - import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -35,7 +33,7 @@ public interface IteratorFactory { * @param fb The frame to iterate through * @return string array iterator */ - public static Iterator<String[]> getStringRowIterator(FrameBlock fb) { + public static RowIterator<String> getStringRowIterator(FrameBlock fb) { return new StringRowIterator(fb, 0, fb.getNumRows()); } @@ -47,7 +45,7 @@ public interface IteratorFactory { * @param cols column selection, 1-based * @return string array iterator */ - public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int[] cols) { + public static RowIterator<String> getStringRowIterator(FrameBlock fb, int[] cols) { return new StringRowIterator(fb, 0, fb.getNumRows(), cols); } @@ -59,7 +57,7 @@ public interface IteratorFactory { * @param colID column selection, 1-based * @return string array iterator */ - public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int colID) { + public static RowIterator<String> getStringRowIterator(FrameBlock fb, int colID) { return new StringRowIterator(fb, 0, fb.getNumRows(), new int[] {colID}); } @@ -71,7 +69,7 @@ public interface IteratorFactory { * @param ru upper row index * @return string array iterator */ - public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int rl, int ru) { + public static RowIterator<String> getStringRowIterator(FrameBlock fb, int rl, int ru) { return new StringRowIterator(fb, rl, ru); } @@ -85,7 +83,7 @@ public interface IteratorFactory { * @param cols column selection, 1-based * @return string array iterator */ - public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + public static RowIterator<String> getStringRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { return new StringRowIterator(fb, rl, ru, cols); } @@ -99,7 +97,7 @@ public interface IteratorFactory { * @param colID columnID, 1-based * @return string array iterator */ - public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int rl, int ru, int colID) { + public static RowIterator<String> getStringRowIterator(FrameBlock fb, int rl, int ru, int colID) { return new StringRowIterator(fb, rl, ru, new int[] {colID}); } @@ -109,7 +107,7 @@ public interface IteratorFactory { * @param fb The frame to iterate through * @return object array iterator */ - public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb) { + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb) { return new ObjectRowIterator(fb, 0, fb.getNumRows()); } @@ -121,7 +119,7 @@ public interface IteratorFactory { * @param schema target schema of objects * @return object array iterator */ - public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, ValueType[] schema) { + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb, ValueType[] schema) { return new ObjectRowIterator(fb, 0, fb.getNumRows(), schema); } @@ -133,10 +131,21 @@ public interface IteratorFactory { * @param cols column selection, 1-based * @return object array iterator */ - public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, int[] cols) { + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb, int[] cols) { return new ObjectRowIterator(fb, 0, fb.getNumRows(), cols); } + /** + * Get a row iterator over the frame where all selected fields are encoded as objects according to their value types. + * + * @param fb The frame to iterate through + * @param colID column selection, 1-based + * @return object array iterator + */ + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb, int colID) { + return new ObjectRowIterator(fb, 0, fb.getNumRows(), new int[] {colID}); + } + /** * Get a row iterator over the frame where all fields are encoded as boxed objects according to their value types. * @@ -145,7 +154,7 @@ public interface IteratorFactory { * @param ru upper row index * @return object array iterator */ - public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, int rl, int ru) { + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb, int rl, int ru) { return new ObjectRowIterator(fb, rl, ru); } @@ -159,8 +168,21 @@ public interface IteratorFactory { * @param cols column selection, 1-based * @return object array iterator */ - public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { return new ObjectRowIterator(fb, rl, ru, cols); } + /** + * Get a row iterator over the frame where all selected fields are encoded as boxed objects according to their value + * types. + * + * @param fb The frame to iterate through + * @param rl lower row index + * @param ru upper row index + * @param colID column selection, 1-based + * @return object array iterator + */ + public static RowIterator<Object> getObjectRowIterator(FrameBlock fb, int rl, int ru, int colID) { + return new ObjectRowIterator(fb, rl, ru, new int[] {colID}); + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java index 584a6a3173..d88b0c8781 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java @@ -26,19 +26,19 @@ import org.apache.sysds.runtime.util.UtilFunctions; public class ObjectRowIterator extends RowIterator<Object> { private final ValueType[] _tgtSchema; - public ObjectRowIterator(FrameBlock fb, int rl, int ru) { + protected ObjectRowIterator(FrameBlock fb, int rl, int ru) { this(fb, rl, ru, UtilFunctions.getSeqArray(1, fb.getNumColumns(), 1), null); } - public ObjectRowIterator(FrameBlock fb, int rl, int ru, ValueType[] schema) { + protected ObjectRowIterator(FrameBlock fb, int rl, int ru, ValueType[] schema) { this(fb, rl, ru, UtilFunctions.getSeqArray(1, fb.getNumColumns(), 1), schema); } - public ObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + protected ObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { this(fb, rl, ru, cols, null); } - public ObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols, ValueType[] schema){ + protected ObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols, ValueType[] schema){ super(fb, rl, ru, cols); _tgtSchema = schema; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java index 68266fd9fa..8aea65bace 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java @@ -23,6 +23,7 @@ import java.util.Iterator; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -41,6 +42,9 @@ public abstract class RowIterator<T> implements Iterator<T[]> { } protected RowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + if(rl < 0 || ru > fb.getNumRows() || rl > ru) + throw new DMLRuntimeException("Invalid range of iterator: " + rl + "->" + ru); + _fb = fb; _curRow = createRow(cols.length); _cols = cols; @@ -55,7 +59,7 @@ public abstract class RowIterator<T> implements Iterator<T[]> { @Override public void remove() { - throw new RuntimeException("RowIterator.remove is unsupported!"); + throw new DMLRuntimeException("RowIterator.remove() is unsupported!"); } protected abstract T[] createRow(int size); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java index 3647ce5106..605471f270 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java @@ -22,11 +22,11 @@ package org.apache.sysds.runtime.frame.data.iterators; import org.apache.sysds.runtime.frame.data.FrameBlock; public class StringRowIterator extends RowIterator<String> { - public StringRowIterator(FrameBlock fb, int rl, int ru) { + protected StringRowIterator(FrameBlock fb, int rl, int ru) { super(fb, rl, ru); } - public StringRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + protected StringRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { super(fb, rl, ru, cols); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java b/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java new file mode 100644 index 0000000000..c6f5bfd621 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.iterators; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; +import org.apache.sysds.runtime.frame.data.iterators.RowIterator; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class IteratorTest { + + private final FrameBlock fb1 = TestUtils.generateRandomFrameBlock(10, 10, 23); + private final FrameBlock fb2 = TestUtils.generateRandomFrameBlock(40, 30, 22); + + @Test + public void StringObjectStringFB1() { + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb1); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb1); + compareIterators(a, b); + } + + @Test + public void StringObjectStringFB2() { + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb2); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2); + compareIterators(a, b); + } + + @Test + public void StringObjectStringNotEquals() { + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb1); + a.next(); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb1); + assertNotEquals(Arrays.toString(a.next()), Arrays.toString(b.next())); + } + + @Test + public void StringObjectStringNotEqualsFB1vsFB2() { + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb1); + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb2); + assertNotEquals(Arrays.toString(a.next()), Arrays.toString(b.next())); + } + + @Test + public void compareSubRangesFB1() { + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb1, 1, fb1.getNumRows()); + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb1); + b.next(); + compareIterators(a, b); + } + + @Test + public void compareSubRangesFB2() { + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb2, 1, fb2.getNumRows()); + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb2); + b.next(); + compareIterators(a, b); + } + + @Test + public void compareSubRangesStringFB1() { + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb1, 1, fb1.getNumRows()); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb1); + b.next(); + compareIterators(a, b); + } + + @Test + public void compareSubRangesStringFB2() { + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2, 1, fb2.getNumRows()); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2); + b.next(); + compareIterators(a, b); + } + + @Test + public void iteratorObjectSelectColumns() { + FrameBlock fb1Slice = fb1.slice(0, fb1.getNumRows() - 1, 1, fb1.getNumColumns() - 1); + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb1Slice); + int[] select = new int[fb1.getNumColumns() - 1]; + for(int i = 0; i < fb1.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb1, select); + compareIterators(a, b); + } + + @Test + public void iteratorObjectSelectColumnsFB2() { + FrameBlock fb2Slice = fb2.slice(0, fb2.getNumRows() - 1, 1, fb2.getNumColumns() - 1); + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb2Slice); + int[] select = new int[fb2.getNumColumns() - 1]; + for(int i = 0; i < fb2.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb2, select); + compareIterators(a, b); + } + + @Test + public void iteratorStringSelectColumns() { + FrameBlock fb1Slice = fb1.slice(0, fb1.getNumRows() - 1, 1, fb1.getNumColumns() - 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb1Slice); + int[] select = new int[fb1.getNumColumns() - 1]; + for(int i = 0; i < fb1.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb1, select); + compareIterators(a, b); + } + + @Test + public void iteratorStringSelectColumnsFB2() { + FrameBlock fb2Slice = fb2.slice(0, fb2.getNumRows() - 1, 1, fb2.getNumColumns() - 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2Slice); + int[] select = new int[fb2.getNumColumns() - 1]; + for(int i = 0; i < fb2.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2, select); + compareIterators(a, b); + } + + @Test + public void iteratorStringSelectColumnsSubRowsFB2() { + FrameBlock fb2Slice = fb2.slice(1, fb2.getNumRows() - 1, 1, fb2.getNumColumns() - 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2Slice); + int[] select = new int[fb2.getNumColumns() - 1]; + for(int i = 0; i < fb2.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2, 1, fb2.getNumRows(), select); + compareIterators(a, b); + } + + @Test + public void iteratorObjectSelectColumnsSubRowsFB2() { + FrameBlock fb2Slice = fb2.slice(1, fb2.getNumRows() - 1, 1, fb2.getNumColumns() - 1); + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb2Slice); + int[] select = new int[fb2.getNumColumns() - 1]; + for(int i = 0; i < fb2.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb2, 1, fb2.getNumRows(), select); + compareIterators(a, b); + } + + @Test + public void iteratorStringSelectSingleColumnSubRowsFB2() { + FrameBlock fb2Slice = fb2.slice(1, fb2.getNumRows() - 1, 1, 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2Slice); + int[] select = new int[fb2.getNumColumns() - 1]; + for(int i = 0; i < fb2.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2, 1, fb2.getNumRows(), 2); + compareIterators(a, b); + } + + @Test + public void iteratorObjectSelectSingleColumnSubRowsFB2() { + FrameBlock fb2Slice = fb2.slice(1, fb2.getNumRows() - 1, 1, 1); + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb2Slice); + int[] select = new int[fb2.getNumColumns() - 1]; + for(int i = 0; i < fb2.getNumColumns() - 1; i++) { + select[i] = i + 2; + } + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb2, 1, fb2.getNumRows(), 2); + compareIterators(a, b); + } + + @Test + public void iteratorColumnIdFB1() { + FrameBlock fb1Slice = fb1.slice(0, fb1.getNumRows() - 1, 1, 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb1Slice); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb1, 2); + compareIterators(a, b); + } + + @Test + public void iteratorColumnId() { + FrameBlock fb2Slice = fb2.slice(0, fb2.getNumRows() - 1, 1, 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2Slice); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2, 2); + compareIterators(a, b); + } + + @Test + public void iteratorColumnIdObjectFB1() { + FrameBlock fb1Slice = fb1.slice(0, fb1.getNumRows() - 1, 1, 1); + RowIterator<Object> a = IteratorFactory.getObjectRowIterator(fb1Slice); + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb1, 2); + compareIterators(a, b); + } + + @Test + public void iteratorColumnObjectId() { + FrameBlock fb2Slice = fb2.slice(0, fb2.getNumRows() - 1, 1, 1); + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2Slice); + RowIterator<String> b = IteratorFactory.getStringRowIterator(fb2, 2); + compareIterators(a, b); + } + + @Test + public void iteratorWithSchema() { + RowIterator<String> a = IteratorFactory.getStringRowIterator(fb2); + RowIterator<Object> b = IteratorFactory.getObjectRowIterator(fb2, // + UtilFunctions.nCopies(fb2.getNumColumns(), ValueType.STRING)); + compareIterators(a, b); + } + + + @Test(expected= DMLRuntimeException.class) + public void invalidRange1(){ + IteratorFactory.getStringRowIterator(fb2, -1, 1); + } + + @Test(expected= DMLRuntimeException.class) + public void invalidRange2(){ + IteratorFactory.getStringRowIterator(fb2, 132415, 132416); + } + + @Test(expected= DMLRuntimeException.class) + public void invalidRange3(){ + IteratorFactory.getStringRowIterator(fb2, 13, 4); + } + + @Test(expected= DMLRuntimeException.class) + public void remove(){ + RowIterator<?> a =IteratorFactory.getStringRowIterator(fb2, 0, 4); + a.remove(); + } + + + private static void compareIterators(RowIterator<?> a, RowIterator<?> b) { + while(a.hasNext()) { + assertTrue(b.hasNext()); + assertEquals(Arrays.toString(a.next()), Arrays.toString(b.next())); + } + } +}
