I'm trying to implement an algorithm on the MNIST digits that runs like so:


   - for every pair of digits (0,1), (0,2), (0,3)... assign a 0/1 label to
   the digits and build a LogisticRegression Classifier -- 45 in total
   - Fit every classifier on the test set separately
   - Aggregate the results per record of the test set and compute a
   prediction from the 45 predictions

I tried implementing this with a Pipeline, composed of

   - stringIndexer
   - a custom transformer which accepts a lower-digit and upper-digit
   argument, producing the 0/1 label
   - a custom transformer to assemble the indexed strings to VectorUDT
   - LogisticRegression

fed by a list of paramMaps. It failed because the fit() method of logistic
couldn't handle cases of null labels, i.e. a case where my 0/1 transformer
found neither the lower nor the upper digit label. I fixed this by
extending the LogisticRegression class and overriding the fit() method to
include a filter for labels in (0,1) -- I didn't want to alter the
transform method.

Now, I'd like to tune these models using CrossValidator with an estimator
of pipeline but when I run either fitMultiple on my paramMap or I loop over
the paramMaps, I get arcane Scala errors.


Is there a better way to build this procedure? Thanks!

Reply via email to