Repository: spark
Updated Branches:
  refs/heads/master 1117fc35f -> e80f18dbd


[SPARK-25763][SQL][PYSPARK][TEST] Use more `@contextmanager` to ensure clean-up 
each test.

## What changes were proposed in this pull request?

Currently each test in `SQLTest` in PySpark is not cleaned properly.
We should introduce and use more `contextmanager` to be convenient to clean up 
the context properly.

## How was this patch tested?

Modified tests.

Closes #22762 from ueshin/issues/SPARK-25763/cleanup_sqltests.

Authored-by: Takuya UESHIN <ues...@databricks.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e80f18db
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e80f18db
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e80f18db

Branch: refs/heads/master
Commit: e80f18dbd8bc4c2aca9ba6dd487b50e95c55d2e6
Parents: 1117fc3
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Fri Oct 19 00:31:01 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Fri Oct 19 00:31:01 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 556 ++++++++++++++++++++++-----------------
 1 file changed, 318 insertions(+), 238 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e80f18db/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8065d82..82dc5a6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -225,6 +225,63 @@ class SQLTestUtils(object):
                 else:
                     self.spark.conf.set(key, old_value)
 
+    @contextmanager
+    def database(self, *databases):
+        """
+        A convenient context manager to test with some specific databases. 
This drops the given
+        databases if exist and sets current database to "default" when it 
exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for db in databases:
+                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
+            self.spark.catalog.setCurrentDatabase("default")
+
+    @contextmanager
+    def table(self, *tables):
+        """
+        A convenient context manager to test with some specific tables. This 
drops the given tables
+        if exist when it exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for t in tables:
+                self.spark.sql("DROP TABLE IF EXISTS %s" % t)
+
+    @contextmanager
+    def tempView(self, *views):
+        """
+        A convenient context manager to test with some specific views. This 
drops the given views
+        if exist when it exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for v in views:
+                self.spark.catalog.dropTempView(v)
+
+    @contextmanager
+    def function(self, *functions):
+        """
+        A convenient context manager to test with some specific functions. 
This drops the given
+        functions if exist when it exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for f in functions:
+                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
+
 
 class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
     @classmethod
@@ -332,6 +389,7 @@ class SQLTests(ReusedSQLTestCase):
     @classmethod
     def setUpClass(cls):
         ReusedSQLTestCase.setUpClass()
+        cls.spark.catalog._reset()
         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
         os.unlink(cls.tempdir.name)
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
@@ -347,12 +405,6 @@ class SQLTests(ReusedSQLTestCase):
         sqlContext2 = SQLContext(self.sc)
         self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
 
-    def tearDown(self):
-        super(SQLTests, self).tearDown()
-
-        # tear down test_bucketed_write state
-        self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket")
-
     def test_row_should_be_read_only(self):
         row = Row(a=1, b=2)
         self.assertEqual(1, row.a)
@@ -473,11 +525,12 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(row[0], 4)
 
     def test_udf2(self):
