Repository: spark
Updated Branches:
  refs/heads/branch-1.3 f9d4efa72 -> 30e7c63f8


[SPARK-6603] [PySpark] [SQL] add SQLContext.udf and deprecate inferSchema() and 
applySchema

This PR create an alias for `registerFunction` as `udf.register`, to be 
consistent with Scala API.

It also deprecated inferSchema() and applySchema(), show an warning for them.

cc rxin

Author: Davies Liu <dav...@databricks.com>

Closes #5273 from davies/udf and squashes the following commits:

476e947 [Davies Liu] address comments
c096fdb [Davies Liu] add SQLContext.udf and deprecate inferSchema() and 
applySchema

(cherry picked from commit f76d2e55b1a67bf5576e1aa001a0b872b9b3895a)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30e7c63f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30e7c63f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30e7c63f

Branch: refs/heads/branch-1.3
Commit: 30e7c63f8487c498a7dba25e121b26d432dd118d
Parents: f9d4efa
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Mar 30 15:47:00 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Mar 30 15:47:06 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/context.py | 87 ++++++++++++++++++++++++++------------
 1 file changed, 60 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/30e7c63f/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 795ef0d..80939a1 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -34,7 +34,7 @@ try:
 except ImportError:
     has_pandas = False
 
-__all__ = ["SQLContext", "HiveContext"]
+__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
 
 
 def _monkey_patch_RDD(sqlCtx):
@@ -56,6 +56,31 @@ def _monkey_patch_RDD(sqlCtx):
     RDD.toDF = toDF
 
 
+class UDFRegistration(object):
+    """Wrapper for register UDF"""
+
+    def __init__(self, sqlCtx):
+        self.sqlCtx = sqlCtx
+
+    def register(self, name, f, returnType=StringType()):
+        """Registers a lambda function as a UDF so it can be used in SQL 
statements.
+
+        In addition to a name and the function itself, the return type can be 
optionally specified.
+        When the return type is not given it default to a string and 
conversion will automatically
+        be done.  For any other return type, the produced object must match 
the specified type.
+
+        >>> sqlCtx.udf.register("stringLengthString", lambda x: len(x))
+        >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+        [Row(c0=u'4')]
+
+        >>> from pyspark.sql.types import IntegerType
+        >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), 
IntegerType())
+        >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+        [Row(c0=4)]
+        """
+        return self.sqlCtx.registerFunction(name, f, returnType)
+
+
 class SQLContext(object):
 
     """Main entry point for Spark SQL functionality.
@@ -118,6 +143,11 @@ class SQLContext(object):
         """
         return self._ssql_ctx.getConf(key, defaultValue)
 
+    @property
+    def udf(self):
+        """Wrapper for register Python function as UDF """
+        return UDFRegistration(self)
+
     def registerFunction(self, name, f, returnType=StringType()):
         """Registers a lambda function as a UDF so it can be used in SQL 
statements.
 
@@ -198,14 +228,12 @@ class SQLContext(object):
         >>> df.collect()[0]
         Row(field1=1, field2=u'row1')
         """
+        warnings.warn("inferSchema is deprecated, please use createDataFrame 
instead")
 
         if isinstance(rdd, DataFrame):
             raise TypeError("Cannot apply schema to DataFrame")
 
-        schema = self._inferSchema(rdd, samplingRatio)
-        converter = _create_converter(schema)
-        rdd = rdd.map(converter)
-        return self.applySchema(rdd, schema)
+        return self.createDataFrame(rdd, None, samplingRatio)
 
     def applySchema(self, rdd, schema):
         """
@@ -230,6 +258,7 @@ class SQLContext(object):
         >>> df.collect()
         [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
         """
+        warnings.warn("applySchema is deprecated, please use createDataFrame 
instead")
 
         if isinstance(rdd, DataFrame):
             raise TypeError("Cannot apply schema to DataFrame")
@@ -237,23 +266,7 @@ class SQLContext(object):
         if not isinstance(schema, StructType):
             raise TypeError("schema should be StructType, but got %s" % schema)
 
-        # take the first few rows to verify schema
-        rows = rdd.take(10)
-        # Row() cannot been deserialized by Pyrolite
-        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ 
== 'Row':
-            rdd = rdd.map(tuple)
-            rows = rdd.take(10)
-
-        for row in rows:
-            _verify_type(row, schema)
-
-        # convert python objects to sql data
-        converter = _python_to_sql_converter(schema)
-        rdd = rdd.map(converter)
-
-        jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
-        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
-        return DataFrame(df, self)
+        return self.createDataFrame(rdd, schema)
 
     def createDataFrame(self, data, schema=None, samplingRatio=None):
         """
@@ -323,22 +336,42 @@ class SQLContext(object):
         if not isinstance(data, RDD):
             try:
                 # data could be list, tuple, generator ...
-                data = self._sc.parallelize(data)
+                rdd = self._sc.parallelize(data)
             except Exception:
                 raise ValueError("cannot create an RDD from type: %s" % 
type(data))
+        else:
+            rdd = data
 
         if schema is None:
-            return self.inferSchema(data, samplingRatio)
+            schema = self._inferSchema(rdd, samplingRatio)
+            converter = _create_converter(schema)
+            rdd = rdd.map(converter)
 
         if isinstance(schema, (list, tuple)):
-            first = data.first()
+            first = rdd.first()
             if not isinstance(first, (list, tuple)):
                 raise ValueError("each row in `rdd` should be list or tuple, "
                                  "but got %r" % type(first))
             row_cls = Row(*schema)
-            schema = self._inferSchema(data.map(lambda r: row_cls(*r)), 
samplingRatio)
+            schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), 
samplingRatio)
 
-        return self.applySchema(data, schema)
+        # take the first few rows to verify schema
+        rows = rdd.take(10)
+        # Row() cannot been deserialized by Pyrolite
+        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ 
== 'Row':
+            rdd = rdd.map(tuple)
+            rows = rdd.take(10)
+
+        for row in rows:
+            _verify_type(row, schema)
+
+        # convert python objects to sql data
+        converter = _python_to_sql_converter(schema)
+        rdd = rdd.map(converter)
+
+        jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
+        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+        return DataFrame(df, self)
 
     def registerDataFrameAsTable(self, rdd, tableName):
         """Registers the given RDD as a temporary table in the catalog.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to