This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new b0b15475a0a [SPARK-44736][CONNECT] Add Dataset.explode to Spark 
Connect Scala Client
b0b15475a0a is described below

commit b0b15475a0ac2d73b829491532747a249498c1a6
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Sun Aug 13 20:27:08 2023 +0200

    [SPARK-44736][CONNECT] Add Dataset.explode to Spark Connect Scala Client
    
    ### What changes were proposed in this pull request?
    This PR adds Dataset.explode to the Spark Connect Scala Client.
    
    ### Why are the changes needed?
    To increase compatibility with the existing Dataset API in sql/core.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it adds a new method to the scala client.
    
    ### How was this patch tested?
    I added a test to `UserDefinedFunctionE2ETestSuite`.
    
    Closes #42418 from hvanhovell/SPARK-44736.
    
    Lead-authored-by: Herman van Hovell <her...@databricks.com>
    Co-authored-by: itholic <haejoon....@databricks.com>
    Co-authored-by: Juliusz Sompolski <ju...@databricks.com>
    Co-authored-by: Martin Grund <martin.gr...@databricks.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Co-authored-by: Kent Yao <y...@apache.org>
    Co-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Wei Liu <wei....@databricks.com>
    Co-authored-by: Ruifeng Zheng <ruife...@apache.org>
    Co-authored-by: Gengliang Wang <gengli...@apache.org>
    Co-authored-by: Yuming Wang <yumw...@ebay.com>
    Co-authored-by: Herman van Hovell <hvanhov...@databricks.com>
    Co-authored-by: 余良 <yul...@chinaunicom.cn>
    Co-authored-by: Dongjoon Hyun <dh...@apple.com>
    Co-authored-by: Jack Chen <jack.c...@databricks.com>
    Co-authored-by: srielau <se...@rielau.com>
    Co-authored-by: zhyhimont <zhyhim...@gmail.com>
    Co-authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Co-authored-by: Dongjoon Hyun <dongj...@apache.org>
    Co-authored-by: Zhyhimont Dmitry <zhyhimon...@profitero.com>
    Co-authored-by: Sandip Agarwala 
<131817656+sandip...@users.noreply.github.com>
    Co-authored-by: yangjie01 <yangji...@baidu.com>
    Co-authored-by: Yihong He <yihong...@databricks.com>
    Co-authored-by: Rameshkrishnan Muthusamy 
<rameshkrishnan_muthus...@apple.com>
    Co-authored-by: Jia Fan <fanjiaemi...@qq.com>
    Co-authored-by: allisonwang-db <allison.w...@databricks.com>
    Co-authored-by: Utkarsh <utkarsh.agar...@databricks.com>
    Co-authored-by: Cheng Pan <cheng...@apache.org>
    Co-authored-by: Jason Li <jason...@databricks.com>
    Co-authored-by: Shu Wang <swa...@linkedin.com>
    Co-authored-by: Nicolas Fraison <nicolas.frai...@datadoghq.com>
    Co-authored-by: Max Gekk <max.g...@gmail.com>
    Co-authored-by: panbingkun <pbk1...@gmail.com>
    Co-authored-by: Ziqi Liu <ziqi....@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit f496cd1ee2a7e59af08e1bd3ab0579f93cc46da9)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 70 ++++++++++++++++++++++
 .../sql/UserDefinedFunctionE2ETestSuite.scala      | 60 +++++++++++++++++++
 .../CheckConnectJvmClientCompatibility.scala       |  1 -
 .../apache/spark/sql/connect/common/UdfUtils.scala |  4 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  3 +-
 5 files changed, 136 insertions(+), 2 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2d72ea6bda8..28b04fb850e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,12 +21,14 @@ import java.util.{Collections, Locale}
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.java.function._
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.OrderUtils
@@ -2728,6 +2730,74 @@ class Dataset[T] private[sql] (
     flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
   }
 