-        self.spark.catalog.registerFunction("strlen", lambda string: 
len(string), IntegerType())
-        self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
-            .createOrReplaceTempView("test")
-        [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 
1").collect()
-        self.assertEqual(4, res[0])
+        with self.tempView("test"):
+            self.spark.catalog.registerFunction("strlen", lambda string: 
len(string), IntegerType())
+            self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
+                .createOrReplaceTempView("test")
+            [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) 
> 1").collect()
+            self.assertEqual(4, res[0])
 
     def test_udf3(self):
         two_args = self.spark.catalog.registerFunction(
@@ -666,14 +719,16 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(row[0], "bar")
 
     def test_udf_with_array_type(self):
-        d = [Row(l=list(range(3)), d={"key": list(range(5))})]
-        rdd = self.sc.parallelize(d)
-        self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
-        self.spark.catalog.registerFunction("copylist", lambda l: list(l), 
ArrayType(IntegerType()))
-        self.spark.catalog.registerFunction("maplen", lambda d: len(d), 
IntegerType())
-        [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from 
test").collect()
-        self.assertEqual(list(range(3)), l1)
-        self.assertEqual(1, l2)
+        with self.tempView("test"):
+            d = [Row(l=list(range(3)), d={"key": list(range(5))})]
+            rdd = self.sc.parallelize(d)
+            self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
+            self.spark.catalog.registerFunction(
+                "copylist", lambda l: list(l), ArrayType(IntegerType()))
+            self.spark.catalog.registerFunction("maplen", lambda d: len(d), 
IntegerType())
+            [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from 
test").collect()
+            self.assertEqual(list(range(3)), l1)
+            self.assertEqual(1, l2)
 
     def test_broadcast_in_udf(self):
         bar = {"a": "aa", "b": "bb", "c": "abc"}
@@ -1061,10 +1116,11 @@ class SQLTests(ReusedSQLTestCase):
         self.assertTrue(df.is_cached)
         self.assertEqual(2, df.count())
 
-        df.createOrReplaceTempView("temp")
-        df = self.spark.sql("select foo from temp")
-        df.count()
-        df.collect()
+        with self.tempView("temp"):
+            df.createOrReplaceTempView("temp")
+            df = self.spark.sql("select foo from temp")
+            df.count()
+            df.collect()
 
     def test_apply_schema_to_row(self):
         df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
@@ -1137,17 +1193,21 @@ class SQLTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(rdd)
         self.assertEqual([], df.rdd.map(lambda r: r.l).first())
         self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
-        df.createOrReplaceTempView("test")
-        result = self.spark.sql("SELECT l[0].a from test where d['key'].d = 
'2'")
-        self.assertEqual(1, result.head()[0])
+
+        with self.tempView("test"):
+            df.createOrReplaceTempView("test")
+            result = self.spark.sql("SELECT l[0].a from test where d['key'].d 
= '2'")
+            self.assertEqual(1, result.head()[0])
 
         df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
         self.assertEqual(df.schema, df2.schema)
         self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
         self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
-        df2.createOrReplaceTempView("test2")
-        result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = 
'2'")
-        self.assertEqual(1, result.head()[0])
+
+        with self.tempView("test2"):
+            df2.createOrReplaceTempView("test2")
+            result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d 
= '2'")
+            self.assertEqual(1, result.head()[0])
 
     def test_infer_schema_specification(self):
         from decimal import Decimal
@@ -1286,12 +1346,13 @@ class SQLTests(ReusedSQLTestCase):
              datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
         self.assertEqual(r, results.first())
 
-        df.createOrReplaceTempView("table2")
-        r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
-                           "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 
1 AS int1, " +
-                           "float1 + 1.5 as float1 FROM table2").first()
+        with self.tempView("table2"):
+            df.createOrReplaceTempView("table2")
+            r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, 
" +
+                               "short1 + 1 AS short1, short2 - 1 AS short2, 
int1 - 1 AS int1, " +
+                               "float1 + 1.5 as float1 FROM table2").first()
 
-        self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
+            self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), 
tuple(r))
 
     def test_struct_in_map(self):
         d = [Row(m={Row(i=1): Row(s="")})]
@@ -1304,10 +1365,12 @@ class SQLTests(ReusedSQLTestCase):
         row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
         self.assertEqual(1, row.asDict()['l'][0].a)
         df = self.sc.parallelize([row]).toDF()
-        df.createOrReplaceTempView("test")
-        row = self.spark.sql("select l, d from test").head()
-        self.assertEqual(1, row.asDict()["l"][0].a)
-        self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+        with self.tempView("test"):
+            df.createOrReplaceTempView("test")
+            row = self.spark.sql("select l, d from test").head()
+            self.assertEqual(1, row.asDict()["l"][0].a)
+            self.assertEqual(1.0, row.asDict()['d']['key'].c)
 
     def test_udt(self):
         from pyspark.sql.types import _parse_datatype_json_string, 
_infer_type, _make_type_verifier
@@ -1401,18 +1464,22 @@ class SQLTests(ReusedSQLTestCase):
         schema = df.schema
         field = [f for f in schema.fields if f.name == "point"][0]
         self.assertEqual(type(field.dataType), ExamplePointUDT)
-        df.createOrReplaceTempView("labeled_point")
-        point = self.spark.sql("SELECT point FROM labeled_point").head().point
-        self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+        with self.tempView("labeled_point"):
+            df.createOrReplaceTempView("labeled_point")
+            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
+            self.assertEqual(point, ExamplePoint(1.0, 2.0))
 
         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
         schema = df.schema
         field = [f for f in schema.fields if f.name == "point"][0]
         self.assertEqual(type(field.dataType), PythonOnlyUDT)
-        df.createOrReplaceTempView("labeled_point")
-        point = self.spark.sql("SELECT point FROM labeled_point").head().point
-        self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+        with self.tempView("labeled_point"):
+            df.createOrReplaceTempView("labeled_point")
+            point = self.spark.sql("SELECT point FROM 
labeled_point").head().point
+            self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
     def test_apply_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
@@ -3053,187 +3120,199 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_current_database(self):
         spark = self.spark
-        spark.catalog._reset()
-        self.assertEquals(spark.catalog.currentDatabase(), "default")
-        spark.sql("CREATE DATABASE some_db")
-        spark.catalog.setCurrentDatabase("some_db")
-        self.assertEquals(spark.catalog.currentDatabase(), "some_db")
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
+        with self.database("some_db"):
+            self.assertEquals(spark.catalog.currentDatabase(), "default")
+            spark.sql("CREATE DATABASE some_db")
+            spark.catalog.setCurrentDatabase("some_db")
+            self.assertEquals(spark.catalog.currentDatabase(), "some_db")
+            self.assertRaisesRegexp(
+                AnalysisException,
+                "does_not_exist",
+                lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
 
     def test_list_databases(self):
         spark = self.spark
-        spark.catalog._reset()
-        databases = [db.name for db in spark.catalog.listDatabases()]
-        self.assertEquals(databases, ["default"])
-        spark.sql("CREATE DATABASE some_db")
-        databases = [db.name for db in spark.catalog.listDatabases()]
-        self.assertEquals(sorted(databases), ["default", "some_db"])
+        with self.database("some_db"):
+            databases = [db.name for db in spark.catalog.listDatabases()]
+            self.assertEquals(databases, ["default"])
+            spark.sql("CREATE DATABASE some_db")
+            databases = [db.name for db in spark.catalog.listDatabases()]
+            self.assertEquals(sorted(databases), ["default", "some_db"])
 
     def test_list_tables(self):
         from pyspark.sql.catalog import Table
         spark = self.spark
-        spark.catalog._reset()
-        spark.sql("CREATE DATABASE some_db")
-        self.assertEquals(spark.catalog.listTables(), [])
-        self.assertEquals(spark.catalog.listTables("some_db"), [])
-        spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
-        spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
-        spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING 
parquet")
-        tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
-        tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda 
t: t.name)
-        tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda 
t: t.name)
-        self.assertEquals(tables, tablesDefault)
-        self.assertEquals(len(tables), 2)
-        self.assertEquals(len(tablesSomeDb), 2)
-        self.assertEquals(tables[0], Table(
-            name="tab1",
-            database="default",
-            description=None,
-            tableType="MANAGED",
-            isTemporary=False))
-        self.assertEquals(tables[1], Table(
-            name="temp_tab",
-            database=None,
-            description=None,
-            tableType="TEMPORARY",
-            isTemporary=True))
-        self.assertEquals(tablesSomeDb[0], Table(
-            name="tab2",
-            database="some_db",
-            description=None,
-            tableType="MANAGED",
-            isTemporary=False))
-        self.assertEquals(tablesSomeDb[1], Table(
-            name="temp_tab",
-            database=None,
-            description=None,
-            tableType="TEMPORARY",
-            isTemporary=True))
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.listTables("does_not_exist"))
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            with self.table("tab1", "some_db.tab2"):
+                with self.tempView("temp_tab"):
+                    self.assertEquals(spark.catalog.listTables(), [])
+                    self.assertEquals(spark.catalog.listTables("some_db"), [])
+                    spark.createDataFrame([(1, 
1)]).createOrReplaceTempView("temp_tab")
+                    spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING 
parquet")
+                    spark.sql("CREATE TABLE some_db.tab2 (name STRING, age 
INT) USING parquet")
+                    tables = sorted(spark.catalog.listTables(), key=lambda t: 
t.name)
+                    tablesDefault = \
+                        sorted(spark.catalog.listTables("default"), key=lambda 
t: t.name)
+                    tablesSomeDb = \
+                        sorted(spark.catalog.listTables("some_db"), key=lambda 
t: t.name)
+                    self.assertEquals(tables, tablesDefault)
+                    self.assertEquals(len(tables), 2)
+                    self.assertEquals(len(tablesSomeDb), 2)
+                    self.assertEquals(tables[0], Table(
+                        name="tab1",
+                        database="default",
+                        description=None,
+                        tableType="MANAGED",
+                        isTemporary=False))
+                    self.assertEquals(tables[1], Table(
+                        name="temp_tab",
+                        database=None,
+                        description=None,
+                        tableType="TEMPORARY",
+                        isTemporary=True))
+                    self.assertEquals(tablesSomeDb[0], Table(
+                        name="tab2",
+                        database="some_db",
+                        description=None,
+                        tableType="MANAGED",
+                        isTemporary=False))
+                    self.assertEquals(tablesSomeDb[1], Table(
+                        name="temp_tab",
+                        database=None,
+                        description=None,
+                        tableType="TEMPORARY",
+                        isTemporary=True))
+                    self.assertRaisesRegexp(
+                        AnalysisException,
+                        "does_not_exist",
+                        lambda: spark.catalog.listTables("does_not_exist"))
 
     def test_list_functions(self):
         from pyspark.sql.catalog import Function
         spark = self.spark
-        spark.catalog._reset()
-        spark.sql("CREATE DATABASE some_db")
-        functions = dict((f.name, f) for f in spark.catalog.listFunctions())
-        functionsDefault = dict((f.name, f) for f in 
spark.catalog.listFunctions("default"))
-        self.assertTrue(len(functions) > 200)
-        self.assertTrue("+" in functions)
-        self.assertTrue("like" in functions)
-        self.assertTrue("month" in functions)
-        self.assertTrue("to_date" in functions)
-        self.assertTrue("to_timestamp" in functions)
-        self.assertTrue("to_unix_timestamp" in functions)
-        self.assertTrue("current_database" in functions)
-        self.assertEquals(functions["+"], Function(
-            name="+",
-            description=None,
-            className="org.apache.spark.sql.catalyst.expressions.Add",
-            isTemporary=True))
-        self.assertEquals(functions, functionsDefault)
-        spark.catalog.registerFunction("temp_func", lambda x: str(x))
-        spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
-        spark.sql("CREATE FUNCTION some_db.func2 AS 
'org.apache.spark.data.bricks'")
-        newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
-        newFunctionsSomeDb = dict((f.name, f) for f in 
spark.catalog.listFunctions("some_db"))
-        self.assertTrue(set(functions).issubset(set(newFunctions)))
-        self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
-        self.assertTrue("temp_func" in newFunctions)
-        self.assertTrue("func1" in newFunctions)
-        self.assertTrue("func2" not in newFunctions)
-        self.assertTrue("temp_func" in newFunctionsSomeDb)
-        self.assertTrue("func1" not in newFunctionsSomeDb)
-        self.assertTrue("func2" in newFunctionsSomeDb)
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.listFunctions("does_not_exist"))
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            functions = dict((f.name, f) for f in 
spark.catalog.listFunctions())
+            functionsDefault = dict((f.name, f) for f in 
spark.catalog.listFunctions("default"))
+            self.assertTrue(len(functions) > 200)
+            self.assertTrue("+" in functions)
+            self.assertTrue("like" in functions)
+            self.assertTrue("month" in functions)
+            self.assertTrue("to_date" in functions)
+            self.assertTrue("to_timestamp" in functions)
+            self.assertTrue("to_unix_timestamp" in functions)
+            self.assertTrue("current_database" in functions)
+            self.assertEquals(functions["+"], Function(
+                name="+",
+                description=None,
+                className="org.apache.spark.sql.catalyst.expressions.Add",
+                isTemporary=True))
+            self.assertEquals(functions, functionsDefault)
+
+            with self.function("func1", "some_db.func2"):
+                spark.catalog.registerFunction("temp_func", lambda x: str(x))
+                spark.sql("CREATE FUNCTION func1 AS 
'org.apache.spark.data.bricks'")
+                spark.sql("CREATE FUNCTION some_db.func2 AS 
'org.apache.spark.data.bricks'")
+                newFunctions = dict((f.name, f) for f in 
spark.catalog.listFunctions())
+                newFunctionsSomeDb = \
+                    dict((f.name, f) for f in 
spark.catalog.listFunctions("some_db"))
+                self.assertTrue(set(functions).issubset(set(newFunctions)))
+                
self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
+                self.assertTrue("temp_func" in newFunctions)
+                self.assertTrue("func1" in newFunctions)
+                self.assertTrue("func2" not in newFunctions)
+                self.assertTrue("temp_func" in newFunctionsSomeDb)
+                self.assertTrue("func1" not in newFunctionsSomeDb)
+                self.assertTrue("func2" in newFunctionsSomeDb)
+                self.assertRaisesRegexp(
+                    AnalysisException,
+                    "does_not_exist",
+                    lambda: spark.catalog.listFunctions("does_not_exist"))
 
     def test_list_columns(self):
         from pyspark.sql.catalog import Column
         spark = self.spark
-        spark.catalog._reset()
-        spark.sql("CREATE DATABASE some_db")
-        spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
-        spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance 
FLOAT) USING parquet")
-        columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: 
c.name)
-        columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), 
key=lambda c: c.name)
-        self.assertEquals(columns, columnsDefault)
-        self.assertEquals(len(columns), 2)
-        self.assertEquals(columns[0], Column(
-            name="age",
-            description=None,
-            dataType="int",
-            nullable=True,
-            isPartition=False,
-            isBucket=False))
-        self.assertEquals(columns[1], Column(
-            name="name",
-            description=None,
-            dataType="string",
-            nullable=True,
-            isPartition=False,
-            isBucket=False))
-        columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), 
key=lambda c: c.name)
-        self.assertEquals(len(columns2), 2)
-        self.assertEquals(columns2[0], Column(
-            name="nickname",
-            description=None,
-            dataType="string",
-            nullable=True,
-            isPartition=False,
-            isBucket=False))
-        self.assertEquals(columns2[1], Column(
-            name="tolerance",
-            description=None,
-            dataType="float",
-            nullable=True,
-            isPartition=False,
-            isBucket=False))
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "tab2",
-            lambda: spark.catalog.listColumns("tab2"))
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.listColumns("does_not_exist"))
+        with self.database("some_db"):
+            spark.sql("CREATE DATABASE some_db")
+            with self.table("tab1", "some_db.tab2"):
+                spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING 
parquet")
+                spark.sql(
+                    "CREATE TABLE some_db.tab2 (nickname STRING, tolerance 
FLOAT) USING parquet")
+                columns = sorted(spark.catalog.listColumns("tab1"), key=lambda 
c: c.name)
+                columnsDefault = \
+                    sorted(spark.catalog.listColumns("tab1", "default"), 
key=lambda c: c.name)
+                self.assertEquals(columns, columnsDefault)
+                self.assertEquals(len(columns), 2)
+                self.assertEquals(columns[0], Column(
+                    name="age",
+                    description=None,
+                    dataType="int",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                self.assertEquals(columns[1], Column(
+                    name="name",
+                    description=None,
+                    dataType="string",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                columns2 = \
+                    sorted(spark.catalog.listColumns("tab2", "some_db"), 
key=lambda c: c.name)
+                self.assertEquals(len(columns2), 2)
+                self.assertEquals(columns2[0], Column(
+                    name="nickname",
+                    description=None,
+                    dataType="string",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                self.assertEquals(columns2[1], Column(
+                    name="tolerance",
+                    description=None,
+                    dataType="float",
+                    nullable=True,
+                    isPartition=False,
+                    isBucket=False))
+                self.assertRaisesRegexp(
+                    AnalysisException,
+                    "tab2",
+                    lambda: spark.catalog.listColumns("tab2"))
+                self.assertRaisesRegexp(
+                    AnalysisException,
+                    "does_not_exist",
+                    lambda: spark.catalog.listColumns("does_not_exist"))
 
     def test_cache(self):
         spark = self.spark
-        spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
-        spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
-        self.assertFalse(spark.catalog.isCached("tab1"))
-        self.assertFalse(spark.catalog.isCached("tab2"))
-        spark.catalog.cacheTable("tab1")
-        self.assertTrue(spark.catalog.isCached("tab1"))
-        self.assertFalse(spark.catalog.isCached("tab2"))
-        spark.catalog.cacheTable("tab2")
-        spark.catalog.uncacheTable("tab1")
-        self.assertFalse(spark.catalog.isCached("tab1"))
-        self.assertTrue(spark.catalog.isCached("tab2"))
-        spark.catalog.clearCache()
-        self.assertFalse(spark.catalog.isCached("tab1"))
-        self.assertFalse(spark.catalog.isCached("tab2"))
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.isCached("does_not_exist"))
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.cacheTable("does_not_exist"))
-        self.assertRaisesRegexp(
-            AnalysisException,
-            "does_not_exist",
-            lambda: spark.catalog.uncacheTable("does_not_exist"))
+        with self.tempView("tab1", "tab2"):
+            spark.createDataFrame([(2, 2), (3, 
3)]).createOrReplaceTempView("tab1")
+            spark.createDataFrame([(2, 2), (3, 
3)]).createOrReplaceTempView("tab2")
+            self.assertFalse(spark.catalog.isCached("tab1"))
+            self.assertFalse(spark.catalog.isCached("tab2"))
+            spark.catalog.cacheTable("tab1")
+            self.assertTrue(spark.catalog.isCached("tab1"))
+            self.assertFalse(spark.catalog.isCached("tab2"))
+            spark.catalog.cacheTable("tab2")
+            spark.catalog.uncacheTable("tab1")
+            self.assertFalse(spark.catalog.isCached("tab1"))
+            self.assertTrue(spark.catalog.isCached("tab2"))
+            spark.catalog.clearCache()
+            self.assertFalse(spark.catalog.isCached("tab1"))
+            self.assertFalse(spark.catalog.isCached("tab2"))
+            self.assertRaisesRegexp(
+                AnalysisException,
+                "does_not_exist",
+                lambda: spark.catalog.isCached("does_not_exist"))
+            self.assertRaisesRegexp(
+                AnalysisException,
+                "does_not_exist",
+                lambda: spark.catalog.cacheTable("does_not_exist"))
+            self.assertRaisesRegexp(
+                AnalysisException,
+                "does_not_exist",
+                lambda: spark.catalog.uncacheTable("does_not_exist"))
 
     def test_read_text_file_list(self):
         df = self.spark.read.text(['python/test_support/sql/text-test.txt',
@@ -3358,37 +3437,38 @@ class SQLTests(ReusedSQLTestCase):
             num = len([c for c in cols if c.name in names and c.isBucket])
             return num
 
-        # Test write with one bucketing column
-        df.write.bucketBy(3, 
"x").mode("overwrite").saveAsTable("pyspark_bucket")
-        self.assertEqual(count_bucketed_cols(["x"]), 1)
-        self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
-
-        # Test write two bucketing columns
-        df.write.bucketBy(3, "x", 
"y").mode("overwrite").saveAsTable("pyspark_bucket")
-        self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
-        self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
-
-        # Test write with bucket and sort
-        df.write.bucketBy(2, 
"x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
-        self.assertEqual(count_bucketed_cols(["x"]), 1)
-        self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
-
-        # Test write with a list of columns
-        df.write.bucketBy(3, ["x", 
"y"]).mode("overwrite").saveAsTable("pyspark_bucket")
-        self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
-        self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
-
-        # Test write with bucket and sort with a list of columns
-        (df.write.bucketBy(2, "x")
-            .sortBy(["y", "z"])
-            .mode("overwrite").saveAsTable("pyspark_bucket"))
-        self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
-
-        # Test write with bucket and sort with multiple columns
-        (df.write.bucketBy(2, "x")
-            .sortBy("y", "z")
-            .mode("overwrite").saveAsTable("pyspark_bucket"))
-        self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
+        with self.table("pyspark_bucket"):
+            # Test write with one bucketing column
+            df.write.bucketBy(3, 
"x").mode("overwrite").saveAsTable("pyspark_bucket")
+            self.assertEqual(count_bucketed_cols(["x"]), 1)
+            self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
+
+            # Test write two bucketing columns
+            df.write.bucketBy(3, "x", 
"y").mode("overwrite").saveAsTable("pyspark_bucket")
+            self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
+            self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
+
+            # Test write with bucket and sort
+            df.write.bucketBy(2, 
"x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
+            self.assertEqual(count_bucketed_cols(["x"]), 1)
+            self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
+
+            # Test write with a list of columns
+            df.write.bucketBy(3, ["x", 
"y"]).mode("overwrite").saveAsTable("pyspark_bucket")
+            self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
+            self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
+
+            # Test write with bucket and sort with a list of columns
+            (df.write.bucketBy(2, "x")
+                .sortBy(["y", "z"])
+                .mode("overwrite").saveAsTable("pyspark_bucket"))
+            self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
+
+            # Test write with bucket and sort with multiple columns
+            (df.write.bucketBy(2, "x")
+                .sortBy("y", "z")
+                .mode("overwrite").saveAsTable("pyspark_bucket"))
+            self.assertSetEqual(set(data), 
set(self.spark.table("pyspark_bucket").collect()))
 
     def _to_pandas(self):
         from datetime import datetime, date


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to