Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/18732#discussion_r142694835 --- Diff: python/pyspark/worker.py --- @@ -74,17 +74,35 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = toArrowType(return_type) - - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + if isinstance(return_type, StructType): + arrow_return_types = [to_arrow_type(field.dataType) for field in return_type] + + def fn(*a): + import pandas as pd + out = f(*a) + assert isinstance(out, pd.DataFrame), \ + 'Return value from the user function is not a pandas.DataFrame.' + assert len(out.columns) == len(arrow_return_types), \ + 'Number of columns of the returned pd.DataFrame doesn\'t match output schema. ' \ --- End diff -- Good catch. Fixed. (Btw thanks for catching these small things)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org