Repository: spark
Updated Branches:
  refs/heads/branch-2.2 a3a5fcfef -> d8bd213f1


[SPARK-20584][PYSPARK][SQL] Python generic hint support

## What changes were proposed in this pull request?

Adds `hint` method to PySpark `DataFrame`.

## How was this patch tested?

Unit tests, doctests.

Author: zero323 <zero...@users.noreply.github.com>

Closes #17850 from zero323/SPARK-20584.

(cherry picked from commit 02bbe73118a39e2fb378aa2002449367a92f6d67)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8bd213f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8bd213f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8bd213f

Branch: refs/heads/branch-2.2
Commit: d8bd213f13279664d50ffa57c1814d0b16fc5d23
Parents: a3a5fcf
Author: zero323 <zero...@users.noreply.github.com>
Authored: Wed May 3 19:15:28 2017 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed May 3 19:15:42 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py | 29 +++++++++++++++++++++++++++++
 python/pyspark/sql/tests.py     | 16 ++++++++++++++++
 2 files changed, 45 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d8bd213f/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f567cc4..d62ba96 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -371,6 +371,35 @@ class DataFrame(object):
         jdf = self._jdf.withWatermark(eventTime, delayThreshold)
         return DataFrame(jdf, self.sql_ctx)
 
+    @since(2.2)
+    def hint(self, name, *parameters):
+        """Specifies some hint on the current DataFrame.
+
+        :param name: A name of the hint.
+        :param parameters: Optional parameters.
+        :return: :class:`DataFrame`
+
+        >>> df.join(df2.hint("broadcast"), "name").show()
+        +----+---+------+
+        |name|age|height|
+        +----+---+------+
+        | Bob|  5|    85|
+        +----+---+------+
+        """
+        if len(parameters) == 1 and isinstance(parameters[0], list):
+            parameters = parameters[0]
+
+        if not isinstance(name, str):
+            raise TypeError("name should be provided as str, got 
{0}".format(type(name)))
+
+        for p in parameters:
+            if not isinstance(p, str):
+                raise TypeError(
+                    "all parameters should be str, got {0} of type 
{1}".format(p, type(p)))
+
+        jdf = self._jdf.hint(name, self._jseq(parameters))
+        return DataFrame(jdf, self.sql_ctx)
+
     @since(1.3)
     def count(self):
         """Returns the number of rows in this :class:`DataFrame`.

http://git-wip-us.apache.org/repos/asf/spark/blob/d8bd213f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index cd92148..2aa2d23 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1906,6 +1906,22 @@ class SQLTests(ReusedPySparkTestCase):
         # planner should not crash without a join
         broadcast(df1)._jdf.queryExecution().executedPlan()
 
+    def test_generic_hints(self):
+        from pyspark.sql import DataFrame
+
+        df1 = self.spark.range(10e10).toDF("id")
+        df2 = self.spark.range(10e10).toDF("id")
+
+        self.assertIsInstance(df1.hint("broadcast"), DataFrame)
+        self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
+
+        # Dummy rules
+        self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
+        self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
+
+        plan = df1.join(df2.hint("broadcast"), 
"id")._jdf.queryExecution().executedPlan()
+        self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
+
     def test_toDF_with_schema_string(self):
         data = [Row(key=i, value=str(i)) for i in range(100)]
         rdd = self.sc.parallelize(data, 5)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to