Repository: spark Updated Branches: refs/heads/master 72adfbf94 -> 2d96d35dc
[SPARK-17946][PYSPARK] Python crossJoin API similar to Scala ## What changes were proposed in this pull request? Add a crossJoin function to the DataFrame API similar to that in Scala. Joins with no condition (cartesian products) must be specified with the crossJoin API ## How was this patch tested? Added python tests to ensure that an AnalysisException if a cartesian product is specified without crossJoin(), and that cartesian products can execute if specified via crossJoin() (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: Srinath Shankar <srin...@databricks.com> Closes #15493 from srinathshankar/crosspython. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d96d35d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d96d35d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d96d35d Branch: refs/heads/master Commit: 2d96d35dc0fed6df249606d9ce9272c0f0109fa2 Parents: 72adfbf Author: Srinath Shankar <srin...@databricks.com> Authored: Fri Oct 14 18:24:47 2016 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Fri Oct 14 18:24:47 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 26 ++++++++++++++++---- python/pyspark/sql/tests.py | 15 ++++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- 3 files changed, 36 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2d96d35d/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7606ac0..29710ac 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -651,6 +651,25 @@ class DataFrame(object): return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @ignore_unicode_prefix + @since(2.1) + def crossJoin(self, other): + """Returns the cartesian product with another :class:`DataFrame`. + + :param other: Right side of the cartesian product. + + >>> df.select("age", "name").collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df2.select("name", "height").collect() + [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85)] + >>> df.crossJoin(df2.select("height")).select("age", "name", "height").collect() + [Row(age=2, name=u'Alice', height=80), Row(age=2, name=u'Alice', height=85), + Row(age=5, name=u'Bob', height=80), Row(age=5, name=u'Bob', height=85)] + """ + + jdf = self._jdf.crossJoin(other._jdf) + return DataFrame(jdf, self.sql_ctx) + + @ignore_unicode_prefix @since(1.3) def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. @@ -690,14 +709,11 @@ class DataFrame(object): on = self._jseq(on) else: assert isinstance(on[0], Column), "on should be Column or list of Column" - if len(on) > 1: - on = reduce(lambda x, y: x.__and__(y), on) - else: - on = on[0] + on = reduce(lambda x, y: x.__and__(y), on) on = on._jc if on is None and how is None: - jdf = self._jdf.crossJoin(other._jdf) + jdf = self._jdf.join(other._jdf) else: if how is None: how = "inner" http://git-wip-us.apache.org/repos/asf/spark/blob/2d96d35d/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 51d5e7a..3d46b85 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1466,7 +1466,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) # no join key -- should not be a broadcast join - plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan() self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) # planner should not crash without a join @@ -1514,6 +1514,19 @@ class SQLTests(ReusedPySparkTestCase): df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + # Cartesian products require cross join syntax + def test_require_cross(self): + from pyspark.sql.functions import broadcast + + df1 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1")], ("key", "value")) + + # joins without conditions require cross join syntax + self.assertRaises(AnalysisException, lambda: df1.join(df2).collect()) + + # works with crossJoin + self.assertEqual(1, df1.crossJoin(df2).count()) + def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") http://git-wip-us.apache.org/repos/asf/spark/blob/2d96d35d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 70c9cf5..7ae3275 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -774,7 +774,7 @@ class Dataset[T] private[sql]( * @param right Right side of the join operation. * * @group untypedrel - * @since 2.0.0 + * @since 2.1.0 */ def crossJoin(right: Dataset[_]): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Cross, None) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org