[ 
https://issues.apache.org/jira/browse/SPARK-52738?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Max Gekk reassigned SPARK-52738:
--------------------------------

    Assignee: Bruce Robbins

> Support aggregating TIME type with a UDAF when the underlying buffer is an 
> UnsafeRow
> ------------------------------------------------------------------------------------
>
>                 Key: SPARK-52738
>                 URL: https://issues.apache.org/jira/browse/SPARK-52738
>             Project: Spark
>          Issue Type: Sub-task
>          Components: SQL
>    Affects Versions: 4.1.0
>            Reporter: Bruce Robbins
>            Assignee: Bruce Robbins
>            Priority: Major
>              Labels: pull-request-available
>
> Spark gets an error while aggregating a TIME type with a UDAF when the 
> underlying aggregation buffer is an unsafe row (i.e., when all fields in the 
> schema are considered mutable by {{UnsafeRow}}).
> Assume this code:
> {noformat}
> import org.apache.spark.sql.expressions.{MutableAggregationBuffer, 
> UserDefinedAggregateFunction}
> import org.apache.spark.sql.types._
> import org.apache.spark.sql.Row
> class ScalaAggregateFunction(schema: StructType) extends 
> UserDefinedAggregateFunction {
>   def inputSchema: StructType = schema
>   def bufferSchema: StructType = schema
>   def dataType: DataType = schema
>   def deterministic: Boolean = true
>   def initialize(buffer: MutableAggregationBuffer): Unit = {
>     (0 until schema.length).foreach { i =>
>       buffer.update(i, null)
>     }
>   }
>   def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
>     if (!input.isNullAt(0) && input.getInt(0) == 50) {
>       (0 until schema.length).foreach { i =>
>         buffer.update(i, input.get(i))
>       }
>     }
>   }
>   def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
>     if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
>       (0 until schema.length).foreach { i =>
>         buffer1.update(i, buffer2.get(i))
>       }
>     }
>   }
>   def evaluate(buffer: Row): Any = {
>     Row.fromSeq(buffer.toSeq)
>   }
> }
> import scala.util.Random
> import java.time.LocalTime
> val r = new Random(65676563L)
> val data = Seq.tabulate(50) { x =>
>   Row((x + 1).toInt, (x + 2).toDouble, (x + 2).toLong, 
> LocalTime.parse("23:33:33.123").minusMinutes(x % 1300 + 1))
> }
> val schema = StructType.fromDDL("id int, col1 double, col2 bigint, col3 time")
> val rdd = spark.sparkContext.parallelize(data, 1)
> val df = spark.createDataFrame(rdd, schema)
> val udaf = new ScalaAggregateFunction(df.schema)
> val allColumns = df.schema.fields.map(f => col(f.name))
> df.groupBy().agg(udaf(allColumns: _*)).show(false)
> {noformat}
> It gets this error:
> {noformat}
> warning: 1 deprecation (since 2.13.0); for details, enable `:setting 
> -deprecation` or `:replay -deprecation`
> Exception in task 0.0 in stage 0.0 (TID 0)
> org.apache.spark.SparkUnsupportedOperationException: 
> [UNSUPPORTED_CALL.WITHOUT_SUGGESTION] Cannot call the method "update" of the 
> class "org.apache.spark.sql.catalyst.expressions.UnsafeRow".  SQLSTATE: 0A000
> {noformat}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to