Repository: systemml Updated Branches: refs/heads/master 8dbc93022 -> 912c65506
[MINOR] Increase MLContext test coverage Create MLContext tests to test previously untested methods. Update MLContext and MLContextConversionUtil to avoid NPEs. Closes #649. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/912c6550 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/912c6550 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/912c6550 Branch: refs/heads/master Commit: 912c65506d626c8b0128ceb80744fde49efd4a1a Parents: 8dbc930 Author: Deron Eriksson <de...@apache.org> Authored: Fri Sep 1 16:02:55 2017 -0700 Committer: Deron Eriksson <de...@apache.org> Committed: Fri Sep 1 16:02:55 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/api/mlcontext/MLContext.java | 10 +- .../api/mlcontext/MLContextConversionUtil.java | 3 + .../sysml/api/mlcontext/MLContextUtil.java | 206 ++++++------ .../integration/mlcontext/MLContextTest.java | 314 +++++++++++++++++++ 4 files changed, 431 insertions(+), 102 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index 83eedb3..35720a5 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -55,7 +55,7 @@ public class MLContext { /** * Logger for MLContext */ - public static Logger log = Logger.getLogger(MLContext.class); + protected static Logger log = Logger.getLogger(MLContext.class); /** * SparkSession object. @@ -665,7 +665,9 @@ public class MLContext { // clear local status, but do not stop sc as it // may be used or stopped externally - executionScript.clearAll(); + if (executionScript != null) { + executionScript.clearAll(); + } resetConfig(); spark = null; } @@ -693,7 +695,7 @@ public class MLContext { */ public String version() { if (info() == null) { - return "Version not available"; + return MLContextUtil.VERSION_NOT_AVAILABLE; } return info().version(); } @@ -705,7 +707,7 @@ public class MLContext { */ public String buildTime() { if (info() == null) { - return "Build time not available"; + return MLContextUtil.BUILD_TIME_NOT_AVAILABLE; } return info().buildTime(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java index 5883127..3f12ace 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -482,6 +482,9 @@ public class MLContextConversionUtil { * the matrix metadata, if available */ public static void determineMatrixFormatIfNeeded(Dataset<Row> dataFrame, MatrixMetadata matrixMetadata) { + if (matrixMetadata == null) { + return; + } MatrixFormat matrixFormat = matrixMetadata.getMatrixFormat(); if (matrixFormat != null) { return; http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java index 51d38a5..03184e3 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -91,122 +91,32 @@ import org.w3c.dom.NodeList; * */ public final class MLContextUtil { - + /** - * Get HOP DAG in dot format for a DML or PYDML Script. - * - * @param mlCtx - * MLContext object. - * @param script - * The DML or PYDML Script object to execute. - * @param lines - * Only display the hops that have begin and end line number - * equals to the given integers. - * @param performHOPRewrites - * should perform static rewrites, perform - * intra-/inter-procedural analysis to propagate size information - * into functions and apply dynamic rewrites - * @param withSubgraph - * If false, the dot graph will be created without subgraphs for - * statement blocks. - * @return hop DAG in dot format - * @throws LanguageException - * if error occurs - * @throws DMLRuntimeException - * if error occurs - * @throws HopsException - * if error occurs + * Version not available message. */ - public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, - boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException, - LanguageException { - return getHopDAG(mlCtx, script, lines, null, performHOPRewrites, withSubgraph); - } + public static final String VERSION_NOT_AVAILABLE = "Version not available"; /** - * Get HOP DAG in dot format for a DML or PYDML Script. - * - * @param mlCtx - * MLContext object. - * @param script - * The DML or PYDML Script object to execute. - * @param lines - * Only display the hops that have begin and end line number - * equals to the given integers. - * @param newConf - * Spark Configuration. - * @param performHOPRewrites - * should perform static rewrites, perform - * intra-/inter-procedural analysis to propagate size information - * into functions and apply dynamic rewrites - * @param withSubgraph - * If false, the dot graph will be created without subgraphs for - * statement blocks. - * @return hop DAG in dot format - * @throws LanguageException - * if error occurs - * @throws DMLRuntimeException - * if error occurs - * @throws HopsException - * if error occurs + * Build time not available message. */ - public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, SparkConf newConf, - boolean performHOPRewrites, boolean withSubgraph) throws HopsException, DMLRuntimeException, - LanguageException { - SparkConf oldConf = mlCtx.getSparkSession().sparkContext().getConf(); - SparkExecutionContext.SparkClusterConfig systemmlConf = SparkExecutionContext.getSparkClusterConfig(); - long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory(); - try { - if (newConf != null) { - systemmlConf.analyzeSparkConfiguation(newConf); - InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory")); - } - ScriptExecutor scriptExecutor = new ScriptExecutor(); - scriptExecutor.setExecutionType(mlCtx.getExecutionType()); - scriptExecutor.setGPU(mlCtx.isGPU()); - scriptExecutor.setForceGPU(mlCtx.isForceGPU()); - scriptExecutor.setInit(mlCtx.isInitBeforeExecution()); - if (mlCtx.isInitBeforeExecution()) { - mlCtx.setInitBeforeExecution(false); - } - scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable()); - - Long time = new Long((new Date()).getTime()); - if ((script.getName() == null) || (script.getName().equals(""))) { - script.setName(time.toString()); - } - - mlCtx.setExecutionScript(script); - scriptExecutor.compile(script, performHOPRewrites); - Explain.reset(); - // To deal with potential Py4J issues - lines = lines.size() == 1 && lines.get(0) == -1 ? new ArrayList<Integer>() : lines; - return Explain.getHopDAG(scriptExecutor.dmlProgram, lines, withSubgraph); - } catch (RuntimeException e) { - throw new MLContextException("Exception when compiling script", e); - } finally { - if (newConf != null) { - systemmlConf.analyzeSparkConfiguation(oldConf); - InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory); - } - } - } + public static final String BUILD_TIME_NOT_AVAILABLE = "Build time not available"; /** - * Basic data types supported by the MLContext API + * Basic data types supported by the MLContext API. */ @SuppressWarnings("rawtypes") public static final Class[] BASIC_DATA_TYPES = { Integer.class, Boolean.class, Double.class, String.class }; /** - * Complex data types supported by the MLContext API + * Complex data types supported by the MLContext API. */ @SuppressWarnings("rawtypes") public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, RDD.class, Dataset.class, Matrix.class, Frame.class, (new double[][] {}).getClass(), MatrixBlock.class, URL.class }; /** - * All data types supported by the MLContext API + * All data types supported by the MLContext API. */ @SuppressWarnings("rawtypes") public static final Class[] ALL_SUPPORTED_DATA_TYPES = (Class[]) ArrayUtils.addAll(BASIC_DATA_TYPES, @@ -1252,4 +1162,104 @@ public final class MLContextUtil { } } } + + /** + * Get HOP DAG in dot format for a DML or PYDML Script. + * + * @param mlCtx + * MLContext object. + * @param script + * The DML or PYDML Script object to execute. + * @param lines + * Only display the hops that have begin and end line number + * equals to the given integers. + * @param performHOPRewrites + * should perform static rewrites, perform + * intra-/inter-procedural analysis to propagate size information + * into functions and apply dynamic rewrites + * @param withSubgraph + * If false, the dot graph will be created without subgraphs for + * statement blocks. + * @return hop DAG in dot format + * @throws LanguageException + * if error occurs + * @throws DMLRuntimeException + * if error occurs + * @throws HopsException + * if error occurs + */ + public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, boolean performHOPRewrites, + boolean withSubgraph) throws HopsException, DMLRuntimeException, LanguageException { + return getHopDAG(mlCtx, script, lines, null, performHOPRewrites, withSubgraph); + } + + /** + * Get HOP DAG in dot format for a DML or PYDML Script. + * + * @param mlCtx + * MLContext object. + * @param script + * The DML or PYDML Script object to execute. + * @param lines + * Only display the hops that have begin and end line number + * equals to the given integers. + * @param newConf + * Spark Configuration. + * @param performHOPRewrites + * should perform static rewrites, perform + * intra-/inter-procedural analysis to propagate size information + * into functions and apply dynamic rewrites + * @param withSubgraph + * If false, the dot graph will be created without subgraphs for + * statement blocks. + * @return hop DAG in dot format + * @throws LanguageException + * if error occurs + * @throws DMLRuntimeException + * if error occurs + * @throws HopsException + * if error occurs + */ + public static String getHopDAG(MLContext mlCtx, Script script, ArrayList<Integer> lines, SparkConf newConf, + boolean performHOPRewrites, boolean withSubgraph) + throws HopsException, DMLRuntimeException, LanguageException { + SparkConf oldConf = mlCtx.getSparkSession().sparkContext().getConf(); + SparkExecutionContext.SparkClusterConfig systemmlConf = SparkExecutionContext.getSparkClusterConfig(); + long oldMaxMemory = InfrastructureAnalyzer.getLocalMaxMemory(); + try { + if (newConf != null) { + systemmlConf.analyzeSparkConfiguation(newConf); + InfrastructureAnalyzer.setLocalMaxMemory(newConf.getSizeAsBytes("spark.driver.memory")); + } + ScriptExecutor scriptExecutor = new ScriptExecutor(); + scriptExecutor.setExecutionType(mlCtx.getExecutionType()); + scriptExecutor.setGPU(mlCtx.isGPU()); + scriptExecutor.setForceGPU(mlCtx.isForceGPU()); + scriptExecutor.setInit(mlCtx.isInitBeforeExecution()); + if (mlCtx.isInitBeforeExecution()) { + mlCtx.setInitBeforeExecution(false); + } + scriptExecutor.setMaintainSymbolTable(mlCtx.isMaintainSymbolTable()); + + Long time = new Long((new Date()).getTime()); + if ((script.getName() == null) || (script.getName().equals(""))) { + script.setName(time.toString()); + } + + mlCtx.setExecutionScript(script); + scriptExecutor.compile(script, performHOPRewrites); + Explain.reset(); + // To deal with potential Py4J issues + lines = lines.size() == 1 && lines.get(0) == -1 ? new ArrayList<Integer>() : lines; + return Explain.getHopDAG(scriptExecutor.dmlProgram, lines, withSubgraph); + } catch (RuntimeException e) { + throw new MLContextException("Exception when compiling script", e); + } finally { + if (newConf != null) { + systemmlConf.analyzeSparkConfiguation(oldConf); + InfrastructureAnalyzer.setLocalMaxMemory(oldMaxMemory); + } + } + } + } http://git-wip-us.apache.org/repos/asf/systemml/blob/912c6550/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index b08f5b9..9e4cfac 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -42,10 +42,13 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.linalg.DenseVector; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; @@ -54,10 +57,12 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DoubleType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.sysml.api.mlcontext.MLContextConversionUtil; import org.apache.sysml.api.mlcontext.MLContextException; +import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.MatrixFormat; @@ -69,11 +74,14 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.util.DataConverter; import org.junit.Assert; import org.junit.Test; +import scala.Tuple1; import scala.Tuple2; import scala.Tuple3; +import scala.Tuple4; import scala.collection.Iterator; import scala.collection.JavaConversions; import scala.collection.Seq; @@ -2756,4 +2764,310 @@ public class MLContextTest extends MLContextTestBase { Assert.assertEquals(3, results.getLong("y")); } + @Test + public void testOutputDataFrameOfVectorsDML() { + System.out.println("MLContextTest - output DataFrame of vectors DML"); + + String s = "m=matrix('1 2 3 4',rows=2,cols=2);"; + Script script = dml(s).out("m"); + MLResults results = ml.execute(script); + Dataset<Row> df = results.getDataFrame("m", true); + Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN); + + // verify column types + StructType schema = sortedDF.schema(); + StructField[] fields = schema.fields(); + StructField idColumn = fields[0]; + StructField vectorColumn = fields[1]; + Assert.assertTrue(idColumn.dataType() instanceof DoubleType); + Assert.assertTrue(vectorColumn.dataType() instanceof VectorUDT); + + List<Row> list = sortedDF.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Vector v1 = (DenseVector) row1.get(1); + double[] arr1 = v1.toArray(); + Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, arr1, 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Vector v2 = (DenseVector) row2.get(1); + double[] arr2 = v2.toArray(); + Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, arr2, 0.0); + } + + @Test + public void testOutputDoubleArrayFromMatrixDML() { + System.out.println("MLContextTest - output double array from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + double[][] matrix = ml.execute(dml(s).out("M")).getMatrix("M").to2DDoubleArray(); + Assert.assertEquals(1.0, matrix[0][0], 0); + Assert.assertEquals(2.0, matrix[0][1], 0); + Assert.assertEquals(3.0, matrix[1][0], 0); + Assert.assertEquals(4.0, matrix[1][1], 0); + } + + @Test + public void testOutputDataFrameFromMatrixDML() { + System.out.println("MLContextTest - output DataFrame from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + Dataset<Row> df = ml.execute(script).getMatrix("M").toDF(); + Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN); + List<Row> list = sortedDF.collectAsList(); + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertEquals(1.0, row1.getDouble(1), 0.0); + Assert.assertEquals(2.0, row1.getDouble(2), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Assert.assertEquals(3.0, row2.getDouble(1), 0.0); + Assert.assertEquals(4.0, row2.getDouble(2), 0.0); + } + + @Test + public void testOutputDataFrameDoublesNoIDColumnFromMatrixDML() { + System.out.println("MLContextTest - output DataFrame of doubles with no ID column from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=1, cols=4);"; + Script script = dml(s).out("M"); + Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleNoIDColumn(); + List<Row> list = df.collectAsList(); + + Row row = list.get(0); + Assert.assertEquals(1.0, row.getDouble(0), 0.0); + Assert.assertEquals(2.0, row.getDouble(1), 0.0); + Assert.assertEquals(3.0, row.getDouble(2), 0.0); + Assert.assertEquals(4.0, row.getDouble(3), 0.0); + } + + @Test + public void testOutputDataFrameDoublesWithIDColumnFromMatrixDML() { + System.out.println("MLContextTest - output DataFrame of doubles with ID column from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleWithIDColumn(); + Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN); + List<Row> list = sortedDF.collectAsList(); + + Row row1 = list.get(0); + Assert.assertEquals(1.0, row1.getDouble(0), 0.0); + Assert.assertEquals(1.0, row1.getDouble(1), 0.0); + Assert.assertEquals(2.0, row1.getDouble(2), 0.0); + + Row row2 = list.get(1); + Assert.assertEquals(2.0, row2.getDouble(0), 0.0); + Assert.assertEquals(3.0, row2.getDouble(1), 0.0); + Assert.assertEquals(4.0, row2.getDouble(2), 0.0); + } + + @Test + public void testOutputDataFrameVectorsNoIDColumnFromMatrixDML() { + System.out.println("MLContextTest - output DataFrame of vectors with no ID column from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=1, cols=4);"; + Script script = dml(s).out("M"); + Dataset<Row> df = ml.execute(script).getMatrix("M").toDFVectorNoIDColumn(); + List<Row> list = df.collectAsList(); + + Row row = list.get(0); + Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, ((Vector) row.get(0)).toArray(), 0.0); + } + + @Test + public void testOutputDataFrameVectorsWithIDColumnFromMatrixDML() { + System.out.println("MLContextTest - output DataFrame of vectors with ID column from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=1, cols=4);"; + Script script = dml(s).out("M"); + Dataset<Row> df = ml.execute(script).getMatrix("M").toDFVectorWithIDColumn(); + List<Row> list = df.collectAsList(); + + Row row = list.get(0); + Assert.assertEquals(1.0, row.getDouble(0), 0.0); + Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0, 4.0 }, ((Vector) row.get(1)).toArray(), 0.0); + } + + @Test + public void testOutputJavaRDDStringCSVFromMatrixDML() { + System.out.println("MLContextTest - output Java RDD String CSV from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=1, cols=4);"; + Script script = dml(s).out("M"); + JavaRDD<String> javaRDDStringCSV = ml.execute(script).getMatrix("M").toJavaRDDStringCSV(); + List<String> lines = javaRDDStringCSV.collect(); + Assert.assertEquals("1.0,2.0,3.0,4.0", lines.get(0)); + } + + @Test + public void testOutputJavaRDDStringIJVFromMatrixDML() { + System.out.println("MLContextTest - output Java RDD String IJV from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + MLResults results = ml.execute(script); + JavaRDD<String> javaRDDStringIJV = results.getJavaRDDStringIJV("M"); + List<String> lines = javaRDDStringIJV.sortBy(row -> row, true, 1).collect(); + Assert.assertEquals("1 1 1.0", lines.get(0)); + Assert.assertEquals("1 2 2.0", lines.get(1)); + Assert.assertEquals("2 1 3.0", lines.get(2)); + Assert.assertEquals("2 2 4.0", lines.get(3)); + } + + @Test + public void testOutputRDDStringCSVFromMatrixDML() { + System.out.println("MLContextTest - output RDD String CSV from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=1, cols=4);"; + Script script = dml(s).out("M"); + RDD<String> rddStringCSV = ml.execute(script).getMatrix("M").toRDDStringCSV(); + Iterator<String> iterator = rddStringCSV.toLocalIterator(); + Assert.assertEquals("1.0,2.0,3.0,4.0", iterator.next()); + } + + @Test + public void testOutputRDDStringIJVFromMatrixDML() { + System.out.println("MLContextTest - output RDD String IJV from matrix DML"); + + String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; + Script script = dml(s).out("M"); + RDD<String> rddStringIJV = ml.execute(script).getMatrix("M").toRDDStringIJV(); + String[] rows = (String[]) rddStringIJV.collect(); + Arrays.sort(rows); + Assert.assertEquals("1 1 1.0", rows[0]); + Assert.assertEquals("1 2 2.0", rows[1]); + Assert.assertEquals("2 1 3.0", rows[2]); + Assert.assertEquals("2 2 4.0", rows[3]); + } + + @Test + public void testMLContextVersionMessage() { + System.out.println("MLContextTest - version message"); + + String version = ml.version(); + // not available until jar built + Assert.assertEquals(MLContextUtil.VERSION_NOT_AVAILABLE, version); + } + + @Test + public void testMLContextBuildTimeMessage() { + System.out.println("MLContextTest - build time message"); + + String buildTime = ml.buildTime(); + // not available until jar built + Assert.assertEquals(MLContextUtil.BUILD_TIME_NOT_AVAILABLE, buildTime); + } + + @Test + public void testMLContextCreateAndClose() { + // MLContext created by the @BeforeClass method in MLContextTestBase + // MLContext closed by the @AfterClass method in MLContextTestBase + System.out.println("MLContextTest - create MLContext and close (without script execution)"); + } + + @Test + public void testDataFrameToBinaryBlocks() { + System.out.println("MLContextTest - DataFrame to binary blocks"); + + List<String> list = new ArrayList<String>(); + list.add("1,2,3"); + list.add("4,5,6"); + list.add("7,8,9"); + JavaRDD<String> javaRddString = sc.parallelize(list); + + JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> dataFrame = spark.createDataFrame(javaRddRow, schema); + + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = MLContextConversionUtil + .dataFrameToMatrixBinaryBlocks(dataFrame); + Tuple2<MatrixIndexes, MatrixBlock> first = binaryBlocks.first(); + MatrixBlock mb = first._2(); + double[][] matrix = DataConverter.convertToDoubleMatrix(mb); + Assert.assertArrayEquals(new double[] { 1.0, 2.0, 3.0 }, matrix[0], 0.0); + Assert.assertArrayEquals(new double[] { 4.0, 5.0, 6.0 }, matrix[1], 0.0); + Assert.assertArrayEquals(new double[] { 7.0, 8.0, 9.0 }, matrix[2], 0.0); + } + + @Test + public void testGetTuple1DML() { + System.out.println("MLContextTest - Get Tuple1<Matrix> DML"); + JavaRDD<String> javaRddString = sc + .parallelize(Stream.of("1,2,3", "4,5,6", "7,8,9").collect(Collectors.toList())); + JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); + List<StructField> fields = new ArrayList<StructField>(); + fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); + fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); + StructType schema = DataTypes.createStructType(fields); + Dataset<Row> df = spark.createDataFrame(javaRddRow, schema); + + Script script = dml("N=M*2").in("M", df).out("N"); + Tuple1<Matrix> tuple = ml.execute(script).getTuple("N"); + double[][] n = tuple._1().to2DDoubleArray(); + Assert.assertEquals(2.0, n[0][0], 0); + Assert.assertEquals(4.0, n[0][1], 0); + Assert.assertEquals(6.0, n[0][2], 0); + Assert.assertEquals(8.0, n[1][0], 0); + Assert.assertEquals(10.0, n[1][1], 0); + Assert.assertEquals(12.0, n[1][2], 0); + Assert.assertEquals(14.0, n[2][0], 0); + Assert.assertEquals(16.0, n[2][1], 0); + Assert.assertEquals(18.0, n[2][2], 0); + } + + @Test + public void testGetTuple2DML() { + System.out.println("MLContextTest - Get Tuple2<Matrix,Double> DML"); + + double[][] m = new double[][] { { 1, 2 }, { 3, 4 } }; + + Script script = dml("N=M*2;s=sum(N)").in("M", m).out("N", "s"); + Tuple2<Matrix, Double> tuple = ml.execute(script).getTuple("N", "s"); + double[][] n = tuple._1().to2DDoubleArray(); + double s = tuple._2(); + Assert.assertArrayEquals(new double[] { 2, 4 }, n[0], 0.0); + Assert.assertArrayEquals(new double[] { 6, 8 }, n[1], 0.0); + Assert.assertEquals(20.0, s, 0.0); + } + + @Test + public void testGetTuple3DML() { + System.out.println("MLContextTest - Get Tuple3<Long,Double,Boolean> DML"); + + Script script = dml("a=1+2;b=a+0.5;c=TRUE;").out("a", "b", "c"); + Tuple3<Long, Double, Boolean> tuple = ml.execute(script).getTuple("a", "b", "c"); + long a = tuple._1(); + double b = tuple._2(); + boolean c = tuple._3(); + Assert.assertEquals(3, a); + Assert.assertEquals(3.5, b, 0.0); + Assert.assertEquals(true, c); + } + + @Test + public void testGetTuple4DML() { + System.out.println("MLContextTest - Get Tuple4<Long,Double,Boolean,String> DML"); + + Script script = dml("a=1+2;b=a+0.5;c=TRUE;d=\"yes it's \"+c").out("a", "b", "c", "d"); + Tuple4<Long, Double, Boolean, String> tuple = ml.execute(script).getTuple("a", "b", "c", "d"); + long a = tuple._1(); + double b = tuple._2(); + boolean c = tuple._3(); + String d = tuple._4(); + Assert.assertEquals(3, a); + Assert.assertEquals(3.5, b, 0.0); + Assert.assertEquals(true, c); + Assert.assertEquals("yes it's TRUE", d); + } + }