Github user imatiach-msft commented on a diff in the pull request: https://github.com/apache/spark/pull/19439#discussion_r144191899 --- Diff: mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala --- @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import java.awt.Color +import java.awt.color.ColorSpace +import java.io.ByteArrayInputStream +import javax.imageio.ImageIO + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types._ + +@Experimental +@Since("2.3.0") +object ImageSchema { + + val undefinedImageType = "Undefined" + + val ocvTypes = Map( + undefinedImageType -> -1, + "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC2" -> 8, "CV_8UC3" -> 16, "CV_8UC4" -> 24, + "CV_8S" -> 1, "CV_8SC1" -> 1, "CV_8SC2" -> 9, "CV_8SC3" -> 17, "CV_8SC4" -> 25, + "CV_16U" -> 2, "CV_16UC1" -> 2, "CV_16UC2" -> 10, "CV_16UC3" -> 18, "CV_16UC4" -> 26, + "CV_16S" -> 3, "CV_16SC1" -> 3, "CV_16SC2" -> 11, "CV_16SC3" -> 19, "CV_16SC4" -> 27, + "CV_32S" -> 4, "CV_32SC1" -> 4, "CV_32SC2" -> 12, "CV_32SC3" -> 20, "CV_32SC4" -> 28, + "CV_32F" -> 5, "CV_32FC1" -> 5, "CV_32FC2" -> 13, "CV_32FC3" -> 21, "CV_32FC4" -> 29, + "CV_64F" -> 6, "CV_64FC1" -> 6, "CV_64FC2" -> 14, "CV_64FC3" -> 22, "CV_64FC4" -> 30 + ) + + /** + * Schema for the image column: Row(String, Int, Int, Int, Array[Byte]) + */ + val columnSchema = StructType( + StructField("origin", StringType, true) :: + StructField("height", IntegerType, false) :: + StructField("width", IntegerType, false) :: + StructField("nChannels", IntegerType, false) :: + // OpenCV-compatible type: CV_8UC3 in most cases + StructField("mode", StringType, false) :: + // Bytes in OpenCV-compatible order: row-wise BGR in most cases + StructField("data", BinaryType, false) :: Nil) + + // Dataframe with a single column of images named "image" (nullable) + private val imageDFSchema = StructType(StructField("image", columnSchema, true) :: Nil) + + @Since("2.3.0") + def getOrigin(row: Row): String = row.getString(0) + + @Since("2.3.0") + def getHeight(row: Row): Int = row.getInt(1) + + @Since("2.3.0") + def getWidth(row: Row): Int = row.getInt(2) + + @Since("2.3.0") + def getNChannels(row: Row): Int = row.getInt(3) + + @Since("2.3.0") + def getMode(row: Row): String = row.getString(4) + + @Since("2.3.0") + def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) + + /** + * Check if the dataframe column contains images (i.e. has ImageSchema) + * + * @param df Dataframe + * @param column Column name + * @return True if the given column matches the image schema + */ + @Since("2.3.0") + def isImageColumn(df: DataFrame, column: String): Boolean = + df.schema(column).dataType == columnSchema + + /** + * Default values for the invalid image + * + * @param origin Origin of the invalid image + * @return Row with the default values + */ + private def invalidImageRow(origin: String): Row = Row(Row(origin, -1, -1, -1, undefinedImageType, + Array.ofDim[Byte](0))) + + /** + * Convert the compressed image (jpeg, png, etc.) into OpenCV + * representation and store it in dataframe Row + * + * @param origin Arbitrary string that identifies the image + * @param bytes Image bytes (for example, jpeg) + * @return Dataframe Row or None (if the decompression fails) + */ + private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = { + + val img = ImageIO.read(new ByteArrayInputStream(bytes)) + + if (img == null) { + None + } else { + val is_gray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY + val has_alpha = img.getColorModel.hasAlpha + + val height = img.getHeight + val width = img.getWidth + val (nChannels, mode) = if (is_gray) { + (1, "CV_8UC1") + } else if (has_alpha) { + (4, "CV_8UC4") + } else { + (3, "CV_8UC3") + } + + assert(height * width * nChannels < 1e9, "image is too large") + val decoded = Array.ofDim[Byte](height*width*nChannels) + + // Grayscale images in Java require special handling to get the correct intensity + if (is_gray) { + var offset = 0 + val raster = img.getRaster + for(h <- 0 until height) { + for (w <- 0 until width) { + decoded(offset) = raster.getSample(w, h, 0).toByte + offset += 1 + } + } + } + else { + var offset = 0 + for (h <- 0 until height) { + for (w <- 0 until width) { + val color = new Color(img.getRGB(w, h)) + + decoded(offset) = color.getBlue.toByte + decoded(offset + 1) = color.getGreen.toByte + decoded(offset + 2) = color.getRed.toByte + if (nChannels == 4) { + decoded(offset + 3) = color.getAlpha.toByte + } + offset += nChannels + } + } + } + + // the internal "Row" is needed, because the image is a single dataframe column + Some(Row(Row(origin, height, width, nChannels, mode, decoded))) + } + } + + /** + * :: Experimental :: + * Read the directory of images from the local or remote source + * + * @param path Path to the image directory + * @param sparkSession Spark Session + * @param recursive Recursive path search flag + * @param numPartitions Number of the dataframe partitions + * @param dropImageFailures Drop the files that are not valid images from the result + * @param sampleRatio Fraction of the files loaded + * @return Dataframe with a single column "image" of images; + * see ImageSchema for the details + */ + @Since("2.3.0") + def readImages(path: String, + sparkSession: SparkSession = null, + recursive: Boolean = false, + numPartitions: Int = 0, + dropImageFailures: Boolean = false, + sampleRatio: Double = 1.0): DataFrame = { + require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be between 0 and 1") + + val session = if (sparkSession != null) sparkSession else SparkSession.builder().getOrCreate + val partitions = + if (numPartitions > 0) { + numPartitions + } else { + session.sparkContext.defaultParallelism + } + + val oldRecursiveFlag = RecursiveFlag.setRecursiveFlag(Some(recursive.toString), session) + val oldPathFilter: Option[Class[_]] = + if (sampleRatio < 1) { + SamplePathFilter.setPathFilter(Some(classOf[SamplePathFilter]), sampleRatio, session) + } else { + None + } + + var result: DataFrame = null + try { + val streams = session.sparkContext.binaryFiles(path, partitions) + .repartition(partitions) + + val images = if (dropImageFailures) { + streams.flatMap { + case (origin, stream) => decode(origin, stream.toArray) + } + } else { + streams.map { + case (origin, stream) => decode(origin, stream.toArray).getOrElse(invalidImageRow(origin)) + } + } + + result = session.createDataFrame(images, imageDFSchema) + } + finally { + // return Hadoop flags to the original values + RecursiveFlag.setRecursiveFlag(oldRecursiveFlag, session) + SamplePathFilter.unsetPathFilter(oldPathFilter, session) --- End diff -- Yes, when sample ratio is less than one we add a path filter to sample images. We then reset back to the old path filter (if there was any) or remove the path filter, depending on whether one existed or not. Sorry, I'm a bit confused though, why or how can we avoid to unset the path filter? Or did you mean that we shouldn't be using a path filter?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org