http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index ea83e8f..52bb4ec 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -54,7 +54,8 @@ public class JavaSimpleParamsExample { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + Dataset<Row> training = + jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -95,14 +96,14 @@ public class JavaSimpleParamsExample { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. - DataFrame results = model2.transform(test); - for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) { + Dataset<Row> results = model2.transform(test); + for (Row r: results.select("features", "label", "myProbability", "prediction").collectRows()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); }
http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 5473881..9bd543c 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -54,7 +54,8 @@ public class JavaSimpleTextClassificationPipeline { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + Dataset<Row> training = + jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +80,11 @@ public class JavaSimpleTextClassificationPipeline { new Document(5L, "l m n"), new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); - DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - DataFrame predictions = model.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) { + Dataset<Row> predictions = model.transform(test); + for (Row r: predictions.select("id", "text", "probability", "prediction").collectRows()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java index da47566..e2dd759 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -24,7 +24,8 @@ import org.apache.spark.sql.SQLContext; // $example on$ import org.apache.spark.ml.feature.StandardScaler; import org.apache.spark.ml.feature.StandardScalerModel; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; // $example off$ public class JavaStandardScalerExample { @@ -34,7 +35,7 @@ public class JavaStandardScalerExample { SQLContext jsql = new SQLContext(jsc); // $example on$ - DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -46,9 +47,9 @@ public class JavaStandardScalerExample { StandardScalerModel scalerModel = scaler.fit(dataFrame); // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); + Dataset<Row> scaledData = scalerModel.transform(dataFrame); scaledData.show(); // $example off$ jsc.stop(); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java index b6b201c..0ff3782 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -26,7 +26,7 @@ import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StopWordsRemover; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; @@ -57,7 +57,7 @@ public class JavaStopWordsRemoverExample { "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(rdd, schema); + Dataset<Row> dataset = jsql.createDataFrame(rdd, schema); remover.transform(dataset).show(); // $example off$ jsc.stop(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java index 05d12c1..ceacbb4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -26,7 +26,7 @@ import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.StructField; @@ -54,13 +54,13 @@ public class JavaStringIndexerExample { createStructField("id", IntegerType, false), createStructField("category", StringType, false) }); - DataFrame df = sqlContext.createDataFrame(jrdd, schema); + Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema); StringIndexer indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex"); - DataFrame indexed = indexer.fit(df).transform(df); + Dataset<Row> indexed = indexer.fit(df).transform(df); indexed.show(); // $example off$ jsc.stop(); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java index a41a5ec..fd1ce42 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.feature.IDF; import org.apache.spark.ml.feature.IDFModel; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -54,19 +54,19 @@ public class JavaTfIdfExample { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema); + Dataset<Row> sentenceData = sqlContext.createDataFrame(jrdd, schema); Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsData = tokenizer.transform(sentenceData); + Dataset<Row> wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurizedData = hashingTF.transform(wordsData); + Dataset<Row> featurizedData = hashingTF.transform(wordsData); IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); - DataFrame rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").take(3)) { + Dataset<Row> rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").takeRows(3)) { Vector features = r.getAs(0); Double label = r.getDouble(1); System.out.println(features); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index 617dc3f..a2f8c43 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -27,7 +27,7 @@ import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; @@ -54,12 +54,12 @@ public class JavaTokenizerExample { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + Dataset<Row> sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); - for (Row r : wordsDataFrame.select("words", "label"). take(3)) { + Dataset<Row> wordsDataFrame = tokenizer.transform(sentenceDataFrame); + for (Row r : wordsDataFrame.select("words", "label").takeRows(3)) { java.util.List<String> words = r.getList(0); for (String word : words) System.out.print(word + " "); System.out.println(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java index d433905..09bbc39 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -23,7 +23,8 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; /** @@ -44,12 +45,12 @@ public class JavaTrainValidationSplitExample { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); - DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Prepare training and test data. - DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); - DataFrame training = splits[0]; - DataFrame test = splits[1]; + Dataset<Row>[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); + Dataset<Row> training = splits[0]; + Dataset<Row> test = splits[1]; LinearRegression lr = new LinearRegression(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java index 7e230b5..953ad45 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.VectorAssembler; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.*; @@ -52,13 +52,13 @@ public class JavaVectorAssemblerExample { }); Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row)); - DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + Dataset<Row> dataset = sqlContext.createDataFrame(rdd, schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) .setOutputCol("features"); - DataFrame output = assembler.transform(dataset); + Dataset<Row> output = assembler.transform(dataset); System.out.println(output.select("features", "clicked").first()); // $example off$ jsc.stop(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java index 545758e..b3b5953 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -26,7 +26,8 @@ import java.util.Map; import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; // $example off$ public class JavaVectorIndexerExample { @@ -36,7 +37,7 @@ public class JavaVectorIndexerExample { SQLContext jsql = new SQLContext(jsc); // $example on$ - DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") @@ -53,9 +54,9 @@ public class JavaVectorIndexerExample { System.out.println(); // Create new column "indexed" with categorical values transformed to indices - DataFrame indexedData = indexerModel.transform(data); + Dataset<Row> indexedData = indexerModel.transform(data); indexedData.show(); // $example off$ jsc.stop(); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java index 4d5cb04..2ae57c3 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.ml.feature.VectorSlicer; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.*; @@ -55,7 +55,8 @@ public class JavaVectorSlicerExample { RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) )); - DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + Dataset<Row> dataset = + jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); @@ -63,7 +64,7 @@ public class JavaVectorSlicerExample { vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) - DataFrame output = vectorSlicer.transform(dataset); + Dataset<Row> output = vectorSlicer.transform(dataset); System.out.println(output.select("userFeatures", "features").first()); // $example off$ http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java index a4a05af..2dce8c2 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.Word2Vec; import org.apache.spark.ml.feature.Word2VecModel; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -49,7 +49,7 @@ public class JavaWord2VecExample { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + Dataset<Row> documentDF = sqlContext.createDataFrame(jrdd, schema); // Learn a mapping from words to Vectors. Word2Vec word2Vec = new Word2Vec() @@ -58,8 +58,8 @@ public class JavaWord2VecExample { .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); - DataFrame result = model.transform(documentDF); - for (Row r : result.select("result").take(3)) { + Dataset<Row> result = model.transform(documentDF); + for (Row r : result.select("result").takeRows(3)) { System.out.println(r); } // $example off$ http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index afee279..354a530 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -74,11 +74,12 @@ public class JavaSparkSQL { }); // Apply a schema to an RDD of Java Beans and register it as a table. - DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); + Dataset<Row> schemaPeople = sqlContext.createDataFrame(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + Dataset<Row> teenagers = + sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -99,11 +100,11 @@ public class JavaSparkSQL { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); + Dataset<Row> parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - DataFrame teenagers2 = + Dataset<Row> teenagers2 = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function<Row, String>() { @Override @@ -120,7 +121,7 @@ public class JavaSparkSQL { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlContext.read().json(path); + Dataset<Row> peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -134,7 +135,8 @@ public class JavaSparkSQL { peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. - DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + Dataset<Row> teenagers3 = + sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. @@ -151,7 +153,7 @@ public class JavaSparkSQL { List<String> jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD<String> anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); + Dataset<Row> peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); @@ -164,7 +166,7 @@ public class JavaSparkSQL { peopleFromJsonRDD.registerTempTable("people2"); - DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); + Dataset<Row> peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); List<String> nameAndCity = peopleWithCity.toJavaRDD().map(new Function<Row, String>() { @Override public String call(Row row) { http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index f0228f5..4b9d9ef 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -27,8 +27,9 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.VoidFunction2; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; @@ -92,13 +93,13 @@ public final class JavaSqlNetworkWordCount { return record; } }); - DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); + Dataset<Row> wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); // Register as table wordsDataFrame.registerTempTable("words"); // Do word count on table using SQL and print it - DataFrame wordCountsDataFrame = + Dataset<Row> wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word"); System.out.println("========= " + time + "========="); wordCountsDataFrame.show(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 0a8c9e5..60a4a1d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.ml; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -26,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +38,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; @Before public void setUp() { @@ -65,7 +66,7 @@ public class JavaPipelineSuite { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 40b9c35..0d923df 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -21,6 +21,8 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -30,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeClassifier dt = new DecisionTreeClassifier() http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 59b6fba..f470f4a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTClassifierSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaGBTClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. GBTClassifier rf = new GBTClassifier() http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fd22eb6..536f0dc 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; private double eps = 1e-5; @@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; @@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collect()) { + Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collectAsList()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); Assert.assertEquals(raw.size(), 2); @@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collect()) { + Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collectAsList()) { double pred = row.getDouble(0); Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index ec6b4bf..d499d36 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -28,7 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { @Test public void testMLPC() { - DataFrame dataFrame = sqlContext.createDataFrame( + Dataset<Row> dataFrame = sqlContext.createDataFrame( jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), @@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { .setSeed(11L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); - DataFrame result = model.transform(dataFrame); - Row[] predictionAndLabels = result.select("prediction", "label").collect(); + Dataset<Row> result = model.transform(dataFrame); + List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList(); for (Row r: predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 07936eb..45101f2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable { jsc = null; } - public void validatePrediction(DataFrame predictionAndLabels) { - for (Row r : predictionAndLabels.collect()) { + public void validatePrediction(Dataset<Row> predictionAndLabels) { + for (Row r : predictionAndLabels.collectAsList()) { double prediction = r.getAs(0); double label = r.getAs(1); assertEquals(label, prediction, 1E-5); @@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); - DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label"); validatePrediction(predictionAndLabels); } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index cbabafe..d493a7f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.List; +import org.apache.spark.sql.Row; import scala.collection.JavaConverters; import org.junit.After; @@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.SQLContext; public class JavaOneVsRestSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before @@ -75,7 +76,7 @@ public class JavaOneVsRestSuite implements Serializable { Assert.assertEquals(ova.getLabelCol() , "label"); Assert.assertEquals(ova.getPredictionCol() , "prediction"); OneVsRestModel ovaModel = ova.fit(dataset); - DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction"); predictions.collectAsList(); Assert.assertEquals(ovaModel.getLabelCol(), "label"); Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 5485fcb..9a63cef 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -31,7 +31,8 @@ import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestClassifierSuite implements Serializable { @@ -58,7 +59,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. RandomForestClassifier rf = new RandomForestClassifier() http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index cc5a4ef..a3fcdb5 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaKMeansSuite implements Serializable { private transient int k = 5; private transient JavaSparkContext sc; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient SQLContext sql; @Before @@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable { Vector[] centers = model.clusterCenters(); assertEquals(k, centers.length); - DataFrame transformed = model.transform(dataset); + Dataset<Row> transformed = model.transform(dataset); List<String> columns = Arrays.asList(transformed.columns()); List<String> expectedColumns = Arrays.asList("features", "prediction"); for (String column: expectedColumns) { http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index d707bde..77e3a48 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -25,7 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public class JavaBucketizerSuite { StructType schema = new StructType(new StructField[] { new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame( + Dataset<Row> dataset = jsql.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), @@ -70,7 +71,7 @@ public class JavaBucketizerSuite { .setOutputCol("result") .setSplits(splits); - Row[] result = bucketizer.transform(dataset).select("result").collect(); + List<Row> result = bucketizer.transform(dataset).select("result").collectAsList(); for (Row r : result) { double index = r.getDouble(0); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 63e5c93..ed1ad4c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -56,7 +57,7 @@ public class JavaDCTSuite { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - DataFrame dataset = jsql.createDataFrame( + Dataset<Row> dataset = jsql.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) @@ -69,8 +70,8 @@ public class JavaDCTSuite { .setInputCol("vec") .setOutputCol("resultVec"); - Row[] result = dct.transform(dataset).select("resultVec").collect(); - Vector resultVec = result[0].getAs("resultVec"); + List<Row> result = dct.transform(dataset).select("resultVec").collectAsList(); + Vector resultVec = result.get(0).getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 5932017..6e2cc7e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -27,7 +27,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,21 +65,21 @@ public class JavaHashingTFSuite { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = jsql.createDataFrame(data, schema); + Dataset<Row> sentenceData = jsql.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); - DataFrame wordsData = tokenizer.transform(sentenceData); + Dataset<Row> wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurizedData = hashingTF.transform(wordsData); + Dataset<Row> featurizedData = hashingTF.transform(wordsData); IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); - DataFrame rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").take(3)) { + Dataset<Row> rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index e17d549..5bbd963 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -26,7 +26,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaNormalizerSuite { @@ -53,17 +54,17 @@ public class JavaNormalizerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); // Normalize each Vector using $L^2$ norm. - DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); l2NormData.count(); // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = + Dataset<Row> lInfNormData = normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); lInfNormData.count(); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index e8f329f..1389d17 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -100,7 +100,7 @@ public class JavaPCASuite implements Serializable { } ); - DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index e22d117..6a8bb64 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -77,11 +77,11 @@ public class JavaPolynomialExpansionSuite { new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); - Row[] pairs = polyExpansion.transform(dataset) + List<Row> pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") - .collect(); + .collectAsList(); for (Row r : pairs) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index ed74363..3f6fc33 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -26,7 +26,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaStandardScalerSuite { @@ -53,7 +54,7 @@ public class JavaStandardScalerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -65,7 +66,7 @@ public class JavaStandardScalerSuite { StandardScalerModel scalerModel = scaler.fit(dataFrame); // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); + Dataset<Row> scaledData = scalerModel.transform(dataFrame); scaledData.count(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index 139d1d0..5812037 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -25,7 +25,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,7 +65,7 @@ public class JavaStopWordsRemoverSuite { StructType schema = new StructType(new StructField[] { new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); remover.transform(dataset).collect(); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 153a08a..431779c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -26,7 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -58,16 +58,16 @@ public class JavaStringIndexerSuite { }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); - DataFrame dataset = sqlContext.createDataFrame(data, schema); + Dataset<Row> dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); - DataFrame output = indexer.fit(dataset).transform(dataset); + Dataset<Row> output = indexer.fit(dataset).transform(dataset); - Assert.assertArrayEquals( - new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, - output.orderBy("id").select("id", "labelIndex").collect()); + Assert.assertEquals( + Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), + output.orderBy("id").select("id", "labelIndex").collectAsList()); } /** An alias for RowFactory.create. */ http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index c407d98..83d16cb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -26,7 +27,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -61,11 +62,11 @@ public class JavaTokenizerSuite { new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); - DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); - Row[] pairs = myRegExTokenizer.transform(dataset) + List<Row> pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") - .collect(); + .collectAsList(); for (Row r : pairs) { Assert.assertEquals(r.get(0), r.get(1)); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index f8ba84e..e45e198 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -64,11 +64,11 @@ public class JavaVectorAssemblerSuite { Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[] {"x", "y", "z", "n"}) .setOutputCol("features"); - DataFrame output = assembler.transform(dataset); + Dataset<Row> output = assembler.transform(dataset); Assert.assertEquals( Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), output.select("features").first().<Vector>getAs(0)); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index bfcca62..fec6cac 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -30,7 +30,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public class JavaVectorIndexerSuite implements Serializable { new FeatureData(Vectors.dense(1.0, 4.0)) ); SQLContext sqlContext = new SQLContext(sc); - DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -66,6 +67,6 @@ public class JavaVectorIndexerSuite implements Serializable { Assert.assertEquals(model.numFeatures(), 2); Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps(); Assert.assertEquals(categoryMaps.size(), 1); - DataFrame indexedData = model.transform(data); + Dataset<Row> indexedData = model.transform(data); } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 786c11c..b87605e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite { RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); - DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + Dataset<Row> dataset = + jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - DataFrame output = vectorSlicer.transform(dataset); + Dataset<Row> output = vectorSlicer.transform(dataset); - for (Row r : output.select("userFeatures", "features").take(2)) { + for (Row r : output.select("userFeatures", "features").takeRows(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index b292b1b..7517b70 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -26,7 +26,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -53,7 +53,7 @@ public class JavaWord2VecSuite { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame( + Dataset<Row> documentDF = sqlContext.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -66,9 +66,9 @@ public class JavaWord2VecSuite { .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); - DataFrame result = model.transform(documentDF); + Dataset<Row> result = model.transform(documentDF); - for (Row r: result.select("result").collect()) { + for (Row r: result.select("result").collectAsList()) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index d5c9d12..a157530 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 38d15dc..9477e8d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTRegressorSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaGBTRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); GBTRegressor rf = new GBTRegressor() .setMaxDepth(2) http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 4fb0b0d..9f81751 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -28,7 +28,8 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite .generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before @@ -64,7 +65,7 @@ public class JavaLinearRegressionSuite implements Serializable { assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 31be888..a90535d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -31,7 +31,8 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestRegressorSuite implements Serializable { @@ -58,7 +59,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. RandomForestRegressor rf = new RandomForestRegressor() http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 2976b38..b8ddf90 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -31,7 +31,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; @@ -68,7 +68,7 @@ public class JavaLibSVMRelationSuite { @Test public void verifyLibSVMDF() { - DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 08eeca5..24b0097 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; @Before public void setUp() { http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/python/pyspark/mllib/common.py ---------------------------------------------------------------------- diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 9fda1b1..6bc2b1e 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -101,7 +101,7 @@ def _java2py(sc, r, encoding="bytes"): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) - if clsName == 'DataFrame': + if clsName == 'Dataset': return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 97f28fa..d2003fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan // TODO: don't swallow original stack trace if it exists @@ -30,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, - val startPosition: Option[Int] = None) + val startPosition: Option[Int] = None, + val plan: Option[LogicalPlan] = None) extends Exception with Serializable { def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d8f755a..902644e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -50,7 +50,9 @@ object RowEncoder { inputObject: Expression, inputType: DataType): Expression = inputType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => inputObject + FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + + case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => val obj = NewInstance( @@ -137,6 +139,7 @@ object RowEncoder { private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt + case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -150,19 +153,23 @@ object RowEncoder { private def constructorFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val field = BoundReference(i, f.dataType, f.nullable) + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + val field = BoundReference(i, dt, f.nullable) If( IsNull(field), - Literal.create(null, externalDataTypeFor(f.dataType)), + Literal.create(null, externalDataTypeFor(dt)), constructorFor(field) ) } - CreateExternalRow(fields) + CreateExternalRow(fields, schema) } private def constructorFor(input: Expression): Expression = input.dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => input + FloatType | DoubleType | BinaryType | CalendarIntervalType => input case udt: UserDefinedType[_] => val obj = NewInstance( @@ -216,7 +223,7 @@ object RowEncoder { "toScalaMap", keyData :: valueData :: Nil) - case StructType(fields) => + case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), @@ -225,6 +232,6 @@ object RowEncoder { } If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), - CreateExternalRow(convertedFields)) + CreateExternalRow(convertedFields, schema)) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org