Repository: systemml
Updated Branches:
  refs/heads/master 8dbc93022 -> 912c65506


[MINOR] Increase MLContext test coverage

Create MLContext tests to test previously untested methods.
Update MLContext and MLContextConversionUtil to avoid NPEs.

Closes #649.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/912c6550
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/912c6550
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/912c6550

Branch: refs/heads/master
Commit: 912c65506d626c8b0128ceb80744fde49efd4a1a
Parents: 8dbc930
Author: Deron Eriksson <de...@apache.org>
Authored: Fri Sep 1 16:02:55 2017 -0700
Committer: Deron Eriksson <de...@apache.org>
Committed: Fri Sep 1 16:02:55 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/api/mlcontext/MLContext.java   |  10 +-
 .../api/mlcontext/MLContextConversionUtil.java  |   3 +
 .../sysml/api/mlcontext/MLContextUtil.java      | 206 ++++++------
 .../integration/mlcontext/MLContextTest.java    | 314 +++++++++++++++++++
 4 files changed, 431 insertions(+), 102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index 83eedb3..35720a5 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -55,7 +55,7 @@ public class MLContext {
        /**
         * Logger for MLContext
         */
-       public static Logger log = Logger.getLogger(MLContext.class);
+       protected static Logger log = Logger.getLogger(MLContext.class);
 
        /**
         * SparkSession object.
@@ -665,7 +665,9 @@ public class MLContext {
 
                // clear local status, but do not stop sc as it
                // may be used or stopped externally
-               executionScript.clearAll();
+               if (executionScript != null) {
+                       executionScript.clearAll();
+               }
                resetConfig();
                spark = null;
        }
@@ -693,7 +695,7 @@ public class MLContext {
         */
        public String version() {
                if (info() == null) {
-                       return "Version not available";
+                       return MLContextUtil.VERSION_NOT_AVAILABLE;
                }
                return info().version();
        }
@@ -705,7 +707,7 @@ public class MLContext {
         */
        public String buildTime() {
                if (info() == null) {
-                       return "Build time not available";
+                       return MLContextUtil.BUILD_TIME_NOT_AVAILABLE;
                }
                return info().buildTime();
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 5883127..3f12ace 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -482,6 +482,9 @@ public class MLContextConversionUtil {
         *            the matrix metadata, if available
         */
        public static void determineMatrixFormatIfNeeded(Dataset<Row> 
dataFrame, MatrixMetadata matrixMetadata) {
+               if (matrixMetadata == null) {
+                       return;
+               }
                MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat();
                if (matrixFormat != null) {
                        return;

http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 51d38a5..03184e3 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -91,122 +91,32 @@ import org.w3c.dom.NodeList;
  *
  */
 public final class MLContextUtil {
-       
+
        /**
-        * Get HOP DAG in dot format for a DML or PYDML Script.
-        *
-        * @param mlCtx
-        *            MLContext object.
-        * @param script
-        *            The DML or PYDML Script object to execute.
-        * @param lines
-        *            Only display the hops that have begin and end line number
-        *            equals to the given integers.
-        * @param performHOPRewrites
-        *            should perform static rewrites, perform
-        *            intra-/inter-procedural analysis to propagate size 
information
-        *            into functions and apply dynamic rewrites
-        * @param withSubgraph
-        *            If false, the dot graph will be created without subgraphs 
for
-        *            statement blocks.
-        * @return hop DAG in dot format
-        * @throws LanguageException
-        *             if error occurs
-        * @throws DMLRuntimeException
-        *             if error occurs
-        * @throws HopsException
-        *             if error occurs
+        * Version not available message.
         */
-       public static String getHopDAG(MLContext mlCtx, Script script, 
ArrayList<Integer> lines,
-                       boolean performHOPRewrites, boolean withSubgraph) 
throws HopsException, DMLRuntimeException,
-                       LanguageException {
-               return getHopDAG(mlCtx, script, lines, null, 
performHOPRewrites, withSubgraph);
-       }
+       public static final String VERSION_NOT_AVAILABLE = "Version not 
available";
 
        /**
-        * Get HOP DAG in dot format for a DML or PYDML Script.
-        *
-        * @param mlCtx
-        *            MLContext object.
-        * @param script
-        *            The DML or PYDML Script object to execute.
-        * @param lines
-        *            Only display the hops that have begin and end line number
-        *            equals to the given integers.
-        * @param newConf
-        *            Spark Configuration.
-        * @param performHOPRewrites
-        *            should perform static rewrites, perform
-        *            intra-/inter-procedural analysis to propagate size 
information
-        *            into functions and apply dynamic rewrites
-        * @param withSubgraph
-        *            If false, the dot graph will be created without subgraphs 
for
-        *            statement blocks.
-        * @return hop DAG in dot format
-        * @throws LanguageException
-        *             if error occurs
-        * @throws DMLRuntimeException
-        *             if error occurs
-        * @throws HopsException
-        *             if error occurs
+        * Build time not available message.
         */
-       public static String getHopDAG(MLContext mlCtx, Script script, 
ArrayList<Integer> lines, SparkConf newConf,
-                       boolean performHOPRewrites, boolean withSubgraph) 
throws HopsException, DMLRuntimeException,
-                       LanguageException {
-               SparkConf oldConf = 
mlCtx.getSparkSession().sparkContext().getConf();
-               SparkExecutionContext.SparkClusterConfig systemmlConf = 
SparkExecutionContext.getSparkClusterConfig();
-               long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory();
-               try {
-                       if (newConf != null) {
-                               systemmlConf.analyzeSparkConfiguation(newConf);
-                               
InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory"));
-                       }
-                       ScriptExecutor scriptExecutor = new ScriptExecutor();
-                       
scriptExecutor.setExecutionType(mlCtx.getExecutionType());
-                       scriptExecutor.setGPU(mlCtx.isGPU());
-                       scriptExecutor.setForceGPU(mlCtx.isForceGPU());
-                       scriptExecutor.setInit(mlCtx.isInitBeforeExecution());
-                       if (mlCtx.isInitBeforeExecution()) {
-                               mlCtx.setInitBeforeExecution(false);
-                       }
-                       
scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable());
-
-                       Long time = new Long((new Date()).getTime());
-                       if ((script.getName() == null) || 
(script.getName().equals(""))) {
-                               script.setName(time.toString());
-                       }
-                       
-                       mlCtx.setExecutionScript(script);
-                       scriptExecutor.compile(script, performHOPRewrites);
-                       Explain.reset();
-                       // To deal with potential Py4J issues
-                       lines = lines.size() == 1 && lines.get(0) == -1 ? new 
ArrayList<Integer>() : lines;
-                       return Explain.getHopDAG(scriptExecutor.dmlProgram, 
lines, withSubgraph);
-               } catch (RuntimeException e) {
-                       throw new MLContextException("Exception when compiling 
script", e);
-               } finally {
-                       if (newConf != null) {
-                               systemmlConf.analyzeSparkConfiguation(oldConf);
-                               
InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory);
-                       }
-               }
-       }
+       public static final String BUILD_TIME_NOT_AVAILABLE = "Build time not 
available";
 
        /**
-        * Basic data types supported by the MLContext API
+        * Basic data types supported by the MLContext API.
         */
        @SuppressWarnings("rawtypes")
        public static final Class[] BASIC_DATA_TYPES = { Integer.class, 
Boolean.class, Double.class, String.class };
 
        /**
-        * Complex data types supported by the MLContext API
+        * Complex data types supported by the MLContext API.
         */
        @SuppressWarnings("rawtypes")
        public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, 
RDD.class, Dataset.class, Matrix.class,
                        Frame.class, (new double[][] {}).getClass(), 
MatrixBlock.class, URL.class };
 
        /**
-        * All data types supported by the MLContext API
+        * All data types supported by the MLContext API.
         */
        @SuppressWarnings("rawtypes")
        public static final Class[] ALL_SUPPORTED_DATA_TYPES = (Class[]) 
ArrayUtils.addAll(BASIC_DATA_TYPES,
@@ -1252,4 +1162,104 @@ public final class MLContextUtil {
                        }
                }
        }
+
+       /**
+        * Get HOP DAG in dot format for a DML or PYDML Script.
+        *
+        * @param mlCtx
+        *            MLContext object.
+        * @param script
+        *            The DML or PYDML Script object to execute.
+        * @param lines
+        *            Only display the hops that have begin and end line number
+        *            equals to the given integers.
+        * @param performHOPRewrites
+        *            should perform static rewrites, perform
+        *            intra-/inter-procedural analysis to propagate size 
information
+        *            into functions and apply dynamic rewrites
+        * @param withSubgraph
+        *            If false, the dot graph will be created without subgraphs 
for
+        *            statement blocks.
+        * @return hop DAG in dot format
+        * @throws LanguageException
+        *             if error occurs
+        * @throws DMLRuntimeException
+        *             if error occurs
+        * @throws HopsException
+        *             if error occurs
+        */
+       public static String getHopDAG(MLContext mlCtx, Script script, 
ArrayList<Integer> lines, boolean performHOPRewrites,
+                       boolean withSubgraph) throws HopsException, 
DMLRuntimeException, LanguageException {
+               return getHopDAG(mlCtx, script, lines, null, 
performHOPRewrites, withSubgraph);
+       }
+
+       /**
+        * Get HOP DAG in dot format for a DML or PYDML Script.
+        *
+        * @param mlCtx
+        *            MLContext object.
+        * @param script
+        *            The DML or PYDML Script object to execute.
+        * @param lines
+        *            Only display the hops that have begin and end line number
+        *            equals to the given integers.
+        * @param newConf
+        *            Spark Configuration.
+        * @param performHOPRewrites
+        *            should perform static rewrites, perform
+        *            intra-/inter-procedural analysis to propagate size 
information
+        *            into functions and apply dynamic rewrites
+        * @param withSubgraph
+        *            If false, the dot graph will be created without subgraphs 
for
+        *            statement blocks.
+        * @return hop DAG in dot format
+        * @throws LanguageException
+        *             if error occurs
+        * @throws DMLRuntimeException
+        *             if error occurs
+        * @throws HopsException
+        *             if error occurs
+        */
+       public static String getHopDAG(MLContext mlCtx, Script script, 
ArrayList<Integer> lines, SparkConf newConf,
+                       boolean performHOPRewrites, boolean withSubgraph)
+                       throws HopsException, DMLRuntimeException, 
LanguageException {
+               SparkConf oldConf = 
mlCtx.getSparkSession().sparkContext().getConf();
+               SparkExecutionContext.SparkClusterConfig systemmlConf = 
SparkExecutionContext.getSparkClusterConfig();
+               long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory();
+               try {
+                       if (newConf != null) {
+                               systemmlConf.analyzeSparkConfiguation(newConf);
+                               
InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory"));
+                       }
+                       ScriptExecutor scriptExecutor = new ScriptExecutor();
+                       
scriptExecutor.setExecutionType(mlCtx.getExecutionType());
+                       scriptExecutor.setGPU(mlCtx.isGPU());
+                       scriptExecutor.setForceGPU(mlCtx.isForceGPU());
+                       scriptExecutor.setInit(mlCtx.isInitBeforeExecution());
+                       if (mlCtx.isInitBeforeExecution()) {
+                               mlCtx.setInitBeforeExecution(false);
+                       }
+                       
scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable());
+
+                       Long time = new Long((new Date()).getTime());
+                       if ((script.getName() == null) || 
(script.getName().equals(""))) {
+                               script.setName(time.toString());
+                       }
+
+                       mlCtx.setExecutionScript(script);
+                       scriptExecutor.compile(script, performHOPRewrites);
+                       Explain.reset();
+                       // To deal with potential Py4J issues
+                       lines = lines.size() == 1 && lines.get(0) == -1 ? new 
ArrayList<Integer>() : lines;
+                       return Explain.getHopDAG(scriptExecutor.dmlProgram, 
lines, withSubgraph);
+               } catch (RuntimeException e) {
+                       throw new MLContextException("Exception when compiling 
script", e);
+               } finally {
+                       if (newConf != null) {
+                               systemmlConf.analyzeSparkConfiguation(oldConf);
+                               
InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory);
+                       }
+               }
+       }
+
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index b08f5b9..9e4cfac 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -42,10 +42,13 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
+import org.apache.spark.ml.linalg.DenseVector;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.ml.linalg.Vectors;
@@ -54,10 +57,12 @@ import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
 import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DoubleType;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.mlcontext.MLContextConversionUtil;
 import org.apache.sysml.api.mlcontext.MLContextException;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -69,11 +74,14 @@ import 
org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.util.DataConverter;
 import org.junit.Assert;
 import org.junit.Test;
 
+import scala.Tuple1;
 import scala.Tuple2;
 import scala.Tuple3;
+import scala.Tuple4;
 import scala.collection.Iterator;
 import scala.collection.JavaConversions;
 import scala.collection.Seq;
@@ -2756,4 +2764,310 @@ public class MLContextTest extends MLContextTestBase {
                Assert.assertEquals(3, results.getLong("y"));
        }
 
+       @Test
+       public void testOutputDataFrameOfVectorsDML() {
+               System.out.println("MLContextTest - output DataFrame of vectors 
DML");
+
+               String s = "m=matrix('1 2 3 4',rows=2,cols=2);";
+               Script script = dml(s).out("m");
+               MLResults results = ml.execute(script);
+               Dataset<Row> df = results.getDataFrame("m", true);
+               Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
+
+               // verify column types
+               StructType schema = sortedDF.schema();
+               StructField[] fields = schema.fields();
+               StructField idColumn = fields[0];
+               StructField vectorColumn = fields[1];
+               Assert.assertTrue(idColumn.dataType() instanceof DoubleType);
+               Assert.assertTrue(vectorColumn.dataType() instanceof VectorUDT);
+
+               List<Row> list = sortedDF.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Vector v1 = (DenseVector) row1.get(1);
+               double[] arr1 = v1.toArray();
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, arr1, 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Vector v2 = (DenseVector) row2.get(1);
+               double[] arr2 = v2.toArray();
+               Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, arr2, 0.0);
+       }
+
+       @Test
+       public void testOutputDoubleArrayFromMatrixDML() {
+               System.out.println("MLContextTest - output double array from 
matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               double[][] matrix = 
ml.execute(dml(s).out("M")).getMatrix("M").to2DDoubleArray();
+               Assert.assertEquals(1.0, matrix[0][0], 0);
+               Assert.assertEquals(2.0, matrix[0][1], 0);
+               Assert.assertEquals(3.0, matrix[1][0], 0);
+               Assert.assertEquals(4.0, matrix[1][1], 0);
+       }
+
+       @Test
+       public void testOutputDataFrameFromMatrixDML() {
+               System.out.println("MLContextTest - output DataFrame from 
matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               Dataset<Row> df = ml.execute(script).getMatrix("M").toDF();
+               Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
+               List<Row> list = sortedDF.collectAsList();
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+               Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+               Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameDoublesNoIDColumnFromMatrixDML() {
+               System.out.println("MLContextTest - output DataFrame of doubles 
with no ID column from matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+               Script script = dml(s).out("M");
+               Dataset<Row> df = 
ml.execute(script).getMatrix("M").toDFDoubleNoIDColumn();
+               List<Row> list = df.collectAsList();
+
+               Row row = list.get(0);
+               Assert.assertEquals(1.0, row.getDouble(0), 0.0);
+               Assert.assertEquals(2.0, row.getDouble(1), 0.0);
+               Assert.assertEquals(3.0, row.getDouble(2), 0.0);
+               Assert.assertEquals(4.0, row.getDouble(3), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameDoublesWithIDColumnFromMatrixDML() {
+               System.out.println("MLContextTest - output DataFrame of doubles 
with ID column from matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               Dataset<Row> df = 
ml.execute(script).getMatrix("M").toDFDoubleWithIDColumn();
+               Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
+               List<Row> list = sortedDF.collectAsList();
+
+               Row row1 = list.get(0);
+               Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
+               Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
+               Assert.assertEquals(2.0, row1.getDouble(2), 0.0);
+
+               Row row2 = list.get(1);
+               Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
+               Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
+               Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameVectorsNoIDColumnFromMatrixDML() {
+               System.out.println("MLContextTest - output DataFrame of vectors 
with no ID column from matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+               Script script = dml(s).out("M");
+               Dataset<Row> df = 
ml.execute(script).getMatrix("M").toDFVectorNoIDColumn();
+               List<Row> list = df.collectAsList();
+
+               Row row = list.get(0);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, 
((Vector) row.get(0)).toArray(), 0.0);
+       }
+
+       @Test
+       public void testOutputDataFrameVectorsWithIDColumnFromMatrixDML() {
+               System.out.println("MLContextTest - output DataFrame of vectors 
with ID column from matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+               Script script = dml(s).out("M");
+               Dataset<Row> df = 
ml.execute(script).getMatrix("M").toDFVectorWithIDColumn();
+               List<Row> list = df.collectAsList();
+
+               Row row = list.get(0);
+               Assert.assertEquals(1.0, row.getDouble(0), 0.0);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, 
((Vector) row.get(1)).toArray(), 0.0);
+       }
+
+       @Test
+       public void testOutputJavaRDDStringCSVFromMatrixDML() {
+               System.out.println("MLContextTest - output Java RDD String CSV 
from matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+               Script script = dml(s).out("M");
+               JavaRDD<String> javaRDDStringCSV = 
ml.execute(script).getMatrix("M").toJavaRDDStringCSV();
+               List<String> lines = javaRDDStringCSV.collect();
+               Assert.assertEquals("1.0,2.0,3.0,4.0", lines.get(0));
+       }
+
+       @Test
+       public void testOutputJavaRDDStringIJVFromMatrixDML() {
+               System.out.println("MLContextTest - output Java RDD String IJV 
from matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               MLResults results = ml.execute(script);
+               JavaRDD<String> javaRDDStringIJV = 
results.getJavaRDDStringIJV("M");
+               List<String> lines = javaRDDStringIJV.sortBy(row -> row, true, 
1).collect();
+               Assert.assertEquals("1 1 1.0", lines.get(0));
+               Assert.assertEquals("1 2 2.0", lines.get(1));
+               Assert.assertEquals("2 1 3.0", lines.get(2));
+               Assert.assertEquals("2 2 4.0", lines.get(3));
+       }
+
+       @Test
+       public void testOutputRDDStringCSVFromMatrixDML() {
+               System.out.println("MLContextTest - output RDD String CSV from 
matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=1, cols=4);";
+               Script script = dml(s).out("M");
+               RDD<String> rddStringCSV = 
ml.execute(script).getMatrix("M").toRDDStringCSV();
+               Iterator<String> iterator = rddStringCSV.toLocalIterator();
+               Assert.assertEquals("1.0,2.0,3.0,4.0", iterator.next());
+       }
+
+       @Test
+       public void testOutputRDDStringIJVFromMatrixDML() {
+               System.out.println("MLContextTest - output RDD String IJV from 
matrix DML");
+
+               String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
+               Script script = dml(s).out("M");
+               RDD<String> rddStringIJV = 
ml.execute(script).getMatrix("M").toRDDStringIJV();
+               String[] rows = (String[]) rddStringIJV.collect();
+               Arrays.sort(rows);
+               Assert.assertEquals("1 1 1.0", rows[0]);
+               Assert.assertEquals("1 2 2.0", rows[1]);
+               Assert.assertEquals("2 1 3.0", rows[2]);
+               Assert.assertEquals("2 2 4.0", rows[3]);
+       }
+
+       @Test
+       public void testMLContextVersionMessage() {
+               System.out.println("MLContextTest - version message");
+
+               String version = ml.version();
+               // not available until jar built
+               Assert.assertEquals(MLContextUtil.VERSION_NOT_AVAILABLE, 
version);
+       }
+
+       @Test
+       public void testMLContextBuildTimeMessage() {
+               System.out.println("MLContextTest - build time message");
+
+               String buildTime = ml.buildTime();
+               // not available until jar built
+               Assert.assertEquals(MLContextUtil.BUILD_TIME_NOT_AVAILABLE, 
buildTime);
+       }
+
+       @Test
+       public void testMLContextCreateAndClose() {
+               // MLContext created by the @BeforeClass method in 
MLContextTestBase
+               // MLContext closed by the @AfterClass method in 
MLContextTestBase
+               System.out.println("MLContextTest - create MLContext and close 
(without script execution)");
+       }
+
+       @Test
+       public void testDataFrameToBinaryBlocks() {
+               System.out.println("MLContextTest - DataFrame to binary 
blocks");
+
+               List<String> list = new ArrayList<String>();
+               list.add("1,2,3");
+               list.add("4,5,6");
+               list.add("7,8,9");
+               JavaRDD<String> javaRddString = sc.parallelize(list);
+
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToDoubleArrayRow());
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.DoubleType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.DoubleType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.DoubleType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, 
schema);
+
+               JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = 
MLContextConversionUtil
+                               .dataFrameToMatrixBinaryBlocks(dataFrame);
+               Tuple2<MatrixIndexes, MatrixBlock> first = binaryBlocks.first();
+               MatrixBlock mb = first._2();
+               double[][] matrix = DataConverter.convertToDoubleMatrix(mb);
+               Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0 }, 
matrix[0], 0.0);
+               Assert.assertArrayEquals(new double[] { 4.0, 5.0, 6.0 }, 
matrix[1], 0.0);
+               Assert.assertArrayEquals(new double[] { 7.0, 8.0, 9.0 }, 
matrix[2], 0.0);
+       }
+
+       @Test
+       public void testGetTuple1DML() {
+               System.out.println("MLContextTest - Get Tuple1<Matrix> DML");
+               JavaRDD<String> javaRddString = sc
+                               .parallelize(Stream.of("1,2,3", "4,5,6", 
"7,8,9").collect(Collectors.toList()));
+               JavaRDD<Row> javaRddRow = javaRddString.map(new 
CommaSeparatedValueStringToDoubleArrayRow());
+               List<StructField> fields = new ArrayList<StructField>();
+               fields.add(DataTypes.createStructField("C1", 
DataTypes.DoubleType, true));
+               fields.add(DataTypes.createStructField("C2", 
DataTypes.DoubleType, true));
+               fields.add(DataTypes.createStructField("C3", 
DataTypes.DoubleType, true));
+               StructType schema = DataTypes.createStructType(fields);
+               Dataset<Row> df = spark.createDataFrame(javaRddRow, schema);
+
+               Script script = dml("N=M*2").in("M", df).out("N");
+               Tuple1<Matrix> tuple = ml.execute(script).getTuple("N");
+               double[][] n = tuple._1().to2DDoubleArray();
+               Assert.assertEquals(2.0, n[0][0], 0);
+               Assert.assertEquals(4.0, n[0][1], 0);
+               Assert.assertEquals(6.0, n[0][2], 0);
+               Assert.assertEquals(8.0, n[1][0], 0);
+               Assert.assertEquals(10.0, n[1][1], 0);
+               Assert.assertEquals(12.0, n[1][2], 0);
+               Assert.assertEquals(14.0, n[2][0], 0);
+               Assert.assertEquals(16.0, n[2][1], 0);
+               Assert.assertEquals(18.0, n[2][2], 0);
+       }
+
+       @Test
+       public void testGetTuple2DML() {
+               System.out.println("MLContextTest - Get Tuple2<Matrix,Double> 
DML");
+
+               double[][] m = new double[][] { { 1, 2 }, { 3, 4 } };
+
+               Script script = dml("N=M*2;s=sum(N)").in("M", m).out("N", "s");
+               Tuple2<Matrix, Double> tuple = ml.execute(script).getTuple("N", 
"s");
+               double[][] n = tuple._1().to2DDoubleArray();
+               double s = tuple._2();
+               Assert.assertArrayEquals(new double[] { 2, 4 }, n[0], 0.0);
+               Assert.assertArrayEquals(new double[] { 6, 8 }, n[1], 0.0);
+               Assert.assertEquals(20.0, s, 0.0);
+       }
+
+       @Test
+       public void testGetTuple3DML() {
+               System.out.println("MLContextTest - Get 
Tuple3<Long,Double,Boolean> DML");
+
+               Script script = dml("a=1+2;b=a+0.5;c=TRUE;").out("a", "b", "c");
+               Tuple3<Long, Double, Boolean> tuple = 
ml.execute(script).getTuple("a", "b", "c");
+               long a = tuple._1();
+               double b = tuple._2();
+               boolean c = tuple._3();
+               Assert.assertEquals(3, a);
+               Assert.assertEquals(3.5, b, 0.0);
+               Assert.assertEquals(true, c);
+       }
+
+       @Test
+       public void testGetTuple4DML() {
+               System.out.println("MLContextTest - Get 
Tuple4<Long,Double,Boolean,String> DML");
+
+               Script script = dml("a=1+2;b=a+0.5;c=TRUE;d=\"yes it's 
\"+c").out("a", "b", "c", "d");
+               Tuple4<Long, Double, Boolean, String> tuple = 
ml.execute(script).getTuple("a", "b", "c", "d");
+               long a = tuple._1();
+               double b = tuple._2();
+               boolean c = tuple._3();
+               String d = tuple._4();
+               Assert.assertEquals(3, a);
+               Assert.assertEquals(3.5, b, 0.0);
+               Assert.assertEquals(true, c);
+               Assert.assertEquals("yes it's TRUE", d);
+       }
+
 }

Reply via email to