Repository: incubator-zeppelin
Updated Branches:
  refs/heads/master f04425bc9 -> 2d0fc518e


[ZEPPELIN-185] ZeppelinContext methods like z.show are not working with 
DataFrame in pyspark

(opening a new PR to have a start history)

z.show() doesn’t seem to work properly in Python – I see the same error 
below: “AttributeError: 'DataFrame' object has no attribute '_get_object_id'"
#Python/PySpark – doesn’t work
rdd = sc.parallelize(["1","2","3"])
Data = Row('first')
df = sqlContext.createDataFrame(rdd.map(lambda d: Data(d)))
print df
print df.collect()
z.show(df)
AttributeError: 'DataFrame' object has no attribute ‘_get_object_id'

More generally, ZeppelinContext methods are not working with Python objects 
since Py4J would need to know how to serialize it

It turns out the error is caused by Py4J trying to auto convert the DataFrame, 
which fails since it can only do that for simple types.
Instead of getting conversion to work, the better approach is to pass along the 
inner java object instead. To do that we intercept the call on the python side 
with a wrapper object instead of letting Py4J handle it.
As per comment, adding container/dictionary methods to allow for string passing 
using ZeppelinContext

Author: Felix Cheung <[email protected]>

Closes #178 from felixcheung/zpyspark and squashes the following commits:

ddc5bb2 [Felix Cheung] small fixes to python script
df6588a [Felix Cheung] [ZEPPELIN-185] ZeppelinContext methods like z.show are 
not working with DataFrame in pyspark It turns out the error is caused by Py4J 
trying to auto convert the DataFrame, which fails since it can only do that for 
simple types. Instead of getting conversion to work, the better approach is to 
pass along the inner java object instead. To do that we intercept the call on 
the python side with a wrapper object instead of letting Py4J handle it. As per 
comment, adding container/dictionary methods to allow for string passing using 
ZeppelinContext


Project: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/commit/2d0fc518
Tree: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/tree/2d0fc518
Diff: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/diff/2d0fc518

Branch: refs/heads/master
Commit: 2d0fc518e773225c3c210baf4a92ea84e78df019
Parents: f04425b
Author: Felix Cheung <[email protected]>
Authored: Mon Aug 3 01:20:45 2015 -0700
Committer: Lee moon soo <[email protected]>
Committed: Fri Aug 7 15:43:05 2015 -0700

----------------------------------------------------------------------
 .../zeppelin/spark/SparkSqlInterpreter.java     |  2 +-
 .../apache/zeppelin/spark/ZeppelinContext.java  | 62 ++++++++--------
 .../main/resources/python/zeppelin_pyspark.py   | 76 ++++++++++++++------
 3 files changed, 90 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/2d0fc518/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
----------------------------------------------------------------------
diff --git 
a/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java 
b/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
index e60ff2b..d3bda44 100644
--- a/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
+++ b/spark/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java
@@ -131,7 +131,7 @@ public class SparkSqlInterpreter extends Interpreter {
 
 
     Object rdd = sqlc.sql(st);
-    String msg = ZeppelinContext.showRDD(sc, context, rdd, maxResult);
+    String msg = ZeppelinContext.showDF(sc, context, rdd, maxResult);
     return new InterpreterResult(Code.SUCCESS, msg);
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/2d0fc518/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java
----------------------------------------------------------------------
diff --git a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java 
b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java
index 6cb94d9..0cb2f16 100644
--- a/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java
+++ b/spark/src/main/java/org/apache/zeppelin/spark/ZeppelinContext.java
@@ -277,26 +277,30 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
     }
 
     if (cls.isInstance(o)) {
-      out.print(showRDD(sc, interpreterContext, o, maxResult));
+      out.print(showDF(sc, interpreterContext, o, maxResult));
     } else {
       out.print(o.toString());
     }
   }
 