+  /**
+   * (Scala-specific) Returns a new Dataset where each row has been expanded 
to zero or more rows
+   * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. 
The columns of the
+   * input row are implicitly joined with each row that is output by the 
function.
+   *
+   * Given that this is deprecated, as an alternative, you can explode columns 
either using
+   * `functions.explode()` or `flatMap()`. The following example uses these 
alternatives to count
+   * the number of books that contain a given word:
+   *
+   * {{{
+   *   case class Book(title: String, words: String)
+   *   val ds: Dataset[Book]
+   *
+   *   val allWords = ds.select($"title", explode(split($"words", " 
")).as("word"))
+   *
+   *   val bookCountPerWord = 
allWords.groupBy("word").agg(count_distinct("title"))
+   * }}}
+   *
+   * Using `flatMap()` this can similarly be exploded as:
+   *
+   * {{{
+   *   ds.flatMap(_.words.split(" "))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.5.0
+   */
+  @deprecated("use flatMap() or select() with functions.explode() instead", 
"3.5.0")
+  def explode[A <: Product: TypeTag](input: Column*)(f: Row => 
TraversableOnce[A]): DataFrame = {
+    val generator = ScalarUserDefinedFunction(
+      UdfUtils.traversableOnceToSeq(f),
+      UnboundRowEncoder :: Nil,
+      ScalaReflection.encoderFor[Seq[A]])
+    select(col("*"), functions.inline(generator(struct(input: _*))))
+  }
+
+  /**
+   * (Scala-specific) Returns a new Dataset where a single column has been 
expanded to zero or
+   * more rows by the provided function. This is similar to a `LATERAL VIEW` 
in HiveQL. All
+   * columns of the input row are implicitly joined with each value that is 
output by the
+   * function.
+   *
+   * Given that this is deprecated, as an alternative, you can explode columns 
either using
+   * `functions.explode()`:
+   *
+   * {{{
+   *   ds.select(explode(split($"words", " ")).as("word"))
+   * }}}
+   *
+   * or `flatMap()`:
+   *
+   * {{{
+   *   ds.flatMap(_.words.split(" "))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.5.0
+   */
+  @deprecated("use flatMap() or select() with functions.explode() instead", 
"3.5.0")
+  def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
+      f: A => TraversableOnce[B]): DataFrame = {
+    val generator = ScalarUserDefinedFunction(
+      UdfUtils.traversableOnceToSeq(f),
+      Nil,
+      ScalaReflection.encoderFor[Seq[B]])
+    select(col("*"), 
functions.explode(generator(col(inputColumn))).as((outputColumn)))
+  }
+
   /**
    * Applies a function `f` to all rows.
    *
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 3a931c9a6ba..d00659ac2d8 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -95,6 +95,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
     rows.forEach(x => assert(x == 42))
   }
 
+  test("(deprecated) Dataset explode") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val result1 = spark
+      .range(3)
+      .filter(col("id") =!= 1L)
+      .explode(col("id") + 41, col("id") + 10) { case Row(x: Long, y: Long) =>
+        Iterator((x, x - 1), (y, y + 1))
+      }
+      .as[(Long, Long, Long)]
+      .collect()
+      .toSeq
+    assert(result1 === Seq((0L, 41L, 40L), (0L, 10L, 11L), (2L, 43L, 42L), 
(2L, 12L, 13L)))
+
+    val result2 = Seq((1, "a b c"), (2, "a b"), (3, "a"))
+      .toDF("number", "letters")
+      .explode('letters) { case Row(letters: String) =>
+        letters.split(' ').map(Tuple1.apply).toSeq
+      }
+      .as[(Int, String, String)]
+      .collect()
+      .toSeq
+    assert(
+      result2 === Seq(
+        (1, "a b c", "a"),
+        (1, "a b c", "b"),
+        (1, "a b c", "c"),
+        (2, "a b", "a"),
+        (2, "a b", "b"),
+        (3, "a", "a")))
+
+    val result3 = Seq("a b c", "d e")
+      .toDF("words")
+      .explode("words", "word") { word: String =>
+        word.split(' ').toSeq
+      }
+      .select(col("word"))
+      .as[String]
+      .collect()
+      .toSeq
+    assert(result3 === Seq("a", "b", "c", "d", "e"))
+
+    val result4 = Seq("a b c", "d e")
+      .toDF("words")
+      .explode("words", "word") { word: String =>
+        word.split(' ').map(s => s -> s.head.toInt).toSeq
+      }
+      .select(col("word"), col("words"))
+      .as[((String, Int), String)]
+      .collect()
+      .toSeq
+    assert(
+      result4 === Seq(
+        (("a", 97), "a b c"),
+        (("b", 98), "a b c"),
+        (("c", 99), "a b c"),
+        (("d", 100), "d e"),
+        (("e", 101), "d e")))
+  }
+
   test("Dataset typed flat map - java") {
     val rows = spark
       .range(5)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 04b162eceec..7356d4daa79 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -184,7 +184,6 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), 
// protected
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), 
// deprecated
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index 16d5823f4a4..433614a4afc 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -131,6 +131,10 @@ private[sql] object UdfUtils extends Serializable {
 
   def noOp[V, K](): V => K = _ => null.asInstanceOf[K]
 
+  def traversableOnceToSeq[A, B](f: A => TraversableOnce[B]): A => Seq[B] = { 
value =>
+    f(value).toSeq
+  }
+
   //  (1 to 22).foreach { i =>
   //    val extTypeArgs = (0 to i).map(_ => "_").mkString(", ")
   //    val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 49bac17a4f4..dc77c52ef46 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -508,7 +508,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
     val commonUdf = rel.getFunc
     commonUdf.getFunctionCase match {
       case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
-        transformTypedMapPartitions(commonUdf, baseRel)
+        val analyzed = session.sessionState.executePlan(baseRel).analyzed
+        transformTypedMapPartitions(commonUdf, analyzed)
       case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
         val pythonUdf = transformPythonUDF(commonUdf)
         val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to