Repository: spark Updated Branches: refs/heads/master e5bb26174 -> d451b7f43
[SPARK-21326][SPARK-21066][ML] Use TextFileFormat in LibSVMFileFormat and allow multiple input paths for determining numFeatures ## What changes were proposed in this pull request? This is related with [SPARK-19918](https://issues.apache.org/jira/browse/SPARK-19918) and [SPARK-18362](https://issues.apache.org/jira/browse/SPARK-18362). This PR proposes to use `TextFileFormat` and allow multiple input paths (but with a warning) when determining the number of features in LibSVM data source via an extra scan. There are three points here: - The main advantage of this change should be to remove file-listing bottlenecks in driver side. - Another advantage is ones from using `FileScanRDD`. For example, I guess we can use `spark.sql.files.ignoreCorruptFiles` option when determining the number of features. - We can unify the schema inference code path in text based data sources. This is also a preparation for [SPARK-21289](https://issues.apache.org/jira/browse/SPARK-21289). ## How was this patch tested? Unit tests in `LibSVMRelationSuite`. Closes #18288 Author: hyukjinkwon <gurwls...@gmail.com> Closes #18556 from HyukjinKwon/libsvm-schema. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d451b7f4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d451b7f4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d451b7f4 Branch: refs/heads/master Commit: d451b7f43d559aa1efd7ac3d1cbec5249f3a7a24 Parents: e5bb261 Author: hyukjinkwon <gurwls...@gmail.com> Authored: Fri Jul 7 12:24:03 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Jul 7 12:24:03 2017 +0800 ---------------------------------------------------------------------- .../spark/ml/source/libsvm/LibSVMRelation.scala | 26 ++++++++++---------- .../org/apache/spark/mllib/util/MLUtils.scala | 25 +++++++++++++++++-- .../ml/source/libsvm/LibSVMRelationSuite.scala | 17 ++++++++++--- 3 files changed, 49 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d451b7f4/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index f68847a..dec1183 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.internal.Logging import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vectors, VectorUDT} @@ -66,7 +67,10 @@ private[libsvm] class LibSVMOutputWriter( /** @see [[LibSVMDataSource]] for public documentation. */ // If this is moved or renamed, please update DataSource's backwardCompatibilityMap. -private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { +private[libsvm] class LibSVMFileFormat + extends TextBasedFileFormat + with DataSourceRegister + with Logging { override def shortName(): String = "libsvm" @@ -89,18 +93,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour files: Seq[FileStatus]): Option[StructType] = { val libSVMOptions = new LibSVMOptions(options) val numFeatures: Int = libSVMOptions.numFeatures.getOrElse { - // Infers number of features if the user doesn't specify (a valid) one. - val dataFiles = files.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) { - dataFiles.head.getPath.toUri.toString - } else if (dataFiles.isEmpty) { - throw new IOException("No input path specified for libsvm data") - } else { - throw new IOException("Multiple input paths are not supported for libsvm data.") - } - - val sc = sparkSession.sparkContext - val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + require(files.nonEmpty, "No input path specified for libsvm data") + logWarning( + "'numFeatures' option not specified, determining the number of features by going " + + "though the input. If you know the number in advance, please specify it via " + + "'numFeatures' option to avoid the extra scan.") + + val paths = files.map(_.getPath.toUri.toString) + val parsed = MLUtils.parseLibSVMFile(sparkSession, paths) MLUtils.computeNumFeatures(parsed) } http://git-wip-us.apache.org/repos/asf/spark/blob/d451b7f4/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4fdad05..14af8b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -28,8 +28,10 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -102,6 +104,25 @@ object MLUtils extends Logging { .map(parseLibSVMRecord) } + private[spark] def parseLibSVMFile( + sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = { + val lines = sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value") + + import lines.sqlContext.implicits._ + + lines.select(trim($"value").as("line")) + .filter(not((length($"line") === 0).or($"line".startsWith("#")))) + .as[String] + .rdd + .map(MLUtils.parseLibSVMRecord) + } + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { val items = line.split(' ') val label = items.head.toDouble http://git-wip-us.apache.org/repos/asf/spark/blob/d451b7f4/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e164d27..a67e49d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -35,15 +35,22 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { override def beforeAll(): Unit = { super.beforeAll() - val lines = + val lines0 = """ |1 1:1.0 3:2.0 5:3.0 |0 + """.stripMargin + val lines1 = + """ |0 2:4.0 4:5.0 6:6.0 """.stripMargin val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") - val file = new File(dir, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + val succ = new File(dir, "_SUCCESS") + val file0 = new File(dir, "part-00000") + val file1 = new File(dir, "part-00001") + Files.write("", succ, StandardCharsets.UTF_8) + Files.write(lines0, file0, StandardCharsets.UTF_8) + Files.write(lines1, file1, StandardCharsets.UTF_8) path = dir.toURI.toString } @@ -145,7 +152,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("create libsvmTable table without schema and path") { try { - val e = intercept[IOException](spark.sql("CREATE TABLE libsvmTable USING libsvm")) + val e = intercept[IllegalArgumentException] { + spark.sql("CREATE TABLE libsvmTable USING libsvm") + } assert(e.getMessage.contains("No input path specified for libsvm data")) } finally { spark.sql("DROP TABLE IF EXISTS libsvmTable") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org