This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cbbc06147ec [SPARK-43992][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listFunctions cbbc06147ec is described below commit cbbc06147eca2f81853554db4f99b4f2f5ff8dd1 Author: Jiaan Geng <belie...@163.com> AuthorDate: Thu Jun 8 17:34:25 2023 +0800 [SPARK-43992][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listFunctions ### What changes were proposed in this pull request? Currently, the syntax `SHOW FUNCTIONS LIKE pattern` supports an optional pattern, so as filtered out the expected functions. But the Catalog.listFunctions missing the function both in Catalog API and Connect Catalog API. In fact, the optional pattern is very useful. ### Why are the changes needed? This PR want add the optional pattern for `Catalog.listFunctions`. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #41497 from beliefer/SPARK-43992. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../org/apache/spark/sql/catalog/Catalog.scala | 10 +++ .../apache/spark/sql/internal/CatalogImpl.scala | 14 ++++ .../scala/org/apache/spark/sql/CatalogSuite.scala | 7 ++ .../src/main/protobuf/spark/connect/catalog.proto | 2 + .../sql/connect/planner/SparkConnectPlanner.scala | 11 ++- project/MimaExcludes.scala | 12 +-- python/pyspark/sql/catalog.py | 19 ++++- python/pyspark/sql/connect/catalog.py | 6 +- python/pyspark/sql/connect/plan.py | 5 +- python/pyspark/sql/connect/proto/catalog_pb2.py | 96 +++++++++++----------- python/pyspark/sql/connect/proto/catalog_pb2.pyi | 33 +++++++- python/pyspark/sql/tests/test_catalog.py | 19 +++++ .../org/apache/spark/sql/catalog/Catalog.scala | 10 +++ .../apache/spark/sql/internal/CatalogImpl.scala | 21 ++++- .../apache/spark/sql/internal/CatalogSuite.scala | 13 +++ 15 files changed, 215 insertions(+), 63 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 0ac704e68e6..268f162cbfa 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -102,6 +102,16 @@ abstract class Catalog { @throws[AnalysisException]("database does not exist") def listFunctions(dbName: String): Dataset[Function] + /** + * Returns a list of functions registered in the specified database (namespace) which name match + * the specify pattern (the name can be qualified with catalog). This includes all built-in and + * temporary functions. + * + * @since 3.5.0 + */ + @throws[AnalysisException]("database does not exist") + def listFunctions(dbName: String, pattern: String): Dataset[Function] + /** * Returns a list of columns for the given table/view or temporary view. * diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 95a3332cfc2..f287568d629 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -137,6 +137,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } + /** + * Returns a list of functions registered in the specified database (namespace) which name match + * the specify pattern (the name can be qualified with catalog). This includes all built-in and + * temporary functions. + * + * @since 3.5.0 + */ + @throws[AnalysisException]("database does not exist") + def listFunctions(dbName: String, pattern: String): Dataset[Function] = { + sparkSession.newDataset(CatalogImpl.functionEncoder) { builder => + builder.getCatalogBuilder.getListFunctionsBuilder.setDbName(dbName).setPattern(pattern) + } + } + /** * Returns a list of columns for the given table/view or temporary view. * diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 671f6ac4051..04b3f4e639a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -211,6 +211,13 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { spark.catalog.getFunction(notExistsFunction) }.getMessage assert(message.contains("UNRESOLVED_ROUTINE")) + + val functionsWithPattern1 = spark.catalog.listFunctions(dbName, "to*").collect() + assert(functionsWithPattern1.nonEmpty) + assert(functionsWithPattern1.exists(f => f.name == "to_date")) + val functionsWithPattern2 = + spark.catalog.listFunctions(dbName, "*not_existing_func*").collect() + assert(functionsWithPattern2.isEmpty) } test("recoverPartitions") { diff --git a/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto b/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto index 97b905da7c3..5b1b90b0087 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/catalog.proto @@ -85,6 +85,8 @@ message ListTables { message ListFunctions { // (Optional) optional string db_name = 1; + // (Optional) The pattern that the function name needs to match + optional string pattern = 2; } // See `spark.catalog.listColumns` diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7e642b0bdf6..86d65eb47fb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2721,7 +2721,16 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformListFunctions(getListFunctions: proto.ListFunctions): LogicalPlan = { if (getListFunctions.hasDbName) { - session.catalog.listFunctions(getListFunctions.getDbName).logicalPlan + if (getListFunctions.hasPattern) { + session.catalog + .listFunctions(getListFunctions.getDbName, getListFunctions.getPattern) + .logicalPlan + } else { + session.catalog.listFunctions(getListFunctions.getDbName).logicalPlan + } + } else if (getListFunctions.hasPattern) { + val currentDatabase = session.catalog.currentDatabase + session.catalog.listFunctions(currentDatabase, getListFunctions.getPattern).logicalPlan } else { session.catalog.listFunctions().logicalPlan } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f22994ed75e..bba20534f44 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,16 +45,18 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.HadoopFSUtils$SerializableFileStatus$"), // [SPARK-43792][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listCatalogs ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listCatalogs"), + // [SPARK-43881][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listDatabases + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases"), + // [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listTables"), + // [SPARK-43992][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listFunctions + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listFunctions"), // [SPARK-43919][SQL] Extract JSON functionality out of Row ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Row.json"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Row.prettyJson"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.json"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.prettyJson"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue"), - // [SPARK-43881][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listDatabases - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listDatabases"), - // [SPARK-43961][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listTables - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listTables") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue") ) // Defulat exclude rules diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 9650affc68a..2c6ed28461f 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -439,7 +439,9 @@ class Catalog: isTemporary=jtable.isTemporary(), ) - def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: + def listFunctions( + self, dbName: Optional[str] = None, pattern: Optional[str] = None + ) -> List[Function]: """ Returns a list of functions registered in the specified database. @@ -450,6 +452,11 @@ class Catalog: dbName : str name of the database to list the functions. ``dbName`` can be qualified with catalog name. + pattern : str + The pattern that the function name needs to match. + + .. versionchanged: 3.5.0 + Adds ``pattern`` argument. Returns ------- @@ -465,10 +472,20 @@ class Catalog: -------- >>> spark.catalog.listFunctions() [Function(name=... + + >>> spark.catalog.listFunctions(pattern="to_*") + [Function(name=... + + >>> spark.catalog.listFunctions(pattern="*not_existing_func*") + [] """ if dbName is None: dbName = self.currentDatabase() iter = self._jcatalog.listFunctions(dbName).toLocalIterator() + if pattern is None: + iter = self._jcatalog.listFunctions(dbName).toLocalIterator() + else: + iter = self._jcatalog.listFunctions(dbName, pattern).toLocalIterator() functions = [] while iter.hasNext(): jfunction = iter.next() diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index 6766060a7b9..2a54a0d727a 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -151,8 +151,10 @@ class Catalog: getTable.__doc__ = PySparkCatalog.getTable.__doc__ - def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: - pdf = self._execute_and_fetch(plan.ListFunctions(db_name=dbName)) + def listFunctions( + self, dbName: Optional[str] = None, pattern: Optional[str] = None + ) -> List[Function]: + pdf = self._execute_and_fetch(plan.ListFunctions(db_name=dbName, pattern=pattern)) return [ Function( name=row.iloc[0], diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 95d7af90f65..fc8b37b102c 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1663,14 +1663,17 @@ class ListTables(LogicalPlan): class ListFunctions(LogicalPlan): - def __init__(self, db_name: Optional[str] = None) -> None: + def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None: super().__init__(None) self._db_name = db_name + self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = proto.Relation(catalog=proto.Catalog(list_functions=proto.ListFunctions())) if self._db_name is not None: plan.catalog.list_functions.db_name = self._db_name + if self._pattern is not None: + plan.catalog.list_functions.pattern = self._pattern return plan diff --git a/python/pyspark/sql/connect/proto/catalog_pb2.py b/python/pyspark/sql/connect/proto/catalog_pb2.py index 920ffa32444..1680eca7314 100644 --- a/python/pyspark/sql/connect/proto/catalog_pb2.py +++ b/python/pyspark/sql/connect/proto/catalog_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1bspark/connect/catalog.proto\x12\rspark.connect\x1a\x1aspark/connect/common.proto\x1a\x19spark/connect/types.proto"\xc6\x0e\n\x07\x43\x61talog\x12K\n\x10\x63urrent_database\x18\x01 \x01(\x0b\x32\x1e.spark.connect.CurrentDatabaseH\x00R\x0f\x63urrentDatabase\x12U\n\x14set_current_database\x18\x02 \x01(\x0b\x32!.spark.connect.SetCurrentDatabaseH\x00R\x12setCurrentDatabase\x12\x45\n\x0elist_databases\x18\x03 \x01(\x0b\x32\x1c.spark.connect.ListDatabasesH\x00R\rlistDatabases\x12<\n [...] + b'\n\x1bspark/connect/catalog.proto\x12\rspark.connect\x1a\x1aspark/connect/common.proto\x1a\x19spark/connect/types.proto"\xc6\x0e\n\x07\x43\x61talog\x12K\n\x10\x63urrent_database\x18\x01 \x01(\x0b\x32\x1e.spark.connect.CurrentDatabaseH\x00R\x0f\x63urrentDatabase\x12U\n\x14set_current_database\x18\x02 \x01(\x0b\x32!.spark.connect.SetCurrentDatabaseH\x00R\x12setCurrentDatabase\x12\x45\n\x0elist_databases\x18\x03 \x01(\x0b\x32\x1c.spark.connect.ListDatabasesH\x00R\rlistDatabases\x12<\n [...] ) @@ -405,51 +405,51 @@ if _descriptor._USE_C_DESCRIPTORS == False: _LISTTABLES._serialized_start = 2092 _LISTTABLES._serialized_end = 2189 _LISTFUNCTIONS._serialized_start = 2191 - _LISTFUNCTIONS._serialized_end = 2248 - _LISTCOLUMNS._serialized_start = 2250 - _LISTCOLUMNS._serialized_end = 2336 - _GETDATABASE._serialized_start = 2338 - _GETDATABASE._serialized_end = 2376 - _GETTABLE._serialized_start = 2378 - _GETTABLE._serialized_end = 2461 - _GETFUNCTION._serialized_start = 2463 - _GETFUNCTION._serialized_end = 2555 - _DATABASEEXISTS._serialized_start = 2557 - _DATABASEEXISTS._serialized_end = 2598 - _TABLEEXISTS._serialized_start = 2600 - _TABLEEXISTS._serialized_end = 2686 - _FUNCTIONEXISTS._serialized_start = 2688 - _FUNCTIONEXISTS._serialized_end = 2783 - _CREATEEXTERNALTABLE._serialized_start = 2786 - _CREATEEXTERNALTABLE._serialized_end = 3112 - _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_start = 3023 - _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_end = 3081 - _CREATETABLE._serialized_start = 3115 - _CREATETABLE._serialized_end = 3480 - _CREATETABLE_OPTIONSENTRY._serialized_start = 3023 - _CREATETABLE_OPTIONSENTRY._serialized_end = 3081 - _DROPTEMPVIEW._serialized_start = 3482 - _DROPTEMPVIEW._serialized_end = 3525 - _DROPGLOBALTEMPVIEW._serialized_start = 3527 - _DROPGLOBALTEMPVIEW._serialized_end = 3576 - _RECOVERPARTITIONS._serialized_start = 3578 - _RECOVERPARTITIONS._serialized_end = 3628 - _ISCACHED._serialized_start = 3630 - _ISCACHED._serialized_end = 3671 - _CACHETABLE._serialized_start = 3674 - _CACHETABLE._serialized_end = 3806 - _UNCACHETABLE._serialized_start = 3808 - _UNCACHETABLE._serialized_end = 3853 - _CLEARCACHE._serialized_start = 3855 - _CLEARCACHE._serialized_end = 3867 - _REFRESHTABLE._serialized_start = 3869 - _REFRESHTABLE._serialized_end = 3914 - _REFRESHBYPATH._serialized_start = 3916 - _REFRESHBYPATH._serialized_end = 3951 - _CURRENTCATALOG._serialized_start = 3953 - _CURRENTCATALOG._serialized_end = 3969 - _SETCURRENTCATALOG._serialized_start = 3971 - _SETCURRENTCATALOG._serialized_end = 4025 - _LISTCATALOGS._serialized_start = 4027 - _LISTCATALOGS._serialized_end = 4084 + _LISTFUNCTIONS._serialized_end = 2291 + _LISTCOLUMNS._serialized_start = 2293 + _LISTCOLUMNS._serialized_end = 2379 + _GETDATABASE._serialized_start = 2381 + _GETDATABASE._serialized_end = 2419 + _GETTABLE._serialized_start = 2421 + _GETTABLE._serialized_end = 2504 + _GETFUNCTION._serialized_start = 2506 + _GETFUNCTION._serialized_end = 2598 + _DATABASEEXISTS._serialized_start = 2600 + _DATABASEEXISTS._serialized_end = 2641 + _TABLEEXISTS._serialized_start = 2643 + _TABLEEXISTS._serialized_end = 2729 + _FUNCTIONEXISTS._serialized_start = 2731 + _FUNCTIONEXISTS._serialized_end = 2826 + _CREATEEXTERNALTABLE._serialized_start = 2829 + _CREATEEXTERNALTABLE._serialized_end = 3155 + _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_start = 3066 + _CREATEEXTERNALTABLE_OPTIONSENTRY._serialized_end = 3124 + _CREATETABLE._serialized_start = 3158 + _CREATETABLE._serialized_end = 3523 + _CREATETABLE_OPTIONSENTRY._serialized_start = 3066 + _CREATETABLE_OPTIONSENTRY._serialized_end = 3124 + _DROPTEMPVIEW._serialized_start = 3525 + _DROPTEMPVIEW._serialized_end = 3568 + _DROPGLOBALTEMPVIEW._serialized_start = 3570 + _DROPGLOBALTEMPVIEW._serialized_end = 3619 + _RECOVERPARTITIONS._serialized_start = 3621 + _RECOVERPARTITIONS._serialized_end = 3671 + _ISCACHED._serialized_start = 3673 + _ISCACHED._serialized_end = 3714 + _CACHETABLE._serialized_start = 3717 + _CACHETABLE._serialized_end = 3849 + _UNCACHETABLE._serialized_start = 3851 + _UNCACHETABLE._serialized_end = 3896 + _CLEARCACHE._serialized_start = 3898 + _CLEARCACHE._serialized_end = 3910 + _REFRESHTABLE._serialized_start = 3912 + _REFRESHTABLE._serialized_end = 3957 + _REFRESHBYPATH._serialized_start = 3959 + _REFRESHBYPATH._serialized_end = 3994 + _CURRENTCATALOG._serialized_start = 3996 + _CURRENTCATALOG._serialized_end = 4012 + _SETCURRENTCATALOG._serialized_start = 4014 + _SETCURRENTCATALOG._serialized_end = 4068 + _LISTCATALOGS._serialized_start = 4070 + _LISTCATALOGS._serialized_end = 4127 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/catalog_pb2.pyi b/python/pyspark/sql/connect/proto/catalog_pb2.pyi index 77a924d6d51..3d14961329b 100644 --- a/python/pyspark/sql/connect/proto/catalog_pb2.pyi +++ b/python/pyspark/sql/connect/proto/catalog_pb2.pyi @@ -427,22 +427,51 @@ class ListFunctions(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DB_NAME_FIELD_NUMBER: builtins.int + PATTERN_FIELD_NUMBER: builtins.int db_name: builtins.str """(Optional)""" + pattern: builtins.str + """(Optional) The pattern that the function name needs to match""" def __init__( self, *, db_name: builtins.str | None = ..., + pattern: builtins.str | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["_db_name", b"_db_name", "db_name", b"db_name"] + self, + field_name: typing_extensions.Literal[ + "_db_name", + b"_db_name", + "_pattern", + b"_pattern", + "db_name", + b"db_name", + "pattern", + b"pattern", + ], ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["_db_name", b"_db_name", "db_name", b"db_name"] + self, + field_name: typing_extensions.Literal[ + "_db_name", + b"_db_name", + "_pattern", + b"_pattern", + "db_name", + b"db_name", + "pattern", + b"pattern", + ], ) -> None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_db_name", b"_db_name"] ) -> typing_extensions.Literal["db_name"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_pattern", b"_pattern"] + ) -> typing_extensions.Literal["pattern"] | None: ... global___ListFunctions = ListFunctions diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 716f0638866..cafffdc9ae8 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -257,6 +257,25 @@ class CatalogTestsMixin: self.assertTrue(functions["+"].isTemporary) self.assertEqual(functions, functionsDefault) + functionsWithPattern = dict( + (f.name, f) for f in spark.catalog.listFunctions(pattern="to*") + ) + functionsDefaultWithPattern = dict( + (f.name, f) for f in spark.catalog.listFunctions("default", "to*") + ) + self.assertTrue(len(functionsWithPattern) > 10) + self.assertFalse("+" in functionsWithPattern) + self.assertFalse("like" in functionsWithPattern) + self.assertFalse("month" in functionsWithPattern) + self.assertTrue("to_date" in functionsWithPattern) + self.assertTrue("to_timestamp" in functionsWithPattern) + self.assertTrue("to_unix_timestamp" in functionsWithPattern) + self.assertEqual(functionsWithPattern, functionsDefaultWithPattern) + functionsWithPattern = dict( + (f.name, f) for f in spark.catalog.listFunctions(pattern="*not_existing_func*") + ) + self.assertTrue(len(functionsWithPattern) == 0) + with self.function("func1", "some_db.func2"): try: spark.udf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index b8cb97e1650..93ff3059f62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -107,6 +107,16 @@ abstract class Catalog { @throws[AnalysisException]("database does not exist") def listFunctions(dbName: String): Dataset[Function] + /** + * Returns a list of functions registered in the specified database (namespace) + * which name match the specify pattern (the name can be qualified with catalog). + * This includes all built-in and temporary functions. + * + * @since 3.5.0 + */ + @throws[AnalysisException]("database does not exist") + def listFunctions(dbName: String, pattern: String): Dataset[Function] + /** * Returns a list of columns for the given table/view or temporary view. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 3c61102699e..55136442b1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -240,6 +240,22 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("database does not exist") override def listFunctions(dbName: String): Dataset[Function] = { + listFunctionsInternal(dbName, None) + } + + /** + * Returns a list of functions registered in the specified database (namespace) + * which name match the specify pattern (the name can be qualified with catalog). + * This includes all built-in and temporary functions. + * + * @since 3.5.0 + */ + @throws[AnalysisException]("database does not exist") + def listFunctions(dbName: String, pattern: String): Dataset[Function] = { + listFunctionsInternal(dbName, Some(pattern)) + } + + private def listFunctionsInternal(dbName: String, pattern: Option[String]): Dataset[Function] = { val namespace = resolveNamespace(dbName) val functions = collection.mutable.ArrayBuilder.make[Function] @@ -249,7 +265,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // List built-in functions. We don't need to specify the namespace here as SHOW FUNCTIONS with // only system scope does not need to know the catalog and namespace. - val plan0 = ShowFunctions(UnresolvedNamespace(Nil), userScope = false, systemScope = true, None) + val plan0 = ShowFunctions(UnresolvedNamespace(Nil), false, true, pattern) sparkSession.sessionState.executePlan(plan0).toRdd.collect().foreach { row => // Built-in functions do not belong to any catalog or namespace. We can only look it up with // a single part name. @@ -258,8 +274,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } // List user functions. - val plan1 = ShowFunctions(UnresolvedNamespace(namespace), - userScope = true, systemScope = false, None) + val plan1 = ShowFunctions(UnresolvedNamespace(namespace), true, false, pattern) sparkSession.sessionState.executePlan(plan1).toRdd.collect().foreach { row => functions += makeFunction(parseIdent(row.getString(0))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 5ef8e35da9e..0f88baeb689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -259,12 +259,25 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf assert(funcNames1.contains("my_func1")) assert(funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) + val funcNamesWithPattern1 = + spark.catalog.listFunctions("default", "my_func*").collect().map(_.name).toSet + assert(funcNamesWithPattern1.contains("my_func1")) + assert(funcNamesWithPattern1.contains("my_func2")) + assert(!funcNamesWithPattern1.contains("my_temp_func")) dropFunction("my_func1") dropTempFunction("my_temp_func") val funcNames2 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(!funcNames2.contains("my_temp_func")) + val funcNamesWithPattern2 = + spark.catalog.listFunctions("default", "my_func*").collect().map(_.name).toSet + assert(!funcNamesWithPattern2.contains("my_func1")) + assert(funcNamesWithPattern2.contains("my_func2")) + assert(!funcNamesWithPattern2.contains("my_temp_func")) + val funcNamesWithPattern3 = + spark.catalog.listFunctions("default", "*not_existing_func*").collect().map(_.name).toSet + assert(funcNamesWithPattern3.isEmpty) } test("SPARK-39828: Catalog.listFunctions() should respect currentCatalog") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org