WeichenXu123 commented on code in PR #37734:
URL: https://github.com/apache/spark/pull/37734#discussion_r1020345415


##########
python/pyspark/ml/functions.py:
##########
@@ -106,6 +117,601 @@ def array_to_vector(col: Column) -> Column:
     return 
Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
 
 
+def _batched(
+    data: pd.Series | pd.DataFrame | Tuple[pd.Series], batch_size: int
+) -> Iterator[pd.DataFrame]:
+    """Generator that splits a pandas dataframe/series into batches."""
+    if isinstance(data, pd.DataFrame):
+        df = data
+    elif isinstance(data, pd.Series):
+        df = pd.concat((data,), axis=1)
+    else:  # isinstance(data, Tuple[pd.Series]):
+        df = pd.concat(data, axis=1)
+
+    index = 0
+    data_size = len(df)
+    while index < data_size:
+        yield df.iloc[index : index + batch_size]
+        index += batch_size
+
+
+def _is_tensor_col(data: pd.Series | pd.DataFrame) -> bool:
+    if isinstance(data, pd.Series):
+        return data.dtype == np.object_ and isinstance(data.iloc[0], 
(np.ndarray, list))
+    elif isinstance(data, pd.DataFrame):
+        return any(data.dtypes == np.object_) and any(
+            [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
+        )
+    else:
+        raise ValueError(
+            "Unexpected data type: {}, expected pd.Series or 
pd.DataFrame.".format(type(data))
+        )
+
+
+def _has_tensor_cols(data: pd.Series | pd.DataFrame | Tuple[pd.Series]) -> 
bool:
+    """Check if input Series/DataFrame/Tuple contains any tensor-valued 
columns."""
+    if isinstance(data, (pd.Series, pd.DataFrame)):
+        return _is_tensor_col(data)
+    else:  # isinstance(data, Tuple):
+        return any(_is_tensor_col(elem) for elem in data)
+
+
+def _validate_and_transform_multiple_inputs(
+    batch: pd.DataFrame, input_shapes: List[List[int] | None], num_input_cols: 
int
+) -> List[np.ndarray]:
+    multi_inputs = [batch[col].to_numpy() for col in batch.columns]
+    if input_shapes:
+        if len(input_shapes) == num_input_cols:
+            multi_inputs = [
+                np.vstack(v).reshape([-1] + input_shapes[i])  # type: ignore
+                if input_shapes[i]
+                else v
+                for i, v in enumerate(multi_inputs)
+            ]
+            if not all([len(x) == len(batch) for x in multi_inputs]):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("input_tensor_shapes must match columns")
+
+    return multi_inputs
+
+
+def _validate_and_transform_single_input(
+    batch: pd.DataFrame,
+    input_shapes: List[List[int] | None],
+    has_tensors: bool,
+    has_tuple: bool,
+) -> np.ndarray:
+    # multiple input columns for single expected input
+    if has_tensors:
+        # tensor columns
+        if len(batch.columns) == 1:
+            # one tensor column and one expected input, vstack rows
+            single_input = np.vstack(batch.iloc[:, 0])
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into 
tensors."
+            )
+    else:
+        # scalar columns
+        if len(batch.columns) == 1:
+            # single scalar column, remove extra dim
+            single_input = np.squeeze(batch.to_numpy())
+            if input_shapes and input_shapes[0] not in [None, [], [1]]:
+                raise ValueError("Invalid input_tensor_shape for scalar 
column.")
+        elif not has_tuple:
+            # columns grouped via struct/array, convert to single tensor
+            single_input = batch.to_numpy()
+            if input_shapes and input_shapes[0] != [len(batch.columns)]:
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError(
+                "Multiple input columns found, but model expected a single "
+                "input, use `struct` or `array` to combine columns into 
tensors."
+            )
+
+    # if input_tensor_shapes provided, try to reshape input
+    if input_shapes:
+        if len(input_shapes) == 1:
+            single_input = single_input.reshape([-1] + input_shapes[0])  # 
type: ignore
+            if len(single_input) != len(batch):
+                raise ValueError("Input data does not match expected shape.")
+        else:
+            raise ValueError("Multiple input_tensor_shapes found, but model 
expected one input")
+
+    return single_input
+
+
+def _validate_and_transform_prediction_result(
+    preds: np.ndarray | Mapping[str, np.ndarray] | List[Mapping[str, Any]],
+    num_input_rows: int,
+    return_type: DataType,
+) -> pd.DataFrame | pd.Series:
+    """Validate numpy-based model predictions against the expected pandas_udf 
return_type and
+    transforms the predictions into an equivalent pandas DataFrame or 
Series."""
+    if isinstance(return_type, StructType):
+        struct_rtype: StructType = return_type
+        fieldNames = struct_rtype.names
+        if isinstance(preds, dict):
+            # dictionary of columns
+            predNames = list(preds.keys())
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[field.name].shape) == 2:
+                        preds[field.name] = list(preds[field.name])
+                    else:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be 
two-dimensional."
+                        )
+                else:
+                    if len(preds[field.name].shape) != 1:
+                        raise ValueError(
+                            "Prediction results for scalar types must be 
one-dimensional."
+                        )
+                if len(preds[field.name]) != num_input_rows:
+                    raise ValueError("Prediction results must have same length 
as input data")
+
+        elif isinstance(preds, list) and isinstance(preds[0], dict):
+            # rows of dictionaries
+            predNames = list(preds[0].keys())
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as 
input data.")
+            for field in struct_rtype.fields:
+                if isinstance(field.dataType, ArrayType):
+                    if len(preds[0][field.name].shape) != 2:
+                        raise ValueError(
+                            "Prediction results for ArrayType must be 
two-dimensional."
+                        )
+                else:
+                    if (
+                        isinstance(preds[0][field.name], np.ndarray)
+                        and preds[0][field.name].shape != ()
+                    ):
+                        raise ValueError("Invalid shape for scalar prediction 
result.")
+        else:
+            raise ValueError(
+                "Prediction results for StructType must be a dictionary or "
+                "a list of dictionary, got: {}".format(type(preds))
+            )
+
+        # check column names
+        if set(predNames) != set(fieldNames):
+            raise ValueError(
+                "Prediction result columns did not match expected return_type "
+                "columns: expected {}, got: {}".format(fieldNames, predNames)
+            )
+
+        return pd.DataFrame(preds)
+    elif isinstance(return_type, ArrayType):
+        if isinstance(preds, np.ndarray):
+            if len(preds) != num_input_rows:
+                raise ValueError("Prediction results must have same length as 
input data.")
+            if len(preds.shape) != 2:
+                raise ValueError("Prediction results for ArrayType must be 
two-dimensional.")
+        else:

Review Comment:
   Ditto, pls use 
   ```
   ...
   elif isinstance(return_type, number_types):
     # scalar number case
     ...
   else:
     raise ValueError("Unsupported field type in return struct type.")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to