[ 
https://issues.apache.org/jira/browse/SPARK-2104?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14045504#comment-14045504
 ] 

Reynold Xin commented on SPARK-2104:
------------------------------------

BTW I have some old code I wrote -- you can do your changes based on this
{code}

/**
 * A [[org.apache.spark.Partitioner]] that partitions sortable records by range 
into roughly
 * equal ranges. The ranges are determined by sampling the content of the RDD 
passed in.
 *
 * Note that the actual number of partitions created by the RangePartitioner 
might not be the same
 * as the `partitions` parameter, in the case where the number of sampled 
records is less than
 * the value of `partitions`.
 */
class RangePartitioner[K : Ordering : ClassTag, V](
    var partitions: Int,
    @transient rdd: RDD[_ <: Product2[K,V]],
    private val ascending: Boolean = true)
  extends Partitioner {

  private var ordering = implicitly[Ordering[K]]

  // An array of upper bounds for the first (partitions - 1) partitions
  var rangeBounds: Array[K] = {
    if (partitions == 1) {
      Array()
    } else {
      val rddSize = rdd.count()
      val maxSampleSize = partitions * 20.0
      val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
      val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
      if (rddSample.length == 0) {
        Array()
      } else {
        val bounds = new Array[K](partitions - 1)
        for (i <- 0 until partitions - 1) {
          val index = (rddSample.length - 1) * (i + 1) / partitions
          bounds(i) = rddSample(index)
        }
        bounds
      }
    }
  }

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream): Unit = {
    val sfactory = SparkEnv.get.serializer
    // Treat java serializer with default action rather than going thru 
serialization, to avoid a
    // separate serialization header.
    sfactory match {
      case js: JavaSerializer => out.defaultWriteObject()
      case _ =>
        out.writeInt(partitions)
        val ser = sfactory.newInstance()
        Utils.serializeViaNestedStream(out, ser) { stream =>
          stream.writeObject(ordering)
          stream.writeObject(scala.reflect.classTag[K])
          stream.writeObject(rangeBounds)
        }
    }
  }

  @throws(classOf[IOException])
  private def readObject(in: ObjectInputStream): Unit = {

    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => in.defaultReadObject()
      case _ =>
        partitions = in.readInt()

        val ser = sfactory.newInstance()
        Utils.deserializeViaNestedStream(in, ser) { ds =>
          println(ds)
          ordering = ds.readObject[Ordering[K]]()
          implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
          rangeBounds = ds.readObject[Array[K]]()(classTag)
          binarySearch = CollectionsUtils.makeBinarySearch[K]
        }
    }
  }

  def numPartitions = rangeBounds.length + 1

  private var binarySearch: ((Array[K], K) => Int) = 
CollectionsUtils.makeBinarySearch[K]

  def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length < 1000) {
      // If we have less than 100 partitions naive search
      while (partition < rangeBounds.length && ordering.gt(k, 
rangeBounds(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition-1
      }
      if (partition > rangeBounds.length) {
        partition = rangeBounds.length
      }
    }
    if (ascending) {
      partition
    } else {
      rangeBounds.length - partition
    }
  }

  override def equals(other: Any): Boolean = other match {
    case r: RangePartitioner[_,_] =>
      r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
    case _ =>
      false
  }

  override def hashCode(): Int = {
    val prime = 31
    var result = 1
    var i = 0
    while (i < rangeBounds.length) {
      result = prime * result + rangeBounds(i).hashCode
      i += 1
    }
    result = prime * result + ascending.hashCode
    result
  }
}
{code}

> RangePartitioner should use user specified serializer to serialize range 
> bounds
> -------------------------------------------------------------------------------
>
>                 Key: SPARK-2104
>                 URL: https://issues.apache.org/jira/browse/SPARK-2104
>             Project: Spark
>          Issue Type: Bug
>            Reporter: Reynold Xin
>
> Otherwise it is pretty annoying to do a sort on types that are not java 
> serializable. 
> To reproduce, just set the serializer to Kryo, and run the following job:
> {code}
> class JavaNonSerializableClass extends Comparable { override def compareTo(o: 
> JavaNonSerializableClass) = 0 }
> sc.parallelize(Seq(new JavaNonSerializableClass, new 
> JavaNonSerializableClass), 2).map(x => (x,x)).sortByKey()
> {code}
> Basically the partitioner will always be serialized using Java (by the task 
> closure serializer). However, the rangeBounds variable in RangePartitioner 
> should be serialized with the user specified serializer. 



--
This message was sent by Atlassian JIRA
(v6.2#6252)

Reply via email to