This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 73947a5f Add support for PostgreSQL `UNLISTEN` syntax and Add support 
for Postgres `LOAD extension` expr (#1531)
73947a5f is described below

commit 73947a5f021128cfccd47293ca65aa5c4e83f598
Author: wugeer <[email protected]>
AuthorDate: Wed Nov 20 05:14:28 2024 +0800

    Add support for PostgreSQL `UNLISTEN` syntax and Add support for Postgres 
`LOAD extension` expr (#1531)
    
    Co-authored-by: Ifeanyi Ubah <[email protected]>
---
 src/ast/mod.rs            | 11 +++++++
 src/dialect/mod.rs        |  9 ++----
 src/dialect/postgresql.rs | 12 +++++---
 src/keywords.rs           |  1 +
 src/parser/mod.rs         | 22 ++++++++++++--
 tests/sqlparser_common.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++---
 tests/sqlparser_duckdb.rs | 14 ---------
 7 files changed, 113 insertions(+), 33 deletions(-)

diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 89e70bdd..9185c9df 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -3340,6 +3340,13 @@ pub enum Statement {
     /// See Postgres <https://www.postgresql.org/docs/current/sql-listen.html>
     LISTEN { channel: Ident },
     /// ```sql
+    /// UNLISTEN
+    /// ```
+    /// stop listening for a notification
+    ///
+    /// See Postgres 
<https://www.postgresql.org/docs/current/sql-unlisten.html>
+    UNLISTEN { channel: Ident },
+    /// ```sql
     /// NOTIFY channel [ , payload ]
     /// ```
     /// send a notification event together with an optional “payload” string 
to channel
@@ -4948,6 +4955,10 @@ impl fmt::Display for Statement {
                 write!(f, "LISTEN {channel}")?;
                 Ok(())
             }
+            Statement::UNLISTEN { channel } => {
+                write!(f, "UNLISTEN {channel}")?;
+                Ok(())
+            }
             Statement::NOTIFY { channel, payload } => {
                 write!(f, "NOTIFY {channel}")?;
                 if let Some(payload) = payload {
diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs
index 39ea98c6..985cad74 100644
--- a/src/dialect/mod.rs
+++ b/src/dialect/mod.rs
@@ -633,13 +633,8 @@ pub trait Dialect: Debug + Any {
         false
     }
 
-    /// Returns true if the dialect supports the `LISTEN` statement
-    fn supports_listen(&self) -> bool {
-        false
-    }
-
-    /// Returns true if the dialect supports the `NOTIFY` statement
-    fn supports_notify(&self) -> bool {
+    /// Returns true if the dialect supports the `LISTEN`, `UNLISTEN` and 
`NOTIFY` statements
+    fn supports_listen_notify(&self) -> bool {
         false
     }
 
diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs
index 5af1ab85..559586e3 100644
--- a/src/dialect/postgresql.rs
+++ b/src/dialect/postgresql.rs
@@ -191,12 +191,9 @@ impl Dialect for PostgreSqlDialect {
     }
 
     /// see <https://www.postgresql.org/docs/current/sql-listen.html>
-    fn supports_listen(&self) -> bool {
-        true
-    }
-
+    /// see <https://www.postgresql.org/docs/current/sql-unlisten.html>
     /// see <https://www.postgresql.org/docs/current/sql-notify.html>
-    fn supports_notify(&self) -> bool {
+    fn supports_listen_notify(&self) -> bool {
         true
     }
 
@@ -209,6 +206,11 @@ impl Dialect for PostgreSqlDialect {
     fn supports_comment_on(&self) -> bool {
         true
     }
+
+    /// See <https://www.postgresql.org/docs/current/sql-load.html>
+    fn supports_load_extension(&self) -> bool {
+        true
+    }
 }
 
 pub fn parse_create(parser: &mut Parser) -> Option<Result<Statement, 
ParserError>> {
diff --git a/src/keywords.rs b/src/keywords.rs
index 29115a0d..fc2a2927 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -799,6 +799,7 @@ define_keywords!(
     UNION,
     UNIQUE,
     UNKNOWN,
+    UNLISTEN,
     UNLOAD,
     UNLOCK,
     UNLOGGED,
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 35ad9580..35c763e9 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -532,10 +532,11 @@ impl<'a> Parser<'a> {
                 Keyword::EXECUTE | Keyword::EXEC => self.parse_execute(),
                 Keyword::PREPARE => self.parse_prepare(),
                 Keyword::MERGE => self.parse_merge(),
-                // `LISTEN` and `NOTIFY` are Postgres-specific
+                // `LISTEN`, `UNLISTEN` and `NOTIFY` are Postgres-specific
                 // syntaxes. They are used for Postgres statement.
-                Keyword::LISTEN if self.dialect.supports_listen() => 
self.parse_listen(),
-                Keyword::NOTIFY if self.dialect.supports_notify() => 
self.parse_notify(),
+                Keyword::LISTEN if self.dialect.supports_listen_notify() => 
self.parse_listen(),
+                Keyword::UNLISTEN if self.dialect.supports_listen_notify() => 
self.parse_unlisten(),
+                Keyword::NOTIFY if self.dialect.supports_listen_notify() => 
self.parse_notify(),
                 // `PRAGMA` is sqlite specific 
https://www.sqlite.org/pragma.html
                 Keyword::PRAGMA => self.parse_pragma(),
                 Keyword::UNLOAD => self.parse_unload(),
@@ -999,6 +1000,21 @@ impl<'a> Parser<'a> {
         Ok(Statement::LISTEN { channel })
     }
 
+    pub fn parse_unlisten(&mut self) -> Result<Statement, ParserError> {
+        let channel = if self.consume_token(&Token::Mul) {
+            Ident::new(Expr::Wildcard.to_string())
+        } else {
+            match self.parse_identifier(false) {
+                Ok(expr) => expr,
+                _ => {
+                    self.prev_token();
+                    return self.expected("wildcard or identifier", 
self.peek_token());
+                }
+            }
+        };
+        Ok(Statement::UNLISTEN { channel })
+    }
+
     pub fn parse_notify(&mut self) -> Result<Statement, ParserError> {
         let channel = self.parse_identifier(false)?;
         let payload = if self.consume_token(&Token::Comma) {
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index ecdca6b1..3d9ba5da 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -11595,7 +11595,7 @@ fn test_show_dbs_schemas_tables_views() {
 
 #[test]
 fn parse_listen_channel() {
-    let dialects = all_dialects_where(|d| d.supports_listen());
+    let dialects = all_dialects_where(|d| d.supports_listen_notify());
 
     match dialects.verified_stmt("LISTEN test1") {
         Statement::LISTEN { channel } => {
@@ -11609,7 +11609,7 @@ fn parse_listen_channel() {
         ParserError::ParserError("Expected: identifier, found: *".to_string())
     );
 
-    let dialects = all_dialects_where(|d| !d.supports_listen());
+    let dialects = all_dialects_where(|d| !d.supports_listen_notify());
 
     assert_eq!(
         dialects.parse_sql_statements("LISTEN test1").unwrap_err(),
@@ -11617,9 +11617,40 @@ fn parse_listen_channel() {
     );
 }
 
+#[test]
+fn parse_unlisten_channel() {
+    let dialects = all_dialects_where(|d| d.supports_listen_notify());
+
+    match dialects.verified_stmt("UNLISTEN test1") {
+        Statement::UNLISTEN { channel } => {
+            assert_eq!(Ident::new("test1"), channel);
+        }
+        _ => unreachable!(),
+    };
+
+    match dialects.verified_stmt("UNLISTEN *") {
+        Statement::UNLISTEN { channel } => {
+            assert_eq!(Ident::new("*"), channel);
+        }
+        _ => unreachable!(),
+    };
+
+    assert_eq!(
+        dialects.parse_sql_statements("UNLISTEN +").unwrap_err(),
+        ParserError::ParserError("Expected: wildcard or identifier, found: 
+".to_string())
+    );
+
+    let dialects = all_dialects_where(|d| !d.supports_listen_notify());
+
+    assert_eq!(
+        dialects.parse_sql_statements("UNLISTEN test1").unwrap_err(),
+        ParserError::ParserError("Expected: an SQL statement, found: 
UNLISTEN".to_string())
+    );
+}
+
 #[test]
 fn parse_notify_channel() {
-    let dialects = all_dialects_where(|d| d.supports_notify());
+    let dialects = all_dialects_where(|d| d.supports_listen_notify());
 
     match dialects.verified_stmt("NOTIFY test1") {
         Statement::NOTIFY { channel, payload } => {
@@ -11655,7 +11686,7 @@ fn parse_notify_channel() {
         "NOTIFY test1",
         "NOTIFY test1, 'this is a test notification'",
     ];
-    let dialects = all_dialects_where(|d| !d.supports_notify());
+    let dialects = all_dialects_where(|d| !d.supports_listen_notify());
 
     for &sql in &sql_statements {
         assert_eq!(
@@ -11864,6 +11895,44 @@ fn parse_load_data() {
     );
 }
 
+#[test]
+fn test_load_extension() {
+    let dialects = all_dialects_where(|d| d.supports_load_extension());
+    let not_supports_load_extension_dialects = all_dialects_where(|d| 
!d.supports_load_extension());
+    let sql = "LOAD my_extension";
+
+    match dialects.verified_stmt(sql) {
+        Statement::Load { extension_name } => {
+            assert_eq!(Ident::new("my_extension"), extension_name);
+        }
+        _ => unreachable!(),
+    };
+
+    assert_eq!(
+        not_supports_load_extension_dialects
+            .parse_sql_statements(sql)
+            .unwrap_err(),
+        ParserError::ParserError(
+            "Expected: `DATA` or an extension name after `LOAD`, found: 
my_extension".to_string()
+        )
+    );
+
+    let sql = "LOAD 'filename'";
+
+    match dialects.verified_stmt(sql) {
+        Statement::Load { extension_name } => {
+            assert_eq!(
+                Ident {
+                    value: "filename".to_string(),
+                    quote_style: Some('\'')
+                },
+                extension_name
+            );
+        }
+        _ => unreachable!(),
+    };
+}
+
 #[test]
 fn test_select_top() {
     let dialects = all_dialects_where(|d| d.supports_top_before_distinct());
diff --git a/tests/sqlparser_duckdb.rs b/tests/sqlparser_duckdb.rs
index d68f3771..a2db5c28 100644
--- a/tests/sqlparser_duckdb.rs
+++ b/tests/sqlparser_duckdb.rs
@@ -359,20 +359,6 @@ fn test_duckdb_install() {
     );
 }
 
-#[test]
-fn test_duckdb_load_extension() {
-    let stmt = duckdb().verified_stmt("LOAD my_extension");
-    assert_eq!(
-        Statement::Load {
-            extension_name: Ident {
-                value: "my_extension".to_string(),
-                quote_style: None
-            }
-        },
-        stmt
-    );
-}
-
 #[test]
 fn test_duckdb_struct_literal() {
     //struct literal syntax 
https://duckdb.org/docs/sql/data_types/struct#creating-structs


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to