[ https://issues.apache.org/jira/browse/SPARK-29952?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Dongjoon Hyun updated SPARK-29952: ---------------------------------- Affects Version/s: (was: 3.0.0) 3.1.0 > Pandas UDFs do not support vectors as input > ------------------------------------------- > > Key: SPARK-29952 > URL: https://issues.apache.org/jira/browse/SPARK-29952 > Project: Spark > Issue Type: Improvement > Components: PySpark, SQL > Affects Versions: 3.1.0 > Reporter: koba > Priority: Minor > > Currently, pandas udfs do not support columns of vectors as input. Only > columns of arrays. This means that feature columns that contain Dense- or > Sparse vectors generated by CountVectorizer for example are not supported by > pandas udfs out of the box. One needs to convert vectors into arrays first. > It was not documented anywhere and I had to find out by trial and error. > Below is an example. > > {code:java} > from pyspark.sql.functions import udf, pandas_udf > import pyspark.sql.functions as F > from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT > from pyspark.sql.types import * > import numpy as np > columns = ['features','id'] > vals = [ > (DenseVector([1, 2, 1, 3]),1), > (DenseVector([2, 2, 1, 3]),2) > ] > sdf = spark.createDataFrame(vals,columns) > sdf.show() > +-----------------+---+ > | features| id| > +-----------------+---+ > |[1.0,2.0,1.0,3.0]| 1| > |[2.0,2.0,1.0,3.0]| 2| > +-----------------+---+ > {code} > {code:java} > @udf(returnType=ArrayType(FloatType())) > def vector_to_array(v): > # convert column of vectors into column of arrays > a = v.values.tolist() > return a > sdf = sdf.withColumn('features_array',vector_to_array('features')) > sdf.show() > sdf.dtypes > +-----------------+---+--------------------+ > | features| id| features_array| > +-----------------+---+--------------------+ > |[1.0,2.0,1.0,3.0]| 1|[1.0, 2.0, 1.0, 3.0]| > |[2.0,2.0,1.0,3.0]| 2|[2.0, 2.0, 1.0, 3.0]| > +-----------------+---+--------------------+ > [('features', 'vector'), ('id', 'bigint'), ('features_array', 'array<float>')] > {code} > {code:java} > import pandas as pd > @pandas_udf(LongType()) > def _pandas_udf(v): > res = [] > for i in v: > res.append(i.mean()) > return pd.Series(res) > sdf.select(_pandas_udf('features_array')).show() > +---------------------------+ > |_pandas_udf(features_array)| > +---------------------------+ > | 1| > | 2| > +---------------------------+ > {code} > But If I use the vector column I get the following error. > {code:java} > sdf.select(_pandas_udf('features')).show() > --------------------------------------------------------------------------- > Py4JJavaError Traceback (most recent call last) > <ipython-input-74-d93e4117f661> in <module> > 13 > 14 > ---> 15 sdf.select(_pandas_udf('features')).show() > ~/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/pyspark/sql/dataframe.py > in show(self, n, truncate, vertical) > 376 """ > 377 if isinstance(truncate, bool) and truncate: > --> 378 print(self._jdf.showString(n, 20, vertical)) > 379 else: > 380 print(self._jdf.showString(n, int(truncate), vertical)) > ~/.pyenv/versions/3.4.4/lib/python3.4/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py > in __call__(self, *args) > 1255 answer = self.gateway_client.send_command(command) > 1256 return_value = get_return_value( > -> 1257 answer, self.gateway_client, self.target_id, self.name) > 1258 > 1259 for temp_arg in temp_args: > ~/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/pyspark/sql/utils.py > in deco(*a, **kw) > 61 def deco(*a, **kw): > 62 try: > ---> 63 return f(*a, **kw) > 64 except py4j.protocol.Py4JJavaError as e: > 65 s = e.java_exception.toString() > ~/.pyenv/versions/3.4.4/lib/python3.4/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py > in get_return_value(answer, gateway_client, target_id, name) > 326 raise Py4JJavaError( > 327 "An error occurred while calling {0}{1}{2}.\n". > --> 328 format(target_id, ".", name), value) > 329 else: > 330 raise Py4JError( > Py4JJavaError: An error occurred while calling o2635.showString. > : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 > in stage 156.0 failed 1 times, most recent failure: Lost task 0.0 in stage > 156.0 (TID 606, localhost, executor driver): > java.lang.UnsupportedOperationException: Unsupported data type: > struct<type:tinyint,size:int,indices:array<int>,values:array<double>> > at > org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:56) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:92) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:116) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:115) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at scala.collection.Iterator$class.foreach(Iterator.scala:891) > at scala.collection.AbstractIterator.foreach(Iterator.scala:1334) > at scala.collection.IterableLike$class.foreach(IterableLike.scala:72) > at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99) > at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) > at org.apache.spark.sql.types.StructType.map(StructType.scala:99) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:115) > at > org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:71) > at > org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:345) > at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945) > at > org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:194) > Driver stacktrace: > at > org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889) > at > org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877) > at > org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876) > at > scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) > at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) > at > org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876) > at > org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926) > at > org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926) > at scala.Option.foreach(Option.scala:257) > at > org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926) > at > org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110) > at > org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059) > at > org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048) > at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) > at > org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737) > at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061) > at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082) > at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101) > at > org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365) > at > org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) > at > org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) > at > org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) > at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364) > at > org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78) > at > org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125) > at > org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73) > at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363) > at org.apache.spark.sql.Dataset.head(Dataset.scala:2544) > at org.apache.spark.sql.Dataset.take(Dataset.scala:2758) > at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254) > at org.apache.spark.sql.Dataset.showString(Dataset.scala:291) > at sun.reflect.GeneratedMethodAccessor81.invoke(Unknown Source) > at > sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) > at java.lang.reflect.Method.invoke(Method.java:498) > at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) > at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) > at py4j.Gateway.invoke(Gateway.java:282) > at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) > at py4j.commands.CallCommand.execute(CallCommand.java:79) > at py4j.GatewayConnection.run(GatewayConnection.java:238) > at java.lang.Thread.run(Thread.java:748) > Caused by: java.lang.UnsupportedOperationException: Unsupported data type: > struct<type:tinyint,size:int,indices:array<int>,values:array<double>> > at > org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:56) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:92) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:116) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:115) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at > scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) > at scala.collection.Iterator$class.foreach(Iterator.scala:891) > at scala.collection.AbstractIterator.foreach(Iterator.scala:1334) > at scala.collection.IterableLike$class.foreach(IterableLike.scala:72) > at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99) > at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) > at org.apache.spark.sql.types.StructType.map(StructType.scala:99) > at > org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:115) > at > org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:71) > at > org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:345) > at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945) > at > org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:194) > {code} > > > -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org