Wenchen,
Thank you very much for your prompt reply and pointer!
As I think through, it makes sense that since my custom RDD is instantiated on
the driver, get whatever things I need from the SparkContext and assign them to
instance variables.
However the "RDD.SparkContext" and the Scala magic of class variables did not
work.
Here's what worked based you on your tip:
class MyDataSourceRDD(sc: SparkContext, conf: Map[String, String]) extends
RDD[Row](sc, Nil) {
val sparkConf = sc.getConf
override def getPartitions: Array[org.apache.spark.Partition] = {
sparkConf.getAll.foreach(println)
val numPartitions = conf.getOrElse("spark.mydata.numpartitions", "1").toInt
val rowsPerPartition = conf.getOrElse("spark.mydata.rowsperpartition",
"3").toInt
val partitions = 0 until numPartitions map(partition => new
MyDataSourcePartition(partition,rowsPerPartition))
partitions.toArray
}
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
{
val myDataSourcePartition = split.asInstanceOf[MyDataSourcePartition]
val partitionId = myDataSourcePartition.index
val rows = myDataSourcePartition.rowCount
val partitionData = 1 to rows map(r => Row(s"Partition: ${partitionId}, row
${r} of ${rows}"))
partitionData.iterator
}
}
From: Wenchen Fan <[email protected]>
Date: Wednesday, February 28, 2018 at 12:25 PM
To: "Thakrar, Jayesh" <[email protected]>
Cc: "[email protected]" <[email protected]>
Subject: Re: SparkContext - parameter for RDD, but not serializable, why?
My understanding:
RDD is also a driver side stuff like SparkContext, works like a handler to your
distributed data on the cluster.
However, `RDD.compute` (defines how to produce data for each partition) needs
to be executed on the remote nodes. It's more convenient to make RDD
serializable, and transfer RDD object to remote nodes and call `compute`.
So you can use SparkContext inside RDD, as long as you don't use it in methods
that are going to be executed on remote nodes, like `RDD.compute`. And you
should call `RDD.sparkContext` instead of using `sc` directly, because that
will turn `sc` from a constructor parameter to a class member variable(kind of
a hacky part of Scala), which will be serialized.
On Thu, Mar 1, 2018 at 2:03 AM, Thakrar, Jayesh
<[email protected]<mailto:[email protected]>> wrote:
Hi All,
I was just toying with creating a very rudimentary RDD datasource to understand
the inner workings of RDDs.
It seems that one of the constructors for RDD has a parameter of type
SparkContext, but it (apparently) exists on the driver only and is not
serializable.
Consequently, any attempt to use SparkContext parameter inside your custom RDD
generates a runtime error of it not being serializable.
Just wondering what is the rationale behind this?
I.e. if it is not serializable/usable, why make it a parameter?
And if it needs to be a parameter, why not make it serializable (is it even
possible?)
Below is my working code where I test a custom RDD.
scala> val mydata = spark.read.format("MyDataSourceProvider").load()
mydata: org.apache.spark.sql.DataFrame = [mydataStr: string]
scala> mydata.show(10, false)
+------------------------+
|mydataStr |
+------------------------+
|Partition: 0, row 1 of 3|
|Partition: 0, row 2 of 3|
|Partition: 0, row 3 of 3|
+------------------------+
scala>
///// custom RDD
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister,
RelationProvider, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
class MyDataSourceProvider extends DataSourceRegister
with RelationProvider with Logging {
override def shortName():String = { "mydata" }
private val myDataSchema: StructType = new StructType(Array[StructField](new
StructField("mydataStr", StringType, false)))
def sourceSchema(sqlContext: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String,
StructType) = {
(shortName(), schema.get)
}
override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
new MyDataRelation(sqlContext, myDataSchema, parameters)
}
}
class MyDataRelation(override val sqlContext: SQLContext,
override val schema: StructType,
params: Map[String, String]) extends BaseRelation with
TableScan with Logging {
override def buildScan(): org.apache.spark.rdd.RDD[Row] = {
val rdd = new MyDataSourceRDD(sqlContext.sparkContext,
sqlContext.getAllConfs)
rdd
}
override def needConversion = true
}
class MyDataSourceRDD(sc: SparkContext, conf: Map[String, String]) extends
RDD[Row](sc, Nil) {
override def getPartitions: Array[org.apache.spark.Partition] = {
// sc.getConf.getAll.foreach(println) - this fails with SparkContext not
serialiable error. So what use is this parameter ?!
val numPartitions = conf.getOrElse("spark.mydata.numpartitions", "1").toInt
val rowsPerPartition = conf.getOrElse("spark.mydata.rowsperpartition",
"3").toInt
val partitions = 0 until numPartitions map(partition => new
MyDataSourcePartition(partition,rowsPerPartition))
partitions.toArray
}
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
{
val myDataSourcePartition = split.asInstanceOf[MyDataSourcePartition]
val partitionId = myDataSourcePartition.index
val rows = myDataSourcePartition.rowCount
val partitionData = 1 to rows map(r => Row(s"Partition: ${partitionId}, row
${r} of ${rows}"))
partitionData.iterator
}
}
class MyDataSourcePartition(partitionId: Int, rows: Int) extends Partition with
Serializable {
override def index:Int = partitionId
def rowCount: Int = rows