first of all i wanted to say that i am very happy to see
org.apache.spark.sql.expressions.Aggregator, it is a neat api, especially
when compared to the UDAF/AggregateFunction stuff.

its doc/comments says: A base class for user-defined aggregations, which
can be used in [[DataFrame]] and [[Dataset]]

it works well with Dataset/GroupedDataset, but i am having no luck using it
with DataFrame/GroupedData. does anyone have an example how to use it with
a DataFrame?

in particular i would like to use it with this method in GroupedData:
  def agg(expr: Column, exprs: Column*): DataFrame

clearly it should be possible, since GroupedDataset uses that very same
method to do the work:
  private def agg(exprs: Column*): DataFrame =
    groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder):
_*)

the trick seems to be the wrapping in withEncoder, which is private. i
tried to do something like it myself, spending my usual daily 30 mins
getting around private restrictions in spark on this, but i had no luck
since it uses more private stuff on TypedColumn. also this column/catalyst
stuff makes me instantly sleepy so i didn't try to hard.

anyhow, my attempt at using it in DataFrame:

val simpleSum = new SqlAggregator[Int, Int, Int] {
  def zero: Int = 0                     // The initial value.
  def reduce(b: Int, a: Int) = b + a    // Add an element to the running
total
  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
  def finish(b: Int) = b                // Return the final result.
}.toColumn

val df = sc.makeRDD(1 to 3).map(i => (i, i)).toDF("k", "v")
df.groupBy("k").agg(simpleSum).show

and the resulting error:
org.apache.spark.sql.AnalysisException: unresolved operator 'Aggregate
[k#104], [k#104,($anon$3(),mode=Complete,isDistinct=false) AS sum#106];
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:38)
at
org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:46)
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:241)
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:50)
at
org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:122)
at
org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:50)
at
org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:46)
at
org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:34)
at org.apache.spark.sql.DataFrame.<init>(DataFrame.scala:130)
at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:49)

best, koert

Reply via email to