Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16383#discussion_r93768760
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 ---
    @@ -143,15 +197,96 @@ case class TypedAggregateExpression(
         }
       }
     
    -  override def toString: String = {
    -    val input = inputDeserializer match {
    -      case Some(UnresolvedDeserializer(deserializer, _)) => 
deserializer.dataType.simpleString
    -      case Some(deserializer) => deserializer.dataType.simpleString
    -      case _ => "unknown"
    +  override def withInputInfo(
    +      deser: Expression,
    +      cls: Class[_],
    +      schema: StructType): TypedAggregateExpression = {
    +    copy(inputDeserializer = Some(deser), inputClass = Some(cls), 
inputSchema = Some(schema))
    +  }
    +}
    +
    +case class ComplexTypedAggregateExpression(
    +    aggregator: Aggregator[Any, Any, Any],
    +    inputDeserializer: Option[Expression],
    +    inputClass: Option[Class[_]],
    +    inputSchema: Option[StructType],
    +    bufferSerializer: Seq[NamedExpression],
    +    bufferDeserializer: Expression,
    +    outputSerializer: Seq[Expression],
    +    dataType: DataType,
    +    nullable: Boolean,
    +    mutableAggBufferOffset: Int = 0,
    +    inputAggBufferOffset: Int = 0)
    +  extends TypedImperativeAggregate[Any] with TypedAggregateExpression with 
NonSQLExpression {
    +
    +  override def deterministic: Boolean = true
    +
    +  override def children: Seq[Expression] = inputDeserializer.toSeq
    +
    +  override lazy val resolved: Boolean = inputDeserializer.isDefined && 
childrenResolved
    +
    +  override def references: AttributeSet = 
AttributeSet(inputDeserializer.toSeq)
    +
    +  override def createAggregationBuffer(): Any = aggregator.zero
    +
    +  private lazy val inputRowToObj = 
GenerateSafeProjection.generate(inputDeserializer.get :: Nil)
    +
    +  override def update(buffer: Any, input: InternalRow): Any = {
    +    val inputObj = inputRowToObj(input).get(0, ObjectType(classOf[Any]))
    +    if (inputObj != null) {
    +      aggregator.reduce(buffer, inputObj)
    +    } else {
    +      buffer
    +    }
    +  }
    +
    +  override def merge(buffer: Any, input: Any): Any = {
    +    aggregator.merge(buffer, input)
    +  }
    +
    +  private lazy val resultObjToRow = dataType match {
    +    case _: StructType =>
    +      UnsafeProjection.create(CreateStruct(outputSerializer))
    +    case _ =>
    +      assert(outputSerializer.length == 1)
    +      UnsafeProjection.create(outputSerializer.head)
    +  }
    +
    +  override def eval(buffer: Any): Any = {
    +    val resultObj = aggregator.finish(buffer)
    +    if (resultObj == null) {
    +      null
    +    } else {
    +      resultObjToRow(InternalRow(resultObj)).get(0, dataType)
         }
    +  }
     
    -    s"$nodeName($input)"
    +  private lazy val bufferObjToRow = 
UnsafeProjection.create(bufferSerializer)
    +
    +  override def serialize(buffer: Any): Array[Byte] = {
    +    bufferObjToRow(InternalRow(buffer)).getBytes
       }
     
    -  override def nodeName: String = 
aggregator.getClass.getSimpleName.stripSuffix("$")
    +  private lazy val bufferRow = new UnsafeRow(bufferSerializer.length)
    +  private lazy val bufferRowToObject = 
GenerateSafeProjection.generate(bufferDeserializer :: Nil)
    +
    +  override def deserialize(storageFormat: Array[Byte]): Any = {
    +    bufferRow.pointTo(storageFormat, storageFormat.length)
    +    bufferRowToObject(bufferRow).get(0, ObjectType(classOf[Any]))
    +  }
    +
    +  override def withNewMutableAggBufferOffset(
    +      newMutableAggBufferOffset: Int): ComplexTypedAggregateExpression =
    +    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
    +
    +  override def withNewInputAggBufferOffset(
    +      newInputAggBufferOffset: Int): ComplexTypedAggregateExpression =
    +    copy(inputAggBufferOffset = newInputAggBufferOffset)
    +
    +  override def withInputInfo(
    +      deser: Expression,
    +      cls: Class[_],
    +      schema: StructType): TypedAggregateExpression = {
    +    copy(inputDeserializer = Some(deser), inputClass = Some(cls), 
inputSchema = Some(schema))
    --- End diff --
    
    ok. I see.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to