This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3b4f6cdf86 [MINOR] Frame tests improvement 2
3b4f6cdf86 is described below
commit 3b4f6cdf86feb89eb83ad30529473f02aa39b9d4
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sat Sep 28 00:08:42 2024 +0200
[MINOR] Frame tests improvement 2
Add tests 100% test coverage for Frame/data/lib
Closes #2120
---
.../runtime/frame/data/lib/FrameLibAppend.java | 10 +-
.../frame/data/lib/FrameLibDetectSchema.java | 21 +--
.../sysds/runtime/frame/data/lib/FrameUtil.java | 39 +++--
.../frame/data/lib/MatrixBlockFromFrame.java | 114 ++++++++++-----
.../sysds/runtime/matrix/data/MatrixBlock.java | 3 +
.../test/component/frame/FrameCustomTest.java | 36 +++++
.../sysds/test/component/frame/FrameTest.java | 41 +++++-
.../sysds/test/component/frame/FrameUtilTest.java | 148 +++++++++++++++----
.../test/component/frame/MatrixFromFrameTest.java | 162 +++++++++++++++++++++
9 files changed, 474 insertions(+), 100 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
index 78177b1d2f..60c61e8f4a 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java
@@ -33,8 +33,12 @@ import
org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
public class FrameLibAppend {
-
protected static final Log LOG =
LogFactory.getLog(FrameLibAppend.class.getName());
+
+ private FrameLibAppend(){
+ // private constructor.
+ }
+
/**
* Appends the given argument FrameBlock 'that' to this FrameBlock by
creating a deep copy to prevent side effects.
* For cbind, the frames are appended column-wise (same number of
rows), while for rbind the frames are appended
@@ -50,7 +54,7 @@ public class FrameLibAppend {
return ret;
}
- public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) {
+ private static FrameBlock appendCbind(FrameBlock a, FrameBlock b) {
final int nRow = a.getNumRows();
final int nRowB = b.getNumRows();
@@ -73,7 +77,7 @@ public class FrameLibAppend {
return new FrameBlock(_schema, _colnames, _colmeta, _coldata);
}
- public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
+ private static FrameBlock appendRbind(FrameBlock a, FrameBlock b) {
final int nCol = a.getNumColumns();
final int nColB = b.getNumColumns();
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
index 889fc853f9..2e8e2ba106 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.frame.data.lib;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
@@ -67,11 +66,16 @@ public final class FrameLibDetectSchema {
}
private FrameBlock apply() {
- final int cols = in.getNumColumns();
- final FrameBlock fb = new
FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING));
- String[] schemaInfo = (k == 1) ? singleThreadApply() :
parallelApply();
- fb.appendRow(schemaInfo);
- return fb;
+ try{
+ final int cols = in.getNumColumns();
+ final FrameBlock fb = new
FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING));
+ String[] schemaInfo = (k == 1) ? singleThreadApply() :
parallelApply();
+ fb.appendRow(schemaInfo);
+ return fb;
+ }
+ catch(Exception e){
+ throw new DMLRuntimeException("Failed to detect
schema", e);
+ }
}
private String[] singleThreadApply() {
@@ -84,7 +88,7 @@ public final class FrameLibDetectSchema {
return schemaInfo;
}
- private String[] parallelApply() {
+ private String[] parallelApply() throws Exception {
final ExecutorService pool = CommonThreadPool.get(k);
try {
final int cols = in.getNumColumns();
@@ -99,9 +103,6 @@ public final class FrameLibDetectSchema {
return schemaInfo;
}
- catch(ExecutionException | InterruptedException e) {
- throw new DMLRuntimeException("Exception interrupted or
exception thrown in detectSchema", e);
- }
finally{
pool.shutdown();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
index e4ccd06c02..17315cf53f 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java
@@ -290,33 +290,30 @@ public interface FrameUtil {
}
public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock
temp2) {
- String[] rowTemp1 =
IteratorFactory.getStringRowIterator(temp1).next();
- String[] rowTemp2 =
IteratorFactory.getStringRowIterator(temp2).next();
+ final int nCol = temp1.getNumColumns();
- if(rowTemp1.length != rowTemp2.length)
- throw new DMLRuntimeException("Schema dimension " +
"mismatch: " + rowTemp1.length + " vs " + rowTemp2.length);
+ if(nCol != temp2.getNumColumns())
+ throw new DMLRuntimeException("Schema dimension
mismatch: " + nCol + " vs " + temp2.getNumColumns());
- for(int i = 0; i < rowTemp1.length; i++) {
+ // hack reuse input temp1 schema, only valid if temp1 never
change schema.
+ // However, this is typically valid.
+ FrameBlock mergedFrame = new FrameBlock(temp1.getSchema());
+ mergedFrame.ensureAllocatedColumns(1);
+ for(int i = 0; i < nCol; i++) {
+ String s1 = (String) temp1.get(0, i);
+ String s2 = (String) temp2.get(0, i);
// modify schema1 if necessary (different schema2)
- if(!rowTemp1[i].equals(rowTemp2[i])) {
- if(rowTemp1[i].equals("STRING") ||
rowTemp2[i].equals("STRING"))
- rowTemp1[i] = "STRING";
- else if(rowTemp1[i].equals("FP64") ||
rowTemp2[i].equals("FP64"))
- rowTemp1[i] = "FP64";
- else if(rowTemp1[i].equals("FP32") &&
- new ArrayList<>(Arrays.asList("INT64",
"INT32", "CHARACTER")).contains(rowTemp2[i]))
- rowTemp1[i] = "FP32";
- else if(rowTemp1[i].equals("INT64") &&
- new ArrayList<>(Arrays.asList("INT32",
"CHARACTER")).contains(rowTemp2[i]))
- rowTemp1[i] = "INT64";
- else if(rowTemp1[i].equals("INT32") ||
rowTemp2[i].equals("CHARACTER"))
- rowTemp1[i] = "INT32";
+ if(!s1.equals(s2)) {
+ ValueType v1 = ValueType.valueOf(s1);
+ ValueType v2 = ValueType.valueOf(s2);
+ ValueType vc =
ValueType.getHighestCommonTypeSafe(v1, v2);
+ mergedFrame.set(0, i, vc.toString());
+ }
+ else{
+ mergedFrame.set(0, i, s1);
}
}
- // create output block one row representing the schema as
strings
- FrameBlock mergedFrame = new
FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING));
- mergedFrame.appendRow(rowTemp1);
return mergedFrame;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java
index 8f0625409c..032afe2cd7 100644
---
a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java
+++
b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java
@@ -25,6 +25,7 @@ import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -40,28 +41,56 @@ public interface MatrixBlockFromFrame {
* Converts a frame block with arbitrary schema into a matrix block.
Since matrix block only supports value type
* double, we do a best effort conversion of non-double types which
might result in errors for non-numerical data.
*
- * @param frame frame block
- * @param k parallelization degree
- * @return matrix block
+ * @param frame Frame block to convert
+ * @param k The parallelization degree
+ * @return MatrixBlock
*/
public static MatrixBlock convertToMatrixBlock(FrameBlock frame, int k)
{
- final int m = frame.getNumRows();
- final int n = frame.getNumColumns();
- final MatrixBlock mb = new MatrixBlock(m, n, false);
- mb.allocateDenseBlock();
- if(k == -1)
- k = InfrastructureAnalyzer.getLocalParallelism();
-
- long nnz = 0;
- if(k == 1)
- nnz = convert(frame, mb, n, 0, m);
- else
- nnz = convertParallel(frame, mb, m, n, k);
+ return convertToMatrixBlock(frame, null, k);
+ }
- mb.setNonZeros(nnz);
+ /**
+ * Converts a frame block with arbitrary schema into a matrix block.
Since matrix block only supports value type
+ * double, we do a best effort conversion of non-double types which
might result in errors for non-numerical data.
+ *
+ * @param frame FrameBlock to convert
+ * @param ret The returned MatrixBlock
+ * @param k The parallelization degree
+ * @return MatrixBlock
+ */
+ public static MatrixBlock convertToMatrixBlock(FrameBlock frame,
MatrixBlock ret, int k) {
+ try {
- mb.examSparsity();
- return mb;
+ final int m = frame.getNumRows();
+ final int n = frame.getNumColumns();
+ ret = allocateRet(ret, m, n);
+
+ if(k == -1)
+ k =
InfrastructureAnalyzer.getLocalParallelism();
+
+ long nnz = 0;
+ if(k == 1)
+ nnz = convert(frame, ret, n, 0, m);
+ else
+ nnz = convertParallel(frame, ret, m, n, k);
+
+ ret.setNonZeros(nnz);
+ ret.examSparsity();
+ return ret;
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("Failed to convert
FrameBlock to MatrixBlock", e);
+ }
+ }
+
+ private static MatrixBlock allocateRet(MatrixBlock ret, final int m,
final int n) {
+ if(ret == null)
+ ret = new MatrixBlock(m, n, false);
+ else if(ret.getNumRows() != m || ret.getNumColumns() != n ||
ret.isInSparseFormat())
+ ret.reset(m, n, false);
+ if(!ret.isAllocated())
+ ret.allocateDenseBlock();
+ return ret;
}
private static long convert(FrameBlock frame, MatrixBlock mb, int n,
int rl, int ru) {
@@ -71,27 +100,25 @@ public interface MatrixBlockFromFrame {
return convertGeneric(frame, mb, n, rl, ru);
}
- private static long convertParallel(FrameBlock frame, MatrixBlock mb,
int m, int n, int k){
+ private static long convertParallel(FrameBlock frame, MatrixBlock mb,
int m, int n, int k) throws Exception {
ExecutorService pool = CommonThreadPool.get(k);
- try{
+ try {
List<Future<Long>> tasks = new ArrayList<>();
final int blkz = Math.max(m / k, 1000);
- for( int i = 0; i < m; i+= blkz){
- final int start = i;
+ for(int i = 0; i < m; i += blkz) {
+ final int start = i;
final int end = Math.min(i + blkz, m);
tasks.add(pool.submit(() -> convert(frame, mb,
n, start, end)));
}
long nnz = 0;
- for( Future<Long> t : tasks)
+ for(Future<Long> t : tasks)
nnz += t.get();
return nnz;
}
- catch(Exception e){
- throw new RuntimeException(e);
- }
- finally{
+
+ finally {
pool.shutdown();
}
}
@@ -104,29 +131,42 @@ public interface MatrixBlockFromFrame {
for(int bj = 0; bj < n; bj += blocksizeIJ) {
int bimin = Math.min(bi + blocksizeIJ, ru);
int bjmin = Math.min(bj + blocksizeIJ, n);
- for(int i = bi, aix = bi * n; i < bimin; i++,
aix += n)
- for(int j = bj; j < bjmin; j++)
- lnnz += (c[aix + j] =
frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
+ lnnz = convertBlockContiguous(frame, n, lnnz,
c, bi, bj, bimin, bjmin);
}
}
return lnnz;
}
- private static long convertGeneric(final FrameBlock frame, final
MatrixBlock mb, final int n, final int rl, final int ru) {
+ private static long convertBlockContiguous(final FrameBlock frame,
final int n, long lnnz, double[] c, int rl,
+ int cl, int ru, int cu) {
+ for(int i = rl, aix = rl * n; i < ru; i++, aix += n)
+ for(int j = cl; j < cu; j++)
+ lnnz += (c[aix + j] = frame.getDoubleNaN(i, j))
!= 0 ? 1 : 0;
+ return lnnz;
+ }
+
+ private static long convertGeneric(final FrameBlock frame, final
MatrixBlock mb, final int n, final int rl,
+ final int ru) {
long lnnz = 0;
final DenseBlock c = mb.getDenseBlock();
for(int bi = rl; bi < ru; bi += blocksizeIJ) {
for(int bj = 0; bj < n; bj += blocksizeIJ) {
int bimin = Math.min(bi + blocksizeIJ, ru);
int bjmin = Math.min(bj + blocksizeIJ, n);
- for(int i = bi; i < bimin; i++) {
- double[] cvals = c.values(i);
- int cpos = c.pos(i);
- for(int j = bj; j < bjmin; j++)
- lnnz += (cvals[cpos + j] =
frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
- }
+ lnnz = convertBlockGeneric(frame, lnnz, c, bi,
bj, bimin, bjmin);
}
}
return lnnz;
}
+
+ private static long convertBlockGeneric(final FrameBlock frame, long
lnnz, final DenseBlock c, final int rl,
+ final int cl, final int ru, final int cu) {
+ for(int i = rl; i < ru; i++) {
+ final double[] cvals = c.values(i);
+ final int cpos = c.pos(i);
+ for(int j = cl; j < cu; j++)
+ lnnz += (cvals[cpos + j] =
frame.getDoubleNaN(i, j)) != 0 ? 1 : 0;
+ }
+ return lnnz;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index ead979a44b..66d616a738 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -332,6 +332,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
if(sparseBlock == null)
return;
sparseBlock.reset(estimatedNNzsPerRow, clen);
+ denseBlock = null;
}
private void resetDense(double val) {
@@ -343,6 +344,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
allocateDenseBlock(false);
denseBlock.set(val);
}
+ sparseBlock = null;
}
private void resetDense(double val, boolean dedup) {
@@ -354,6 +356,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
allocateDenseBlock(false, dedup);
denseBlock.set(val);
}
+ sparseBlock = null;
}
/**
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
index 5e482a5369..047b2da3b2 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java
@@ -19,16 +19,26 @@
package org.apache.sysds.test.component.frame;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend;
+import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
public class FrameCustomTest {
+ protected static final Log LOG =
LogFactory.getLog(FrameCustomTest.class.getName());
@Test
public void castToFrame() {
@@ -61,4 +71,30 @@ public class FrameCustomTest {
assertTrue(f.getSchema()[0] == ValueType.FP64);
}
+
+ @Test
+ public void detectSchemaError(){
+ FrameBlock f = TestUtils.generateRandomFrameBlock(10, 10, 23);
+ FrameBlock spy = spy(f);
+ when(spy.getColumn(anyInt())).thenThrow(new RuntimeException());
+
+ Exception e = assertThrows(DMLRuntimeException.class, () ->
FrameLibDetectSchema.detectSchema(spy, 3));
+
+ assertTrue(e.getMessage().contains("Failed to detect schema"));
+ }
+
+
+
+ @Test
+ public void appendUniqueColNames(){
+ FrameBlock a = new FrameBlock(new ValueType[]{ValueType.FP32},
new String[]{"Hi"});
+ a.appendRow(new String[]{"0.2"});
+ FrameBlock b = new FrameBlock(new ValueType[]{ValueType.FP32},
new String[]{"There"});
+ b.appendRow(new String[]{"0.5"});
+
+ FrameBlock c = FrameLibAppend.append(a, b, true);
+
+ assertTrue(c.getColumnName(0).equals("Hi"));
+ assertTrue(c.getColumnName(1).equals("There"));
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
index bf84898344..310d497a6f 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java
@@ -90,16 +90,16 @@ public class FrameTest {
@AfterClass
public static void cleanup() {
- try{
+ try {
IOCompressionTestUtils.deleteDirectory(new
File(nameBeginning));
}
- catch(Exception e){
+ catch(Exception e) {
e.printStackTrace();
LOG.error("failed to delete", e);
}
}
-
+
@Test
public void appendSelfRBind() {
FrameBlock ff = append(f, f, false);
@@ -173,6 +173,34 @@ public class FrameTest {
f.append(b, true);
}
+ @Test(expected = DMLRuntimeException.class)
+ public void cBindEmptyCols() {
+ // must have same number of rows.
+ FrameBlock b = new FrameBlock();
+ b.append(f, false);
+ }
+
+ @Test(expected = DMLRuntimeException.class)
+ public void cBindEmptyAfterCols() {
+ // must have same number of rows.
+ FrameBlock b = new FrameBlock();
+ f.append(b, false);
+ }
+
+ @Test
+ public void cBindEmptyR() {
+ // must have same number of rows.
+ FrameBlock b = new FrameBlock(new ValueType[0], f.getNumRows());
+ b.append(f, true);
+ }
+
+ @Test
+ public void cBindEmptyAfterR() {
+ // must have same number of rows.
+ FrameBlock b = new FrameBlock(new ValueType[0], f.getNumRows());
+ f.append(b, true);
+ }
+
@Test
public void cBindStringColAfter() {
// must have same number of rows.
@@ -308,6 +336,13 @@ public class FrameTest {
TestUtils.compareFrames(f, fs, true);
}
+ @Test
+ public void testApplyApproxSchema() {
+ final FrameBlock schema = FrameLibDetectSchema.detectSchema(f,
0.1, 1);
+ final FrameBlock fs =
FrameLibApplySchema.applySchema(f.copyShallow(), schema, 1);
+ TestUtils.compareFrames(f, fs, true);
+ }
+
protected static void writeAndRead(FrameBlock fb) {
try {
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
index 713fc73db0..035a8b3e6d 100644
--- a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
+++ b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java
@@ -283,6 +283,104 @@ public class FrameUtilTest {
assertEquals(ValueType.FP64, FrameUtil.isType(33.231425155253));
}
+ @Test
+ public void mergeSchema1() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"STRING"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"FP64"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("STRING"));
+ }
+
+ @Test
+ public void mergeSchema2() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"FP32"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"FP64"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("FP64"));
+ }
+
+ @Test
+ public void mergeSchema3() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"INT32"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"FP64"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("FP64"));
+ }
+
+ @Test
+ public void mergeSchema4() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"INT32"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"INT64"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("INT64"));
+ }
+
+ @Test
+ public void mergeSchema5() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"INT32"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"STRING"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("STRING"));
+ }
+
+ @Test
+ public void mergeSchema6() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"BOOLEAN"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"INT32"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("INT32"));
+ }
+
+ @Test
+ public void mergeSchema7() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"BOOLEAN"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"UINT8"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("UINT8"));
+ }
+
+ @Test
+ public void mergeSchema8() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ a.appendRow(new String[] {"BOOLEAN"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"BOOLEAN"});
+
+ FrameBlock c = FrameUtil.mergeSchema(a, b);
+ assertTrue(c.get(0, 0).equals("BOOLEAN"));
+ }
+
+ @Test(expected = Exception.class)
+ public void mergeSchemaInvalid() {
+ FrameBlock a = new FrameBlock(new ValueType[]
{ValueType.STRING, ValueType.STRING});
+ a.appendRow(new String[] {"BOOLEAN", "BOOLEAN"});
+ FrameBlock b = new FrameBlock(new ValueType[]
{ValueType.STRING});
+ b.appendRow(new String[] {"BOOLEAN"});
+
+ FrameUtil.mergeSchema(a, b);
+ }
+
@Test
public void testSparkFrameBlockALignment() {
ValueType[] schema = new ValueType[0];
@@ -466,35 +564,33 @@ public class FrameUtilTest {
assertTrue(ValueType.FP64 ==
FrameUtil.isType(2.2231342152323232, ValueType.FP64));
}
-
- @Test
- public void isDefault(){
+ @Test
+ public void isDefault() {
assertTrue(FrameUtil.isDefault(null, null));
assertTrue(FrameUtil.isDefault("false", ValueType.BOOLEAN));
assertTrue(FrameUtil.isDefault("f", ValueType.BOOLEAN));
assertTrue(FrameUtil.isDefault("0", ValueType.BOOLEAN));
- assertTrue(FrameUtil.isDefault("" + (char)(0),
ValueType.CHARACTER));
- assertTrue(FrameUtil.isDefault("0.0" , ValueType.FP32));
- assertTrue(FrameUtil.isDefault("0" , ValueType.FP32));
- assertTrue(FrameUtil.isDefault("0.0" , ValueType.FP64));
- assertTrue(FrameUtil.isDefault("0" , ValueType.FP64));
- assertTrue(FrameUtil.isDefault("0.0" , ValueType.INT32));
- assertTrue(FrameUtil.isDefault("0" , ValueType.INT32));
- assertTrue(FrameUtil.isDefault("0.0" , ValueType.INT64));
- assertTrue(FrameUtil.isDefault("0" , ValueType.INT64));
-
-
- assertFalse(FrameUtil.isDefault("0.0" , ValueType.STRING));
- assertFalse(FrameUtil.isDefault("0" , ValueType.STRING));
- assertFalse(FrameUtil.isDefault("" , ValueType.STRING));
- assertFalse(FrameUtil.isDefault("13" , ValueType.STRING));
- assertFalse(FrameUtil.isDefault("13" , ValueType.INT32));
- assertFalse(FrameUtil.isDefault("13" , ValueType.INT64));
- assertFalse(FrameUtil.isDefault("13" , ValueType.FP64));
- assertFalse(FrameUtil.isDefault("13" , ValueType.FP32));
- assertFalse(FrameUtil.isDefault("1" , ValueType.CHARACTER));
- assertFalse(FrameUtil.isDefault("0" , ValueType.CHARACTER));
- assertFalse(FrameUtil.isDefault("t" , ValueType.BOOLEAN));
- assertFalse(FrameUtil.isDefault("true" , ValueType.BOOLEAN));
+ assertTrue(FrameUtil.isDefault("" + (char) (0),
ValueType.CHARACTER));
+ assertTrue(FrameUtil.isDefault("0.0", ValueType.FP32));
+ assertTrue(FrameUtil.isDefault("0", ValueType.FP32));
+ assertTrue(FrameUtil.isDefault("0.0", ValueType.FP64));
+ assertTrue(FrameUtil.isDefault("0", ValueType.FP64));
+ assertTrue(FrameUtil.isDefault("0.0", ValueType.INT32));
+ assertTrue(FrameUtil.isDefault("0", ValueType.INT32));
+ assertTrue(FrameUtil.isDefault("0.0", ValueType.INT64));
+ assertTrue(FrameUtil.isDefault("0", ValueType.INT64));
+
+ assertFalse(FrameUtil.isDefault("0.0", ValueType.STRING));
+ assertFalse(FrameUtil.isDefault("0", ValueType.STRING));
+ assertFalse(FrameUtil.isDefault("", ValueType.STRING));
+ assertFalse(FrameUtil.isDefault("13", ValueType.STRING));
+ assertFalse(FrameUtil.isDefault("13", ValueType.INT32));
+ assertFalse(FrameUtil.isDefault("13", ValueType.INT64));
+ assertFalse(FrameUtil.isDefault("13", ValueType.FP64));
+ assertFalse(FrameUtil.isDefault("13", ValueType.FP32));
+ assertFalse(FrameUtil.isDefault("1", ValueType.CHARACTER));
+ assertFalse(FrameUtil.isDefault("0", ValueType.CHARACTER));
+ assertFalse(FrameUtil.isDefault("t", ValueType.BOOLEAN));
+ assertFalse(FrameUtil.isDefault("true", ValueType.BOOLEAN));
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameTest.java
b/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameTest.java
new file mode 100644
index 0000000000..ad6e638fbb
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+import java.util.ArrayList;
+import java.util.Collection;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema;
+import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class MatrixFromFrameTest {
+ protected static final Log LOG =
LogFactory.getLog(MatrixFromFrameTest.class.getName());
+
+ public final FrameBlock fb;
+
+ public MatrixFromFrameTest(ValueType[] schema, int rows, long seed) {
+ FrameBlock tmp = TestUtils.generateRandomFrameBlock(10, schema,
3214L);
+ fb = FrameLibApplySchema.applySchema(tmp, schema);
+ }
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ ArrayList<Object[]> tests = new ArrayList<>();
+
+ try {
+ tests.add(new Object[] {new ValueType[]
{ValueType.FP64, ValueType.BOOLEAN}, 10, 1324L});
+ tests.add(new Object[] {new ValueType[]
{ValueType.CHARACTER, ValueType.INT32}, 10, 1324L});
+ tests.add(new Object[] {new ValueType[]
{ValueType.HASH64, ValueType.INT64, ValueType.BOOLEAN}, 10, 1324L});
+ ValueType t = ValueType.INT64;
+ tests.add(new Object[] {new ValueType[] {t, t, t, t, t,
t, t, t, t, t}, 100, 1324L});
+
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail("failed constructing tests");
+ }
+
+ return tests;
+ }
+
+ @Test
+ public void singleThread() {
+
+ MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb,
1);
+ compare(fb, mb);
+ }
+
+ @Test
+ public void parallelThread() {
+
+ MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb,
2);
+
+ compare(fb, mb);
+ }
+
+ @Test
+ public void dynamicThread() {
+
+ MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb,
-1);
+
+ compare(fb, mb);
+ }
+
+ @Test
+ public void allocatedOut() {
+
+ MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb,
new MatrixBlock(3, 3, true), -1);
+
+ compare(fb, mb);
+ }
+
+ @Test
+ public void allocatedOutDense() {
+ MatrixBlock mb = new MatrixBlock(fb.getNumRows(),
fb.getNumRows(), false);
+ mb.allocateBlock();
+
+ mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, -1);
+
+ compare(fb, mb);
+ }
+
+ @Test
+ public void allocatedOutSparse() {
+ MatrixBlock mb = new MatrixBlock(fb.getNumRows(),
fb.getNumRows(), true);
+ mb.allocateBlock();
+
+ mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, -1);
+
+ compare(fb, mb);
+ }
+
+ @Test
+ public void allocatedOutNonContinuous() {
+ MatrixBlock mb = new MatrixBlock(fb.getNumRows(),
fb.getNumRows(), false);
+ mb.allocateBlock();
+ DenseBlock spy = spy(mb.getDenseBlock());
+ when(spy.isContiguous()).thenReturn(false);
+ mb.setDenseBlock(spy);
+
+ mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, -1);
+
+ compare(fb, mb);
+ }
+
+ @Test
+ public void testException() {
+ MatrixBlock mb = new MatrixBlock(fb.getNumRows(),
fb.getNumRows(), false);
+ mb.allocateBlock();
+ MatrixBlock spy = spy(mb);
+ when(spy.getDenseBlockValues()).thenThrow(new
RuntimeException());
+
+ Exception e = assertThrows(DMLRuntimeException.class,
+ () -> MatrixBlockFromFrame.convertToMatrixBlock(fb,
spy, -1));
+
+ assertTrue(e.getMessage().contains("Failed to convert
FrameBlock to MatrixBlock"));
+ }
+
+ private void compare(FrameBlock fb, MatrixBlock mb) {
+ for(int i = 0; i < fb.getNumRows(); i++) {
+ for(int j = 0; j < fb.getNumColumns(); j++) {
+ assertEquals(fb.getColumn(j).getAsNaNDouble(i),
mb.get(i, j), 0.0);
+ }
+ }
+ }
+}