Repository: spark Updated Branches: refs/heads/master 4ae9fe091 -> d33e3d572
[SPARK-14988][PYTHON] SparkSession API follow-ups ## What changes were proposed in this pull request? Addresses comments in #12765. ## How was this patch tested? Python tests. Author: Andrew Or <and...@databricks.com> Closes #12784 from andrewor14/python-followup. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d33e3d57 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d33e3d57 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d33e3d57 Branch: refs/heads/master Commit: d33e3d572ed7143f151f9c96fd08407f8de340f4 Parents: 4ae9fe0 Author: Andrew Or <and...@databricks.com> Authored: Fri Apr 29 16:41:13 2016 -0700 Committer: Andrew Or <and...@databricks.com> Committed: Fri Apr 29 16:41:13 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/catalog.py | 168 +--------------- python/pyspark/sql/conf.py | 58 ++---- python/pyspark/sql/context.py | 8 +- python/pyspark/sql/session.py | 4 +- python/pyspark/sql/tests.py | 199 ++++++++++++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/RuntimeConfig.scala | 17 ++ .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../spark/sql/execution/command/cache.scala | 2 +- .../org/apache/spark/sql/internal/SQLConf.scala | 7 + 11 files changed, 256 insertions(+), 213 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/catalog.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 4f92383..9cfdd0a 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -45,45 +45,19 @@ class Catalog(object): @ignore_unicode_prefix @since(2.0) def currentDatabase(self): - """Returns the current default database in this session. - - >>> spark.catalog._reset() - >>> spark.catalog.currentDatabase() - u'default' - """ + """Returns the current default database in this session.""" return self._jcatalog.currentDatabase() @ignore_unicode_prefix @since(2.0) def setCurrentDatabase(self, dbName): - """Sets the current default database in this session. - - >>> spark.catalog._reset() - >>> spark.sql("CREATE DATABASE some_db") - DataFrame[] - >>> spark.catalog.setCurrentDatabase("some_db") - >>> spark.catalog.currentDatabase() - u'some_db' - >>> spark.catalog.setCurrentDatabase("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... - """ + """Sets the current default database in this session.""" return self._jcatalog.setCurrentDatabase(dbName) @ignore_unicode_prefix @since(2.0) def listDatabases(self): - """Returns a list of databases available across all sessions. - - >>> spark.catalog._reset() - >>> [db.name for db in spark.catalog.listDatabases()] - [u'default'] - >>> spark.sql("CREATE DATABASE some_db") - DataFrame[] - >>> [db.name for db in spark.catalog.listDatabases()] - [u'default', u'some_db'] - """ + """Returns a list of databases available across all sessions.""" iter = self._jcatalog.listDatabases().toLocalIterator() databases = [] while iter.hasNext(): @@ -101,31 +75,6 @@ class Catalog(object): If no database is specified, the current database is used. This includes all temporary tables. - - >>> spark.catalog._reset() - >>> spark.sql("CREATE DATABASE some_db") - DataFrame[] - >>> spark.catalog.listTables() - [] - >>> spark.catalog.listTables("some_db") - [] - >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_temp_tab") - >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)") - DataFrame[] - >>> spark.sql("CREATE TABLE some_db.my_tab2 (name STRING, age INT)") - DataFrame[] - >>> spark.catalog.listTables() - [Table(name=u'my_tab1', database=u'default', description=None, tableType=u'MANAGED', - isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None, - tableType=u'TEMPORARY', isTemporary=True)] - >>> spark.catalog.listTables("some_db") - [Table(name=u'my_tab2', database=u'some_db', description=None, tableType=u'MANAGED', - isTemporary=False), Table(name=u'my_temp_tab', database=None, description=None, - tableType=u'TEMPORARY', isTemporary=True)] - >>> spark.catalog.listTables("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... """ if dbName is None: dbName = self.currentDatabase() @@ -148,28 +97,6 @@ class Catalog(object): If no database is specified, the current database is used. This includes all temporary functions. - - >>> spark.catalog._reset() - >>> spark.sql("CREATE DATABASE my_db") - DataFrame[] - >>> funcNames = set(f.name for f in spark.catalog.listFunctions()) - >>> set(["+", "floor", "to_unix_timestamp", "current_database"]).issubset(funcNames) - True - >>> spark.sql("CREATE FUNCTION my_func1 AS 'org.apache.spark.whatever'") - DataFrame[] - >>> spark.sql("CREATE FUNCTION my_db.my_func2 AS 'org.apache.spark.whatever'") - DataFrame[] - >>> spark.catalog.registerFunction("temp_func", lambda x: str(x)) - >>> newFuncNames = set(f.name for f in spark.catalog.listFunctions()) - funcNames - >>> newFuncNamesDb = set(f.name for f in spark.catalog.listFunctions("my_db")) - funcNames - >>> sorted(list(newFuncNames)) - [u'my_func1', u'temp_func'] - >>> sorted(list(newFuncNamesDb)) - [u'my_func2', u'temp_func'] - >>> spark.catalog.listFunctions("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... """ if dbName is None: dbName = self.currentDatabase() @@ -193,26 +120,6 @@ class Catalog(object): Note: the order of arguments here is different from that of its JVM counterpart because Python does not support method overloading. - - >>> spark.catalog._reset() - >>> spark.sql("CREATE DATABASE some_db") - DataFrame[] - >>> spark.sql("CREATE TABLE my_tab1 (name STRING, age INT)") - DataFrame[] - >>> spark.sql("CREATE TABLE some_db.my_tab2 (nickname STRING, tolerance FLOAT)") - DataFrame[] - >>> spark.catalog.listColumns("my_tab1") - [Column(name=u'name', description=None, dataType=u'string', nullable=True, - isPartition=False, isBucket=False), Column(name=u'age', description=None, - dataType=u'int', nullable=True, isPartition=False, isBucket=False)] - >>> spark.catalog.listColumns("my_tab2", "some_db") - [Column(name=u'nickname', description=None, dataType=u'string', nullable=True, - isPartition=False, isBucket=False), Column(name=u'tolerance', description=None, - dataType=u'float', nullable=True, isPartition=False, isBucket=False)] - >>> spark.catalog.listColumns("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... """ if dbName is None: dbName = self.currentDatabase() @@ -247,7 +154,7 @@ class Catalog(object): if path is not None: options["path"] = path if source is None: - source = self._sparkSession.getConf( + source = self._sparkSession.conf.get( "spark.sql.sources.default", "org.apache.spark.sql.parquet") if schema is None: df = self._jcatalog.createExternalTable(tableName, source, options) @@ -275,16 +182,16 @@ class Catalog(object): self._jcatalog.dropTempTable(tableName) @since(2.0) - def registerDataFrameAsTable(self, df, tableName): + def registerTable(self, df, tableName): """Registers the given :class:`DataFrame` as a temporary table in the catalog. >>> df = spark.createDataFrame([(2, 1), (3, 1)]) - >>> spark.catalog.registerDataFrameAsTable(df, "my_cool_table") + >>> spark.catalog.registerTable(df, "my_cool_table") >>> spark.table("my_cool_table").collect() [Row(_1=2, _2=1), Row(_1=3, _2=1)] """ if isinstance(df, DataFrame): - self._jsparkSession.registerDataFrameAsTable(df._jdf, tableName) + self._jsparkSession.registerTable(df._jdf, tableName) else: raise ValueError("Can only register DataFrame as table") @@ -321,75 +228,22 @@ class Catalog(object): @since(2.0) def isCached(self, tableName): - """Returns true if the table is currently cached in-memory. - - >>> spark.catalog._reset() - >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab") - >>> spark.catalog.isCached("my_tab") - False - >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... - """ + """Returns true if the table is currently cached in-memory.""" return self._jcatalog.isCached(tableName) @since(2.0) def cacheTable(self, tableName): - """Caches the specified table in-memory. - - >>> spark.catalog._reset() - >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab") - >>> spark.catalog.isCached("my_tab") - False - >>> spark.catalog.cacheTable("my_tab") - >>> spark.catalog.isCached("my_tab") - True - >>> spark.catalog.cacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... - """ + """Caches the specified table in-memory.""" self._jcatalog.cacheTable(tableName) @since(2.0) def uncacheTable(self, tableName): - """Removes the specified table from the in-memory cache. - - >>> spark.catalog._reset() - >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab") - >>> spark.catalog.cacheTable("my_tab") - >>> spark.catalog.isCached("my_tab") - True - >>> spark.catalog.uncacheTable("my_tab") - >>> spark.catalog.isCached("my_tab") - False - >>> spark.catalog.uncacheTable("does_not_exist") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - AnalysisException: ... - """ + """Removes the specified table from the in-memory cache.""" self._jcatalog.uncacheTable(tableName) @since(2.0) def clearCache(self): - """Removes all cached tables from the in-memory cache. - - >>> spark.catalog._reset() - >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab1") - >>> spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("my_tab2") - >>> spark.catalog.cacheTable("my_tab1") - >>> spark.catalog.cacheTable("my_tab2") - >>> spark.catalog.isCached("my_tab1") - True - >>> spark.catalog.isCached("my_tab2") - True - >>> spark.catalog.clearCache() - >>> spark.catalog.isCached("my_tab1") - False - >>> spark.catalog.isCached("my_tab2") - False - """ + """Removes all cached tables from the in-memory cache.""" self._jcatalog.clearCache() def _reset(self): http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/conf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 1d9f052..7428c91 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -33,64 +33,34 @@ class RuntimeConfig(object): @ignore_unicode_prefix @since(2.0) def set(self, key, value): - """Sets the given Spark runtime configuration property. - - >>> spark.conf.set("garble", "marble") - >>> spark.getConf("garble") - u'marble' - """ + """Sets the given Spark runtime configuration property.""" self._jconf.set(key, value) @ignore_unicode_prefix @since(2.0) - def get(self, key): + def get(self, key, default=None): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. - - >>> spark.setConf("bogo", "sipeo") - >>> spark.conf.get("bogo") - u'sipeo' - >>> spark.conf.get("definitely.not.set") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - Py4JJavaError: ... - """ - return self._jconf.get(key) - - @ignore_unicode_prefix - @since(2.0) - def getOption(self, key): - """Returns the value of Spark runtime configuration property for the given key, - or None if it is not set. - - >>> spark.setConf("bogo", "sipeo") - >>> spark.conf.getOption("bogo") - u'sipeo' - >>> spark.conf.getOption("definitely.not.set") is None - True """ - iter = self._jconf.getOption(key).iterator() - if iter.hasNext(): - return iter.next() + self._checkType(key, "key") + if default is None: + return self._jconf.get(key) else: - return None + self._checkType(default, "default") + return self._jconf.get(key, default) @ignore_unicode_prefix @since(2.0) def unset(self, key): - """Resets the configuration property for the given key. - - >>> spark.setConf("armado", "larmado") - >>> spark.getConf("armado") - u'larmado' - >>> spark.conf.unset("armado") - >>> spark.getConf("armado") # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - Py4JJavaError: ... - """ + """Resets the configuration property for the given key.""" self._jconf.unset(key) + def _checkType(self, obj, identifier): + """Assert that an object is of type str.""" + if not isinstance(obj, str) and not isinstance(obj, unicode): + raise TypeError("expected %s '%s' to be a string (was '%s')" % + (identifier, obj, type(obj).__name__)) + def _test(): import os http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 94856c2..417d719 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -127,10 +127,10 @@ class SQLContext(object): >>> sqlContext.getConf("spark.sql.shuffle.partitions") u'200' - >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10") + >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10") u'10' - >>> sqlContext.setConf("spark.sql.shuffle.partitions", "50") - >>> sqlContext.getConf("spark.sql.shuffle.partitions", "10") + >>> sqlContext.setConf("spark.sql.shuffle.partitions", u"50") + >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10") u'50' """ return self.sparkSession.getConf(key, defaultValue) @@ -301,7 +301,7 @@ class SQLContext(object): >>> sqlContext.registerDataFrameAsTable(df, "table1") """ - self.sparkSession.catalog.registerDataFrameAsTable(df, tableName) + self.sparkSession.catalog.registerTable(df, tableName) @since(1.6) def dropTempTable(self, tableName): http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/session.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3bc896..c245261 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -443,7 +443,7 @@ class SparkSession(object): :return: :class:`DataFrame` - >>> spark.catalog.registerDataFrameAsTable(df, "table1") + >>> spark.catalog.registerTable(df, "table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] @@ -456,7 +456,7 @@ class SparkSession(object): :return: :class:`DataFrame` - >>> spark.catalog.registerDataFrameAsTable(df, "table1") + >>> spark.catalog.registerTable(df, "table1") >>> df2 = spark.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1d3dc15..ea98206 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -199,7 +199,8 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.sqlCtx = SQLContext(cls.sc) + cls.sparkSession = SparkSession(cls.sc) + cls.sqlCtx = cls.sparkSession._wrapped cls.testData = [Row(key=i, value=str(i)) for i in range(100)] rdd = cls.sc.parallelize(cls.testData, 2) cls.df = rdd.toDF() @@ -1394,6 +1395,200 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(df.schema.simpleString(), "struct<value:int>") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_conf(self): + spark = self.sparkSession + spark.setConf("bogo", "sipeo") + self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo") + spark.setConf("bogo", "ta") + self.assertEqual(spark.conf.get("bogo"), "ta") + self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") + self.assertEqual(spark.conf.get("not.set", "ta"), "ta") + self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set")) + spark.conf.unset("bogo") + self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + + def test_current_database(self): + spark = self.sparkSession + spark.catalog._reset() + self.assertEquals(spark.catalog.currentDatabase(), "default") + spark.sql("CREATE DATABASE some_db") + spark.catalog.setCurrentDatabase("some_db") + self.assertEquals(spark.catalog.currentDatabase(), "some_db") + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.setCurrentDatabase("does_not_exist")) + + def test_list_databases(self): + spark = self.sparkSession + spark.catalog._reset() + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(databases, ["default"]) + spark.sql("CREATE DATABASE some_db") + databases = [db.name for db in spark.catalog.listDatabases()] + self.assertEquals(sorted(databases), ["default", "some_db"]) + + def test_list_tables(self): + from pyspark.sql.catalog import Table + spark = self.sparkSession + spark.catalog._reset() + spark.sql("CREATE DATABASE some_db") + self.assertEquals(spark.catalog.listTables(), []) + self.assertEquals(spark.catalog.listTables("some_db"), []) + spark.createDataFrame([(1, 1)]).registerTempTable("temp_tab") + spark.sql("CREATE TABLE tab1 (name STRING, age INT)") + spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT)") + tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) + tablesDefault = sorted(spark.catalog.listTables("default"), key=lambda t: t.name) + tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name) + self.assertEquals(tables, tablesDefault) + self.assertEquals(len(tables), 2) + self.assertEquals(len(tablesSomeDb), 2) + self.assertEquals(tables[0], Table( + name="tab1", + database="default", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tables[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertEquals(tablesSomeDb[0], Table( + name="tab2", + database="some_db", + description=None, + tableType="MANAGED", + isTemporary=False)) + self.assertEquals(tablesSomeDb[1], Table( + name="temp_tab", + database=None, + description=None, + tableType="TEMPORARY", + isTemporary=True)) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listTables("does_not_exist")) + + def test_list_functions(self): + from pyspark.sql.catalog import Function + spark = self.sparkSession + spark.catalog._reset() + spark.sql("CREATE DATABASE some_db") + functions = dict((f.name, f) for f in spark.catalog.listFunctions()) + functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default")) + self.assertTrue(len(functions) > 200) + self.assertTrue("+" in functions) + self.assertTrue("like" in functions) + self.assertTrue("month" in functions) + self.assertTrue("to_unix_timestamp" in functions) + self.assertTrue("current_database" in functions) + self.assertEquals(functions["+"], Function( + name="+", + description=None, + className="org.apache.spark.sql.catalyst.expressions.Add", + isTemporary=True)) + self.assertEquals(functions, functionsDefault) + spark.catalog.registerFunction("temp_func", lambda x: str(x)) + spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'") + spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'") + newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions()) + newFunctionsSomeDb = dict((f.name, f) for f in spark.catalog.listFunctions("some_db")) + self.assertTrue(set(functions).issubset(set(newFunctions))) + self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb))) + self.assertTrue("temp_func" in newFunctions) + self.assertTrue("func1" in newFunctions) + self.assertTrue("func2" not in newFunctions) + self.assertTrue("temp_func" in newFunctionsSomeDb) + self.assertTrue("func1" not in newFunctionsSomeDb) + self.assertTrue("func2" in newFunctionsSomeDb) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listFunctions("does_not_exist")) + + def test_list_columns(self): + from pyspark.sql.catalog import Column + spark = self.sparkSession + spark.catalog._reset() + spark.sql("CREATE DATABASE some_db") + spark.sql("CREATE TABLE tab1 (name STRING, age INT)") + spark.sql("CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT)") + columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name) + columnsDefault = sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name) + self.assertEquals(columns, columnsDefault) + self.assertEquals(len(columns), 2) + self.assertEquals(columns[0], Column( + name="age", + description=None, + dataType="int", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns[1], Column( + name="name", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + columns2 = sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name) + self.assertEquals(len(columns2), 2) + self.assertEquals(columns2[0], Column( + name="nickname", + description=None, + dataType="string", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertEquals(columns2[1], Column( + name="tolerance", + description=None, + dataType="float", + nullable=True, + isPartition=False, + isBucket=False)) + self.assertRaisesRegexp( + AnalysisException, + "tab2", + lambda: spark.catalog.listColumns("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.listColumns("does_not_exist")) + + def test_cache(self): + spark = self.sparkSession + spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1") + spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2") + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + spark.catalog.cacheTable("tab1") + self.assertTrue(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + spark.catalog.cacheTable("tab2") + spark.catalog.uncacheTable("tab1") + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertTrue(spark.catalog.isCached("tab2")) + spark.catalog.clearCache() + self.assertFalse(spark.catalog.isCached("tab1")) + self.assertFalse(spark.catalog.isCached("tab2")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.isCached("does_not_exist")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.cacheTable("does_not_exist")) + self.assertRaisesRegexp( + AnalysisException, + "does_not_exist", + lambda: spark.catalog.uncacheTable("does_not_exist")) + class HiveContextSQLTests(ReusedPySparkTestCase): http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1439d14..08be94e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2308,7 +2308,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def registerTempTable(tableName: String): Unit = { - sparkSession.registerDataFrameAsTable(toDF(), tableName) + sparkSession.registerTable(toDF(), tableName) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index bf97d72..f2e8515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -72,6 +72,15 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * * @since 2.0.0 */ + def get(key: String, default: String): String = { + sqlConf.getConfString(key, default) + } + + /** + * Returns the value of Spark runtime configuration property for the given key. + * + * @since 2.0.0 + */ def getOption(key: String): Option[String] = { try Option(get(key)) catch { case _: NoSuchElementException => None @@ -86,4 +95,12 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { def unset(key: String): Unit = { sqlConf.unsetConf(key) } + + /** + * Returns whether a particular key is set. + */ + protected[sql] def contains(key: String): Boolean = { + sqlConf.contains(key) + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1f08a61..6dfac3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -600,7 +600,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - sparkSession.registerDataFrameAsTable(df, tableName) + sparkSession.registerTable(df, tableName) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2814b70..11c0aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -577,7 +577,7 @@ class SparkSession private( * Registers the given [[DataFrame]] as a temporary table in the catalog. * Temporary tables exist only during the lifetime of this instance of [[SparkSession]]. */ - protected[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { + protected[sql] def registerTable(df: DataFrame, tableName: String): Unit = { sessionState.catalog.createTempTable( sessionState.sqlParser.parseTableIdentifier(tableName).table, df.logicalPlan, http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index ec3fada..f05401b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,7 +30,7 @@ case class CacheTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => - sparkSession.registerDataFrameAsTable(Dataset.ofRows(sparkSession, logicalPlan), tableName) + sparkSession.registerTable(Dataset.ofRows(sparkSession, logicalPlan), tableName) } sparkSession.catalog.cacheTable(tableName) http://git-wip-us.apache.org/repos/asf/spark/blob/d33e3d57/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2bfc895..7de7748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -755,6 +755,13 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { }.toSeq } + /** + * Return whether a given key is set in this [[SQLConf]]. + */ + def contains(key: String): Boolean = { + settings.containsKey(key) + } + private def setConfWithCheck(key: String, value: String): Unit = { if (key.startsWith("spark.") && !key.startsWith("spark.sql.")) { logWarning(s"Attempt to set non-Spark SQL config in SQLConf: key = $key, value = $value") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org