http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/main/resources/python/zeppelin_python.py ---------------------------------------------------------------------- diff --git a/python/src/main/resources/python/zeppelin_python.py b/python/src/main/resources/python/zeppelin_python.py index 0b2d533..19fa220 100644 --- a/python/src/main/resources/python/zeppelin_python.py +++ b/python/src/main/resources/python/zeppelin_python.py @@ -15,24 +15,12 @@ # limitations under the License. # -import os, sys, getopt, traceback, json, re +import os, sys, traceback, json, re from py4j.java_gateway import java_import, JavaGateway, GatewayClient -from py4j.protocol import Py4JJavaError, Py4JNetworkError -import warnings -import ast -import traceback -import warnings -import signal -import base64 - -from io import BytesIO -try: - from StringIO import StringIO -except ImportError: - from io import StringIO +from py4j.protocol import Py4JJavaError -# for back compatibility +import ast class Logger(object): def __init__(self): @@ -47,46 +35,79 @@ class Logger(object): def flush(self): pass -def handler_stop_signals(sig, frame): - sys.exit("Got signal : " + str(sig)) +class PythonCompletion: + def __init__(self, interpreter, userNameSpace): + self.interpreter = interpreter + self.userNameSpace = userNameSpace -signal.signal(signal.SIGINT, handler_stop_signals) + def getObjectCompletion(self, text_value): + completions = [completion for completion in list(self.userNameSpace.keys()) if completion.startswith(text_value)] + builtinCompletions = [completion for completion in dir(__builtins__) if completion.startswith(text_value)] + return completions + builtinCompletions -host = "127.0.0.1" -if len(sys.argv) >= 3: - host = sys.argv[2] + def getMethodCompletion(self, objName, methodName): + execResult = locals() + try: + exec("{} = dir({})".format("objectDefList", objName), _zcUserQueryNameSpace, execResult) + except: + self.interpreter.logPythonOutput("Fail to run dir on " + objName) + self.interpreter.logPythonOutput(traceback.format_exc()) + return None + else: + objectDefList = execResult['objectDefList'] + return [completion for completion in execResult['objectDefList'] if completion.startswith(methodName)] + + def getCompletion(self, text_value): + if text_value == None: + return None + + dotPos = text_value.find(".") + if dotPos == -1: + objName = text_value + completionList = self.getObjectCompletion(objName) + else: + objName = text_value[:dotPos] + methodName = text_value[dotPos + 1:] + completionList = self.getMethodCompletion(objName, methodName) + + if completionList is None or len(completionList) <= 0: + self.interpreter.setStatementsFinished("", False) + else: + result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList)))) + self.interpreter.setStatementsFinished(result, False) + +host = sys.argv[1] +port = int(sys.argv[2]) + +client = GatewayClient(address=host, port=port) +gateway = JavaGateway(client, auto_convert = True) +intp = gateway.entry_point +# redirect stdout/stderr to java side so that PythonInterpreter can capture the python execution result +output = Logger() +sys.stdout = output +sys.stderr = output _zcUserQueryNameSpace = {} -client = GatewayClient(address=host, port=int(sys.argv[1])) - -gateway = JavaGateway(client) - -intp = gateway.entry_point -intp.onPythonScriptInitialized(os.getpid()) -java_import(gateway.jvm, "org.apache.zeppelin.display.Input") +completion = PythonCompletion(intp, _zcUserQueryNameSpace) +_zcUserQueryNameSpace["__zeppelin_completion__"] = completion +_zcUserQueryNameSpace["gateway"] = gateway from zeppelin_context import PyZeppelinContext +if intp.getZeppelinContext(): + z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway) + __zeppelin__._setup_matplotlib() + _zcUserQueryNameSpace["z"] = z + _zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ -z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext(), gateway) -__zeppelin__._setup_matplotlib() - -_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__ -_zcUserQueryNameSpace["z"] = z - -output = Logger() -sys.stdout = output -#sys.stderr = output +intp.onPythonScriptInitialized(os.getpid()) while True : req = intp.getStatements() - if req == None: - break - try: stmts = req.statements().split("\n") - final_code = [] + isForCompletion = req.isForCompletion() # Get post-execute hooks try: @@ -98,35 +119,23 @@ while True : user_hook = __zeppelin__.getHook('post_exec') except: user_hook = None - - nhooks = 0 - for hook in (global_hook, user_hook): - if hook: - nhooks += 1 - for s in stmts: - if s == None: - continue - - # skip comment - s_stripped = s.strip() - if len(s_stripped) == 0 or s_stripped.startswith("#"): - continue - - final_code.append(s) + nhooks = 0 + if not isForCompletion: + for hook in (global_hook, user_hook): + if hook: + nhooks += 1 - if final_code: + if stmts: # use exec mode to compile the statements except the last statement, # so that the last statement's evaluation will be printed to stdout - code = compile('\n'.join(final_code), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1) - + code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1) to_run_hooks = [] if (nhooks > 0): to_run_hooks = code.body[-nhooks:] to_run_exec, to_run_single = (code.body[:-(nhooks + 1)], [code.body[-(nhooks + 1)]]) - try: for node in to_run_exec: mod = ast.Module([node]) @@ -142,19 +151,37 @@ while True : mod = ast.Module([node]) code = compile(mod, '<stdin>', 'exec') exec(code, _zcUserQueryNameSpace) + + if not isForCompletion: + # only call it when it is not for code completion. code completion will call it in + # PythonCompletion.getCompletion + intp.setStatementsFinished("", False) + except Py4JJavaError: + # raise it to outside try except + raise except: - raise Exception(traceback.format_exc()) + if not isForCompletion: + # extract which line incur error from error message. e.g. + # Traceback (most recent call last): + # File "<stdin>", line 1, in <module> + # ZeroDivisionError: integer division or modulo by zero + exception = traceback.format_exc() + m = re.search("File \"<stdin>\", line (\d+).*", exception) + if m: + line_no = int(m.group(1)) + intp.setStatementsFinished( + "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True) + else: + intp.setStatementsFinished(exception, True) + else: + intp.setStatementsFinished("", False) - intp.setStatementsFinished("", False) except Py4JJavaError: excInnerError = traceback.format_exc() # format_tb() does not return the inner exception innerErrorStart = excInnerError.find("Py4JJavaError:") if innerErrorStart > -1: - excInnerError = excInnerError[innerErrorStart:] + excInnerError = excInnerError[innerErrorStart:] intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True) - except Py4JNetworkError: - # lost connection from gateway server. exit - sys.exit(1) except: intp.setStatementsFinished(traceback.format_exc(), True)
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java ---------------------------------------------------------------------- diff --git a/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java new file mode 100644 index 0000000..9bedd53 --- /dev/null +++ b/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.python; + +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.display.ui.CheckBox; +import org.apache.zeppelin.display.ui.Select; +import org.apache.zeppelin.display.ui.TextBox; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.remote.RemoteEventClient; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +public abstract class BasePythonInterpreterTest { + + protected InterpreterGroup intpGroup; + protected Interpreter interpreter; + + @Before + public abstract void setUp() throws InterpreterException; + + @After + public abstract void tearDown() throws InterpreterException; + + + @Test + public void testPythonBasics() throws InterpreterException, InterruptedException, IOException { + + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("import sys\nprint(sys.version[0])", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + Thread.sleep(100); + List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + + // single output without print + context = getInterpreterContext(); + result = interpreter.interpret("'hello world'", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("'hello world'", interpreterResultMessages.get(0).getData().trim()); + + // unicode + context = getInterpreterContext(); + result = interpreter.interpret("print(u'ä½ å¥½')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("ä½ å¥½\n", interpreterResultMessages.get(0).getData()); + + // only the last statement is printed + context = getInterpreterContext(); + result = interpreter.interpret("'hello world'\n'hello world2'", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("'hello world2'", interpreterResultMessages.get(0).getData().trim()); + + // single output + context = getInterpreterContext(); + result = interpreter.interpret("print('hello world')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("hello world\n", interpreterResultMessages.get(0).getData()); + + // multiple output + context = getInterpreterContext(); + result = interpreter.interpret("print('hello world')\nprint('hello world2')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("hello world\nhello world2\n", interpreterResultMessages.get(0).getData()); + + // assignment + context = getInterpreterContext(); + result = interpreter.interpret("abc=1",context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(0, interpreterResultMessages.size()); + + // if block + context = getInterpreterContext(); + result = interpreter.interpret("if abc > 0:\n\tprint('True')\nelse:\n\tprint('False')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("True\n", interpreterResultMessages.get(0).getData()); + + // for loop + context = getInterpreterContext(); + result = interpreter.interpret("for i in range(3):\n\tprint(i)", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("0\n1\n2\n", interpreterResultMessages.get(0).getData()); + + // syntax error + context = getInterpreterContext(); + result = interpreter.interpret("print(unknown)", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + if (interpreter instanceof IPythonInterpreter) { + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertTrue(interpreterResultMessages.get(0).getData().contains("name 'unknown' is not defined")); + } else if (interpreter instanceof PythonInterpreter) { + assertTrue(result.message().get(0).getData().contains("name 'unknown' is not defined")); + } + + // raise runtime exception + context = getInterpreterContext(); + result = interpreter.interpret("1/0", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + if (interpreter instanceof IPythonInterpreter) { + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertTrue(interpreterResultMessages.get(0).getData().contains("ZeroDivisionError")); + } else if (interpreter instanceof PythonInterpreter) { + assertTrue(result.message().get(0).getData().contains("ZeroDivisionError")); + } + + // ZEPPELIN-1133 + context = getInterpreterContext(); + result = interpreter.interpret( + "from __future__ import print_function\n" + + "def greet(name):\n" + + " print('Hello', name)\n" + + "greet('Jack')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("Hello Jack\n",interpreterResultMessages.get(0).getData()); + + // ZEPPELIN-1114 + context = getInterpreterContext(); + result = interpreter.interpret("print('there is no Error: ok')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals("there is no Error: ok\n", interpreterResultMessages.get(0).getData()); + } + + @Test + public void testCodeCompletion() throws InterpreterException, IOException, InterruptedException { + // there's no completion for 'a.' because it is not recognized by compiler for now. + InterpreterContext context = getInterpreterContext(); + String st = "a='hello'\na."; + List<InterpreterCompletion> completions = interpreter.completion(st, st.length(), context); + assertEquals(0, completions.size()); + + // define `a` first + context = getInterpreterContext(); + st = "a='hello'"; + InterpreterResult result = interpreter.interpret(st, context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // now we can get the completion for `a.` + context = getInterpreterContext(); + st = "a."; + completions = interpreter.completion(st, st.length(), context); + // it is different for python2 and python3 and may even different for different minor version + // so only verify it is larger than 20 + assertTrue(completions.size() > 20); + + context = getInterpreterContext(); + st = "a.co"; + completions = interpreter.completion(st, st.length(), context); + assertEquals(1, completions.size()); + assertEquals("count", completions.get(0).getValue()); + + // cursor is in the middle of code + context = getInterpreterContext(); + st = "a.co\b='hello"; + completions = interpreter.completion(st, 4, context); + assertEquals(1, completions.size()); + assertEquals("count", completions.get(0).getValue()); + } + + @Test + public void testZeppelinContext() throws InterpreterException, InterruptedException, IOException { + // TextBox + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("z.input(name='text_1', defaultValue='value_1')", context); + Thread.sleep(100); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertTrue(interpreterResultMessages.get(0).getData().contains("'value_1'")); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("text_1") instanceof TextBox); + TextBox textbox = (TextBox) context.getGui().getForms().get("text_1"); + assertEquals("text_1", textbox.getName()); + assertEquals("value_1", textbox.getDefaultValue()); + + // Select + context = getInterpreterContext(); + result = interpreter.interpret("z.select(name='select_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("select_1") instanceof Select); + Select select = (Select) context.getGui().getForms().get("select_1"); + assertEquals("select_1", select.getName()); + assertEquals(2, select.getOptions().length); + assertEquals("name_1", select.getOptions()[0].getDisplayName()); + assertEquals("value_1", select.getOptions()[0].getValue()); + + // CheckBox + context = getInterpreterContext(); + result = interpreter.interpret("z.checkbox(name='checkbox_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox); + CheckBox checkbox = (CheckBox) context.getGui().getForms().get("checkbox_1"); + assertEquals("checkbox_1", checkbox.getName()); + assertEquals(2, checkbox.getOptions().length); + assertEquals("name_1", checkbox.getOptions()[0].getDisplayName()); + assertEquals("value_1", checkbox.getOptions()[0].getValue()); + + // Pandas DataFrame + context = getInterpreterContext(); + result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); + assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); + + context = getInterpreterContext(); + result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(2, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); + assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); + assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType()); + assertEquals("<font color=red>Results are limited by 3.</font>\n", interpreterResultMessages.get(1).getData()); + + // z.show(matplotlib) + context = getInterpreterContext(); + result = interpreter.interpret("import matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)\nz.show(plt)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType()); + + // clear output + context = getInterpreterContext(); + result = interpreter.interpret("import time\nprint(\"Hello\")\ntime.sleep(0.5)\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context); + assertEquals("%text world\n", context.out.getCurrentOutput().toString()); + } + + @Test + public void testRedefinitionZeppelinContext() throws InterpreterException { + String redefinitionCode = "z = 1\n"; + String restoreCode = "z = __zeppelin__\n"; + String validCode = "z.input(\"test\")\n"; + + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(validCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(redefinitionCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.ERROR, interpreter.interpret(validCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(restoreCode, getInterpreterContext()).code()); + assertEquals(InterpreterResult.Code.SUCCESS, interpreter.interpret(validCode, getInterpreterContext()).code()); + } + + protected InterpreterContext getInterpreterContext() { + return new InterpreterContext( + "noteId", + "paragraphId", + "replName", + "paragraphTitle", + "paragraphText", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + null, + null, + null, + new InterpreterOutput(null)); + } + + protected InterpreterContext getInterpreterContext(RemoteEventClient mockRemoteEventClient) { + InterpreterContext context = getInterpreterContext(); + context.setClient(mockRemoteEventClient); + return context; + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java ---------------------------------------------------------------------- diff --git a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java index f016f09..9e01d06 100644 --- a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java @@ -17,288 +17,64 @@ package org.apache.zeppelin.python; -import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.display.ui.CheckBox; -import org.apache.zeppelin.display.ui.Select; -import org.apache.zeppelin.display.ui.TextBox; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterGroup; -import org.apache.zeppelin.interpreter.InterpreterOutput; -import org.apache.zeppelin.interpreter.InterpreterOutputListener; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResultMessage; -import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput; -import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; -import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.After; -import org.junit.Before; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.junit.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Properties; -import java.util.concurrent.CopyOnWriteArrayList; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.mockito.Mockito.mock; -public class IPythonInterpreterTest { +public class IPythonInterpreterTest extends BasePythonInterpreterTest { - private static final Logger LOGGER = LoggerFactory.getLogger(IPythonInterpreterTest.class); - private IPythonInterpreter interpreter; - public void startInterpreter(Properties properties) throws InterpreterException { - interpreter = new IPythonInterpreter(properties); - InterpreterGroup mockInterpreterGroup = mock(InterpreterGroup.class); - interpreter.setInterpreterGroup(mockInterpreterGroup); - interpreter.open(); - } - - @After - public void close() throws InterpreterException { - interpreter.close(); - } - - - @Test - public void testIPython() throws IOException, InterruptedException, InterpreterException { + protected Properties initIntpProperties() { Properties properties = new Properties(); properties.setProperty("zeppelin.python.maxResult", "3"); - startInterpreter(properties); - testInterpreter(interpreter); + properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); + return properties; } - @Test - public void testGrpcFrameSize() throws InterpreterException, IOException { - Properties properties = new Properties(); - properties.setProperty("zeppelin.ipython.grpc.message_size", "200"); - startInterpreter(properties); - - // to make this test can run under both python2 and python3 - InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - - InterpreterContext context = getInterpreterContext(); - result = interpreter.interpret("print('1'*300)", context); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertTrue(interpreterResultMessages.get(0).getData().contains("Frame size 304 exceeds maximum: 200")); - - // next call continue work - result = interpreter.interpret("print(1)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + protected void startInterpreter(Properties properties) throws InterpreterException { + interpreter = new LazyOpenInterpreter(new IPythonInterpreter(properties)); + intpGroup = new InterpreterGroup(); + intpGroup.put("session_1", new ArrayList<Interpreter>()); + intpGroup.get("session_1").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); - close(); + interpreter.open(); + } - // increase framesize to make it work - properties.setProperty("zeppelin.ipython.grpc.message_size", "500"); + @Override + public void setUp() throws InterpreterException { + Properties properties = initIntpProperties(); startInterpreter(properties); - // to make this test can run under both python2 and python3 - result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - - context = getInterpreterContext(); - result = interpreter.interpret("print('1'*300)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); } - public static void testInterpreter(final Interpreter interpreter) throws IOException, InterruptedException, InterpreterException { - // to make this test can run under both python2 and python3 - InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - - - InterpreterContext context = getInterpreterContext(); - result = interpreter.interpret("import sys\nprint(sys.version[0])", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - Thread.sleep(100); - List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - boolean isPython2 = interpreterResultMessages.get(0).getData().equals("2\n"); - - // single output without print - context = getInterpreterContext(); - result = interpreter.interpret("'hello world'", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("'hello world'", interpreterResultMessages.get(0).getData()); - - // unicode - context = getInterpreterContext(); - result = interpreter.interpret("print(u'ä½ å¥½')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("ä½ å¥½\n", interpreterResultMessages.get(0).getData()); - - // only the last statement is printed - context = getInterpreterContext(); - result = interpreter.interpret("'hello world'\n'hello world2'", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("'hello world2'", interpreterResultMessages.get(0).getData()); - - // single output - context = getInterpreterContext(); - result = interpreter.interpret("print('hello world')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("hello world\n", interpreterResultMessages.get(0).getData()); - - // multiple output - context = getInterpreterContext(); - result = interpreter.interpret("print('hello world')\nprint('hello world2')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("hello world\nhello world2\n", interpreterResultMessages.get(0).getData()); - - // assignment - context = getInterpreterContext(); - result = interpreter.interpret("abc=1",context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(0, interpreterResultMessages.size()); - - // if block - context = getInterpreterContext(); - result = interpreter.interpret("if abc > 0:\n\tprint('True')\nelse:\n\tprint('False')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("True\n", interpreterResultMessages.get(0).getData()); - - // for loop - context = getInterpreterContext(); - result = interpreter.interpret("for i in range(3):\n\tprint(i)", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("0\n1\n2\n", interpreterResultMessages.get(0).getData()); - - // syntax error - context = getInterpreterContext(); - result = interpreter.interpret("print(unknown)", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertTrue(interpreterResultMessages.get(0).getData().contains("name 'unknown' is not defined")); - - // raise runtime exception - context = getInterpreterContext(); - result = interpreter.interpret("1/0", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.ERROR, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertTrue(interpreterResultMessages.get(0).getData().contains("ZeroDivisionError")); - - // ZEPPELIN-1133 - context = getInterpreterContext(); - result = interpreter.interpret("def greet(name):\n" + - " print('Hello', name)\n" + - "greet('Jack')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("Hello Jack\n",interpreterResultMessages.get(0).getData()); - - // ZEPPELIN-1114 - context = getInterpreterContext(); - result = interpreter.interpret("print('there is no Error: ok')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(1, interpreterResultMessages.size()); - assertEquals("there is no Error: ok\n", interpreterResultMessages.get(0).getData()); - - // completion - context = getInterpreterContext(); - List<InterpreterCompletion> completions = interpreter.completion("ab", 2, context); - assertEquals(2, completions.size()); - assertEquals("abc", completions.get(0).getValue()); - assertEquals("abs", completions.get(1).getValue()); - - context = getInterpreterContext(); - interpreter.interpret("import sys", context); - completions = interpreter.completion("sys.", 4, context); - assertFalse(completions.isEmpty()); - - context = getInterpreterContext(); - completions = interpreter.completion("sys.std", 7, context); - for (InterpreterCompletion completion : completions) { - System.out.println(completion.getValue()); - } - assertEquals(3, completions.size()); - assertEquals("stderr", completions.get(0).getValue()); - assertEquals("stdin", completions.get(1).getValue()); - assertEquals("stdout", completions.get(2).getValue()); - - // there's no completion for 'a.' because it is not recognized by compiler for now. - context = getInterpreterContext(); - String st = "a='hello'\na."; - completions = interpreter.completion(st, st.length(), context); - assertEquals(0, completions.size()); - - // define `a` first - context = getInterpreterContext(); - st = "a='hello'"; - result = interpreter.interpret(st, context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(0, interpreterResultMessages.size()); - - // now we can get the completion for `a.` - context = getInterpreterContext(); - st = "a."; - completions = interpreter.completion(st, st.length(), context); - // it is different for python2 and python3 and may even different for different minor version - // so only verify it is larger than 20 - assertTrue(completions.size() > 20); - - context = getInterpreterContext(); - st = "a.co"; - completions = interpreter.completion(st, st.length(), context); - assertEquals(1, completions.size()); - assertEquals("count", completions.get(0).getValue()); - - // cursor is in the middle of code - context = getInterpreterContext(); - st = "a.co\b='hello"; - completions = interpreter.completion(st, 4, context); - assertEquals(1, completions.size()); - assertEquals("count", completions.get(0).getValue()); + @Override + public void tearDown() throws InterpreterException { + intpGroup.close(); + } + @Test + public void testIPythonAdvancedFeatures() throws InterpreterException, InterruptedException, IOException { // ipython help - context = getInterpreterContext(); - result = interpreter.interpret("range?", context); + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("range?", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); + List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); assertTrue(interpreterResultMessages.get(0).getData().contains("range(stop)")); // timeit @@ -331,13 +107,16 @@ public class IPythonInterpreterTest { assertEquals(InterpreterResult.Code.ERROR, result.code()); interpreterResultMessages = context2.out.toInterpreterResultMessage(); assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt")); + } + @Test + public void testIPythonPlotting() throws InterpreterException, InterruptedException, IOException { // matplotlib - context = getInterpreterContext(); - result = interpreter.interpret("%matplotlib inline\nimport matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)", context); + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret("%matplotlib inline\nimport matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)", context); Thread.sleep(100); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); + List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); // the order of IMAGE and TEXT is not determined // check there must be one IMAGE output boolean hasImageOutput = false; @@ -411,94 +190,44 @@ public class IPythonInterpreterTest { } } assertTrue("No Image Output", hasImageOutput); + } - // ZeppelinContext + @Test + public void testGrpcFrameSize() throws InterpreterException, IOException { + tearDown(); - // TextBox - context = getInterpreterContext(); - result = interpreter.interpret("z.input(name='text_1', defaultValue='value_1')", context); - Thread.sleep(100); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertTrue(interpreterResultMessages.get(0).getData().contains("'value_1'")); - assertEquals(1, context.getGui().getForms().size()); - assertTrue(context.getGui().getForms().get("text_1") instanceof TextBox); - TextBox textbox = (TextBox) context.getGui().getForms().get("text_1"); - assertEquals("text_1", textbox.getName()); - assertEquals("value_1", textbox.getDefaultValue()); + Properties properties = initIntpProperties(); + properties.setProperty("zeppelin.ipython.grpc.message_size", "3000"); - // Select - context = getInterpreterContext(); - result = interpreter.interpret("z.select(name='select_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - assertEquals(1, context.getGui().getForms().size()); - assertTrue(context.getGui().getForms().get("select_1") instanceof Select); - Select select = (Select) context.getGui().getForms().get("select_1"); - assertEquals("select_1", select.getName()); - assertEquals(2, select.getOptions().length); - assertEquals("name_1", select.getOptions()[0].getDisplayName()); - assertEquals("value_1", select.getOptions()[0].getValue()); + startInterpreter(properties); - // CheckBox - context = getInterpreterContext(); - result = interpreter.interpret("z.checkbox(name='checkbox_1', options=[('value_1', 'name_1'), ('value_2', 'name_2')])", context); + // to make this test can run under both python2 and python3 + InterpreterResult result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - assertEquals(1, context.getGui().getForms().size()); - assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox); - CheckBox checkbox = (CheckBox) context.getGui().getForms().get("checkbox_1"); - assertEquals("checkbox_1", checkbox.getName()); - assertEquals(2, checkbox.getOptions().length); - assertEquals("name_1", checkbox.getOptions()[0].getDisplayName()); - assertEquals("value_1", checkbox.getOptions()[0].getValue()); - // Pandas DataFrame - context = getInterpreterContext(); - result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3], 'name':['a','b','c']})\nz.show(df)", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); + InterpreterContext context = getInterpreterContext(); + result = interpreter.interpret("print('1'*3000)", context); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + List<InterpreterResultMessage> interpreterResultMessages = context.out.toInterpreterResultMessage(); assertEquals(1, interpreterResultMessages.size()); - assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); - assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); + assertTrue(interpreterResultMessages.get(0).getData().contains("exceeds maximum: 3000")); - context = getInterpreterContext(); - result = interpreter.interpret("import pandas as pd\ndf = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context); + // next call continue work + result = interpreter.interpret("print(1)", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(2, interpreterResultMessages.size()); - assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); - assertEquals("id\tname\n1\ta\n2\tb\n3\tc\n", interpreterResultMessages.get(0).getData()); - assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType()); - assertEquals("<font color=red>Results are limited by 3.</font>\n", interpreterResultMessages.get(1).getData()); - // z.show(matplotlib) - context = getInterpreterContext(); - result = interpreter.interpret("import matplotlib.pyplot as plt\ndata=[1,1,2,3,4]\nplt.figure()\nplt.plot(data)\nz.show(plt)", context); + tearDown(); + + // increase framesize to make it work + properties.setProperty("zeppelin.ipython.grpc.message_size", "5000"); + startInterpreter(properties); + // to make this test can run under both python2 and python3 + result = interpreter.interpret("from __future__ import print_function", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(2, interpreterResultMessages.size()); - assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType()); - assertEquals(InterpreterResult.Type.IMG, interpreterResultMessages.get(1).getType()); - // clear output context = getInterpreterContext(); - result = interpreter.interpret("import time\nprint(\"Hello\")\ntime.sleep(0.5)\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context); - assertEquals("%text world\n", context.out.getCurrentOutput().toString()); + result = interpreter.interpret("print('1'*3000)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); } - private static InterpreterContext getInterpreterContext() { - return new InterpreterContext( - "noteId", - "paragraphId", - "replName", - "paragraphTitle", - "paragraphText", - new AuthenticationInfo(), - new HashMap<String, Object>(), - new GUI(), - new GUI(), - null, - null, - null, - new InterpreterOutput(null)); - } } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java ---------------------------------------------------------------------- diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java index c750352..f1be1b9 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonCondaInterpreterTest.java @@ -39,7 +39,9 @@ public class PythonCondaInterpreterTest { @Before public void setUp() throws InterpreterException { conda = spy(new PythonCondaInterpreter(new Properties())); + when(conda.getClassName()).thenReturn(PythonCondaInterpreter.class.getName()); python = mock(PythonInterpreter.class); + when(python.getClassName()).thenReturn(PythonInterpreter.class.getName()); InterpreterGroup group = new InterpreterGroup(); group.put("note", Arrays.asList(python, conda)); @@ -79,7 +81,7 @@ public class PythonCondaInterpreterTest { conda.interpret("activate " + envname, context); verify(python, times(1)).open(); verify(python, times(1)).close(); - verify(python).setPythonCommand("/path1/bin/python"); + verify(python).setPythonExec("/path1/bin/python"); assertTrue(envname.equals(conda.getCurrentCondaEnvName())); } @@ -89,7 +91,7 @@ public class PythonCondaInterpreterTest { conda.interpret("deactivate", context); verify(python, times(1)).open(); verify(python, times(1)).close(); - verify(python).setPythonCommand("python"); + verify(python).setPythonExec("python"); assertTrue(conda.getCurrentCondaEnvName().isEmpty()); } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java ---------------------------------------------------------------------- diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java index 5634630..17f6cc1 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonDockerInterpreterTest.java @@ -17,24 +17,27 @@ package org.apache.zeppelin.python; import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; import org.apache.zeppelin.user.AuthenticationInfo; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; -import java.io.IOException; -import java.net.Inet4Address; -import java.net.UnknownHostException; +import java.io.File; import java.util.Arrays; import java.util.HashMap; import java.util.Properties; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; public class PythonDockerInterpreterTest { private PythonDockerInterpreter docker; @@ -52,7 +55,7 @@ public class PythonDockerInterpreterTest { doReturn(true).when(docker).pull(any(InterpreterOutput.class), anyString()); doReturn(python).when(docker).getPythonInterpreter(); - doReturn("/scriptpath/zeppelin_python.py").when(python).getScriptPath(); + doReturn(new File("/scriptpath")).when(python).getPythonWorkDir(); docker.open(); } @@ -64,7 +67,7 @@ public class PythonDockerInterpreterTest { verify(python, times(1)).open(); verify(python, times(1)).close(); verify(docker, times(1)).pull(any(InterpreterOutput.class), anyString()); - verify(python).setPythonCommand(Mockito.matches("docker run -i --rm -v.*")); + verify(python).setPythonExec(Mockito.matches("docker run -i --rm -v.*")); } @Test @@ -73,7 +76,7 @@ public class PythonDockerInterpreterTest { docker.interpret("deactivate", context); verify(python, times(1)).open(); verify(python, times(1)).close(); - verify(python).setPythonCommand(null); + verify(python).setPythonExec(null); } private InterpreterContext getInterpreterContext() { http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java ---------------------------------------------------------------------- diff --git a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java index c0beccb..e750dde 100644 --- a/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/PythonInterpreterTest.java @@ -17,130 +17,91 @@ package org.apache.zeppelin.python; -import static org.apache.zeppelin.python.PythonInterpreter.DEFAULT_ZEPPELIN_PYTHON; -import static org.apache.zeppelin.python.PythonInterpreter.MAX_RESULT; -import static org.apache.zeppelin.python.PythonInterpreter.ZEPPELIN_PYTHON; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import java.io.File; -import java.io.IOException; -import java.net.URISyntaxException; -import java.net.URL; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Map; -import java.util.Properties; - -import org.apache.commons.exec.environment.EnvironmentUtils; -import org.apache.zeppelin.display.AngularObjectRegistry; -import org.apache.zeppelin.display.GUI; import org.apache.zeppelin.interpreter.Interpreter; import org.apache.zeppelin.interpreter.InterpreterContext; -import org.apache.zeppelin.interpreter.InterpreterContextRunner; import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterGroup; -import org.apache.zeppelin.interpreter.InterpreterOutput; -import org.apache.zeppelin.interpreter.InterpreterOutputListener; import org.apache.zeppelin.interpreter.InterpreterResult; -import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput; -import org.apache.zeppelin.resource.LocalResourcePool; -import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.After; -import org.junit.Before; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.junit.Test; -public class PythonInterpreterTest implements InterpreterOutputListener { - PythonInterpreter pythonInterpreter = null; - String cmdHistory; - private InterpreterContext context; - InterpreterOutput out; - - public static Properties getPythonTestProperties() { - Properties p = new Properties(); - p.setProperty(ZEPPELIN_PYTHON, DEFAULT_ZEPPELIN_PYTHON); - p.setProperty(MAX_RESULT, "1000"); - p.setProperty("zeppelin.python.useIPython", "false"); - return p; - } - - @Before - public void beforeTest() throws IOException, InterpreterException { - cmdHistory = ""; - - // python interpreter - pythonInterpreter = new PythonInterpreter(getPythonTestProperties()); - - // create interpreter group - InterpreterGroup group = new InterpreterGroup(); - group.put("note", new LinkedList<Interpreter>()); - group.get("note").add(pythonInterpreter); - pythonInterpreter.setInterpreterGroup(group); - - out = new InterpreterOutput(this); +import java.io.IOException; +import java.util.LinkedList; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; - context = new InterpreterContext("note", "id", null, "title", "text", - new AuthenticationInfo(), - new HashMap<String, Object>(), - new GUI(), - new GUI(), - new AngularObjectRegistry(group.getId(), null), - new LocalResourcePool("id"), - new LinkedList<InterpreterContextRunner>(), - out); - InterpreterContext.set(context); - pythonInterpreter.open(); - } +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; - @After - public void afterTest() throws IOException, InterpreterException { - pythonInterpreter.close(); - } +public class PythonInterpreterTest extends BasePythonInterpreterTest { - @Test - public void testInterpret() throws InterruptedException, IOException, InterpreterException { - InterpreterResult result = pythonInterpreter.interpret("print (\"hi\")", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - } + @Override + public void setUp() throws InterpreterException { - @Test - public void testInterpretInvalidSyntax() throws IOException, InterpreterException { - InterpreterResult result = pythonInterpreter.interpret("for x in range(0,3): print (\"hi\")\n", context); - assertEquals(InterpreterResult.Code.SUCCESS, result.code()); - assertTrue(new String(out.getOutputAt(0).toByteArray()).contains("hi\nhi\nhi")); - } + intpGroup = new InterpreterGroup(); - @Test - public void testRedefinitionZeppelinContext() throws InterpreterException { - String pyRedefinitionCode = "z = 1\n"; - String pyRestoreCode = "z = __zeppelin__\n"; - String pyValidCode = "z.input(\"test\")\n"; + Properties properties = new Properties(); + properties.setProperty("zeppelin.python.maxResult", "3"); + properties.setProperty("zeppelin.python.useIPython", "false"); + properties.setProperty("zeppelin.python.gatewayserver_address", "127.0.0.1"); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyValidCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyRedefinitionCode, context).code()); - assertEquals(InterpreterResult.Code.ERROR, pythonInterpreter.interpret(pyValidCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyRestoreCode, context).code()); - assertEquals(InterpreterResult.Code.SUCCESS, pythonInterpreter.interpret(pyValidCode, context).code()); - } + interpreter = new LazyOpenInterpreter(new PythonInterpreter(properties)); + intpGroup.put("note", new LinkedList<Interpreter>()); + intpGroup.get("note").add(interpreter); + interpreter.setInterpreterGroup(intpGroup); - @Test - public void testOutputClear() throws InterpreterException { - InterpreterResult result = pythonInterpreter.interpret("print(\"Hello\")\nz.getInterpreterContext().out().clear()\nprint(\"world\")\n", context); - assertEquals("%text world\n", out.getCurrentOutput().toString()); + InterpreterContext.set(getInterpreterContext()); + interpreter.open(); } @Override - public void onUpdateAll(InterpreterOutput out) { - + public void tearDown() throws InterpreterException { + intpGroup.close(); } @Override - public void onAppend(int index, InterpreterResultMessageOutput out, byte[] line) { - + public void testCodeCompletion() throws InterpreterException, IOException, InterruptedException { + super.testCodeCompletion(); + + //TODO(zjffdu) PythonInterpreter doesn't support this kind of code completion for now. + // completion + // InterpreterContext context = getInterpreterContext(); + // List<InterpreterCompletion> completions = interpreter.completion("ab", 2, context); + // assertEquals(2, completions.size()); + // assertEquals("abc", completions.get(0).getValue()); + // assertEquals("abs", completions.get(1).getValue()); } - @Override - public void onUpdate(int index, InterpreterResultMessageOutput out) { + private class infinityPythonJob implements Runnable { + @Override + public void run() { + String code = "import time\nwhile True:\n time.sleep(1)" ; + InterpreterResult ret = null; + try { + ret = interpreter.interpret(code, getInterpreterContext()); + } catch (InterpreterException e) { + e.printStackTrace(); + } + assertNotNull(ret); + Pattern expectedMessage = Pattern.compile("KeyboardInterrupt"); + Matcher m = expectedMessage.matcher(ret.message().toString()); + assertTrue(m.find()); + } + } + @Test + public void testCancelIntp() throws InterruptedException, InterpreterException { + assertEquals(InterpreterResult.Code.SUCCESS, + interpreter.interpret("a = 1\n", getInterpreterContext()).code()); + Thread t = new Thread(new infinityPythonJob()); + t.start(); + Thread.sleep(5000); + interpreter.cancel(getInterpreterContext()); + assertTrue(t.isAlive()); + t.join(2000); + assertFalse(t.isAlive()); } } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/python/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/python/src/test/resources/log4j.properties b/python/src/test/resources/log4j.properties index 035c7a3..8993ff2 100644 --- a/python/src/test/resources/log4j.properties +++ b/python/src/test/resources/log4j.properties @@ -15,18 +15,13 @@ # limitations under the License. # +# Root logger option +log4j.rootLogger=INFO, stdout + # Direct log messages to stdout log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.Target=System.out log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{ABSOLUTE} %5p %c:%L - %m%n -#log4j.appender.stdout.layout.ConversionPattern= -#%5p [%t] (%F:%L) - %m%n -#%-4r [%t] %-5p %c %x - %m%n -# +log4j.appender.stdout.layout.ConversionPattern=%5p [%d] ({%t} %F[%M]:%L) - %m%n -# Root logger option -log4j.rootLogger=INFO, stdout -log4j.logger.org.apache.zeppelin.python.IPythonInterpreter=DEBUG -log4j.logger.org.apache.zeppelin.python.IPythonClient=DEBUG -log4j.logger.org.apache.zeppelin.python=DEBUG \ No newline at end of file + +log4j.logger.org.apache.zeppelin.python=DEBUG http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/pom.xml ---------------------------------------------------------------------- diff --git a/spark/interpreter/pom.xml b/spark/interpreter/pom.xml index c89cfa6..5330b1c 100644 --- a/spark/interpreter/pom.xml +++ b/spark/interpreter/pom.xml @@ -441,14 +441,14 @@ <configuration> <forkCount>1</forkCount> <reuseForks>false</reuseForks> - <argLine>-Xmx1024m -XX:MaxPermSize=256m</argLine> + <argLine>-Xmx1536m -XX:MaxPermSize=256m</argLine> <excludes> <exclude>**/SparkRInterpreterTest.java</exclude> <exclude>${pyspark.test.exclude}</exclude> <exclude>${tests.to.exclude}</exclude> </excludes> <environmentVariables> - <PYTHONPATH>${project.build.directory}/../../../interpreter/spark/pyspark/pyspark.zip:${project.build.directory}/../../../interpreter/lib/python/:${project.build.directory}/../../../interpreter/spark/pyspark/py4j-${py4j.version}-src.zip:.</PYTHONPATH> + <PYTHONPATH>${project.build.directory}/../../../interpreter/spark/pyspark/pyspark.zip:${project.build.directory}/../../../interpreter/spark/pyspark/py4j-${py4j.version}-src.zip</PYTHONPATH> <ZEPPELIN_HOME>${basedir}/../../</ZEPPELIN_HOME> </environmentVariables> </configuration> http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java index 3691156..3896cba 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/IPySparkInterpreter.java @@ -27,6 +27,7 @@ import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.LazyOpenInterpreter; import org.apache.zeppelin.interpreter.WrappedInterpreter; import org.apache.zeppelin.python.IPythonInterpreter; +import org.apache.zeppelin.python.PythonInterpreter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,8 +50,8 @@ public class IPySparkInterpreter extends IPythonInterpreter { @Override public void open() throws InterpreterException { - setProperty("zeppelin.python", - PySparkInterpreter.getPythonExec(getProperties())); + PySparkInterpreter pySparkInterpreter = getPySparkInterpreter(); + setProperty("zeppelin.python", pySparkInterpreter.getPythonExec()); sparkInterpreter = getSparkInterpreter(); SparkConf conf = sparkInterpreter.getSparkContext().getConf(); // only set PYTHONPATH in embedded, local or yarn-client mode. @@ -94,6 +95,16 @@ public class IPySparkInterpreter extends IPythonInterpreter { return spark; } + private PySparkInterpreter getPySparkInterpreter() throws InterpreterException { + PySparkInterpreter pySpark = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(PySparkInterpreter.class.getName()); + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + pySpark = (PySparkInterpreter) p; + return pySpark; + } + @Override public BaseZeppelinContext buildZeppelinContext() { return sparkInterpreter.getZeppelinContext(); @@ -117,6 +128,7 @@ public class IPySparkInterpreter extends IPythonInterpreter { @Override public void close() throws InterpreterException { + LOGGER.info("Close IPySparkInterpreter"); super.close(); if (sparkInterpreter != null) { sparkInterpreter.close(); http://git-wip-us.apache.org/repos/asf/zeppelin/blob/0a97446a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java index c8efa7a..9b629f9 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java @@ -56,7 +56,7 @@ import java.util.Properties; */ public class NewSparkInterpreter extends AbstractSparkInterpreter { - private static final Logger LOGGER = LoggerFactory.getLogger(SparkInterpreter.class); + private static final Logger LOGGER = LoggerFactory.getLogger(NewSparkInterpreter.class); private BaseSparkScalaInterpreter innerInterpreter; private Map<String, String> innerInterpreterClassMap = new HashMap<>(); @@ -177,7 +177,10 @@ public class NewSparkInterpreter extends AbstractSparkInterpreter { @Override public void close() { LOGGER.info("Close SparkInterpreter"); - innerInterpreter.close(); + if (innerInterpreter != null) { + innerInterpreter.close(); + innerInterpreter = null; + } } @Override