This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 13fd4b32957 [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests 13fd4b32957 is described below commit 13fd4b32957bf8040ae1f7b175040ec2dff21017 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Mar 2 20:37:19 2023 +0900 [SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests ### What changes were proposed in this pull request? This PR removes `sql("command").collect()` workaround in PySpark tests codes. ### Why are the changes needed? They were added previously to work around within Spark Connect. This is fixed now, so we don't need to call `collect` anymore. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? CI in this PR should test it out. Closes #40251 from HyukjinKwon/SPARK-41725. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 79da1ab400f25dbceec45e107e5366d084138fa8) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/catalog.py | 84 ++++++++++++++--------------- python/pyspark/sql/readwriter.py | 18 +++---- python/pyspark/sql/tests/test_catalog.py | 56 +++++++++---------- python/pyspark/sql/tests/test_readwriter.py | 8 +-- python/pyspark/sql/tests/test_types.py | 4 +- python/pyspark/testing/sqlutils.py | 6 +-- 6 files changed, 83 insertions(+), 93 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index c83d02d4cb3..ccf88492acf 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -246,8 +246,6 @@ class Catalog: locationUri=jdb.locationUri(), ) - # TODO(SPARK-41725): we don't have to `collect` for every `sql` but - # Spark Connect requires it. We should remove them out. def databaseExists(self, dbName: str) -> bool: """Check if the database with the specified name exists. @@ -275,7 +273,7 @@ class Catalog: >>> spark.catalog.databaseExists("test_new_database") False - >>> _ = spark.sql("CREATE DATABASE test_new_database").collect() + >>> _ = spark.sql("CREATE DATABASE test_new_database") >>> spark.catalog.databaseExists("test_new_database") True @@ -283,7 +281,7 @@ class Catalog: >>> spark.catalog.databaseExists("spark_catalog.test_new_database") True - >>> _ = spark.sql("DROP DATABASE test_new_database").collect() + >>> _ = spark.sql("DROP DATABASE test_new_database") """ return self._jcatalog.databaseExists(dbName) @@ -372,8 +370,8 @@ class Catalog: Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") >>> spark.catalog.getTable("tbl1") Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ... @@ -383,7 +381,7 @@ class Catalog: Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ... >>> spark.catalog.getTable("spark_catalog.default.tbl1") Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ... - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") Throw an analysis exception when the table does not exist. @@ -535,7 +533,7 @@ class Catalog: Examples -------- >>> _ = spark.sql( - ... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'").collect() + ... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'") >>> spark.catalog.getFunction("my_func1") Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ... @@ -602,11 +600,11 @@ class Catalog: Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet") >>> spark.catalog.listColumns("tblA") [Column(name='name', description=None, dataType='string', nullable=True, ... - >>> _ = spark.sql("DROP TABLE tblA").collect() + >>> _ = spark.sql("DROP TABLE tblA") """ if dbName is None: iter = self._jcatalog.listColumns(tableName).toLocalIterator() @@ -667,8 +665,8 @@ class Catalog: >>> spark.catalog.tableExists("unexisting_table") False - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") >>> spark.catalog.tableExists("tbl1") True @@ -680,13 +678,13 @@ class Catalog: True >>> spark.catalog.tableExists("tbl1", "default") True - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") Check if views exist: >>> spark.catalog.tableExists("view1") False - >>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1").collect() + >>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1") >>> spark.catalog.tableExists("view1") True @@ -698,14 +696,14 @@ class Catalog: True >>> spark.catalog.tableExists("view1", "default") True - >>> _ = spark.sql("DROP VIEW view1").collect() + >>> _ = spark.sql("DROP VIEW view1") Check if temporary views exist: - >>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1").collect() + >>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1") >>> spark.catalog.tableExists("view1") True - >>> df = spark.sql("DROP VIEW view1").collect() + >>> df = spark.sql("DROP VIEW view1") >>> spark.catalog.tableExists("view1") False """ @@ -806,7 +804,7 @@ class Catalog: Creating a managed table. >>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet') - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") Creating an external table @@ -814,7 +812,7 @@ class Catalog: >>> with tempfile.TemporaryDirectory() as d: ... _ = spark.catalog.createTable( ... "tbl2", schema=spark.range(1).schema, path=d, source='parquet') - >>> _ = spark.sql("DROP TABLE tbl2").collect() + >>> _ = spark.sql("DROP TABLE tbl2") """ if path is not None: options["path"] = path @@ -954,8 +952,8 @@ class Catalog: Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") >>> spark.catalog.cacheTable("tbl1") >>> spark.catalog.isCached("tbl1") True @@ -972,7 +970,7 @@ class Catalog: >>> spark.catalog.isCached("spark_catalog.default.tbl1") True >>> spark.catalog.uncacheTable("tbl1") - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ return self._jcatalog.isCached(tableName) @@ -994,8 +992,8 @@ class Catalog: Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") >>> spark.catalog.cacheTable("tbl1") Throw an analysis exception when the table does not exist. @@ -1009,7 +1007,7 @@ class Catalog: >>> spark.catalog.cacheTable("spark_catalog.default.tbl1") >>> spark.catalog.uncacheTable("tbl1") - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ self._jcatalog.cacheTable(tableName) @@ -1031,8 +1029,8 @@ class Catalog: Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") >>> spark.catalog.cacheTable("tbl1") >>> spark.catalog.uncacheTable("tbl1") >>> spark.catalog.isCached("tbl1") @@ -1050,7 +1048,7 @@ class Catalog: >>> spark.catalog.uncacheTable("spark_catalog.default.tbl1") >>> spark.catalog.isCached("tbl1") False - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ self._jcatalog.uncacheTable(tableName) @@ -1064,12 +1062,12 @@ class Catalog: Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() - >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1") + >>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet") >>> spark.catalog.clearCache() >>> spark.catalog.isCached("tbl1") False - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ self._jcatalog.clearCache() @@ -1095,10 +1093,10 @@ class Catalog: >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: - ... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() + ... _ = spark.sql("DROP TABLE IF EXISTS tbl1") ... _ = spark.sql( - ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect() - ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect() + ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)) + ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'") ... spark.catalog.cacheTable("tbl1") ... spark.table("tbl1").show() +---+ @@ -1121,7 +1119,7 @@ class Catalog: Using the fully qualified name for the table. >>> spark.catalog.refreshTable("spark_catalog.default.tbl1") - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ self._jcatalog.refreshTable(tableName) @@ -1149,12 +1147,12 @@ class Catalog: >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: - ... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() + ... _ = spark.sql("DROP TABLE IF EXISTS tbl1") ... spark.range(1).selectExpr( ... "id as key", "id as value").write.partitionBy("key").mode("overwrite").save(d) ... _ = spark.sql( ... "CREATE TABLE tbl1 (key LONG, value LONG)" - ... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d)).collect() + ... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d)) ... spark.table("tbl1").show() ... spark.catalog.recoverPartitions("tbl1") ... spark.table("tbl1").show() @@ -1167,7 +1165,7 @@ class Catalog: +-----+---+ | 0| 0| +-----+---+ - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ self._jcatalog.recoverPartitions(tableName) @@ -1191,10 +1189,10 @@ class Catalog: >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: - ... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect() + ... _ = spark.sql("DROP TABLE IF EXISTS tbl1") ... _ = spark.sql( - ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect() - ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect() + ... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)) + ... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'") ... spark.catalog.cacheTable("tbl1") ... spark.table("tbl1").show() +---+ @@ -1214,7 +1212,7 @@ class Catalog: >>> spark.table("tbl1").count() 0 - >>> _ = spark.sql("DROP TABLE tbl1").collect() + >>> _ = spark.sql("DROP TABLE tbl1") """ self._jcatalog.refreshByPath(path) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 93fd938dff4..17b59311648 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -466,7 +466,7 @@ class DataFrameReader(OptionUtils): | 8| | 9| +---+ - >>> _ = spark.sql("DROP TABLE tblA").collect() + >>> _ = spark.sql("DROP TABLE tblA") """ return self._df(self._jreader.table(tableName)) @@ -1232,7 +1232,7 @@ class DataFrameWriter(OptionUtils): >>> from pyspark.sql.functions import input_file_name >>> # Write a DataFrame into a Parquet file in a bucketed manner. - ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect() + ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table") >>> spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1246,7 +1246,7 @@ class DataFrameWriter(OptionUtils): |120|Hyukjin Kwon| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE bucketed_table").collect() + >>> _ = spark.sql("DROP TABLE bucketed_table") """ if not isinstance(numBuckets, int): raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) @@ -1296,7 +1296,7 @@ class DataFrameWriter(OptionUtils): >>> from pyspark.sql.functions import input_file_name >>> # Write a DataFrame into a Parquet file in a sorted-bucketed manner. - ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table").collect() + ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table") >>> spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1311,7 +1311,7 @@ class DataFrameWriter(OptionUtils): |120|Hyukjin Kwon| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect() + >>> _ = spark.sql("DROP TABLE sorted_bucketed_table") """ if isinstance(col, (list, tuple)): if cols: @@ -1417,7 +1417,7 @@ class DataFrameWriter(OptionUtils): Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tblA") >>> df = spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1438,7 +1438,7 @@ class DataFrameWriter(OptionUtils): |140| Haejoon Lee| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE tblA").collect() + >>> _ = spark.sql("DROP TABLE tblA") """ if overwrite is not None: self.mode("overwrite" if overwrite else "append") @@ -1495,7 +1495,7 @@ class DataFrameWriter(OptionUtils): -------- Creates a table from a DataFrame, and read it back. - >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect() + >>> _ = spark.sql("DROP TABLE IF EXISTS tblA") >>> spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1508,7 +1508,7 @@ class DataFrameWriter(OptionUtils): |120|Hyukjin Kwon| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE tblA").collect() + >>> _ = spark.sql("DROP TABLE tblA") """ self.mode(mode).options(**options) if partitionBy is not None: diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 10f3ec12c9c..ae92ce57dc8 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -24,7 +24,7 @@ class CatalogTestsMixin: spark = self.spark with self.database("some_db"): self.assertEqual(spark.catalog.currentDatabase(), "default") - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") spark.catalog.setCurrentDatabase("some_db") self.assertEqual(spark.catalog.currentDatabase(), "some_db") self.assertRaisesRegex( @@ -38,7 +38,7 @@ class CatalogTestsMixin: with self.database("some_db"): databases = [db.name for db in spark.catalog.listDatabases()] self.assertEqual(databases, ["default"]) - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") databases = [db.name for db in spark.catalog.listDatabases()] self.assertEqual(sorted(databases), ["default", "some_db"]) @@ -47,7 +47,7 @@ class CatalogTestsMixin: spark = self.spark with self.database("some_db"): self.assertFalse(spark.catalog.databaseExists("some_db")) - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") self.assertTrue(spark.catalog.databaseExists("some_db")) self.assertTrue(spark.catalog.databaseExists("spark_catalog.some_db")) self.assertFalse(spark.catalog.databaseExists("spark_catalog.some_db2")) @@ -55,7 +55,7 @@ class CatalogTestsMixin: def test_get_database(self): spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") db = spark.catalog.getDatabase("spark_catalog.some_db") self.assertEqual(db.name, "some_db") self.assertEqual(db.catalog, "spark_catalog") @@ -65,16 +65,14 @@ class CatalogTestsMixin: spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") with self.table("tab1", "some_db.tab2", "tab3_via_catalog"): with self.tempView("temp_tab"): self.assertEqual(spark.catalog.listTables(), []) self.assertEqual(spark.catalog.listTables("some_db"), []) spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab") - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect() - spark.sql( - "CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet" - ).collect() + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") + spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet") schema = StructType([StructField("a", IntegerType(), True)]) description = "this a table created via Catalog.createTable()" @@ -187,7 +185,7 @@ class CatalogTestsMixin: def test_list_functions(self): spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") functions = dict((f.name, f) for f in spark.catalog.listFunctions()) functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) self.assertTrue(len(functions) > 200) @@ -215,10 +213,8 @@ class CatalogTestsMixin: if support_udf: spark.udf.register("temp_func", lambda x: str(x)) - spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect() - spark.sql( - "CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'" - ).collect() + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") + spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) newFunctionsSomeDb = dict( (f.name, f) for f in spark.catalog.listFunctions("some_db") @@ -247,7 +243,7 @@ class CatalogTestsMixin: self.assertFalse(spark.catalog.functionExists("default.func1")) self.assertFalse(spark.catalog.functionExists("spark_catalog.default.func1")) self.assertFalse(spark.catalog.functionExists("func1", "default")) - spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect() + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") self.assertTrue(spark.catalog.functionExists("func1")) self.assertTrue(spark.catalog.functionExists("default.func1")) self.assertTrue(spark.catalog.functionExists("spark_catalog.default.func1")) @@ -256,7 +252,7 @@ class CatalogTestsMixin: def test_get_function(self): spark = self.spark with self.function("func1"): - spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'").collect() + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") func1 = spark.catalog.getFunction("spark_catalog.default.func1") self.assertTrue(func1.name == "func1") self.assertTrue(func1.namespace == ["default"]) @@ -269,12 +265,12 @@ class CatalogTestsMixin: spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") with self.table("tab1", "some_db.tab2"): - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect() + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") spark.sql( "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet" - ).collect() + ) columns = sorted( spark.catalog.listColumns("spark_catalog.default.tab1"), key=lambda c: c.name ) @@ -343,11 +339,9 @@ class CatalogTestsMixin: def test_table_cache(self): spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") with self.table("tab1"): - spark.sql( - "CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet" - ).collect() + spark.sql("CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet") self.assertFalse(spark.catalog.isCached("some_db.tab1")) self.assertFalse(spark.catalog.isCached("spark_catalog.some_db.tab1")) spark.catalog.cacheTable("spark_catalog.some_db.tab1") @@ -361,18 +355,16 @@ class CatalogTestsMixin: # SPARK-36176: testing that table_exists returns correct boolean spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") with self.table("tab1", "some_db.tab2"): self.assertFalse(spark.catalog.tableExists("tab1")) self.assertFalse(spark.catalog.tableExists("tab2", "some_db")) - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect() + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") self.assertTrue(spark.catalog.tableExists("tab1")) self.assertTrue(spark.catalog.tableExists("default.tab1")) self.assertTrue(spark.catalog.tableExists("spark_catalog.default.tab1")) self.assertTrue(spark.catalog.tableExists("tab1", "default")) - spark.sql( - "CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet" - ).collect() + spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet") self.assertFalse(spark.catalog.tableExists("tab2")) self.assertTrue(spark.catalog.tableExists("some_db.tab2")) self.assertTrue(spark.catalog.tableExists("spark_catalog.some_db.tab2")) @@ -381,9 +373,9 @@ class CatalogTestsMixin: def test_get_table(self): spark = self.spark with self.database("some_db"): - spark.sql("CREATE DATABASE some_db").collect() + spark.sql("CREATE DATABASE some_db") with self.table("tab1"): - spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet").collect() + spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet") self.assertEqual(spark.catalog.getTable("tab1").database, "default") self.assertEqual(spark.catalog.getTable("default.tab1").catalog, "spark_catalog") self.assertEqual(spark.catalog.getTable("spark_catalog.default.tab1").name, "tab1") @@ -397,8 +389,8 @@ class CatalogTestsMixin: with self.table("my_tab"): spark.sql( "CREATE TABLE my_tab (col STRING) USING TEXT LOCATION '{}'".format(tmp_dir) - ).collect() - spark.sql("INSERT INTO my_tab SELECT 'abc'").collect() + ) + spark.sql("INSERT INTO my_tab SELECT 'abc'") spark.catalog.cacheTable("my_tab") self.assertEqual(spark.table("my_tab").count(), 1) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 21c66284ace..17c158a870a 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -56,11 +56,11 @@ class ReadwriterTestsMixin: self.assertEqual(sorted(df.collect()), sorted(actual.collect())) try: - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect() + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) finally: - self.spark.sql("RESET spark.sql.sources.default").collect() + self.spark.sql("RESET spark.sql.sources.default") csvpath = os.path.join(tempfile.mkdtemp(), "data") df.write.option("quote", None).format("csv").save(csvpath) @@ -95,11 +95,11 @@ class ReadwriterTestsMixin: self.assertEqual(sorted(df.collect()), sorted(actual.collect())) try: - self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect() + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) finally: - self.spark.sql("RESET spark.sql.sources.default").collect() + self.spark.sql("RESET spark.sql.sources.default") finally: shutil.rmtree(tmpPath) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 68424cad386..9db090fa810 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -379,13 +379,13 @@ class TypesTestsMixin: def test_negative_decimal(self): try: - self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true").collect() + self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true") df = self.spark.createDataFrame([(1,), (11,)], ["value"]) ret = df.select(col("value").cast(DecimalType(1, -1))).collect() actual = list(map(lambda r: int(r.value), ret)) self.assertEqual(actual, [0, 10]) finally: - self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false").collect() + self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false") def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 937ad491479..077d854b1dd 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -202,7 +202,7 @@ class SQLTestUtils: yield finally: for db in databases: - self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db).collect() + self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) self.spark.catalog.setCurrentDatabase("default") @contextmanager @@ -217,7 +217,7 @@ class SQLTestUtils: yield finally: for t in tables: - self.spark.sql("DROP TABLE IF EXISTS %s" % t).collect() + self.spark.sql("DROP TABLE IF EXISTS %s" % t) @contextmanager def tempView(self, *views): @@ -245,7 +245,7 @@ class SQLTestUtils: yield finally: for f in functions: - self.spark.sql("DROP FUNCTION IF EXISTS %s" % f).collect() + self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) @staticmethod def assert_close(a, b): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org