Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/16383#discussion_r93725196 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala --- @@ -143,15 +197,96 @@ case class TypedAggregateExpression( } } - override def toString: String = { - val input = inputDeserializer match { - case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString - case Some(deserializer) => deserializer.dataType.simpleString - case _ => "unknown" + override def withInputInfo( + deser: Expression, + cls: Class[_], + schema: StructType): TypedAggregateExpression = { + copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) + } +} + +case class ComplexTypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + inputDeserializer: Option[Expression], + inputClass: Option[Class[_]], + inputSchema: Option[StructType], + bufferSerializer: Seq[NamedExpression], + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + dataType: DataType, + nullable: Boolean, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { + + override def deterministic: Boolean = true + + override def children: Seq[Expression] = inputDeserializer.toSeq + + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved + + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) + + override def createAggregationBuffer(): Any = aggregator.zero + + private lazy val inputRowToObj = GenerateSafeProjection.generate(inputDeserializer.get :: Nil) + + override def update(buffer: Any, input: InternalRow): Any = { + val inputObj = inputRowToObj(input).get(0, ObjectType(classOf[Any])) + if (inputObj != null) { + aggregator.reduce(buffer, inputObj) + } else { + buffer + } + } + + override def merge(buffer: Any, input: Any): Any = { + aggregator.merge(buffer, input) + } + + private lazy val resultObjToRow = dataType match { + case _: StructType => + UnsafeProjection.create(CreateStruct(outputSerializer)) + case _ => + assert(outputSerializer.length == 1) + UnsafeProjection.create(outputSerializer.head) + } + + override def eval(buffer: Any): Any = { + val resultObj = aggregator.finish(buffer) + if (resultObj == null) { + null + } else { + resultObjToRow(InternalRow(resultObj)).get(0, dataType) } + } - s"$nodeName($input)" + private lazy val bufferObjToRow = UnsafeProjection.create(bufferSerializer) + + override def serialize(buffer: Any): Array[Byte] = { + bufferObjToRow(InternalRow(buffer)).getBytes } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + private lazy val bufferRow = new UnsafeRow(bufferSerializer.length) + private lazy val bufferRowToObject = GenerateSafeProjection.generate(bufferDeserializer :: Nil) + + override def deserialize(storageFormat: Array[Byte]): Any = { + bufferRow.pointTo(storageFormat, storageFormat.length) + bufferRowToObject(bufferRow).get(0, ObjectType(classOf[Any])) + } + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ComplexTypedAggregateExpression = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset( + newInputAggBufferOffset: Int): ComplexTypedAggregateExpression = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withInputInfo( + deser: Expression, + cls: Class[_], + schema: StructType): TypedAggregateExpression = { + copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) --- End diff -- Where do we need to use `inputClass`? `TypedAggregateExpression` has this parameter but I don't see it is used anywhere.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org