This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7d320d784a2 [SPARK-41077][CONNECT][PYTHON][REFACTORING] Rename `ColumnRef` to `Column` in Python client implementation 7d320d784a2 is described below commit 7d320d784a2d637fd1a8fd0798da3d2a39b4d7cd Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Nov 11 11:03:04 2022 +0900 [SPARK-41077][CONNECT][PYTHON][REFACTORING] Rename `ColumnRef` to `Column` in Python client implementation ### What changes were proposed in this pull request? Connect python client uses `ColumnRef` to represent columns in API (e.g. `df.name`). Current PySpark uses `Class Column` for the same thing. In this case, we can align Connect with PySpark, which can help existing PySpark users to reuse their code for Spark Connect python client as much as possible (minimize the code change). ### Why are the changes needed? This is to help existing PySpark users to reuse their code for Spark Connect python client as much as possible (minimize the code change). ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Existing UT Closes #38586 from amaliujia/SPARK-41077. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/column.py | 12 ++++---- python/pyspark/sql/connect/dataframe.py | 34 +++++++++++----------- python/pyspark/sql/connect/function_builder.py | 4 +-- python/pyspark/sql/connect/functions.py | 6 ++-- python/pyspark/sql/connect/plan.py | 14 ++++----- python/pyspark/sql/connect/typing/__init__.pyi | 4 +-- .../connect/test_connect_column_expressions.py | 6 ++-- 7 files changed, 40 insertions(+), 40 deletions(-) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 3c9f8c3d736..417bc7097de 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -30,8 +30,8 @@ if TYPE_CHECKING: def _bin_op( name: str, doc: str = "binary function", reverse: bool = False -) -> Callable[["ColumnRef", Any], "Expression"]: - def _(self: "ColumnRef", other: Any) -> "Expression": +) -> Callable[["Column", Any], "Expression"]: + def _(self: "Column", other: Any) -> "Expression": if isinstance(other, get_args(PrimitiveType)): other = LiteralExpression(other) if not reverse: @@ -163,15 +163,15 @@ class LiteralExpression(Expression): return f"Literal({self._value})" -class ColumnRef(Expression): +class Column(Expression): """Represents a column reference. There is no guarantee that this column actually exists. In the context of this project, we refer by its name and treat it as an unresolved attribute. Attributes that have the same fully qualified name are identical""" @classmethod - def from_qualified_name(cls, name: str) -> "ColumnRef": - return ColumnRef(name) + def from_qualified_name(cls, name: str) -> "Column": + return Column(name) def __init__(self, name: str) -> None: super().__init__() @@ -198,7 +198,7 @@ class ColumnRef(Expression): class SortOrder(Expression): - def __init__(self, col: ColumnRef, ascending: bool = True, nullsLast: bool = True) -> None: + def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True) -> None: super().__init__() self.ref = col self.ascending = ascending diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index e3116ea1250..0c19c67309d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -31,7 +31,7 @@ import pandas import pyspark.sql.connect.plan as plan from pyspark.sql.connect.column import ( - ColumnRef, + Column, Expression, LiteralExpression, ) @@ -44,7 +44,7 @@ if TYPE_CHECKING: from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString from pyspark.sql.connect.client import RemoteSparkSession -ColumnOrName = Union[ColumnRef, str] +ColumnOrName = Union[Column, str] class GroupingFrame(object): @@ -52,9 +52,9 @@ class GroupingFrame(object): MeasuresType = Union[Sequence[Tuple["ExpressionOrString", str]], Dict[str, str]] OptMeasuresType = Optional[MeasuresType] - def __init__(self, df: "DataFrame", *grouping_cols: Union[ColumnRef, str]) -> None: + def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> None: self._df = df - self._grouping_cols = [x if isinstance(x, ColumnRef) else df[x] for x in grouping_cols] + self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in grouping_cols] def agg(self, exprs: Optional[MeasuresType] = None) -> "DataFrame": @@ -76,18 +76,18 @@ class GroupingFrame(object): ) return res - def _map_cols_to_dict(self, fun: str, cols: List[Union[ColumnRef, str]]) -> Dict[str, str]: + def _map_cols_to_dict(self, fun: str, cols: List[Union[Column, str]]) -> Dict[str, str]: return {x if isinstance(x, str) else x.name(): fun for x in cols} - def min(self, *cols: Union[ColumnRef, str]) -> "DataFrame": + def min(self, *cols: Union[Column, str]) -> "DataFrame": expr = self._map_cols_to_dict("min", list(cols)) return self.agg(expr) - def max(self, *cols: Union[ColumnRef, str]) -> "DataFrame": + def max(self, *cols: Union[Column, str]) -> "DataFrame": expr = self._map_cols_to_dict("max", list(cols)) return self.agg(expr) - def sum(self, *cols: Union[ColumnRef, str]) -> "DataFrame": + def sum(self, *cols: Union[Column, str]) -> "DataFrame": expr = self._map_cols_to_dict("sum", list(cols)) return self.agg(expr) @@ -129,7 +129,7 @@ class DataFrame(object): def alias(self, alias: str) -> "DataFrame": return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), session=self._session) - def approxQuantile(self, col: ColumnRef, probabilities: Any, relativeError: Any) -> "DataFrame": + def approxQuantile(self, col: Column, probabilities: Any, relativeError: Any) -> "DataFrame": ... def colRegex(self, regex: str) -> "DataFrame": @@ -206,7 +206,7 @@ class DataFrame(object): self._session, ) - def describe(self, cols: List[ColumnRef]) -> Any: + def describe(self, cols: List[Column]) -> Any: ... def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": @@ -250,7 +250,7 @@ class DataFrame(object): def drop(self, *cols: "ColumnOrString") -> "DataFrame": all_cols = self.columns - dropped = set([c.name() if isinstance(c, ColumnRef) else self[c].name() for c in cols]) + dropped = set([c.name() if isinstance(c, Column) else self[c].name() for c in cols]) dropped_cols = filter(lambda x: x in dropped, all_cols) return DataFrame.withPlan(plan.Project(self._plan, *dropped_cols), session=self._session) @@ -320,11 +320,11 @@ class DataFrame(object): """ return self.limit(num).collect() - # TODO: extend `on` to also be type List[ColumnRef]. + # TODO: extend `on` to also be type List[Column]. def join( self, other: "DataFrame", - on: Optional[Union[str, List[str], ColumnRef]] = None, + on: Optional[Union[str, List[str], Column]] = None, how: Optional[str] = None, ) -> "DataFrame": if self._plan is None: @@ -566,16 +566,16 @@ class DataFrame(object): p = p._child return None - def __getattr__(self, name: str) -> "ColumnRef": + def __getattr__(self, name: str) -> "Column": return self[name] - def __getitem__(self, name: str) -> "ColumnRef": + def __getitem__(self, name: str) -> "Column": # Check for alias alias = self._get_alias() if alias is not None: - return ColumnRef(alias) + return Column(alias) else: - return ColumnRef(name) + return Column(name) def _print_plan(self) -> str: if self._plan: diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index 9c519312a4f..e116e493954 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Optional, Any, Iterable, Union import pyspark.sql.connect.proto as proto import pyspark.sql.types from pyspark.sql.connect.column import ( - ColumnRef, + Column, Expression, ScalarFunctionExpression, ) @@ -45,7 +45,7 @@ def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression: ------- :class:`ScalarFunctionExpression` """ - cols = [x if isinstance(x, Expression) else ColumnRef.from_qualified_name(x) for x in args] + cols = [x if isinstance(x, Expression) else Column.from_qualified_name(x) for x in args] return ScalarFunctionExpression(name, *cols) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 880096da459..00d0a56aedb 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -14,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from pyspark.sql.connect.column import ColumnRef, LiteralExpression +from pyspark.sql.connect.column import Column, LiteralExpression from typing import Any # TODO(SPARK-40538) Add support for the missing PySpark functions. -def col(x: str) -> ColumnRef: - return ColumnRef(x) +def col(x: str) -> Column: + return Column(x) def lit(x: Any) -> LiteralExpression: diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 926119c5457..e5eed195568 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -28,7 +28,7 @@ from typing import ( import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import ( - ColumnRef, + Column, Expression, SortOrder, ) @@ -64,7 +64,7 @@ class LogicalPlan(object): if type(col) is str: return self.unresolved_attr(col) else: - return cast(ColumnRef, col).to_plan(session) + return cast(Column, col).to_plan(session) def plan(self, session: "RemoteSparkSession") -> proto.Relation: ... @@ -360,7 +360,7 @@ class Sort(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], - columns: List[Union[SortOrder, ColumnRef, str]], + columns: List[Union[SortOrder, Column, str]], is_global: bool, ) -> None: super().__init__(child) @@ -368,7 +368,7 @@ class Sort(LogicalPlan): self.is_global = is_global def col_to_sort_field( - self, col: Union[SortOrder, ColumnRef, str], session: "RemoteSparkSession" + self, col: Union[SortOrder, Column, str], session: "RemoteSparkSession" ) -> proto.Sort.SortField: if isinstance(col, SortOrder): sf = proto.Sort.SortField() @@ -387,7 +387,7 @@ class Sort(LogicalPlan): else: sf = proto.Sort.SortField() # Check string - if isinstance(col, ColumnRef): + if isinstance(col, Column): sf.expression.CopyFrom(col.to_plan(session)) else: sf.expression.CopyFrom(self.unresolved_attr(col)) @@ -478,7 +478,7 @@ class Aggregate(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], - grouping_cols: List[ColumnRef], + grouping_cols: List[Column], measures: OptMeasuresType, ) -> None: super().__init__(child) @@ -532,7 +532,7 @@ class Join(LogicalPlan): self, left: Optional["LogicalPlan"], right: "LogicalPlan", - on: Optional[Union[str, List[str], ColumnRef]], + on: Optional[Union[str, List[str], Column]], how: Optional[str], ) -> None: super().__init__(left) diff --git a/python/pyspark/sql/connect/typing/__init__.pyi b/python/pyspark/sql/connect/typing/__init__.pyi index d8f8e300324..6c67b561311 100644 --- a/python/pyspark/sql/connect/typing/__init__.pyi +++ b/python/pyspark/sql/connect/typing/__init__.pyi @@ -17,12 +17,12 @@ from typing_extensions import Protocol from typing import Union -from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, ColumnRef +from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column from pyspark.sql.connect.function_builder import UserDefinedFunction ExpressionOrString = Union[str, Expression] -ColumnOrString = Union[str, ColumnRef] +ColumnOrString = Union[str, Column] class FunctionBuilderCallable(Protocol): def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: ... diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index ca75b14bb67..59e3c97679e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -36,11 +36,11 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): df = self.connect.with_plan(p.Read("table")) c1 = df.col_name - self.assertIsInstance(c1, col.ColumnRef) + self.assertIsInstance(c1, col.Column) c2 = df["col_name"] - self.assertIsInstance(c2, col.ColumnRef) + self.assertIsInstance(c2, col.Column) c3 = fun.col("col_name") - self.assertIsInstance(c3, col.ColumnRef) + self.assertIsInstance(c3, col.Column) # All Protos should be identical cp1 = c1.to_plan(None) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org