Github user imatiach-msft commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19439#discussion_r144191899
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala 
---
    @@ -0,0 +1,229 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.image
    +
    +import java.awt.Color
    +import java.awt.color.ColorSpace
    +import java.io.ByteArrayInputStream
    +import javax.imageio.ImageIO
    +
    +import org.apache.spark.annotation.{Experimental, Since}
    +import org.apache.spark.sql.{DataFrame, Row, SparkSession}
    +import org.apache.spark.sql.types._
    +
    +@Experimental
    +@Since("2.3.0")
    +object ImageSchema {
    +
    +  val undefinedImageType = "Undefined"
    +
    +  val ocvTypes = Map(
    +    undefinedImageType -> -1,
    +    "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC2" -> 8, "CV_8UC3" -> 16, 
"CV_8UC4" -> 24,
    +    "CV_8S" -> 1, "CV_8SC1" -> 1, "CV_8SC2" -> 9, "CV_8SC3" -> 17, 
"CV_8SC4" -> 25,
    +    "CV_16U" -> 2, "CV_16UC1" -> 2, "CV_16UC2" -> 10, "CV_16UC3" -> 18, 
"CV_16UC4" -> 26,
    +    "CV_16S" -> 3, "CV_16SC1" -> 3, "CV_16SC2" -> 11, "CV_16SC3" -> 19, 
"CV_16SC4" -> 27,
    +    "CV_32S" -> 4, "CV_32SC1" -> 4, "CV_32SC2" -> 12, "CV_32SC3" -> 20, 
"CV_32SC4" -> 28,
    +    "CV_32F" -> 5, "CV_32FC1" -> 5, "CV_32FC2" -> 13, "CV_32FC3" -> 21, 
"CV_32FC4" -> 29,
    +    "CV_64F" -> 6, "CV_64FC1" -> 6, "CV_64FC2" -> 14, "CV_64FC3" -> 22, 
"CV_64FC4" -> 30
    +  )
    +
    +  /**
    +   * Schema for the image column: Row(String, Int, Int, Int, Array[Byte])
    +   */
    +  val columnSchema = StructType(
    +    StructField("origin", StringType, true) ::
    +      StructField("height", IntegerType, false) ::
    +      StructField("width", IntegerType, false) ::
    +      StructField("nChannels", IntegerType, false) ::
    +      // OpenCV-compatible type: CV_8UC3 in most cases
    +      StructField("mode", StringType, false) ::
    +      // Bytes in OpenCV-compatible order: row-wise BGR in most cases
    +      StructField("data", BinaryType, false) :: Nil)
    +
    +  // Dataframe with a single column of images named "image" (nullable)
    +  private val imageDFSchema = StructType(StructField("image", 
columnSchema, true) :: Nil)
    +
    +  @Since("2.3.0")
    +  def getOrigin(row: Row): String = row.getString(0)
    +
    +  @Since("2.3.0")
    +  def getHeight(row: Row): Int = row.getInt(1)
    +
    +  @Since("2.3.0")
    +  def getWidth(row: Row): Int = row.getInt(2)
    +
    +  @Since("2.3.0")
    +  def getNChannels(row: Row): Int = row.getInt(3)
    +
    +  @Since("2.3.0")
    +  def getMode(row: Row): String = row.getString(4)
    +
    +  @Since("2.3.0")
    +  def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5)
    +
    +  /**
    +   * Check if the dataframe column contains images (i.e. has ImageSchema)
    +   *
    +   * @param df       Dataframe
    +   * @param column   Column name
    +   * @return         True if the given column matches the image schema
    +   */
    +  @Since("2.3.0")
    +  def isImageColumn(df: DataFrame, column: String): Boolean =
    +    df.schema(column).dataType == columnSchema
    +
    +  /**
    +   * Default values for the invalid image
    +   *
    +   * @param origin Origin of the invalid image
    +   * @return       Row with the default values
    +   */
    +  private def invalidImageRow(origin: String): Row = Row(Row(origin, -1, 
-1, -1, undefinedImageType,
    +    Array.ofDim[Byte](0)))
    +
    +  /**
    +   * Convert the compressed image (jpeg, png, etc.) into OpenCV
    +   * representation and store it in dataframe Row
    +   *
    +   * @param origin   Arbitrary string that identifies the image
    +   * @param bytes    Image bytes (for example, jpeg)
    +   * @return         Dataframe Row or None (if the decompression fails)
    +   */
    +  private[spark] def decode(origin: String, bytes: Array[Byte]): 
