[ https://issues.apache.org/jira/browse/SPARK-6932?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
zhangyouhua updated SPARK-6932: ------------------------------- Comment: was deleted (was: @Qiping Li in your idea the PS client run in slave node, but where the PS Server will run or deploy?) > A Prototype of Parameter Server > ------------------------------- > > Key: SPARK-6932 > URL: https://issues.apache.org/jira/browse/SPARK-6932 > Project: Spark > Issue Type: New Feature > Components: ML, MLlib, Spark Core > Reporter: Qiping Li > > h2. Introduction > As specified in > [SPARK-4590|https://issues.apache.org/jira/browse/SPARK-4590],it would be > very helpful to integrate parameter server into Spark for machine learning > algorithms, especially for those with ultra high dimensions features. > After carefully studying the design doc of [Parameter > Servers|https://docs.google.com/document/d/1SX3nkmF41wFXAAIr9BgqvrHSS5mW362fJ7roBXJm06o/edit?usp=sharing],and > the paper of [Factorbird|http://stanford.edu/~rezab/papers/factorbird.pdf], > we proposed a prototype of Parameter Server on Spark(Ps-on-Spark), with > several key design concerns: > * *User friendly interface* > Careful investigation is done to most existing Parameter Server > systems(including: [petuum|http://petuum.github.io], [parameter > server|http://parameterserver.org], > [paracel|https://github.com/douban/paracel]) and a user friendly interface is > design by absorbing essence from all these system. > * *Prototype of distributed array* > IndexRDD (see > [SPARK-4590|https://issues.apache.org/jira/browse/SPARK-4590]) doesn't seem > to be a good option for distributed array, because in most case, the #key > updates/second is not be very high. > So we implement a distributed HashMap to store the parameters, which can > be easily extended to get better performance. > > * *Minimal code change* > Quite a lot of effort in done to avoid code change of Spark core. Tasks > which need parameter server are still created and scheduled by Spark's > scheduler. Tasks communicate with parameter server with a client object, > through *akka* or *netty*. > With all these concerns we propose the following architecture: > h2. Architecture > !https://cloud.githubusercontent.com/assets/1285855/7158179/f2d25cc4-e3a9-11e4-835e-89681596c478.jpg! > Data is stored in RDD and is partitioned across workers. During each > iteration, each worker gets parameters from parameter server then computes > new parameters based on old parameters and data in the partition. Finally > each worker updates parameters to parameter server.Worker communicates with > parameter server through a parameter server client,which is initialized in > `TaskContext` of this worker. > The current implementation is based on YARN cluster mode, > but it should not be a problem to transplanted it to other modes. > h3. Interface > We refer to existing parameter server systems(petuum, parameter server, > paracel) when design the interface of parameter server. > *`PSClient` provides the following interface for workers to use:* > {code} > // get parameter indexed by key from parameter server > def get[T](key: String): T > // get multiple parameters from parameter server > def multiGet[T](keys: Array[String]): Array[T] > // add parameter indexed by `key` by `delta`, > // if multiple `delta` to update on the same parameter, > // use `reduceFunc` to reduce these `delta`s frist. > def update[T](key: String, delta: T, reduceFunc: (T, T) => T): Unit > // update multiple parameters at the same time, use the same `reduceFunc`. > def multiUpdate(keys: Array[String], delta: Array[T], reduceFunc: (T, T) => > T: Unit > > // advance clock to indicate that current iteration is finished. > def clock(): Unit > > // block until all workers have reached this line of code. > def sync(): Unit > {code} > *`PSContext` provides following functions to use on driver:* > {code} > // load parameters from existing rdd. > def loadPSModel[T](model: RDD[String, T]) > // fetch parameters from parameter server to construct model. > def fetchPSModel[T](keys: Array[String]): Array[T] > {code} > > *A new function has been add to `RDD` to run parameter server tasks:* > {code} > // run the provided `func` on each partition of this RDD. > // This function can use data of this partition(the first argument) > // and a parameter server client(the second argument). > // See the following Logistic Regression for an example. > def runWithPS[U: ClassTag](func: (Array[T], PSClient) => U): Array[U] > > {code} > h2. Example > Here is an example of using our prototype to implement logistic regression: > {code:title=LogisticRegression.scala|borderStyle=solid} > def train( > sc: SparkContext, > input: RDD[LabeledPoint], > numIterations: Int, > stepSize: Double, > miniBatchFraction: Double): LogisticRegressionModel = { > > // initialize weights > val numFeatures = input.map(_.features.size).first() > val initialWeights = new Array[Double](numFeatures) > // initialize parameter server context > val pssc = new PSContext(sc) > // load initialized weights into parameter server > val initialModelRDD = sc.parallelize(Array(("w", initialWeights)), 1) > pssc.loadPSModel(initialModelRDD) > // run logistic regression algorithm on input data > input.runWithPS((arr, client) => { > val sampler = new BernoulliSampler[LabeledPoint](miniBatchFraction) > > // for each iteration, compute delta and update weights > for (i <- 0 to numIterations) { > // get weights from parameter server > val weights = Vectors.dense(client.get[Array[Double]]("w")) > sampler.setSeed(i + 42) > // for each sample point, compute delta and update weights > sampler.sample(arr.toIterator).foreach { point => > // compute delta > val data = point.features > val label = point.label > val margin = -1.0 * dot(data, weights) > val multiplier = (1.0 / (1.0 + math.exp(margin))) - label > val delta = Vectors.dense(new Array[Double](numFeatures)) > axpy((-1) * stepSize / math.sqrt(i + 1) * multiplier, data, delta) > // update weights > client.update("w", delta.toArray, (d1, d2) => { > d1.zip(d2).map((a, b) => a + b) > }) > } > > // end of current iteration > client.clock() > } > }) > // fetch weights from parameter server > val weights = > Vectors.dense(pssc.fetchPSModel[Array[Double]](Array("w"))(0)) > val intercept = 0.0 > // construct LogisiticRegressionModel > new LogisticRegressionModel(weights, intercept).clearThreshold() > } > {code} > The above code can be run on current PS-on-Spark implementation. > h2. Other considerations > The current implementation is just a prototype and we will try to improve it > in the following directions: > h3. Consistency protocol > Currently we have just implemented BSP protocol. And SSP consistency will be > added soon. > h3. Model partition across servers > Currently all the parameters are stored on a single server. Parameters should > be partitioned across multiple servers when the parameter size get large. > Parameter server client should route request to different servers > accordingly. > h3. Performance optimizing > To get better performance, client can cache parameter servers and store > updates through operation log(as petuum does). There may be some other ways > to improve performance. > h3. Fault Recovery > When a parameter server crashes, it should be restarted on another node. Data > of a parameter server should be periodically checkpointed so it can be > transfered when a server is restarted.When a task is restarted, it should not > rerun finished iterations. > We would like to see parameter server integrated into Spark soon and hope > this help other Spark users who need parameter server. As specified above, > there is still much work to be done so any comments are welcome. -- This message was sent by Atlassian JIRA (v6.3.4#6332) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org