Rory,

I just sent a PR (https://github.com/avulanov/ann-benchmark/pull/1) to
bring that benchmark up to date. Hope it helps.

On Fri, Sep 11, 2015 at 6:39 AM, Rory Waite <rwa...@sdl.com> wrote:

> Hi,
>
> I’ve been trying to train the new MultilayerPerceptronClassifier in spark
> 1.5 for the MNIST digit recognition task. I’m trying to reproduce the work
> here:
>
> https://github.com/avulanov/ann-benchmark
>
> The API has changed since this work, so I’m not sure that I’m setting up
> the task correctly.
>
> After I've trained the classifier, it classifies everything as a 1. It
> even does this for the training set. I am doing something wrong with the
> setup? I’m not looking for state of the art performance, just something
> that looks reasonable. This experiment is meant to be a quick sanity test.
>
> Here is the job:
>
> import org.apache.log4j._
> //Logger.getRootLogger.setLevel(Level.OFF)
> import org.apache.spark.mllib.linalg.Vectors
> import org.apache.spark.mllib.regression.LabeledPoint
> import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
> import org.apache.spark.ml.Pipeline
> import org.apache.spark.ml.PipelineStage
> import org.apache.spark.mllib.util.MLUtils
> import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
> import org.apache.spark.SparkContext
> import org.apache.spark.SparkContext._
> import org.apache.spark.SparkConf
> import org.apache.spark.sql.SQLContext
> import java.io.FileOutputStream
> import java.io.ObjectOutputStream
>
> object MNIST {
>  def main(args: Array[String]) {
>    val conf = new SparkConf().setAppName("MNIST")
>    conf.set("spark.driver.extraJavaOptions", "-XX:MaxPermSize=512M")
>    val sc = new SparkContext(conf)
>    val batchSize = 100
>    val numIterations = 5
>    val mlp = new MultilayerPerceptronClassifier
>    mlp.setLayers(Array[Int](780, 2500, 2000, 1500, 1000, 500, 10))
>    mlp.setMaxIter(numIterations)
>    mlp.setBlockSize(batchSize)
>    val train = MLUtils.loadLibSVMFile(sc,
> "file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale")
>    train.repartition(200)
>    val sqlContext = new SQLContext(sc)
>    import sqlContext.implicits._
>    val df = train.toDF
>    val model = mlp.fit(df)
>    val trainPredictions = model.transform(df)
>    trainPredictions.show(100)
>    val test = MLUtils.loadLibSVMFile(sc,
> "file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale.t", 780).toDF
>    val result = model.transform(test)
>    result.show(100)
>    val predictionAndLabels = result.select("prediction", "label")
>    val evaluator = new MulticlassClassificationEvaluator()
>      .setMetricName("precision")
>    println("Precision:" + evaluator.evaluate(predictionAndLabels))
>    val fos = new
> FileOutputStream("/home/rwaite/mt-work/ann-benchmark/spark_out/spark_model.obj");
>    val oos = new ObjectOutputStream(fos);
>    oos.writeObject(model);
>    oos.close
>  }
> }
>
>
> And here is the output:
>
> +-----+--------------------+----------+
> |label|            features|prediction|
> +-----+--------------------+----------+
> |  5.0|(780,[152,153,154...|       1.0|
> |  0.0|(780,[127,128,129...|       1.0|
> |  4.0|(780,[160,161,162...|       1.0|
> |  1.0|(780,[158,159,160...|       1.0|
> |  9.0|(780,[208,209,210...|       1.0|
> |  2.0|(780,[155,156,157...|       1.0|
> |  1.0|(780,[124,125,126...|       1.0|
> |  3.0|(780,[151,152,153...|       1.0|
> |  1.0|(780,[152,153,154...|       1.0|
> |  4.0|(780,[134,135,161...|       1.0|
> |  3.0|(780,[123,124,125...|       1.0|
> |  5.0|(780,[216,217,218...|       1.0|
> |  3.0|(780,[143,144,145...|       1.0|
> |  6.0|(780,[72,73,74,99...|       1.0|
> |  1.0|(780,[151,152,153...|       1.0|
> |  7.0|(780,[211,212,213...|       1.0|
> |  2.0|(780,[151,152,153...|       1.0|
> |  8.0|(780,[159,160,161...|       1.0|
> |  6.0|(780,[100,101,102...|       1.0|
> |  9.0|(780,[209,210,211...|       1.0|
> |  4.0|(780,[129,130,131...|       1.0|
> |  0.0|(780,[129,130,131...|       1.0|
> |  9.0|(780,[183,184,185...|       1.0|
> |  1.0|(780,[158,159,160...|       1.0|
> |  1.0|(780,[99,100,101,...|       1.0|
> |  2.0|(780,[124,125,126...|       1.0|
> |  4.0|(780,[185,186,187...|       1.0|
> |  3.0|(780,[150,151,152...|       1.0|
> |  2.0|(780,[145,146,147...|       1.0|
> |  7.0|(780,[240,241,242...|       1.0|
> |  3.0|(780,[201,202,203...|       1.0|
> |  8.0|(780,[153,154,155...|       1.0|
> |  6.0|(780,[71,72,73,74...|       1.0|
> |  9.0|(780,[210,211,212...|       1.0|
> |  0.0|(780,[154,155,156...|       1.0|
> |  5.0|(780,[188,189,190...|       1.0|
> |  6.0|(780,[98,99,100,1...|       1.0|
> |  0.0|(780,[127,128,129...|       1.0|
> |  7.0|(780,[201,202,203...|       1.0|
> |  6.0|(780,[125,126,127...|       1.0|
> |  1.0|(780,[154,155,156...|       1.0|
> |  8.0|(780,[131,132,133...|       1.0|
> |  7.0|(780,[209,210,211...|       1.0|
> |  9.0|(780,[181,182,183...|       1.0|
> |  3.0|(780,[174,175,176...|       1.0|
> |  9.0|(780,[208,209,210...|       1.0|
> |  8.0|(780,[152,153,154...|       1.0|
> |  5.0|(780,[186,187,188...|       1.0|
> |  9.0|(780,[150,151,152...|       1.0|
> |  3.0|(780,[152,153,154...|       1.0|
> |  3.0|(780,[122,123,124...|       1.0|
> |  0.0|(780,[153,154,155...|       1.0|
> |  7.0|(780,[203,204,205...|       1.0|
> |  4.0|(780,[212,213,214...|       1.0|
> |  9.0|(780,[205,206,207...|       1.0|
> |  8.0|(780,[181,182,183...|       1.0|
> |  0.0|(780,[151,152,153...|       1.0|
> |  9.0|(780,[210,211,212...|       1.0|
> |  4.0|(780,[156,157,158...|       1.0|
> |  1.0|(780,[129,130,131...|       1.0|
> |  4.0|(780,[149,159,160...|       1.0|
> |  4.0|(780,[187,188,189...|       1.0|
> |  6.0|(780,[127,128,129...|       1.0|
> |  0.0|(780,[154,155,156...|       1.0|
> |  4.0|(780,[152,153,154...|       1.0|
> |  5.0|(780,[219,220,221...|       1.0|
> |  6.0|(780,[74,75,101,1...|       1.0|
> |  1.0|(780,[150,151,152...|       1.0|
> |  0.0|(780,[124,125,126...|       1.0|
> |  0.0|(780,[152,153,154...|       1.0|
> |  1.0|(780,[97,98,99,12...|       1.0|
> |  7.0|(780,[237,238,239...|       1.0|
> |  1.0|(780,[124,125,126...|       1.0|
> |  6.0|(780,[70,71,72,73...|       1.0|
> |  3.0|(780,[149,150,151...|       1.0|
> |  0.0|(780,[154,155,156...|       1.0|
> |  2.0|(780,[124,125,126...|       1.0|
> |  1.0|(780,[156,157,158...|       1.0|
> |  1.0|(780,[127,128,129...|       1.0|
> |  7.0|(780,[213,214,215...|       1.0|
> |  9.0|(780,[123,124,125...|       1.0|
> |  0.0|(780,[153,154,155...|       1.0|
> |  2.0|(780,[94,95,96,97...|       1.0|
> |  6.0|(780,[72,73,99,10...|       1.0|
> |  7.0|(780,[199,200,201...|       1.0|
> |  8.0|(780,[152,153,154...|       1.0|
> |  3.0|(780,[171,172,173...|       1.0|
> |  9.0|(780,[208,209,210...|       1.0|
> |  0.0|(780,[122,123,124...|       1.0|
> |  4.0|(780,[189,190,191...|       1.0|
> |  6.0|(780,[73,74,75,76...|       1.0|
> |  7.0|(780,[238,239,240...|       1.0|
> |  4.0|(780,[158,159,177...|       1.0|
> |  6.0|(780,[99,100,101,...|       1.0|
> |  8.0|(780,[154,155,156...|       1.0|
> |  0.0|(780,[126,127,128...|       1.0|
> |  7.0|(780,[209,210,211...|       1.0|
> |  8.0|(780,[152,153,154...|       1.0|
> |  3.0|(780,[150,151,152...|       1.0|
> |  1.0|(780,[156,157,158...|       1.0|
> +-----+--------------------+----------+
> only showing top 100 rows
>
> +-----+--------------------+----------+
> |label|            features|prediction|
> +-----+--------------------+----------+
> |  7.0|(780,[202,203,204...|       1.0|
> |  2.0|(780,[94,95,96,97...|       1.0|
> |  1.0|(780,[128,129,130...|       1.0|
> |  0.0|(780,[124,125,126...|       1.0|
> |  4.0|(780,[150,151,159...|       1.0|
> |  1.0|(780,[156,157,158...|       1.0|
> |  4.0|(780,[149,150,151...|       1.0|
> |  9.0|(780,[179,180,181...|       1.0|
> |  5.0|(780,[129,130,131...|       1.0|
> |  9.0|(780,[209,210,211...|       1.0|
> |  0.0|(780,[123,124,125...|       1.0|
> |  6.0|(780,[94,95,96,97...|       1.0|
> |  9.0|(780,[208,209,210...|       1.0|
> |  0.0|(780,[152,153,154...|       1.0|
> |  1.0|(780,[125,126,127...|       1.0|
> |  5.0|(780,[124,125,126...|       1.0|
> |  9.0|(780,[179,180,181...|       1.0|
> |  7.0|(780,[200,201,202...|       1.0|
> |  3.0|(780,[118,119,120...|       1.0|
> |  4.0|(780,[158,159,185...|       1.0|
> |  9.0|(780,[183,184,185...|       1.0|
> |  6.0|(780,[96,97,98,99...|       1.0|
> |  6.0|(780,[93,94,95,12...|       1.0|
> |  5.0|(780,[156,157,158...|       1.0|
> |  4.0|(780,[151,152,178...|       1.0|
> |  0.0|(780,[125,126,127...|       1.0|
> |  7.0|(780,[230,234,235...|       1.0|
> |  4.0|(780,[152,153,179...|       1.0|
> |  0.0|(780,[149,150,151...|       1.0|
> |  1.0|(780,[123,124,125...|       1.0|
> |  3.0|(780,[175,176,177...|       1.0|
> |  1.0|(780,[152,153,154...|       1.0|
> |  3.0|(780,[148,149,150...|       1.0|
> |  4.0|(780,[122,123,150...|       1.0|
> |  7.0|(780,[175,176,177...|       1.0|
> |  2.0|(780,[124,125,126...|       1.0|
> |  7.0|(780,[202,203,204...|       1.0|
> |  1.0|(780,[151,152,153...|       1.0|
> |  2.0|(780,[125,126,127...|       1.0|
> |  1.0|(780,[126,127,128...|       1.0|
> |  1.0|(780,[125,126,153...|       1.0|
> |  7.0|(780,[207,208,209...|       1.0|
> |  4.0|(780,[176,177,178...|       1.0|
> |  2.0|(780,[126,127,128...|       1.0|
> |  3.0|(780,[121,122,123...|       1.0|
> |  5.0|(780,[152,153,154...|       1.0|
> |  1.0|(780,[122,123,124...|       1.0|
> |  2.0|(780,[65,66,67,68...|       1.0|
> |  4.0|(780,[177,178,179...|       1.0|
> |  4.0|(780,[147,148,157...|       1.0|
> |  6.0|(780,[100,101,102...|       1.0|
> |  3.0|(780,[172,173,174...|       1.0|
> |  5.0|(780,[163,164,165...|       1.0|
> |  5.0|(780,[126,127,128...|       1.0|
> |  6.0|(780,[93,94,95,12...|       1.0|
> |  0.0|(780,[151,152,153...|       1.0|
> |  4.0|(780,[148,149,150...|       1.0|
> |  1.0|(780,[155,156,157...|       1.0|
> |  9.0|(780,[209,210,211...|       1.0|
> |  5.0|(780,[190,191,192...|       1.0|
> |  7.0|(780,[198,199,200...|       1.0|
> |  8.0|(780,[153,154,155...|       1.0|
> |  9.0|(780,[178,179,180...|       1.0|
> |  3.0|(780,[95,96,97,98...|       1.0|
> |  7.0|(780,[200,201,202...|       1.0|
> |  4.0|(780,[156,157,184...|       1.0|
> |  6.0|(780,[67,68,69,95...|       1.0|
> |  4.0|(780,[160,161,162...|       1.0|
> |  3.0|(780,[148,149,150...|       1.0|
> |  0.0|(780,[152,153,179...|       1.0|
> |  7.0|(780,[206,207,208...|       1.0|
> |  0.0|(780,[123,124,125...|       1.0|
> |  2.0|(780,[119,120,121...|       1.0|
> |  9.0|(780,[180,181,182...|       1.0|
> |  1.0|(780,[152,153,154...|       1.0|
> |  7.0|(780,[213,214,215...|       1.0|
> |  3.0|(780,[124,125,126...|       1.0|
> |  2.0|(780,[205,206,207...|       1.0|
> |  9.0|(780,[183,184,185...|       1.0|
> |  7.0|(780,[209,210,211...|       1.0|
> |  7.0|(780,[205,206,207...|       1.0|
> |  6.0|(780,[99,100,101,...|       1.0|
> |  2.0|(780,[96,97,98,99...|       1.0|
> |  7.0|(780,[204,205,206...|       1.0|
> |  8.0|(780,[156,157,159...|       1.0|
> |  4.0|(780,[147,148,158...|       1.0|
> |  7.0|(780,[203,204,205...|       1.0|
> |  3.0|(780,[146,147,148...|       1.0|
> |  6.0|(780,[67,68,69,70...|       1.0|
> |  1.0|(780,[128,129,130...|       1.0|
> |  3.0|(780,[152,153,154...|       1.0|
> |  6.0|(780,[71,72,73,74...|       1.0|
> |  9.0|(780,[182,183,184...|       1.0|
> |  3.0|(780,[149,150,151...|       1.0|
> |  1.0|(780,[123,124,125...|       1.0|
> |  4.0|(780,[158,159,160...|       1.0|
> |  1.0|(780,[149,150,151...|       1.0|
> |  7.0|(780,[175,176,177...|       1.0|
> |  6.0|(780,[99,100,101,...|       1.0|
> |  9.0|(780,[177,178,179...|       1.0|
> +-----+--------------------+----------+
> only showing top 100 rows
>
> Precision:0.1135
>
>
>
>
>   <http://www.sdl.com/>
> www.sdl.com
>
>
> SDL PLC confidential, all rights reserved. If you are not the intended
> recipient of this mail SDL requests and requires that you delete it without
> acting upon or copying any of its contents, and we further request that you
> advise us.
>
> SDL PLC is a public limited company registered in England and Wales.
> Registered number: 02675207.
> Registered address: Globe House, Clivemont Road, Maidenhead, Berkshire SL6
> 7DY, UK.
>
>
> This message has been scanned for malware by Websense. www.websense.com
>

Reply via email to