Repository: zeppelin Updated Branches: refs/heads/master 3651e5cdc -> 76e0027c2
ZEPPELIN-3652. Remove reflection in SparkInterpreter ### What is this PR for? This is a followup ticket of ZEPPELIN-3635. This PR remove some legacy code about old spark versions where we have to use reflection at that time. ### What type of PR is it? [Refactoring] ### Todos * [ ] - Task ### What is the Jira issue? * https://issues.apache.org/jira/browse/ZEPPELIN-3652 ### How should this be tested? * CI pass ### Screenshots (if appropriate) ### Questions: * Does the licenses files need update? No * Is there breaking changes for older versions? No * Does this needs documentation? No Author: Jeff Zhang <zjf...@apache.org> Closes #3095 from zjffdu/ZEPPELIN-3652 and squashes the following commits: 322f1db8f [Jeff Zhang] ZEPPELIN-3652. Remove reflection in SparkInterpreter Project: http://git-wip-us.apache.org/repos/asf/zeppelin/repo Commit: http://git-wip-us.apache.org/repos/asf/zeppelin/commit/76e0027c Tree: http://git-wip-us.apache.org/repos/asf/zeppelin/tree/76e0027c Diff: http://git-wip-us.apache.org/repos/asf/zeppelin/diff/76e0027c Branch: refs/heads/master Commit: 76e0027c2d43ac693c86a2123c2b589b9faaaeb3 Parents: 3651e5c Author: Jeff Zhang <zjf...@apache.org> Authored: Tue Jul 24 13:13:47 2018 +0800 Committer: Jeff Zhang <zjf...@apache.org> Committed: Thu Jul 26 15:05:01 2018 +0800 ---------------------------------------------------------------------- .../zeppelin/spark/NewSparkInterpreter.java | 10 +- .../zeppelin/spark/OldSparkInterpreter.java | 13 +-- .../zeppelin/spark/SparkSqlInterpreter.java | 33 ++---- .../zeppelin/spark/SparkZeppelinContext.java | 117 +++---------------- .../zeppelin/spark/NewSparkInterpreterTest.java | 30 ++--- .../apache/zeppelin/spark/SparkShimsTest.java | 5 + .../org/apache/zeppelin/spark/SparkShims.java | 4 + spark/spark1-shims/pom.xml | 7 ++ .../org/apache/zeppelin/spark/Spark1Shims.java | 39 +++++++ spark/spark2-shims/pom.xml | 7 ++ .../org/apache/zeppelin/spark/Spark2Shims.java | 41 +++++++ 11 files changed, 159 insertions(+), 147 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java index 864cc30..9ee504a 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/NewSparkInterpreter.java @@ -115,11 +115,6 @@ public class NewSparkInterpreter extends AbstractSparkInterpreter { sqlContext = this.innerInterpreter.sqlContext(); sparkSession = this.innerInterpreter.sparkSession(); hooks = getInterpreterGroup().getInterpreterHookRegistry(); - z = new SparkZeppelinContext(sc, hooks, - Integer.parseInt(getProperty("zeppelin.spark.maxResult"))); - this.innerInterpreter.bind("z", z.getClass().getCanonicalName(), z, - Lists.newArrayList("@transient")); - sparkUrl = this.innerInterpreter.sparkUrl(); String sparkUrlProp = getProperty("zeppelin.spark.uiWebUrl", ""); if (!StringUtils.isBlank(sparkUrlProp)) { @@ -127,6 +122,11 @@ public class NewSparkInterpreter extends AbstractSparkInterpreter { } sparkShims = SparkShims.getInstance(sc.version()); sparkShims.setupSparkListener(sc.master(), sparkUrl, InterpreterContext.get()); + + z = new SparkZeppelinContext(sc, sparkShims, hooks, + Integer.parseInt(getProperty("zeppelin.spark.maxResult"))); + this.innerInterpreter.bind("z", z.getClass().getCanonicalName(), z, + Lists.newArrayList("@transient")); } catch (Exception e) { LOGGER.error("Fail to open SparkInterpreter", ExceptionUtils.getStackTrace(e)); throw new InterpreterException("Fail to open SparkInterpreter", e); http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/interpreter/src/main/java/org/apache/zeppelin/spark/OldSparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/OldSparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/OldSparkInterpreter.java index deb29f4..0366e3b 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/OldSparkInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/OldSparkInterpreter.java @@ -703,14 +703,15 @@ public class OldSparkInterpreter extends AbstractSparkInterpreter { } sparkVersion = SparkVersion.fromVersionString(sc.version()); - sqlc = getSQLContext(); - dep = getDependencyResolver(); - hooks = getInterpreterGroup().getInterpreterHookRegistry(); + sparkUrl = getSparkUIUrl(); + sparkShims = SparkShims.getInstance(sc.version()); + sparkShims.setupSparkListener(sc.master(), sparkUrl, InterpreterContext.get()); + numReferenceOfSparkContext.incrementAndGet(); - z = new SparkZeppelinContext(sc, hooks, + z = new SparkZeppelinContext(sc, sparkShims, hooks, Integer.parseInt(getProperty("zeppelin.spark.maxResult"))); interpret("@transient val _binder = new java.util.HashMap[String, Object]()"); @@ -817,10 +818,6 @@ public class OldSparkInterpreter extends AbstractSparkInterpreter { } } - sparkUrl = getSparkUIUrl(); - sparkShims = SparkShims.getInstance(sc.version()); - sparkShims.setupSparkListener(sc.master(), sparkUrl, InterpreterContext.get()); - numReferenceOfSparkContext.incrementAndGet(); } public String getSparkUIUrl() { http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java index 0fa9185..04eb844 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java @@ -30,7 +30,6 @@ import org.apache.zeppelin.scheduler.SchedulerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.List; import java.util.Properties; @@ -72,32 +71,24 @@ public class SparkSqlInterpreter extends Interpreter { SparkContext sc = sqlc.sparkContext(); sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool")); sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false); - Object rdd = null; + try { - // method signature of sqlc.sql() is changed - // from def sql(sqlText: String): SchemaRDD (1.2 and prior) - // to def sql(sqlText: String): DataFrame (1.3 and later). - // Therefore need to use reflection to keep binary compatibility for all spark versions. - Method sqlMethod = sqlc.getClass().getMethod("sql", String.class); - String effectiveString = Boolean.parseBoolean(getProperty("zeppelin.spark.sql.interpolation")) ? - interpolate(st, context.getResourcePool()) : st; - rdd = sqlMethod.invoke(sqlc, effectiveString); - } catch (InvocationTargetException ite) { + String effectiveSQL = Boolean.parseBoolean(getProperty("zeppelin.spark.sql.interpolation")) ? + interpolate(st, context.getResourcePool()) : st; + Method method = sqlc.getClass().getMethod("sql", String.class); + String msg = sparkInterpreter.getZeppelinContext().showData( + method.invoke(sqlc, effectiveSQL)); + sc.clearJobGroup(); + return new InterpreterResult(Code.SUCCESS, msg); + } catch (Exception e) { if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) { - throw new InterpreterException(ite); + throw new InterpreterException(e); } - logger.error("Invocation target exception", ite); - String msg = ite.getTargetException().getMessage() + logger.error("Invocation target exception", e); + String msg = e.getMessage() + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace"; return new InterpreterResult(Code.ERROR, msg); - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException e) { - throw new InterpreterException(e); } - - String msg = sparkInterpreter.getZeppelinContext().showData(rdd); - sc.clearJobGroup(); - return new InterpreterResult(Code.SUCCESS, msg); } @Override http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java index 3161ab6..87d5b16 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java @@ -19,24 +19,28 @@ package org.apache.zeppelin.spark; import com.google.common.collect.Lists; import org.apache.spark.SparkContext; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.zeppelin.annotation.ZeppelinApi; import org.apache.zeppelin.display.AngularObjectWatcher; -import org.apache.zeppelin.display.Input; import org.apache.zeppelin.display.ui.OptionInput; -import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.BaseZeppelinContext; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterHookRegistry; import scala.Tuple2; import scala.Unit; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; -import static scala.collection.JavaConversions.asJavaCollection; import static scala.collection.JavaConversions.asJavaIterable; import static scala.collection.JavaConversions.collectionAsScalaIterable; + /** * ZeppelinContext for Spark */ @@ -45,14 +49,16 @@ public class SparkZeppelinContext extends BaseZeppelinContext { private SparkContext sc; private List<Class> supportedClasses; private Map<String, String> interpreterClassMap; + private SparkShims sparkShims; public SparkZeppelinContext( SparkContext sc, + SparkShims sparkShims, InterpreterHookRegistry hooks, int maxResult) { super(hooks, maxResult); this.sc = sc; - + this.sparkShims = sparkShims; interpreterClassMap = new HashMap(); interpreterClassMap.put("spark", "org.apache.zeppelin.spark.SparkInterpreter"); interpreterClassMap.put("sql", "org.apache.zeppelin.spark.SparkSqlInterpreter"); @@ -70,13 +76,8 @@ public class SparkZeppelinContext extends BaseZeppelinContext { } catch (ClassNotFoundException e) { } - try { - supportedClasses.add(this.getClass().forName("org.apache.spark.sql.SchemaRDD")); - } catch (ClassNotFoundException e) { - } - if (supportedClasses.isEmpty()) { - throw new RuntimeException("Can not load Dataset/DataFrame/SchemaRDD class"); + throw new RuntimeException("Can not load Dataset/DataFrame class"); } } @@ -91,88 +92,8 @@ public class SparkZeppelinContext extends BaseZeppelinContext { } @Override - public String showData(Object df) { - Object[] rows = null; - Method take; - String jobGroup = Utils.buildJobGroupId(interpreterContext); - sc.setJobGroup(jobGroup, "Zeppelin", false); - - try { - // convert it to DataFrame if it is Dataset, as we will iterate all the records - // and assume it is type Row. - if (df.getClass().getCanonicalName().equals("org.apache.spark.sql.Dataset")) { - Method convertToDFMethod = df.getClass().getMethod("toDF"); - df = convertToDFMethod.invoke(df); - } - take = df.getClass().getMethod("take", int.class); - rows = (Object[]) take.invoke(df, maxResult + 1); - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException | ClassCastException e) { - sc.clearJobGroup(); - throw new RuntimeException(e); - } - - List<Attribute> columns = null; - // get field names - try { - // Use reflection because of classname returned by queryExecution changes from - // Spark <1.5.2 org.apache.spark.sql.SQLContext$QueryExecution - // Spark 1.6.0> org.apache.spark.sql.hive.HiveContext$QueryExecution - Object qe = df.getClass().getMethod("queryExecution").invoke(df); - Object a = qe.getClass().getMethod("analyzed").invoke(qe); - scala.collection.Seq seq = (scala.collection.Seq) a.getClass().getMethod("output").invoke(a); - - columns = (List<Attribute>) scala.collection.JavaConverters.seqAsJavaListConverter(seq) - .asJava(); - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException e) { - throw new RuntimeException(e); - } - - StringBuilder msg = new StringBuilder(); - msg.append("%table "); - for (Attribute col : columns) { - msg.append(col.name() + "\t"); - } - String trim = msg.toString().trim(); - msg = new StringBuilder(trim); - msg.append("\n"); - - // ArrayType, BinaryType, BooleanType, ByteType, DecimalType, DoubleType, DynamicType, - // FloatType, FractionalType, IntegerType, IntegralType, LongType, MapType, NativeType, - // NullType, NumericType, ShortType, StringType, StructType - - try { - for (int r = 0; r < maxResult && r < rows.length; r++) { - Object row = rows[r]; - Method isNullAt = row.getClass().getMethod("isNullAt", int.class); - Method apply = row.getClass().getMethod("apply", int.class); - - for (int i = 0; i < columns.size(); i++) { - if (!(Boolean) isNullAt.invoke(row, i)) { - msg.append(apply.invoke(row, i).toString()); - } else { - msg.append("null"); - } - if (i != columns.size() - 1) { - msg.append("\t"); - } - } - msg.append("\n"); - } - } catch (NoSuchMethodException | SecurityException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException e) { - throw new RuntimeException(e); - } - - if (rows.length > maxResult) { - msg.append("\n"); - msg.append(ResultMessages.getExceedsLimitRowsMessage(maxResult, "zeppelin.spark.maxResult")); - } - // append %text at the end, otherwise the following output will be put in table as well. - msg.append("\n%text "); - sc.clearJobGroup(); - return msg.toString(); + public String showData(Object obj) { + return sparkShims.showDataFrame(obj, maxResult); } @ZeppelinApi @@ -215,7 +136,7 @@ public class SparkZeppelinContext extends BaseZeppelinContext { @ZeppelinApi public Object noteSelect(String name, Object defaultValue, - scala.collection.Iterable<Tuple2<Object, String>> options) { + scala.collection.Iterable<Tuple2<Object, String>> options) { return noteSelect(name, defaultValue, tuplesToParamOptions(options)); } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java index 2b17ecd..48be45b 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java @@ -184,31 +184,31 @@ public class NewSparkInterpreterTest { assertEquals(InterpreterResult.Code.SUCCESS, result.code()); result = interpreter.interpret( - "val df = sqlContext.createDataFrame(Seq((1,\"a\"),(2,\"b\")))\n" + + "val df = sqlContext.createDataFrame(Seq((1,\"a\"),(2, null)))\n" + "df.show()", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); assertTrue(output.contains( - "+---+---+\n" + - "| _1| _2|\n" + - "+---+---+\n" + - "| 1| a|\n" + - "| 2| b|\n" + - "+---+---+")); + "+---+----+\n" + + "| _1| _2|\n" + + "+---+----+\n" + + "| 1| a|\n" + + "| 2|null|\n" + + "+---+----+")); } else if (version.contains("String = 2.")) { result = interpreter.interpret("spark", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); result = interpreter.interpret( - "val df = spark.createDataFrame(Seq((1,\"a\"),(2,\"b\")))\n" + + "val df = spark.createDataFrame(Seq((1,\"a\"),(2, null)))\n" + "df.show()", getInterpreterContext()); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); assertTrue(output.contains( - "+---+---+\n" + - "| _1| _2|\n" + - "+---+---+\n" + - "| 1| a|\n" + - "| 2| b|\n" + - "+---+---+")); + "+---+----+\n" + + "| _1| _2|\n" + + "+---+----+\n" + + "| 1| a|\n" + + "| 2|null|\n" + + "+---+----+")); } // ZeppelinContext @@ -216,7 +216,7 @@ public class NewSparkInterpreterTest { assertEquals(InterpreterResult.Code.SUCCESS, result.code()); assertEquals(InterpreterResult.Type.TABLE, messageOutput.getType()); messageOutput.flush(); - assertEquals("_1\t_2\n1\ta\n2\tb\n", messageOutput.toInterpreterResultMessage().getData()); + assertEquals("_1\t_2\n1\ta\n2\tnull\n", messageOutput.toInterpreterResultMessage().getData()); context = getInterpreterContext(); result = interpreter.interpret("z.input(\"name\", \"default_name\")", context); http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkShimsTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkShimsTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkShimsTest.java index ad0c534..ccebac3 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkShimsTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkShimsTest.java @@ -94,6 +94,11 @@ public class SparkShimsTest { public void setupSparkListener(String master, String sparkWebUrl, InterpreterContext context) {} + + @Override + public String showDataFrame(Object obj, int maxResult) { + return null; + } }; assertEquals(expected, sparkShims.supportYarn6615(version)); } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/spark-shims/src/main/scala/org/apache/zeppelin/spark/SparkShims.java ---------------------------------------------------------------------- diff --git a/spark/spark-shims/src/main/scala/org/apache/zeppelin/spark/SparkShims.java b/spark/spark-shims/src/main/scala/org/apache/zeppelin/spark/SparkShims.java index 6d45b06..d308762 100644 --- a/spark/spark-shims/src/main/scala/org/apache/zeppelin/spark/SparkShims.java +++ b/spark/spark-shims/src/main/scala/org/apache/zeppelin/spark/SparkShims.java @@ -20,6 +20,7 @@ package org.apache.zeppelin.spark; import org.apache.hadoop.util.VersionInfo; import org.apache.hadoop.util.VersionUtil; import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.ResultMessages; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -85,6 +86,9 @@ public abstract class SparkShims { String sparkWebUrl, InterpreterContext context); + public abstract String showDataFrame(Object obj, int maxResult); + + protected String getNoteId(String jobgroupId) { int indexOf = jobgroupId.indexOf("-"); int secondIndex = jobgroupId.indexOf("-", indexOf + 1); http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/spark1-shims/pom.xml ---------------------------------------------------------------------- diff --git a/spark/spark1-shims/pom.xml b/spark/spark1-shims/pom.xml index 93640c6..559b8d8 100644 --- a/spark/spark1-shims/pom.xml +++ b/spark/spark1-shims/pom.xml @@ -55,6 +55,13 @@ </dependency> <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + + <dependency> <groupId>org.apache.zeppelin</groupId> <artifactId>zeppelin-interpreter</artifactId> <version>${project.version}</version> http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java ---------------------------------------------------------------------- diff --git a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java index 7c922aa..db0727c 100644 --- a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java +++ b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java @@ -18,10 +18,16 @@ package org.apache.zeppelin.spark; +import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkContext; import org.apache.spark.scheduler.SparkListenerJobStart; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; import org.apache.spark.ui.jobs.JobProgressListener; import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.ResultMessages; + +import java.util.List; public class Spark1Shims extends SparkShims { @@ -36,4 +42,37 @@ public class Spark1Shims extends SparkShims { } }); } + + @Override + public String showDataFrame(Object obj, int maxResult) { + if (obj instanceof DataFrame) { + DataFrame df = (DataFrame) obj; + String[] columns = df.columns(); + List<Row> rows = df.takeAsList(maxResult + 1); + + StringBuilder msg = new StringBuilder(); + msg.append("%table "); + msg.append(StringUtils.join(columns, "\t")); + msg.append("\n"); + for (Row row : rows) { + for (int i = 0; i < row.size(); ++i) { + msg.append(row.get(i)); + if (i != row.size() -1) { + msg.append("\t"); + } + } + msg.append("\n"); + } + + if (rows.size() > maxResult) { + msg.append("\n"); + msg.append(ResultMessages.getExceedsLimitRowsMessage(maxResult, "zeppelin.spark.maxResult")); + } + // append %text at the end, otherwise the following output will be put in table as well. + msg.append("\n%text "); + return msg.toString(); + } else { + return obj.toString(); + } + } } http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/spark2-shims/pom.xml ---------------------------------------------------------------------- diff --git a/spark/spark2-shims/pom.xml b/spark/spark2-shims/pom.xml index 000e3ab..31249a8 100644 --- a/spark/spark2-shims/pom.xml +++ b/spark/spark2-shims/pom.xml @@ -54,6 +54,13 @@ </dependency> <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + + <dependency> <groupId>org.apache.zeppelin</groupId> <artifactId>zeppelin-interpreter</artifactId> <version>${project.version}</version> http://git-wip-us.apache.org/repos/asf/zeppelin/blob/76e0027c/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java ---------------------------------------------------------------------- diff --git a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java index 63bd688..177b0ac 100644 --- a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java +++ b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java @@ -18,10 +18,17 @@ package org.apache.zeppelin.spark; +import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkContext; import org.apache.spark.scheduler.SparkListener; import org.apache.spark.scheduler.SparkListenerJobStart; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.ResultMessages; + +import java.util.List; public class Spark2Shims extends SparkShims { @@ -36,4 +43,38 @@ public class Spark2Shims extends SparkShims { } }); } + + @Override + public String showDataFrame(Object obj, int maxResult) { + if (obj instanceof Dataset) { + Dataset<Row> df = ((Dataset) obj).toDF(); + String[] columns = df.columns(); + List<Row> rows = df.takeAsList(maxResult + 1); + + StringBuilder msg = new StringBuilder(); + msg.append("%table "); + msg.append(StringUtils.join(columns, "\t")); + msg.append("\n"); + for (Row row : rows) { + for (int i = 0; i < row.size(); ++i) { + msg.append(row.get(i)); + if (i != row.size() -1) { + msg.append("\t"); + } + } + msg.append("\n"); + } + + if (rows.size() > maxResult) { + msg.append("\n"); + msg.append(ResultMessages.getExceedsLimitRowsMessage(maxResult, "zeppelin.spark.maxResult")); + } + // append %text at the end, otherwise the following output will be put in table as well. + msg.append("\n%text "); + return msg.toString(); + } else { + return obj.toString(); + } + } + }