Repository: spark
Updated Branches:
  refs/heads/branch-1.1 6cbf83c05 -> 9c06c7230


[SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd

Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), 
repartition()) can not been called in Python easily, there is no way to specify 
the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should 
also be JavaSchemaRDD.

In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for 
it. If some methods are missing from JavaSchemaRDD, then it's called by 
_schema_rdd.baseSchemaRDD().xxx().

BTW, Do we need JavaSQLContext?

Author: Davies Liu <davies....@gmail.com>

Closes #2369 from davies/fix_schemardd and squashes the following commits:

abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd

(cherry picked from commit 885d1621bc06bc1f009c9707c3452eac26baf828)
Signed-off-by: Josh Rosen <joshro...@apache.org>

Conflicts:
        python/pyspark/tests.py


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9c06c723
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9c06c723
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9c06c723

Branch: refs/heads/branch-1.1
Commit: 9c06c723018d4ef96ff31eb947226a6273ed8080
Parents: 6cbf83c
Author: Davies Liu <davies....@gmail.com>
Authored: Fri Sep 12 19:05:39 2014 -0700
Committer: Josh Rosen <joshro...@apache.org>
Committed: Fri Sep 12 19:28:45 2014 -0700

----------------------------------------------------------------------
 python/pyspark/sql.py   | 38 ++++++++++++++++++--------------------
 python/pyspark/tests.py | 37 +++++++++++++++++++++++++++++++++++++
 2 files changed, 55 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9c06c723/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 0ff6a54..07b39c9 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -1121,7 +1121,7 @@ class SQLContext:
         batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
         jrdd = self._pythonToJava(rdd._jrdd, batched)
         srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
-        return SchemaRDD(srdd, self)
+        return SchemaRDD(srdd.toJavaSchemaRDD(), self)
 
     def registerRDDAsTable(self, rdd, tableName):
         """Registers the given RDD as a temporary table in the catalog.
@@ -1133,8 +1133,8 @@ class SQLContext:
         >>> sqlCtx.registerRDDAsTable(srdd, "table1")
         """
         if (rdd.__class__ is SchemaRDD):
-            jschema_rdd = rdd._jschema_rdd
-            self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName)
+            srdd = rdd._jschema_rdd.baseSchemaRDD()
+            self._ssql_ctx.registerRDDAsTable(srdd, tableName)
         else:
             raise ValueError("Can only register SchemaRDD as table")
 
@@ -1150,7 +1150,7 @@ class SQLContext:
         >>> sorted(srdd.collect()) == sorted(srdd2.collect())
         True
         """
-        jschema_rdd = self._ssql_ctx.parquetFile(path)
+        jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
         return SchemaRDD(jschema_rdd, self)
 
     def jsonFile(self, path, schema=None):
@@ -1206,11 +1206,11 @@ class SQLContext:
         [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
         """
         if schema is None:
-            jschema_rdd = self._ssql_ctx.jsonFile(path)
+            srdd = self._ssql_ctx.jsonFile(path)
         else:
             scala_datatype = self._ssql_ctx.parseDataType(str(schema))
-            jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype)
-        return SchemaRDD(jschema_rdd, self)
+            srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
+        return SchemaRDD(srdd.toJavaSchemaRDD(), self)
 
     def jsonRDD(self, rdd, schema=None):
         """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
@@ -1274,11 +1274,11 @@ class SQLContext:
         keyed._bypass_serializer = True
         jrdd = keyed._jrdd.map(self._jvm.BytesToString())
         if schema is None:
-            jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+            srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
         else:
             scala_datatype = self._ssql_ctx.parseDataType(str(schema))
-            jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
-        return SchemaRDD(jschema_rdd, self)
+            srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+        return SchemaRDD(srdd.toJavaSchemaRDD(), self)
 
     def sql(self, sqlQuery):
         """Return a L{SchemaRDD} representing the result of the given query.
