WeichenXu123 commented on code in PR #37734: URL: https://github.com/apache/spark/pull/37734#discussion_r1013958443
########## python/pyspark/ml/functions.py: ########## @@ -106,6 +117,542 @@ def array_to_vector(col: Column) -> Column: return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col))) +def _batched( + data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int +) -> Iterator[pd.DataFrame]: + """Generator that splits a pandas dataframe/series into batches.""" + if isinstance(data, pd.DataFrame): + index = 0 + data_size = len(data) + while index < data_size: + yield data.iloc[index : index + batch_size] + index += batch_size + else: + # convert (tuple of) pd.Series into pd.DataFrame + if isinstance(data, pd.Series): + df = pd.concat((data,), axis=1) + else: # isinstance(data, Tuple[pd.Series]): + df = pd.concat(data, axis=1) + + index = 0 + data_size = len(df) + while index < data_size: + yield df.iloc[index : index + batch_size] + index += batch_size + + +def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool: + if isinstance(data, pd.Series): + return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list)) + elif isinstance(data, pd.DataFrame): + return any(data.dtypes == np.object_) and any( + [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]] + ) + else: + raise ValueError( + "Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data)) + ) + + +def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> bool: + """Check if input Series/DataFrame/Tuple contains any tensor-valued columns.""" + if isinstance(data, (pd.Series, pd.DataFrame)): + return _is_tensor_col(data) + else: # isinstance(data, Tuple): + return any(_is_tensor_col(elem) for elem in data) + + +def _validate( + preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]], + num_input_rows: int, + return_type: DataType, +) -> None: + """Validate model predictions against the expected pandas_udf return_type.""" + if isinstance(return_type, StructType): + struct_rtype: StructType = return_type + fieldNames = struct_rtype.names + if isinstance(preds, dict): + # dictionary of columns + predNames = list(preds.keys()) + if not all(v.shape == (num_input_rows,) for v in preds.values()): + raise ValueError("Prediction results for StructType fields must be scalars.") + elif isinstance(preds, list) and isinstance(preds[0], dict): + # rows of dictionaries + predNames = list(preds[0].keys()) + if len(preds) != num_input_rows: + raise ValueError("Prediction results must have same length as input data.") Review Comment: Q: Shall we support this case ? I think we should regard this case as illegal case. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org