zhengruifeng commented on code in PR #55530:
URL: https://github.com/apache/spark/pull/55530#discussion_r3222955743
##########
python/pyspark/worker.py:
##########
@@ -2940,20 +2920,9 @@ def dataframe_iter():
parsed_offsets = extract_key_value_indexes(arg_offsets)
- # Pre-compute expected column names/types for strict result validation.
- # Cogrouped map has a strict contract: missing, extra, or
type-mismatched
- # columns must raise; no silent coercion.
- if runner_conf.assign_cols_by_name:
- expected_cols_and_types = {
- col.name: to_arrow_type(col.dataType, timezone="UTC") for col
in return_type.fields
- }
- reorder_names = [col.name for col in return_type.fields]
- else:
- expected_cols_and_types = [
- (col.name, to_arrow_type(col.dataType, timezone="UTC"))
- for col in return_type.fields
- ]
- reorder_names = None
+ arrow_return_schema = pa.schema(
+ [(col.name, to_arrow_type(col.dataType, timezone="UTC")) for col
in return_type.fields]
+ )
Review Comment:
The new grouped paths in this PR (`worker.py:2651-2654` and `:2714-2717`)
thread `runner_conf.use_large_var_types` through `to_arrow_type(return_type,
...)` for their validation schema, but cogrouped still builds the schema
per-field without `prefers_large_types`. Under
`spark.sql.execution.arrow.useLargeVarTypes=true` this leaves the cogrouped
validation expecting regular `string`/`binary` while the rest of the pipeline
(and the new grouped paths) expects `large_string`/`large_binary` — a UDF
returning large variants is rejected, and a regular-string return that should
be flagged is accepted. The pre-PR `verify_arrow_result` setup had the same
omission, but since the grouped paths in this PR pick it up, aligning cogrouped
is the consistency fix.
```suggestion
arrow_return_type = to_arrow_type(
return_type, timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types
)
arrow_return_schema = pa.schema(list(arrow_return_type))
```
##########
python/pyspark/sql/conversion.py:
##########
@@ -145,11 +146,26 @@ def enforce_schema(
If False, raise an error on type mismatch instead of casting.
safecheck : bool, default True
If True, use safe casting (fails on overflow/truncation).
+ reorder_by_name : bool, default True
+ If True, match columns by name and reorder to the target order; any
+ missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``.
Output
Review Comment:
Heads-up: the new default `reorder_by_name=True` strictly rejects extras
(raises `RESULT_COLUMN_NAMES_MISMATCH`), but the old `enforce_schema` silently
dropped them — it only looked up target names via `batch.column(name)`. The
remaining default-behavior caller —
`ArrowStreamArrowUDTFSerializer.dump_stream`
(`pyspark/sql/pandas/serializers.py:293`) — therefore changes contract: a
`SQL_ARROW_UDTF` that returned target columns plus extras was a no-op before,
now raises. Probably the right contract (the old leniency was undocumented),
but worth surfacing in the "user-facing change" section of the PR description
so any UDTF returning extras can be cleaned up before upgrade.
##########
python/pyspark/sql/conversion.py:
##########
@@ -145,11 +146,26 @@ def enforce_schema(
If False, raise an error on type mismatch instead of casting.
safecheck : bool, default True
If True, use safe casting (fails on overflow/truncation).
+ reorder_by_name : bool, default True
+ If True, match columns by name and reorder to the target order; any
+ missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``.
Output
+ columns are renamed to target names.
+ If False, match columns by position (ignore names) and preserve the
+ original column names in the output.
Returns
-------
- pa.RecordBatch
- RecordBatch with columns reordered and types coerced to match
target schema.
+ pa.RecordBatch or pa.Table
+ Same container type as ``batch``, with columns matched (and
possibly
+ reordered/cast) per the target schema.
+
+ Raises
+ ------
+ PySparkRuntimeError
+ ``RESULT_COLUMN_NAMES_MISMATCH`` when ``reorder_by_name=True`` and
the
+ batch has missing or extra column names.
+ ``RESULT_COLUMN_TYPES_MISMATCH`` when any column's type does not
match
+ the target (and either ``arrow_cast=False`` or the cast itself
fails).
Review Comment:
The `Raises` section omits `RESULT_COLUMN_SCHEMA_MISMATCH`, which the
function also raises (positional mode, when `batch.num_columns !=
len(arrow_schema)`).
```suggestion
the target (and either ``arrow_cast=False`` or the cast itself
fails).
``RESULT_COLUMN_SCHEMA_MISMATCH`` when ``reorder_by_name=False``
and the
batch has a different number of columns than the target schema.
```
##########
python/pyspark/sql/conversion.py:
##########
@@ -160,37 +176,68 @@ def enforce_schema(
if batch.schema.equals(arrow_schema, check_metadata=False):
return batch
- # Check if columns are in the same order (by name) as the target
schema.
- # If so, use index-based access (faster than name lookup).
- batch_names = [batch.schema.field(i).name for i in
range(batch.num_columns)]
target_names = [field.name for field in arrow_schema]
- use_index = batch_names == target_names
- coerced_arrays = []
- for i, field in enumerate(arrow_schema):
- try:
- arr = batch.column(i) if use_index else
batch.column(field.name)
- except KeyError:
- raise PySparkTypeError(
- f"Result column '{field.name}' does not exist in the
output. "
- f"Expected schema: {arrow_schema}, got: {batch.schema}."
+ # Step 1: pick source columns from batch to align with target schema
+ if reorder_by_name:
+ batch_names = [batch.schema.field(i).name for i in
range(batch.num_columns)]
+ missing = sorted(set(target_names) - set(batch_names))
+ extra = sorted(set(batch_names) - set(target_names))
+ if missing or extra:
+ raise PySparkRuntimeError(
+ errorClass="RESULT_COLUMN_NAMES_MISMATCH",
+ messageParameters={
+ "missing": f" Missing: {', '.join(missing)}." if
missing else "",
+ "extra": f" Unexpected: {', '.join(extra)}." if extra
else "",
+ },
)
- if arr.type != field.type:
- if not arrow_cast:
- raise PySparkTypeError(
- f"Result type of column '{field.name}' does not match "
- f"the expected type. Expected: {field.type}, got:
{arr.type}."
- )
+ source_columns = [batch.column(name) for name in target_names]
+ output_names = target_names
+ else:
+ # Positional: require exact column-count match, then take columns
by
+ # index, preserving the batch's original column names.
+ if batch.num_columns != len(arrow_schema):
Review Comment:
Behavior change worth noting in the PR description: under
`assign_cols_by_name=False`, the old `verify_arrow_result` did `zip(expected,
actual)` for the column list and silently truncated to the shorter list, so
column-count mismatches in the positional grouped/cogrouped paths slipped
through. The new strict count check is an improvement over the silent
truncation, but it's a runtime behavior change for users running with
`spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName=false` whose
UDF/UDTF returned the wrong number of columns (and was previously getting
silently partial validation).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]