[ https://issues.apache.org/jira/browse/SPARK-2104?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14045504#comment-14045504 ]
Reynold Xin commented on SPARK-2104: ------------------------------------ BTW I have some old code I wrote -- you can do your changes based on this {code} /** * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * * Note that the actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( var partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], private val ascending: Boolean = true) extends Partitioner { private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions var rangeBounds: Array[K] = { if (partitions == 1) { Array() } else { val rddSize = rdd.count() val maxSampleSize = partitions * 20.0 val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted if (rddSample.length == 0) { Array() } else { val bounds = new Array[K](partitions - 1) for (i <- 0 until partitions - 1) { val index = (rddSample.length - 1) * (i + 1) / partitions bounds(i) = rddSample(index) } bounds } } } @throws(classOf[IOException]) private def writeObject(out: ObjectOutputStream): Unit = { val sfactory = SparkEnv.get.serializer // Treat java serializer with default action rather than going thru serialization, to avoid a // separate serialization header. sfactory match { case js: JavaSerializer => out.defaultWriteObject() case _ => out.writeInt(partitions) val ser = sfactory.newInstance() Utils.serializeViaNestedStream(out, ser) { stream => stream.writeObject(ordering) stream.writeObject(scala.reflect.classTag[K]) stream.writeObject(rangeBounds) } } } @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() case _ => partitions = in.readInt() val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds => println(ds) ordering = ds.readObject[Ordering[K]]() implicit val classTag = ds.readObject[ClassTag[Array[K]]]() rangeBounds = ds.readObject[Array[K]]()(classTag) binarySearch = CollectionsUtils.makeBinarySearch[K] } } } def numPartitions = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] def getPartition(key: Any): Int = { val k = key.asInstanceOf[K] var partition = 0 if (rangeBounds.length < 1000) { // If we have less than 100 partitions naive search while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) { partition += 1 } } else { // Determine which binary search method to use only once. partition = binarySearch(rangeBounds, k) // binarySearch either returns the match location or -[insertion point]-1 if (partition < 0) { partition = -partition-1 } if (partition > rangeBounds.length) { partition = rangeBounds.length } } if (ascending) { partition } else { rangeBounds.length - partition } } override def equals(other: Any): Boolean = other match { case r: RangePartitioner[_,_] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false } override def hashCode(): Int = { val prime = 31 var result = 1 var i = 0 while (i < rangeBounds.length) { result = prime * result + rangeBounds(i).hashCode i += 1 } result = prime * result + ascending.hashCode result } } {code} > RangePartitioner should use user specified serializer to serialize range > bounds > ------------------------------------------------------------------------------- > > Key: SPARK-2104 > URL: https://issues.apache.org/jira/browse/SPARK-2104 > Project: Spark > Issue Type: Bug > Reporter: Reynold Xin > > Otherwise it is pretty annoying to do a sort on types that are not java > serializable. > To reproduce, just set the serializer to Kryo, and run the following job: > {code} > class JavaNonSerializableClass extends Comparable { override def compareTo(o: > JavaNonSerializableClass) = 0 } > sc.parallelize(Seq(new JavaNonSerializableClass, new > JavaNonSerializableClass), 2).map(x => (x,x)).sortByKey() > {code} > Basically the partitioner will always be serialized using Java (by the task > closure serializer). However, the rangeBounds variable in RangePartitioner > should be serialized with the user specified serializer. -- This message was sent by Atlassian JIRA (v6.2#6252)