hi Yin,

I have a written a simple UDAF to generate N samples for each group. I am
using reservoir sampling algorithm for this. In this case since the input
data type doesn't matter as I am not doing any kind of processing on the
input data but just selecting them by random and building an array and
returning that array. Later I use explode to convert them back into rows.
Below is my UDAF and test suite

========= UDAF ========
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,
UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.util.Random



/**
 * Created by ragrawal on 9/23/15.
 */
class ReservoirSampling(k: Int, seed: Long = Random.nextLong() ) extends
UserDefinedAggregateFunction{

  // Schema you get as an input
  def inputSchema = StructType(StructField("id", StringType) :: Nil )



  // Schema of the row which is used for aggregation
  def bufferSchema = StructType(
                        StructField("ids", ArrayType(StringType,
containsNull = true)) ::
                        StructField("count", LongType) :: Nil
                    )


  // Returned type
  def dataType: DataType = ArrayType(StringType, containsNull = false)

  // Self-explaining
  def deterministic = true

  // zero value
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = Seq[String]()
    buffer(1) = 0
  }


  // Similar to seqOp in aggregate
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    update(buffer, input.getAs[String](0))
  }

  def update(buffer: MutableAggregationBuffer, item: String) = {
    if(item != null){
      val ids: Seq[String] = buffer.getSeq(0)

      if(ids.length < k){ // fill reservoir
        buffer(0) = ids :+ item
      }else{
        //TODO: validate this its buffer.getInt(1) or buffer.getInt(1) + 1
        val idx = new Random(seed).nextInt(buffer.getInt(1))
        if(idx < k){ // maintain reservoir
          ids.updated(idx, item)
        }
      }
      buffer(1) = buffer.getInt(1) + 1
    }else{
      throw new RuntimeException("Cannot handle null strings")
    }

  }
  // Similar to combOp in aggregate
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer2.getSeq[String](0).foreach(update(buffer1, _))
  }


  // Called on exit to get return value
  def evaluate(buffer: Row) = buffer.getSeq(0)

}

====== TEST SUITE ====
  test("basic test of reservoir sample") {

    val rsampling = new ReservoirSampling(10, 10L)
    sqlContext.udf.register("reservoir ", rsampling)


    val data = sc.parallelize(1 to 1000, 1)
    val schema = new StructType(Array(
                  StructField("key", StringType, nullable = false)
                ))

    val df = sqlContext
              .createDataFrame(data.map{x:Int => Row(x.toString)}, schema)
              .select("key")
              .withColumn("g", org.apache.spark.sql.functions.lit(10))
              .distinct
              .groupBy("g")
              .agg(rsampling(col("key")).as("keys"))
              .explode("skey", "key"){keys: Seq[String] => keys}
              .select("g", "key")
              .show()

  }




On Fri, Sep 25, 2015 at 9:35 AM, Yin Huai <yh...@databricks.com> wrote:

> Hi Ritesh,
>
> Right now, we only allow specific data types defined in the inputSchema.
> Supporting abstract types (e.g. NumericType) may cause the logic of a UDAF
> be more complex. It will be great to understand the use cases first. What
> kinds of possible input data types that you want to support and do you need
> to know the actual argument types to determine how to process input data?
>
> btw, for now, one possible workaround is to define multiple UDAFs for
> different input types. Then, based on arguments that you have, you invoke
> the corresponding UDAF.
>
> Thanks,
>
> Yin
>
> On Fri, Sep 25, 2015 at 8:07 AM, Ritesh Agrawal <
> ragra...@netflix.com.invalid> wrote:
>
>> Hi all,
>>
>> I am trying to learn about UDAF and implemented a simple reservoir sample
>> UDAF. It's working fine. However I am not able to figure out what DataType
>> should I use so that its can deal with all DataTypes (simple and complex).
>> For instance currently I have defined my input schema as
>>
>>  def inputSchema = StructType(StructField("id", StringType) :: Nil )
>>
>>
>> Instead of StringType can I use some other data type that is superclass
>> of all the DataTypes ?
>>
>> Ritesh
>>
>
>

Reply via email to