Hi,
I am trying to run the straightforward example of SVm but I am getting low
accuracy (around 50%) when I predict using the same data I used for
training. I am probably doing the prediction in a wrong way. My code is
below. I would appreciate any help.
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;
import edu.illinois.biglbjava.readers.LabeledPointReader;
public class SimpleDistSVM {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("SVM Classifier Example");
SparkContext sc = new SparkContext(conf);
String inputPath=args[0];
// Read training data
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc,
inputPath).toJavaRDD();
// Run training algorithm to build the model.
int numIterations = 3;
final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations);
// Clear the default threshold.
model.clearThreshold();
// Predict points in test set and map to an RDD of 0/1 values where 0
is misclassication and 1 is correct classification
JavaRDD<Integer> classification = data.map(new Function<LabeledPoint,
Integer>() {
public Integer call(LabeledPoint p) {
int label = (int) p.label();
Double score = model.predict(p.features());
if((score >=0 && label == 1) || (score <0 && label == 0))
{
return 1; //correct classiciation
}
else
return 0;
}
}
);
// sum up all values in the rdd to get the number of correctly
classified examples
int sum=classification.reduce(new Function2<Integer, Integer,
Integer>()
{
public Integer call(Integer arg0, Integer arg1)
throws Exception {
return arg0+arg1;
}});
//compute accuracy as the percentage of the correctly classified
examples
double accuracy=((double)sum)/((double)classification.count());
System.out.println("Accuracy = " + accuracy);
}
}
);
}
}