This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 167bbca49c1 [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join` 167bbca49c1 is described below commit 167bbca49c1c12ccd349d4330862c136b38d4522 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Feb 13 16:22:06 2023 +0800 [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join` ### What changes were proposed in this pull request? In Python Client - generate `plan_id` for each proto plan (It's up to the Client to guarantee the uniqueness); - attach `plan_id` to the column created by `DataFrame[col_name]` or `DataFrame.col_name`; - Note that `F.col(col_name)` doesn't have `plan_id`; In Connect Planner: - attach `plan_id` to `UnresolvedAttribute`s and `LogicalPlan `s via `TreeNodeTag` In Analyzer: - for an `UnresolvedAttribute` with `plan_id`, search the matching node in the plan, and resolve it with the found node if possible **Out of scope:** - resolve `self-join` - add a `DetectAmbiguousSelfJoin`-like rule for detection ### Why are the changes needed? Fix bug, before this PR: ``` df1.join(df2, df1["value"] == df2["value"]) <- fail due to can not resolve `value` df1.join(df2, df1["value"] == df2["value"]).select(df1.value) <- fail due to can not resolve `value` df1.select(df2.value) <- should fail, but run as `df1.select(df1.value)` and return the incorrect results ``` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added tests, enabled tests Closes #39925 from zhengruifeng/connect_plan_id. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 3 + .../main/protobuf/spark/connect/relations.proto | 3 + .../sql/connect/planner/SparkConnectPlanner.scala | 22 ++- python/pyspark/sql/column.py | 5 +- python/pyspark/sql/connect/dataframe.py | 15 +- python/pyspark/sql/connect/expressions.py | 7 +- python/pyspark/sql/connect/functions.py | 19 +- python/pyspark/sql/connect/plan.py | 213 +++++++++++---------- .../pyspark/sql/connect/proto/expressions_pb2.py | 54 +++--- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 20 +- python/pyspark/sql/connect/proto/relations_pb2.py | 200 +++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 15 +- .../sql/tests/connect/test_connect_basic.py | 62 ++++++ .../pyspark/sql/tests/connect/test_connect_plan.py | 20 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 +- .../catalyst/analysis/ColumnResolutionHelper.scala | 58 +++++- .../sql/catalyst/plans/logical/LogicalPlan.scala | 14 +- 17 files changed, 481 insertions(+), 265 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 8682e1ee27b..1929d9cdca3 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -197,6 +197,9 @@ message Expression { // (Required) An identifier that will be parsed by Catalyst parser. This should follow the // Spark SQL identifier syntax. string unparsed_identifier = 1; + + // (Optional) The id of corresponding connect plan. + optional int64 plan_id = 2; } // An unresolved function is not explicitly bound to one explicit function, but the function diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index ea1216957d8..29fffd65c75 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -93,6 +93,9 @@ message Unknown {} message RelationCommon { // (Required) Shared relation metadata. string source_info = 1; + + // (Optional) A per-client globally unique id for a given connect plan. + optional int64 plan_id = 2; } // Relation that uses a SQL query to generate the output. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 740d6b85964..53d494cdcb7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -66,7 +66,7 @@ class SparkConnectPlanner(val session: SparkSession) { // The root of the query plan is a relation and we apply the transformations to it. def transformRelation(rel: proto.Relation): LogicalPlan = { - rel.getRelTypeCase match { + val plan = rel.getRelTypeCase match { // DataFrame API case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) @@ -124,6 +124,11 @@ class SparkConnectPlanner(val session: SparkSession) { transformRelationPlugin(rel.getExtension) case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") } + + if (rel.hasCommon && rel.getCommon.hasPlanId) { + plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId) + } + plan } private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = { @@ -702,10 +707,6 @@ class SparkConnectPlanner(val session: SparkSession) { logical.Project(projectList = projection, child = baseRel) } - private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = { - UnresolvedAttribute.quotedString(exp.getUnresolvedAttribute.getUnparsedIdentifier) - } - /** * Transforms an input protobuf expression into the Catalyst expression. This is usually not * called directly. Typically the planner will traverse the expressions automatically, only @@ -720,7 +721,7 @@ class SparkConnectPlanner(val session: SparkSession) { exp.getExprTypeCase match { case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral) case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE => - transformUnresolvedExpression(exp) + transformUnresolvedAttribute(exp.getUnresolvedAttribute) case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => transformUnregisteredFunction(exp.getUnresolvedFunction) .getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction)) @@ -758,6 +759,15 @@ class SparkConnectPlanner(val session: SparkSession) { case expr => UnresolvedAlias(expr) } + private def transformUnresolvedAttribute( + attr: proto.Expression.UnresolvedAttribute): UnresolvedAttribute = { + val expr = UnresolvedAttribute.quotedString(attr.getUnparsedIdentifier) + if (attr.hasPlanId) { + expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId) + } + expr + } + private def transformExpressionPlugin(extension: ProtoAny): Expression = { SparkConnectPluginRegistry.expressionRegistry // Lazily traverse the collection. diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index a790f191110..0b5f94cfaaa 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -279,7 +279,6 @@ class Column: __ge__ = _bin_op("geq") __gt__ = _bin_op("gt") - # TODO(SPARK-41812): DataFrame.join: ambiguous column _eqNullSafe_doc = """ Equality test that is safe for null values. @@ -315,9 +314,9 @@ class Column: ... Row(value = 'bar'), ... Row(value = None) ... ]) - >>> df1.join(df2, df1["value"] == df2["value"]).count() # doctest: +SKIP + >>> df1.join(df2, df1["value"] == df2["value"]).count() 0 - >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() # doctest: +SKIP + >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() 1 >>> df2 = spark.createDataFrame([ ... Row(id=1, value=float('NaN')), diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 95e39f93dc0..667295e8667 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -50,12 +50,14 @@ from pyspark.sql.dataframe import ( ) from pyspark.errors import PySparkTypeError +from pyspark.errors.exceptions.connect import SparkConnectException import pyspark.sql.connect.plan as plan from pyspark.sql.connect.group import GroupedData from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import UnresolvedRegex from pyspark.sql.connect.functions import ( + _to_col_with_plan_id, _to_col, _invoke_function, col, @@ -1284,10 +1286,12 @@ class DataFrame: if isinstance(item, str): # Check for alias alias = self._get_alias() - if alias is not None: - return col(alias) - else: - return col(item) + if self._plan is None: + raise SparkConnectException("Cannot analyze on empty plan.") + return _to_col_with_plan_id( + col=alias if alias is not None else item, + plan_id=self._plan._plan_id, + ) elif isinstance(item, Column): return self.filter(item) elif isinstance(item, (list, tuple)): @@ -1694,9 +1698,8 @@ def _test() -> None: del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__ del pyspark.sql.connect.dataframe.DataFrame.repartitionByRange.__doc__ - # TODO(SPARK-41823): ambiguous column names + # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns del pyspark.sql.connect.dataframe.DataFrame.drop.__doc__ - del pyspark.sql.connect.dataframe.DataFrame.join.__doc__ # TODO(SPARK-41625): Support Structured Streaming del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__ diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 28b796496ec..571dd2b2f4b 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -339,11 +339,14 @@ class ColumnReference(Expression): treat it as an unresolved attribute. Attributes that have the same fully qualified name are identical""" - def __init__(self, unparsed_identifier: str) -> None: + def __init__(self, unparsed_identifier: str, plan_id: Optional[int] = None) -> None: super().__init__() assert isinstance(unparsed_identifier, str) self._unparsed_identifier = unparsed_identifier + assert plan_id is None or isinstance(plan_id, int) + self._plan_id = plan_id + def name(self) -> str: """Returns the qualified name of the column reference.""" return self._unparsed_identifier @@ -352,6 +355,8 @@ class ColumnReference(Expression): """Returns the Proto representation of the expression.""" expr = proto.Expression() expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier + if self._plan_id is not None: + expr.unresolved_attribute.plan_id = self._plan_id return expr def __repr__(self) -> str: diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index d4984b1ba67..e5305938797 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -63,6 +63,15 @@ if TYPE_CHECKING: from pyspark.sql.connect.dataframe import DataFrame +def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column: + if col == "*": + return Column(UnresolvedStar(unparsed_target=None)) + elif col.endswith(".*"): + return Column(UnresolvedStar(unparsed_target=col)) + else: + return Column(ColumnReference(unparsed_identifier=col, plan_id=plan_id)) + + def _to_col(col: "ColumnOrName") -> Column: assert isinstance(col, (Column, str)) return col if isinstance(col, Column) else column(col) @@ -202,12 +211,7 @@ def _options_to_col(options: Dict[str, Any]) -> Column: def col(col: str) -> Column: - if col == "*": - return Column(UnresolvedStar(unparsed_target=None)) - elif col.endswith(".*"): - return Column(UnresolvedStar(unparsed_target=col)) - else: - return Column(ColumnReference(unparsed_identifier=col)) + return _to_col_with_plan_id(col=col, plan_id=None) col.__doc__ = pysparkfuncs.col.__doc__ @@ -2470,9 +2474,6 @@ def _test() -> None: del pyspark.sql.connect.functions.timestamp_seconds.__doc__ del pyspark.sql.connect.functions.unix_timestamp.__doc__ - # TODO(SPARK-41812): Proper column names after join - del pyspark.sql.connect.functions.count_distinct.__doc__ - # TODO(SPARK-41843): Implement SparkSession.udf del pyspark.sql.connect.functions.call_udf.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index ced0e4008e1..d37201e4408 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -21,9 +21,11 @@ check_dependencies(__name__, __file__) from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools import json -import pyarrow as pa +from threading import Lock from inspect import signature, isclass +import pyarrow as pa + from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto @@ -40,13 +42,29 @@ class InputValidationError(Exception): pass -class LogicalPlan(object): +class LogicalPlan: + + _lock: Lock = Lock() + _nextPlanId: int = 0 INDENT = 2 def __init__(self, child: Optional["LogicalPlan"]) -> None: self._child = child + plan_id: Optional[int] = None + with LogicalPlan._lock: + plan_id = LogicalPlan._nextPlanId + LogicalPlan._nextPlanId += 1 + + assert plan_id is not None + self._plan_id = plan_id + + def _create_proto_relation(self) -> proto.Relation: + plan = proto.Relation() + plan.common.plan_id = self._plan_id + return plan + def unresolved_attr(self, colName: str) -> proto.Expression: """Creates an unresolved attribute from a column name.""" exp = proto.Expression() @@ -258,7 +276,7 @@ class DataSource(LogicalPlan): self._paths = paths def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation() + plan = self._create_proto_relation() plan.read.data_source.format = self._format if self._schema is not None: plan.read.data_source.schema = self._schema @@ -276,7 +294,7 @@ class Read(LogicalPlan): self.table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation() + plan = self._create_proto_relation() plan.read.named_table.unparsed_identifier = self.table_name return plan @@ -306,8 +324,7 @@ class LocalRelation(LogicalPlan): self._schema = schema def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = proto.Relation() - + plan = self._create_proto_relation() if self._table is not None: sink = pa.BufferOutputStream() with pa.ipc.new_stream(sink, self._table.schema) as writer: @@ -341,7 +358,7 @@ class ShowString(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.show_string.input.CopyFrom(self._child.plan(session)) plan.show_string.num_rows = self.num_rows plan.show_string.truncate = self.truncate @@ -378,6 +395,8 @@ class Project(LogicalPlan): from pyspark.sql.connect.functions import col assert self._child is not None + plan = self._create_proto_relation() + plan.project.input.CopyFrom(self._child.plan(session)) proj_exprs = [] for c in self._columns: @@ -386,8 +405,6 @@ class Project(LogicalPlan): else: proj_exprs.append(col(c).to_plan(session)) - plan = proto.Relation() - plan.project.input.CopyFrom(self._child.plan(session)) plan.project.expressions.extend(proj_exprs) return plan @@ -426,7 +443,7 @@ class WithColumns(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.with_columns.input.CopyFrom(self._child.plan(session)) for i in range(0, len(self._columnNames)): @@ -461,7 +478,7 @@ class Hint(LogicalPlan): from pyspark.sql.connect.functions import array, lit assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.hint.input.CopyFrom(self._child.plan(session)) plan.hint.name = self._name for param in self._parameters: @@ -479,7 +496,7 @@ class Filter(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.filter.input.CopyFrom(self._child.plan(session)) plan.filter.condition.CopyFrom(self.filter.to_plan(session)) return plan @@ -492,7 +509,7 @@ class Limit(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.limit.input.CopyFrom(self._child.plan(session)) plan.limit.limit = self.limit return plan @@ -505,7 +522,7 @@ class Tail(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.tail.input.CopyFrom(self._child.plan(session)) plan.tail.limit = self.limit return plan @@ -518,7 +535,7 @@ class Offset(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.offset.input.CopyFrom(self._child.plan(session)) plan.offset.offset = self.offset return plan @@ -537,7 +554,7 @@ class Deduplicate(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.deduplicate.input.CopyFrom(self._child.plan(session)) plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys if self.column_names is not None: @@ -570,7 +587,7 @@ class Sort(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.sort.input.CopyFrom(self._child.plan(session)) plan.sort.order.extend([self._convert_col(c, session) for c in self.columns]) plan.sort.is_global = self.is_global @@ -599,7 +616,7 @@ class Drop(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.drop.input.CopyFrom(self._child.plan(session)) plan.drop.cols.extend([self._convert_to_expr(c, session) for c in self.columns]) return plan @@ -624,7 +641,7 @@ class Sample(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.sample.input.CopyFrom(self._child.plan(session)) plan.sample.lower_bound = self.lower_bound plan.sample.upper_bound = self.upper_bound @@ -672,32 +689,31 @@ class Aggregate(LogicalPlan): from pyspark.sql.connect.functions import lit assert self._child is not None - - agg = proto.Relation() - - agg.aggregate.input.CopyFrom(self._child.plan(session)) - - agg.aggregate.grouping_expressions.extend([c.to_plan(session) for c in self._grouping_cols]) - agg.aggregate.aggregate_expressions.extend( + plan = self._create_proto_relation() + plan.aggregate.input.CopyFrom(self._child.plan(session)) + plan.aggregate.grouping_expressions.extend( + [c.to_plan(session) for c in self._grouping_cols] + ) + plan.aggregate.aggregate_expressions.extend( [c.to_plan(session) for c in self._aggregate_cols] ) if self._group_type == "groupby": - agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY + plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY elif self._group_type == "rollup": - agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP + plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP elif self._group_type == "cube": - agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE + plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE elif self._group_type == "pivot": - agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT + plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT assert self._pivot_col is not None - agg.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session)) + plan.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session)) if self._pivot_values is not None and len(self._pivot_values) > 0: - agg.aggregate.pivot.values.extend( + plan.aggregate.pivot.values.extend( [lit(v).to_plan(session).literal for v in self._pivot_values] ) - return agg + return plan class Join(LogicalPlan): @@ -742,23 +758,23 @@ class Join(LogicalPlan): self.how = join_type def plan(self, session: "SparkConnectClient") -> proto.Relation: - rel = proto.Relation() - rel.join.left.CopyFrom(self.left.plan(session)) - rel.join.right.CopyFrom(self.right.plan(session)) + plan = self._create_proto_relation() + plan.join.left.CopyFrom(self.left.plan(session)) + plan.join.right.CopyFrom(self.right.plan(session)) if self.on is not None: if not isinstance(self.on, list): if isinstance(self.on, str): - rel.join.using_columns.append(self.on) + plan.join.using_columns.append(self.on) else: - rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session)) + plan.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session)) elif len(self.on) > 0: if isinstance(self.on[0], str): - rel.join.using_columns.extend(cast(str, self.on)) + plan.join.using_columns.extend(cast(str, self.on)) else: merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on) - rel.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session)) - rel.join.join_type = self.how - return rel + plan.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session)) + plan.join.join_type = self.how + return plan def print(self, indent: int = 0) -> str: i = " " * indent @@ -800,29 +816,29 @@ class SetOperation(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - rel = proto.Relation() + plan = self._create_proto_relation() if self._child is not None: - rel.set_op.left_input.CopyFrom(self._child.plan(session)) + plan.set_op.left_input.CopyFrom(self._child.plan(session)) if self.other is not None: - rel.set_op.right_input.CopyFrom(self.other.plan(session)) + plan.set_op.right_input.CopyFrom(self.other.plan(session)) if self.set_op == "union": - rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION + plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION elif self.set_op == "intersect": - rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT + plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT elif self.set_op == "except": - rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT + plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT else: raise NotImplementedError( """ Unsupported set operation type: %s. """ - % rel.set_op.set_op_type + % plan.set_op.set_op_type ) - rel.set_op.is_all = self.is_all - rel.set_op.by_name = self.by_name - rel.set_op.allow_missing_columns = self.allow_missing_columns - return rel + plan.set_op.is_all = self.is_all + plan.set_op.by_name = self.by_name + plan.set_op.allow_missing_columns = self.allow_missing_columns + return plan def print(self, indent: int = 0) -> str: assert self._child is not None @@ -860,12 +876,12 @@ class Repartition(LogicalPlan): self._shuffle = shuffle def plan(self, session: "SparkConnectClient") -> proto.Relation: - rel = proto.Relation() + plan = self._create_proto_relation() if self._child is not None: - rel.repartition.input.CopyFrom(self._child.plan(session)) - rel.repartition.shuffle = self._shuffle - rel.repartition.num_partitions = self._num_partitions - return rel + plan.repartition.input.CopyFrom(self._child.plan(session)) + plan.repartition.shuffle = self._shuffle + plan.repartition.num_partitions = self._num_partitions + return plan class RepartitionByExpression(LogicalPlan): @@ -882,7 +898,7 @@ class RepartitionByExpression(LogicalPlan): self.columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: - rel = proto.Relation() + plan = self._create_proto_relation() part_exprs = [] for c in self.columns: @@ -894,13 +910,13 @@ class RepartitionByExpression(LogicalPlan): part_exprs.append(exp) else: part_exprs.append(self.unresolved_attr(c)) - rel.repartition_by_expression.partition_exprs.extend(part_exprs) + plan.repartition_by_expression.partition_exprs.extend(part_exprs) if self._child is not None: - rel.repartition_by_expression.input.CopyFrom(self._child.plan(session)) + plan.repartition_by_expression.input.CopyFrom(self._child.plan(session)) if self.num_partitions is not None: - rel.repartition_by_expression.num_partitions = self.num_partitions - return rel + plan.repartition_by_expression.num_partitions = self.num_partitions + return plan class SubqueryAlias(LogicalPlan): @@ -911,11 +927,11 @@ class SubqueryAlias(LogicalPlan): self._alias = alias def plan(self, session: "SparkConnectClient") -> proto.Relation: - rel = proto.Relation() + plan = self._create_proto_relation() if self._child is not None: - rel.subquery_alias.input.CopyFrom(self._child.plan(session)) - rel.subquery_alias.alias = self._alias - return rel + plan.subquery_alias.input.CopyFrom(self._child.plan(session)) + plan.subquery_alias.alias = self._alias + return plan class SQL(LogicalPlan): @@ -931,14 +947,14 @@ class SQL(LogicalPlan): self._args = args def plan(self, session: "SparkConnectClient") -> proto.Relation: - rel = proto.Relation() - rel.sql.query = self._query + plan = self._create_proto_relation() + plan.sql.query = self._query if self._args is not None and len(self._args) > 0: for k, v in self._args.items(): - rel.sql.args[k] = v + plan.sql.args[k] = v - return rel + return plan class Range(LogicalPlan): @@ -956,13 +972,13 @@ class Range(LogicalPlan): self._num_partitions = num_partitions def plan(self, session: "SparkConnectClient") -> proto.Relation: - rel = proto.Relation() - rel.range.start = self._start - rel.range.end = self._end - rel.range.step = self._step + plan = self._create_proto_relation() + plan.range.start = self._start + plan.range.end = self._end + plan.range.step = self._step if self._num_partitions is not None: - rel.range.num_partitions = self._num_partitions - return rel + plan.range.num_partitions = self._num_partitions + return plan class ToSchema(LogicalPlan): @@ -972,8 +988,7 @@ class ToSchema(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.to_schema.input.CopyFrom(self._child.plan(session)) plan.to_schema.schema.CopyFrom(pyspark_types_to_proto_types(self._schema)) return plan @@ -986,8 +1001,7 @@ class WithColumnsRenamed(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.with_columns_renamed.input.CopyFrom(self._child.plan(session)) for k, v in self._colsMap.items(): plan.with_columns_renamed.rename_columns_map[k] = v @@ -1019,8 +1033,7 @@ class Unpivot(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.unpivot.input.CopyFrom(self._child.plan(session)) plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids]) if self.values is not None: @@ -1064,7 +1077,7 @@ class NAFill(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.fill_na.input.CopyFrom(self._child.plan(session)) if self.cols is not None and len(self.cols) > 0: plan.fill_na.cols.extend(self.cols) @@ -1086,7 +1099,7 @@ class NADrop(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.drop_na.input.CopyFrom(self._child.plan(session)) if self.cols is not None and len(self.cols) > 0: plan.drop_na.cols.extend(self.cols) @@ -1122,7 +1135,7 @@ class NAReplace(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.replace.input.CopyFrom(self._child.plan(session)) if self.cols is not None and len(self.cols) > 0: plan.replace.cols.extend(self.cols) @@ -1150,7 +1163,7 @@ class StatSummary(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.summary.input.CopyFrom(self._child.plan(session)) plan.summary.statistics.extend(self.statistics) return plan @@ -1163,7 +1176,7 @@ class StatDescribe(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - plan = proto.Relation() + plan = self._create_proto_relation() plan.describe.input.CopyFrom(self._child.plan(session)) plan.describe.cols.extend(self.cols) return plan @@ -1177,8 +1190,7 @@ class StatCov(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.cov.input.CopyFrom(self._child.plan(session)) plan.cov.col1 = self._col1 plan.cov.col2 = self._col2 @@ -1200,8 +1212,7 @@ class StatApproxQuantile(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.approx_quantile.input.CopyFrom(self._child.plan(session)) plan.approx_quantile.cols.extend(self._cols) plan.approx_quantile.probabilities.extend(self._probabilities) @@ -1217,8 +1228,7 @@ class StatCrosstab(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.crosstab.input.CopyFrom(self._child.plan(session)) plan.crosstab.col1 = self.col1 plan.crosstab.col2 = self.col2 @@ -1238,8 +1248,7 @@ class StatFreqItems(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.freq_items.input.CopyFrom(self._child.plan(session)) plan.freq_items.cols.extend(self._cols) plan.freq_items.support = self._support @@ -1275,8 +1284,7 @@ class StatSampleBy(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.sample_by.input.CopyFrom(self._child.plan(session)) plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session)) if len(self._fractions) > 0: @@ -1299,8 +1307,7 @@ class StatCorr(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.corr.input.CopyFrom(self._child.plan(session)) plan.corr.col1 = self._col1 plan.corr.col2 = self._col2 @@ -1315,8 +1322,7 @@ class ToDF(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None - - plan = proto.Relation() + plan = self._create_proto_relation() plan.to_df.input.CopyFrom(self._child.plan(session)) plan.to_df.column_names.extend(self._cols) return plan @@ -1333,8 +1339,8 @@ class CreateView(LogicalPlan): def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None - plan = proto.Command() + plan.create_dataframe_view.replace = self._replace plan.create_dataframe_view.is_global = self._is_global plan.create_dataframe_view.name = self._name @@ -1358,6 +1364,7 @@ class WriteOperation(LogicalPlan): def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None plan = proto.Command() + plan.write_operation.input.CopyFrom(self._child.plan(session)) if self.source is not None: plan.write_operation.source = self.source diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 92d9e6a610a..891be5ea9ea 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbc%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) @@ -300,7 +300,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 4859 + _EXPRESSION._serialized_end = 4901 _EXPRESSION_WINDOW._serialized_start = 1475 _EXPRESSION_WINDOW._serialized_end = 2258 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765 @@ -324,29 +324,29 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599 _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697 _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3785 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3788 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3992 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3994 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4044 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4046 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4128 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4130 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4174 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4177 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4309 - _EXPRESSION_UPDATEFIELDS._serialized_start = 4312 - _EXPRESSION_UPDATEFIELDS._serialized_end = 4499 - _EXPRESSION_ALIAS._serialized_start = 4501 - _EXPRESSION_ALIAS._serialized_end = 4621 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4624 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173 - _PYTHONUDF._serialized_start = 5176 - _PYTHONUDF._serialized_end = 5306 - _SCALARSCALAUDF._serialized_start = 5309 - _SCALARSCALAUDF._serialized_end = 5493 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3827 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3830 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4034 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4036 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4086 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4216 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4219 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4351 + _EXPRESSION_UPDATEFIELDS._serialized_start = 4354 + _EXPRESSION_UPDATEFIELDS._serialized_end = 4541 + _EXPRESSION_ALIAS._serialized_start = 4543 + _EXPRESSION_ALIAS._serialized_end = 4663 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4666 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4824 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4826 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4888 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4904 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5215 + _PYTHONUDF._serialized_start = 5218 + _PYTHONUDF._serialized_end = 5348 + _SCALARSCALAUDF._serialized_start = 5351 + _SCALARSCALAUDF._serialized_end = 5535 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 934e0016c90..88b1fd8ef7e 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -613,19 +613,37 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor UNPARSED_IDENTIFIER_FIELD_NUMBER: builtins.int + PLAN_ID_FIELD_NUMBER: builtins.int unparsed_identifier: builtins.str """(Required) An identifier that will be parsed by Catalyst parser. This should follow the Spark SQL identifier syntax. """ + plan_id: builtins.int + """(Optional) The id of corresponding connect plan.""" def __init__( self, *, unparsed_identifier: builtins.str = ..., + plan_id: builtins.int | None = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"], + ) -> builtins.bool: ... def ClearField( self, - field_name: typing_extensions.Literal["unparsed_identifier", b"unparsed_identifier"], + field_name: typing_extensions.Literal[ + "_plan_id", + b"_plan_id", + "plan_id", + b"plan_id", + "unparsed_identifier", + b"unparsed_identifier", + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"] + ) -> typing_extensions.Literal["plan_id"] | None: ... class UnresolvedFunction(google.protobuf.message.Message): """An unresolved function is not explicitly bound to one explicit function, but the function diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index ece5920953e..057b96a8da9 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -639,103 +639,103 @@ if _descriptor._USE_C_DESCRIPTORS == False: _UNKNOWN._serialized_start = 2464 _UNKNOWN._serialized_end = 2473 _RELATIONCOMMON._serialized_start = 2475 - _RELATIONCOMMON._serialized_end = 2524 - _SQL._serialized_start = 2527 - _SQL._serialized_end = 2661 - _SQL_ARGSENTRY._serialized_start = 2606 - _SQL_ARGSENTRY._serialized_end = 2661 - _READ._serialized_start = 2664 - _READ._serialized_end = 3112 - _READ_NAMEDTABLE._serialized_start = 2806 - _READ_NAMEDTABLE._serialized_end = 2867 - _READ_DATASOURCE._serialized_start = 2870 - _READ_DATASOURCE._serialized_end = 3099 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3030 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3088 - _PROJECT._serialized_start = 3114 - _PROJECT._serialized_end = 3231 - _FILTER._serialized_start = 3233 - _FILTER._serialized_end = 3345 - _JOIN._serialized_start = 3348 - _JOIN._serialized_end = 3819 - _JOIN_JOINTYPE._serialized_start = 3611 - _JOIN_JOINTYPE._serialized_end = 3819 - _SETOPERATION._serialized_start = 3822 - _SETOPERATION._serialized_end = 4301 - _SETOPERATION_SETOPTYPE._serialized_start = 4138 - _SETOPERATION_SETOPTYPE._serialized_end = 4252 - _LIMIT._serialized_start = 4303 - _LIMIT._serialized_end = 4379 - _OFFSET._serialized_start = 4381 - _OFFSET._serialized_end = 4460 - _TAIL._serialized_start = 4462 - _TAIL._serialized_end = 4537 - _AGGREGATE._serialized_start = 4540 - _AGGREGATE._serialized_end = 5122 - _AGGREGATE_PIVOT._serialized_start = 4879 - _AGGREGATE_PIVOT._serialized_end = 4990 - _AGGREGATE_GROUPTYPE._serialized_start = 4993 - _AGGREGATE_GROUPTYPE._serialized_end = 5122 - _SORT._serialized_start = 5125 - _SORT._serialized_end = 5285 - _DROP._serialized_start = 5287 - _DROP._serialized_end = 5387 - _DEDUPLICATE._serialized_start = 5390 - _DEDUPLICATE._serialized_end = 5561 - _LOCALRELATION._serialized_start = 5563 - _LOCALRELATION._serialized_end = 5652 - _SAMPLE._serialized_start = 5655 - _SAMPLE._serialized_end = 5928 - _RANGE._serialized_start = 5931 - _RANGE._serialized_end = 6076 - _SUBQUERYALIAS._serialized_start = 6078 - _SUBQUERYALIAS._serialized_end = 6192 - _REPARTITION._serialized_start = 6195 - _REPARTITION._serialized_end = 6337 - _SHOWSTRING._serialized_start = 6340 - _SHOWSTRING._serialized_end = 6482 - _STATSUMMARY._serialized_start = 6484 - _STATSUMMARY._serialized_end = 6576 - _STATDESCRIBE._serialized_start = 6578 - _STATDESCRIBE._serialized_end = 6659 - _STATCROSSTAB._serialized_start = 6661 - _STATCROSSTAB._serialized_end = 6762 - _STATCOV._serialized_start = 6764 - _STATCOV._serialized_end = 6860 - _STATCORR._serialized_start = 6863 - _STATCORR._serialized_end = 7000 - _STATAPPROXQUANTILE._serialized_start = 7003 - _STATAPPROXQUANTILE._serialized_end = 7167 - _STATFREQITEMS._serialized_start = 7169 - _STATFREQITEMS._serialized_end = 7294 - _STATSAMPLEBY._serialized_start = 7297 - _STATSAMPLEBY._serialized_end = 7606 - _STATSAMPLEBY_FRACTION._serialized_start = 7498 - _STATSAMPLEBY_FRACTION._serialized_end = 7597 - _NAFILL._serialized_start = 7609 - _NAFILL._serialized_end = 7743 - _NADROP._serialized_start = 7746 - _NADROP._serialized_end = 7880 - _NAREPLACE._serialized_start = 7883 - _NAREPLACE._serialized_end = 8179 - _NAREPLACE_REPLACEMENT._serialized_start = 8038 - _NAREPLACE_REPLACEMENT._serialized_end = 8179 - _TODF._serialized_start = 8181 - _TODF._serialized_end = 8269 - _WITHCOLUMNSRENAMED._serialized_start = 8272 - _WITHCOLUMNSRENAMED._serialized_end = 8511 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8444 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8511 - _WITHCOLUMNS._serialized_start = 8513 - _WITHCOLUMNS._serialized_end = 8632 - _HINT._serialized_start = 8635 - _HINT._serialized_end = 8767 - _UNPIVOT._serialized_start = 8770 - _UNPIVOT._serialized_end = 9097 - _UNPIVOT_VALUES._serialized_start = 9027 - _UNPIVOT_VALUES._serialized_end = 9086 - _TOSCHEMA._serialized_start = 9099 - _TOSCHEMA._serialized_end = 9205 - _REPARTITIONBYEXPRESSION._serialized_start = 9208 - _REPARTITIONBYEXPRESSION._serialized_end = 9411 + _RELATIONCOMMON._serialized_end = 2566 + _SQL._serialized_start = 2569 + _SQL._serialized_end = 2703 + _SQL_ARGSENTRY._serialized_start = 2648 + _SQL_ARGSENTRY._serialized_end = 2703 + _READ._serialized_start = 2706 + _READ._serialized_end = 3154 + _READ_NAMEDTABLE._serialized_start = 2848 + _READ_NAMEDTABLE._serialized_end = 2909 + _READ_DATASOURCE._serialized_start = 2912 + _READ_DATASOURCE._serialized_end = 3141 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3072 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3130 + _PROJECT._serialized_start = 3156 + _PROJECT._serialized_end = 3273 + _FILTER._serialized_start = 3275 + _FILTER._serialized_end = 3387 + _JOIN._serialized_start = 3390 + _JOIN._serialized_end = 3861 + _JOIN_JOINTYPE._serialized_start = 3653 + _JOIN_JOINTYPE._serialized_end = 3861 + _SETOPERATION._serialized_start = 3864 + _SETOPERATION._serialized_end = 4343 + _SETOPERATION_SETOPTYPE._serialized_start = 4180 + _SETOPERATION_SETOPTYPE._serialized_end = 4294 + _LIMIT._serialized_start = 4345 + _LIMIT._serialized_end = 4421 + _OFFSET._serialized_start = 4423 + _OFFSET._serialized_end = 4502 + _TAIL._serialized_start = 4504 + _TAIL._serialized_end = 4579 + _AGGREGATE._serialized_start = 4582 + _AGGREGATE._serialized_end = 5164 + _AGGREGATE_PIVOT._serialized_start = 4921 + _AGGREGATE_PIVOT._serialized_end = 5032 + _AGGREGATE_GROUPTYPE._serialized_start = 5035 + _AGGREGATE_GROUPTYPE._serialized_end = 5164 + _SORT._serialized_start = 5167 + _SORT._serialized_end = 5327 + _DROP._serialized_start = 5329 + _DROP._serialized_end = 5429 + _DEDUPLICATE._serialized_start = 5432 + _DEDUPLICATE._serialized_end = 5603 + _LOCALRELATION._serialized_start = 5605 + _LOCALRELATION._serialized_end = 5694 + _SAMPLE._serialized_start = 5697 + _SAMPLE._serialized_end = 5970 + _RANGE._serialized_start = 5973 + _RANGE._serialized_end = 6118 + _SUBQUERYALIAS._serialized_start = 6120 + _SUBQUERYALIAS._serialized_end = 6234 + _REPARTITION._serialized_start = 6237 + _REPARTITION._serialized_end = 6379 + _SHOWSTRING._serialized_start = 6382 + _SHOWSTRING._serialized_end = 6524 + _STATSUMMARY._serialized_start = 6526 + _STATSUMMARY._serialized_end = 6618 + _STATDESCRIBE._serialized_start = 6620 + _STATDESCRIBE._serialized_end = 6701 + _STATCROSSTAB._serialized_start = 6703 + _STATCROSSTAB._serialized_end = 6804 + _STATCOV._serialized_start = 6806 + _STATCOV._serialized_end = 6902 + _STATCORR._serialized_start = 6905 + _STATCORR._serialized_end = 7042 + _STATAPPROXQUANTILE._serialized_start = 7045 + _STATAPPROXQUANTILE._serialized_end = 7209 + _STATFREQITEMS._serialized_start = 7211 + _STATFREQITEMS._serialized_end = 7336 + _STATSAMPLEBY._serialized_start = 7339 + _STATSAMPLEBY._serialized_end = 7648 + _STATSAMPLEBY_FRACTION._serialized_start = 7540 + _STATSAMPLEBY_FRACTION._serialized_end = 7639 + _NAFILL._serialized_start = 7651 + _NAFILL._serialized_end = 7785 + _NADROP._serialized_start = 7788 + _NADROP._serialized_end = 7922 + _NAREPLACE._serialized_start = 7925 + _NAREPLACE._serialized_end = 8221 + _NAREPLACE_REPLACEMENT._serialized_start = 8080 + _NAREPLACE_REPLACEMENT._serialized_end = 8221 + _TODF._serialized_start = 8223 + _TODF._serialized_end = 8311 + _WITHCOLUMNSRENAMED._serialized_start = 8314 + _WITHCOLUMNSRENAMED._serialized_end = 8553 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8486 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8553 + _WITHCOLUMNS._serialized_start = 8555 + _WITHCOLUMNS._serialized_end = 8674 + _HINT._serialized_start = 8677 + _HINT._serialized_end = 8809 + _UNPIVOT._serialized_start = 8812 + _UNPIVOT._serialized_end = 9139 + _UNPIVOT_VALUES._serialized_start = 9069 + _UNPIVOT_VALUES._serialized_end = 9128 + _TOSCHEMA._serialized_start = 9141 + _TOSCHEMA._serialized_end = 9247 + _REPARTITIONBYEXPRESSION._serialized_start = 9250 + _REPARTITIONBYEXPRESSION._serialized_end = 9453 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 41962ee4062..b7cef7b299d 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -478,16 +478,29 @@ class RelationCommon(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor SOURCE_INFO_FIELD_NUMBER: builtins.int + PLAN_ID_FIELD_NUMBER: builtins.int source_info: builtins.str """(Required) Shared relation metadata.""" + plan_id: builtins.int + """(Optional) A per-client globally unique id for a given connect plan.""" def __init__( self, *, source_info: builtins.str = ..., + plan_id: builtins.int | None = ..., ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"] + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["source_info", b"source_info"] + self, + field_name: typing_extensions.Literal[ + "_plan_id", b"_plan_id", "plan_id", b"plan_id", "source_info", b"source_info" + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"] + ) -> typing_extensions.Literal["plan_id"] | None: ... global___RelationCommon = RelationCommon diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a723163cbe8..9e9341c9a2a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -356,6 +356,68 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ) self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas()) + def test_join_ambiguous_cols(self): + # SPARK-41812: test join with ambiguous columns + data1 = [Row(id=1, value="foo"), Row(id=2, value=None)] + cdf1 = self.connect.createDataFrame(data1) + sdf1 = self.spark.createDataFrame(data1) + + data2 = [Row(value="bar"), Row(value=None), Row(value="foo")] + cdf2 = self.connect.createDataFrame(data2) + sdf2 = self.spark.createDataFrame(data2) + + cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]) + sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + + cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"])) + sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"])) + + self.assertEqual(cdf4.schema, sdf4.schema) + self.assertEqual(cdf4.collect(), sdf4.collect()) + + cdf5 = cdf1.join( + cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"])) + ) + sdf5 = sdf1.join( + sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"])) + ) + + self.assertEqual(cdf5.schema, sdf5.schema) + self.assertEqual(cdf5.collect(), sdf5.collect()) + + cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value) + sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value) + + self.assertEqual(cdf6.schema, sdf6.schema) + self.assertEqual(cdf6.collect(), sdf6.collect()) + + cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value) + sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value) + + self.assertEqual(cdf7.schema, sdf7.schema) + self.assertEqual(cdf7.collect(), sdf7.collect()) + + def test_invalid_column(self): + # SPARK-41812: fail df1.select(df2.col) + data1 = [Row(a=1, b=2, c=3)] + cdf1 = self.connect.createDataFrame(data1) + + data2 = [Row(a=2, b=0)] + cdf2 = self.connect.createDataFrame(data2) + + with self.assertRaises(AnalysisException): + cdf1.select(cdf2.a).schema + + with self.assertRaises(AnalysisException): + cdf2.withColumn("x", cdf1.a + 1).schema + + with self.assertRaisesRegex(AnalysisException, "attribute.*missing"): + cdf3 = cdf1.select(cdf1.a) + cdf3.select(cdf1.b).schema + def test_collect(self): cdf = self.connect.read.table(self.tbl_name) sdf = self.spark.read.table(self.tbl_name) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index 1892e64f8f9..a5f691d0bef 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -85,7 +85,18 @@ class SparkConnectPlanTests(PlanOnlyTestFixture): right_input = self.connect.readTable(table_name=self.tbl_name) crossJoin_plan = left_input.crossJoin(other=right_input)._plan.to_proto(self.connect) join_plan = left_input.join(other=right_input, how="cross")._plan.to_proto(self.connect) - self.assertEqual(crossJoin_plan, join_plan) + self.assertEqual( + crossJoin_plan.root.join.left.read.named_table, + join_plan.root.join.left.read.named_table, + ) + self.assertEqual( + crossJoin_plan.root.join.right.read.named_table, + join_plan.root.join.right.read.named_table, + ) + self.assertEqual( + crossJoin_plan.root.join.join_type, + join_plan.root.join.join_type, + ) def test_filter(self): df = self.connect.readTable(table_name=self.tbl_name) @@ -732,7 +743,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture): self.assertIsNotNone(cp1) self.assertEqual(cp1, cp2) - self.assertEqual(cp2, cp3) + self.assertEqual( + cp2.unresolved_attribute.unparsed_identifier, + cp3.unresolved_attribute.unparsed_identifier, + ) + self.assertTrue(cp2.unresolved_attribute.HasField("plan_id")) + self.assertFalse(cp3.unresolved_attribute.HasField("plan_id")) def test_null_literal(self): null_lit = lit(None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9a2648a79a5..eff8c114a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3307,12 +3307,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor _.containsPattern(NATURAL_LIKE_JOIN), ruleId) { case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint) if left.resolved && right.resolved && j.duplicateResolved => - commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint) + commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint, + j.getTagValue(LogicalPlan.PLAN_ID_TAG)) case j @ Join(left, right, NaturalJoin(joinType), condition, hint) if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) - commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint) + commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint, + j.getTagValue(LogicalPlan.PLAN_ID_TAG)) } } @@ -3442,7 +3444,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor joinType: JoinType, joinNames: Seq[String], condition: Option[Expression], - hint: JoinHint): LogicalPlan = { + hint: JoinHint, + planId: Option[Long] = None): LogicalPlan = { import org.apache.spark.sql.catalyst.util._ val leftKeys = joinNames.map { keyName => @@ -3483,9 +3486,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case _ => throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType) } + + val newJoin = Join(left, right, joinType, newCondition, hint) + // retain the plan id used in Spark Connect + planId.foreach(newJoin.setTagValue(LogicalPlan.PLAN_ID_TAG, _)) + // use Project to hide duplicated common keys // propagate hidden columns from nested USING/NATURAL JOINs - val project = Project(projectList, Join(left, right, joinType, newCondition, hint)) + val project = Project(projectList, newJoin) project.setTagValue( Project.hiddenOutputTag, hiddenList.map(_.markAsQualifiedAccessOnly()) ++ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 0985fe6852c..ba550bce791 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -357,8 +357,21 @@ trait ColumnResolutionHelper extends Logging { e: Expression, q: LogicalPlan, allowOuter: Boolean = false): Expression = { + val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { + // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and + // expression are from Spark Connect, and need to be resolved in this way: + // 1, extract the attached plan id from the expression (UnresolvedAttribute only for now); + // 2, top-down traverse the query plan to find the plan node that matches the plan id; + // 3, if can not find the matching node, fail the analysis due to illegal references; + // 4, resolve the expression with the matching node, if any error occurs here, apply the + // old code path; + resolveExpressionByPlanId(e, q) + } else { + e + } + resolveExpression( - e, + newE, resolveColumnByName = nameParts => { q.resolveChildren(nameParts, conf.resolver) }, @@ -369,4 +382,47 @@ trait ColumnResolutionHelper extends Logging { throws = true, allowOuter = allowOuter) } + + private def resolveExpressionByPlanId( + e: Expression, + q: LogicalPlan): Expression = { + if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) { + return e + } + + e match { + case u: UnresolvedAttribute => + resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u) + case _ => + e.mapChildren(c => resolveExpressionByPlanId(c, q)) + } + } + + private def resolveUnresolvedAttributeByPlanId( + u: UnresolvedAttribute, + q: LogicalPlan): Option[NamedExpression] = { + val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) + if (planIdOpt.isEmpty) return None + val planId = planIdOpt.get + logDebug(s"Extract plan_id $planId from $u") + + val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId)) + if (planOpt.isEmpty) { + // For example: + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw new AnalysisException(s"When resolving $u, " + + s"fail to find subplan with plan_id=$planId in $q") + } + val plan = planOpt.get + + try { + plan.resolve(u.nameParts, conf.resolver) + } catch { + case e: AnalysisException => + logDebug(s"Fail to resolve $u with $plan due to $e") + None + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 5a7dcff3667..36187bb2d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats -import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike} import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, StructType} @@ -157,6 +157,18 @@ abstract class LogicalPlan } } +object LogicalPlan { + // A dedicated tag for Spark Connect. + // If an expression (only support UnresolvedAttribute for now) was attached by this tag, + // the analyzer will: + // 1, extract the plan id; + // 2, top-down traverse the query plan to find the node that was attached by the same tag. + // and fails the whole analysis if can not find it; + // 3, resolve this expression with the matching node. If any error occurs, analyzer fallbacks + // to the old code path. + private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id") +} + /** * A logical plan node with no children. */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org