My understanding:

RDD is also a driver side stuff like SparkContext, works like a handler to
your distributed data on the cluster.

However, `RDD.compute` (defines how to produce data for each partition)
needs to be executed on the remote nodes. It's more convenient to make RDD
serializable, and transfer RDD object to remote nodes and call `compute`.

So you can use SparkContext inside RDD, as long as you don't use it in
methods that are going to be executed on remote nodes, like `RDD.compute`.
And you should call `RDD.sparkContext` instead of using `sc` directly,
because that will turn `sc` from a constructor parameter to a class member
variable(kind of a hacky part of Scala), which will be serialized.

On Thu, Mar 1, 2018 at 2:03 AM, Thakrar, Jayesh <
jthak...@conversantmedia.com> wrote:

> Hi All,
>
>
>
> I was just toying with creating a very rudimentary RDD datasource to
> understand the inner workings of RDDs.
>
>
>
> It seems that one of the constructors for RDD has a parameter of type
> SparkContext, but it (apparently) exists on the driver only and is not
> serializable.
>
>
>
> Consequently, any attempt to use SparkContext parameter inside your custom
> RDD generates a runtime error of it not being serializable.
>
>
>
> Just wondering what is the rationale behind this?
>
> I.e. if it is not serializable/usable, why make it a parameter?
>
> And if it needs to be a parameter, why not make it serializable (is it
> even possible?)
>
>
>
> Below is my working code where I test a custom RDD.
>
>
>
> scala> val mydata = spark.read.format("MyDataSourceProvider").load()
>
> mydata: org.apache.spark.sql.DataFrame = [mydataStr: string]
>
>
>
> scala> mydata.show(10, false)
>
> +------------------------+
>
> |mydataStr               |
>
> +------------------------+
>
> |Partition: 0, row 1 of 3|
>
> |Partition: 0, row 2 of 3|
>
> |Partition: 0, row 3 of 3|
>
> +------------------------+
>
>
>
> scala>
>
>
>
>
>
> ///// custom RDD
>
>
>
>
> *import *org.apache.spark.internal.Logging
> *import *org.apache.spark.sql.{Row, SQLContext}
> *import *org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister,
> RelationProvider, TableScan}
> *import *org.apache.spark.sql.types.{StringType, StructField, StructType}
> *import *org.apache.spark.{Partition, SparkContext, TaskContext}
> *import *org.apache.spark.rdd.RDD
>
>
> *class *MyDataSourceProvider *extends *DataSourceRegister
>   *with *RelationProvider *with *Logging {
>
>   *override def *shortName():String = { *"mydata" *}
>
>   *private val **myDataSchema*: StructType = *new *StructType(*Array*
> [StructField](*new *StructField(*"mydataStr"*, StringType, *false*)))
>
>   *def *sourceSchema(sqlContext: SQLContext,
>                             schema: Option[StructType],
>                             providerName: String,
>                             parameters: Map[String, String]): (String,
> StructType) = {
>     (shortName(), schema.get)
>   }
>
>   *override def *createRelation(sqlContext: SQLContext,
>                               parameters: Map[String, String]):
> BaseRelation = {
>     *new *MyDataRelation(sqlContext, *myDataSchema*, parameters)
>   }
>
> }
>
>
> *class *MyDataRelation(*override val *sqlContext: SQLContext,
>                      *override val *schema: StructType,
>                      params: Map[String, String]) *extends *BaseRelation *
> with *TableScan *with *Logging {
>
>   *override def *buildScan(): org.apache.spark.rdd.RDD[Row] = {
>     *val *rdd = *new *MyDataSourceRDD(sqlContext.sparkContext,
> sqlContext.getAllConfs)
>     rdd
>   }
>
>   *override def *needConversion =
> *true *}
>
>
> *class *MyDataSourceRDD(sc: SparkContext, conf: Map[String, String]) *extends
> *RDD[Row](sc, *Nil*) {
>
>   *override def *getPartitions: Array[org.apache.spark.Partition] = {
>
> *// sc.getConf.getAll.foreach(println) - this fails with SparkContext not
> serialiable error. So what use is this parameter ?!     **val *numPartitions
> = conf.getOrElse(*"spark.mydata.numpartitions"*, *"1"*).toInt
>     *val *rowsPerPartition = conf.getOrElse(
> *"spark.mydata.rowsperpartition"*, *"3"*).toInt
>     *val *partitions = 0 until numPartitions map(partition => *new *
> MyDataSourcePartition(partition,rowsPerPartition))
>     partitions.toArray
>   }
>
>   *override def *compute(split: Partition, context: TaskContext):
> Iterator[Row] = {
>     *val *myDataSourcePartition = split.asInstanceOf[
> MyDataSourcePartition]
>     *val *partitionId = myDataSourcePartition.index
>     *val *rows = myDataSourcePartition.rowCount
>     *val *partitionData = 1 to rows map(r => *Row*(*s"Partition: $*
> {partitionId}*, row $*{r}* of $*{rows}*"*))
>     partitionData.iterator
>   }
>
> }
>
>
> *class *MyDataSourcePartition(partitionId: Int, rows: Int) *extends *Partition
> *with *Serializable {
>
>   *override def *index:Int = partitionId
>
>   *def *rowCount: Int = rows
>
>
>

Reply via email to