This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 3b4f6cdf86 [MINOR] Frame tests improvement 2 3b4f6cdf86 is described below commit 3b4f6cdf86feb89eb83ad30529473f02aa39b9d4 Author: Sebastian Baunsgaard <baunsga...@apache.org> AuthorDate: Sat Sep 28 00:08:42 2024 +0200 [MINOR] Frame tests improvement 2 Add tests 100% test coverage for Frame/data/lib Closes #2120 --- .../runtime/frame/data/lib/FrameLibAppend.java | 10 +- .../frame/data/lib/FrameLibDetectSchema.java | 21 +-- .../sysds/runtime/frame/data/lib/FrameUtil.java | 39 +++-- .../frame/data/lib/MatrixBlockFromFrame.java | 114 ++++++++++----- .../sysds/runtime/matrix/data/MatrixBlock.java | 3 + .../test/component/frame/FrameCustomTest.java | 36 +++++ .../sysds/test/component/frame/FrameTest.java | 41 +++++- .../sysds/test/component/frame/FrameUtilTest.java | 148 +++++++++++++++---- .../test/component/frame/MatrixFromFrameTest.java | 162 +++++++++++++++++++++ 9 files changed, 474 insertions(+), 100 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java index 78177b1d2f..60c61e8f4a 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java @@ -33,8 +33,12 @@ import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; public class FrameLibAppend { - protected static final Log LOG = LogFactory.getLog(FrameLibAppend.class.getName()); + + private FrameLibAppend(){ + // private constructor. + } + /** * Appends the given argument FrameBlock 'that' to this FrameBlock by creating a deep copy to prevent side effects. * For cbind, the frames are appended column-wise (same number of rows), while for rbind the frames are appended @@ -50,7 +54,7 @@ public class FrameLibAppend { return ret; } - public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { + private static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { final int nRow = a.getNumRows(); final int nRowB = b.getNumRows(); @@ -73,7 +77,7 @@ public class FrameLibAppend { return new FrameBlock(_schema, _colnames, _colmeta, _coldata); } - public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { + private static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { final int nCol = a.getNumColumns(); final int nColB = b.getNumColumns(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java index 889fc853f9..2e8e2ba106 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java @@ -22,7 +22,6 @@ package org.apache.sysds.runtime.frame.data.lib; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -67,11 +66,16 @@ public final class FrameLibDetectSchema { } private FrameBlock apply() { - final int cols = in.getNumColumns(); - final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING)); - String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply(); - fb.appendRow(schemaInfo); - return fb; + try{ + final int cols = in.getNumColumns(); + final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING)); + String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply(); + fb.appendRow(schemaInfo); + return fb; + } + catch(Exception e){ + throw new DMLRuntimeException("Failed to detect schema", e); + } } private String[] singleThreadApply() { @@ -84,7 +88,7 @@ public final class FrameLibDetectSchema { return schemaInfo; } - private String[] parallelApply() { + private String[] parallelApply() throws Exception { final ExecutorService pool = CommonThreadPool.get(k); try { final int cols = in.getNumColumns(); @@ -99,9 +103,6 @@ public final class FrameLibDetectSchema { return schemaInfo; } - catch(ExecutionException | InterruptedException e) { - throw new DMLRuntimeException("Exception interrupted or exception thrown in detectSchema", e); - } finally{ pool.shutdown(); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java index e4ccd06c02..17315cf53f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java @@ -290,33 +290,30 @@ public interface FrameUtil { } public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { - String[] rowTemp1 = IteratorFactory.getStringRowIterator(temp1).next(); - String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); + final int nCol = temp1.getNumColumns(); - if(rowTemp1.length != rowTemp2.length) - throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); + if(nCol != temp2.getNumColumns()) + throw new DMLRuntimeException("Schema dimension mismatch: " + nCol + " vs " + temp2.getNumColumns()); - for(int i = 0; i < rowTemp1.length; i++) { + // hack reuse input temp1 schema, only valid if temp1 never change schema. + // However, this is typically valid. + FrameBlock mergedFrame = new FrameBlock(temp1.getSchema()); + mergedFrame.ensureAllocatedColumns(1); + for(int i = 0; i < nCol; i++) { + String s1 = (String) temp1.get(0, i); + String s2 = (String) temp2.get(0, i); // modify schema1 if necessary (different schema2) - if(!rowTemp1[i].equals(rowTemp2[i])) { - if(rowTemp1[i].equals("STRING") || rowTemp2[i].equals("STRING")) - rowTemp1[i] = "STRING"; - else if(rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64")) - rowTemp1[i] = "FP64"; - else if(rowTemp1[i].equals("FP32") && - new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i])) - rowTemp1[i] = "FP32"; - else if(rowTemp1[i].equals("INT64") && - new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i])) - rowTemp1[i] = "INT64"; - else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) - rowTemp1[i] = "INT32"; + if(!s1.equals(s2)) { + ValueType v1 = ValueType.valueOf(s1); + ValueType v2 = ValueType.valueOf(s2); + ValueType vc = ValueType.getHighestCommonTypeSafe(v1, v2); + mergedFrame.set(0, i, vc.toString()); + } + else{ + mergedFrame.set(0, i, s1); } } - // create output block one row representing the schema as strings - FrameBlock mergedFrame = new FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING)); - mergedFrame.appendRow(rowTemp1); return mergedFrame; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java index 8f0625409c..032afe2cd7 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java @@ -25,6 +25,7 @@ import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -40,28 +41,56 @@ public interface MatrixBlockFromFrame { * Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. * - * @param frame frame block - * @param k parallelization degree - * @return matrix block + * @param frame Frame block to convert + * @param k The parallelization degree + * @return MatrixBlock */ public static MatrixBlock convertToMatrixBlock(FrameBlock frame, int k) { - final int m = frame.getNumRows(); - final int n = frame.getNumColumns(); - final MatrixBlock mb = new MatrixBlock(m, n, false); - mb.allocateDenseBlock(); - if(k == -1) - k = InfrastructureAnalyzer.getLocalParallelism(); - - long nnz = 0; - if(k == 1) - nnz = convert(frame, mb, n, 0, m); - else - nnz = convertParallel(frame, mb, m, n, k); + return convertToMatrixBlock(frame, null, k); + } - mb.setNonZeros(nnz); + /** + * Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type + * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. + * + * @param frame FrameBlock to convert + * @param ret The returned MatrixBlock + * @param k The parallelization degree + * @return MatrixBlock + */ + public static MatrixBlock convertToMatrixBlock(FrameBlock frame, MatrixBlock ret, int k) { + try { - mb.examSparsity(); - return mb; + final int m = frame.getNumRows(); + final int n = frame.getNumColumns(); + ret = allocateRet(ret, m, n); + + if(k == -1) + k = InfrastructureAnalyzer.getLocalParallelism(); + + long nnz = 0; + if(k == 1) + nnz = convert(frame, ret, n, 0, m); + else + nnz = convertParallel(frame, ret, m, n, k); + + ret.setNonZeros(nnz); + ret.examSparsity(); + return ret; + } + catch(Exception e) { + throw new DMLRuntimeException("Failed to convert FrameBlock to MatrixBlock", e); + } + } + + private static MatrixBlock allocateRet(MatrixBlock ret, final int m, final int n) { + if(ret == null) + ret = new MatrixBlock(m, n, false); + else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseFormat()) + ret.reset(m, n, false); + if(!ret.isAllocated()) + ret.allocateDenseBlock(); + return ret; } private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { @@ -71,27 +100,25 @@ public interface MatrixBlockFromFrame { return convertGeneric(frame, mb, n, rl, ru); } - private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k){ + private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception { ExecutorService pool = CommonThreadPool.get(k); - try{ + try { List<Future<Long>> tasks = new ArrayList<>(); final int blkz = Math.max(m / k, 1000); - for( int i = 0; i < m; i+= blkz){ - final int start = i; + for(int i = 0; i < m; i += blkz) { + final int start = i; final int end = Math.min(i + blkz, m); tasks.add(pool.submit(() -> convert(frame, mb, n, start, end))); } long nnz = 0; - for( Future<Long> t : tasks) + for(Future<Long> t : tasks) nnz += t.get(); return nnz; } - catch(Exception e){ - throw new RuntimeException(e); - } - finally{ + + finally { pool.shutdown(); } } @@ -104,29 +131,42 @@ public interface MatrixBlockFromFrame { for(int bj = 0; bj < n; bj += blocksizeIJ) { int bimin = Math.min(bi + blocksizeIJ, ru); int bjmin = Math.min(bj + blocksizeIJ, n); - for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) - for(int j = bj; j < bjmin; j++) - lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + lnnz = convertBlockContiguous(frame, n, lnnz, c, bi, bj, bimin, bjmin); } } return lnnz; } - private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, final int ru) { + private static long convertBlockContiguous(final FrameBlock frame, final int n, long lnnz, double[] c, int rl, + int cl, int ru, int cu) { + for(int i = rl, aix = rl * n; i < ru; i++, aix += n) + for(int j = cl; j < cu; j++) + lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + return lnnz; + } + + private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, + final int ru) { long lnnz = 0; final DenseBlock c = mb.getDenseBlock(); for(int bi = rl; bi < ru; bi += blocksizeIJ) { for(int bj = 0; bj < n; bj += blocksizeIJ) { int bimin = Math.min(bi + blocksizeIJ, ru); int bjmin = Math.min(bj + blocksizeIJ, n); - for(int i = bi; i < bimin; i++) { - double[] cvals = c.values(i); - int cpos = c.pos(i); - for(int j = bj; j < bjmin; j++) - lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; - } + lnnz = convertBlockGeneric(frame, lnnz, c, bi, bj, bimin, bjmin); } } return lnnz; } + + private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl, + final int cl, final int ru, final int cu) { + for(int i = rl; i < ru; i++) { + final double[] cvals = c.values(i); + final int cpos = c.pos(i); + for(int j = cl; j < cu; j++) + lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + } + return lnnz; + } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index ead979a44b..66d616a738 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -332,6 +332,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, if(sparseBlock == null) return; sparseBlock.reset(estimatedNNzsPerRow, clen); + denseBlock = null; } private void resetDense(double val) { @@ -343,6 +344,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, allocateDenseBlock(false); denseBlock.set(val); } + sparseBlock = null; } private void resetDense(double val, boolean dedup) { @@ -354,6 +356,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, allocateDenseBlock(false, dedup); denseBlock.set(val); } + sparseBlock = null; } /** diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java index 5e482a5369..047b2da3b2 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java @@ -19,16 +19,26 @@ package org.apache.sysds.test.component.frame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend; +import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.TestUtils; import org.junit.Test; public class FrameCustomTest { + protected static final Log LOG = LogFactory.getLog(FrameCustomTest.class.getName()); @Test public void castToFrame() { @@ -61,4 +71,30 @@ public class FrameCustomTest { assertTrue(f.getSchema()[0] == ValueType.FP64); } + + @Test + public void detectSchemaError(){ + FrameBlock f = TestUtils.generateRandomFrameBlock(10, 10, 23); + FrameBlock spy = spy(f); + when(spy.getColumn(anyInt())).thenThrow(new RuntimeException()); + + Exception e = assertThrows(DMLRuntimeException.class, () -> FrameLibDetectSchema.detectSchema(spy, 3)); + + assertTrue(e.getMessage().contains("Failed to detect schema")); + } + + + + @Test + public void appendUniqueColNames(){ + FrameBlock a = new FrameBlock(new ValueType[]{ValueType.FP32}, new String[]{"Hi"}); + a.appendRow(new String[]{"0.2"}); + FrameBlock b = new FrameBlock(new ValueType[]{ValueType.FP32}, new String[]{"There"}); + b.appendRow(new String[]{"0.5"}); + + FrameBlock c = FrameLibAppend.append(a, b, true); + + assertTrue(c.getColumnName(0).equals("Hi")); + assertTrue(c.getColumnName(1).equals("There")); + } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java index bf84898344..310d497a6f 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java @@ -90,16 +90,16 @@ public class FrameTest { @AfterClass public static void cleanup() { - try{ + try { IOCompressionTestUtils.deleteDirectory(new File(nameBeginning)); } - catch(Exception e){ + catch(Exception e) { e.printStackTrace(); LOG.error("failed to delete", e); } } - + @Test public void appendSelfRBind() { FrameBlock ff = append(f, f, false); @@ -173,6 +173,34 @@ public class FrameTest { f.append(b, true); } + @Test(expected = DMLRuntimeException.class) + public void cBindEmptyCols() { + // must have same number of rows. + FrameBlock b = new FrameBlock(); + b.append(f, false); + } + + @Test(expected = DMLRuntimeException.class) + public void cBindEmptyAfterCols() { + // must have same number of rows. + FrameBlock b = new FrameBlock(); + f.append(b, false); + } + + @Test + public void cBindEmptyR() { + // must have same number of rows. + FrameBlock b = new FrameBlock(new ValueType[0], f.getNumRows()); + b.append(f, true); + } + + @Test + public void cBindEmptyAfterR() { + // must have same number of rows. + FrameBlock b = new FrameBlock(new ValueType[0], f.getNumRows()); + f.append(b, true); + } + @Test public void cBindStringColAfter() { // must have same number of rows. @@ -308,6 +336,13 @@ public class FrameTest { TestUtils.compareFrames(f, fs, true); } + @Test + public void testApplyApproxSchema() { + final FrameBlock schema = FrameLibDetectSchema.detectSchema(f, 0.1, 1); + final FrameBlock fs = FrameLibApplySchema.applySchema(f.copyShallow(), schema, 1); + TestUtils.compareFrames(f, fs, true); + } + protected static void writeAndRead(FrameBlock fb) { try { diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java index 713fc73db0..035a8b3e6d 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java @@ -283,6 +283,104 @@ public class FrameUtilTest { assertEquals(ValueType.FP64, FrameUtil.isType(33.231425155253)); } + @Test + public void mergeSchema1() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"STRING"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"FP64"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("STRING")); + } + + @Test + public void mergeSchema2() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"FP32"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"FP64"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("FP64")); + } + + @Test + public void mergeSchema3() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"INT32"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"FP64"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("FP64")); + } + + @Test + public void mergeSchema4() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"INT32"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"INT64"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("INT64")); + } + + @Test + public void mergeSchema5() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"INT32"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"STRING"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("STRING")); + } + + @Test + public void mergeSchema6() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"BOOLEAN"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"INT32"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("INT32")); + } + + @Test + public void mergeSchema7() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"BOOLEAN"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"UINT8"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("UINT8")); + } + + @Test + public void mergeSchema8() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING}); + a.appendRow(new String[] {"BOOLEAN"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"BOOLEAN"}); + + FrameBlock c = FrameUtil.mergeSchema(a, b); + assertTrue(c.get(0, 0).equals("BOOLEAN")); + } + + @Test(expected = Exception.class) + public void mergeSchemaInvalid() { + FrameBlock a = new FrameBlock(new ValueType[] {ValueType.STRING, ValueType.STRING}); + a.appendRow(new String[] {"BOOLEAN", "BOOLEAN"}); + FrameBlock b = new FrameBlock(new ValueType[] {ValueType.STRING}); + b.appendRow(new String[] {"BOOLEAN"}); + + FrameUtil.mergeSchema(a, b); + } + @Test public void testSparkFrameBlockALignment() { ValueType[] schema = new ValueType[0]; @@ -466,35 +564,33 @@ public class FrameUtilTest { assertTrue(ValueType.FP64 == FrameUtil.isType(2.2231342152323232, ValueType.FP64)); } - - @Test - public void isDefault(){ + @Test + public void isDefault() { assertTrue(FrameUtil.isDefault(null, null)); assertTrue(FrameUtil.isDefault("false", ValueType.BOOLEAN)); assertTrue(FrameUtil.isDefault("f", ValueType.BOOLEAN)); assertTrue(FrameUtil.isDefault("0", ValueType.BOOLEAN)); - assertTrue(FrameUtil.isDefault("" + (char)(0), ValueType.CHARACTER)); - assertTrue(FrameUtil.isDefault("0.0" , ValueType.FP32)); - assertTrue(FrameUtil.isDefault("0" , ValueType.FP32)); - assertTrue(FrameUtil.isDefault("0.0" , ValueType.FP64)); - assertTrue(FrameUtil.isDefault("0" , ValueType.FP64)); - assertTrue(FrameUtil.isDefault("0.0" , ValueType.INT32)); - assertTrue(FrameUtil.isDefault("0" , ValueType.INT32)); - assertTrue(FrameUtil.isDefault("0.0" , ValueType.INT64)); - assertTrue(FrameUtil.isDefault("0" , ValueType.INT64)); - - - assertFalse(FrameUtil.isDefault("0.0" , ValueType.STRING)); - assertFalse(FrameUtil.isDefault("0" , ValueType.STRING)); - assertFalse(FrameUtil.isDefault("" , ValueType.STRING)); - assertFalse(FrameUtil.isDefault("13" , ValueType.STRING)); - assertFalse(FrameUtil.isDefault("13" , ValueType.INT32)); - assertFalse(FrameUtil.isDefault("13" , ValueType.INT64)); - assertFalse(FrameUtil.isDefault("13" , ValueType.FP64)); - assertFalse(FrameUtil.isDefault("13" , ValueType.FP32)); - assertFalse(FrameUtil.isDefault("1" , ValueType.CHARACTER)); - assertFalse(FrameUtil.isDefault("0" , ValueType.CHARACTER)); - assertFalse(FrameUtil.isDefault("t" , ValueType.BOOLEAN)); - assertFalse(FrameUtil.isDefault("true" , ValueType.BOOLEAN)); + assertTrue(FrameUtil.isDefault("" + (char) (0), ValueType.CHARACTER)); + assertTrue(FrameUtil.isDefault("0.0", ValueType.FP32)); + assertTrue(FrameUtil.isDefault("0", ValueType.FP32)); + assertTrue(FrameUtil.isDefault("0.0", ValueType.FP64)); + assertTrue(FrameUtil.isDefault("0", ValueType.FP64)); + assertTrue(FrameUtil.isDefault("0.0", ValueType.INT32)); + assertTrue(FrameUtil.isDefault("0", ValueType.INT32)); + assertTrue(FrameUtil.isDefault("0.0", ValueType.INT64)); + assertTrue(FrameUtil.isDefault("0", ValueType.INT64)); + + assertFalse(FrameUtil.isDefault("0.0", ValueType.STRING)); + assertFalse(FrameUtil.isDefault("0", ValueType.STRING)); + assertFalse(FrameUtil.isDefault("", ValueType.STRING)); + assertFalse(FrameUtil.isDefault("13", ValueType.STRING)); + assertFalse(FrameUtil.isDefault("13", ValueType.INT32)); + assertFalse(FrameUtil.isDefault("13", ValueType.INT64)); + assertFalse(FrameUtil.isDefault("13", ValueType.FP64)); + assertFalse(FrameUtil.isDefault("13", ValueType.FP32)); + assertFalse(FrameUtil.isDefault("1", ValueType.CHARACTER)); + assertFalse(FrameUtil.isDefault("0", ValueType.CHARACTER)); + assertFalse(FrameUtil.isDefault("t", ValueType.BOOLEAN)); + assertFalse(FrameUtil.isDefault("true", ValueType.BOOLEAN)); } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameTest.java b/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameTest.java new file mode 100644 index 0000000000..ad6e638fbb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collection; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class MatrixFromFrameTest { + protected static final Log LOG = LogFactory.getLog(MatrixFromFrameTest.class.getName()); + + public final FrameBlock fb; + + public MatrixFromFrameTest(ValueType[] schema, int rows, long seed) { + FrameBlock tmp = TestUtils.generateRandomFrameBlock(10, schema, 3214L); + fb = FrameLibApplySchema.applySchema(tmp, schema); + } + + @Parameters + public static Collection<Object[]> data() { + ArrayList<Object[]> tests = new ArrayList<>(); + + try { + tests.add(new Object[] {new ValueType[] {ValueType.FP64, ValueType.BOOLEAN}, 10, 1324L}); + tests.add(new Object[] {new ValueType[] {ValueType.CHARACTER, ValueType.INT32}, 10, 1324L}); + tests.add(new Object[] {new ValueType[] {ValueType.HASH64, ValueType.INT64, ValueType.BOOLEAN}, 10, 1324L}); + ValueType t = ValueType.INT64; + tests.add(new Object[] {new ValueType[] {t, t, t, t, t, t, t, t, t, t}, 100, 1324L}); + + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + @Test + public void singleThread() { + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + compare(fb, mb); + } + + @Test + public void parallelThread() { + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 2); + + compare(fb, mb); + } + + @Test + public void dynamicThread() { + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, -1); + + compare(fb, mb); + } + + @Test + public void allocatedOut() { + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, new MatrixBlock(3, 3, true), -1); + + compare(fb, mb); + } + + @Test + public void allocatedOutDense() { + MatrixBlock mb = new MatrixBlock(fb.getNumRows(), fb.getNumRows(), false); + mb.allocateBlock(); + + mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, -1); + + compare(fb, mb); + } + + @Test + public void allocatedOutSparse() { + MatrixBlock mb = new MatrixBlock(fb.getNumRows(), fb.getNumRows(), true); + mb.allocateBlock(); + + mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, -1); + + compare(fb, mb); + } + + @Test + public void allocatedOutNonContinuous() { + MatrixBlock mb = new MatrixBlock(fb.getNumRows(), fb.getNumRows(), false); + mb.allocateBlock(); + DenseBlock spy = spy(mb.getDenseBlock()); + when(spy.isContiguous()).thenReturn(false); + mb.setDenseBlock(spy); + + mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, -1); + + compare(fb, mb); + } + + @Test + public void testException() { + MatrixBlock mb = new MatrixBlock(fb.getNumRows(), fb.getNumRows(), false); + mb.allocateBlock(); + MatrixBlock spy = spy(mb); + when(spy.getDenseBlockValues()).thenThrow(new RuntimeException()); + + Exception e = assertThrows(DMLRuntimeException.class, + () -> MatrixBlockFromFrame.convertToMatrixBlock(fb, spy, -1)); + + assertTrue(e.getMessage().contains("Failed to convert FrameBlock to MatrixBlock")); + } + + private void compare(FrameBlock fb, MatrixBlock mb) { + for(int i = 0; i < fb.getNumRows(); i++) { + for(int j = 0; j < fb.getNumColumns(); j++) { + assertEquals(fb.getColumn(j).getAsNaNDouble(i), mb.get(i, j), 0.0); + } + } + } +}