Github user jaceklaskowski commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22328#discussion_r215038606
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala
 ---
    @@ -0,0 +1,119 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.source.image
    +
    +import java.nio.file.Paths
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.image.ImageSchema._
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.Row
    +import org.apache.spark.sql.functions.{col, substring_index}
    +
    +class ImageFileFormatSuite extends SparkFunSuite with 
MLlibTestSparkContext {
    +
    +  // Single column of images named "image"
    +  private lazy val imagePath = "../data/mllib/images/imagesWithPartitions"
    +
    +  test("image datasource count test") {
    +    val df1 = spark.read.format("image").load(imagePath)
    +    assert(df1.count === 9)
    +
    +    val df2 = spark.read.format("image").option("dropImageFailures", 
"true").load(imagePath)
    +    assert(df2.count === 8)
    +  }
    +
    +  test("image datasource test: read jpg image") {
    +    val df = spark.read.format("image").load(imagePath + 
"/cls=kittens/date=2018-02/DP153539.jpg")
    +    assert(df.count() === 1)
    +  }
    +
    +  test("image datasource test: read png image") {
    +    val df = spark.read.format("image").load(imagePath + 
"/cls=multichannel/date=2018-01/BGRA.png")
    +    assert(df.count() === 1)
    +  }
    +
    +  test("image datasource test: read non image") {
    +    val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt"
    +    val df = spark.read.format("image").option("dropImageFailures", "true")
    +      .load(filePath)
    +    assert(df.count() === 0)
    +
    +    val df2 = spark.read.format("image").option("dropImageFailures", 
"false")
    +      .load(filePath)
    +    assert(df2.count() === 1)
    +    val result = df2.head()
    +    assert(result === invalidImageRow(
    +      Paths.get(filePath).toAbsolutePath().normalize().toUri().toString))
    +  }
    +
    +  test("image datasource partition test") {
    +    val result = spark.read.format("image")
    +      .option("dropImageFailures", "true").load(imagePath)
    +      .select(substring_index(col("image.origin"), "/", -1).as("origin"), 
col("cls"), col("date"))
    +      .collect()
    +
    +    assert(Set(result: _*) === Set(
    +      Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"),
    +      Row("54893.jpg", "kittens", "2018-02"),
    +      Row("DP153539.jpg", "kittens", "2018-02"),
    +      Row("DP802813.jpg", "kittens", "2018-02"),
    +      Row("BGRA.png", "multichannel", "2018-01"),
    +      Row("BGRA_alpha_60.png", "multichannel", "2018-01"),
    +      Row("chr30.4.184.jpg", "multichannel", "2018-02"),
    +      Row("grayscale.jpg", "multichannel", "2018-02")
    +    ))
    +  }
    +
    +  // Images with the different number of channels
    +  test("readImages pixel values test") {
    +
    +    val images = spark.read.format("image").option("dropImageFailures", 
"true")
    +      .load(imagePath + "/cls=multichannel/").collect()
    +
    +    val firstBytes20Map = images.map { rrow =>
    +      val row = rrow.getAs[Row]("image")
    +      val filename = Paths.get(getOrigin(row)).getFileName().toString()
    +      val mode = getMode(row)
    +      val bytes20 = getData(row).slice(0, 20).toList
    +      filename -> Tuple2(mode, bytes20)
    --- End diff --
    
    Why is `Tuple2` required here? Wouldn't `(mode, bytes20)` work here?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to