Bago,
The code I wrote is not generating the issue. In our case, we build a ML
pipeline from a UI and is done in a particular fashion so that a user can
create a pipeline behind the scene using drag and drop. I am yet to dig
deeper to recreate the same as a standalone code. Meanwhile I am sharing a
similar which I wrote here. Hope to find time next week to get the correct
one.
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.streaming.StreamingQuery;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class StreamingIssueCountVectorizerSplit {
public static void main(String[] args) throws Exception{
SparkSession sparkSession =
SparkSession.builder().appName("StreamingIssueCountVectorizer")
.master("local[2]")
.getOrCreate();
List<Row> _trainData = Arrays.asList(
RowFactory.create("sunny fantastic day", "Positive"),
RowFactory.create("fantastic morning match", "Positive"),
RowFactory.create("good morning", "Positive"),
RowFactory.create("boring evening", "Negative"),
RowFactory.create("tragic evening event", "Negative"),
RowFactory.create("today is bad ", "Negative")
);
List<Row> _testData = Arrays.asList(
RowFactory.create("sunny morning"),
RowFactory.create("bad evening")
);
StructType schema = new StructType(new StructField[]{
new StructField("tweet", DataTypes.StringType, false,
Metadata.empty()),
new StructField("sentiment", DataTypes.StringType, true,
Metadata.empty())
});
StructType testSchema = new StructType(new StructField[]{
new StructField("tweet", DataTypes.StringType, false,
Metadata.empty())
});
Dataset<Row> trainData = sparkSession.createDataFrame(_trainData,
schema);
Dataset<Row> testData = sparkSession.createDataFrame(_testData,
testSchema);
StringIndexerModel labelIndexerModel = new StringIndexer()
.setInputCol("sentiment")
.setOutputCol("label")
.setHandleInvalid("skip")
.fit(trainData);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("tweet")
.setOutputCol("words");
CountVectorizer countVectorizer = new CountVectorizer()
.setInputCol(tokenizer.getOutputCol())
.setOutputCol("features")
.setVocabSize(3)
.setMinDF(2)
.setMinTF(2).setBinary(true);
Dataset<Row> words = tokenizer.transform(trainData);
CountVectorizerModel countVectorizerModel = countVectorizer.fit(words);
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.001);
IndexToString labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predicted")
.setLabels(labelIndexerModel.labels());
countVectorizerModel.setMinTF(1);
Pipeline pipeline = new Pipeline()
.setStages(
new PipelineStage[]{labelIndexerModel, tokenizer,
countVectorizerModel, lr, labelConverter});
ParamMap[] paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam(), new double[]{0.1, 0.01})
.addGrid(lr.fitIntercept())
.addGrid(lr.elasticNetParam(), new double[]{0.0, 0.5, 1.0})
.build();
MulticlassClassificationEvaluator evaluator = new
MulticlassClassificationEvaluator();
evaluator.setLabelCol("label");
evaluator.setPredictionCol("prediction");
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setTrainRatio(0.7);
// Fit the pipeline to training documents.
TrainValidationSplitModel trainValidationSplitModel =
trainValidationSplit.fit(trainData);
trainValidationSplitModel.write().overwrite().save("/tmp/CountSplit.model");
TrainValidationSplitModel _loadedModel =
TrainValidationSplitModel.load("/tmp/CountSplit.model");
PipelineModel loadedModel = (PipelineModel) ( _loadedModel).bestModel();
//Test on non-streaming data
Dataset<Row> predicted = loadedModel.transform(testData);
List<Row> rows = predicted.select("tweet", "predicted").collectAsList();
for (Row r : rows) {
System.out.println("[" + r.get(0) + "], prediction=" + r.get(1));
}
//Test on streaming data
Dataset<Row> lines = sparkSession
.readStream()
.format("socket")
.option("host", "localhost")
.option("port", 9999)
.load();
lines = lines.withColumnRenamed("value", "tweet");
StreamingQuery query = loadedModel.transform(lines).writeStream()
.outputMode("append")
.format("console")
.start();
query.awaitTermination();
}
}
--
Sent from: http://apache-spark-developers-list.1001551.n3.nabble.com/
---------------------------------------------------------------------
To unsubscribe e-mail: [email protected]