Rory, I just sent a PR (https://github.com/avulanov/ann-benchmark/pull/1) to bring that benchmark up to date. Hope it helps.
On Fri, Sep 11, 2015 at 6:39 AM, Rory Waite <rwa...@sdl.com> wrote: > Hi, > > I’ve been trying to train the new MultilayerPerceptronClassifier in spark > 1.5 for the MNIST digit recognition task. I’m trying to reproduce the work > here: > > https://github.com/avulanov/ann-benchmark > > The API has changed since this work, so I’m not sure that I’m setting up > the task correctly. > > After I've trained the classifier, it classifies everything as a 1. It > even does this for the training set. I am doing something wrong with the > setup? I’m not looking for state of the art performance, just something > that looks reasonable. This experiment is meant to be a quick sanity test. > > Here is the job: > > import org.apache.log4j._ > //Logger.getRootLogger.setLevel(Level.OFF) > import org.apache.spark.mllib.linalg.Vectors > import org.apache.spark.mllib.regression.LabeledPoint > import org.apache.spark.ml.classification.MultilayerPerceptronClassifier > import org.apache.spark.ml.Pipeline > import org.apache.spark.ml.PipelineStage > import org.apache.spark.mllib.util.MLUtils > import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator > import org.apache.spark.SparkContext > import org.apache.spark.SparkContext._ > import org.apache.spark.SparkConf > import org.apache.spark.sql.SQLContext > import java.io.FileOutputStream > import java.io.ObjectOutputStream > > object MNIST { > def main(args: Array[String]) { > val conf = new SparkConf().setAppName("MNIST") > conf.set("spark.driver.extraJavaOptions", "-XX:MaxPermSize=512M") > val sc = new SparkContext(conf) > val batchSize = 100 > val numIterations = 5 > val mlp = new MultilayerPerceptronClassifier > mlp.setLayers(Array[Int](780, 2500, 2000, 1500, 1000, 500, 10)) > mlp.setMaxIter(numIterations) > mlp.setBlockSize(batchSize) > val train = MLUtils.loadLibSVMFile(sc, > "file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale") > train.repartition(200) > val sqlContext = new SQLContext(sc) > import sqlContext.implicits._ > val df = train.toDF > val model = mlp.fit(df) > val trainPredictions = model.transform(df) > trainPredictions.show(100) > val test = MLUtils.loadLibSVMFile(sc, > "file:///misc/home/rwaite/mt-work/ann-benchmark/mnist.scale.t", 780).toDF > val result = model.transform(test) > result.show(100) > val predictionAndLabels = result.select("prediction", "label") > val evaluator = new MulticlassClassificationEvaluator() > .setMetricName("precision") > println("Precision:" + evaluator.evaluate(predictionAndLabels)) > val fos = new > FileOutputStream("/home/rwaite/mt-work/ann-benchmark/spark_out/spark_model.obj"); > val oos = new ObjectOutputStream(fos); > oos.writeObject(model); > oos.close > } > } > > > And here is the output: > > +-----+--------------------+----------+ > |label| features|prediction| > +-----+--------------------+----------+ > | 5.0|(780,[152,153,154...| 1.0| > | 0.0|(780,[127,128,129...| 1.0| > | 4.0|(780,[160,161,162...| 1.0| > | 1.0|(780,[158,159,160...| 1.0| > | 9.0|(780,[208,209,210...| 1.0| > | 2.0|(780,[155,156,157...| 1.0| > | 1.0|(780,[124,125,126...| 1.0| > | 3.0|(780,[151,152,153...| 1.0| > | 1.0|(780,[152,153,154...| 1.0| > | 4.0|(780,[134,135,161...| 1.0| > | 3.0|(780,[123,124,125...| 1.0| > | 5.0|(780,[216,217,218...| 1.0| > | 3.0|(780,[143,144,145...| 1.0| > | 6.0|(780,[72,73,74,99...| 1.0| > | 1.0|(780,[151,152,153...| 1.0| > | 7.0|(780,[211,212,213...| 1.0| > | 2.0|(780,[151,152,153...| 1.0| > | 8.0|(780,[159,160,161...| 1.0| > | 6.0|(780,[100,101,102...| 1.0| > | 9.0|(780,[209,210,211...| 1.0| > | 4.0|(780,[129,130,131...| 1.0| > | 0.0|(780,[129,130,131...| 1.0| > | 9.0|(780,[183,184,185...| 1.0| > | 1.0|(780,[158,159,160...| 1.0| > | 1.0|(780,[99,100,101,...| 1.0| > | 2.0|(780,[124,125,126...| 1.0| > | 4.0|(780,[185,186,187...| 1.0| > | 3.0|(780,[150,151,152...| 1.0| > | 2.0|(780,[145,146,147...| 1.0| > | 7.0|(780,[240,241,242...| 1.0| > | 3.0|(780,[201,202,203...| 1.0| > | 8.0|(780,[153,154,155...| 1.0| > | 6.0|(780,[71,72,73,74...| 1.0| > | 9.0|(780,[210,211,212...| 1.0| > | 0.0|(780,[154,155,156...| 1.0| > | 5.0|(780,[188,189,190...| 1.0| > | 6.0|(780,[98,99,100,1...| 1.0| > | 0.0|(780,[127,128,129...| 1.0| > | 7.0|(780,[201,202,203...| 1.0| > | 6.0|(780,[125,126,127...| 1.0| > | 1.0|(780,[154,155,156...| 1.0| > | 8.0|(780,[131,132,133...| 1.0| > | 7.0|(780,[209,210,211...| 1.0| > | 9.0|(780,[181,182,183...| 1.0| > | 3.0|(780,[174,175,176...| 1.0| > | 9.0|(780,[208,209,210...| 1.0| > | 8.0|(780,[152,153,154...| 1.0| > | 5.0|(780,[186,187,188...| 1.0| > | 9.0|(780,[150,151,152...| 1.0| > | 3.0|(780,[152,153,154...| 1.0| > | 3.0|(780,[122,123,124...| 1.0| > | 0.0|(780,[153,154,155...| 1.0| > | 7.0|(780,[203,204,205...| 1.0| > | 4.0|(780,[212,213,214...| 1.0| > | 9.0|(780,[205,206,207...| 1.0| > | 8.0|(780,[181,182,183...| 1.0| > | 0.0|(780,[151,152,153...| 1.0| > | 9.0|(780,[210,211,212...| 1.0| > | 4.0|(780,[156,157,158...| 1.0| > | 1.0|(780,[129,130,131...| 1.0| > | 4.0|(780,[149,159,160...| 1.0| > | 4.0|(780,[187,188,189...| 1.0| > | 6.0|(780,[127,128,129...| 1.0| > | 0.0|(780,[154,155,156...| 1.0| > | 4.0|(780,[152,153,154...| 1.0| > | 5.0|(780,[219,220,221...| 1.0| > | 6.0|(780,[74,75,101,1...| 1.0| > | 1.0|(780,[150,151,152...| 1.0| > | 0.0|(780,[124,125,126...| 1.0| > | 0.0|(780,[152,153,154...| 1.0| > | 1.0|(780,[97,98,99,12...| 1.0| > | 7.0|(780,[237,238,239...| 1.0| > | 1.0|(780,[124,125,126...| 1.0| > | 6.0|(780,[70,71,72,73...| 1.0| > | 3.0|(780,[149,150,151...| 1.0| > | 0.0|(780,[154,155,156...| 1.0| > | 2.0|(780,[124,125,126...| 1.0| > | 1.0|(780,[156,157,158...| 1.0| > | 1.0|(780,[127,128,129...| 1.0| > | 7.0|(780,[213,214,215...| 1.0| > | 9.0|(780,[123,124,125...| 1.0| > | 0.0|(780,[153,154,155...| 1.0| > | 2.0|(780,[94,95,96,97...| 1.0| > | 6.0|(780,[72,73,99,10...| 1.0| > | 7.0|(780,[199,200,201...| 1.0| > | 8.0|(780,[152,153,154...| 1.0| > | 3.0|(780,[171,172,173...| 1.0| > | 9.0|(780,[208,209,210...| 1.0| > | 0.0|(780,[122,123,124...| 1.0| > | 4.0|(780,[189,190,191...| 1.0| > | 6.0|(780,[73,74,75,76...| 1.0| > | 7.0|(780,[238,239,240...| 1.0| > | 4.0|(780,[158,159,177...| 1.0| > | 6.0|(780,[99,100,101,...| 1.0| > | 8.0|(780,[154,155,156...| 1.0| > | 0.0|(780,[126,127,128...| 1.0| > | 7.0|(780,[209,210,211...| 1.0| > | 8.0|(780,[152,153,154...| 1.0| > | 3.0|(780,[150,151,152...| 1.0| > | 1.0|(780,[156,157,158...| 1.0| > +-----+--------------------+----------+ > only showing top 100 rows > > +-----+--------------------+----------+ > |label| features|prediction| > +-----+--------------------+----------+ > | 7.0|(780,[202,203,204...| 1.0| > | 2.0|(780,[94,95,96,97...| 1.0| > | 1.0|(780,[128,129,130...| 1.0| > | 0.0|(780,[124,125,126...| 1.0| > | 4.0|(780,[150,151,159...| 1.0| > | 1.0|(780,[156,157,158...| 1.0| > | 4.0|(780,[149,150,151...| 1.0| > | 9.0|(780,[179,180,181...| 1.0| > | 5.0|(780,[129,130,131...| 1.0| > | 9.0|(780,[209,210,211...| 1.0| > | 0.0|(780,[123,124,125...| 1.0| > | 6.0|(780,[94,95,96,97...| 1.0| > | 9.0|(780,[208,209,210...| 1.0| > | 0.0|(780,[152,153,154...| 1.0| > | 1.0|(780,[125,126,127...| 1.0| > | 5.0|(780,[124,125,126...| 1.0| > | 9.0|(780,[179,180,181...| 1.0| > | 7.0|(780,[200,201,202...| 1.0| > | 3.0|(780,[118,119,120...| 1.0| > | 4.0|(780,[158,159,185...| 1.0| > | 9.0|(780,[183,184,185...| 1.0| > | 6.0|(780,[96,97,98,99...| 1.0| > | 6.0|(780,[93,94,95,12...| 1.0| > | 5.0|(780,[156,157,158...| 1.0| > | 4.0|(780,[151,152,178...| 1.0| > | 0.0|(780,[125,126,127...| 1.0| > | 7.0|(780,[230,234,235...| 1.0| > | 4.0|(780,[152,153,179...| 1.0| > | 0.0|(780,[149,150,151...| 1.0| > | 1.0|(780,[123,124,125...| 1.0| > | 3.0|(780,[175,176,177...| 1.0| > | 1.0|(780,[152,153,154...| 1.0| > | 3.0|(780,[148,149,150...| 1.0| > | 4.0|(780,[122,123,150...| 1.0| > | 7.0|(780,[175,176,177...| 1.0| > | 2.0|(780,[124,125,126...| 1.0| > | 7.0|(780,[202,203,204...| 1.0| > | 1.0|(780,[151,152,153...| 1.0| > | 2.0|(780,[125,126,127...| 1.0| > | 1.0|(780,[126,127,128...| 1.0| > | 1.0|(780,[125,126,153...| 1.0| > | 7.0|(780,[207,208,209...| 1.0| > | 4.0|(780,[176,177,178...| 1.0| > | 2.0|(780,[126,127,128...| 1.0| > | 3.0|(780,[121,122,123...| 1.0| > | 5.0|(780,[152,153,154...| 1.0| > | 1.0|(780,[122,123,124...| 1.0| > | 2.0|(780,[65,66,67,68...| 1.0| > | 4.0|(780,[177,178,179...| 1.0| > | 4.0|(780,[147,148,157...| 1.0| > | 6.0|(780,[100,101,102...| 1.0| > | 3.0|(780,[172,173,174...| 1.0| > | 5.0|(780,[163,164,165...| 1.0| > | 5.0|(780,[126,127,128...| 1.0| > | 6.0|(780,[93,94,95,12...| 1.0| > | 0.0|(780,[151,152,153...| 1.0| > | 4.0|(780,[148,149,150...| 1.0| > | 1.0|(780,[155,156,157...| 1.0| > | 9.0|(780,[209,210,211...| 1.0| > | 5.0|(780,[190,191,192...| 1.0| > | 7.0|(780,[198,199,200...| 1.0| > | 8.0|(780,[153,154,155...| 1.0| > | 9.0|(780,[178,179,180...| 1.0| > | 3.0|(780,[95,96,97,98...| 1.0| > | 7.0|(780,[200,201,202...| 1.0| > | 4.0|(780,[156,157,184...| 1.0| > | 6.0|(780,[67,68,69,95...| 1.0| > | 4.0|(780,[160,161,162...| 1.0| > | 3.0|(780,[148,149,150...| 1.0| > | 0.0|(780,[152,153,179...| 1.0| > | 7.0|(780,[206,207,208...| 1.0| > | 0.0|(780,[123,124,125...| 1.0| > | 2.0|(780,[119,120,121...| 1.0| > | 9.0|(780,[180,181,182...| 1.0| > | 1.0|(780,[152,153,154...| 1.0| > | 7.0|(780,[213,214,215...| 1.0| > | 3.0|(780,[124,125,126...| 1.0| > | 2.0|(780,[205,206,207...| 1.0| > | 9.0|(780,[183,184,185...| 1.0| > | 7.0|(780,[209,210,211...| 1.0| > | 7.0|(780,[205,206,207...| 1.0| > | 6.0|(780,[99,100,101,...| 1.0| > | 2.0|(780,[96,97,98,99...| 1.0| > | 7.0|(780,[204,205,206...| 1.0| > | 8.0|(780,[156,157,159...| 1.0| > | 4.0|(780,[147,148,158...| 1.0| > | 7.0|(780,[203,204,205...| 1.0| > | 3.0|(780,[146,147,148...| 1.0| > | 6.0|(780,[67,68,69,70...| 1.0| > | 1.0|(780,[128,129,130...| 1.0| > | 3.0|(780,[152,153,154...| 1.0| > | 6.0|(780,[71,72,73,74...| 1.0| > | 9.0|(780,[182,183,184...| 1.0| > | 3.0|(780,[149,150,151...| 1.0| > | 1.0|(780,[123,124,125...| 1.0| > | 4.0|(780,[158,159,160...| 1.0| > | 1.0|(780,[149,150,151...| 1.0| > | 7.0|(780,[175,176,177...| 1.0| > | 6.0|(780,[99,100,101,...| 1.0| > | 9.0|(780,[177,178,179...| 1.0| > +-----+--------------------+----------+ > only showing top 100 rows > > Precision:0.1135 > > > > > <http://www.sdl.com/> > www.sdl.com > > > SDL PLC confidential, all rights reserved. If you are not the intended > recipient of this mail SDL requests and requires that you delete it without > acting upon or copying any of its contents, and we further request that you > advise us. > > SDL PLC is a public limited company registered in England and Wales. > Registered number: 02675207. > Registered address: Globe House, Clivemont Road, Maidenhead, Berkshire SL6 > 7DY, UK. > > > This message has been scanned for malware by Websense. www.websense.com >