Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22610#discussion_r222173904
  
    --- Diff: python/pyspark/worker.py ---
    @@ -84,13 +84,36 @@ def wrap_scalar_pandas_udf(f, return_type):
         arrow_return_type = to_arrow_type(return_type)
     
         def verify_result_length(*a):
    +        import pyarrow as pa
             result = f(*a)
             if not hasattr(result, "__len__"):
                 raise TypeError("Return type of the user-defined function 
should be "
                                 "Pandas.Series, but is 
{}".format(type(result)))
             if len(result) != len(a[0]):
                 raise RuntimeError("Result vector from pandas_udf was not the 
required length: "
                                    "expected %d, got %d" % (len(a[0]), 
len(result)))
    +
    +        # Ensure return type of Pandas.Series matches the arrow return 
type of the user-defined
    +        # function. Otherwise, we may produce incorrect serialized data.
    +        # Note: for timestamp type, we only need to ensure both types are 
timestamp because the
    +        # serializer will do conversion.
    +        try:
    +            arrow_type_of_result = pa.from_numpy_dtype(result.dtype)
    +            both_are_timestamp = 
pa.types.is_timestamp(arrow_type_of_result) and \
    +                pa.types.is_timestamp(arrow_return_type)
    +            if not both_are_timestamp and arrow_return_type != 
arrow_type_of_result:
    +                print("WARN: Arrow type %s of return Pandas.Series of the 
user-defined function's "
    +                      "dtype %s doesn't match the arrow type %s "
    +                      "of defined return type %s" % (arrow_type_of_result, 
result.dtype,
    +                                                     arrow_return_type, 
return_type),
    +                      file=sys.stderr)
    +        except:
    +            print("WARN: Can't infer arrow type of Pandas.Series's dtype: 
%s, which might not "
    +                  "match the arrow type %s of defined return type %s" % 
(result.dtype,
    +                                                                         
arrow_return_type,
    +                                                                         
return_type),
    --- End diff --
    
    ok. thanks. :-)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to