http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java new file mode 100644 index 0000000..9fce093 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import java.io.BufferedWriter; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.net.ServerSocket; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import org.apache.commons.compress.utils.IOUtils; +import org.apache.commons.exec.CommandLine; +import org.apache.commons.exec.DefaultExecutor; +import org.apache.commons.exec.ExecuteException; +import org.apache.commons.exec.ExecuteResultHandler; +import org.apache.commons.exec.ExecuteWatchdog; +import org.apache.commons.exec.PumpStreamHandler; +import org.apache.commons.exec.environment.EnvironmentUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterPropertyBuilder; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import py4j.GatewayServer; + +/** + * + */ +public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler { + Logger logger = LoggerFactory.getLogger(PySparkInterpreter.class); + private GatewayServer gatewayServer; + private DefaultExecutor executor; + private int port; + private ByteArrayOutputStream outputStream; + private ByteArrayOutputStream errStream; + private BufferedWriter ins; + private PipedInputStream in; + private ByteArrayOutputStream input; + private String scriptPath; + boolean pythonscriptRunning = false; + + static { + Interpreter.register( + "pyspark", + "spark", + PySparkInterpreter.class.getName(), + new InterpreterPropertyBuilder() + .add("spark.home", + SparkInterpreter.getSystemDefault("SPARK_HOME", "spark.home", ""), + "Spark home path. Should be provided for pyspark") + .add("zeppelin.pyspark.python", + SparkInterpreter.getSystemDefault("PYSPARK_PYTHON", null, "python"), + "Python command to run pyspark with").build()); + } + + public PySparkInterpreter(Properties property) { + super(property); + + scriptPath = System.getProperty("java.io.tmpdir") + "/zeppelin_pyspark.py"; + } + + private String getSparkHome() { + String sparkHome = getProperty("spark.home"); + if (sparkHome == null) { + throw new InterpreterException("spark.home is undefined"); + } else { + return sparkHome; + } + } + + + private void createPythonScript() { + ClassLoader classLoader = getClass().getClassLoader(); + File out = new File(scriptPath); + + if (out.exists() && out.isDirectory()) { + throw new InterpreterException("Can't create python script " + out.getAbsolutePath()); + } + + try { + FileOutputStream outStream = new FileOutputStream(out); + IOUtils.copy( + classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), + outStream); + outStream.close(); + } catch (IOException e) { + throw new InterpreterException(e); + } + + logger.info("File {} created", scriptPath); + } + + @Override + public void open() { + // create python script + createPythonScript(); + + port = findRandomOpenPortOnAllLocalInterfaces(); + + gatewayServer = new GatewayServer(this, port); + gatewayServer.start(); + + // Run python shell + CommandLine cmd = CommandLine.parse(getProperty("zeppelin.pyspark.python")); + cmd.addArgument(scriptPath, false); + cmd.addArgument(Integer.toString(port), false); + executor = new DefaultExecutor(); + outputStream = new ByteArrayOutputStream(); + PipedOutputStream ps = new PipedOutputStream(); + in = null; + try { + in = new PipedInputStream(ps); + } catch (IOException e1) { + throw new InterpreterException(e1); + } + ins = new BufferedWriter(new OutputStreamWriter(ps)); + + input = new ByteArrayOutputStream(); + + PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in); + executor.setStreamHandler(streamHandler); + executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); + + + try { + Map env = EnvironmentUtils.getProcEnvironment(); + + String pythonPath = (String) env.get("PYTHONPATH"); + if (pythonPath == null) { + pythonPath = ""; + } else { + pythonPath += ":"; + } + + pythonPath += getSparkHome() + "/python/lib/py4j-0.8.2.1-src.zip:" + + getSparkHome() + "/python"; + + env.put("PYTHONPATH", pythonPath); + + executor.execute(cmd, env, this); + pythonscriptRunning = true; + } catch (IOException e) { + throw new InterpreterException(e); + } + + + try { + input.write("import sys, getopt\n".getBytes()); + ins.flush(); + } catch (IOException e) { + throw new InterpreterException(e); + } + } + + private int findRandomOpenPortOnAllLocalInterfaces() { + int port; + try (ServerSocket socket = new ServerSocket(0);) { + port = socket.getLocalPort(); + socket.close(); + } catch (IOException e) { + throw new InterpreterException(e); + } + return port; + } + + @Override + public void close() { + executor.getWatchdog().destroyProcess(); + gatewayServer.shutdown(); + } + + PythonInterpretRequest pythonInterpretRequest = null; + + /** + * + */ + public class PythonInterpretRequest { + public String statements; + public String jobGroup; + + public PythonInterpretRequest(String statements, String jobGroup) { + this.statements = statements; + this.jobGroup = jobGroup; + } + + public String statements() { + return statements; + } + + public String jobGroup() { + return jobGroup; + } + } + + Integer statementSetNotifier = new Integer(0); + + public PythonInterpretRequest getStatements() { + synchronized (statementSetNotifier) { + while (pythonInterpretRequest == null) { + try { + statementSetNotifier.wait(1000); + } catch (InterruptedException e) { + } + } + PythonInterpretRequest req = pythonInterpretRequest; + pythonInterpretRequest = null; + return req; + } + } + + String statementOutput = null; + boolean statementError = false; + Integer statementFinishedNotifier = new Integer(0); + + public void setStatementsFinished(String out, boolean error) { + synchronized (statementFinishedNotifier) { + statementOutput = out; + statementError = error; + statementFinishedNotifier.notify(); + } + + } + + boolean pythonScriptInitialized = false; + Integer pythonScriptInitializeNotifier = new Integer(0); + + public void onPythonScriptInitialized() { + synchronized (pythonScriptInitializeNotifier) { + pythonScriptInitialized = true; + pythonScriptInitializeNotifier.notifyAll(); + } + } + + @Override + public InterpreterResult interpret(String st, InterpreterContext context) { + if (!pythonscriptRunning) { + return new InterpreterResult(Code.ERROR, "python process not running" + + outputStream.toString()); + } + + outputStream.reset(); + + synchronized (pythonScriptInitializeNotifier) { + long startTime = System.currentTimeMillis(); + while (pythonScriptInitialized == false + && pythonscriptRunning + && System.currentTimeMillis() - startTime < 10 * 1000) { + try { + pythonScriptInitializeNotifier.wait(1000); + } catch (InterruptedException e) { + } + } + } + + if (pythonscriptRunning == false) { + // python script failed to initialize and terminated + return new InterpreterResult(Code.ERROR, "failed to start pyspark" + + outputStream.toString()); + } + if (pythonScriptInitialized == false) { + // timeout. didn't get initialized message + return new InterpreterResult(Code.ERROR, "pyspark is not responding " + + outputStream.toString()); + } + + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + if (!sparkInterpreter.getSparkContext().version().startsWith("1.2") && + !sparkInterpreter.getSparkContext().version().startsWith("1.3")) { + return new InterpreterResult(Code.ERROR, "pyspark " + + sparkInterpreter.getSparkContext().version() + " is not supported"); + } + String jobGroup = sparkInterpreter.getJobGroup(context); + ZeppelinContext z = sparkInterpreter.getZeppelinContext(); + z.setInterpreterContext(context); + z.setGui(context.getGui()); + pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup); + statementOutput = null; + + synchronized (statementSetNotifier) { + statementSetNotifier.notify(); + } + + synchronized (statementFinishedNotifier) { + while (statementOutput == null) { + try { + statementFinishedNotifier.wait(1000); + } catch (InterruptedException e) { + } + } + } + + if (statementError) { + return new InterpreterResult(Code.ERROR, statementOutput); + } else { + return new InterpreterResult(Code.SUCCESS, statementOutput); + } + } + + @Override + public void cancel(InterpreterContext context) { + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + sparkInterpreter.cancel(context); + } + + @Override + public FormType getFormType() { + return FormType.NATIVE; + } + + @Override + public int getProgress(InterpreterContext context) { + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + return sparkInterpreter.getProgress(context); + } + + @Override + public List<String> completion(String buf, int cursor) { + // not supported + return new LinkedList<String>(); + } + + private SparkInterpreter getSparkInterpreter() { + InterpreterGroup intpGroup = getInterpreterGroup(); + synchronized (intpGroup) { + for (Interpreter intp : getInterpreterGroup()){ + if (intp.getClassName().equals(SparkInterpreter.class.getName())) { + Interpreter p = intp; + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + ((LazyOpenInterpreter) p).open(); + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + return (SparkInterpreter) p; + } + } + } + return null; + } + + public ZeppelinContext getZeppelinContext() { + SparkInterpreter sparkIntp = getSparkInterpreter(); + if (sparkIntp != null) { + return getSparkInterpreter().getZeppelinContext(); + } else { + return null; + } + } + + public JavaSparkContext getJavaSparkContext() { + SparkInterpreter intp = getSparkInterpreter(); + if (intp == null) { + return null; + } else { + return new JavaSparkContext(intp.getSparkContext()); + } + } + + public SparkConf getSparkConf() { + JavaSparkContext sc = getJavaSparkContext(); + if (sc == null) { + return null; + } else { + return getJavaSparkContext().getConf(); + } + } + + public SQLContext getSQLContext() { + SparkInterpreter intp = getSparkInterpreter(); + if (intp == null) { + return null; + } else { + return intp.getSQLContext(); + } + } + + + @Override + public void onProcessComplete(int exitValue) { + pythonscriptRunning = false; + logger.info("python process terminated. exit code " + exitValue); + } + + @Override + public void onProcessFailed(ExecuteException e) { + pythonscriptRunning = false; + logger.error("python process failed", e); + } +}
http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java new file mode 100644 index 0000000..71c5ab5 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.PrintStream; +import java.io.PrintWriter; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import org.apache.spark.HttpServer; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.SparkEnv; +import org.apache.spark.repl.SparkCommandLine; +import org.apache.spark.repl.SparkILoop; +import org.apache.spark.repl.SparkIMain; +import org.apache.spark.repl.SparkJLineCompletion; +import org.apache.spark.scheduler.ActiveJob; +import org.apache.spark.scheduler.DAGScheduler; +import org.apache.spark.scheduler.Pool; +import org.apache.spark.scheduler.Stage; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.ui.jobs.JobProgressListener; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterPropertyBuilder; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.InterpreterUtils; +import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.scheduler.Scheduler; +import org.apache.zeppelin.scheduler.SchedulerFactory; +import org.apache.zeppelin.spark.dep.DependencyContext; +import org.apache.zeppelin.spark.dep.DependencyResolver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import scala.Console; +import scala.Enumeration.Value; +import scala.None; +import scala.Some; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConversions; +import scala.collection.JavaConverters; +import scala.collection.mutable.HashMap; +import scala.collection.mutable.HashSet; +import scala.tools.nsc.Settings; +import scala.tools.nsc.interpreter.Completion.Candidates; +import scala.tools.nsc.interpreter.Completion.ScalaCompleter; +import scala.tools.nsc.settings.MutableSettings.BooleanSetting; +import scala.tools.nsc.settings.MutableSettings.PathSetting; + +/** + * Spark interpreter for Zeppelin. + * + */ +public class SparkInterpreter extends Interpreter { + Logger logger = LoggerFactory.getLogger(SparkInterpreter.class); + + static { + Interpreter.register( + "spark", + "spark", + SparkInterpreter.class.getName(), + new InterpreterPropertyBuilder() + .add("spark.app.name", "Zeppelin", "The name of spark application.") + .add("master", + getSystemDefault("MASTER", "spark.master", "local[*]"), + "Spark master uri. ex) spark://masterhost:7077") + .add("spark.executor.memory", + getSystemDefault(null, "spark.executor.memory", "512m"), + "Executor memory per worker instance. ex) 512m, 32g") + .add("spark.cores.max", + getSystemDefault(null, "spark.cores.max", ""), + "Total number of cores to use. Empty value uses all available core.") + .add("spark.yarn.jar", + getSystemDefault("SPARK_YARN_JAR", "spark.yarn.jar", ""), + "The location of the Spark jar file. If you use yarn as a cluster, " + + "we should set this value") + .add("zeppelin.spark.useHiveContext", "true", + "Use HiveContext instead of SQLContext if it is true.") + .add("args", "", "spark commandline args").build()); + + } + + private ZeppelinContext z; + private SparkILoop interpreter; + private SparkIMain intp; + private SparkContext sc; + private ByteArrayOutputStream out; + private SQLContext sqlc; + private DependencyResolver dep; + private SparkJLineCompletion completor; + + private JobProgressListener sparkListener; + + private Map<String, Object> binder; + private SparkEnv env; + + + public SparkInterpreter(Properties property) { + super(property); + out = new ByteArrayOutputStream(); + } + + public SparkInterpreter(Properties property, SparkContext sc) { + this(property); + + this.sc = sc; + env = SparkEnv.get(); + sparkListener = setupListeners(this.sc); + } + + public synchronized SparkContext getSparkContext() { + if (sc == null) { + sc = createSparkContext(); + env = SparkEnv.get(); + sparkListener = setupListeners(sc); + } + return sc; + } + + public boolean isSparkContextInitialized() { + return sc != null; + } + + private static JobProgressListener setupListeners(SparkContext context) { + JobProgressListener pl = new JobProgressListener(context.getConf()); + context.listenerBus().addListener(pl); + return pl; + } + + private boolean useHiveContext() { + return Boolean.parseBoolean(getProperty("zeppelin.spark.useHiveContext")); + } + + public SQLContext getSQLContext() { + if (sqlc == null) { + if (useHiveContext()) { + String name = "org.apache.spark.sql.hive.HiveContext"; + Constructor<?> hc; + try { + hc = getClass().getClassLoader().loadClass(name) + .getConstructor(SparkContext.class); + sqlc = (SQLContext) hc.newInstance(getSparkContext()); + } catch (NoSuchMethodException | SecurityException + | ClassNotFoundException | InstantiationException + | IllegalAccessException | IllegalArgumentException + | InvocationTargetException e) { + logger.warn("Can't create HiveContext. Fallback to SQLContext", e); + // when hive dependency is not loaded, it'll fail. + // in this case SQLContext can be used. + sqlc = new SQLContext(getSparkContext()); + } + } else { + sqlc = new SQLContext(getSparkContext()); + } + } + + return sqlc; + } + + public DependencyResolver getDependencyResolver() { + if (dep == null) { + dep = new DependencyResolver(intp, sc, getProperty("zeppelin.dep.localrepo")); + } + return dep; + } + + private DepInterpreter getDepInterpreter() { + InterpreterGroup intpGroup = getInterpreterGroup(); + if (intpGroup == null) return null; + synchronized (intpGroup) { + for (Interpreter intp : intpGroup) { + if (intp.getClassName().equals(DepInterpreter.class.getName())) { + Interpreter p = intp; + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + return (DepInterpreter) p; + } + } + } + return null; + } + + public SparkContext createSparkContext() { + System.err.println("------ Create new SparkContext " + getProperty("master") + " -------"); + + String execUri = System.getenv("SPARK_EXECUTOR_URI"); + String[] jars = SparkILoop.getAddedJars(); + + String classServerUri = null; + + try { // in case of spark 1.1x, spark 1.2x + Method classServer = interpreter.intp().getClass().getMethod("classServer"); + HttpServer httpServer = (HttpServer) classServer.invoke(interpreter.intp()); + classServerUri = httpServer.uri(); + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + // continue + } + + if (classServerUri == null) { + try { // for spark 1.3x + Method classServer = interpreter.intp().getClass().getMethod("classServerUri"); + classServerUri = (String) classServer.invoke(interpreter.intp()); + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + throw new InterpreterException(e); + } + } + + SparkConf conf = + new SparkConf() + .setMaster(getProperty("master")) + .setAppName(getProperty("spark.app.name")) + .setJars(jars) + .set("spark.repl.class.uri", classServerUri); + + if (execUri != null) { + conf.set("spark.executor.uri", execUri); + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")); + } + conf.set("spark.scheduler.mode", "FAIR"); + + Properties intpProperty = getProperty(); + + for (Object k : intpProperty.keySet()) { + String key = (String) k; + Object value = intpProperty.get(key); + if (!isEmptyString(value)) { + logger.debug(String.format("SparkConf: key = [%s], value = [%s]", key, value)); + conf.set(key, (String) value); + } + } + + SparkContext sparkContext = new SparkContext(conf); + return sparkContext; + } + + public static boolean isEmptyString(Object val) { + return val instanceof String && ((String) val).trim().isEmpty(); + } + + public static String getSystemDefault( + String envName, + String propertyName, + String defaultValue) { + + if (envName != null && !envName.isEmpty()) { + String envValue = System.getenv().get(envName); + if (envValue != null) { + return envValue; + } + } + + if (propertyName != null && !propertyName.isEmpty()) { + String propValue = System.getProperty(propertyName); + if (propValue != null) { + return propValue; + } + } + return defaultValue; + } + + @Override + public void open() { + URL[] urls = getClassloaderUrls(); + + // Very nice discussion about how scala compiler handle classpath + // https://groups.google.com/forum/#!topic/scala-user/MlVwo2xCCI0 + + /* + * > val env = new nsc.Settings(errLogger) > env.usejavacp.value = true > val p = new + * Interpreter(env) > p.setContextClassLoader > Alternatively you can set the class path through + * nsc.Settings.classpath. + * + * >> val settings = new Settings() >> settings.usejavacp.value = true >> + * settings.classpath.value += File.pathSeparator + >> System.getProperty("java.class.path") >> + * val in = new Interpreter(settings) { >> override protected def parentClassLoader = + * getClass.getClassLoader >> } >> in.setContextClassLoader() + */ + Settings settings = new Settings(); + if (getProperty("args") != null) { + String[] argsArray = getProperty("args").split(" "); + LinkedList<String> argList = new LinkedList<String>(); + for (String arg : argsArray) { + argList.add(arg); + } + + SparkCommandLine command = + new SparkCommandLine(scala.collection.JavaConversions.asScalaBuffer( + argList).toList()); + settings = command.settings(); + } + + // set classpath for scala compiler + PathSetting pathSettings = settings.classpath(); + String classpath = ""; + List<File> paths = currentClassPath(); + for (File f : paths) { + if (classpath.length() > 0) { + classpath += File.pathSeparator; + } + classpath += f.getAbsolutePath(); + } + + if (urls != null) { + for (URL u : urls) { + if (classpath.length() > 0) { + classpath += File.pathSeparator; + } + classpath += u.getFile(); + } + } + + // add dependency from DepInterpreter + DepInterpreter depInterpreter = getDepInterpreter(); + if (depInterpreter != null) { + DependencyContext depc = depInterpreter.getDependencyContext(); + if (depc != null) { + List<File> files = depc.getFiles(); + if (files != null) { + for (File f : files) { + if (classpath.length() > 0) { + classpath += File.pathSeparator; + } + classpath += f.getAbsolutePath(); + } + } + } + } + + pathSettings.v_$eq(classpath); + settings.scala$tools$nsc$settings$ScalaSettings$_setter_$classpath_$eq(pathSettings); + + + // set classloader for scala compiler + settings.explicitParentLoader_$eq(new Some<ClassLoader>(Thread.currentThread() + .getContextClassLoader())); + BooleanSetting b = (BooleanSetting) settings.usejavacp(); + b.v_$eq(true); + settings.scala$tools$nsc$settings$StandardScalaSettings$_setter_$usejavacp_$eq(b); + + PrintStream printStream = new PrintStream(out); + + /* spark interpreter */ + this.interpreter = new SparkILoop(null, new PrintWriter(out)); + interpreter.settings_$eq(settings); + + interpreter.createInterpreter(); + + intp = interpreter.intp(); + intp.setContextClassLoader(); + intp.initializeSynchronous(); + + completor = new SparkJLineCompletion(intp); + + sc = getSparkContext(); + if (sc.getPoolForName("fair").isEmpty()) { + Value schedulingMode = org.apache.spark.scheduler.SchedulingMode.FAIR(); + int minimumShare = 0; + int weight = 1; + Pool pool = new Pool("fair", schedulingMode, minimumShare, weight); + sc.taskScheduler().rootPool().addSchedulable(pool); + } + + sqlc = getSQLContext(); + + dep = getDependencyResolver(); + + z = new ZeppelinContext(sc, sqlc, null, dep, printStream); + + try { + if (sc.version().startsWith("1.1") || sc.version().startsWith("1.2")) { + Method loadFiles = this.interpreter.getClass().getMethod("loadFiles", Settings.class); + loadFiles.invoke(this.interpreter, settings); + } else if (sc.version().startsWith("1.3")) { + Method loadFiles = this.interpreter.getClass().getMethod( + "org$apache$spark$repl$SparkILoop$$loadFiles", Settings.class); + loadFiles.invoke(this.interpreter, settings); + } + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + throw new InterpreterException(e); + } + + + intp.interpret("@transient var _binder = new java.util.HashMap[String, Object]()"); + binder = (Map<String, Object>) getValue("_binder"); + binder.put("sc", sc); + binder.put("sqlc", sqlc); + binder.put("z", z); + binder.put("out", printStream); + + intp.interpret("@transient val z = " + + "_binder.get(\"z\").asInstanceOf[org.apache.zeppelin.spark.ZeppelinContext]"); + intp.interpret("@transient val sc = " + + "_binder.get(\"sc\").asInstanceOf[org.apache.spark.SparkContext]"); + intp.interpret("@transient val sqlc = " + + "_binder.get(\"sqlc\").asInstanceOf[org.apache.spark.sql.SQLContext]"); + intp.interpret("@transient val sqlContext = " + + "_binder.get(\"sqlc\").asInstanceOf[org.apache.spark.sql.SQLContext]"); + intp.interpret("import org.apache.spark.SparkContext._"); + + if (sc.version().startsWith("1.1")) { + intp.interpret("import sqlContext._"); + } else if (sc.version().startsWith("1.2")) { + intp.interpret("import sqlContext._"); + } else if (sc.version().startsWith("1.3")) { + intp.interpret("import sqlContext.implicits._"); + intp.interpret("import sqlContext.sql"); + intp.interpret("import org.apache.spark.sql.functions._"); + } + + // add jar + if (depInterpreter != null) { + DependencyContext depc = depInterpreter.getDependencyContext(); + if (depc != null) { + List<File> files = depc.getFilesDist(); + if (files != null) { + for (File f : files) { + if (f.getName().toLowerCase().endsWith(".jar")) { + sc.addJar(f.getAbsolutePath()); + logger.info("sc.addJar(" + f.getAbsolutePath() + ")"); + } else { + sc.addFile(f.getAbsolutePath()); + logger.info("sc.addFile(" + f.getAbsolutePath() + ")"); + } + } + } + } + } + } + + private List<File> currentClassPath() { + List<File> paths = classPath(Thread.currentThread().getContextClassLoader()); + String[] cps = System.getProperty("java.class.path").split(File.pathSeparator); + if (cps != null) { + for (String cp : cps) { + paths.add(new File(cp)); + } + } + return paths; + } + + private List<File> classPath(ClassLoader cl) { + List<File> paths = new LinkedList<File>(); + if (cl == null) { + return paths; + } + + if (cl instanceof URLClassLoader) { + URLClassLoader ucl = (URLClassLoader) cl; + URL[] urls = ucl.getURLs(); + if (urls != null) { + for (URL url : urls) { + paths.add(new File(url.getFile())); + } + } + } + return paths; + } + + @Override + public List<String> completion(String buf, int cursor) { + ScalaCompleter c = completor.completer(); + Candidates ret = c.complete(buf, cursor); + return scala.collection.JavaConversions.asJavaList(ret.candidates()); + } + + public Object getValue(String name) { + Object ret = intp.valueOfTerm(name); + if (ret instanceof None) { + return null; + } else if (ret instanceof Some) { + return ((Some) ret).get(); + } else { + return ret; + } + } + + String getJobGroup(InterpreterContext context){ + return "zeppelin-" + this.hashCode() + "-" + context.getParagraphId(); + } + + /** + * Interpret a single line. + */ + @Override + public InterpreterResult interpret(String line, InterpreterContext context) { + z.setInterpreterContext(context); + if (line == null || line.trim().length() == 0) { + return new InterpreterResult(Code.SUCCESS); + } + return interpret(line.split("\n"), context); + } + + public InterpreterResult interpret(String[] lines, InterpreterContext context) { + synchronized (this) { + z.setGui(context.getGui()); + sc.setJobGroup(getJobGroup(context), "Zeppelin", false); + InterpreterResult r = interpretInput(lines); + sc.clearJobGroup(); + return r; + } + } + + public InterpreterResult interpretInput(String[] lines) { + SparkEnv.set(env); + + // add print("") to make sure not finishing with comment + // see https://github.com/NFLabs/zeppelin/issues/151 + String[] linesToRun = new String[lines.length + 1]; + for (int i = 0; i < lines.length; i++) { + linesToRun[i] = lines[i]; + } + linesToRun[lines.length] = "print(\"\")"; + + Console.setOut((java.io.PrintStream) binder.get("out")); + out.reset(); + Code r = null; + String incomplete = ""; + for (String s : linesToRun) { + scala.tools.nsc.interpreter.Results.Result res = null; + try { + res = intp.interpret(incomplete + s); + } catch (Exception e) { + sc.clearJobGroup(); + logger.info("Interpreter exception", e); + return new InterpreterResult(Code.ERROR, InterpreterUtils.getMostRelevantMessage(e)); + } + + r = getResultCode(res); + + if (r == Code.ERROR) { + sc.clearJobGroup(); + return new InterpreterResult(r, out.toString()); + } else if (r == Code.INCOMPLETE) { + incomplete += s + "\n"; + } else { + incomplete = ""; + } + } + + if (r == Code.INCOMPLETE) { + return new InterpreterResult(r, "Incomplete expression"); + } else { + return new InterpreterResult(r, out.toString()); + } + } + + + @Override + public void cancel(InterpreterContext context) { + sc.cancelJobGroup(getJobGroup(context)); + } + + @Override + public int getProgress(InterpreterContext context) { + String jobGroup = getJobGroup(context); + int completedTasks = 0; + int totalTasks = 0; + + DAGScheduler scheduler = sc.dagScheduler(); + if (scheduler == null) { + return 0; + } + HashSet<ActiveJob> jobs = scheduler.activeJobs(); + if (jobs == null || jobs.size() == 0) { + return 0; + } + Iterator<ActiveJob> it = jobs.iterator(); + while (it.hasNext()) { + ActiveJob job = it.next(); + String g = (String) job.properties().get("spark.jobGroup.id"); + + if (jobGroup.equals(g)) { + int[] progressInfo = null; + if (sc.version().startsWith("1.0")) { + progressInfo = getProgressFromStage_1_0x(sparkListener, job.finalStage()); + } else if (sc.version().startsWith("1.1")) { + progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage()); + } else if (sc.version().startsWith("1.2")) { + progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage()); + } else if (sc.version().startsWith("1.3")) { + progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage()); + } else { + continue; + } + totalTasks += progressInfo[0]; + completedTasks += progressInfo[1]; + } + } + + if (totalTasks == 0) { + return 0; + } + return completedTasks * 100 / totalTasks; + } + + private int[] getProgressFromStage_1_0x(JobProgressListener sparkListener, Stage stage) { + int numTasks = stage.numTasks(); + int completedTasks = 0; + + Method method; + Object completedTaskInfo = null; + try { + method = sparkListener.getClass().getMethod("stageIdToTasksComplete"); + completedTaskInfo = + JavaConversions.asJavaMap((HashMap<Object, Object>) method.invoke(sparkListener)).get( + stage.id()); + } catch (NoSuchMethodException | SecurityException e) { + logger.error("Error while getting progress", e); + } catch (IllegalAccessException e) { + logger.error("Error while getting progress", e); + } catch (IllegalArgumentException e) { + logger.error("Error while getting progress", e); + } catch (InvocationTargetException e) { + logger.error("Error while getting progress", e); + } + + if (completedTaskInfo != null) { + completedTasks += (int) completedTaskInfo; + } + List<Stage> parents = JavaConversions.asJavaList(stage.parents()); + if (parents != null) { + for (Stage s : parents) { + int[] p = getProgressFromStage_1_0x(sparkListener, s); + numTasks += p[0]; + completedTasks += p[1]; + } + } + + return new int[] {numTasks, completedTasks}; + } + + private int[] getProgressFromStage_1_1x(JobProgressListener sparkListener, Stage stage) { + int numTasks = stage.numTasks(); + int completedTasks = 0; + + try { + Method stageIdToData = sparkListener.getClass().getMethod("stageIdToData"); + HashMap<Tuple2<Object, Object>, Object> stageIdData = + (HashMap<Tuple2<Object, Object>, Object>) stageIdToData.invoke(sparkListener); + Class<?> stageUIDataClass = + this.getClass().forName("org.apache.spark.ui.jobs.UIData$StageUIData"); + + Method numCompletedTasks = stageUIDataClass.getMethod("numCompleteTasks"); + + Set<Tuple2<Object, Object>> keys = + JavaConverters.asJavaSetConverter(stageIdData.keySet()).asJava(); + for (Tuple2<Object, Object> k : keys) { + if (stage.id() == (int) k._1()) { + Object uiData = stageIdData.get(k).get(); + completedTasks += (int) numCompletedTasks.invoke(uiData); + } + } + } catch (Exception e) { + logger.error("Error on getting progress information", e); + } + + List<Stage> parents = JavaConversions.asJavaList(stage.parents()); + if (parents != null) { + for (Stage s : parents) { + int[] p = getProgressFromStage_1_1x(sparkListener, s); + numTasks += p[0]; + completedTasks += p[1]; + } + } + return new int[] {numTasks, completedTasks}; + } + + private Code getResultCode(scala.tools.nsc.interpreter.Results.Result r) { + if (r instanceof scala.tools.nsc.interpreter.Results.Success$) { + return Code.SUCCESS; + } else if (r instanceof scala.tools.nsc.interpreter.Results.Incomplete$) { + return Code.INCOMPLETE; + } else { + return Code.ERROR; + } + } + + @Override + public void close() { + sc.stop(); + sc = null; + + intp.close(); + } + + @Override + public FormType getFormType() { + return FormType.NATIVE; + } + + public JobProgressListener getJobProgressListener() { + return sparkListener; + } + + @Override + public Scheduler getScheduler() { + return SchedulerFactory.singleton().createOrGetFIFOScheduler( + SparkInterpreter.class.getName() + this.hashCode()); + } + + public ZeppelinContext getZeppelinContext() { + return z; + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java new file mode 100644 index 0000000..2555988 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.spark.SparkContext; +import org.apache.spark.scheduler.ActiveJob; +import org.apache.spark.scheduler.DAGScheduler; +import org.apache.spark.scheduler.Stage; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SQLContext.QueryExecution; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.ui.jobs.JobProgressListener; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterPropertyBuilder; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterUtils; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.scheduler.Scheduler; +import org.apache.zeppelin.scheduler.SchedulerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConversions; +import scala.collection.JavaConverters; +import scala.collection.mutable.HashMap; +import scala.collection.mutable.HashSet; + +/** + * Spark SQL interpreter for Zeppelin. + * + * @author Leemoonsoo + * + */ +public class SparkSqlInterpreter extends Interpreter { + Logger logger = LoggerFactory.getLogger(SparkSqlInterpreter.class); + AtomicInteger num = new AtomicInteger(0); + + static { + Interpreter.register( + "sql", + "spark", + SparkSqlInterpreter.class.getName(), + new InterpreterPropertyBuilder() + .add("zeppelin.spark.maxResult", "10000", "Max number of SparkSQL result to display.") + .add("zeppelin.spark.concurrentSQL", "false", + "Execute multiple SQL concurrently if set true.") + .build()); + } + + private String getJobGroup(InterpreterContext context){ + return "zeppelin-" + this.hashCode() + "-" + context.getParagraphId(); + } + + private int maxResult; + + public SparkSqlInterpreter(Properties property) { + super(property); + } + + @Override + public void open() { + this.maxResult = Integer.parseInt(getProperty("zeppelin.spark.maxResult")); + } + + private SparkInterpreter getSparkInterpreter() { + for (Interpreter intp : getInterpreterGroup()) { + if (intp.getClassName().equals(SparkInterpreter.class.getName())) { + Interpreter p = intp; + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + p.open(); + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + return (SparkInterpreter) p; + } + } + return null; + } + + public boolean concurrentSQL() { + return Boolean.parseBoolean(getProperty("zeppelin.spark.concurrentSQL")); + } + + @Override + public void close() {} + + @Override + public InterpreterResult interpret(String st, InterpreterContext context) { + SQLContext sqlc = null; + + sqlc = getSparkInterpreter().getSQLContext(); + + SparkContext sc = sqlc.sparkContext(); + if (concurrentSQL()) { + sc.setLocalProperty("spark.scheduler.pool", "fair"); + } else { + sc.setLocalProperty("spark.scheduler.pool", null); + } + + sc.setJobGroup(getJobGroup(context), "Zeppelin", false); + + // SchemaRDD - spark 1.1, 1.2, DataFrame - spark 1.3 + Object rdd; + Object[] rows = null; + try { + rdd = sqlc.sql(st); + + Method take = rdd.getClass().getMethod("take", int.class); + rows = (Object[]) take.invoke(rdd, maxResult + 1); + } catch (Exception e) { + logger.error("Error", e); + sc.clearJobGroup(); + return new InterpreterResult(Code.ERROR, InterpreterUtils.getMostRelevantMessage(e)); + } + + String msg = null; + + // get field names + Method queryExecution; + QueryExecution qe; + try { + queryExecution = rdd.getClass().getMethod("queryExecution"); + qe = (QueryExecution) queryExecution.invoke(rdd); + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + throw new InterpreterException(e); + } + + List<Attribute> columns = + scala.collection.JavaConverters.asJavaListConverter( + qe.analyzed().output()).asJava(); + + for (Attribute col : columns) { + if (msg == null) { + msg = col.name(); + } else { + msg += "\t" + col.name(); + } + } + + msg += "\n"; + + // ArrayType, BinaryType, BooleanType, ByteType, DecimalType, DoubleType, DynamicType, + // FloatType, FractionalType, IntegerType, IntegralType, LongType, MapType, NativeType, + // NullType, NumericType, ShortType, StringType, StructType + + try { + for (int r = 0; r < maxResult && r < rows.length; r++) { + Object row = rows[r]; + Method isNullAt = row.getClass().getMethod("isNullAt", int.class); + Method apply = row.getClass().getMethod("apply", int.class); + + for (int i = 0; i < columns.size(); i++) { + if (!(Boolean) isNullAt.invoke(row, i)) { + msg += apply.invoke(row, i).toString(); + } else { + msg += "null"; + } + if (i != columns.size() - 1) { + msg += "\t"; + } + } + msg += "\n"; + } + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + throw new InterpreterException(e); + } + + if (rows.length > maxResult) { + msg += "\n<font color=red>Results are limited by " + maxResult + ".</font>"; + } + InterpreterResult rett = new InterpreterResult(Code.SUCCESS, "%table " + msg); + sc.clearJobGroup(); + return rett; + } + + @Override + public void cancel(InterpreterContext context) { + SQLContext sqlc = getSparkInterpreter().getSQLContext(); + SparkContext sc = sqlc.sparkContext(); + + sc.cancelJobGroup(getJobGroup(context)); + } + + @Override + public FormType getFormType() { + return FormType.SIMPLE; + } + + + @Override + public int getProgress(InterpreterContext context) { + String jobGroup = getJobGroup(context); + SQLContext sqlc = getSparkInterpreter().getSQLContext(); + SparkContext sc = sqlc.sparkContext(); + JobProgressListener sparkListener = getSparkInterpreter().getJobProgressListener(); + int completedTasks = 0; + int totalTasks = 0; + + DAGScheduler scheduler = sc.dagScheduler(); + HashSet<ActiveJob> jobs = scheduler.activeJobs(); + Iterator<ActiveJob> it = jobs.iterator(); + while (it.hasNext()) { + ActiveJob job = it.next(); + String g = (String) job.properties().get("spark.jobGroup.id"); + if (jobGroup.equals(g)) { + int[] progressInfo = null; + if (sc.version().startsWith("1.0")) { + progressInfo = getProgressFromStage_1_0x(sparkListener, job.finalStage()); + } else if (sc.version().startsWith("1.1")) { + progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage()); + } else if (sc.version().startsWith("1.2")) { + progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage()); + } else if (sc.version().startsWith("1.3")) { + progressInfo = getProgressFromStage_1_1x(sparkListener, job.finalStage()); + } else { + logger.warn("Spark {} getting progress information not supported" + sc.version()); + continue; + } + totalTasks += progressInfo[0]; + completedTasks += progressInfo[1]; + } + } + + if (totalTasks == 0) { + return 0; + } + return completedTasks * 100 / totalTasks; + } + + private int[] getProgressFromStage_1_0x(JobProgressListener sparkListener, Stage stage) { + int numTasks = stage.numTasks(); + int completedTasks = 0; + + Method method; + Object completedTaskInfo = null; + try { + method = sparkListener.getClass().getMethod("stageIdToTasksComplete"); + completedTaskInfo = + JavaConversions.asJavaMap((HashMap<Object, Object>) method.invoke(sparkListener)).get( + stage.id()); + } catch (NoSuchMethodException | SecurityException e) { + logger.error("Error while getting progress", e); + } catch (IllegalAccessException e) { + logger.error("Error while getting progress", e); + } catch (IllegalArgumentException e) { + logger.error("Error while getting progress", e); + } catch (InvocationTargetException e) { + logger.error("Error while getting progress", e); + } + + if (completedTaskInfo != null) { + completedTasks += (int) completedTaskInfo; + } + List<Stage> parents = JavaConversions.asJavaList(stage.parents()); + if (parents != null) { + for (Stage s : parents) { + int[] p = getProgressFromStage_1_0x(sparkListener, s); + numTasks += p[0]; + completedTasks += p[1]; + } + } + + return new int[] {numTasks, completedTasks}; + } + + private int[] getProgressFromStage_1_1x(JobProgressListener sparkListener, Stage stage) { + int numTasks = stage.numTasks(); + int completedTasks = 0; + + try { + Method stageIdToData = sparkListener.getClass().getMethod("stageIdToData"); + HashMap<Tuple2<Object, Object>, Object> stageIdData = + (HashMap<Tuple2<Object, Object>, Object>) stageIdToData.invoke(sparkListener); + Class<?> stageUIDataClass = + this.getClass().forName("org.apache.spark.ui.jobs.UIData$StageUIData"); + + Method numCompletedTasks = stageUIDataClass.getMethod("numCompleteTasks"); + + Set<Tuple2<Object, Object>> keys = + JavaConverters.asJavaSetConverter(stageIdData.keySet()).asJava(); + for (Tuple2<Object, Object> k : keys) { + if (stage.id() == (int) k._1()) { + Object uiData = stageIdData.get(k).get(); + completedTasks += (int) numCompletedTasks.invoke(uiData); + } + } + } catch (Exception e) { + logger.error("Error on getting progress information", e); + } + + List<Stage> parents = JavaConversions.asJavaList(stage.parents()); + if (parents != null) { + for (Stage s : parents) { + int[] p = getProgressFromStage_1_1x(sparkListener, s); + numTasks += p[0]; + completedTasks += p[1]; + } + } + return new int[] {numTasks, completedTasks}; + } + + @Override + public Scheduler getScheduler() { + if (concurrentSQL()) { + int maxConcurrency = 10; + return SchedulerFactory.singleton().createOrGetParallelScheduler( + SparkSqlInterpreter.class.getName() + this.hashCode(), maxConcurrency); + } else { + // getSparkInterpreter() calls open() inside. + // That means if SparkInterpreter is not opened, it'll wait until SparkInterpreter open. + // In this moment UI displays 'READY' or 'FINISHED' instead of 'PENDING' or 'RUNNING'. + // It's because of scheduler is not created yet, and scheduler is created by this function. + // Therefore, we can still use getSparkInterpreter() here, but it's better and safe + // to getSparkInterpreter without opening it. + for (Interpreter intp : getInterpreterGroup()) { + if (intp.getClassName().equals(SparkInterpreter.class.getName())) { + Interpreter p = intp; + return p.getScheduler(); + } else { + continue; + } + } + throw new InterpreterException("Can't find SparkInterpreter"); + } + } + + @Override + public List<String> completion(String buf, int cursor) { + return null; + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java new file mode 100644 index 0000000..87cd188 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import static scala.collection.JavaConversions.asJavaCollection; +import static scala.collection.JavaConversions.asJavaIterable; +import static scala.collection.JavaConversions.collectionAsScalaIterable; + +import java.io.PrintStream; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; + +import org.apache.spark.SparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.hive.HiveContext; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.display.Input.ParamOption; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.spark.dep.DependencyResolver; + +import scala.Tuple2; +import scala.collection.Iterable; + +/** + * Spark context for zeppelin. + * + * @author Leemoonsoo + * + */ +public class ZeppelinContext extends HashMap<String, Object> { + private DependencyResolver dep; + private PrintStream out; + private InterpreterContext interpreterContext; + + public ZeppelinContext(SparkContext sc, SQLContext sql, + InterpreterContext interpreterContext, + DependencyResolver dep, PrintStream printStream) { + this.sc = sc; + this.sqlContext = sql; + this.interpreterContext = interpreterContext; + this.dep = dep; + this.out = printStream; + } + + public SparkContext sc; + public SQLContext sqlContext; + public HiveContext hiveContext; + private GUI gui; + + /* spark-1.3 + public SchemaRDD sql(String sql) { + return sqlContext.sql(sql); + } + */ + + /** + * Load dependency for interpreter and runtime (driver). + * And distribute them to spark cluster (sc.add()) + * + * @param artifact "group:artifact:version" or file path like "/somepath/your.jar" + * @return + * @throws Exception + */ + public Iterable<String> load(String artifact) throws Exception { + return collectionAsScalaIterable(dep.load(artifact, true)); + } + + /** + * Load dependency and it's transitive dependencies for interpreter and runtime (driver). + * And distribute them to spark cluster (sc.add()) + * + * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" + * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. + * @return + * @throws Exception + */ + public Iterable<String> load(String artifact, scala.collection.Iterable<String> excludes) + throws Exception { + return collectionAsScalaIterable( + dep.load(artifact, + asJavaCollection(excludes), + true)); + } + + /** + * Load dependency and it's transitive dependencies for interpreter and runtime (driver). + * And distribute them to spark cluster (sc.add()) + * + * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" + * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. + * @return + * @throws Exception + */ + public Iterable<String> load(String artifact, Collection<String> excludes) throws Exception { + return collectionAsScalaIterable(dep.load(artifact, excludes, true)); + } + + /** + * Load dependency for interpreter and runtime, and then add to sparkContext. + * But not adding them to spark cluster + * + * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" + * @return + * @throws Exception + */ + public Iterable<String> loadLocal(String artifact) throws Exception { + return collectionAsScalaIterable(dep.load(artifact, false)); + } + + + /** + * Load dependency and it's transitive dependencies and then add to sparkContext. + * But not adding them to spark cluster + * + * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" + * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. + * @return + * @throws Exception + */ + public Iterable<String> loadLocal(String artifact, + scala.collection.Iterable<String> excludes) throws Exception { + return collectionAsScalaIterable(dep.load(artifact, + asJavaCollection(excludes), false)); + } + + /** + * Load dependency and it's transitive dependencies and then add to sparkContext. + * But not adding them to spark cluster + * + * @param artifact "groupId:artifactId:version" or file path like "/somepath/your.jar" + * @param excludes exclusion list of transitive dependency. list of "groupId:artifactId" string. + * @return + * @throws Exception + */ + public Iterable<String> loadLocal(String artifact, Collection<String> excludes) + throws Exception { + return collectionAsScalaIterable(dep.load(artifact, excludes, false)); + } + + + /** + * Add maven repository + * + * @param id id of repository ex) oss, local, snapshot + * @param url url of repository. supported protocol : file, http, https + */ + public void addRepo(String id, String url) { + addRepo(id, url, false); + } + + /** + * Add maven repository + * + * @param id id of repository + * @param url url of repository. supported protocol : file, http, https + * @param snapshot true if it is snapshot repository + */ + public void addRepo(String id, String url, boolean snapshot) { + dep.addRepo(id, url, snapshot); + } + + /** + * Remove maven repository by id + * @param id id of repository + */ + public void removeRepo(String id){ + dep.delRepo(id); + } + + /** + * Load dependency only interpreter. + * + * @param name + * @return + */ + + public Object input(String name) { + return input(name, ""); + } + + public Object input(String name, Object defaultValue) { + return gui.input(name, defaultValue); + } + + public Object select(String name, scala.collection.Iterable<Tuple2<Object, String>> options) { + return select(name, "", options); + } + + public Object select(String name, Object defaultValue, + scala.collection.Iterable<Tuple2<Object, String>> options) { + int n = options.size(); + ParamOption[] paramOptions = new ParamOption[n]; + Iterator<Tuple2<Object, String>> it = asJavaIterable(options).iterator(); + + int i = 0; + while (it.hasNext()) { + Tuple2<Object, String> valueAndDisplayValue = it.next(); + paramOptions[i++] = new ParamOption(valueAndDisplayValue._1(), valueAndDisplayValue._2()); + } + + return gui.select(name, "", paramOptions); + } + + public void setGui(GUI o) { + this.gui = o; + } + + public void run(String lines) { + /* + String intpName = Paragraph.getRequiredReplName(lines); + String scriptBody = Paragraph.getScriptBody(lines); + Interpreter intp = interpreterContext.getParagraph().getRepl(intpName); + InterpreterResult ret = intp.interpret(scriptBody, interpreterContext); + if (ret.code() == InterpreterResult.Code.SUCCESS) { + out.println("%" + ret.type().toString().toLowerCase() + " " + ret.message()); + } else if (ret.code() == InterpreterResult.Code.ERROR) { + out.println("Error: " + ret.message()); + } else if (ret.code() == InterpreterResult.Code.INCOMPLETE) { + out.println("Incomplete"); + } else { + out.println("Unknown error"); + } + */ + throw new RuntimeException("Missing implementation"); + } + + private void restartInterpreter() { + } + + public InterpreterContext getInterpreterContext() { + return interpreterContext; + } + + public void setInterpreterContext(InterpreterContext interpreterContext) { + this.interpreterContext = interpreterContext; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/dep/Booter.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/dep/Booter.java b/spark/src/main/java/org/apache/zeppelin/spark/dep/Booter.java new file mode 100644 index 0000000..0533804 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/dep/Booter.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +import java.io.File; + +import org.apache.maven.repository.internal.MavenRepositorySystemSession; +import org.sonatype.aether.RepositorySystem; +import org.sonatype.aether.RepositorySystemSession; +import org.sonatype.aether.repository.LocalRepository; +import org.sonatype.aether.repository.RemoteRepository; + +/** + * Manage mvn repository. + * + * @author anthonycorbacho + * + */ +public class Booter { + public static RepositorySystem newRepositorySystem() { + return RepositorySystemFactory.newRepositorySystem(); + } + + public static RepositorySystemSession newRepositorySystemSession( + RepositorySystem system, String localRepoPath) { + MavenRepositorySystemSession session = new MavenRepositorySystemSession(); + + // find homedir + String home = System.getenv("ZEPPELIN_HOME"); + if (home == null) { + home = System.getProperty("zeppelin.home"); + } + if (home == null) { + home = ".."; + } + + String path = home + "/" + localRepoPath; + + LocalRepository localRepo = + new LocalRepository(new File(path).getAbsolutePath()); + session.setLocalRepositoryManager(system.newLocalRepositoryManager(localRepo)); + + // session.setTransferListener(new ConsoleTransferListener()); + // session.setRepositoryListener(new ConsoleRepositoryListener()); + + // uncomment to generate dirty trees + // session.setDependencyGraphTransformer( null ); + + return session; + } + + public static RemoteRepository newCentralRepository() { + return new RemoteRepository("central", "default", "http://repo1.maven.org/maven2/"); + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/dep/Dependency.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/dep/Dependency.java b/spark/src/main/java/org/apache/zeppelin/spark/dep/Dependency.java new file mode 100644 index 0000000..ca92893 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/dep/Dependency.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +import java.util.LinkedList; +import java.util.List; + +/** + * + */ +public class Dependency { + private String groupArtifactVersion; + private boolean local = false; + private List<String> exclusions; + + + public Dependency(String groupArtifactVersion) { + this.groupArtifactVersion = groupArtifactVersion; + exclusions = new LinkedList<String>(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Dependency)) { + return false; + } else { + return ((Dependency) o).groupArtifactVersion.equals(groupArtifactVersion); + } + } + + /** + * Don't add artifact into SparkContext (sc.addJar()) + * @return + */ + public Dependency local() { + local = true; + return this; + } + + public Dependency excludeAll() { + exclude("*"); + return this; + } + + /** + * + * @param exclusions comma or newline separated list of "groupId:ArtifactId" + * @return + */ + public Dependency exclude(String exclusions) { + for (String item : exclusions.split(",|\n")) { + this.exclusions.add(item); + } + + return this; + } + + + public String getGroupArtifactVersion() { + return groupArtifactVersion; + } + + public boolean isDist() { + return !local; + } + + public List<String> getExclusions() { + return exclusions; + } + + public boolean isLocalFsArtifact() { + int numSplits = groupArtifactVersion.split(":").length; + return !(numSplits >= 3 && numSplits <= 6); + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyContext.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyContext.java b/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyContext.java new file mode 100644 index 0000000..f0fd313 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyContext.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +import java.io.File; +import java.net.MalformedURLException; +import java.util.LinkedList; +import java.util.List; + +import org.sonatype.aether.RepositorySystem; +import org.sonatype.aether.RepositorySystemSession; +import org.sonatype.aether.artifact.Artifact; +import org.sonatype.aether.collection.CollectRequest; +import org.sonatype.aether.graph.DependencyFilter; +import org.sonatype.aether.repository.RemoteRepository; +import org.sonatype.aether.resolution.ArtifactResolutionException; +import org.sonatype.aether.resolution.ArtifactResult; +import org.sonatype.aether.resolution.DependencyRequest; +import org.sonatype.aether.resolution.DependencyResolutionException; +import org.sonatype.aether.util.artifact.DefaultArtifact; +import org.sonatype.aether.util.artifact.JavaScopes; +import org.sonatype.aether.util.filter.DependencyFilterUtils; +import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter; + + +/** + * + */ +public class DependencyContext { + List<Dependency> dependencies = new LinkedList<Dependency>(); + List<Repository> repositories = new LinkedList<Repository>(); + + List<File> files = new LinkedList<File>(); + List<File> filesDist = new LinkedList<File>(); + private RepositorySystem system = Booter.newRepositorySystem(); + private RepositorySystemSession session; + private RemoteRepository mavenCentral = new RemoteRepository("central", + "default", "http://repo1.maven.org/maven2/"); + private RemoteRepository mavenLocal = new RemoteRepository("local", + "default", "file://" + System.getProperty("user.home") + "/.m2/repository"); + + public DependencyContext(String localRepoPath) { + session = Booter.newRepositorySystemSession(system, localRepoPath); + } + + public Dependency load(String lib) { + Dependency dep = new Dependency(lib); + + if (dependencies.contains(dep)) { + dependencies.remove(dep); + } + dependencies.add(dep); + return dep; + } + + public Repository addRepo(String name) { + Repository rep = new Repository(name); + repositories.add(rep); + return rep; + } + + public void reset() { + dependencies = new LinkedList<Dependency>(); + repositories = new LinkedList<Repository>(); + + files = new LinkedList<File>(); + filesDist = new LinkedList<File>(); + } + + + /** + * fetch all artifacts + * @return + * @throws MalformedURLException + * @throws ArtifactResolutionException + * @throws DependencyResolutionException + */ + public List<File> fetch() throws MalformedURLException, + DependencyResolutionException, ArtifactResolutionException { + + for (Dependency dep : dependencies) { + if (!dep.isLocalFsArtifact()) { + List<ArtifactResult> artifacts = fetchArtifactWithDep(dep); + for (ArtifactResult artifact : artifacts) { + if (dep.isDist()) { + filesDist.add(artifact.getArtifact().getFile()); + } + files.add(artifact.getArtifact().getFile()); + } + } else { + if (dep.isDist()) { + filesDist.add(new File(dep.getGroupArtifactVersion())); + } + files.add(new File(dep.getGroupArtifactVersion())); + } + } + + return files; + } + + private List<ArtifactResult> fetchArtifactWithDep(Dependency dep) + throws DependencyResolutionException, ArtifactResolutionException { + Artifact artifact = new DefaultArtifact( + DependencyResolver.inferScalaVersion(dep.getGroupArtifactVersion())); + + DependencyFilter classpathFlter = DependencyFilterUtils + .classpathFilter(JavaScopes.COMPILE); + PatternExclusionsDependencyFilter exclusionFilter = new PatternExclusionsDependencyFilter( + DependencyResolver.inferScalaVersion(dep.getExclusions())); + + CollectRequest collectRequest = new CollectRequest(); + collectRequest.setRoot(new org.sonatype.aether.graph.Dependency(artifact, + JavaScopes.COMPILE)); + + collectRequest.addRepository(mavenCentral); + collectRequest.addRepository(mavenLocal); + for (Repository repo : repositories) { + RemoteRepository rr = new RemoteRepository(repo.getName(), "default", repo.getUrl()); + rr.setPolicy(repo.isSnapshot(), null); + collectRequest.addRepository(rr); + } + + DependencyRequest dependencyRequest = new DependencyRequest(collectRequest, + DependencyFilterUtils.andFilter(exclusionFilter, classpathFlter)); + + return system.resolveDependencies(session, dependencyRequest).getArtifactResults(); + } + + public List<File> getFiles() { + return files; + } + + public List<File> getFilesDist() { + return filesDist; + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyResolver.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyResolver.java b/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyResolver.java new file mode 100644 index 0000000..06a4022 --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/dep/DependencyResolver.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +import java.io.File; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URL; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.commons.lang.StringUtils; +import org.apache.spark.SparkContext; +import org.apache.spark.repl.SparkIMain; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.sonatype.aether.RepositorySystem; +import org.sonatype.aether.RepositorySystemSession; +import org.sonatype.aether.artifact.Artifact; +import org.sonatype.aether.collection.CollectRequest; +import org.sonatype.aether.graph.Dependency; +import org.sonatype.aether.graph.DependencyFilter; +import org.sonatype.aether.repository.RemoteRepository; +import org.sonatype.aether.resolution.ArtifactResult; +import org.sonatype.aether.resolution.DependencyRequest; +import org.sonatype.aether.util.artifact.DefaultArtifact; +import org.sonatype.aether.util.artifact.JavaScopes; +import org.sonatype.aether.util.filter.DependencyFilterUtils; +import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter; + +import scala.Some; +import scala.collection.IndexedSeq; +import scala.reflect.io.AbstractFile; +import scala.tools.nsc.Global; +import scala.tools.nsc.backend.JavaPlatform; +import scala.tools.nsc.util.ClassPath; +import scala.tools.nsc.util.MergedClassPath; + +/** + * Deps resolver. + * Add new dependencies from mvn repo (at runetime) to Zeppelin. + * + * @author anthonycorbacho + * + */ +public class DependencyResolver { + Logger logger = LoggerFactory.getLogger(DependencyResolver.class); + private Global global; + private SparkIMain intp; + private SparkContext sc; + private RepositorySystem system = Booter.newRepositorySystem(); + private List<RemoteRepository> repos = new LinkedList<RemoteRepository>(); + private RepositorySystemSession session; + private DependencyFilter classpathFlter = DependencyFilterUtils.classpathFilter( + JavaScopes.COMPILE, + JavaScopes.PROVIDED, + JavaScopes.RUNTIME, + JavaScopes.SYSTEM); + + private final String[] exclusions = new String[] {"org.scala-lang:scala-library", + "org.scala-lang:scala-compiler", + "org.scala-lang:scala-reflect", + "org.scala-lang:scalap", + "org.apache.zeppelin:zeppelin-zengine", + "org.apache.zeppelin:zeppelin-spark", + "org.apache.zeppelin:zeppelin-server"}; + + public DependencyResolver(SparkIMain intp, SparkContext sc, String localRepoPath) { + this.intp = intp; + this.global = intp.global(); + this.sc = sc; + session = Booter.newRepositorySystemSession(system, localRepoPath); + repos.add(Booter.newCentralRepository()); // add maven central + repos.add(new RemoteRepository("local", "default", "file://" + + System.getProperty("user.home") + "/.m2/repository")); + } + + public void addRepo(String id, String url, boolean snapshot) { + synchronized (repos) { + delRepo(id); + RemoteRepository rr = new RemoteRepository(id, "default", url); + rr.setPolicy(snapshot, null); + repos.add(rr); + } + } + + public RemoteRepository delRepo(String id) { + synchronized (repos) { + Iterator<RemoteRepository> it = repos.iterator(); + if (it.hasNext()) { + RemoteRepository repo = it.next(); + if (repo.getId().equals(id)) { + it.remove(); + return repo; + } + } + } + return null; + } + + private void updateCompilerClassPath(URL[] urls) throws IllegalAccessException, + IllegalArgumentException, InvocationTargetException { + + JavaPlatform platform = (JavaPlatform) global.platform(); + MergedClassPath<AbstractFile> newClassPath = mergeUrlsIntoClassPath(platform, urls); + + Method[] methods = platform.getClass().getMethods(); + for (Method m : methods) { + if (m.getName().endsWith("currentClassPath_$eq")) { + m.invoke(platform, new Some(newClassPath)); + break; + } + } + + // NOTE: Must use reflection until this is exposed/fixed upstream in Scala + List<String> classPaths = new LinkedList<String>(); + for (URL url : urls) { + classPaths.add(url.getPath()); + } + + // Reload all jars specified into our compiler + global.invalidateClassPathEntries(scala.collection.JavaConversions.asScalaBuffer(classPaths) + .toList()); + } + + // Until spark 1.1.x + // check https://github.com/apache/spark/commit/191d7cf2a655d032f160b9fa181730364681d0e7 + private void updateRuntimeClassPath(URL[] urls) throws SecurityException, IllegalAccessException, + IllegalArgumentException, InvocationTargetException, NoSuchMethodException { + ClassLoader cl = intp.classLoader().getParent(); + Method addURL; + addURL = cl.getClass().getDeclaredMethod("addURL", new Class[] {URL.class}); + addURL.setAccessible(true); + for (URL url : urls) { + addURL.invoke(cl, url); + } + } + + private MergedClassPath<AbstractFile> mergeUrlsIntoClassPath(JavaPlatform platform, URL[] urls) { + IndexedSeq<ClassPath<AbstractFile>> entries = + ((MergedClassPath<AbstractFile>) platform.classPath()).entries(); + List<ClassPath<AbstractFile>> cp = new LinkedList<ClassPath<AbstractFile>>(); + + for (int i = 0; i < entries.size(); i++) { + cp.add(entries.apply(i)); + } + + for (URL url : urls) { + AbstractFile file; + if ("file".equals(url.getProtocol())) { + File f = new File(url.getPath()); + if (f.isDirectory()) { + file = AbstractFile.getDirectory(scala.reflect.io.File.jfile2path(f)); + } else { + file = AbstractFile.getFile(scala.reflect.io.File.jfile2path(f)); + } + } else { + file = AbstractFile.getURL(url); + } + + ClassPath<AbstractFile> newcp = platform.classPath().context().newClassPath(file); + + // distinct + if (cp.contains(newcp) == false) { + cp.add(newcp); + } + } + + return new MergedClassPath(scala.collection.JavaConversions.asScalaBuffer(cp).toIndexedSeq(), + platform.classPath().context()); + } + + public List<String> load(String artifact, + boolean addSparkContext) throws Exception { + return load(artifact, new LinkedList<String>(), addSparkContext); + } + + public List<String> load(String artifact, Collection<String> excludes, + boolean addSparkContext) throws Exception { + if (StringUtils.isBlank(artifact)) { + // Should throw here + throw new RuntimeException("Invalid artifact to load"); + } + + // <groupId>:<artifactId>[:<extension>[:<classifier>]]:<version> + int numSplits = artifact.split(":").length; + if (numSplits >= 3 && numSplits <= 6) { + return loadFromMvn(artifact, excludes, addSparkContext); + } else { + loadFromFs(artifact, addSparkContext); + LinkedList<String> libs = new LinkedList<String>(); + libs.add(artifact); + return libs; + } + } + + private void loadFromFs(String artifact, boolean addSparkContext) throws Exception { + File jarFile = new File(artifact); + + intp.global().new Run(); + + updateRuntimeClassPath(new URL[] {jarFile.toURI().toURL()}); + updateCompilerClassPath(new URL[] {jarFile.toURI().toURL()}); + + if (addSparkContext) { + sc.addJar(jarFile.getAbsolutePath()); + } + } + + private List<String> loadFromMvn(String artifact, Collection<String> excludes, + boolean addSparkContext) throws Exception { + List<String> loadedLibs = new LinkedList<String>(); + Collection<String> allExclusions = new LinkedList<String>(); + allExclusions.addAll(excludes); + allExclusions.addAll(Arrays.asList(exclusions)); + + List<ArtifactResult> listOfArtifact; + listOfArtifact = getArtifactsWithDep(artifact, allExclusions); + + Iterator<ArtifactResult> it = listOfArtifact.iterator(); + while (it.hasNext()) { + Artifact a = it.next().getArtifact(); + String gav = a.getGroupId() + ":" + a.getArtifactId() + ":" + a.getVersion(); + for (String exclude : allExclusions) { + if (gav.startsWith(exclude)) { + it.remove(); + break; + } + } + } + + List<URL> newClassPathList = new LinkedList<URL>(); + List<File> files = new LinkedList<File>(); + for (ArtifactResult artifactResult : listOfArtifact) { + logger.info("Load " + artifactResult.getArtifact().getGroupId() + ":" + + artifactResult.getArtifact().getArtifactId() + ":" + + artifactResult.getArtifact().getVersion()); + newClassPathList.add(artifactResult.getArtifact().getFile().toURI().toURL()); + files.add(artifactResult.getArtifact().getFile()); + loadedLibs.add(artifactResult.getArtifact().getGroupId() + ":" + + artifactResult.getArtifact().getArtifactId() + ":" + + artifactResult.getArtifact().getVersion()); + } + + intp.global().new Run(); + updateRuntimeClassPath(newClassPathList.toArray(new URL[0])); + updateCompilerClassPath(newClassPathList.toArray(new URL[0])); + + if (addSparkContext) { + for (File f : files) { + sc.addJar(f.getAbsolutePath()); + } + } + + return loadedLibs; + } + + /** + * + * @param dependency + * @param excludes list of pattern can either be of the form groupId:artifactId + * @return + * @throws Exception + */ + public List<ArtifactResult> getArtifactsWithDep(String dependency, + Collection<String> excludes) throws Exception { + Artifact artifact = new DefaultArtifact(inferScalaVersion(dependency)); + DependencyFilter classpathFlter = DependencyFilterUtils.classpathFilter( JavaScopes.COMPILE ); + PatternExclusionsDependencyFilter exclusionFilter = + new PatternExclusionsDependencyFilter(inferScalaVersion(excludes)); + + CollectRequest collectRequest = new CollectRequest(); + collectRequest.setRoot(new Dependency(artifact, JavaScopes.COMPILE)); + + synchronized (repos) { + for (RemoteRepository repo : repos) { + collectRequest.addRepository(repo); + } + } + DependencyRequest dependencyRequest = new DependencyRequest(collectRequest, + DependencyFilterUtils.andFilter(exclusionFilter, classpathFlter)); + return system.resolveDependencies(session, dependencyRequest).getArtifactResults(); + } + + public static Collection<String> inferScalaVersion(Collection<String> artifact) { + List<String> list = new LinkedList<String>(); + for (String a : artifact) { + list.add(inferScalaVersion(a)); + } + return list; + } + + public static String inferScalaVersion(String artifact) { + int pos = artifact.indexOf(":"); + if (pos < 0 || pos + 2 >= artifact.length()) { + // failed to infer + return artifact; + } + + if (':' == artifact.charAt(pos + 1)) { + String restOfthem = ""; + String versionSep = ":"; + + String groupId = artifact.substring(0, pos); + int nextPos = artifact.indexOf(":", pos + 2); + if (nextPos < 0) { + if (artifact.charAt(artifact.length() - 1) == '*') { + nextPos = artifact.length() - 1; + versionSep = ""; + restOfthem = "*"; + } else { + versionSep = ""; + nextPos = artifact.length(); + } + } + + String artifactId = artifact.substring(pos + 2, nextPos); + if (nextPos < artifact.length()) { + if (!restOfthem.equals("*")) { + restOfthem = artifact.substring(nextPos + 1); + } + } + + String [] version = scala.util.Properties.versionNumberString().split("[.]"); + String scalaVersion = version[0] + "." + version[1]; + + return groupId + ":" + artifactId + "_" + scalaVersion + versionSep + restOfthem; + } else { + return artifact; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/669d408d/spark/src/main/java/org/apache/zeppelin/spark/dep/Repository.java ---------------------------------------------------------------------- diff --git a/spark/src/main/java/org/apache/zeppelin/spark/dep/Repository.java b/spark/src/main/java/org/apache/zeppelin/spark/dep/Repository.java new file mode 100644 index 0000000..49c6c9b --- /dev/null +++ b/spark/src/main/java/org/apache/zeppelin/spark/dep/Repository.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark.dep; + +/** + * + * + */ +public class Repository { + private boolean snapshot = false; + private String name; + private String url; + + public Repository(String name){ + this.name = name; + } + + public Repository url(String url) { + this.url = url; + return this; + } + + public Repository snapshot() { + snapshot = true; + return this; + } + + public boolean isSnapshot() { + return snapshot; + } + + public String getName() { + return name; + } + + public String getUrl() { + return url; + } +}
