hi Yin,
I have a written a simple UDAF to generate N samples for each group. I am
using reservoir sampling algorithm for this. In this case since the input
data type doesn't matter as I am not doing any kind of processing on the
input data but just selecting them by random and building an array and
returning that array. Later I use explode to convert them back into rows.
Below is my UDAF and test suite
========= UDAF ========
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import scala.util.Random
/**
* Created by ragrawal on 9/23/15.
*/
class ReservoirSampling(k: Int, seed: Long = Random.nextLong() ) extends
UserDefinedAggregateFunction{
// Schema you get as an input
def inputSchema = StructType(StructField("id", StringType) :: Nil )
// Schema of the row which is used for aggregation
def bufferSchema = StructType(
StructField("ids", ArrayType(StringType,
containsNull = true)) ::
StructField("count", LongType) :: Nil
)
// Returned type
def dataType: DataType = ArrayType(StringType, containsNull = false)
// Self-explaining
def deterministic = true
// zero value
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = Seq[String]()
buffer(1) = 0
}
// Similar to seqOp in aggregate
def update(buffer: MutableAggregationBuffer, input: Row) = {
update(buffer, input.getAs[String](0))
}
def update(buffer: MutableAggregationBuffer, item: String) = {
if(item != null){
val ids: Seq[String] = buffer.getSeq(0)
if(ids.length < k){ // fill reservoir
buffer(0) = ids :+ item
}else{
//TODO: validate this its buffer.getInt(1) or buffer.getInt(1) + 1
val idx = new Random(seed).nextInt(buffer.getInt(1))
if(idx < k){ // maintain reservoir
ids.updated(idx, item)
}
}
buffer(1) = buffer.getInt(1) + 1
}else{
throw new RuntimeException("Cannot handle null strings")
}
}
// Similar to combOp in aggregate
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer2.getSeq[String](0).foreach(update(buffer1, _))
}
// Called on exit to get return value
def evaluate(buffer: Row) = buffer.getSeq(0)
}
====== TEST SUITE ====
test("basic test of reservoir sample") {
val rsampling = new ReservoirSampling(10, 10L)
sqlContext.udf.register("reservoir ", rsampling)
val data = sc.parallelize(1 to 1000, 1)
val schema = new StructType(Array(
StructField("key", StringType, nullable = false)
))
val df = sqlContext
.createDataFrame(data.map{x:Int => Row(x.toString)}, schema)
.select("key")
.withColumn("g", org.apache.spark.sql.functions.lit(10))
.distinct
.groupBy("g")
.agg(rsampling(col("key")).as("keys"))
.explode("skey", "key"){keys: Seq[String] => keys}
.select("g", "key")
.show()
}
On Fri, Sep 25, 2015 at 9:35 AM, Yin Huai <[email protected]> wrote:
> Hi Ritesh,
>
> Right now, we only allow specific data types defined in the inputSchema.
> Supporting abstract types (e.g. NumericType) may cause the logic of a UDAF
> be more complex. It will be great to understand the use cases first. What
> kinds of possible input data types that you want to support and do you need
> to know the actual argument types to determine how to process input data?
>
> btw, for now, one possible workaround is to define multiple UDAFs for
> different input types. Then, based on arguments that you have, you invoke
> the corresponding UDAF.
>
> Thanks,
>
> Yin
>
> On Fri, Sep 25, 2015 at 8:07 AM, Ritesh Agrawal <
> [email protected]> wrote:
>
>> Hi all,
>>
>> I am trying to learn about UDAF and implemented a simple reservoir sample
>> UDAF. It's working fine. However I am not able to figure out what DataType
>> should I use so that its can deal with all DataTypes (simple and complex).
>> For instance currently I have defined my input schema as
>>
>> def inputSchema = StructType(StructField("id", StringType) :: Nil )
>>
>>
>> Instead of StringType can I use some other data type that is superclass
>> of all the DataTypes ?
>>
>> Ritesh
>>
>
>