This is an automated email from the ASF dual-hosted git repository.
ssiddiqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 8dcb720 [MINOR] Row-wise frame initialization support Closes #1228.
8dcb720 is described below
commit 8dcb720136d048b1fe7027c4b87254f1a948276e
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Thu Apr 15 16:44:38 2021 +0200
[MINOR] Row-wise frame initialization support
Closes #1228.
---
.../instructions/cp/DataGenCPInstruction.java | 16 +++++++-
.../instructions/spark/RandSPInstruction.java | 17 ++++++++-
.../test/functions/frame/FrameConstructorTest.java | 44 ++++++++++++++++------
.../functions/frame/FrameConstructorTest.dml | 3 ++
4 files changed, 65 insertions(+), 15 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index f217d6c..9017ea8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.cp;
import java.util.Arrays;
import java.util.Random;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -353,15 +354,26 @@ public class DataGenCPInstruction extends
UnaryCPInstruction {
}
else {
String[] data =
frame_data.split(DataExpression.DELIM_NA_STRING_SEP);
- if(data.length != schemaLength && data.length >
1)
+ int rowLength = data.length/lrows;
+ if(data.length != schemaLength && data.length >
1 && rowLength != schemaLength)
throw new DMLRuntimeException(
"data values should be equal to
number of columns," + " or a single values for all columns");
out = new FrameBlock(vt);
FrameBlock outF = (FrameBlock) out;
- if(data.length > 1) {
+ if(data.length > 1 && rowLength !=
schemaLength) {
for(int i = 0; i < lrows; i++)
outF.appendRow(data);
}
+ else if(data.length > 1 && rowLength ==
schemaLength)
+ {
+ int beg = 0;
+ for(int i = 1; i <= lrows; i++) {
+ int end = lcols * i;
+ String[] data1 =
ArrayUtils.subarray(data, beg, end);
+ beg = end;
+ outF.appendRow(data1);
+ }
+ }
else {
String[] data1 = new String[lcols];
Arrays.fill(data1, frame_data);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
index 3d0c2e5..cfd0087 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/RandSPInstruction.java
@@ -27,6 +27,7 @@ import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -993,14 +994,26 @@ public class RandSPInstruction extends UnarySPInstruction
{
}
else {
String[] data =
_data.split(DataExpression.DELIM_NA_STRING_SEP);
- if(data.length != _schema.length && data.length
> 1)
+ int rowLength = data.length/(int)_rlen;
+ if(data.length != _schema.length && data.length
> 1 && rowLength != _schema.length)
throw new DMLRuntimeException("data
values should be equal "
+ "to number of columns, or a
single values for all columns");
- if(data.length > 1) {
+ if(data.length > 1 && rowLength !=
_schema.length) {
out = new FrameBlock(_schema);
for(int i = 0; i < lrlen; i++)
out.appendRow(data);
}
+ else if(data.length > 1 && rowLength ==
_schema.length)
+ {
+ out = new FrameBlock(_schema);
+ int beg = 0;
+ for(int i = 1; i <= lrlen; i++) {
+ int end = (int)_clen * i;
+ String[] data1 =
ArrayUtils.subarray(data, beg, end);
+ beg = end;
+ out.appendRow(data1);
+ }
+ }
else {
out = new FrameBlock(_schema);
String[] data1 = new String[(int)_clen];
diff --git
a/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java
b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java
index ef9e8a6..a4a87c8 100644
---
a/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java
@@ -52,7 +52,8 @@ public class FrameConstructorTest extends AutomatedTestBase {
NAMED,
NO_SCHEMA,
RANDOM_DATA,
- SINGLE_DATA
+ SINGLE_DATA,
+ MULTI_ROW_DATA
}
@Override
@@ -66,25 +67,25 @@ public class FrameConstructorTest extends AutomatedTestBase
{
@Test
public void testFrameNamedParam() {
- FrameBlock exp = createExpectedFrame(schemaStrings1, false);
+ FrameBlock exp = createExpectedFrame(schemaStrings1,
rows,"mixed");
runFrameTest(TestType.NAMED, exp, Types.ExecMode.SINGLE_NODE);
}
@Test
public void testFrameNamedParamSP() {
- FrameBlock exp = createExpectedFrame(schemaStrings1, false);
+ FrameBlock exp = createExpectedFrame(schemaStrings1,
rows,"mixed");
runFrameTest(TestType.NAMED, exp, Types.ExecMode.SPARK);
}
@Test
public void testNoSchema() {
- FrameBlock exp = createExpectedFrame(schemaStrings2, false);
+ FrameBlock exp = createExpectedFrame(schemaStrings2,
rows,"mixed");
runFrameTest(TestType.NO_SCHEMA, exp,
Types.ExecMode.SINGLE_NODE);
}
@Test
public void testNoSchemaSP() {
- FrameBlock exp = createExpectedFrame(schemaStrings2, false);
+ FrameBlock exp = createExpectedFrame(schemaStrings2,
rows,"mixed");
runFrameTest(TestType.NO_SCHEMA, exp, Types.ExecMode.SPARK);
}
@@ -102,16 +103,28 @@ public class FrameConstructorTest extends
AutomatedTestBase {
@Test
public void testSingleData() {
- FrameBlock exp = createExpectedFrame(schemaStrings1, true);
+ FrameBlock exp = createExpectedFrame(schemaStrings1,
rows,"constant");
runFrameTest(TestType.SINGLE_DATA, exp,
Types.ExecMode.SINGLE_NODE);
}
@Test
public void testSingleDataSP() {
- FrameBlock exp = createExpectedFrame(schemaStrings1, true);
+ FrameBlock exp = createExpectedFrame(schemaStrings1,
rows,"constant");
runFrameTest(TestType.SINGLE_DATA, exp, Types.ExecMode.SPARK);
}
+ @Test
+ public void testMultiRowData() {
+ FrameBlock exp = createExpectedFrame(schemaStrings1,
5,"multi-row");
+ runFrameTest(TestType.MULTI_ROW_DATA, exp,
Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void testMultiRowDataSP() {
+ FrameBlock exp = createExpectedFrame(schemaStrings1,
5,"multi-row");
+ runFrameTest(TestType.MULTI_ROW_DATA, exp,
Types.ExecMode.SPARK);
+ }
+
private void runFrameTest(TestType type, FrameBlock expectedOutput,
Types.ExecMode et) {
Types.ExecMode platformOld = setExecMode(et);
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -144,11 +157,20 @@ public class FrameConstructorTest extends
AutomatedTestBase {
}
}
- private static FrameBlock createExpectedFrame(ValueType[] schema,
boolean constant) {
+ private static FrameBlock createExpectedFrame(ValueType[] schema, int
rows, String type) {
FrameBlock exp = new FrameBlock(schema);
- String[] out = constant ?
- new String[]{"1", "1", "1", "1"} :
- new String[]{"1", "abc", "2.5", "TRUE"};
+ String[] out = null;
+ if(type.equals("mixed"))
+ out = new String[]{"1", "abc", "2.5", "TRUE"};
+ else if(type.equals("constant"))
+ out = new String[]{"1", "1", "1", "1"};
+ else if (type.equals("multi-row")) //multi-row data
+ out = new String[]{"1", "abc", "2.5", "TRUE"};
+ else {
+ System.out.println("invalid test type");
+ System.exit(1);
+ }
+
for(int i=0; i<rows; i++)
exp.appendRow(out);
return exp;
diff --git a/src/test/scripts/functions/frame/FrameConstructorTest.dml
b/src/test/scripts/functions/frame/FrameConstructorTest.dml
index 8762d30..53196a6 100644
--- a/src/test/scripts/functions/frame/FrameConstructorTest.dml
+++ b/src/test/scripts/functions/frame/FrameConstructorTest.dml
@@ -28,6 +28,9 @@ if($1 == "RANDOM_DATA")
f1 = frame("", rows=40, cols=4, schema=["INT64", "STRING", "FP64",
"BOOLEAN"]) # no data
if($1 == "SINGLE_DATA")
f1 = frame(1, rows=40, cols=4, schema=["INT64", "STRING", "FP64",
"BOOLEAN"]) # no data
+if($1 == "MULTI_ROW_DATA")
+ f1 = frame(data=["1", "abc", "2.5", "TRUE", "1", "abc", "2.5", "TRUE", "1",
"abc", "2.5", "TRUE", "1", "abc", "2.5", "TRUE",
+ "1", "abc", "2.5", "TRUE" ], rows=5, cols=4, schema=["INT64", "STRING",
"FP64", "BOOLEAN"]) # initialization by row
# f1 = frame(1, 4, 3) # unnamed parameters not working
write(f1, $2, format="csv")