This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ad35f35f12f [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source ad35f35f12f is described below commit ad35f35f12f715c276d216d621be583a6a44111a Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Sat Feb 25 14:14:01 2023 -0400 [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source ### What changes were proposed in this pull request? Fixes `DataFrameReader` to use the default source. ### Why are the changes needed? ```py spark.read.load(path) ``` should work and use the default source without specifying the format. ### Does this PR introduce _any_ user-facing change? The `format` doesn't need to be specified. ### How was this patch tested? Enabled related tests. Closes #40166 from ueshin/issues/SPARK-42570/reader. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 6 +- .../sql/connect/planner/SparkConnectPlanner.scala | 7 +- .../connect/planner/SparkConnectPlannerSuite.scala | 12 -- python/pyspark/sql/connect/plan.py | 8 +- python/pyspark/sql/connect/proto/relations_pb2.py | 186 ++++++++++----------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 26 ++- python/pyspark/sql/connect/readwriter.py | 2 +- .../sql/tests/connect/test_parity_readwriter.py | 10 +- python/pyspark/sql/tests/test_readwriter.py | 126 +++++++------- 9 files changed, 193 insertions(+), 190 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 4d96b6b0c7e..2221b4e3982 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -122,8 +122,10 @@ message Read { } message DataSource { - // (Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro. - string format = 1; + // (Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro. + // + // If not set, the value from SQL conf 'spark.sql.sources.default' will be used. + optional string format = 1; // (Optional) If not set, Spark will infer the schema. // diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cc43c1cace3..887379ab80d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -667,12 +667,11 @@ class SparkConnectPlanner(val session: SparkSession) { UnresolvedRelation(multipartIdentifier) case proto.Read.ReadTypeCase.DATA_SOURCE => - if (rel.getDataSource.getFormat == "") { - throw InvalidPlanInput("DataSource requires a format") - } val localMap = CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap) val reader = session.read - reader.format(rel.getDataSource.getFormat) + if (rel.getDataSource.hasFormat) { + reader.format(rel.getDataSource.getFormat) + } localMap.foreach { case (key, value) => reader.option(key, value) } if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 3e4a0f94ea2..83056c27729 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -332,18 +332,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { assert(res.nodeName == "Aggregate") } - test("Invalid DataSource") { - val dataSource = proto.Read.DataSource.newBuilder() - - val e = intercept[InvalidPlanInput]( - transform( - proto.Relation - .newBuilder() - .setRead(proto.Read.newBuilder().setDataSource(dataSource)) - .build())) - assert(e.getMessage.contains("DataSource requires a format")) - } - test("Test invalid deduplicate") { val deduplicate = proto.Deduplicate .newBuilder() diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index badbb9871ed..857cca64c6f 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -255,15 +255,14 @@ class DataSource(LogicalPlan): def __init__( self, - format: str, + format: Optional[str] = None, schema: Optional[str] = None, options: Optional[Mapping[str, str]] = None, paths: Optional[List[str]] = None, ) -> None: super().__init__(None) - assert isinstance(format, str) and format != "" - + assert format is None or isinstance(format, str) assert schema is None or isinstance(schema, str) if options is not None: @@ -282,7 +281,8 @@ class DataSource(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() - plan.read.data_source.format = self._format + if self._format is not None: + plan.read.data_source.format = self._format if self._schema is not None: plan.read.data_source.schema = self._schema if self._options is not None and len(self._options) > 0: diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 3afdf61e681..c6d9616e44c 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -657,99 +657,99 @@ if _descriptor._USE_C_DESCRIPTORS == False: _SQL_ARGSENTRY._serialized_start = 2704 _SQL_ARGSENTRY._serialized_end = 2759 _READ._serialized_start = 2762 - _READ._serialized_end = 3210 + _READ._serialized_end = 3226 _READ_NAMEDTABLE._serialized_start = 2904 _READ_NAMEDTABLE._serialized_end = 2965 _READ_DATASOURCE._serialized_start = 2968 - _READ_DATASOURCE._serialized_end = 3197 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3128 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3186 - _PROJECT._serialized_start = 3212 - _PROJECT._serialized_end = 3329 - _FILTER._serialized_start = 3331 - _FILTER._serialized_end = 3443 - _JOIN._serialized_start = 3446 - _JOIN._serialized_end = 3917 - _JOIN_JOINTYPE._serialized_start = 3709 - _JOIN_JOINTYPE._serialized_end = 3917 - _SETOPERATION._serialized_start = 3920 - _SETOPERATION._serialized_end = 4399 - _SETOPERATION_SETOPTYPE._serialized_start = 4236 - _SETOPERATION_SETOPTYPE._serialized_end = 4350 - _LIMIT._serialized_start = 4401 - _LIMIT._serialized_end = 4477 - _OFFSET._serialized_start = 4479 - _OFFSET._serialized_end = 4558 - _TAIL._serialized_start = 4560 - _TAIL._serialized_end = 4635 - _AGGREGATE._serialized_start = 4638 - _AGGREGATE._serialized_end = 5220 - _AGGREGATE_PIVOT._serialized_start = 4977 - _AGGREGATE_PIVOT._serialized_end = 5088 - _AGGREGATE_GROUPTYPE._serialized_start = 5091 - _AGGREGATE_GROUPTYPE._serialized_end = 5220 - _SORT._serialized_start = 5223 - _SORT._serialized_end = 5383 - _DROP._serialized_start = 5385 - _DROP._serialized_end = 5485 - _DEDUPLICATE._serialized_start = 5488 - _DEDUPLICATE._serialized_end = 5659 - _LOCALRELATION._serialized_start = 5661 - _LOCALRELATION._serialized_end = 5750 - _SAMPLE._serialized_start = 5753 - _SAMPLE._serialized_end = 6026 - _RANGE._serialized_start = 6029 - _RANGE._serialized_end = 6174 - _SUBQUERYALIAS._serialized_start = 6176 - _SUBQUERYALIAS._serialized_end = 6290 - _REPARTITION._serialized_start = 6293 - _REPARTITION._serialized_end = 6435 - _SHOWSTRING._serialized_start = 6438 - _SHOWSTRING._serialized_end = 6580 - _STATSUMMARY._serialized_start = 6582 - _STATSUMMARY._serialized_end = 6674 - _STATDESCRIBE._serialized_start = 6676 - _STATDESCRIBE._serialized_end = 6757 - _STATCROSSTAB._serialized_start = 6759 - _STATCROSSTAB._serialized_end = 6860 - _STATCOV._serialized_start = 6862 - _STATCOV._serialized_end = 6958 - _STATCORR._serialized_start = 6961 - _STATCORR._serialized_end = 7098 - _STATAPPROXQUANTILE._serialized_start = 7101 - _STATAPPROXQUANTILE._serialized_end = 7265 - _STATFREQITEMS._serialized_start = 7267 - _STATFREQITEMS._serialized_end = 7392 - _STATSAMPLEBY._serialized_start = 7395 - _STATSAMPLEBY._serialized_end = 7704 - _STATSAMPLEBY_FRACTION._serialized_start = 7596 - _STATSAMPLEBY_FRACTION._serialized_end = 7695 - _NAFILL._serialized_start = 7707 - _NAFILL._serialized_end = 7841 - _NADROP._serialized_start = 7844 - _NADROP._serialized_end = 7978 - _NAREPLACE._serialized_start = 7981 - _NAREPLACE._serialized_end = 8277 - _NAREPLACE_REPLACEMENT._serialized_start = 8136 - _NAREPLACE_REPLACEMENT._serialized_end = 8277 - _TODF._serialized_start = 8279 - _TODF._serialized_end = 8367 - _WITHCOLUMNSRENAMED._serialized_start = 8370 - _WITHCOLUMNSRENAMED._serialized_end = 8609 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8542 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8609 - _WITHCOLUMNS._serialized_start = 8611 - _WITHCOLUMNS._serialized_end = 8730 - _HINT._serialized_start = 8733 - _HINT._serialized_end = 8865 - _UNPIVOT._serialized_start = 8868 - _UNPIVOT._serialized_end = 9195 - _UNPIVOT_VALUES._serialized_start = 9125 - _UNPIVOT_VALUES._serialized_end = 9184 - _TOSCHEMA._serialized_start = 9197 - _TOSCHEMA._serialized_end = 9303 - _REPARTITIONBYEXPRESSION._serialized_start = 9306 - _REPARTITIONBYEXPRESSION._serialized_end = 9509 - _FRAMEMAP._serialized_start = 9511 - _FRAMEMAP._serialized_end = 9636 + _READ_DATASOURCE._serialized_end = 3213 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3133 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3191 + _PROJECT._serialized_start = 3228 + _PROJECT._serialized_end = 3345 + _FILTER._serialized_start = 3347 + _FILTER._serialized_end = 3459 + _JOIN._serialized_start = 3462 + _JOIN._serialized_end = 3933 + _JOIN_JOINTYPE._serialized_start = 3725 + _JOIN_JOINTYPE._serialized_end = 3933 + _SETOPERATION._serialized_start = 3936 + _SETOPERATION._serialized_end = 4415 + _SETOPERATION_SETOPTYPE._serialized_start = 4252 + _SETOPERATION_SETOPTYPE._serialized_end = 4366 + _LIMIT._serialized_start = 4417 + _LIMIT._serialized_end = 4493 + _OFFSET._serialized_start = 4495 + _OFFSET._serialized_end = 4574 + _TAIL._serialized_start = 4576 + _TAIL._serialized_end = 4651 + _AGGREGATE._serialized_start = 4654 + _AGGREGATE._serialized_end = 5236 + _AGGREGATE_PIVOT._serialized_start = 4993 + _AGGREGATE_PIVOT._serialized_end = 5104 + _AGGREGATE_GROUPTYPE._serialized_start = 5107 + _AGGREGATE_GROUPTYPE._serialized_end = 5236 + _SORT._serialized_start = 5239 + _SORT._serialized_end = 5399 + _DROP._serialized_start = 5401 + _DROP._serialized_end = 5501 + _DEDUPLICATE._serialized_start = 5504 + _DEDUPLICATE._serialized_end = 5675 + _LOCALRELATION._serialized_start = 5677 + _LOCALRELATION._serialized_end = 5766 + _SAMPLE._serialized_start = 5769 + _SAMPLE._serialized_end = 6042 + _RANGE._serialized_start = 6045 + _RANGE._serialized_end = 6190 + _SUBQUERYALIAS._serialized_start = 6192 + _SUBQUERYALIAS._serialized_end = 6306 + _REPARTITION._serialized_start = 6309 + _REPARTITION._serialized_end = 6451 + _SHOWSTRING._serialized_start = 6454 + _SHOWSTRING._serialized_end = 6596 + _STATSUMMARY._serialized_start = 6598 + _STATSUMMARY._serialized_end = 6690 + _STATDESCRIBE._serialized_start = 6692 + _STATDESCRIBE._serialized_end = 6773 + _STATCROSSTAB._serialized_start = 6775 + _STATCROSSTAB._serialized_end = 6876 + _STATCOV._serialized_start = 6878 + _STATCOV._serialized_end = 6974 + _STATCORR._serialized_start = 6977 + _STATCORR._serialized_end = 7114 + _STATAPPROXQUANTILE._serialized_start = 7117 + _STATAPPROXQUANTILE._serialized_end = 7281 + _STATFREQITEMS._serialized_start = 7283 + _STATFREQITEMS._serialized_end = 7408 + _STATSAMPLEBY._serialized_start = 7411 + _STATSAMPLEBY._serialized_end = 7720 + _STATSAMPLEBY_FRACTION._serialized_start = 7612 + _STATSAMPLEBY_FRACTION._serialized_end = 7711 + _NAFILL._serialized_start = 7723 + _NAFILL._serialized_end = 7857 + _NADROP._serialized_start = 7860 + _NADROP._serialized_end = 7994 + _NAREPLACE._serialized_start = 7997 + _NAREPLACE._serialized_end = 8293 + _NAREPLACE_REPLACEMENT._serialized_start = 8152 + _NAREPLACE_REPLACEMENT._serialized_end = 8293 + _TODF._serialized_start = 8295 + _TODF._serialized_end = 8383 + _WITHCOLUMNSRENAMED._serialized_start = 8386 + _WITHCOLUMNSRENAMED._serialized_end = 8625 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8558 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8625 + _WITHCOLUMNS._serialized_start = 8627 + _WITHCOLUMNS._serialized_end = 8746 + _HINT._serialized_start = 8749 + _HINT._serialized_end = 8881 + _UNPIVOT._serialized_start = 8884 + _UNPIVOT._serialized_end = 9211 + _UNPIVOT_VALUES._serialized_start = 9141 + _UNPIVOT_VALUES._serialized_end = 9200 + _TOSCHEMA._serialized_start = 9213 + _TOSCHEMA._serialized_end = 9319 + _REPARTITIONBYEXPRESSION._serialized_start = 9322 + _REPARTITIONBYEXPRESSION._serialized_end = 9525 + _FRAMEMAP._serialized_start = 9527 + _FRAMEMAP._serialized_end = 9652 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 3f3b9f4c5b0..27fd07a192e 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -602,7 +602,10 @@ class Read(google.protobuf.message.Message): OPTIONS_FIELD_NUMBER: builtins.int PATHS_FIELD_NUMBER: builtins.int format: builtins.str - """(Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro.""" + """(Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro. + + If not set, the value from SQL conf 'spark.sql.sources.default' will be used. + """ schema: builtins.str """(Optional) If not set, Spark will infer the schema. @@ -624,17 +627,29 @@ class Read(google.protobuf.message.Message): def __init__( self, *, - format: builtins.str = ..., + format: builtins.str | None = ..., schema: builtins.str | None = ..., options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., paths: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["_schema", b"_schema", "schema", b"schema"] + self, + field_name: typing_extensions.Literal[ + "_format", + b"_format", + "_schema", + b"_schema", + "format", + b"format", + "schema", + b"schema", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_format", + b"_format", "_schema", b"_schema", "format", @@ -647,6 +662,11 @@ class Read(google.protobuf.message.Message): b"schema", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_format", b"_format"] + ) -> typing_extensions.Literal["format"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_schema", b"_schema"] ) -> typing_extensions.Literal["schema"] | None: ... diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 292e58b3552..9c9c79cb6eb 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -63,7 +63,7 @@ class DataFrameReader(OptionUtils): def __init__(self, client: "SparkSession"): self._client = client - self._format = "" + self._format: Optional[str] = None self._schema = "" self._options: Dict[str, str] = {} diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py b/python/pyspark/sql/tests/connect/test_parity_readwriter.py index bf77043ef38..2fa3f79a92f 100644 --- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py @@ -22,15 +22,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase class ReadwriterParityTests(ReadwriterTestsMixin, ReusedConnectTestCase): - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_save_and_load(self): - super().test_save_and_load() - - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_save_and_load_builder(self): - super().test_save_and_load_builder() + pass class ReadwriterV2ParityTests(ReadwriterV2TestsMixin, ReusedConnectTestCase): diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 7f9b5e61051..21c66284ace 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -31,75 +31,77 @@ class ReadwriterTestsMixin: df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.json(tmpPath, "overwrite") - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.save( - format="json", - mode="overwrite", - path=tmpPath, - noUse="this options will not be used in save.", - ) - actual = self.spark.read.load( - format="json", path=tmpPath, noUse="this options will not be used in load." - ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet" - ) - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + try: + df.write.json(tmpPath) + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.spark.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.json(tmpPath, "overwrite") + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.save( + format="json", + mode="overwrite", + path=tmpPath, + noUse="this options will not be used in save.", + ) + actual = self.spark.read.load( + format="json", path=tmpPath, noUse="this options will not be used in load." + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - csvpath = os.path.join(tempfile.mkdtemp(), "data") - df.write.option("quote", None).format("csv").save(csvpath) + try: + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect() + actual = self.spark.read.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + finally: + self.spark.sql("RESET spark.sql.sources.default").collect() - shutil.rmtree(tmpPath) + csvpath = os.path.join(tempfile.mkdtemp(), "data") + df.write.option("quote", None).format("csv").save(csvpath) + finally: + shutil.rmtree(tmpPath) def test_save_and_load_builder(self): df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - df.write.json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - schema = StructType([StructField("value", StringType(), True)]) - actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - - df.write.mode("overwrite").json(tmpPath) - actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - df.write.mode("overwrite").options(noUse="this options will not be used in save.").option( - "noUse", "this option will not be used in save." - ).format("json").save(path=tmpPath) - actual = self.spark.read.format("json").load( - path=tmpPath, noUse="this options will not be used in load." - ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - - defaultDataSourceName = self.spark.conf.get( - "spark.sql.sources.default", "org.apache.spark.sql.parquet" - ) - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) - - shutil.rmtree(tmpPath) + try: + df.write.json(tmpPath) + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.spark.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.mode("overwrite").json(tmpPath) + actual = self.spark.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.mode("overwrite").options( + noUse="this options will not be used in save." + ).option("noUse", "this option will not be used in save.").format("json").save( + path=tmpPath + ) + actual = self.spark.read.format("json").load( + path=tmpPath, noUse="this options will not be used in load." + ) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + try: + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect() + actual = self.spark.read.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + finally: + self.spark.sql("RESET spark.sql.sources.default").collect() + finally: + shutil.rmtree(tmpPath) def test_bucketed_write(self): data = [ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org