Repository: zeppelin Updated Branches: refs/heads/master 7aa94ce93 -> 0a97446a7
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index 809e883..beebd42 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -30,6 +30,7 @@ import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.interpreter.BaseZeppelinContext; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; @@ -44,6 +45,8 @@ import org.apache.zeppelin.interpreter.WrappedInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteInterpreterUtils; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; +import org.apache.zeppelin.python.IPythonInterpreter; +import org.apache.zeppelin.python.PythonInterpreter; import org.apache.zeppelin.spark.dep.SparkDependencyContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,56 +71,23 @@ import java.util.Properties; * features compared to IPySparkInterpreter, but requires less prerequisites than * IPySparkInterpreter, only python is required. */ -public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler { - private static final Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class); - private static final int MAX_TIMEOUT_SEC = 10; +public class PySparkInterpreter extends PythonInterpreter { + + private static Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class); - private GatewayServer gatewayServer; - private DefaultExecutor executor; - // used to forward output from python process to InterpreterOutput - private InterpreterOutputStream outputStream; - private String scriptPath; - private boolean pythonscriptRunning = false; - private long pythonPid = -1; - private IPySparkInterpreter iPySparkInterpreter; private SparkInterpreter sparkInterpreter; public PySparkInterpreter(Properties property) { super(property); + this.useBuiltinPy4j = false; } @Override public void open() throws InterpreterException { - // try IPySparkInterpreter first - iPySparkInterpreter = getIPySparkInterpreter(); - if (getProperty("zeppelin.pyspark.useIPython", "true").equals("true") && - StringUtils.isEmpty( - iPySparkInterpreter.checkIPythonPrerequisite(getPythonExec(getProperties())))) { - try { - iPySparkInterpreter.open(); - LOGGER.info("IPython is available, Use IPySparkInterpreter to replace PySparkInterpreter"); - return; - } catch (Exception e) { - iPySparkInterpreter = null; - LOGGER.warn("Fail to open IPySparkInterpreter", e); - } - } + setProperty("zeppelin.python.useIPython", getProperty("zeppelin.pyspark.useIPython", "true")); - // reset iPySparkInterpreter to null as it is not available - iPySparkInterpreter = null; - LOGGER.info("IPython is not available, use the native PySparkInterpreter\n"); - // Add matplotlib display hook - InterpreterGroup intpGroup = getInterpreterGroup(); - if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) { - try { - // just for unit test I believe (zjffdu) - registerHook(HookType.POST_EXEC_DEV.getName(), "__zeppelin__._displayhook()"); - } catch (InvalidHookException e) { - throw new InterpreterException(e); - } - } + // create SparkInterpreter in JVM side TODO(zjffdu) move to SparkInterpreter DepInterpreter depInterpreter = getDepInterpreter(); - // load libraries from Dependency Interpreter URL [] urls = new URL[0]; List<URL> urlList = new LinkedList<>(); @@ -159,474 +129,81 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); try { URLClassLoader newCl = new URLClassLoader(urls, oldCl); - LOGGER.info("urls:" + urls); - for (URL url : urls) { - LOGGER.info("url:" + url); - } Thread.currentThread().setContextClassLoader(newCl); + // create Python Process and JVM gateway + super.open(); // must create spark interpreter after ClassLoader is set, otherwise the additional jars // can not be loaded by spark repl. this.sparkInterpreter = getSparkInterpreter(); - createGatewayServerAndStartScript(); - } catch (IOException e) { - LOGGER.error("Fail to open PySparkInterpreter", e); - throw new InterpreterException("Fail to open PySparkInterpreter", e); } finally { Thread.currentThread().setContextClassLoader(oldCl); } - } - - private void createGatewayServerAndStartScript() throws IOException { - // start gateway server in JVM side - int port = RemoteInterpreterUtils.findRandomAvailablePortOnAllLocalInterfaces(); - gatewayServer = new GatewayServer(this, port); - gatewayServer.start(); - - // launch python process to connect to the gateway server in JVM side - createPythonScript(); - String pythonExec = getPythonExec(getProperties()); - LOGGER.info("PythonExec: " + pythonExec); - CommandLine cmd = CommandLine.parse(pythonExec); - cmd.addArgument(scriptPath, false); - cmd.addArgument(Integer.toString(port), false); - cmd.addArgument(Integer.toString(sparkInterpreter.getSparkVersion().toNumber()), false); - executor = new DefaultExecutor(); - outputStream = new InterpreterOutputStream(LOGGER); - PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream); - executor.setStreamHandler(streamHandler); - executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); - - Map<String, String> env = setupPySparkEnv(); - executor.execute(cmd, env, this); - pythonscriptRunning = true; - } - - private void createPythonScript() throws IOException { - FileOutputStream pysparkScriptOutput = null; - FileOutputStream zeppelinContextOutput = null; - try { - // copy zeppelin_pyspark.py - File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py"); - this.scriptPath = scriptFile.getAbsolutePath(); - pysparkScriptOutput = new FileOutputStream(scriptFile); - IOUtils.copy( - getClass().getClassLoader().getResourceAsStream("python/zeppelin_pyspark.py"), - pysparkScriptOutput); - - // copy zeppelin_context.py to the same folder of zeppelin_pyspark.py - zeppelinContextOutput = new FileOutputStream(scriptFile.getParent() + "/zeppelin_context.py"); - IOUtils.copy( - getClass().getClassLoader().getResourceAsStream("python/zeppelin_context.py"), - zeppelinContextOutput); - LOGGER.info("PySpark script {} {} is created", - scriptPath, scriptFile.getParent() + "/zeppelin_context.py"); - } finally { - if (pysparkScriptOutput != null) { - try { - pysparkScriptOutput.close(); - } catch (IOException e) { - // ignore - } - } - if (zeppelinContextOutput != null) { - try { - zeppelinContextOutput.close(); - } catch (IOException e) { - // ignore - } - } - } - } - - private Map<String, String> setupPySparkEnv() throws IOException { - Map<String, String> env = EnvironmentUtils.getProcEnvironment(); - // only set PYTHONPATH in local or yarn-client mode. - // yarn-cluster will setup PYTHONPATH automatically. - SparkConf conf = null; - try { - conf = getSparkConf(); - } catch (InterpreterException e) { - throw new IOException(e); - } - if (!conf.get("spark.submit.deployMode", "client").equals("cluster")) { - if (!env.containsKey("PYTHONPATH")) { - env.put("PYTHONPATH", PythonUtils.sparkPythonPath()); - } else { - env.put("PYTHONPATH", PythonUtils.sparkPythonPath() + ":" + env.get("PYTHONPATH")); - } - } - // get additional class paths when using SPARK_SUBMIT and not using YARN-CLIENT - // also, add all packages to PYTHONPATH since there might be transitive dependencies - if (SparkInterpreter.useSparkSubmit() && - !sparkInterpreter.isYarnMode()) { - String sparkSubmitJars = conf.get("spark.jars").replace(",", ":"); - if (!StringUtils.isEmpty(sparkSubmitJars)) { - env.put("PYTHONPATH", env.get("PYTHONPATH") + ":" + sparkSubmitJars); + if (!useIPython()) { + // Initialize Spark in Python Process + try { + bootstrapInterpreter("python/zeppelin_pyspark.py"); + } catch (IOException e) { + throw new InterpreterException("Fail to bootstrap pyspark", e); } } - - // set PYSPARK_PYTHON - if (conf.contains("spark.pyspark.python")) { - env.put("PYSPARK_PYTHON", conf.get("spark.pyspark.python")); - } - LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH")); - return env; - } - - // Run python shell - // Choose python in the order of - // PYSPARK_DRIVER_PYTHON > PYSPARK_PYTHON > zeppelin.pyspark.python - public static String getPythonExec(Properties properties) { - String pythonExec = properties.getProperty("zeppelin.pyspark.python", "python"); - if (System.getenv("PYSPARK_PYTHON") != null) { - pythonExec = System.getenv("PYSPARK_PYTHON"); - } - if (System.getenv("PYSPARK_DRIVER_PYTHON") != null) { - pythonExec = System.getenv("PYSPARK_DRIVER_PYTHON"); - } - return pythonExec; } @Override public void close() throws InterpreterException { - if (iPySparkInterpreter != null) { - iPySparkInterpreter.close(); - return; + super.close(); + if (sparkInterpreter != null) { + sparkInterpreter.close(); } - executor.getWatchdog().destroyProcess(); - gatewayServer.shutdown(); } - private PythonInterpretRequest pythonInterpretRequest = null; - private Integer statementSetNotifier = new Integer(0); - private String statementOutput = null; - private boolean statementError = false; - private Integer statementFinishedNotifier = new Integer(0); - - /** - * Request send to Python Daemon - */ - public class PythonInterpretRequest { - public String statements; - public String jobGroup; - public String jobDescription; - public boolean isForCompletion; - - public PythonInterpretRequest(String statements, String jobGroup, - String jobDescription, boolean isForCompletion) { - this.statements = statements; - this.jobGroup = jobGroup; - this.jobDescription = jobDescription; - this.isForCompletion = isForCompletion; - } - - public String statements() { - return statements; - } - - public String jobGroup() { - return jobGroup; - } - - public String jobDescription() { - return jobDescription; - } - - public boolean isForCompletion() { - return isForCompletion; - } - } - - // called by Python Process - public PythonInterpretRequest getStatements() { - synchronized (statementSetNotifier) { - while (pythonInterpretRequest == null) { - try { - statementSetNotifier.wait(1000); - } catch (InterruptedException e) { - } - } - PythonInterpretRequest req = pythonInterpretRequest; - pythonInterpretRequest = null; - return req; - } - } - - // called by Python Process - public void setStatementsFinished(String out, boolean error) { - synchronized (statementFinishedNotifier) { - LOGGER.debug("Setting python statement output: " + out + ", error: " + error); - statementOutput = out; - statementError = error; - statementFinishedNotifier.notify(); - } - } - - private boolean pythonScriptInitialized = false; - private Integer pythonScriptInitializeNotifier = new Integer(0); - - // called by Python Process - public void onPythonScriptInitialized(long pid) { - pythonPid = pid; - synchronized (pythonScriptInitializeNotifier) { - LOGGER.debug("onPythonScriptInitialized is called"); - pythonScriptInitialized = true; - pythonScriptInitializeNotifier.notifyAll(); - } - } - - // called by Python Process - public void appendOutput(String message) throws IOException { - LOGGER.debug("Output from python process: " + message); - outputStream.getInterpreterOutput().write(message); + @Override + protected BaseZeppelinContext createZeppelinContext() { + return sparkInterpreter.getZeppelinContext(); } @Override public InterpreterResult interpret(String st, InterpreterContext context) throws InterpreterException { - if (iPySparkInterpreter != null) { - return iPySparkInterpreter.interpret(st, context); - } - - if (sparkInterpreter.isUnsupportedSparkVersion()) { - return new InterpreterResult(Code.ERROR, "Spark " - + sparkInterpreter.getSparkVersion().toString() + " is not supported"); - } sparkInterpreter.populateSparkWebUrl(context); + return super.interpret(st, context); + } - if (!pythonscriptRunning) { - return new InterpreterResult(Code.ERROR, "python process not running " - + outputStream.toString()); - } - - outputStream.setInterpreterOutput(context.out); - - synchronized (pythonScriptInitializeNotifier) { - long startTime = System.currentTimeMillis(); - while (pythonScriptInitialized == false - && pythonscriptRunning - && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) { - try { - LOGGER.info("Wait for PythonScript running"); - pythonScriptInitializeNotifier.wait(1000); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - } - - List<InterpreterResultMessage> errorMessage; - try { - context.out.flush(); - errorMessage = context.out.toInterpreterResultMessage(); - } catch (IOException e) { - throw new InterpreterException(e); - } - - - if (pythonscriptRunning == false) { - // python script failed to initialize and terminated - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "Failed to start PySpark")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - if (pythonScriptInitialized == false) { - // timeout. didn't get initialized message - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, "Failed to initialize PySpark")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - - //TODO(zjffdu) remove this as PySpark is supported starting from spark 1.2s - if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) { - errorMessage.add(new InterpreterResultMessage( - InterpreterResult.Type.TEXT, - "pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported")); - return new InterpreterResult(Code.ERROR, errorMessage); - } - + @Override + protected void preCallPython(InterpreterContext context) { String jobGroup = Utils.buildJobGroupId(context); String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo()); - - SparkZeppelinContext z = sparkInterpreter.getZeppelinContext(); - z.setInterpreterContext(context); - z.setGui(context.getGui()); - z.setNoteGui(context.getNoteGui()); - InterpreterContext.set(context); - - pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc, false); - statementOutput = null; - - synchronized (statementSetNotifier) { - statementSetNotifier.notify(); - } - - synchronized (statementFinishedNotifier) { - while (statementOutput == null) { - try { - statementFinishedNotifier.wait(1000); - } catch (InterruptedException e) { - } - } - } - - if (statementError) { - return new InterpreterResult(Code.ERROR, statementOutput); - } else { - try { - context.out.flush(); - } catch (IOException e) { - throw new InterpreterException(e); - } - return new InterpreterResult(Code.SUCCESS); - } - } - - public void interrupt() throws IOException, InterpreterException { - if (pythonPid > -1) { - LOGGER.info("Sending SIGINT signal to PID : " + pythonPid); - Runtime.getRuntime().exec("kill -SIGINT " + pythonPid); - } else { - LOGGER.warn("Non UNIX/Linux system, close the interpreter"); - close(); - } + callPython(new PythonInterpretRequest( + String.format("if 'sc' in locals():\n\tsc.setJobGroup('%s', '%s')", jobGroup, jobDesc), + false)); } + // Run python shell + // Choose python in the order of + // PYSPARK_DRIVER_PYTHON > PYSPARK_PYTHON > zeppelin.pyspark.python @Override - public void cancel(InterpreterContext context) throws InterpreterException { - if (iPySparkInterpreter != null) { - iPySparkInterpreter.cancel(context); - return; - } - SparkInterpreter sparkInterpreter = getSparkInterpreter(); - sparkInterpreter.cancel(context); - try { - interrupt(); - } catch (IOException e) { - LOGGER.error("Error", e); + protected String getPythonExec() { + String pythonExec = getProperty("zeppelin.pyspark.python", "python"); + if (System.getenv("PYSPARK_PYTHON") != null) { + pythonExec = System.getenv("PYSPARK_PYTHON"); } - } - - @Override - public FormType getFormType() { - return FormType.NATIVE; - } - - @Override - public int getProgress(InterpreterContext context) throws InterpreterException { - if (iPySparkInterpreter != null) { - return iPySparkInterpreter.getProgress(context); + if (System.getenv("PYSPARK_DRIVER_PYTHON") != null) { + pythonExec = System.getenv("PYSPARK_DRIVER_PYTHON"); } - SparkInterpreter sparkInterpreter = getSparkInterpreter(); - return sparkInterpreter.getProgress(context); + return pythonExec; } - @Override - public List<InterpreterCompletion> completion(String buf, int cursor, - InterpreterContext interpreterContext) - throws InterpreterException { - if (iPySparkInterpreter != null) { - return iPySparkInterpreter.completion(buf, cursor, interpreterContext); - } - if (buf.length() < cursor) { - cursor = buf.length(); - } - String completionString = getCompletionTargetString(buf, cursor); - String completionCommand = "completion.getCompletion('" + completionString + "')"; - LOGGER.debug("completionCommand: " + completionCommand); - - //start code for completion - if (sparkInterpreter.isUnsupportedSparkVersion() || pythonscriptRunning == false) { - return new LinkedList<>(); - } - - pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", "", true); - statementOutput = null; - - synchronized (statementSetNotifier) { - statementSetNotifier.notify(); - } - - String[] completionList = null; - synchronized (statementFinishedNotifier) { - long startTime = System.currentTimeMillis(); - while (statementOutput == null - && pythonscriptRunning) { - try { - if (System.currentTimeMillis() - startTime > MAX_TIMEOUT_SEC * 1000) { - LOGGER.error("pyspark completion didn't have response for {}sec.", MAX_TIMEOUT_SEC); - break; - } - statementFinishedNotifier.wait(1000); - } catch (InterruptedException e) { - // not working - LOGGER.info("wait drop"); - return new LinkedList<>(); - } - } - if (statementError) { - return new LinkedList<>(); - } - Gson gson = new Gson(); - completionList = gson.fromJson(statementOutput, String[].class); - } - //end code for completion - if (completionList == null) { - return new LinkedList<>(); - } - - List<InterpreterCompletion> results = new LinkedList<>(); - for (String name: completionList) { - results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY)); - LOGGER.debug("completion: " + name); - } - return results; - } - - private String getCompletionTargetString(String text, int cursor) { - String[] completionSeqCharaters = {" ", "\n", "\t"}; - int completionEndPosition = cursor; - int completionStartPosition = cursor; - int indexOfReverseSeqPostion = cursor; - - String resultCompletionText = ""; - String completionScriptText = ""; - try { - completionScriptText = text.substring(0, cursor); - } - catch (Exception e) { - LOGGER.error(e.toString()); - return null; - } - completionEndPosition = completionScriptText.length(); - - String tempReverseCompletionText = new StringBuilder(completionScriptText).reverse().toString(); - - for (String seqCharacter : completionSeqCharaters) { - indexOfReverseSeqPostion = tempReverseCompletionText.indexOf(seqCharacter); - - if (indexOfReverseSeqPostion < completionStartPosition && indexOfReverseSeqPostion > 0) { - completionStartPosition = indexOfReverseSeqPostion; - } - - } - - if (completionStartPosition == completionEndPosition) { - completionStartPosition = 0; - } - else - { - completionStartPosition = completionEndPosition - completionStartPosition; + protected IPythonInterpreter getIPythonInterpreter() { + IPySparkInterpreter iPython = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName()); + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); } - resultCompletionText = completionScriptText.substring( - completionStartPosition , completionEndPosition); - - return resultCompletionText; + iPython = (IPySparkInterpreter) p; + return iPython; } - private SparkInterpreter getSparkInterpreter() throws InterpreterException { LazyOpenInterpreter lazy = null; SparkInterpreter spark = null; @@ -646,63 +223,45 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand return spark; } - private IPySparkInterpreter getIPySparkInterpreter() { - LazyOpenInterpreter lazy = null; - IPySparkInterpreter iPySpark = null; - Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName()); - - while (p instanceof WrappedInterpreter) { - if (p instanceof LazyOpenInterpreter) { - lazy = (LazyOpenInterpreter) p; - } - p = ((WrappedInterpreter) p).getInnerInterpreter(); - } - iPySpark = (IPySparkInterpreter) p; - return iPySpark; - } - public SparkZeppelinContext getZeppelinContext() throws InterpreterException { - SparkInterpreter sparkIntp = getSparkInterpreter(); - if (sparkIntp != null) { - return getSparkInterpreter().getZeppelinContext(); + public SparkZeppelinContext getZeppelinContext() { + if (sparkInterpreter != null) { + return sparkInterpreter.getZeppelinContext(); } else { return null; } } - public JavaSparkContext getJavaSparkContext() throws InterpreterException { - SparkInterpreter intp = getSparkInterpreter(); - if (intp == null) { + public JavaSparkContext getJavaSparkContext() { + if (sparkInterpreter == null) { return null; } else { - return new JavaSparkContext(intp.getSparkContext()); + return new JavaSparkContext(sparkInterpreter.getSparkContext()); } } - public Object getSparkSession() throws InterpreterException { - SparkInterpreter intp = getSparkInterpreter(); - if (intp == null) { + public Object getSparkSession() { + if (sparkInterpreter == null) { return null; } else { - return intp.getSparkSession(); + return sparkInterpreter.getSparkSession(); } } - public SparkConf getSparkConf() throws InterpreterException { + public SparkConf getSparkConf() { JavaSparkContext sc = getJavaSparkContext(); if (sc == null) { return null; } else { - return getJavaSparkContext().getConf(); + return sc.getConf(); } } - public SQLContext getSQLContext() throws InterpreterException { - SparkInterpreter intp = getSparkInterpreter(); - if (intp == null) { + public SQLContext getSQLContext() { + if (sparkInterpreter == null) { return null; } else { - return intp.getSQLContext(); + return sparkInterpreter.getSQLContext(); } } @@ -718,21 +277,7 @@ public class PySparkInterpreter extends Interpreter implements ExecuteResultHand return (DepInterpreter) p; } - - @Override - public void onProcessComplete(int exitValue) { - pythonscriptRunning = false; - LOGGER.info("python process terminated. exit code " + exitValue); - } - - @Override - public void onProcessFailed(ExecuteException e) { - pythonscriptRunning = false; - LOGGER.error("python process failed", e); - } - - // Called by Python Process, used for debugging purpose - public void logPythonOutput(String message) { - LOGGER.debug("Python Process Output: " + message); + public boolean isSpark2() { + return sparkInterpreter.getSparkVersion().newerThanEquals(SparkVersion.SPARK_2_0_0); } } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py index 1352318..8fcca9b 100644 --- a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py @@ -15,150 +15,43 @@ # limitations under the License. # -import os, sys, getopt, traceback, json, re - -from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.protocol import Py4JJavaError +from py4j.java_gateway import java_import from pyspark.conf import SparkConf from pyspark.context import SparkContext -import ast -import warnings # for back compatibility from pyspark.sql import SQLContext, HiveContext, Row -class Logger(object): - def __init__(self): - pass - - def write(self, message): - intp.appendOutput(message) - - def reset(self): - pass - - def flush(self): - pass - - -class SparkVersion(object): - SPARK_1_4_0 = 10400 - SPARK_1_3_0 = 10300 - SPARK_2_0_0 = 20000 - - def __init__(self, versionNumber): - self.version = versionNumber - - def isAutoConvertEnabled(self): - return self.version >= self.SPARK_1_4_0 - - def isImportAllPackageUnderSparkSql(self): - return self.version >= self.SPARK_1_3_0 - - def isSpark2(self): - return self.version >= self.SPARK_2_0_0 - -class PySparkCompletion: - def __init__(self, interpreterObject): - self.interpreterObject = interpreterObject - - def getGlobalCompletion(self, text_value): - completions = [completion for completion in list(globals().keys()) if completion.startswith(text_value)] - return completions - - def getMethodCompletion(self, objName, methodName): - execResult = locals() - try: - exec("{} = dir({})".format("objectDefList", objName), globals(), execResult) - except: - return None - else: - objectDefList = execResult['objectDefList'] - return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)] - - def getCompletion(self, text_value): - if text_value == None: - return None - - dotPos = text_value.find(".") - if dotPos == -1: - objName = text_value - completionList = self.getGlobalCompletion(objName) - else: - objName = text_value[:dotPos] - methodName = text_value[dotPos + 1:] - completionList = self.getMethodCompletion(objName, methodName) - - if len(completionList) <= 0: - self.interpreterObject.setStatementsFinished("", False) - else: - result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList)))) - self.interpreterObject.setStatementsFinished(result, False) - -client = GatewayClient(port=int(sys.argv[1])) -sparkVersion = SparkVersion(int(sys.argv[2])) -if sparkVersion.isSpark2(): +intp = gateway.entry_point +isSpark2 = intp.isSpark2() +if isSpark2: from pyspark.sql import SparkSession -else: - from pyspark.sql import SchemaRDD - -if sparkVersion.isAutoConvertEnabled(): - gateway = JavaGateway(client, auto_convert = True) -else: - gateway = JavaGateway(client) +jsc = intp.getJavaSparkContext() java_import(gateway.jvm, "org.apache.spark.SparkEnv") java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") -intp = gateway.entry_point -output = Logger() -sys.stdout = output -sys.stderr = output - -jsc = intp.getJavaSparkContext() - -if sparkVersion.isImportAllPackageUnderSparkSql(): - java_import(gateway.jvm, "org.apache.spark.sql.*") - java_import(gateway.jvm, "org.apache.spark.sql.hive.*") -else: - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") - +java_import(gateway.jvm, "org.apache.spark.sql.*") +java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") -_zcUserQueryNameSpace = {} - jconf = intp.getSparkConf() conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf) sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf) -_zcUserQueryNameSpace["_zsc_"] = _zsc_ -_zcUserQueryNameSpace["sc"] = sc -if sparkVersion.isSpark2(): + +if isSpark2: spark = __zSpark__ = SparkSession(sc, intp.getSparkSession()) sqlc = __zSqlc__ = __zSpark__._wrapped - _zcUserQueryNameSpace["sqlc"] = sqlc - _zcUserQueryNameSpace["__zSqlc__"] = __zSqlc__ - _zcUserQueryNameSpace["spark"] = spark - _zcUserQueryNameSpace["__zSpark__"] = __zSpark__ + else: sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext()) - _zcUserQueryNameSpace["sqlc"] = sqlc - _zcUserQueryNameSpace["__zSqlc__"] = sqlc sqlContext = __zSqlc__ -_zcUserQueryNameSpace["sqlContext"] = sqlContext - -completion = __zeppelin_completion__ = PySparkCompletion(intp) -_zcUserQueryNameSpace["completion"] = completion -_zcUserQueryNameSpace["__zeppelin_completion__"] = __zeppelin_completion__ - from zeppelin_context import PyZeppelinContext @@ -176,92 +69,4 @@ class PySparkZeppelinContext(PyZeppelinContext): super(PySparkZeppelinContext, self).show(obj) z = __zeppelin__ = PySparkZeppelinContext(intp.getZeppelinContext(), gateway) - __zeppelin__._setup_matplotlib() -_zcUserQueryNameSpace["z"] = z -_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ - -intp.onPythonScriptInitialized(os.getpid()) - -while True : - req = intp.getStatements() - try: - stmts = req.statements().split("\n") - jobGroup = req.jobGroup() - jobDesc = req.jobDescription() - isForCompletion = req.isForCompletion() - - # Get post-execute hooks - try: - global_hook = intp.getHook('post_exec_dev') - except: - global_hook = None - - try: - user_hook = __zeppelin__.getHook('post_exec') - except: - user_hook = None - - nhooks = 0 - if not isForCompletion: - for hook in (global_hook, user_hook): - if hook: - nhooks += 1 - - if stmts: - # use exec mode to compile the statements except the last statement, - # so that the last statement's evaluation will be printed to stdout - sc.setJobGroup(jobGroup, jobDesc) - code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1) - to_run_hooks = [] - if (nhooks > 0): - to_run_hooks = code.body[-nhooks:] - - to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], - [code.body[-(nhooks + 1)]]) - try: - for node in to_run_exec: - mod = ast.Module([node]) - code = compile(mod, '<stdin>', 'exec') - exec(code, _zcUserQueryNameSpace) - - for node in to_run_single: - mod = ast.Interactive([node]) - code = compile(mod, '<stdin>', 'single') - exec(code, _zcUserQueryNameSpace) - - for node in to_run_hooks: - mod = ast.Module([node]) - code = compile(mod, '<stdin>', 'exec') - exec(code, _zcUserQueryNameSpace) - - if not isForCompletion: - # only call it when it is not for code completion. code completion will call it in - # PySparkCompletion.getCompletion - intp.setStatementsFinished("", False) - except Py4JJavaError: - # raise it to outside try except - raise - except: - if not isForCompletion: - exception = traceback.format_exc() - m = re.search("File \"<stdin>\", line (\d+).*", exception) - if m: - line_no = int(m.group(1)) - intp.setStatementsFinished( - "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) - else: - intp.setStatementsFinished(exception, True) - else: - intp.setStatementsFinished("", False) - - except Py4JJavaError: - excInnerError = traceback.format_exc() # format_tb() does not return the inner exception - innerErrorStart = excInnerError.find("Py4JJavaError:") - if innerErrorStart > -1: - excInnerError = excInnerError[innerErrorStart:] - intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) - except: - intp.setStatementsFinished(traceback.format_exc(), True) - - output.reset() http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java index 2cc11ac..ece5235 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java @@ -27,18 +27,16 @@ import org.apache.zeppelin.interpreter.InterpreterGroup; import org.apache.zeppelin.interpreter.InterpreterOutput; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteEventClient; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.python.IPythonInterpreterTest; import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import java.io.IOException; -import java.net.URL; +import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Properties; @@ -46,65 +44,72 @@ import java.util.Properties; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; -import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; -public class IPySparkInterpreterTest { +public class IPySparkInterpreterTest extends IPythonInterpreterTest { - private IPySparkInterpreter iPySparkInterpreter; private InterpreterGroup intpGroup; private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); - @Before - public void setup() throws InterpreterException { + @Override + protected Properties initIntpProperties() { Properties p = new Properties(); p.setProperty("spark.master", "local[4]"); p.setProperty("master", "local[4]"); p.setProperty("spark.submit.deployMode", "client"); p.setProperty("spark.app.name", "Zeppelin Test"); - p.setProperty("zeppelin.spark.useHiveContext", "true"); + p.setProperty("zeppelin.spark.useHiveContext", "false"); p.setProperty("zeppelin.spark.maxResult", "3"); p.setProperty("zeppelin.spark.importImplicit", "true"); + p.setProperty("zeppelin.spark.useNew", "true"); p.setProperty("zeppelin.pyspark.python", "python"); p.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath()); + p.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); + return p; + } + @Override + protected void startInterpreter(Properties properties) throws InterpreterException { intpGroup = new InterpreterGroup(); - intpGroup.put("session_1", new LinkedList<Interpreter>()); + intpGroup.put("session_1", new ArrayList<Interpreter>()); - SparkInterpreter sparkInterpreter = new SparkInterpreter(p); + LazyOpenInterpreter sparkInterpreter = new LazyOpenInterpreter( + new SparkInterpreter(properties)); intpGroup.get("session_1").add(sparkInterpreter); sparkInterpreter.setInterpreterGroup(intpGroup); - sparkInterpreter.open(); - sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient); - iPySparkInterpreter = new IPySparkInterpreter(p); - intpGroup.get("session_1").add(iPySparkInterpreter); - iPySparkInterpreter.setInterpreterGroup(intpGroup); - iPySparkInterpreter.open(); - sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient); + LazyOpenInterpreter pySparkInterpreter = + new LazyOpenInterpreter(new PySparkInterpreter(properties)); + intpGroup.get("session_1").add(pySparkInterpreter); + pySparkInterpreter.setInterpreterGroup(intpGroup); + + interpreter = new LazyOpenInterpreter(new IPySparkInterpreter(properties)); + intpGroup.get("session_1").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); + + interpreter.open(); } - @After + @Override public void tearDown() throws InterpreterException { - if (iPySparkInterpreter != null) { - iPySparkInterpreter.close(); - } + intpGroup.close(); + interpreter = null; + intpGroup = null; } @Test - public void testBasics() throws InterruptedException, IOException, InterpreterException { - // all the ipython test should pass too. - IPythonInterpreterTest.testInterpreter(iPySparkInterpreter); - testPySpark(iPySparkInterpreter, mockRemoteEventClient); - + public void testIPySpark() throws InterruptedException, InterpreterException, IOException { + testPySpark(interpreter, mockRemoteEventClient); } public static void testPySpark(final Interpreter interpreter, RemoteEventClient mockRemoteEventClient) throws InterpreterException, IOException, InterruptedException { + reset(mockRemoteEventClient); // rdd - InterpreterContext context = getInterpreterContext(mockRemoteEventClient); + InterpreterContext context = createInterpreterContext(mockRemoteEventClient); InterpreterResult result = interpreter.interpret("sc.version", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); @@ -112,17 +117,17 @@ public class IPySparkInterpreterTest { // spark url is sent verify(mockRemoteEventClient).onMetaInfosReceived(any(Map.class)); - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret("sc.range(1,10).sum()", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals("45", interpreterResultMessages.get(0).getData().trim()); // spark job url is sent - verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class)); +// verify(mockRemoteEventClient).onParaInfosReceived(any(String.class), any(String.class), any(Map.class)); // spark sql - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); if (!isSpark2(sparkVersion)) { result = interpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); @@ -135,7 +140,7 @@ public class IPySparkInterpreterTest { "| 2| b|\n" + "+---+---+", interpreterResultMessages.get(0).getData().trim()); - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret("z.show(df)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); @@ -155,7 +160,7 @@ public class IPySparkInterpreterTest { "| 2| b|\n" + "+---+---+", interpreterResultMessages.get(0).getData().trim()); - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret("z.show(df)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); @@ -166,7 +171,7 @@ public class IPySparkInterpreterTest { } // cancel if (interpreter instanceof IPySparkInterpreter) { - final InterpreterContext context2 = getInterpreterContext(mockRemoteEventClient); + final InterpreterContext context2 = createInterpreterContext(mockRemoteEventClient); Thread thread = new Thread() { @Override @@ -196,24 +201,24 @@ public class IPySparkInterpreterTest { } // completions - List<InterpreterCompletion> completions = interpreter.completion("sc.ran", 6, getInterpreterContext(mockRemoteEventClient)); + List<InterpreterCompletion> completions = interpreter.completion("sc.ran", 6, createInterpreterContext(mockRemoteEventClient)); assertEquals(1, completions.size()); assertEquals("range", completions.get(0).getValue()); - completions = interpreter.completion("sc.", 3, getInterpreterContext(mockRemoteEventClient)); + completions = interpreter.completion("sc.", 3, createInterpreterContext(mockRemoteEventClient)); assertTrue(completions.size() > 0); completions.contains(new InterpreterCompletion("range", "range", "")); - completions = interpreter.completion("1+1\nsc.", 7, getInterpreterContext(mockRemoteEventClient)); + completions = interpreter.completion("1+1\nsc.", 7, createInterpreterContext(mockRemoteEventClient)); assertTrue(completions.size() > 0); completions.contains(new InterpreterCompletion("range", "range", "")); - completions = interpreter.completion("s", 1, getInterpreterContext(mockRemoteEventClient)); + completions = interpreter.completion("s", 1, createInterpreterContext(mockRemoteEventClient)); assertTrue(completions.size() > 0); completions.contains(new InterpreterCompletion("sc", "sc", "")); // pyspark streaming - context = getInterpreterContext(mockRemoteEventClient); + context = createInterpreterContext(mockRemoteEventClient); result = interpreter.interpret( "from pyspark.streaming import StreamingContext\n" + "import time\n" + @@ -239,7 +244,7 @@ public class IPySparkInterpreterTest { return sparkVersion.startsWith("'2.") || sparkVersion.startsWith("u'2."); } - private static InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) { + private static InterpreterContext createInterpreterContext(RemoteEventClient mockRemoteEventClient) { InterpreterContext context = new InterpreterContext( "noteId", "paragraphId", http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java index 068ff50..3a98653 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java @@ -127,7 +127,7 @@ public class OldSparkInterpreterTest { new LocalResourcePool("id"), new LinkedList<InterpreterContextRunner>(), new InterpreterOutput(null)) { - + @Override public RemoteEventClientWrapper getClient() { return remoteEventClientWrapper; @@ -192,7 +192,7 @@ public class OldSparkInterpreterTest { public void testEndWithComment() throws InterpreterException { assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("val c=1\n//comment", context).code()); } - + @Test public void testCreateDataFrame() throws InterpreterException { if (getSparkVersionNumber(repl) >= 13) { http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java index e228c7e..446f183 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java @@ -17,154 +17,73 @@ package org.apache.zeppelin.spark; -import org.apache.zeppelin.display.AngularObjectRegistry; -import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.interpreter.*; +import com.google.common.io.Files; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteEventClient; -import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; -import org.apache.zeppelin.resource.LocalResourcePool; -import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.*; -import org.junit.rules.TemporaryFolder; -import org.junit.runners.MethodSorters; +import org.apache.zeppelin.python.PythonInterpreterTest; +import org.junit.Test; import java.io.IOException; -import java.util.HashMap; import java.util.LinkedList; -import java.util.List; -import java.util.Map; import java.util.Properties; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import static org.junit.Assert.*; -import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -@FixMethodOrder(MethodSorters.NAME_ASCENDING) -public class PySparkInterpreterTest { +public class PySparkInterpreterTest extends PythonInterpreterTest { - @ClassRule - public static TemporaryFolder tmpDir = new TemporaryFolder(); - - static SparkInterpreter sparkInterpreter; - static PySparkInterpreter pySparkInterpreter; - static InterpreterGroup intpGroup; - static InterpreterContext context; private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); - private static Properties getPySparkTestProperties() throws IOException { - Properties p = new Properties(); - p.setProperty("spark.master", "local"); - p.setProperty("spark.app.name", "Zeppelin Test"); - p.setProperty("zeppelin.spark.useHiveContext", "true"); - p.setProperty("zeppelin.spark.maxResult", "1000"); - p.setProperty("zeppelin.spark.importImplicit", "true"); - p.setProperty("zeppelin.pyspark.python", "python"); - p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); - p.setProperty("zeppelin.pyspark.useIPython", "false"); - p.setProperty("zeppelin.spark.test", "true"); - return p; - } - - /** - * Get spark version number as a numerical value. - * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ... - */ - public static int getSparkVersionNumber() { - if (sparkInterpreter == null) { - return 0; - } - - String[] split = sparkInterpreter.getSparkContext().version().split("\\."); - int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]); - return version; - } - - @BeforeClass - public static void setUp() throws Exception { + @Override + public void setUp() throws InterpreterException { + Properties properties = new Properties(); + properties.setProperty("spark.master", "local"); + properties.setProperty("spark.app.name", "Zeppelin Test"); + properties.setProperty("zeppelin.spark.useHiveContext", "false"); + properties.setProperty("zeppelin.spark.maxResult", "3"); + properties.setProperty("zeppelin.spark.importImplicit", "true"); + properties.setProperty("zeppelin.pyspark.python", "python"); + properties.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath()); + properties.setProperty("zeppelin.pyspark.useIPython", "false"); + properties.setProperty("zeppelin.spark.useNew", "true"); + properties.setProperty("zeppelin.spark.test", "true"); + properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); + + InterpreterContext.set(getInterpreterContext(mockRemoteEventClient)); + // create interpreter group intpGroup = new InterpreterGroup(); intpGroup.put("note", new LinkedList<Interpreter>()); - context = new InterpreterContext("note", "id", null, "title", "text", - new AuthenticationInfo(), - new HashMap<String, Object>(), - new GUI(), - new GUI(), - new AngularObjectRegistry(intpGroup.getId(), null), - new LocalResourcePool("id"), - new LinkedList<InterpreterContextRunner>(), - new InterpreterOutput(null)); - InterpreterContext.set(context); - - sparkInterpreter = new SparkInterpreter(getPySparkTestProperties()); + LazyOpenInterpreter sparkInterpreter = + new LazyOpenInterpreter(new SparkInterpreter(properties)); intpGroup.get("note").add(sparkInterpreter); sparkInterpreter.setInterpreterGroup(intpGroup); - sparkInterpreter.open(); - - pySparkInterpreter = new PySparkInterpreter(getPySparkTestProperties()); - intpGroup.get("note").add(pySparkInterpreter); - pySparkInterpreter.setInterpreterGroup(intpGroup); - pySparkInterpreter.open(); - } - - @AfterClass - public static void tearDown() throws InterpreterException { - pySparkInterpreter.close(); - sparkInterpreter.close(); - } - @Test - public void testBasicIntp() throws InterpreterException, InterruptedException, IOException { - IPySparkInterpreterTest.testPySpark(pySparkInterpreter, mockRemoteEventClient); - } + LazyOpenInterpreter iPySparkInterpreter = + new LazyOpenInterpreter(new IPySparkInterpreter(properties)); + intpGroup.get("note").add(iPySparkInterpreter); + iPySparkInterpreter.setInterpreterGroup(intpGroup); - @Test - public void testRedefinitionZeppelinContext() throws InterpreterException { - if (getSparkVersionNumber() > 11) { - String redefinitionCode = "z = 1\n"; - String restoreCode = "z = __zeppelin__\n"; - String validCode = "z.input(\"test\")\n"; + interpreter = new LazyOpenInterpreter(new PySparkInterpreter(properties)); + intpGroup.get("note").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(redefinitionCode, context).code()); - assertEquals(InterpreterResult.Code.ERROR, pySparkInterpreter.interpret(validCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(restoreCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code()); - } + interpreter.open(); } - private class infinityPythonJob implements Runnable { - @Override - public void run() { - String code = "import time\nwhile True:\n time.sleep(1)" ; - InterpreterResult ret = null; - try { - ret = pySparkInterpreter.interpret(code, context); - } catch (InterpreterException e) { - e.printStackTrace(); - } - assertNotNull(ret); - Pattern expectedMessage = Pattern.compile("KeyboardInterrupt"); - Matcher m = expectedMessage.matcher(ret.message().toString()); - assertTrue(m.find()); - } + @Override + public void tearDown() throws InterpreterException { + intpGroup.close(); + intpGroup = null; + interpreter = null; } @Test - public void testCancelIntp() throws InterruptedException, InterpreterException { - if (getSparkVersionNumber() > 11) { - assertEquals(InterpreterResult.Code.SUCCESS, - pySparkInterpreter.interpret("a = 1\n", context).code()); - - Thread t = new Thread(new infinityPythonJob()); - t.start(); - Thread.sleep(5000); - pySparkInterpreter.cancel(context); - assertTrue(t.isAlive()); - t.join(2000); - assertFalse(t.isAlive()); - } + public void testPySpark() throws InterruptedException, InterpreterException, IOException { + IPySparkInterpreterTest.testPySpark(interpreter, mockRemoteEventClient); } + } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java index 53f29c3..8eaf1e4 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java @@ -26,6 +26,8 @@ import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.remote.RemoteEventClient; import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import java.io.IOException; @@ -47,8 +49,8 @@ public class SparkRInterpreterTest { private SparkInterpreter sparkInterpreter; private RemoteEventClient mockRemoteEventClient = mock(RemoteEventClient.class); - @Test - public void testSparkRInterpreter() throws InterpreterException, InterruptedException { + @Before + public void setUp() throws InterpreterException { Properties properties = new Properties(); properties.setProperty("spark.master", "local"); properties.setProperty("spark.app.name", "test"); @@ -69,6 +71,16 @@ public class SparkRInterpreterTest { sparkRInterpreter.open(); sparkInterpreter.getZeppelinContext().setEventClient(mockRemoteEventClient); + } + + @After + public void tearDown() throws InterpreterException { + sparkInterpreter.close(); + } + + @Test + public void testSparkRInterpreter() throws InterpreterException, InterruptedException { + InterpreterResult result = sparkRInterpreter.interpret("1+1", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/resources/log4j.properties b/spark/interpreter/src/test/resources/log4j.properties index 0dc7c89..edd13e4 100644 --- a/spark/interpreter/src/test/resources/log4j.properties +++ b/spark/interpreter/src/test/resources/log4j.properties @@ -43,9 +43,9 @@ log4j.logger.DataNucleus.Datastore=ERROR # Log all JDBC parameters log4j.logger.org.hibernate.type=ALL -log4j.logger.org.apache.zeppelin.interpreter=DEBUG -log4j.logger.org.apache.zeppelin.spark=DEBUG +log4j.logger.org.apache.zeppelin.interpreter=WARN +log4j.logger.org.apache.zeppelin.spark=INFO log4j.logger.org.apache.zeppelin.python=DEBUG -log4j.logger.org.apache.spark.repl.Main=INFO +log4j.logger.org.apache.spark.repl.Main=WARN http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java ---------------------------------------------------------------------- diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java index 9f88901..4cf4b31 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterGroup.java @@ -161,4 +161,17 @@ public class InterpreterGroup { public int hashCode() { return id != null ? id.hashCode() : 0; } + + public void close() { + for (List<Interpreter> session : sessions.values()) { + for (Interpreter interpreter : session) { + try { + interpreter.close(); + } catch (InterpreterException e) { + LOGGER.warn("Fail to close interpreter: " + interpreter.getClassName(), e); + } + } + } + sessions.clear(); + } } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java ---------------------------------------------------------------------- diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java index 6198c7b..6710915 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/ZeppelinSparkClusterTest.java @@ -31,9 +31,11 @@ import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.util.Arrays; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.zeppelin.conf.ZeppelinConfiguration; import org.apache.zeppelin.display.AngularObject; @@ -54,8 +56,16 @@ import org.apache.zeppelin.user.AuthenticationInfo; */ @RunWith(value = Parameterized.class) public class ZeppelinSparkClusterTest extends AbstractTestRestApi { + private static final Logger LOGGER = LoggerFactory.getLogger(ZeppelinSparkClusterTest.class); + //This is for only run setupSparkInterpreter one time for each spark version, otherwise + //each test method will run setupSparkInterpreter which will cost a long time and may cause travis + //ci timeout. + //TODO(zjffdu) remove this after we upgrade it to junit 4.13 (ZEPPELIN-3341) + private static Set<String> verifiedSparkVersions = new HashSet<>(); + + private String sparkVersion; private AuthenticationInfo anonymous = new AuthenticationInfo("anonymous"); @@ -63,8 +73,11 @@ public class ZeppelinSparkClusterTest extends AbstractTestRestApi { this.sparkVersion = sparkVersion; LOGGER.info("Testing SparkVersion: " + sparkVersion); String sparkHome = SparkDownloadUtils.downloadSpark(sparkVersion); - setupSparkInterpreter(sparkHome); - verifySparkVersionNumber(); + if (!verifiedSparkVersions.contains(sparkVersion)) { + verifiedSparkVersions.add(sparkVersion); + setupSparkInterpreter(sparkHome); + verifySparkVersionNumber(); + } } @Parameterized.Parameters http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java ---------------------------------------------------------------------- diff --git a/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java b/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java index 21de851..04a87fd 100644 --- a/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java +++ b/zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/InterpreterSetting.java @@ -520,7 +520,8 @@ public class InterpreterSetting { Map<String, InterpreterProperty> iProperties = (Map<String, InterpreterProperty>) properties; for (Map.Entry<String, InterpreterProperty> entry : iProperties.entrySet()) { if (entry.getValue().getValue() != null) { - jProperties.setProperty(entry.getKey().trim(), entry.getValue().getValue().toString().trim()); + jProperties.setProperty(entry.getKey().trim(), + entry.getValue().getValue().toString().trim()); } }