This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 103eca60102 [SPARK-42605][CONNECT] Add TypedColumn 103eca60102 is described below commit 103eca60102d943fd083bf029db1d0e7f22d67ff Author: Herman van Hovell <her...@databricks.com> AuthorDate: Mon Feb 27 15:13:36 2023 -0400 [SPARK-42605][CONNECT] Add TypedColumn ### What changes were proposed in this pull request? This PR adds TypedColumn to the Spark Connect Scala Client. We also add one of the typed select methods for Dataset, and typed count function. ### Why are the changes needed? API Parity. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Added tests to PlanGenerationTestSuite and ClientE2EtestSuite. Closes #40197 from hvanhovell/SPARK-42605. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 7f64ec302420652932ff515c325ba37938f0b175) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Column.scala | 40 +++++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 25 +++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 10 ++++++ .../org/apache/spark/sql/ClientE2ETestSuite.scala | 40 +++++++++++++++------ .../apache/spark/sql/PlanGenerationTestSuite.scala | 10 ++++++ .../sql/connect/client/CompatibilitySuite.scala | 10 +++--- .../explain-results/function_count_typed.explain | 2 ++ .../explain-results/select_typed_1-arg.explain | 3 ++ .../query-tests/queries/function_count_typed.json | 25 +++++++++++++ .../queries/function_count_typed.proto.bin | Bin 0 -> 66 bytes .../query-tests/queries/select_typed_1-arg.json | 39 ++++++++++++++++++++ .../queries/select_typed_1-arg.proto.bin | Bin 0 -> 98 bytes 12 files changed, 189 insertions(+), 15 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala index 9af0096fc1c..c39d5c9757e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala @@ -22,6 +22,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.expressions.Window @@ -70,6 +71,17 @@ class Column private[sql] (private[sql] val expr: proto.Expression) extends Logg override def hashCode: Int = expr.hashCode() + /** + * Provides a type hint about the expected return value of this column. This information can be + * used by operations such as `select` on a [[Dataset]] to automatically convert the results + * into the correct JVM types. + * @since 3.4.0 + */ + def as[U: Encoder]: TypedColumn[Any, U] = { + val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]] + new TypedColumn[Any, U](expr, encoder) + } + /** * Extracts a value or values from a complex type. The following types of extraction are * supported: @@ -1430,3 +1442,31 @@ class ColumnName(name: String) extends Column(name) { */ def struct(structType: StructType): StructField = StructField(name, structType) } + +/** + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. To + * create a [[TypedColumn]], use the `as` function on a [[Column]]. + * + * @tparam T + * The input type expected for this expression. Can be `Any` if the expression is type checked + * by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U + * The output type of this column. + * + * @since 3.4.0 + */ +class TypedColumn[-T, U] private[sql] ( + expr: proto.Expression, + private[sql] val encoder: AgnosticEncoder[U]) + extends Column(expr) { + + /** + * Gives the [[TypedColumn]] a name (alias). If the current `TypedColumn` has metadata + * associated with it, this metadata will be propagated to the new column. + * + * @group expr_ops + * @since 3.4.0 + */ + override def name(alias: String): TypedColumn[T, U] = + new TypedColumn[T, U](super.name(alias).expr, encoder) +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 73de35456fc..1015d61a9c2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1012,6 +1012,31 @@ class Dataset[T] private[sql] ( select(exprs.map(functions.expr): _*) } + /** + * Returns a new Dataset by computing the given [[Column]] expression for each element. + * + * {{{ + * val ds = Seq(1, 2, 3).toDS() + * val newDS = ds.select(expr("value + 1").as[Int]) + * }}} + * + * @group typedrel + * @since 3.4.0 + */ + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + val encoder = c1.encoder + val expr = if (encoder.schema == encoder.dataType) { + functions.inline(functions.array(c1)).expr + } else { + c1.expr + } + sparkSession.newDataset(encoder) { builder => + builder.getProjectBuilder + .setInput(plan.getRoot) + .addExpressions(expr) + } + } + /** * Filters rows using the given condition. * {{{ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 94882087eee..386219a699c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import com.google.protobuf.ByteString import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.connect.client.unsupported import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} @@ -401,6 +402,15 @@ object functions { */ def count(e: Column): Column = Column.fn("count", e) + /** + * Aggregate function: returns the number of items in a group. + * + * @group agg_funcs + * @since 3.4.0 + */ + def count(columnName: String): TypedColumn[Any, Long] = + count(Column(columnName)).as(PrimitiveLongEncoder) + /** * Aggregate function: returns the number of distinct items in a group. * diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index debb314f8c3..3f00f7c9c36 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -30,7 +30,7 @@ import org.scalactic.TolerantNumerics import org.apache.spark.SPARK_VERSION import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} -import org.apache.spark.sql.functions.{aggregate, array, col, lit, rand, sequence, shuffle, transform, udf} +import org.apache.spark.sql.functions.{aggregate, array, col, count, lit, rand, sequence, shuffle, struct, transform, udf} import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession { @@ -412,16 +412,13 @@ class ClientE2ETestSuite extends RemoteSparkSession { } } - test("Dataset collect complex type") { - val result = spark - .range(3) - .select( - (col("id") / lit(10.0d)).as("b"), - col("id"), - lit("world").as("d"), - (col("id") % 2).cast("int").as("a")) - .as[MyType] - .collect() + private val generateMyTypeColumns = Seq( + (col("id") / lit(10.0d)).as("b"), + col("id"), + lit("world").as("d"), + (col("id") % 2).cast("int").as("a")) + + private def validateMyTypeResult(result: Array[MyType]): Unit = { result.zipWithIndex.foreach { case (MyType(id, a, b), i) => assert(id == i) assert(a == id % 2) @@ -429,6 +426,27 @@ class ClientE2ETestSuite extends RemoteSparkSession { } } + test("Dataset collect complex type") { + val result = spark + .range(3) + .select(generateMyTypeColumns: _*) + .as[MyType] + .collect() + validateMyTypeResult(result) + } + + test("Dataset typed select - simple column") { + val numRows = spark.range(1000).select(count("id")).first() + assert(numRows === 1000) + } + + test("Dataset typed select - complex column") { + val ds = spark + .range(3) + .select(struct(generateMyTypeColumns: _*).as[MyType]) + validateMyTypeResult(ds.collect()) + } + test("lambda functions") { // This test is mostly to validate lambda variables are properly resolved. val result = spark diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 67ea148cb87..52e5d892012 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.{functions => fn} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.expressions.Window @@ -287,6 +288,11 @@ class PlanGenerationTestSuite simple.select(fn.col("id")) } + test("select typed 1-arg") { + val encoder = ScalaReflection.encoderFor[(Long, Int)] + simple.select(fn.struct(fn.col("id"), fn.col("a")).as(encoder)) + } + test("limit") { simple.limit(10) } @@ -876,6 +882,10 @@ class PlanGenerationTestSuite fn.count(fn.col("a")) } + test("function count typed") { + simple.select(fn.count("a")) + } + functionTest("countDistinct") { fn.countDistinct("a", "g") } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala index bb480e0ee08..ccee3b550eb 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala @@ -81,7 +81,8 @@ class CompatibilitySuite extends ConnectFunSuite { IncludeByName("org.apache.spark.sql.functions.*"), IncludeByName("org.apache.spark.sql.RelationalGroupedDataset.*"), IncludeByName("org.apache.spark.sql.SparkSession.*"), - IncludeByName("org.apache.spark.sql.RuntimeConfig.*")) + IncludeByName("org.apache.spark.sql.RuntimeConfig.*"), + IncludeByName("org.apache.spark.sql.TypedColumn.*")) val excludeRules = Seq( // Filter unsupported rules: // Note when muting errors for a method, checks on all overloading methods are also muted. @@ -136,7 +137,6 @@ class CompatibilitySuite extends ConnectFunSuite { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.count"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"), @@ -178,10 +178,12 @@ class CompatibilitySuite extends ConnectFunSuite { ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.getActiveSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.getDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.range"), // RuntimeConfig - ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this")) + ProblemFilters.exclude[Problem]("org.apache.spark.sql.RuntimeConfig.this"), + + // TypedColumn + ProblemFilters.exclude[Problem]("org.apache.spark.sql.TypedColumn.this")) val problems = allProblems .filter { p => includedRules.exists(rule => rule(p)) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_typed.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_typed.explain new file mode 100644 index 00000000000..200513a1181 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_typed.explain @@ -0,0 +1,2 @@ +Aggregate [count(a#0) AS count(a)#0L] ++- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/select_typed_1-arg.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/select_typed_1-arg.explain new file mode 100644 index 00000000000..64017a5e073 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/select_typed_1-arg.explain @@ -0,0 +1,3 @@ +Project [id#0L, a#0] ++- Generate inline(array(struct(id, id#0L, a, a#0))), false, [id#0L, a#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.json b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.json new file mode 100644 index 00000000000..1c5df90b79c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.json @@ -0,0 +1,25 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "count", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.proto.bin new file mode 100644 index 00000000000..44b613eb40c Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_count_typed.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json new file mode 100644 index 00000000000..90ef62c5f41 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.json @@ -0,0 +1,39 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "inline", + "arguments": [{ + "unresolvedFunction": { + "functionName": "array", + "arguments": [{ + "unresolvedFunction": { + "functionName": "struct", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "id" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }] + } + }] + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin new file mode 100644 index 00000000000..2273a16d4e6 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/select_typed_1-arg.proto.bin differ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org