Repository: spark Updated Branches: refs/heads/master c3713fde8 -> a2bec6c92
[SPARK-21043][SQL] Add unionByName in Dataset ## What changes were proposed in this pull request? This pr added `unionByName` in `DataSet`. Here is how to use: ``` val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") df1.unionByName(df2).show // output: // +----+----+----+ // |col0|col1|col2| // +----+----+----+ // | 1| 2| 3| // | 6| 4| 5| // +----+----+----+ ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro <yamam...@apache.org> Closes #18300 from maropu/SPARK-21043-2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a2bec6c9 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a2bec6c9 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a2bec6c9 Branch: refs/heads/master Commit: a2bec6c92a063f4a8e9ed75a9f3f06808485b6d7 Parents: c3713fd Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Mon Jul 10 20:16:29 2017 -0700 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Mon Jul 10 20:16:29 2017 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Dataset.scala | 60 ++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 87 ++++++++++++++++++++ 2 files changed, 147 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a2bec6c9/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a777383..7f3ae05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -1735,6 +1736,65 @@ class Dataset[T] private[sql]( } /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * + * This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set + * union (that does deduplication of elements), use this function followed by a [[distinct]]. + * + * The difference between this function and [[union]] is that this function + * resolves columns by name (not by position): + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0") + * df1.unionByName(df2).show + * + * // output: + * // +----+----+----+ + * // |col0|col1|col2| + * // +----+----+----+ + * // | 1| 2| 3| + * // | 6| 4| 5| + * // +----+----+----+ + * }}} + * + * @group typedrel + * @since 2.3.0 + */ + def unionByName(other: Dataset[T]): Dataset[T] = withSetOperator { + // Check column name duplication + val resolver = sparkSession.sessionState.analyzer.resolver + val leftOutputAttrs = logicalPlan.output + val rightOutputAttrs = other.logicalPlan.output + + SchemaUtils.checkColumnNameDuplication( + leftOutputAttrs.map(_.name), + "in the left attributes", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + SchemaUtils.checkColumnNameDuplication( + rightOutputAttrs.map(_.name), + "in the right attributes", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + // Builds a project list for `other` based on `logicalPlan` output names + val rightProjectList = leftOutputAttrs.map { lattr => + rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "${lattr.name}" among """ + + s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + } + } + + // Delegates failure checks to `CheckAnalysis` + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan) + + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, rightChild)) + } + + /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. * http://git-wip-us.apache.org/repos/asf/spark/blob/a2bec6c9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a5a2e1c..5ae2703 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -111,6 +111,93 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("union by name") { + var df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + var df2 = Seq((3, 1, 2)).toDF("c", "a", "b") + val df3 = Seq((2, 3, 1)).toDF("b", "c", "a") + val unionDf = df1.unionByName(df2.unionByName(df3)) + checkAnswer(unionDf, + Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil + ) + + // Check if adjacent unions are combined into a single one + assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1) + + // Check failure cases + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains( + "Union can only be performed on tables with the same number of columns, " + + "but the first table has 2 columns and the second table has 3 columns")) + + df1 = Seq((1, 2, 3)).toDF("a", "b", "c") + df2 = Seq((4, 5, 6)).toDF("a", "c", "d") + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)""")) + } + + test("union by name - type coercion") { + var df1 = Seq((1, "a")).toDF("c0", "c1") + var df2 = Seq((3, 1L)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil) + + df1 = Seq((1, 1.0)).toDF("c0", "c1") + df2 = Seq((8L, 3.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil) + + df1 = Seq((2.0f, 7.4)).toDF("c0", "c1") + df2 = Seq(("a", 4.0)).toDF("c1", "c0") + checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil) + + df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2") + df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1") + val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0") + checkAnswer(df1.unionByName(df2.unionByName(df3)), + Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil + ) + } + + test("union by name - check case sensitivity") { + def checkCaseSensitiveTest(): Unit = { + val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef") + val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB") + checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val errMsg2 = intercept[AnalysisException] { + checkCaseSensitiveTest() + }.getMessage + assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)""")) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkCaseSensitiveTest() + } + } + + test("union by name - check name duplication") { + Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + var df1 = Seq((1, 1)).toDF(c0, c1) + var df2 = Seq((1, 1)).toDF("c0", "c1") + var errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the left attributes:")) + df1 = Seq((1, 1)).toDF("c0", "c1") + df2 = Seq((1, 1)).toDF(c0, c1) + errMsg = intercept[AnalysisException] { + df1.unionByName(df2) + }.getMessage + assert(errMsg.contains("Found duplicate column(s) in the right attributes:")) + } + } + } + test("empty data frame") { assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(spark.emptyDataFrame.count() === 0) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org