This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 23e3c9b7c2f [SPARK-41828][CONNECT][PYTHON] Make `createDataFrame` support empty dataframe 23e3c9b7c2f is described below commit 23e3c9b7c2f08c5350992934cf660de6d2793982 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Jan 4 17:45:46 2023 +0900 [SPARK-41828][CONNECT][PYTHON] Make `createDataFrame` support empty dataframe ### What changes were proposed in this pull request? Make `createDataFrame` support empty dataframe: ``` In [24]: spark.createDataFrame([], schema="x STRING, y INTEGER") Out[24]: DataFrame[x: string, y: int] ``` ### Why are the changes needed? to be consistent with PySpark ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT and enabled doctests Closes #39379 from zhengruifeng/connect_fix_41828. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 18 ++-- .../sql/connect/planner/SparkConnectPlanner.scala | 68 ++++++++----- python/pyspark/sql/connect/dataframe.py | 3 - python/pyspark/sql/connect/plan.py | 34 ++++--- python/pyspark/sql/connect/proto/relations_pb2.py | 110 ++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 41 ++++---- python/pyspark/sql/connect/session.py | 32 ++++-- .../sql/tests/connect/test_connect_basic.py | 28 ++++++ 8 files changed, 193 insertions(+), 141 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 51981714ded..c0f22dd4576 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -328,20 +328,16 @@ message Deduplicate { // A relation that does not need to be qualified by name. message LocalRelation { - // Local collection data serialized into Arrow IPC streaming format which contains + // (Optional) Local collection data serialized into Arrow IPC streaming format which contains // the schema of the data. - bytes data = 1; + optional bytes data = 1; - // (Optional) The user provided schema. + // (Optional) The schema of local data. + // It should be either a DDL-formatted type string or a JSON string. // - // The Sever side will update the column names and data types according to this schema. - oneof schema { - - DataType datatype = 2; - - // Server will use Catalyst parser to parse this string to DataType. - string datatype_str = 3; - } + // The server side will update the column names and data types according to this schema. + // If the 'data' is not provided, then this schema will be required. + optional string schema = 2; } // Relation of type [[Sample]] that samples a fraction of the dataset. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 754bb7ced9e..b4c882541e0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -571,47 +571,61 @@ class SparkConnectPlanner(session: SparkSession) { try { parser.parseTableSchema(sqlText) } catch { - case _: ParseException => + case e: ParseException => try { parser.parseDataType(sqlText) } catch { case _: ParseException => - parser.parseDataType(s"struct<${sqlText.trim}>") + try { + parser.parseDataType(s"struct<${sqlText.trim}>") + } catch { + case _: ParseException => + throw e + } } } } private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { - val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( - Iterator(rel.getData.toByteArray), - TaskContext.get()) - if (structType == null) { - throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.") + var schema: StructType = null + if (rel.hasSchema) { + val schemaType = DataType.parseTypeWithFallback( + rel.getSchema, + parseDatatypeString, + fallbackParser = DataType.fromJson) + schema = schemaType match { + case s: StructType => s + case d => StructType(Seq(StructField("value", d))) + } } - val attributes = structType.toAttributes - val proj = UnsafeProjection.create(attributes, attributes) - val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) - if (!rel.hasDatatype && !rel.hasDatatypeStr) { - return relation - } + if (rel.hasData) { + val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( + Iterator(rel.getData.toByteArray), + TaskContext.get()) + if (structType == null) { + throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.") + } + val attributes = structType.toAttributes + val proj = UnsafeProjection.create(attributes, attributes) + val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) - val schemaType = if (rel.hasDatatype) { - DataTypeProtoConverter.toCatalystType(rel.getDatatype) + if (schema == null) { + relation + } else { + Dataset + .ofRows(session, logicalPlan = relation) + .toDF(schema.names: _*) + .to(schema) + .logicalPlan + } } else { - parseDatatypeString(rel.getDatatypeStr) - } - - val schemaStruct = schemaType match { - case s: StructType => s - case d => StructType(Seq(StructField("value", d))) + if (schema == null) { + throw InvalidPlanInput( + s"Schema for LocalRelation is required when the input data is not provided.") + } + LocalRelation(schema.toAttributes, data = Seq.empty) } - - Dataset - .ofRows(session, logicalPlan = relation) - .toDF(schemaStruct.names: _*) - .to(schemaStruct) - .logicalPlan } private def transformReadRel(rel: proto.Read): LogicalPlan = { diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 57c9e801c22..646cc5ced9a 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1426,9 +1426,6 @@ def _test() -> None: # TODO(SPARK-41827): groupBy requires all cols be Column or str del pyspark.sql.connect.dataframe.DataFrame.groupBy.__doc__ - # TODO(SPARK-41828): Implement creating empty DataFrame - del pyspark.sql.connect.dataframe.DataFrame.isEmpty.__doc__ - # TODO(SPARK-41829): Add Dataframe sort ordering del pyspark.sql.connect.dataframe.DataFrame.sort.__doc__ del pyspark.sql.connect.dataframe.DataFrame.sortWithinPartitions.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 48a8fa598e7..1f4e4192fdf 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -270,30 +270,34 @@ class LocalRelation(LogicalPlan): def __init__( self, - table: "pa.Table", - schema: Optional[Union[DataType, str]] = None, + table: Optional["pa.Table"], + schema: Optional[str] = None, ) -> None: super().__init__(None) - assert table is not None and isinstance(table, pa.Table) + + if table is None: + assert schema is not None + else: + assert isinstance(table, pa.Table) + + assert schema is None or isinstance(schema, str) + self._table = table - if schema is not None: - assert isinstance(schema, (DataType, str)) self._schema = schema def plan(self, session: "SparkConnectClient") -> proto.Relation: - sink = pa.BufferOutputStream() - with pa.ipc.new_stream(sink, self._table.schema) as writer: - for b in self._table.to_batches(): - writer.write_batch(b) - plan = proto.Relation() - plan.local_relation.data = sink.getvalue().to_pybytes() + + if self._table is not None: + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, self._table.schema) as writer: + for b in self._table.to_batches(): + writer.write_batch(b) + plan.local_relation.data = sink.getvalue().to_pybytes() + if self._schema is not None: - if isinstance(self._schema, DataType): - plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema)) - elif isinstance(self._schema, str): - plan.local_relation.datatype_str = self._schema + plan.local_relation.schema = self._schema return plan def print(self, indent: int = 0) -> str: diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index cf0f2eb3513..9e230c3d239 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -656,58 +656,58 @@ if _descriptor._USE_C_DESCRIPTORS == False: _DROP._serialized_end = 5290 _DEDUPLICATE._serialized_start = 5293 _DEDUPLICATE._serialized_end = 5464 - _LOCALRELATION._serialized_start = 5467 - _LOCALRELATION._serialized_end = 5604 - _SAMPLE._serialized_start = 5607 - _SAMPLE._serialized_end = 5880 - _RANGE._serialized_start = 5883 - _RANGE._serialized_end = 6028 - _SUBQUERYALIAS._serialized_start = 6030 - _SUBQUERYALIAS._serialized_end = 6144 - _REPARTITION._serialized_start = 6147 - _REPARTITION._serialized_end = 6289 - _SHOWSTRING._serialized_start = 6292 - _SHOWSTRING._serialized_end = 6434 - _STATSUMMARY._serialized_start = 6436 - _STATSUMMARY._serialized_end = 6528 - _STATDESCRIBE._serialized_start = 6530 - _STATDESCRIBE._serialized_end = 6611 - _STATCROSSTAB._serialized_start = 6613 - _STATCROSSTAB._serialized_end = 6714 - _STATCOV._serialized_start = 6716 - _STATCOV._serialized_end = 6812 - _STATCORR._serialized_start = 6815 - _STATCORR._serialized_end = 6952 - _STATAPPROXQUANTILE._serialized_start = 6955 - _STATAPPROXQUANTILE._serialized_end = 7119 - _STATFREQITEMS._serialized_start = 7121 - _STATFREQITEMS._serialized_end = 7246 - _STATSAMPLEBY._serialized_start = 7249 - _STATSAMPLEBY._serialized_end = 7558 - _STATSAMPLEBY_FRACTION._serialized_start = 7450 - _STATSAMPLEBY_FRACTION._serialized_end = 7549 - _NAFILL._serialized_start = 7561 - _NAFILL._serialized_end = 7695 - _NADROP._serialized_start = 7698 - _NADROP._serialized_end = 7832 - _NAREPLACE._serialized_start = 7835 - _NAREPLACE._serialized_end = 8131 - _NAREPLACE_REPLACEMENT._serialized_start = 7990 - _NAREPLACE_REPLACEMENT._serialized_end = 8131 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8133 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8247 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8250 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8509 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 8442 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8509 - _WITHCOLUMNS._serialized_start = 8512 - _WITHCOLUMNS._serialized_end = 8643 - _HINT._serialized_start = 8646 - _HINT._serialized_end = 8786 - _UNPIVOT._serialized_start = 8789 - _UNPIVOT._serialized_end = 9035 - _TOSCHEMA._serialized_start = 9037 - _TOSCHEMA._serialized_end = 9143 - _REPARTITIONBYEXPRESSION._serialized_start = 9146 - _REPARTITIONBYEXPRESSION._serialized_end = 9349 + _LOCALRELATION._serialized_start = 5466 + _LOCALRELATION._serialized_end = 5555 + _SAMPLE._serialized_start = 5558 + _SAMPLE._serialized_end = 5831 + _RANGE._serialized_start = 5834 + _RANGE._serialized_end = 5979 + _SUBQUERYALIAS._serialized_start = 5981 + _SUBQUERYALIAS._serialized_end = 6095 + _REPARTITION._serialized_start = 6098 + _REPARTITION._serialized_end = 6240 + _SHOWSTRING._serialized_start = 6243 + _SHOWSTRING._serialized_end = 6385 + _STATSUMMARY._serialized_start = 6387 + _STATSUMMARY._serialized_end = 6479 + _STATDESCRIBE._serialized_start = 6481 + _STATDESCRIBE._serialized_end = 6562 + _STATCROSSTAB._serialized_start = 6564 + _STATCROSSTAB._serialized_end = 6665 + _STATCOV._serialized_start = 6667 + _STATCOV._serialized_end = 6763 + _STATCORR._serialized_start = 6766 + _STATCORR._serialized_end = 6903 + _STATAPPROXQUANTILE._serialized_start = 6906 + _STATAPPROXQUANTILE._serialized_end = 7070 + _STATFREQITEMS._serialized_start = 7072 + _STATFREQITEMS._serialized_end = 7197 + _STATSAMPLEBY._serialized_start = 7200 + _STATSAMPLEBY._serialized_end = 7509 + _STATSAMPLEBY_FRACTION._serialized_start = 7401 + _STATSAMPLEBY_FRACTION._serialized_end = 7500 + _NAFILL._serialized_start = 7512 + _NAFILL._serialized_end = 7646 + _NADROP._serialized_start = 7649 + _NADROP._serialized_end = 7783 + _NAREPLACE._serialized_start = 7786 + _NAREPLACE._serialized_end = 8082 + _NAREPLACE_REPLACEMENT._serialized_start = 7941 + _NAREPLACE_REPLACEMENT._serialized_end = 8082 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8084 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8198 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8201 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8460 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 8393 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8460 + _WITHCOLUMNS._serialized_start = 8463 + _WITHCOLUMNS._serialized_end = 8594 + _HINT._serialized_start = 8597 + _HINT._serialized_end = 8737 + _UNPIVOT._serialized_start = 8740 + _UNPIVOT._serialized_end = 8986 + _TOSCHEMA._serialized_start = 8988 + _TOSCHEMA._serialized_end = 9094 + _REPARTITIONBYEXPRESSION._serialized_start = 9097 + _REPARTITIONBYEXPRESSION._serialized_end = 9300 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 7e63d363277..811f005d24b 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1268,45 +1268,44 @@ class LocalRelation(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DATA_FIELD_NUMBER: builtins.int - DATATYPE_FIELD_NUMBER: builtins.int - DATATYPE_STR_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int data: builtins.bytes - """Local collection data serialized into Arrow IPC streaming format which contains + """(Optional) Local collection data serialized into Arrow IPC streaming format which contains the schema of the data. """ - @property - def datatype(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... - datatype_str: builtins.str - """Server will use Catalyst parser to parse this string to DataType.""" + schema: builtins.str + """(Optional) The schema of local data. + It should be either a DDL-formatted type string or a JSON string. + + The server side will update the column names and data types according to this schema. + If the 'data' is not provided, then this schema will be required. + """ def __init__( self, *, - data: builtins.bytes = ..., - datatype: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., - datatype_str: builtins.str = ..., + data: builtins.bytes | None = ..., + schema: builtins.str | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "datatype", b"datatype", "datatype_str", b"datatype_str", "schema", b"schema" + "_data", b"_data", "_schema", b"_schema", "data", b"data", "schema", b"schema" ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "data", - b"data", - "datatype", - b"datatype", - "datatype_str", - b"datatype_str", - "schema", - b"schema", + "_data", b"_data", "_schema", b"_schema", "data", b"data", "schema", b"schema" ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_data", b"_data"] + ) -> typing_extensions.Literal["data"] | None: ... + @typing.overload def WhichOneof( - self, oneof_group: typing_extensions.Literal["schema", b"schema"] - ) -> typing_extensions.Literal["datatype", "datatype_str"] | None: ... + self, oneof_group: typing_extensions.Literal["_schema", b"_schema"] + ) -> typing_extensions.Literal["schema"] | None: ... global___LocalRelation = LocalRelation diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index a5d778e9c0e..09ad58fa3e0 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -31,6 +31,7 @@ from pyspark.sql.types import ( Row, DataType, StructType, + AtomicType, ) from pyspark.sql.utils import to_str @@ -177,20 +178,18 @@ class SparkSession: def createDataFrame( self, data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]], - schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] = None, + schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, ) -> "DataFrame": assert data is not None if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if isinstance(data, Sized) and len(data) == 0: - raise ValueError("Input data cannot be empty") table: Optional[pa.Table] = None - _schema: Optional[StructType] = None + _schema: Optional[Union[AtomicType, StructType]] = None _schema_str: Optional[str] = None _cols: Optional[List[str]] = None - if isinstance(schema, StructType): + if isinstance(schema, (AtomicType, StructType)): _schema = schema elif isinstance(schema, str): @@ -200,6 +199,14 @@ class SparkSession: # Must re-encode any unicode strings to be consistent with StructField names _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema] + if isinstance(data, Sized) and len(data) == 0: + if _schema is not None: + return DataFrame.withPlan(LocalRelation(table=None, schema=_schema.json()), self) + elif _schema_str is not None: + return DataFrame.withPlan(LocalRelation(table=None, schema=_schema_str), self) + else: + raise ValueError("can not infer schema from empty dataset") + if isinstance(data, pd.DataFrame): table = pa.Table.from_pandas(data) @@ -253,8 +260,10 @@ class SparkSession: _cols = ["_%s" % i for i in range(1, len(_data[0]) + 1)] else: _cols = ["_1"] - else: + elif isinstance(_schema, StructType): _cols = _schema.names + else: + _cols = ["value"] if isinstance(_data[0], Row): table = pa.Table.from_pylist([row.asDict(recursive=True) for row in _data]) @@ -268,19 +277,24 @@ class SparkSession: # Validate number of columns num_cols = table.shape[1] - if _schema is not None and len(_schema.fields) != num_cols: + if ( + _schema is not None + and isinstance(_schema, StructType) + and len(_schema.fields) != num_cols + ): raise ValueError( f"Length mismatch: Expected axis has {num_cols} elements, " f"new values have {len(_schema.fields)} elements" ) - elif _cols is not None and len(_cols) != num_cols: + + if _cols is not None and len(_cols) != num_cols: raise ValueError( f"Length mismatch: Expected axis has {num_cols} elements, " f"new values have {len(_cols)} elements" ) if _schema is not None: - return DataFrame.withPlan(LocalRelation(table, schema=_schema), self) + return DataFrame.withPlan(LocalRelation(table, schema=_schema.json()), self) elif _schema_str is not None: return DataFrame.withPlan(LocalRelation(table, schema=_schema_str), self) elif _cols is not None and len(_cols) > 0: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index e82dc7f7f76..fe6c2c65e25 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -525,6 +525,34 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): sdf.select(SF.pmod("a", "b")).toPandas(), ) + def test_create_empty_df(self): + for schema in [ + "STRING", + "x STRING", + "x STRING, y INTEGER", + StringType(), + StructType( + [ + StructField("x", StringType(), True), + StructField("y", IntegerType(), True), + ] + ), + ]: + print(schema) + print(schema) + print(schema) + cdf = self.connect.createDataFrame(data=[], schema=schema) + sdf = self.spark.createDataFrame(data=[], schema=schema) + + self.assert_eq(cdf.toPandas(), sdf.toPandas()) + + # check error + with self.assertRaisesRegex( + ValueError, + "can not infer schema from empty dataset", + ): + self.connect.createDataFrame(data=[]) + def test_simple_explain_string(self): df = self.connect.read.table(self.tbl_name).limit(10) result = df._explain_string() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org