Option[Row] = {
    +
    +    val img = ImageIO.read(new ByteArrayInputStream(bytes))
    +
    +    if (img == null) {
    +      None
    +    } else {
    +      val is_gray = img.getColorModel.getColorSpace.getType == 
ColorSpace.TYPE_GRAY
    +      val has_alpha = img.getColorModel.hasAlpha
    +
    +      val height = img.getHeight
    +      val width = img.getWidth
    +      val (nChannels, mode) = if (is_gray) {
    +        (1, "CV_8UC1")
    +      } else if (has_alpha) {
    +        (4, "CV_8UC4")
    +      } else {
    +        (3, "CV_8UC3")
    +      }
    +
    +      assert(height * width * nChannels < 1e9, "image is too large")
    +      val decoded = Array.ofDim[Byte](height*width*nChannels)
    +
    +      // Grayscale images in Java require special handling to get the 
correct intensity
    +      if (is_gray) {
    +        var offset = 0
    +        val raster = img.getRaster
    +        for(h <- 0 until height) {
    +          for (w <- 0 until width) {
    +            decoded(offset) = raster.getSample(w, h, 0).toByte
    +            offset += 1
    +          }
    +        }
    +      }
    +      else {
    +        var offset = 0
    +        for (h <- 0 until height) {
    +          for (w <- 0 until width) {
    +            val color = new Color(img.getRGB(w, h))
    +
    +            decoded(offset) = color.getBlue.toByte
    +            decoded(offset + 1) = color.getGreen.toByte
    +            decoded(offset + 2) = color.getRed.toByte
    +            if (nChannels == 4) {
    +              decoded(offset + 3) = color.getAlpha.toByte
    +            }
    +            offset += nChannels
    +          }
    +        }
    +      }
    +
    +      // the internal "Row" is needed, because the image is a single 
dataframe column
    +      Some(Row(Row(origin, height, width, nChannels, mode, decoded)))
    +    }
    +  }
    +
    +  /**
    +   * :: Experimental ::
    +   * Read the directory of images from the local or remote source
    +   *
    +   * @param path               Path to the image directory
    +   * @param sparkSession       Spark Session
    +   * @param recursive          Recursive path search flag
    +   * @param numPartitions      Number of the dataframe partitions
    +   * @param dropImageFailures  Drop the files that are not valid images 
from the result
    +   * @param sampleRatio        Fraction of the files loaded
    +   * @return                   Dataframe with a single column "image" of 
images;
    +   *                           see ImageSchema for the details
    +   */
    +  @Since("2.3.0")
    +  def readImages(path: String,
    +                 sparkSession: SparkSession = null,
    +                 recursive: Boolean = false,
    +                 numPartitions: Int = 0,
    +                 dropImageFailures: Boolean = false,
    +                 sampleRatio: Double = 1.0): DataFrame = {
    +    require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be 
between 0 and 1")
    +
    +    val session = if (sparkSession != null) sparkSession else 
SparkSession.builder().getOrCreate
    +    val partitions =
    +      if (numPartitions > 0) {
    +        numPartitions
    +      } else {
    +        session.sparkContext.defaultParallelism
    +      }
    +
    +    val oldRecursiveFlag = 
RecursiveFlag.setRecursiveFlag(Some(recursive.toString), session)
    +    val oldPathFilter: Option[Class[_]] =
    +      if (sampleRatio < 1) {
    +        SamplePathFilter.setPathFilter(Some(classOf[SamplePathFilter]), 
sampleRatio, session)
    +      } else {
    +        None
    +      }
    +
    +    var result: DataFrame = null
    +    try {
    +      val streams = session.sparkContext.binaryFiles(path, partitions)
    +        .repartition(partitions)
    +
    +      val images = if (dropImageFailures) {
    +        streams.flatMap {
    +          case (origin, stream) => decode(origin, stream.toArray)
    +        }
    +      } else {
    +        streams.map {
    +          case (origin, stream) => decode(origin, 
stream.toArray).getOrElse(invalidImageRow(origin))
    +        }
    +      }
    +
    +      result = session.createDataFrame(images, imageDFSchema)
    +    }
    +    finally {
    +      // return Hadoop flags to the original values
    +      RecursiveFlag.setRecursiveFlag(oldRecursiveFlag, session)
    +      SamplePathFilter.unsetPathFilter(oldPathFilter, session)
    --- End diff --
    
    Yes, when sample ratio is less than one we add a path filter to sample 
images.  We then reset back to the old path filter (if there was any) or remove 
the path filter, depending on whether one existed or not.  Sorry, I'm a bit 
confused though, why or how can we avoid to unset the path filter?  Or did you 
mean that we shouldn't be using a path filter?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to