-  public static String showRDD(SparkContext sc,
+  public static String showDF(ZeppelinContext z, Object df) {
+    return showDF(z.sc, z.interpreterContext, df, z.maxResult);
+  }
+
+  public static String showDF(SparkContext sc,
       InterpreterContext interpreterContext,
-      Object rdd, int maxResult) {
+      Object df, int maxResult) {
     Object[] rows = null;
     Method take;
     String jobGroup = "zeppelin-" + interpreterContext.getParagraphId();
     sc.setJobGroup(jobGroup, "Zeppelin", false);
 
     try {
-      take = rdd.getClass().getMethod("take", int.class);
-      rows = (Object[]) take.invoke(rdd, maxResult + 1);
+      take = df.getClass().getMethod("take", int.class);
+      rows = (Object[]) take.invoke(df, maxResult + 1);
 
     } catch (NoSuchMethodException | SecurityException | IllegalAccessException
-        | IllegalArgumentException | InvocationTargetException e) {
+        | IllegalArgumentException | InvocationTargetException | 
ClassCastException e) {
       sc.clearJobGroup();
       throw new InterpreterException(e);
     }
@@ -307,8 +311,8 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
     Method queryExecution;
     QueryExecution qe;
     try {
-      queryExecution = rdd.getClass().getMethod("queryExecution");
-      qe = (QueryExecution) queryExecution.invoke(rdd);
+      queryExecution = df.getClass().getMethod("queryExecution");
+      qe = (QueryExecution) queryExecution.invoke(df);
     } catch (NoSuchMethodException | SecurityException | IllegalAccessException
         | IllegalArgumentException | InvocationTargetException e) {
       throw new InterpreterException(e);
@@ -492,7 +496,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
       return ao.get();
     }
   }
-  
+
   /**
    * Get angular object. Look up global registry
    * @param name variable name
@@ -506,8 +510,8 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
     } else {
       return ao.get();
     }
-  }  
-  
+  }
+
   /**
    * Create angular variable in local registry and bind with front end Angular 
display system.
    * If variable exists, it'll be overwritten.
@@ -517,7 +521,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularBind(String name, Object o) {
     angularBind(name, o, interpreterContext.getNoteId());
   }
-  
+
   /**
    * Create angular variable in global registry and bind with front end 
Angular display system.
    * If variable exists, it'll be overwritten.
@@ -527,7 +531,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularBindGlobal(String name, Object o) {
     angularBind(name, o, (String) null);
   }
- 
+
   /**
    * Create angular variable in local registry and bind with front end Angular 
display system.
    * If variable exists, value will be overwritten and watcher will be added.
@@ -538,7 +542,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularBind(String name, Object o, AngularObjectWatcher watcher) 
{
     angularBind(name, o, interpreterContext.getNoteId(), watcher);
   }
-  
+
   /**
    * Create angular variable in global registry and bind with front end 
Angular display system.
    * If variable exists, value will be overwritten and watcher will be added.
@@ -558,9 +562,9 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularWatch(String name, AngularObjectWatcher watcher) {
     angularWatch(name, interpreterContext.getNoteId(), watcher);
   }
-  
+
   /**
-   * Add watcher into angular variable (global registry) 
+   * Add watcher into angular variable (global registry)
    * @param name name of the variable
    * @param watcher watcher
    */
@@ -573,7 +577,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
       final scala.Function2<Object, Object, Unit> func) {
     angularWatch(name, interpreterContext.getNoteId(), func);
   }
-  
+
   public void angularWatchGlobal(String name,
       final scala.Function2<Object, Object, Unit> func) {
     angularWatch(name, null, func);
@@ -584,13 +588,13 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
       final scala.Function3<Object, Object, InterpreterContext, Unit> func) {
     angularWatch(name, interpreterContext.getNoteId(), func);
   }
-  
+
   public void angularWatchGlobal(
       String name,
       final scala.Function3<Object, Object, InterpreterContext, Unit> func) {
     angularWatch(name, null, func);
-  } 
-  
+  }
+
   /**
    * Remove watcher from angular variable (local)
    * @param name
@@ -599,7 +603,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularUnwatch(String name, AngularObjectWatcher watcher) {
     angularUnwatch(name, interpreterContext.getNoteId(), watcher);
   }
-  
+
   /**
    * Remove watcher from angular variable (global)
    * @param name
@@ -617,7 +621,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularUnwatch(String name) {
     angularUnwatch(name, interpreterContext.getNoteId());
   }
-  
+
   /**
    * Remove all watchers for the angular variable (global)
    * @param name
@@ -642,7 +646,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   public void angularUnbindGlobal(String name) {
     angularUnbind(name, null);
   }
-  
+
   /**
    * Create angular variable in local registry and bind with front end Angular 
display system.
    * If variable exists, it'll be overwritten.
@@ -651,14 +655,14 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
    */
   private void angularBind(String name, Object o, String noteId) {
     AngularObjectRegistry registry = 
interpreterContext.getAngularObjectRegistry();
-        
+
     if (registry.get(name, noteId) == null) {
       registry.add(name, o, noteId);
     } else {
       registry.get(name, noteId).set(o);
     }
   }