@@ -1289,7 +1289,7 @@ class SQLContext:
         >>> srdd2.collect()
         [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
         """
-        return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+        return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
 
     def table(self, tableName):
         """Returns the specified table as a L{SchemaRDD}.
@@ -1300,7 +1300,7 @@ class SQLContext:
         >>> sorted(srdd.collect()) == sorted(srdd2.collect())
         True
         """
-        return SchemaRDD(self._ssql_ctx.table(tableName), self)
+        return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), 
self)
 
     def cacheTable(self, tableName):
         """Caches the specified table in-memory."""
@@ -1352,7 +1352,7 @@ class HiveContext(SQLContext):
         warnings.warn("hiveql() is deprecated as the sql function now parses 
using HiveQL by" +
                       "default. The SQL dialect for parsing can be set using 
'spark.sql.dialect'",
                       DeprecationWarning)
-        return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self)
+        return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), 
self)
 
     def hql(self, hqlQuery):
         """
@@ -1508,6 +1508,8 @@ class SchemaRDD(RDD):
     def __init__(self, jschema_rdd, sql_ctx):
         self.sql_ctx = sql_ctx
         self._sc = sql_ctx._sc
+        clsName = jschema_rdd.getClass().getName()
+        assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be 
JavaSchemaRDD"
         self._jschema_rdd = jschema_rdd
 
         self.is_cached = False
@@ -1524,7 +1526,7 @@ class SchemaRDD(RDD):
         L{pyspark.rdd.RDD} super class (map, filter, etc.).
         """
         if not hasattr(self, '_lazy_jrdd'):
-            self._lazy_jrdd = self._jschema_rdd.javaToPython()
+            self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
         return self._lazy_jrdd
 
     @property
@@ -1580,7 +1582,7 @@ class SchemaRDD(RDD):
     def schema(self):
         """Returns the schema of this SchemaRDD (represented by
         a L{StructType})."""
-        return _parse_datatype_string(self._jschema_rdd.schema().toString())
+        return 
_parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
 
     def schemaString(self):
         """Returns the output schema in the tree format."""
@@ -1631,8 +1633,6 @@ class SchemaRDD(RDD):
         rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
 
         schema = self.schema()
-        import pickle
-        pickle.loads(pickle.dumps(schema))
 
         def applySchema(_, it):
             cls = _create_cls(schema)
@@ -1669,10 +1669,8 @@ class SchemaRDD(RDD):
 
     def getCheckpointFile(self):
         checkpointFile = self._jschema_rdd.getCheckpointFile()
-        if checkpointFile.isDefined():
+        if checkpointFile.isPresent():
             return checkpointFile.get()
-        else:
-            return None
 
     def coalesce(self, numPartitions, shuffle=False):
         rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)

http://git-wip-us.apache.org/repos/asf/spark/blob/9c06c723/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1db922f..8f0a351 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -41,6 +41,8 @@ from pyspark.context import SparkContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, 
MarshalSerializer, PickleSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
+from pyspark.storagelevel import StorageLevel
+from pyspark.sql import SQLContext
 
 _have_scipy = False
 _have_numpy = False
@@ -469,6 +471,41 @@ class TestRDDFunctions(PySparkTestCase):
         self.assertRaises(TypeError, lambda: rdd.histogram(2))
 
 
+class TestSQL(PySparkTestCase):
+
+    def setUp(self):
+        PySparkTestCase.setUp(self)
+        self.sqlCtx = SQLContext(self.sc)
+
+    def test_basic_functions(self):
+        rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+        srdd = self.sqlCtx.jsonRDD(rdd)
+        srdd.count()
+        srdd.collect()
+        srdd.schemaString()
+        srdd.schema()
+
+        # cache and checkpoint
+        self.assertFalse(srdd.is_cached)
+        srdd.persist(StorageLevel.MEMORY_ONLY_SER)
+        srdd.unpersist()
+        srdd.cache()
+        self.assertTrue(srdd.is_cached)
+        self.assertFalse(srdd.isCheckpointed())
+        self.assertEqual(None, srdd.getCheckpointFile())
+
+        srdd = srdd.coalesce(2, True)
+        srdd = srdd.repartition(3)
+        srdd = srdd.distinct()
+        srdd.intersection(srdd)
+        self.assertEqual(2, srdd.count())
+
+        srdd.registerTempTable("temp")
+        srdd = self.sqlCtx.sql("select foo from temp")
+        srdd.count()
+        srdd.collect()
+
+
 class TestIO(PySparkTestCase):
 
     def test_stdout_redirection(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to