Github user imatiach-msft commented on a diff in the pull request: https://github.com/apache/spark/pull/19439#discussion_r148029475 --- Diff: python/pyspark/ml/image.py --- @@ -0,0 +1,139 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import * +from pyspark.ml.param.shared import * +from pyspark.sql.types import * +from pyspark.sql.types import Row, _create_row +from pyspark.sql import DataFrame, SparkSession, SQLContext +import numpy as np + +undefinedImageType = "Undefined" + +imageFields = ["origin", "height", "width", "nChannels", "mode", "data"] + + +def getOcvTypes(spark=None): + """ + Returns the OpenCV type mapping supported + + :param sparkSession (SparkSession): The current spark session + :rtype dict: The OpenCV type mapping supported + + .. versionadded:: 2.3.0 + """ + spark = spark or SparkSession.builder.getOrCreate() + ctx = spark.sparkContext + return ctx._jvm.org.apache.spark.ml.image.ImageSchema.ocvTypes + + +# DataFrame with a single column of images named "image" (nullable) +def getImageSchema(spark=None): + """ + Returns the image schema + + :param spark (SparkSession): The current spark session + :rtype StructType: The image schema + + .. versionadded:: 2.3.0 + """ + spark = spark or SparkSession.builder.getOrCreate() + ctx = spark.sparkContext + return ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema + + +def toNDArray(image): + """ + Converts an image to a one-dimensional array. + + :param image (object): The image to be converted + :rtype array: The image as a one-dimensional array + + .. versionadded:: 2.3.0 + """ + height = image.height + width = image.width + nChannels = image.nChannels + return np.ndarray( + shape=(height, width, nChannels), + dtype=np.uint8, + buffer=image.data, + strides=(width * nChannels, nChannels, 1)) + + +def toImage(array, origin="", spark=None): + """ + Converts a one-dimensional array to a two-dimensional image. + + :param array (array): The array to convert to image + :param origin (str): Path to the image + :param spark (SparkSession): The current spark session + :rtype object: Two dimensional image + + .. versionadded:: 2.3.0 + """ + spark = spark or SparkSession.builder.getOrCreate() + if array.ndim != 3: + raise --- End diff -- @MrBago recommended this code change above, you are right I did not realize the syntax is invalid, what specific exception should I raise?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org