Github user icexelloss commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18732#discussion_r142694835
  
    --- Diff: python/pyspark/worker.py ---
    @@ -74,17 +74,35 @@ def wrap_udf(f, return_type):
     
     
     def wrap_pandas_udf(f, return_type):
    -    arrow_return_type = toArrowType(return_type)
    -
    -    def verify_result_length(*a):
    -        result = f(*a)
    -        if not hasattr(result, "__len__"):
    -            raise TypeError("Return type of pandas_udf should be a 
Pandas.Series")
    -        if len(result) != len(a[0]):
    -            raise RuntimeError("Result vector from pandas_udf was not the 
required length: "
    -                               "expected %d, got %d" % (len(a[0]), 
len(result)))
    -        return result
    -    return lambda *a: (verify_result_length(*a), arrow_return_type)
    +    if isinstance(return_type, StructType):
    +        arrow_return_types = [to_arrow_type(field.dataType) for field in 
return_type]
    +
    +        def fn(*a):
    +            import pandas as pd
    +            out = f(*a)
    +            assert isinstance(out, pd.DataFrame), \
    +                'Return value from the user function is not a 
pandas.DataFrame.'
    +            assert len(out.columns) == len(arrow_return_types), \
    +                'Number of columns of the returned pd.DataFrame doesn\'t 
match output schema. ' \
    --- End diff --
    
    Good catch. Fixed. (Btw thanks for catching these small things)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to