This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 780aeec610a [SPARK-43146][CONNECT][PYTHON] Implement eager evaluation for __repr__ and _repr_html_ 780aeec610a is described below commit 780aeec610a08c084e2ee9b24467437cbd4a915b Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Apr 19 10:23:36 2023 +0900 [SPARK-43146][CONNECT][PYTHON] Implement eager evaluation for __repr__ and _repr_html_ ### What changes were proposed in this pull request? Implements eager evaluation for `DataFrame.__repr__` and `DataFrame._repr_html_`. ### Why are the changes needed? When `spark.sql.repl.eagerEval.enabled` is `True`, DataFrames should eagerly evaluate and show the results. ```py >>> spark.conf.set('spark.sql.repl.eagerEval.enabled', True) >>> spark.range(3) +---+ | id| +---+ | 0| | 1| | 2| +---+ ``` ### Does this PR introduce _any_ user-facing change? The eager evaluation will be available. ### How was this patch tested? Enabled the related test. Closes #40800 from ueshin/issues/SPARK-43146/eager_repr. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 15 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 10 + python/pyspark/sql/connect/dataframe.py | 50 +++- python/pyspark/sql/connect/plan.py | 15 ++ python/pyspark/sql/connect/proto/relations_pb2.py | 262 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 47 ++++ python/pyspark/sql/dataframe.py | 27 +-- .../sql/tests/connect/test_connect_basic.py | 1 - .../sql/tests/connect/test_parity_dataframe.py | 5 - .../main/scala/org/apache/spark/sql/Dataset.scala | 38 +++ 10 files changed, 312 insertions(+), 158 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index f49fca079b0..57bdf57c9cb 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -67,6 +67,7 @@ message Relation { CoGroupMap co_group_map = 32; WithWatermark with_watermark = 33; ApplyInPandasWithState apply_in_pandas_with_state = 34; + HtmlString html_string = 35; // NA functions NAFill fill_na = 90; @@ -457,6 +458,20 @@ message ShowString { bool vertical = 4; } +// Compose the string representing rows for output. +// It will invoke 'Dataset.htmlString' to compute the results. +message HtmlString { + // (Required) The input relation. + Relation input = 1; + + // (Required) Number of rows to show. + int32 num_rows = 2; + + // (Required) If set to more than 0, truncates strings to + // `truncate` characters and all cells will be aligned right. + int32 truncate = 3; +} + // Computes specified statistics for numeric and string columns. // It will invoke 'Dataset.summary' (same as 'StatFunctions.summary') // to compute the results. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f2b03ca05d4..5f39fcd17f7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -81,6 +81,7 @@ class SparkConnectPlanner(val session: SparkSession) { val plan = rel.getRelTypeCase match { // DataFrame API case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) + case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString) case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) @@ -225,6 +226,15 @@ class SparkConnectPlanner(val session: SparkSession) { data = Tuple1.apply(showString) :: Nil) } + private def transformHtmlString(rel: proto.HtmlString): LogicalPlan = { + val htmlString = Dataset + .ofRows(session, transformRelation(rel.getInput)) + .htmlString(rel.getNumRows, rel.getTruncate) + LocalRelation.fromProduct( + output = AttributeReference("html_string", StringType, false)() :: Nil, + data = Tuple1.apply(htmlString) :: Nil) + } + private def transformSql(sql: proto.SQL): LogicalPlan = { val args = sql.getArgsMap val parser = session.sessionState.sqlParser diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 9106ddc2fc8..6dd230920b5 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -96,10 +96,57 @@ class DataFrame: self._schema = schema self._plan: Optional[plan.LogicalPlan] = None self._session: "SparkSession" = session + # Check whether _repr_html is supported or not, we use it to avoid calling RPC twice + # by __repr__ and _repr_html_ while eager evaluation opens. + self._support_repr_html = False def __repr__(self) -> str: + if not self._support_repr_html: + ( + repl_eager_eval_enabled, + repl_eager_eval_max_num_rows, + repl_eager_eval_truncate, + ) = self._session._get_configs( + "spark.sql.repl.eagerEval.enabled", + "spark.sql.repl.eagerEval.maxNumRows", + "spark.sql.repl.eagerEval.truncate", + ) + if repl_eager_eval_enabled == "true": + return self._show_string( + n=int(cast(str, repl_eager_eval_max_num_rows)), + truncate=int(cast(str, repl_eager_eval_truncate)), + vertical=False, + ) return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + def _repr_html_(self) -> Optional[str]: + if not self._support_repr_html: + self._support_repr_html = True + ( + repl_eager_eval_enabled, + repl_eager_eval_max_num_rows, + repl_eager_eval_truncate, + ) = self._session._get_configs( + "spark.sql.repl.eagerEval.enabled", + "spark.sql.repl.eagerEval.maxNumRows", + "spark.sql.repl.eagerEval.truncate", + ) + if repl_eager_eval_enabled == "true": + pdf = DataFrame.withPlan( + plan.HtmlString( + child=self._plan, + num_rows=int(cast(str, repl_eager_eval_max_num_rows)), + truncate=int(cast(str, repl_eager_eval_truncate)), + ), + session=self._session, + ).toPandas() + assert pdf is not None + return pdf["html_string"][0] + else: + return None + + _repr_html_.__doc__ = PySparkDataFrame._repr_html_.__doc__ + @property def write(self) -> "DataFrameWriter": assert self._plan is not None @@ -1827,9 +1874,6 @@ class DataFrame: def toJSON(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("toJSON() is not implemented.") - def _repr_html_(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("_repr_html_() is not implemented.") - def sameSemantics(self, other: "DataFrame") -> bool: assert self._plan is not None assert other._plan is not None diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 586668864cc..c3b81cf80f0 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -395,6 +395,21 @@ class ShowString(LogicalPlan): return plan +class HtmlString(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], num_rows: int, truncate: int) -> None: + super().__init__(child) + self.num_rows = num_rows + self.truncate = truncate + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.html_string.input.CopyFrom(self._child.plan(session)) + plan.html_string.num_rows = self.num_rows + plan.html_string.truncate = self.truncate + return plan + + class Project(LogicalPlan): """Logical plan object for a projection. diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 8229dc18afe..6a4226185e7 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xdb\x15\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x99\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -68,6 +68,7 @@ _RANGE = DESCRIPTOR.message_types_by_name["Range"] _SUBQUERYALIAS = DESCRIPTOR.message_types_by_name["SubqueryAlias"] _REPARTITION = DESCRIPTOR.message_types_by_name["Repartition"] _SHOWSTRING = DESCRIPTOR.message_types_by_name["ShowString"] +_HTMLSTRING = DESCRIPTOR.message_types_by_name["HtmlString"] _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"] _STATDESCRIBE = DESCRIPTOR.message_types_by_name["StatDescribe"] _STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"] @@ -406,6 +407,17 @@ ShowString = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(ShowString) +HtmlString = _reflection.GeneratedProtocolMessageType( + "HtmlString", + (_message.Message,), + { + "DESCRIPTOR": _HTMLSTRING, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.HtmlString) + }, +) +_sym_db.RegisterMessage(HtmlString) + StatSummary = _reflection.GeneratedProtocolMessageType( "StatSummary", (_message.Message,), @@ -746,127 +758,129 @@ if _descriptor._USE_C_DESCRIPTORS == False: _PARSE_OPTIONSENTRY._options = None _PARSE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 165 - _RELATION._serialized_end = 2944 - _UNKNOWN._serialized_start = 2946 - _UNKNOWN._serialized_end = 2955 - _RELATIONCOMMON._serialized_start = 2957 - _RELATIONCOMMON._serialized_end = 3048 - _SQL._serialized_start = 3051 - _SQL._serialized_end = 3220 - _SQL_ARGSENTRY._serialized_start = 3130 - _SQL_ARGSENTRY._serialized_end = 3220 - _READ._serialized_start = 3223 - _READ._serialized_end = 3886 - _READ_NAMEDTABLE._serialized_start = 3401 - _READ_NAMEDTABLE._serialized_end = 3593 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3535 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 3593 - _READ_DATASOURCE._serialized_start = 3596 - _READ_DATASOURCE._serialized_end = 3873 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3535 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3593 - _PROJECT._serialized_start = 3888 - _PROJECT._serialized_end = 4005 - _FILTER._serialized_start = 4007 - _FILTER._serialized_end = 4119 - _JOIN._serialized_start = 4122 - _JOIN._serialized_end = 4593 - _JOIN_JOINTYPE._serialized_start = 4385 - _JOIN_JOINTYPE._serialized_end = 4593 - _SETOPERATION._serialized_start = 4596 - _SETOPERATION._serialized_end = 5075 - _SETOPERATION_SETOPTYPE._serialized_start = 4912 - _SETOPERATION_SETOPTYPE._serialized_end = 5026 - _LIMIT._serialized_start = 5077 - _LIMIT._serialized_end = 5153 - _OFFSET._serialized_start = 5155 - _OFFSET._serialized_end = 5234 - _TAIL._serialized_start = 5236 - _TAIL._serialized_end = 5311 - _AGGREGATE._serialized_start = 5314 - _AGGREGATE._serialized_end = 5896 - _AGGREGATE_PIVOT._serialized_start = 5653 - _AGGREGATE_PIVOT._serialized_end = 5764 - _AGGREGATE_GROUPTYPE._serialized_start = 5767 - _AGGREGATE_GROUPTYPE._serialized_end = 5896 - _SORT._serialized_start = 5899 - _SORT._serialized_end = 6059 - _DROP._serialized_start = 6062 - _DROP._serialized_end = 6203 - _DEDUPLICATE._serialized_start = 6206 - _DEDUPLICATE._serialized_end = 6377 - _LOCALRELATION._serialized_start = 6379 - _LOCALRELATION._serialized_end = 6468 - _SAMPLE._serialized_start = 6471 - _SAMPLE._serialized_end = 6744 - _RANGE._serialized_start = 6747 - _RANGE._serialized_end = 6892 - _SUBQUERYALIAS._serialized_start = 6894 - _SUBQUERYALIAS._serialized_end = 7008 - _REPARTITION._serialized_start = 7011 - _REPARTITION._serialized_end = 7153 - _SHOWSTRING._serialized_start = 7156 - _SHOWSTRING._serialized_end = 7298 - _STATSUMMARY._serialized_start = 7300 - _STATSUMMARY._serialized_end = 7392 - _STATDESCRIBE._serialized_start = 7394 - _STATDESCRIBE._serialized_end = 7475 - _STATCROSSTAB._serialized_start = 7477 - _STATCROSSTAB._serialized_end = 7578 - _STATCOV._serialized_start = 7580 - _STATCOV._serialized_end = 7676 - _STATCORR._serialized_start = 7679 - _STATCORR._serialized_end = 7816 - _STATAPPROXQUANTILE._serialized_start = 7819 - _STATAPPROXQUANTILE._serialized_end = 7983 - _STATFREQITEMS._serialized_start = 7985 - _STATFREQITEMS._serialized_end = 8110 - _STATSAMPLEBY._serialized_start = 8113 - _STATSAMPLEBY._serialized_end = 8422 - _STATSAMPLEBY_FRACTION._serialized_start = 8314 - _STATSAMPLEBY_FRACTION._serialized_end = 8413 - _NAFILL._serialized_start = 8425 - _NAFILL._serialized_end = 8559 - _NADROP._serialized_start = 8562 - _NADROP._serialized_end = 8696 - _NAREPLACE._serialized_start = 8699 - _NAREPLACE._serialized_end = 8995 - _NAREPLACE_REPLACEMENT._serialized_start = 8854 - _NAREPLACE_REPLACEMENT._serialized_end = 8995 - _TODF._serialized_start = 8997 - _TODF._serialized_end = 9085 - _WITHCOLUMNSRENAMED._serialized_start = 9088 - _WITHCOLUMNSRENAMED._serialized_end = 9327 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9260 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9327 - _WITHCOLUMNS._serialized_start = 9329 - _WITHCOLUMNS._serialized_end = 9448 - _WITHWATERMARK._serialized_start = 9451 - _WITHWATERMARK._serialized_end = 9585 - _HINT._serialized_start = 9588 - _HINT._serialized_end = 9720 - _UNPIVOT._serialized_start = 9723 - _UNPIVOT._serialized_end = 10050 - _UNPIVOT_VALUES._serialized_start = 9980 - _UNPIVOT_VALUES._serialized_end = 10039 - _TOSCHEMA._serialized_start = 10052 - _TOSCHEMA._serialized_end = 10158 - _REPARTITIONBYEXPRESSION._serialized_start = 10161 - _REPARTITIONBYEXPRESSION._serialized_end = 10364 - _MAPPARTITIONS._serialized_start = 10367 - _MAPPARTITIONS._serialized_end = 10548 - _GROUPMAP._serialized_start = 10551 - _GROUPMAP._serialized_end = 10754 - _COGROUPMAP._serialized_start = 10757 - _COGROUPMAP._serialized_end = 11109 - _APPLYINPANDASWITHSTATE._serialized_start = 11112 - _APPLYINPANDASWITHSTATE._serialized_end = 11469 - _COLLECTMETRICS._serialized_start = 11472 - _COLLECTMETRICS._serialized_end = 11608 - _PARSE._serialized_start = 11611 - _PARSE._serialized_end = 11999 - _PARSE_OPTIONSENTRY._serialized_start = 3535 - _PARSE_OPTIONSENTRY._serialized_end = 3593 - _PARSE_PARSEFORMAT._serialized_start = 11900 - _PARSE_PARSEFORMAT._serialized_end = 11988 + _RELATION._serialized_end = 3006 + _UNKNOWN._serialized_start = 3008 + _UNKNOWN._serialized_end = 3017 + _RELATIONCOMMON._serialized_start = 3019 + _RELATIONCOMMON._serialized_end = 3110 + _SQL._serialized_start = 3113 + _SQL._serialized_end = 3282 + _SQL_ARGSENTRY._serialized_start = 3192 + _SQL_ARGSENTRY._serialized_end = 3282 + _READ._serialized_start = 3285 + _READ._serialized_end = 3948 + _READ_NAMEDTABLE._serialized_start = 3463 + _READ_NAMEDTABLE._serialized_end = 3655 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3597 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 3655 + _READ_DATASOURCE._serialized_start = 3658 + _READ_DATASOURCE._serialized_end = 3935 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3597 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3655 + _PROJECT._serialized_start = 3950 + _PROJECT._serialized_end = 4067 + _FILTER._serialized_start = 4069 + _FILTER._serialized_end = 4181 + _JOIN._serialized_start = 4184 + _JOIN._serialized_end = 4655 + _JOIN_JOINTYPE._serialized_start = 4447 + _JOIN_JOINTYPE._serialized_end = 4655 + _SETOPERATION._serialized_start = 4658 + _SETOPERATION._serialized_end = 5137 + _SETOPERATION_SETOPTYPE._serialized_start = 4974 + _SETOPERATION_SETOPTYPE._serialized_end = 5088 + _LIMIT._serialized_start = 5139 + _LIMIT._serialized_end = 5215 + _OFFSET._serialized_start = 5217 + _OFFSET._serialized_end = 5296 + _TAIL._serialized_start = 5298 + _TAIL._serialized_end = 5373 + _AGGREGATE._serialized_start = 5376 + _AGGREGATE._serialized_end = 5958 + _AGGREGATE_PIVOT._serialized_start = 5715 + _AGGREGATE_PIVOT._serialized_end = 5826 + _AGGREGATE_GROUPTYPE._serialized_start = 5829 + _AGGREGATE_GROUPTYPE._serialized_end = 5958 + _SORT._serialized_start = 5961 + _SORT._serialized_end = 6121 + _DROP._serialized_start = 6124 + _DROP._serialized_end = 6265 + _DEDUPLICATE._serialized_start = 6268 + _DEDUPLICATE._serialized_end = 6439 + _LOCALRELATION._serialized_start = 6441 + _LOCALRELATION._serialized_end = 6530 + _SAMPLE._serialized_start = 6533 + _SAMPLE._serialized_end = 6806 + _RANGE._serialized_start = 6809 + _RANGE._serialized_end = 6954 + _SUBQUERYALIAS._serialized_start = 6956 + _SUBQUERYALIAS._serialized_end = 7070 + _REPARTITION._serialized_start = 7073 + _REPARTITION._serialized_end = 7215 + _SHOWSTRING._serialized_start = 7218 + _SHOWSTRING._serialized_end = 7360 + _HTMLSTRING._serialized_start = 7362 + _HTMLSTRING._serialized_end = 7476 + _STATSUMMARY._serialized_start = 7478 + _STATSUMMARY._serialized_end = 7570 + _STATDESCRIBE._serialized_start = 7572 + _STATDESCRIBE._serialized_end = 7653 + _STATCROSSTAB._serialized_start = 7655 + _STATCROSSTAB._serialized_end = 7756 + _STATCOV._serialized_start = 7758 + _STATCOV._serialized_end = 7854 + _STATCORR._serialized_start = 7857 + _STATCORR._serialized_end = 7994 + _STATAPPROXQUANTILE._serialized_start = 7997 + _STATAPPROXQUANTILE._serialized_end = 8161 + _STATFREQITEMS._serialized_start = 8163 + _STATFREQITEMS._serialized_end = 8288 + _STATSAMPLEBY._serialized_start = 8291 + _STATSAMPLEBY._serialized_end = 8600 + _STATSAMPLEBY_FRACTION._serialized_start = 8492 + _STATSAMPLEBY_FRACTION._serialized_end = 8591 + _NAFILL._serialized_start = 8603 + _NAFILL._serialized_end = 8737 + _NADROP._serialized_start = 8740 + _NADROP._serialized_end = 8874 + _NAREPLACE._serialized_start = 8877 + _NAREPLACE._serialized_end = 9173 + _NAREPLACE_REPLACEMENT._serialized_start = 9032 + _NAREPLACE_REPLACEMENT._serialized_end = 9173 + _TODF._serialized_start = 9175 + _TODF._serialized_end = 9263 + _WITHCOLUMNSRENAMED._serialized_start = 9266 + _WITHCOLUMNSRENAMED._serialized_end = 9505 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9438 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9505 + _WITHCOLUMNS._serialized_start = 9507 + _WITHCOLUMNS._serialized_end = 9626 + _WITHWATERMARK._serialized_start = 9629 + _WITHWATERMARK._serialized_end = 9763 + _HINT._serialized_start = 9766 + _HINT._serialized_end = 9898 + _UNPIVOT._serialized_start = 9901 + _UNPIVOT._serialized_end = 10228 + _UNPIVOT_VALUES._serialized_start = 10158 + _UNPIVOT_VALUES._serialized_end = 10217 + _TOSCHEMA._serialized_start = 10230 + _TOSCHEMA._serialized_end = 10336 + _REPARTITIONBYEXPRESSION._serialized_start = 10339 + _REPARTITIONBYEXPRESSION._serialized_end = 10542 + _MAPPARTITIONS._serialized_start = 10545 + _MAPPARTITIONS._serialized_end = 10726 + _GROUPMAP._serialized_start = 10729 + _GROUPMAP._serialized_end = 10932 + _COGROUPMAP._serialized_start = 10935 + _COGROUPMAP._serialized_end = 11287 + _APPLYINPANDASWITHSTATE._serialized_start = 11290 + _APPLYINPANDASWITHSTATE._serialized_end = 11647 + _COLLECTMETRICS._serialized_start = 11650 + _COLLECTMETRICS._serialized_end = 11786 + _PARSE._serialized_start = 11789 + _PARSE._serialized_end = 12177 + _PARSE_OPTIONSENTRY._serialized_start = 3597 + _PARSE_OPTIONSENTRY._serialized_end = 3655 + _PARSE_PARSEFORMAT._serialized_start = 12078 + _PARSE_PARSEFORMAT._serialized_end = 12166 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 14552d1f127..b847378d78b 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -96,6 +96,7 @@ class Relation(google.protobuf.message.Message): CO_GROUP_MAP_FIELD_NUMBER: builtins.int WITH_WATERMARK_FIELD_NUMBER: builtins.int APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int + HTML_STRING_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -179,6 +180,8 @@ class Relation(google.protobuf.message.Message): @property def apply_in_pandas_with_state(self) -> global___ApplyInPandasWithState: ... @property + def html_string(self) -> global___HtmlString: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -249,6 +252,7 @@ class Relation(google.protobuf.message.Message): co_group_map: global___CoGroupMap | None = ..., with_watermark: global___WithWatermark | None = ..., apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ..., + html_string: global___HtmlString | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -307,6 +311,8 @@ class Relation(google.protobuf.message.Message): b"group_map", "hint", b"hint", + "html_string", + b"html_string", "join", b"join", "limit", @@ -410,6 +416,8 @@ class Relation(google.protobuf.message.Message): b"group_map", "hint", b"hint", + "html_string", + b"html_string", "join", b"join", "limit", @@ -506,6 +514,7 @@ class Relation(google.protobuf.message.Message): "co_group_map", "with_watermark", "apply_in_pandas_with_state", + "html_string", "fill_na", "drop_na", "replace", @@ -1813,6 +1822,44 @@ class ShowString(google.protobuf.message.Message): global___ShowString = ShowString +class HtmlString(google.protobuf.message.Message): + """Compose the string representing rows for output. + It will invoke 'Dataset.htmlString' to compute the results. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + NUM_ROWS_FIELD_NUMBER: builtins.int + TRUNCATE_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + num_rows: builtins.int + """(Required) Number of rows to show.""" + truncate: builtins.int + """(Required) If set to more than 0, truncates strings to + `truncate` characters and all cells will be aligned right. + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + num_rows: builtins.int = ..., + truncate: builtins.int = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "input", b"input", "num_rows", b"num_rows", "truncate", b"truncate" + ], + ) -> None: ... + +global___HtmlString = HtmlString + class StatSummary(google.protobuf.message.Message): """Computes specified statistics for numeric and string columns. It will invoke 'Dataset.summary' (same as 'StatFunctions.summary') diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0014695aa71..01e5e2d77a6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,7 +22,6 @@ import random import warnings from collections.abc import Iterable from functools import reduce -from html import escape as html_escape from typing import ( Any, Callable, @@ -936,32 +935,10 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): if not self._support_repr_html: self._support_repr_html = True if self.sparkSession._jconf.isReplEagerEvalEnabled(): - max_num_rows = max(self.sparkSession._jconf.replEagerEvalMaxNumRows(), 0) - sock_info = self._jdf.getRowsToPython( - max_num_rows, + return self._jdf.htmlString( + self.sparkSession._jconf.replEagerEvalMaxNumRows(), self.sparkSession._jconf.replEagerEvalTruncate(), ) - rows = list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) - head = rows[0] - row_data = rows[1:] - has_more_data = len(row_data) > max_num_rows - row_data = row_data[:max_num_rows] - - html = "<table border='1'>\n" - # generate table head - html += "<tr><th>%s</th></tr>\n" % "</th><th>".join(map(lambda x: html_escape(x), head)) - # generate table rows - for row in row_data: - html += "<tr><td>%s</td></tr>\n" % "</td><td>".join( - map(lambda x: html_escape(x), row) - ) - html += "</table>\n" - if has_more_data: - html += "only showing top %d %s\n" % ( - max_num_rows, - "row" if max_num_rows == 1 else "rows", - ) - return html else: return None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index b1b3a94accf..2c1b6342924 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2923,7 +2923,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): "foreachPartition", "checkpoint", "localCheckpoint", - "_repr_html_", ): with self.assertRaises(NotImplementedError): getattr(df, f)() diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 72a97a2a65c..a74afc4d504 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -50,11 +50,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_repartitionByRange_dataframe(self): super().test_repartitionByRange_dataframe() - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_repr_behaviors(self): - super().test_repr_behaviors() - @unittest.skip("Spark Connect does not SparkContext but the tests depend on them.") def test_same_semantics_error(self): super().test_same_semantics_error() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 584ce19c77a..be37fdae025 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils +import org.apache.commons.text.StringEscapeUtils import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable} @@ -402,6 +403,43 @@ class Dataset[T] private[sql]( sb.toString() } + /** + * Compose the HTML representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + */ + private[sql] def htmlString( + _numRows: Int, + truncate: Int = 20): String = { + val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1) + // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. + val tmpRows = getRows(numRows, truncate) + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) + + val sb = new StringBuilder + + sb.append("<table border='1'>\n") + + sb.append(rows.head.map(StringEscapeUtils.escapeHtml4) + .mkString("<tr><th>", "</th><th>", "</th></tr>\n")) + rows.tail.foreach { row => + sb.append(row.map(StringEscapeUtils.escapeHtml4) + .mkString("<tr><td>", "</td><td>", "</td></tr>\n")) + } + + sb.append("</table>\n") + + if (hasMoreData) { + sb.append(s"only showing top $numRows ${if (numRows == 1) "row" else "rows"}\n") + } + + sb.toString() + } + override def toString: String = { try { val builder = new StringBuilder --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org