Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/19439#discussion_r144299338 --- Diff: mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala --- @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.image + +import java.awt.Color +import java.awt.color.ColorSpace +import java.io.ByteArrayInputStream +import javax.imageio.ImageIO + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types._ + +@Experimental +@Since("2.3.0") +object ImageSchema { + + val undefinedImageType = "Undefined" + + val ocvTypes = Map( + undefinedImageType -> -1, + "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC2" -> 8, "CV_8UC3" -> 16, "CV_8UC4" -> 24, + "CV_8S" -> 1, "CV_8SC1" -> 1, "CV_8SC2" -> 9, "CV_8SC3" -> 17, "CV_8SC4" -> 25, + "CV_16U" -> 2, "CV_16UC1" -> 2, "CV_16UC2" -> 10, "CV_16UC3" -> 18, "CV_16UC4" -> 26, + "CV_16S" -> 3, "CV_16SC1" -> 3, "CV_16SC2" -> 11, "CV_16SC3" -> 19, "CV_16SC4" -> 27, + "CV_32S" -> 4, "CV_32SC1" -> 4, "CV_32SC2" -> 12, "CV_32SC3" -> 20, "CV_32SC4" -> 28, + "CV_32F" -> 5, "CV_32FC1" -> 5, "CV_32FC2" -> 13, "CV_32FC3" -> 21, "CV_32FC4" -> 29, + "CV_64F" -> 6, "CV_64FC1" -> 6, "CV_64FC2" -> 14, "CV_64FC3" -> 22, "CV_64FC4" -> 30 + ) + + /** + * Schema for the image column: Row(String, Int, Int, Int, Array[Byte]) + */ + val columnSchema = StructType( + StructField("origin", StringType, true) :: + StructField("height", IntegerType, false) :: + StructField("width", IntegerType, false) :: + StructField("nChannels", IntegerType, false) :: + // OpenCV-compatible type: CV_8UC3 in most cases + StructField("mode", StringType, false) :: + // Bytes in OpenCV-compatible order: row-wise BGR in most cases + StructField("data", BinaryType, false) :: Nil) + + // DataFrame with a single column of images named "image" (nullable) + private val imageDFSchema = StructType(StructField("image", columnSchema, true) :: Nil) + + /** + * :: Experimental :: + * Gets the origin of the image + * + * @return The origin of the image + */ + def getOrigin(row: Row): String = row.getString(0) + + /** + * :: Experimental :: + * Gets the height of the image + * + * @return The height of the image + */ + def getHeight(row: Row): Int = row.getInt(1) + + /** + * :: Experimental :: + * Gets the width of the image + * + * @return The width of the image + */ + def getWidth(row: Row): Int = row.getInt(2) + + /** + * :: Experimental :: + * Gets the number of channels in the image + * + * @return The number of channels in the image + */ + def getNChannels(row: Row): Int = row.getInt(3) + + /** + * :: Experimental :: + * Gets the OpenCV representation + * + * @return The OpenCV representation + */ + def getMode(row: Row): String = row.getString(4) + + /** + * :: Experimental :: + * Gets the image data + * + * @return The image data + */ + def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5) + + /** + * :: Experimental :: + * Check if the DataFrame column contains images (i.e. has ImageSchema) + * + * @param df DataFrame + * @param column Column name + * @return True if the given column matches the image schema + */ + def isImageColumn(df: DataFrame, column: String): Boolean = + df.schema(column).dataType == columnSchema + + /** + * Default values for the invalid image + * + * @param origin Origin of the invalid image + * @return Row with the default values + */ + private def invalidImageRow(origin: String): Row = Row(Row(origin, -1, -1, -1, undefinedImageType, + Array.ofDim[Byte](0))) --- End diff -- nit: ```scala private def invalidImageRow(origin: String): Row = Row(Row(origin, -1, -1, -1, undefinedImageType, Array.ofDim[Byte](0))) ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org