
shijinkui commented on SPARK-6932:

hi, @Xiangrui Meng. i have idea that training data  task keep  a persistent 
connection with parameter server using akka streams or netty directly. what do 
u think about this

> A Prototype of Parameter Server
> -------------------------------
>                 Key: SPARK-6932
>                 URL: https://issues.apache.org/jira/browse/SPARK-6932
>             Project: Spark
>          Issue Type: New Feature
>          Components: ML, MLlib, Spark Core
>            Reporter: Qiping Li
>  h2. Introduction
> As specified in 
> [SPARK-4590|https://issues.apache.org/jira/browse/SPARK-4590],it would be 
> very helpful to integrate parameter server into Spark for machine learning 
> algorithms, especially for those with ultra high dimensions features. 
> After carefully studying the design doc of [Parameter 
> Servers|https://docs.google.com/document/d/1SX3nkmF41wFXAAIr9BgqvrHSS5mW362fJ7roBXJm06o/edit?usp=sharing],and
>  the paper of [Factorbird|http://stanford.edu/~rezab/papers/factorbird.pdf], 
> we proposed a prototype of Parameter Server on Spark(Ps-on-Spark), with 
> several key design concerns:
> * *User friendly interface*
>       Careful investigation is done to most existing Parameter Server 
> systems(including:  [petuum|http://petuum.github.io], [parameter 
> server|http://parameterserver.org], 
> [paracel|https://github.com/douban/paracel]) and a user friendly interface is 
> design by absorbing essence from all these system. 
> * *Prototype of distributed array*
>     IndexRDD (see 
> [SPARK-4590|https://issues.apache.org/jira/browse/SPARK-4590]) doesn't seem 
> to be a good option for distributed array, because in most case, the #key 
> updates/second is not be very high. 
>     So we implement a distributed HashMap to store the parameters, which can 
> be easily extended to get better performance.
> * *Minimal code change*
>       Quite a lot of effort in done to avoid code change of Spark core. Tasks 
> which need parameter server are still created and scheduled by Spark's 
> scheduler. Tasks communicate with parameter server with a client object, 
> through *akka* or *netty*.
> With all these concerns we propose the following architecture:
> h2. Architecture
> !https://cloud.githubusercontent.com/assets/1285855/7158179/f2d25cc4-e3a9-11e4-835e-89681596c478.jpg!
> Data is stored in RDD and is partitioned across workers. During each 
> iteration, each worker gets parameters from parameter server then computes 
> new parameters based on old parameters and data in the partition. Finally 
> each worker updates parameters to parameter server.Worker communicates with 
> parameter server through a parameter server client,which is initialized in 
> `TaskContext` of this worker.
> The current implementation is based on YARN cluster mode, 
> but it should not be a problem to transplanted it to other modes. 
> h3. Interface
> We refer to existing parameter server systems(petuum, parameter server, 
> paracel) when design the interface of parameter server. 
> *`PSClient` provides the following interface for workers to use:*
> {code}
> //  get parameter indexed by key from parameter server
> def get[T](key: String): T
> // get multiple parameters from parameter server
> def multiGet[T](keys: Array[String]): Array[T]
> // add parameter indexed by `key` by `delta`, 
> // if multiple `delta` to update on the same parameter,
> // use `reduceFunc` to reduce these `delta`s frist.
> def update[T](key: String, delta: T, reduceFunc: (T, T) => T): Unit
> // update multiple parameters at the same time, use the same `reduceFunc`.    
> def multiUpdate(keys: Array[String], delta: Array[T], reduceFunc: (T, T) => 
> T: Unit
> // advance clock to indicate that current iteration is finished.
> def clock(): Unit
> // block until all workers have reached this line of code.
> def sync(): Unit
> {code}
> *`PSContext` provides following functions to use on driver:*
> {code}
> // load parameters from existing rdd.
> def loadPSModel[T](model: RDD[String, T]) 
> // fetch parameters from parameter server to construct model.
> def fetchPSModel[T](keys: Array[String]): Array[T]
> {code} 
> *A new function has been add to `RDD` to run parameter server tasks:*
> {code}
> // run the provided `func` on each partition of this RDD. 
> // This function can use data of this partition(the first argument) 
> // and a parameter server client(the second argument). 
> // See the following Logistic Regression for an example.
> def runWithPS[U: ClassTag](func: (Array[T], PSClient) => U): Array[U]
> {code}
> h2. Example
> Here is an example of using our prototype to implement logistic regression:
> {code:title=LogisticRegression.scala|borderStyle=solid}
> def train(
>     sc: SparkContext,
>     input: RDD[LabeledPoint],
>     numIterations: Int,
>     stepSize: Double,
>     miniBatchFraction: Double): LogisticRegressionModel = {
>     // initialize weights
>     val numFeatures = input.map(_.features.size).first()
>     val initialWeights = new Array[Double](numFeatures)
>     // initialize parameter server context
>     val pssc = new PSContext(sc)
>     // load initialized weights into parameter server
>     val initialModelRDD = sc.parallelize(Array(("w", initialWeights)), 1)
>     pssc.loadPSModel(initialModelRDD)
>     // run logistic regression algorithm on input data   
>     input.runWithPS((arr, client) => {
>       val sampler = new BernoulliSampler[LabeledPoint](miniBatchFraction)
>       // for each iteration, compute delta and update weights
>       for (i <- 0 to numIterations) {
>         // get weights from parameter server
>         val weights = Vectors.dense(client.get[Array[Double]]("w"))
>         sampler.setSeed(i + 42)
>         // for each sample point, compute delta and update weights
>         sampler.sample(arr.toIterator).foreach { point =>
>           // compute delta
>           val data = point.features
>           val label = point.label
>           val margin = -1.0 * dot(data, weights)
>           val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
>           val delta = Vectors.dense(new Array[Double](numFeatures))
>           axpy((-1) * stepSize / math.sqrt(i + 1) * multiplier, data, delta)
>           // update weights
>           client.update("w", delta.toArray, (d1, d2) => {
>             d1.zip(d2).map((a, b) => a + b)
>           })
>         }
>         // end of current iteration
>         client.clock()
>       }
>     })
>     // fetch weights from parameter server
>     val weights = 
> Vectors.dense(pssc.fetchPSModel[Array[Double]](Array("w"))(0))
>     val intercept = 0.0
>     // construct LogisiticRegressionModel
>     new LogisticRegressionModel(weights, intercept).clearThreshold()
> }
> {code}
> The above code can be run on  current PS-on-Spark implementation.
> h2. Other considerations
> The current implementation is just a prototype and we will try to improve it 
> in the following directions: 
> h3. Consistency protocol
> Currently we have just implemented BSP protocol. And SSP consistency will be 
> added soon.
> h3. Model partition across servers
> Currently all the parameters are stored on a single server. Parameters should 
> be partitioned across multiple servers when the parameter size get large. 
> Parameter server client should route request to different servers 
> accordingly. 
> h3. Performance optimizing
> To get better performance, client can cache parameter servers and store 
> updates through operation log(as petuum does). There may be some other ways 
> to improve performance.
> h3. Fault Recovery
> When a parameter server crashes, it should be restarted on another node. Data 
> of a parameter server should be periodically checkpointed so it can be 
> transfered when a server is restarted.When a task is restarted, it should not 
> rerun finished iterations. 
> We would like to see parameter server integrated into Spark soon and hope 
> this help other Spark users who need parameter server. As specified above, 
> there is still much work to be done so any comments are welcome.

This message was sent by Atlassian JIRA

To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to