- 
+
   /**
    * Create angular variable in local registry and bind with front end Angular 
display system.
    * If variable exists, value will be overwritten and watcher will be added.
@@ -668,7 +672,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
    */
   private void angularBind(String name, Object o, String noteId, 
AngularObjectWatcher watcher) {
     AngularObjectRegistry registry = 
interpreterContext.getAngularObjectRegistry();
-    
+
     if (registry.get(name, noteId) == null) {
       registry.add(name, o, noteId);
     } else {
@@ -678,13 +682,13 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
   }
 
   /**
-   * Add watcher into angular binding variable 
+   * Add watcher into angular binding variable
    * @param name name of the variable
    * @param watcher watcher
    */
   private void angularWatch(String name, String noteId, AngularObjectWatcher 
watcher) {
     AngularObjectRegistry registry = 
interpreterContext.getAngularObjectRegistry();
-    
+
     if (registry.get(name, noteId) != null) {
       registry.get(name, noteId).addWatcher(watcher);
     }
@@ -715,7 +719,7 @@ public class ZeppelinContext extends HashMap<String, 
Object> {
       }
     };
     angularWatch(name, noteId, w);
-  }  
+  }
 
   /**
    * Remove watcher

http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/2d0fc518/spark/src/main/resources/python/zeppelin_pyspark.py
----------------------------------------------------------------------
diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py 
b/spark/src/main/resources/python/zeppelin_pyspark.py
index 802015d..794fbc7 100644
--- a/spark/src/main/resources/python/zeppelin_pyspark.py
+++ b/spark/src/main/resources/python/zeppelin_pyspark.py
@@ -31,6 +31,58 @@ from pyspark.serializers import MarshalSerializer, 
PickleSerializer
 # for back compatibility
 from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
 
+class Logger(object):
+  def __init__(self):
+    self.out = ""
+
+  def write(self, message):
+    self.out = self.out + message
+
+  def get(self):
+    return self.out
+
+  def reset(self):
+    self.out = ""
+
+
+class PyZeppelinContext(dict):
+  def __init__(self, zc):
+    self.z = zc
+
+  def show(self, obj):
+    from pyspark.sql import DataFrame
+    if isinstance(obj, DataFrame):
+      print 
gateway.jvm.org.apache.zeppelin.spark.ZeppelinContext.showDF(self.z, obj._jdf)
+    else:
+      print str(obj)
+
+  # By implementing special methods it makes operating on it more Pythonic
+  def __setitem__(self, key, item):
+    self.z.put(key, item)
+
+  def __getitem__(self, key):
+    return self.z.get(key)
+
+  def __delitem__(self, key):
+    self.z.remove(key)
+
+  def __contains__(self, item):
+    return self.z.containsKey(item)
+
+  def add(self, key, value):
+    self.__setitem__(key, value)
+
+  def put(self, key, value):
+    self.__setitem__(key, value)
+
+  def get(self, key):
+    return self.__getitem__(key)
+
+
+output = Logger()
+sys.stdout = output
+sys.stderr = output
+
 client = GatewayClient(port=int(sys.argv[1]))
 sparkVersion = sys.argv[2]
 
@@ -71,24 +123,7 @@ sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
 sqlc = SQLContext(sc, intp.getSQLContext())
 sqlContext = sqlc
 
-z = intp.getZeppelinContext()
-
-class Logger(object):
-  def __init__(self):
-    self.out = ""
-
-  def write(self, message):
-    self.out = self.out + message
-
-  def get(self):
-    return self.out
-
-  def reset(self):
-    self.out = ""
-
-output = Logger()
-sys.stdout = output
-sys.stderr = output
+z = PyZeppelinContext(intp.getZeppelinContext())
 
 while True :
   req = intp.getStatements()
@@ -98,11 +133,12 @@ while True :
     final_code = None
 
     for s in stmts:
-      if s == None or len(s.strip()) == 0:
+      if s == None:
         continue
 
       # skip comment
-      if s.strip().startswith("#"):
+      s_stripped = s.strip()
+      if len(s_stripped) == 0 or s_stripped.startswith("#"):
         continue
 
       if final_code:

Reply via email to