Repository: systemml Updated Branches: refs/heads/master 41de8dcdc -> 7907c0ea5
[SYSTEMML-1325] Harmonize Compilation Execution Pipelines and Add GPU Support to JMLC This PR adds support for compilation and execution of GPU enabled scripts in JMLC and harmonizes the pipeline used to compile and execute DML programs across the JMLC, MLContext and DMLScript. Specifically, the following changes were made: 1. All three APIs now call ScriptExecutorUtils.compileRuntimeProgram to compile DML scripts. The original logic in MLContext and JMLC for pinning inputs and persisting outputs has been preserved. 2. All three APIs now use ScriptExecutorUtils.executeRuntimeProgram to execute the compiled program. Previously, JMLC called the Script.execute method directly. 3. jmlc.Connection.prepareScript now supports compiling a script to use GPU. Note that following #832 the issue noted in #830 has been resolved. 4. A PreparedScript is now statically assigned a GPU context when it is compiled and instatiated. This has potential performance implications because it means that a PreparedScript must be executed on a specific GPU. However, it reduces overhead from creating a GPU context each time a script is executed and unsures that a user cannot compile a script to use GPU and then forget to assign a GPU context when the script is run. 5. Per (3) I have added a unit test which compiles and executes a GPU enabled script in JMLC both with and without pinned data and just asserts that no errors occur. Closes #836. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7907c0ea Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7907c0ea Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7907c0ea Branch: refs/heads/master Commit: 7907c0ea5109e9b33465b7d7a2ac2bf0c42ab380 Parents: 41de8dc Author: Anthony Thomas <ahtho...@eng.ucsd.edu> Authored: Sun Oct 14 09:24:33 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Sun Oct 14 09:25:55 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/api/DMLScript.java | 172 ++++------- .../apache/sysml/api/ScriptExecutorUtils.java | 290 ++++++++++++++++--- .../org/apache/sysml/api/jmlc/Connection.java | 158 +++++----- .../apache/sysml/api/jmlc/PreparedScript.java | 43 +-- .../sysml/api/mlcontext/ScriptExecutor.java | 100 +++---- .../apache/sysml/conf/ConfigurationManager.java | 38 ++- .../controlprogram/LocalVariableMap.java | 11 +- .../runtime/controlprogram/ProgramBlock.java | 4 +- .../controlprogram/caching/CacheableData.java | 4 +- .../context/ExecutionContext.java | 4 +- .../context/SparkExecutionContext.java | 5 +- .../gpu/context/GPUContextPool.java | 31 +- .../java/org/apache/sysml/utils/Statistics.java | 6 +- .../org/apache/sysml/test/gpu/GPUTests.java | 13 +- .../org/apache/sysml/test/gpu/JMLCTests.java | 126 ++++++++ .../jmlc/JMLCParfor2ForCompileTest.java | 7 +- 16 files changed, 661 insertions(+), 351 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/api/DMLScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index 9976adc..16d8986 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -32,6 +32,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Scanner; @@ -48,6 +49,7 @@ import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.util.GenericOptionsParser; import org.apache.log4j.Level; import org.apache.log4j.Logger; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.mlcontext.ScriptType; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.ConfigurationManager; @@ -65,13 +67,14 @@ import org.apache.sysml.parser.ParserFactory; import org.apache.sysml.parser.ParserWrapper; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.io.IOUtilFunctions; import org.apache.sysml.runtime.matrix.CleanupMR; @@ -79,13 +82,9 @@ import org.apache.sysml.runtime.matrix.mapred.MRConfigurationNames; import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration; import org.apache.sysml.runtime.util.LocalFileUtils; import org.apache.sysml.runtime.util.MapReduceTool; -import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.NativeHelper; -import org.apache.sysml.utils.Explain.ExplainCounts; import org.apache.sysml.utils.Explain.ExplainType; import org.apache.sysml.utils.Statistics; -import org.apache.sysml.yarn.DMLAppMasterUtils; -import org.apache.sysml.yarn.DMLYarnClientProxy; public class DMLScript @@ -111,10 +110,9 @@ public class DMLScript // ARC, // https://dbs.uni-leipzig.de/file/ARC.pdf // LOOP_AWARE // different policies for operations in for/while/parfor loop vs out-side the loop } - - // TODO: Anthony - public static boolean JMLC_MEM_STATISTICS = false; // whether to gather memory use stats in JMLC - + + public static RUNTIME_PLATFORM rtplatform = DMLOptions.defaultOptions.execMode; // the execution mode + // debug mode is deprecated and will be removed soon. public static boolean ENABLE_DEBUG_MODE = DMLOptions.defaultOptions.debug; // debug mode @@ -141,7 +139,7 @@ public class DMLScript public static boolean VALIDATOR_IGNORE_ISSUES = false; public static String _uuid = IDHandler.createDistributedUniqueID(); - private static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); + static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); /////////////////////////////// // public external interface @@ -195,6 +193,7 @@ public class DMLScript } } + /** * Single entry point for all public invocation alternatives (e.g., * main, executeScript, JaqlUdf etc) @@ -215,8 +214,7 @@ public class DMLScript { dmlOptions = DMLOptions.parseCLArguments(args); ConfigurationManager.setGlobalOptions(dmlOptions); - - JMLC_MEM_STATISTICS = dmlOptions.memStats; + EXPLAIN = dmlOptions.explainType; ENABLE_DEBUG_MODE = dmlOptions.debug; SCRIPT_TYPE = dmlOptions.scriptType; @@ -378,104 +376,56 @@ public class DMLScript // (core compilation and execute) //////// - /** - * The running body of DMLScript execution. This method should be called after execution properties have been correctly set, - * and customized parameters have been put into _argVals - * - * @param dmlScriptStr DML script string - * @param fnameOptConfig configuration file - * @param argVals map of argument values - * @param allArgs arguments - * @param scriptType type of script (DML or PyDML) - * @throws IOException if IOException occurs - */ - private static void execute(String dmlScriptStr, String fnameOptConfig, Map<String,String> argVals, String[] allArgs, ScriptType scriptType) - throws IOException - { - SCRIPT_TYPE = scriptType; - - //print basic time and environment info - printStartExecInfo( dmlScriptStr ); - - //Step 1: parse configuration files & write any configuration specific global variables - DMLConfig dmlconf = DMLConfig.readConfigurationFile(fnameOptConfig); - ConfigurationManager.setGlobalConfig(dmlconf); - CompilerConfig cconf = OptimizerUtils.constructCompilerConfig(dmlconf); - ConfigurationManager.setGlobalConfig(cconf); - LOG.debug("\nDML config: \n" + dmlconf.getConfigInfo()); - - setGlobalFlags(dmlconf); - - //Step 2: set local/remote memory if requested (for compile in AM context) - if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ - DMLAppMasterUtils.setupConfigRemoteMaxMemory(dmlconf); - } - - //Step 3: parse dml script - Statistics.startCompileTimer(); - ParserWrapper parser = ParserFactory.createParser(scriptType); - DMLProgram prog = parser.parse(DML_FILE_PATH_ANTLR_PARSER, dmlScriptStr, argVals); - - //Step 4: construct HOP DAGs (incl LVA, validate, and setup) - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - dmlt.constructHops(prog); - - //init working directories (before usage by following compilation steps) - initHadoopExecution( dmlconf ); - - //Step 5: rewrite HOP DAGs (incl IPA and memory estimates) - dmlt.rewriteHopsDAG(prog); - - //Step 6: construct lops (incl exec type and op selection) - dmlt.constructLops(prog); - - if (LOG.isDebugEnabled()) { - LOG.debug("\n********************** LOPS DAG *******************"); - dmlt.printLops(prog); - dmlt.resetLopsDAGVisitStatus(prog); - } - - //Step 7: generate runtime program, incl codegen - Program rtprog = dmlt.getRuntimeProgram(prog, dmlconf); - - //launch SystemML appmaster (if requested and not already in launched AM) - if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ - if( !isActiveAM() && DMLYarnClientProxy.launchDMLYarnAppmaster(dmlScriptStr, dmlconf, allArgs, rtprog) ) - return; //if AM launch unsuccessful, fall back to normal execute - if( isActiveAM() ) //in AM context (not failed AM launch) - DMLAppMasterUtils.setupProgramMappingRemoteMaxMemory(rtprog); - } - - //Step 9: prepare statistics [and optional explain output] - //count number compiled MR jobs / SP instructions - ExplainCounts counts = Explain.countDistributedOperations(rtprog); - Statistics.resetNoOfCompiledJobs( counts.numJobs ); - - //explain plan of program (hops or runtime) - if( EXPLAIN != ExplainType.NONE ) - System.out.println(Explain.display(prog, rtprog, EXPLAIN, counts)); - - Statistics.stopCompileTimer(); - - //double costs = CostEstimationWrapper.getTimeEstimate(rtprog, ExecutionContextFactory.createContext()); - //System.out.println("Estimated costs: "+costs); - - //Step 10: execute runtime program - ExecutionContext ec = null; - try { - ec = ExecutionContextFactory.createContext(rtprog); - ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, dmlconf, ConfigurationManager.isStatistics() ? ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters() : 0, null); - } - finally { - if(ec != null && ec instanceof SparkExecutionContext) - ((SparkExecutionContext) ec).close(); - LOG.info("END DML run " + getDateTime() ); - //cleanup scratch_space and all working dirs - cleanupHadoopExecution( dmlconf ); - } - } + /** + * The running body of DMLScript execution. This method should be called after execution properties have been correctly set, + * and customized parameters have been put into _argVals + * + * @param dmlScriptStr DML script string + * @param fnameOptConfig configuration file + * @param argVals map of argument values + * @param allArgs arguments + * @param scriptType type of script (DML or PyDML) + * @throws IOException if IOException occurs + */ + private static void execute(String dmlScriptStr, String fnameOptConfig, Map<String,String> argVals, String[] allArgs, ScriptType scriptType) + throws IOException + { + SCRIPT_TYPE = scriptType; + + //print basic time and environment info + printStartExecInfo( dmlScriptStr ); + + //Step 1: parse configuration files & write any configuration specific global variables + DMLConfig dmlconf = DMLConfig.readConfigurationFile(fnameOptConfig); + ConfigurationManager.setGlobalConfig(dmlconf); + CompilerConfig cconf = OptimizerUtils.constructCompilerConfig(dmlconf); + ConfigurationManager.setGlobalConfig(cconf); + LOG.debug("\nDML config: \n" + dmlconf.getConfigInfo()); + + setGlobalFlags(dmlconf); + Program rtprog = ScriptExecutorUtils.compileRuntimeProgram(dmlScriptStr, argVals, allArgs, + scriptType, dmlconf, SystemMLAPI.DMLScript); + List<GPUContext> gCtxs = ConfigurationManager.getDMLOptions().gpu ? GPUContextPool.getAllGPUContexts() : null; + + //double costs = CostEstimationWrapper.getTimeEstimate(rtprog, ExecutionContextFactory.createContext()); + //System.out.println("Estimated costs: "+costs); + + //Step 10: execute runtime program + ExecutionContext ec = null; + try { + ec = ScriptExecutorUtils.executeRuntimeProgram( + rtprog, dmlconf, ConfigurationManager.isStatistics() ? + ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters() : 0, + new LocalVariableMap(), null, SystemMLAPI.DMLScript, gCtxs); + } + finally { + if(ec != null && ec instanceof SparkExecutionContext) + ((SparkExecutionContext) ec).close(); + LOG.info("END DML run " + getDateTime() ); + //cleanup scratch_space and all working dirs + cleanupHadoopExecution( dmlconf ); + } + } /** * Sets the global flags in DMLScript based on user provided configuration @@ -706,4 +656,4 @@ public class DMLScript throw new DMLException("Failed to run SystemML workspace cleanup.", ex); } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java index 9956518..b1b5735 100644 --- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java +++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,74 +19,271 @@ package org.apache.sysml.api; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; -import java.util.Set; +import java.util.Map; -import org.apache.sysml.api.mlcontext.ScriptExecutor; +import org.apache.sysml.api.jmlc.JMLCUtils; +import org.apache.sysml.api.mlcontext.MLContextUtil; +import org.apache.sysml.api.mlcontext.ScriptType; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.hops.codegen.SpoofCompiler; -import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.hops.rewrite.ProgramRewriter; +import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DMLTranslator; +import org.apache.sysml.parser.LanguageException; +import org.apache.sysml.parser.ParseException; +import org.apache.sysml.parser.ParserFactory; +import org.apache.sysml.parser.ParserWrapper; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.instructions.gpu.context.GPUObject; +import org.apache.sysml.runtime.util.UtilFunctions; +import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.Statistics; +import org.apache.sysml.utils.Explain.ExplainCounts; +import org.apache.sysml.utils.Explain.ExplainType; +import org.apache.sysml.yarn.DMLAppMasterUtils; +import org.apache.sysml.yarn.DMLYarnClientProxy; +import org.apache.sysml.runtime.DMLRuntimeException; public class ScriptExecutorUtils { + public static final boolean IS_JCUDA_AVAILABLE; + static { + // Early detection of JCuda libraries avoids synchronization overhead for common JMLC scenario: + // i.e. CPU-only multi-threaded execution + boolean isJCudaAvailable = false; + try { + Class.forName("jcuda.Pointer"); + isJCudaAvailable = true; + } + catch (ClassNotFoundException e) { } + IS_JCUDA_AVAILABLE = isJCudaAvailable; + } + + public static enum SystemMLAPI { + DMLScript, + MLContext, + JMLC + } + + public static Program compileRuntimeProgram(String script, Map<String,String> nsscripts, Map<String, String> args, + String[] inputs, String[] outputs, ScriptType scriptType, DMLConfig dmlconf, SystemMLAPI api) { + return compileRuntimeProgram(script, nsscripts, args, null, null, inputs, outputs, + scriptType, dmlconf, api, true, false, false); + } + + public static Program compileRuntimeProgram(String script, Map<String, String> args, String[] allArgs, + ScriptType scriptType, DMLConfig dmlconf, SystemMLAPI api) { + return compileRuntimeProgram(script, Collections.emptyMap(), args, allArgs, null, null, null, + scriptType, dmlconf, api, true, false, false); + } + /** - * Execute the runtime program. This involves execution of the program - * blocks that make up the runtime program and may involve dynamic - * recompilation. - * - * @param se - * script executor - * @param statisticsMaxHeavyHitters - * maximum number of statistics to print + * Compile a runtime program + * + * @param script string representing of the DML or PyDML script + * @param nsscripts map (name, script) of the DML or PyDML namespace scripts + * @param args map of input parameters ($) and their values + * @param allArgs commandline arguments + * @param symbolTable symbol table associated with MLContext + * @param inputs string array of input variables to register + * @param outputs string array of output variables to register + * @param scriptType is this script DML or PyDML + * @param dmlconf configuration provided by the user + * @param api API used to execute the runtime program + * @param performHOPRewrites should perform hop rewrites + * @param maintainSymbolTable whether or not all values should be maintained in the symbol table after execution. + * @return compiled runtime program */ - public static void executeRuntimeProgram(ScriptExecutor se, int statisticsMaxHeavyHitters) { - Program prog = se.getRuntimeProgram(); - ExecutionContext ec = se.getExecutionContext(); - DMLConfig config = se.getConfig(); - executeRuntimeProgram(prog, ec, config, statisticsMaxHeavyHitters, se.getScript().getOutputVariables()); + public static Program compileRuntimeProgram(String script, Map<String,String> nsscripts, Map<String, String> args, String[] allArgs, + // Input/Outputs registered in MLContext and JMLC. These are set to null by DMLScript + LocalVariableMap symbolTable, String[] inputs, String[] outputs, + ScriptType scriptType, DMLConfig dmlconf, SystemMLAPI api, + // MLContext-specific flags + boolean performHOPRewrites, boolean maintainSymbolTable, + boolean init) { + DMLScript.SCRIPT_TYPE = scriptType; + + Program rtprog = null; + + if (ConfigurationManager.isGPU() && !IS_JCUDA_AVAILABLE) + throw new RuntimeException("Incorrect usage: Cannot use the GPU backend without JCuda libraries. Hint: Include systemml-*-extra.jar (compiled using mvn package -P distribution) into the classpath."); + else if (!ConfigurationManager.isGPU() && ConfigurationManager.isForcedGPU()) + throw new RuntimeException("Incorrect usage: Cannot force a GPU-execution without enabling GPU"); + + if(api == SystemMLAPI.JMLC) { + //check for valid names of passed arguments + String[] invalidArgs = args.keySet().stream() + .filter(k -> k==null || !k.startsWith("$")).toArray(String[]::new); + if( invalidArgs.length > 0 ) + throw new LanguageException("Invalid argument names: "+Arrays.toString(invalidArgs)); + + //check for valid names of input and output variables + String[] invalidVars = UtilFunctions.asSet(inputs, outputs).stream() + .filter(k -> k==null || k.startsWith("$")).toArray(String[]::new); + if( invalidVars.length > 0 ) + throw new LanguageException("Invalid variable names: "+Arrays.toString(invalidVars)); + } + + String dmlParserFilePath = (api == SystemMLAPI.JMLC) ? null : DMLScript.DML_FILE_PATH_ANTLR_PARSER; + + try { + //Step 1: set local/remote memory if requested (for compile in AM context) + if(api == SystemMLAPI.DMLScript && dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ + DMLAppMasterUtils.setupConfigRemoteMaxMemory(dmlconf); + } + + // Start timer (disabled for JMLC) + if(api != SystemMLAPI.JMLC) + Statistics.startCompileTimer(); + + //Step 2: parse dml script + ParserWrapper parser = ParserFactory.createParser(scriptType, nsscripts); + DMLProgram prog = parser.parse(dmlParserFilePath, script, args); + + //Step 3: construct HOP DAGs (incl LVA, validate, and setup) + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + + //init working directories (before usage by following compilation steps) + if(api != SystemMLAPI.JMLC) + if ((api == SystemMLAPI.MLContext && init) || api != SystemMLAPI.MLContext) + DMLScript.initHadoopExecution( dmlconf ); + + + //Step 4: rewrite HOP DAGs (incl IPA and memory estimates) + if(performHOPRewrites) + dmlt.rewriteHopsDAG(prog); + + //Step 5: Remove Persistent Read/Writes + if(api == SystemMLAPI.JMLC) { + //rewrite persistent reads/writes + RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs); + ProgramRewriter rewriter2 = new ProgramRewriter(rewrite); + rewriter2.rewriteProgramHopDAGs(prog); + } + else if(api == SystemMLAPI.MLContext) { + //rewrite persistent reads/writes + RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs, symbolTable); + ProgramRewriter rewriter2 = new ProgramRewriter(rewrite); + rewriter2.rewriteProgramHopDAGs(prog); + } + + //Step 6: construct lops (incl exec type and op selection) + dmlt.constructLops(prog); + + if(DMLScript.LOG.isDebugEnabled()) { + DMLScript.LOG.debug("\n********************** LOPS DAG *******************"); + dmlt.printLops(prog); + dmlt.resetLopsDAGVisitStatus(prog); + } + + //Step 7: generate runtime program, incl codegen + rtprog = dmlt.getRuntimeProgram(prog, dmlconf); + + // Step 8: Cleanup/post-processing + if(api == SystemMLAPI.JMLC) { + JMLCUtils.cleanupRuntimeProgram(rtprog, outputs); + } + else if(api == SystemMLAPI.DMLScript) { + //launch SystemML appmaster (if requested and not already in launched AM) + if( dmlconf.getBooleanValue(DMLConfig.YARN_APPMASTER) ){ + if( !DMLScript.isActiveAM() && DMLYarnClientProxy.launchDMLYarnAppmaster(script, dmlconf, allArgs, rtprog) ) + return null; //if AM launch unsuccessful, fall back to normal execute + if( DMLScript.isActiveAM() ) //in AM context (not failed AM launch) + DMLAppMasterUtils.setupProgramMappingRemoteMaxMemory(rtprog); + } + } + else if(api == SystemMLAPI.MLContext) { + if (maintainSymbolTable) { + MLContextUtil.deleteRemoveVariableInstructions(rtprog); + } else { + JMLCUtils.cleanupRuntimeProgram(rtprog, outputs); + } + } + + //Step 9: prepare statistics [and optional explain output] + //count number compiled MR jobs / SP instructions + if(api != SystemMLAPI.JMLC) { + ExplainCounts counts = Explain.countDistributedOperations(rtprog); + Statistics.resetNoOfCompiledJobs( counts.numJobs ); + //explain plan of program (hops or runtime) + if( DMLScript.EXPLAIN != ExplainType.NONE ) + System.out.println(Explain.display(prog, rtprog, DMLScript.EXPLAIN, counts)); + + Statistics.stopCompileTimer(); + } + } + catch(ParseException pe) { + // don't chain ParseException (for cleaner error output) + throw pe; + } + catch(IOException ex) { + throw new DMLException(ex); + } + catch(Exception ex) { + throw new DMLException(ex); + } + return rtprog; } /** * Execute the runtime program. This involves execution of the program * blocks that make up the runtime program and may involve dynamic * recompilation. - * + * * @param rtprog * runtime program - * @param ec - * execution context * @param dmlconf * dml configuration * @param statisticsMaxHeavyHitters * maximum number of statistics to print + * @param symbolTable + * symbol table (that were registered as input as part of MLContext) * @param outputVariables - * output variables that were registered as part of MLContext + * output variables (that were registered as output as part of MLContext) + * @param api + * API used to execute the runtime program + * @param gCtxs + * list of GPU contexts + * @return execution context */ - public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DMLConfig dmlconf, int statisticsMaxHeavyHitters, Set<String> outputVariables) { + public static ExecutionContext executeRuntimeProgram(Program rtprog, DMLConfig dmlconf, int statisticsMaxHeavyHitters, + LocalVariableMap symbolTable, HashSet<String> outputVariables, + SystemMLAPI api, List<GPUContext> gCtxs) { boolean exceptionThrown = false; - + + // Start timer Statistics.startRunTimer(); + + // Create execution context and attach registered outputs + ExecutionContext ec = ExecutionContextFactory.createContext(symbolTable, rtprog); + if(outputVariables != null) + ec.getVariables().setRegisteredOutputs(outputVariables); + + // Assign GPUContext to the current ExecutionContext + if(gCtxs != null) { + gCtxs.get(0).initializeThread(); + ec.setGPUContexts(gCtxs); + } + Exception finalizeException = null; try { // run execute (w/ exception handling to ensure proper shutdown) - if (ConfigurationManager.isGPU() && ec != null) { - List<GPUContext> gCtxs = GPUContextPool.reserveAllGPUContexts(); - if (gCtxs == null) { - throw new DMLRuntimeException( - "GPU : Could not create GPUContext, either no GPU or all GPUs currently in use"); - } - gCtxs.get(0).initializeThread(); - ec.setGPUContexts(gCtxs); - } rtprog.execute(ec); } catch (Throwable e) { exceptionThrown = true; @@ -116,25 +313,32 @@ public class ScriptExecutorUtils { for(GPUContext gCtx : ec.getGPUContexts()) { gCtx.clearTemporaryMemory(); } - GPUContextPool.freeAllGPUContexts(); } catch (Exception e1) { exceptionThrown = true; finalizeException = e1; // do not throw exception while cleanup } + } if( ConfigurationManager.isCodegenEnabled() ) SpoofCompiler.cleanupCodeGenerator(); - - // display statistics (incl caching stats if enabled) + + //cleanup unnecessary outputs + if (outputVariables != null) + symbolTable.removeAllNotIn(outputVariables); + + // Display statistics (disabled for JMLC) Statistics.stopRunTimer(); - (exceptionThrown ? System.err : System.out) - .println(Statistics.display(statisticsMaxHeavyHitters > 0 ? - statisticsMaxHeavyHitters : ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters())); - ConfigurationManager.resetStatistics(); + if(api != SystemMLAPI.JMLC) { + (exceptionThrown ? System.err : System.out) + .println(Statistics.display(statisticsMaxHeavyHitters > 0 ? + statisticsMaxHeavyHitters : + ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters())); + } } if(finalizeException != null) { throw new DMLRuntimeException("Error occured while GPU memory cleanup.", finalizeException); } - } + return ec; + } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/api/jmlc/Connection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java b/src/main/java/org/apache/sysml/api/jmlc/Connection.java index ea0d503..7f5d7c9 100644 --- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java +++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java @@ -25,15 +25,17 @@ import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.util.Arrays; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Map; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.api.ScriptExecutorUtils; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.mlcontext.ScriptType; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.CompilerConfig.ConfigType; @@ -41,18 +43,13 @@ import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.conf.DMLOptions; import org.apache.sysml.hops.codegen.SpoofCompiler; -import org.apache.sysml.hops.rewrite.ProgramRewriter; -import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite; -import org.apache.sysml.parser.DMLProgram; -import org.apache.sysml.parser.DMLTranslator; import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.LanguageException; -import org.apache.sysml.parser.ParseException; -import org.apache.sysml.parser.ParserFactory; -import org.apache.sysml.parser.ParserWrapper; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.io.FrameReader; import org.apache.sysml.runtime.io.FrameReaderFactory; import org.apache.sysml.runtime.io.IOUtilFunctions; @@ -151,6 +148,8 @@ public class Connection implements Closeable * @param dmlconfig a dml configuration. */ public Connection(DMLConfig dmlconfig) { + DMLScript.rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; + //setup basic parameters for embedded execution //(parser, compiler, and runtime parameters) CompilerConfig cconf = new CompilerConfig(); @@ -193,7 +192,25 @@ public class Connection implements Closeable /** * Prepares (precompiles) a script and registers input and output variables. - * + * + * @param script string representing the DML or PyDML script + * @param inputs string array of input variables to register + * @param outputs string array of output variables to register + * @param useGpu {@code true} if prepare the script with GPU support, {@code false} + * @param forceGpu {@code true} if prepare the script with forced GPU support, {@code false} + * @param gpuIndex the GPU to use to execute the given prepared script + * @return PreparedScript object representing the precompiled script + */ + public PreparedScript prepareScript( + String script, String[] inputs, String[] outputs, boolean useGpu, boolean forceGpu, int gpuIndex) { + return prepareScript( + script, Collections.emptyMap(), Collections.emptyMap(), + inputs, outputs, false, useGpu, forceGpu, gpuIndex); + } + + /** + * Prepares (precompiles) a script and registers input and output variables. + * * @param script string representing the DML or PyDML script * @param inputs string array of input variables to register * @param outputs string array of output variables to register @@ -230,67 +247,72 @@ public class Connection implements Closeable * @return PreparedScript object representing the precompiled script */ public PreparedScript prepareScript(String script, Map<String,String> nsscripts, Map<String, String> args, String[] inputs, String[] outputs, boolean parsePyDML) { - DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML; + return prepareScript(script, nsscripts, args, inputs, outputs, parsePyDML, false, false, -1); + } + + /** + * List of available GPU contexts: + */ + static GPUContext [] AVAILABLE_GPU_CONTEXTS; + + + /** + * Prepares (precompiles) a script, sets input parameter values, and registers input and output variables. + * + * @param script string representing of the DML or PyDML script + * @param nsscripts map (name, script) of the DML or PyDML namespace scripts + * @param args map of input parameters ($) and their values + * @param inputs string array of input variables to register + * @param outputs string array of output variables to register + * @param parsePyDML {@code true} if PyDML, {@code false} if DML + * @param useGPU {@code true} if prepare the script with GPU support, {@code false} + * @param forceGPU {@code true} if prepare the script with forced GPU support, {@code false} + * @param gpuIndex the GPU to use to execute the given prepared script + * @return PreparedScript object representing the precompiled script + */ + public PreparedScript prepareScript(String script, Map<String,String> nsscripts, Map<String, String> args, String[] inputs, String[] outputs, + boolean parsePyDML, boolean useGPU, boolean forceGPU, int gpuIndex) { - // Set DML Options here: - boolean gpu = false; boolean forceGPU = false; - ConfigurationManager.setLocalOptions(new DMLOptions(args, - false, 10, false, Explain.ExplainType.NONE, RUNTIME_PLATFORM.SINGLE_NODE, gpu, forceGPU, + DMLScript.SCRIPT_TYPE = parsePyDML ? ScriptType.PYDML : ScriptType.DML; + ConfigurationManager.setLocalOptions(new DMLOptions(args, + false, 10, false, + Explain.ExplainType.NONE, RUNTIME_PLATFORM.SINGLE_NODE, useGPU, forceGPU, parsePyDML ? ScriptType.PYDML : ScriptType.DML, null, script)); - - //check for valid names of passed arguments - String[] invalidArgs = args.keySet().stream() - .filter(k -> k==null || !k.startsWith("$")).toArray(String[]::new); - if( invalidArgs.length > 0 ) - throw new LanguageException("Invalid argument names: "+Arrays.toString(invalidArgs)); - - //check for valid names of input and output variables - String[] invalidVars = UtilFunctions.asSet(inputs, outputs).stream() - .filter(k -> k==null || k.startsWith("$")).toArray(String[]::new); - if( invalidVars.length > 0 ) - throw new LanguageException("Invalid variable names: "+Arrays.toString(invalidVars)); - setLocalConfigs(); - - //simplified compilation chain - Program rtprog = null; - try { - //parsing - ParserWrapper parser = ParserFactory.createParser( - parsePyDML ? ScriptType.PYDML : ScriptType.DML, nsscripts); - DMLProgram prog = parser.parse(null, script, args); - - //language validate - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - - //hop construct/rewrite - dmlt.constructHops(prog); - dmlt.rewriteHopsDAG(prog); - - //rewrite persistent reads/writes - RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs); - ProgramRewriter rewriter2 = new ProgramRewriter(rewrite); - rewriter2.rewriteProgramHopDAGs(prog); - - //lop construct and runtime prog generation - dmlt.constructLops(prog); - rtprog = dmlt.getRuntimeProgram(prog, _dmlconf); - - //final cleanup runtime prog - JMLCUtils.cleanupRuntimeProgram(rtprog, outputs); - } - catch(ParseException pe) { - // don't chain ParseException (for cleaner error output) - throw pe; - } - catch(Exception ex) { - throw new DMLException(ex); + + List<GPUContext> _gpuCtx = new ArrayList<>(); + if (useGPU) { + if (AVAILABLE_GPU_CONTEXTS == null) { + synchronized (Connection.class) { + if (AVAILABLE_GPU_CONTEXTS == null) { + // Initialize the GPUs if not already + String oldAvailableGpus = GPUContextPool.AVAILABLE_GPUS; + GPUContextPool.AVAILABLE_GPUS = "-1"; // use all the GPUs in JMLC mode + List<GPUContext> availableCtx = GPUContextPool.getAllGPUContexts(); + AVAILABLE_GPU_CONTEXTS = availableCtx.toArray(new GPUContext[availableCtx.size()]); + GPUContextPool.AVAILABLE_GPUS = oldAvailableGpus; + } + } + } + if (AVAILABLE_GPU_CONTEXTS.length == 0) + throw new DMLRuntimeException("No GPU Context in available"); + else if (gpuIndex < 0 || gpuIndex >= AVAILABLE_GPU_CONTEXTS.length) + throw new DMLRuntimeException("Cannot use the GPU " + gpuIndex + + ". Valid values: [0, " + (AVAILABLE_GPU_CONTEXTS.length - 1) + "]"); + // For simplicity of the API, the initial version statically associates a GPU to the prepared script. + // We can revisit this assumption if it turns out to be the overhead. + _gpuCtx.add(AVAILABLE_GPU_CONTEXTS[gpuIndex]); } - - //return newly create precompiled script - return new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); + + Program rtprog = ScriptExecutorUtils.compileRuntimeProgram(script, nsscripts, args, inputs, outputs, + parsePyDML ? ScriptType.PYDML : ScriptType.DML, _dmlconf, SystemMLAPI.JMLC); + + + //return newly create precompiled script + PreparedScript ret = new PreparedScript(rtprog, inputs, outputs, _dmlconf, _cconf); + + if (useGPU) ret._gpuCtx = _gpuCtx; + return ret; } /** @@ -924,4 +946,4 @@ public class Connection implements Closeable ConfigurationManager.setLocalConfig(_dmlconf); ConfigurationManager.setLocalConfig(_cconf); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java index d5955f4..5224097 100644 --- a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java +++ b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java @@ -30,6 +30,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.ConfigurableAPI; import org.apache.sysml.api.DMLException; +import org.apache.sysml.api.ScriptExecutorUtils; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.ConfigurationManager; @@ -53,6 +55,7 @@ import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MetaDataFormat; import org.apache.sysml.runtime.matrix.data.FrameBlock; @@ -80,8 +83,10 @@ public class PreparedScript implements ConfigurableAPI private final LocalVariableMap _vars; private final DMLConfig _dmlconf; private final CompilerConfig _cconf; + private boolean _isStatisticsEnabled = false; - + private boolean _gatherMemStats = false; + private PreparedScript(PreparedScript that) { //shallow copy, except for a separate symbol table //and related meta data of reused inputs @@ -141,7 +146,7 @@ public class PreparedScript implements ConfigurableAPI */ public void gatherMemStats(boolean stats) { this._isStatisticsEnabled = this._isStatisticsEnabled || ConfigurationManager.isStatistics(); - DMLScript.JMLC_MEM_STATISTICS = stats; + this._gatherMemStats = stats; } @Override @@ -433,7 +438,12 @@ public class PreparedScript implements ConfigurableAPI public void clearParameters() { _vars.removeAll(); } - + + /** + * GPU Context to use for execution + */ + List<GPUContext> _gpuCtx = null; + /** * Executes the prepared script over the bound inputs, creating the * result variables according to bound and registered outputs. @@ -443,20 +453,20 @@ public class PreparedScript implements ConfigurableAPI public ResultVariables executeScript() { //add reused variables _vars.putAll(_inVarReuse); - + //set thread-local configurations ConfigurationManager.setLocalConfig(_dmlconf); ConfigurationManager.setLocalConfig(_cconf); - + ConfigurationManager.setStatistics(_isStatisticsEnabled); + ConfigurationManager.setJMLCMemStats(_gatherMemStats); + ConfigurationManager.setFinegrainedStatistics(_gatherMemStats); + //create and populate execution context - ExecutionContext ec = ExecutionContextFactory.createContext(_vars, _prog); - - //core execute runtime program - _prog.execute(ec); - - //cleanup unnecessary outputs - _vars.removeAllNotIn(_outVarnames); - + ScriptExecutorUtils.executeRuntimeProgram( + _prog, _dmlconf, ConfigurationManager.isStatistics() ? + ConfigurationManager.getDMLOptions().getStatisticsMaxHeavyHitters() : 0, + _vars, _outVarnames, SystemMLAPI.JMLC, _gpuCtx); + //construct results ResultVariables rvars = new ResultVariables(); for( String ovar : _outVarnames ) { @@ -464,10 +474,9 @@ public class PreparedScript implements ConfigurableAPI if( tmpVar != null ) rvars.addResult(ovar, tmpVar); } - - //clear thread-local configurations + + // clear prior thread local configurations (for subsequent run) ConfigurationManager.clearLocalConfigs(); - ConfigurationManager.resetStatistics(); return rvars; @@ -551,4 +560,4 @@ public class PreparedScript implements ConfigurableAPI public Object clone() { return clone(true); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index 135e1cd..06778b3 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -20,7 +20,9 @@ package org.apache.sysml.api.mlcontext; import java.io.IOException; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -28,6 +30,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.ScriptExecutorUtils; +import org.apache.sysml.api.ScriptExecutorUtils.SystemMLAPI; import org.apache.sysml.api.jmlc.JMLCUtils; import org.apache.sysml.api.mlcontext.MLContext.ExecutionType; import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; @@ -50,7 +53,8 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.Explain.ExplainCounts; import org.apache.sysml.utils.Explain.ExplainType; @@ -113,6 +117,7 @@ public class ScriptExecutor { protected ExecutionType executionType; protected int statisticsMaxHeavyHitters = 10; protected boolean maintainSymbolTable = false; + protected List<GPUContext> gCtxs = null; /** * ScriptExecutor constructor. @@ -208,20 +213,6 @@ public class ScriptExecutor { } /** - * Create an execution context and set its variables to be the symbol table - * of the script. - */ - protected void createAndInitializeExecutionContext() { - executionContext = ExecutionContextFactory.createContext(runtimeProgram); - LocalVariableMap symbolTable = script.getSymbolTable(); - if (symbolTable != null) - executionContext.setVariables(symbolTable); - //attach registered outputs (for dynamic recompile) - executionContext.getVariables().setRegisteredOutputs( - new HashSet<String>(script.getOutputVariables())); - } - - /** * Set the global flags (for example: statistics, gpu, etc). */ protected void setGlobalFlags() { @@ -291,27 +282,23 @@ public class ScriptExecutor { */ public void compile(Script script, boolean performHOPRewrites) { - // main steps in script execution setup(script); - if (statistics) { - Statistics.startCompileTimer(); - } - parseScript(); - liveVariableAnalysis(); - validateScript(); - constructHops(); - if(performHOPRewrites) - rewriteHops(); - rewritePersistentReadsAndWrites(); - constructLops(); - generateRuntimeProgram(); - showExplanation(); - countCompiledMRJobsAndSparkInstructions(); - initializeCachingAndScratchSpace(); - cleanupRuntimeProgram(); - if (statistics) { - Statistics.stopCompileTimer(); + + LocalVariableMap symbolTable = script.getSymbolTable(); + String[] inputs = null; String[] outputs = null; + if (symbolTable != null) { + inputs = (script.getInputVariables() == null) ? new String[0] + : script.getInputVariables().toArray(new String[0]); + outputs = (script.getOutputVariables() == null) ? new String[0] + : script.getOutputVariables().toArray(new String[0]); } + + Map<String, String> args = MLContextUtil + .convertInputParametersForParser(script.getInputParameters(), script.getScriptType()); + runtimeProgram = ScriptExecutorUtils.compileRuntimeProgram(script.getScriptExecutionString(), Collections.emptyMap(), + args, null, symbolTable, inputs, outputs, script.getScriptType(), config, SystemMLAPI.MLContext, + performHOPRewrites, isMaintainSymbolTable(), init); + gCtxs = ConfigurationManager.isGPU() ? GPUContextPool.getAllGPUContexts() : null; } @@ -321,8 +308,6 @@ public class ScriptExecutor { * * <ol> * <li>{@link #compile(Script)}</li> - * <li>{@link #createAndInitializeExecutionContext()}</li> - * <li>{@link #executeRuntimeProgram()}</li> * <li>{@link #cleanupAfterExecution()}</li> * </ol> * @@ -352,8 +337,11 @@ public class ScriptExecutor { compile(script); try { - createAndInitializeExecutionContext(); - executeRuntimeProgram(); + executionContext = ScriptExecutorUtils.executeRuntimeProgram(getRuntimeProgram(), getConfig(), + statistics ? statisticsMaxHeavyHitters : 0, script.getSymbolTable(), + new HashSet<String>(getScript().getOutputVariables()), SystemMLAPI.MLContext, gCtxs); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred while executing runtime program", e); } finally { cleanupAfterExecution(); } @@ -376,8 +364,17 @@ public class ScriptExecutor { */ protected void setup(Script script) { this.script = script; - checkScriptHasTypeAndString(); + if (script == null) { + throw new MLContextException("Script is null"); + } else if (script.getScriptType() == null) { + throw new MLContextException("ScriptType (DML or PYDML) needs to be specified"); + } else if (script.getScriptString() == null) { + throw new MLContextException("Script string is null"); + } else if (StringUtils.isBlank(script.getScriptString())) { + throw new MLContextException("Script string is blank"); + } script.setScriptExecutor(this); + // Set global variable indicating the script type DMLScript.SCRIPT_TYPE = script.getScriptType(); setGlobalFlags(); @@ -385,6 +382,7 @@ public class ScriptExecutor { Statistics.resetNoOfExecutedJobs(); if (statistics) Statistics.reset(); + DMLScript.EXPLAIN = (explainLevel != null) ? explainLevel.getExplainType() : ExplainType.NONE; } /** @@ -429,19 +427,6 @@ public class ScriptExecutor { } /** - * Execute the runtime program. This involves execution of the program - * blocks that make up the runtime program and may involve dynamic - * recompilation. - */ - protected void executeRuntimeProgram() { - try { - ScriptExecutorUtils.executeRuntimeProgram(this, statistics ? statisticsMaxHeavyHitters : 0); - } catch (DMLRuntimeException e) { - throw new MLContextException("Exception occurred while executing runtime program", e); - } - } - - /** * Check security, create scratch space, cleanup working directories, * initialize caching, and reset statistics. */ @@ -467,12 +452,9 @@ public class ScriptExecutor { protected void parseScript() { try { ParserWrapper parser = ParserFactory.createParser(script.getScriptType()); - Map<String, Object> inputParameters = script.getInputParameters(); - Map<String, String> inputParametersStringMaps = MLContextUtil - .convertInputParametersForParser(inputParameters, script.getScriptType()); - - String scriptExecutionString = script.getScriptExecutionString(); - dmlProgram = parser.parse(null, scriptExecutionString, inputParametersStringMaps); + Map<String, String> args = MLContextUtil + .convertInputParametersForParser(script.getInputParameters(), script.getScriptType()); + dmlProgram = parser.parse(null, script.getScriptExecutionString(), args); } catch (ParseException e) { throw new MLContextException("Exception occurred while parsing script", e); } @@ -735,4 +717,4 @@ public class ScriptExecutor { ConfigurationManager.getDMLOptions().setExecutionMode(executionType.getRuntimePlatform()); this.executionType = executionType; } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/conf/ConfigurationManager.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/conf/ConfigurationManager.java b/src/main/java/org/apache/sysml/conf/ConfigurationManager.java index 96c3885..08d929e 100644 --- a/src/main/java/org/apache/sysml/conf/ConfigurationManager.java +++ b/src/main/java/org/apache/sysml/conf/ConfigurationManager.java @@ -155,7 +155,7 @@ public class ConfigurationManager /** * Sets the current thread-local dml configuration to the given options. * - * @param conf the configuration + * @param opts the configuration */ public static void setLocalOptions( DMLOptions opts ) { _dmlOptions = opts; @@ -275,6 +275,7 @@ public class ConfigurationManager // _dmlconf.getBooleanValue(DMLConfig.EXTRA_FINEGRAINED_STATS); private static boolean STATISTICS = false; private static boolean FINEGRAINED_STATISTICS = false; + private static boolean JMLC_MEM_STATISTICS = false; /** * @return true if statistics is enabled @@ -289,7 +290,12 @@ public class ConfigurationManager public static boolean isFinegrainedStatistics() { return FINEGRAINED_STATISTICS; } - + + /** + * @return true if JMLC memory statistics are enabled + */ + public static boolean isJMLCMemStatistics() { return JMLC_MEM_STATISTICS; } + /** * Whether or not statistics about the DML/PYDML program should be output to * standard output. @@ -301,7 +307,31 @@ public class ConfigurationManager public static void setStatistics(boolean enabled) { STATISTICS = enabled; } - + + /** + * Whether or not detailed statistics about program memory use should be output + * to standard output when running under JMLC + * + * @param enabled + * {@code true} if statistics should be output, {@code false} + * otherwise + */ + public static void setJMLCMemStats(boolean enabled) { + JMLC_MEM_STATISTICS = enabled; + } + + + /** + * Whether or not finegrained statistics should be enabled + * + * @param enabled + * {@code true} if statistics should be output, {@code false} + * otherwise + */ + public static void setFinegrainedStatistics(boolean enabled) { + FINEGRAINED_STATISTICS = enabled; + } + /** * Reset the statistics flag. */ @@ -342,4 +372,4 @@ public class ConfigurationManager return null; } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java index 6b241c1..220f3e5 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java @@ -115,11 +115,8 @@ public class LocalVariableMap implements Cloneable } public boolean hasReferences( Data d ) { - //perf: avoid java streams here for reduced overhead in rmvar - for( Data o : localMap.values() ) - if( o instanceof ListObject ? ((ListObject)o).getData().contains(d) : o == d ) - return true; - return false; + return localMap.values().stream().anyMatch(e -> (e instanceof ListObject) ? + ((ListObject)e).getData().contains(d) : e == d); } public void setRegisteredOutputs(HashSet<String> outputs) { @@ -143,7 +140,7 @@ public class LocalVariableMap implements Cloneable if( !dict.containsKey(hash) && e.getValue() instanceof CacheableData ) { dict.put(hash, e.getValue()); double size = ((CacheableData<?>) e.getValue()).getDataSize(); - if (DMLScript.JMLC_MEM_STATISTICS && ConfigurationManager.isFinegrainedStatistics()) + if (ConfigurationManager.isJMLCMemStatistics() && ConfigurationManager.isFinegrainedStatistics()) Statistics.maintainCPHeavyHittersMem(e.getKey(), size); total += size; } @@ -203,4 +200,4 @@ public class LocalVariableMap implements Cloneable public Object clone() { return new LocalVariableMap(this); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java index feb1234..814a085 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java @@ -259,7 +259,7 @@ public class ProgramBlock implements ParseInfo Statistics.maintainCPHeavyHitters( tmp.getExtendedOpcode(), System.nanoTime()-t0); } - if (DMLScript.JMLC_MEM_STATISTICS && ConfigurationManager.isFinegrainedStatistics()) + if (ConfigurationManager.isJMLCMemStatistics() && ConfigurationManager.isFinegrainedStatistics()) ec.getVariables().getPinnedDataSize(); // optional trace information (instruction and runtime) @@ -418,4 +418,4 @@ public class ProgramBlock implements ParseInfo _filename = parseInfo.getFilename(); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java index 15dd23e..1d40d72 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java @@ -500,7 +500,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data if( ConfigurationManager.isStatistics() ){ long t1 = System.nanoTime(); CacheStatistics.incrementAcquireMTime(t1-t0); - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) Statistics.addCPMemObject(System.identityHashCode(this), getDataSize()); } @@ -1339,4 +1339,4 @@ public abstract class CacheableData<T extends CacheBlock> extends Data return str.toString(); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index f310d76..3e8636b 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -644,7 +644,7 @@ public class ExecutionContext { } public void cleanupCacheableData(CacheableData<?> mo) { - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) Statistics.removeCPMemObject(System.identityHashCode(mo)); //early abort w/o scan of symbol table if no cleanup required boolean fileExists = (mo.isHDFSFileExists() && mo.getFileName() != null); @@ -786,4 +786,4 @@ public class ExecutionContext { _dbState.prevPC.setProgramBlockNumber(_dbState.getPC().getProgramBlockNumber()); _dbState.prevPC.setLineNumber(currInst.getLineNum()); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 879133b..33a61db 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -42,6 +42,7 @@ import org.apache.spark.storage.RDDInfo; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; +import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContextUtil; @@ -1100,7 +1101,7 @@ public class SparkExecutionContext extends ExecutionContext //and hence is transparently used by rmvar instructions and other users. The //core difference is the lineage-based cleanup of RDD and broadcast variables. - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) Statistics.removeCPMemObject(System.identityHashCode(mo)); if( !mo.isCleanupEnabled() ) @@ -1608,4 +1609,4 @@ public class SparkExecutionContext extends ExecutionContext _rdds.clear(); } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java index 242127a..cfb8d1e 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContextPool.java @@ -28,6 +28,7 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.utils.GPUStatistics; @@ -71,9 +72,9 @@ public class GPUContextPool { static List<GPUContext> pool = new LinkedList<>(); /** - * Whether the pool of GPUs is reserved or not + * Used to throw an error in case of incorrect usage */ - static boolean reserved = false; + private static String oldAvailableGpus; /** * Static initialization of the number of devices @@ -99,6 +100,7 @@ public class GPUContextPool { try { ArrayList<Integer> listOfGPUs = parseListString(AVAILABLE_GPUS, deviceCount); + oldAvailableGpus = AVAILABLE_GPUS; // Initialize the list of devices & the pool of GPUContexts for (int i : listOfGPUs) { @@ -202,17 +204,17 @@ public class GPUContextPool { } /** - * Reserves and gets an initialized list of GPUContexts + * Gets an initialized list of GPUContexts * * @return null if no GPUContexts in pool, otherwise a valid list of GPUContext */ - public static synchronized List<GPUContext> reserveAllGPUContexts() { - if (reserved) - throw new DMLRuntimeException("Trying to re-reserve GPUs"); + public static synchronized List<GPUContext> getAllGPUContexts() { if (!initialized) initializeGPU(); - reserved = true; - LOG.trace("GPU : Reserved all GPUs"); + if(!oldAvailableGpus.equals(AVAILABLE_GPUS)) { + LOG.warn("GPUContextPool was already initialized with " + DMLConfig.AVAILABLE_GPUS + "=" + oldAvailableGpus + + ". Cannot reinitialize it with " + DMLConfig.AVAILABLE_GPUS + "=" + AVAILABLE_GPUS); + } return pool; } @@ -250,17 +252,6 @@ public class GPUContextPool { } /** - * Unreserves all GPUContexts - */ - public static synchronized void freeAllGPUContexts() { - if (!reserved) - throw new DMLRuntimeException("Trying to free unreserved GPUs"); - reserved = false; - LOG.trace("GPU : Unreserved all GPUs"); - - } - - /** * Gets the initial GPU memory budget. This is the minimum of the * available memories across all the GPUs on the machine(s) * @return minimum available memory @@ -275,4 +266,4 @@ public class GPUContextPool { throw new RuntimeException(e); } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index ca55564..656de32 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -975,7 +975,7 @@ public class Statistics sb.append("Cache hits (Mem, WB, FS, HDFS):\t" + CacheStatistics.displayHits() + ".\n"); sb.append("Cache writes (WB, FS, HDFS):\t" + CacheStatistics.displayWrites() + ".\n"); sb.append("Cache times (ACQr/m, RLS, EXP):\t" + CacheStatistics.displayTime() + " sec.\n"); - if (DMLScript.JMLC_MEM_STATISTICS) + if (ConfigurationManager.isJMLCMemStatistics()) sb.append("Max size of live objects:\t" + byteCountToDisplaySize(getSizeofPinnedObjects()) + " (" + getNumPinnedObjects() + " total objects)" + "\n"); sb.append("HOP DAGs recompiled (PRED, SB):\t" + getHopRecompiledPredDAGs() + "/" + getHopRecompiledSBDAGs() + ".\n"); sb.append("HOP DAGs recompile time:\t" + String.format("%.3f", ((double)getHopRecompileTime())/1000000000) + " sec.\n"); @@ -1029,10 +1029,10 @@ public class Statistics sb.append("Total JVM GC time:\t\t" + ((double)getJVMgcTime())/1000 + " sec.\n"); LibMatrixDNN.appendStatistics(sb); sb.append("Heavy hitter instructions:\n" + getHeavyHitters(maxHeavyHitters)); - if (DMLScript.JMLC_MEM_STATISTICS && ConfigurationManager.isFinegrainedStatistics()) + if (ConfigurationManager.isJMLCMemStatistics() && ConfigurationManager.isFinegrainedStatistics()) sb.append("Heavy hitter objects:\n" + getCPHeavyHittersMem(maxHeavyHitters)); } return sb.toString(); } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/test/java/org/apache/sysml/test/gpu/GPUTests.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java index e1ae1ae..d911d4f 100644 --- a/src/test/java/org/apache/sysml/test/gpu/GPUTests.java +++ b/src/test/java/org/apache/sysml/test/gpu/GPUTests.java @@ -56,9 +56,9 @@ public abstract class GPUTests extends AutomatedTestBase { protected static SparkSession spark; protected final double DOUBLE_PRECISION_THRESHOLD = 1e-9; // for relative error private static final boolean PRINT_MAT_ERROR = false; - - // We will use this flag until lower precision is supported on CP. - private final static String FLOATING_POINT_PRECISION = "double"; + + // We will use this flag until lower precision is supported on CP. + private final static String FLOATING_POINT_PRECISION = "double"; protected final double SINGLE_PRECISION_THRESHOLD = 1e-3; // for relative error @@ -106,7 +106,7 @@ public abstract class GPUTests extends AutomatedTestBase { int freeCount = GPUContextPool.getAvailableCount(); Assert.assertTrue("All GPUContexts have not been returned to the GPUContextPool", count == freeCount); - List<GPUContext> gCtxs = GPUContextPool.reserveAllGPUContexts(); + List<GPUContext> gCtxs = GPUContextPool.getAllGPUContexts(); for (GPUContext gCtx : gCtxs) { gCtx.initializeThread(); try { @@ -116,9 +116,6 @@ public abstract class GPUTests extends AutomatedTestBase { throw e; } } - GPUContextPool.freeAllGPUContexts(); - - } catch (DMLRuntimeException e) { // Ignore } @@ -421,4 +418,4 @@ public abstract class GPUTests extends AutomatedTestBase { Assert.fail("Invalid types for comparison"); } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java b/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java new file mode 100644 index 0000000..d77d794 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/gpu/JMLCTests.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.gpu; + +import java.util.Random; +import org.junit.Test; +import org.junit.Assert; +import org.apache.sysml.api.jmlc.Connection; +import org.apache.sysml.api.jmlc.PreparedScript; + + +public class JMLCTests extends GPUTests { + + static class ScriptContainer { + String dml; + String[] inputVarNames; + } + + @Test + public void testJMLC() { + try { + Connection conn = new Connection(); + + int numMatrices = 10; + int matrixNumRows = 100; + int numScriptInvocations = 10; + + ScriptContainer SC = generateDMLScript(numMatrices); + + PreparedScript script = conn.prepareScript( + SC.dml, SC.inputVarNames, new String[] { "Z" }, true, true, 0); + + // execute the script without pinning input matrices between invocations + executeDMLScript(script, numScriptInvocations, matrixNumRows, numMatrices, false); + + // execute the script while pinning input matrices between invocations + executeDMLScript(script, numScriptInvocations, matrixNumRows, numMatrices, true); + } catch (Exception e) { + Assert.fail("An unexpected exception occurred: " + e.getMessage()); + } + } + + // Generates a simple synthetic DML script which multiplies a sequence of square matrices. + // I.e. Z = X %*% W1 %*% W2 %*% W3 ... + // numMatrices determines the number of matrices in the sequences. The size of the matrices can be set + // in executeDMLScript + static ScriptContainer generateDMLScript(int numMatrices) { + ScriptContainer SC = new ScriptContainer(); + String[] inputVarNames = new String[numMatrices + 1]; + inputVarNames[0] = "x"; + + StringBuilder dml = new StringBuilder("x = read(\"/tmp/X.mtx\", rows=-1, cols=-1)\n"); + for (int ix=0; ix<numMatrices; ix++) + { + String name = "W" + ix; + inputVarNames[ix+1] = name; + dml.append(name + " = read(\"/tmp/" + name + ".mtx\", rows=-1, cols=-1)\n"); + } + + dml.append("Z = x %*% W0\n"); + for (int ix=1; ix<numMatrices; ix++) + { + dml.append("Z = Z %*% W" + ix + "\n"); + } + + dml.append("while (-1 > 1)\n print(as.scalar(Z[1,1]))\n"); + + SC.dml = dml.toString(); + SC.inputVarNames = inputVarNames; + + return SC; + } + + // Executes a PreparedScript generated by generateDMLScript. The parameter n determines the + // number of times the script is invoked. The parameter rows controls the shape of the matrices. + // Set this parameter larger to use more memory. The parameter numMatrices must be set to the same value as + // in generateDMLScript. The parameter pinWeights controls whether weight matrices should be + // pinned in memory between script invocations. + static void executeDMLScript(PreparedScript script, int n, int rows, int numMatrices, boolean pinWeights) { + for (int ix=0; ix<numMatrices; ix++) + script.setMatrix("W" + ix, randomMatrix(rows, rows, 0.0,1.0, 1.0), pinWeights); + + for (int ix=0; ix<n; ix++) + { + script.setMatrix("x", randomMatrix(rows, rows, 0.0, 1.0, 1.0), false); + script.executeScript(); + if (!pinWeights) + for (int iy=0; iy<numMatrices; iy++) + script.setMatrix( + "W" + iy, randomMatrix(rows, rows, 0.0,1.0, 1.0), false); + } + } + + static double[][] randomMatrix( + int rows, int cols, double min, double max, double sparsity) { + double[][] matrix = new double[rows][cols]; + Random random = new Random(System.currentTimeMillis()); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + if (random.nextDouble() > sparsity) { + continue; + } + matrix[i][j] = (random.nextDouble() * (max - min) + min); + } + } + return matrix; + } + +} http://git-wip-us.apache.org/repos/asf/systemml/blob/7907c0ea/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java index 2f2022e..25cc037 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCParfor2ForCompileTest.java @@ -63,11 +63,12 @@ public class JMLCParfor2ForCompileTest extends AutomatedTestBase PreparedScript pscript = conn.prepareScript( script, new String[]{}, new String[]{}, false); - ConfigurationManager.setStatistics(true); + pscript.setStatistics(true); pscript.executeScript(); conn.close(); + //check for existing or non-existing parfor - Assert.assertTrue(Statistics.getParforOptCount()==(par?1:0)); + Assert.assertTrue("INCORRECT PARFOR COUNT", Statistics.getParforOptCount()==(par?1:0)); } catch(Exception ex) { Assert.fail("JMLC parfor test failed: "+ex.getMessage()); @@ -75,4 +76,4 @@ public class JMLCParfor2ForCompileTest extends AutomatedTestBase ConfigurationManager.resetStatistics(); } } -} +} \ No newline at end of file