hi Yin, I have a written a simple UDAF to generate N samples for each group. I am using reservoir sampling algorithm for this. In this case since the input data type doesn't matter as I am not doing any kind of processing on the input data but just selecting them by random and building an array and returning that array. Later I use explode to convert them back into rows. Below is my UDAF and test suite
========= UDAF ======== import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import scala.util.Random /** * Created by ragrawal on 9/23/15. */ class ReservoirSampling(k: Int, seed: Long = Random.nextLong() ) extends UserDefinedAggregateFunction{ // Schema you get as an input def inputSchema = StructType(StructField("id", StringType) :: Nil ) // Schema of the row which is used for aggregation def bufferSchema = StructType( StructField("ids", ArrayType(StringType, containsNull = true)) :: StructField("count", LongType) :: Nil ) // Returned type def dataType: DataType = ArrayType(StringType, containsNull = false) // Self-explaining def deterministic = true // zero value def initialize(buffer: MutableAggregationBuffer) = { buffer(0) = Seq[String]() buffer(1) = 0 } // Similar to seqOp in aggregate def update(buffer: MutableAggregationBuffer, input: Row) = { update(buffer, input.getAs[String](0)) } def update(buffer: MutableAggregationBuffer, item: String) = { if(item != null){ val ids: Seq[String] = buffer.getSeq(0) if(ids.length < k){ // fill reservoir buffer(0) = ids :+ item }else{ //TODO: validate this its buffer.getInt(1) or buffer.getInt(1) + 1 val idx = new Random(seed).nextInt(buffer.getInt(1)) if(idx < k){ // maintain reservoir ids.updated(idx, item) } } buffer(1) = buffer.getInt(1) + 1 }else{ throw new RuntimeException("Cannot handle null strings") } } // Similar to combOp in aggregate def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { buffer2.getSeq[String](0).foreach(update(buffer1, _)) } // Called on exit to get return value def evaluate(buffer: Row) = buffer.getSeq(0) } ====== TEST SUITE ==== test("basic test of reservoir sample") { val rsampling = new ReservoirSampling(10, 10L) sqlContext.udf.register("reservoir ", rsampling) val data = sc.parallelize(1 to 1000, 1) val schema = new StructType(Array( StructField("key", StringType, nullable = false) )) val df = sqlContext .createDataFrame(data.map{x:Int => Row(x.toString)}, schema) .select("key") .withColumn("g", org.apache.spark.sql.functions.lit(10)) .distinct .groupBy("g") .agg(rsampling(col("key")).as("keys")) .explode("skey", "key"){keys: Seq[String] => keys} .select("g", "key") .show() } On Fri, Sep 25, 2015 at 9:35 AM, Yin Huai <yh...@databricks.com> wrote: > Hi Ritesh, > > Right now, we only allow specific data types defined in the inputSchema. > Supporting abstract types (e.g. NumericType) may cause the logic of a UDAF > be more complex. It will be great to understand the use cases first. What > kinds of possible input data types that you want to support and do you need > to know the actual argument types to determine how to process input data? > > btw, for now, one possible workaround is to define multiple UDAFs for > different input types. Then, based on arguments that you have, you invoke > the corresponding UDAF. > > Thanks, > > Yin > > On Fri, Sep 25, 2015 at 8:07 AM, Ritesh Agrawal < > ragra...@netflix.com.invalid> wrote: > >> Hi all, >> >> I am trying to learn about UDAF and implemented a simple reservoir sample >> UDAF. It's working fine. However I am not able to figure out what DataType >> should I use so that its can deal with all DataTypes (simple and complex). >> For instance currently I have defined my input schema as >> >> def inputSchema = StructType(StructField("id", StringType) :: Nil ) >> >> >> Instead of StringType can I use some other data type that is superclass >> of all the DataTypes ? >> >> Ritesh >> > >