This is an automated email from the ASF dual-hosted git repository.

iffyio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 05d7ffb1 Handle optional datatypes properly in `CREATE FUNCTION` 
statements (#1826)
05d7ffb1 is described below

commit 05d7ffb1d5ef6e4c4852a200e0aa08fec224aa3c
Author: Luca Cappelletti <[email protected]>
AuthorDate: Wed May 21 05:49:28 2025 +0200

    Handle optional datatypes properly in `CREATE FUNCTION` statements (#1826)
    
    Co-authored-by: Ifeanyi Ubah <[email protected]>
---
 src/parser/mod.rs           |  19 ++--
 tests/sqlparser_postgres.rs | 211 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 225 insertions(+), 5 deletions(-)

diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 992d19c4..f6a45ada 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -5273,12 +5273,21 @@ impl<'a> Parser<'a> {
         // parse: [ argname ] argtype
         let mut name = None;
         let mut data_type = self.parse_data_type()?;
-        if let DataType::Custom(n, _) = &data_type {
-            // the first token is actually a name
-            match n.0[0].clone() {
-                ObjectNamePart::Identifier(ident) => name = Some(ident),
+
+        // To check whether the first token is a name or a type, we need to
+        // peek the next token, which if it is another type keyword, then the
+        // first token is a name and not a type in itself.
+        let data_type_idx = self.get_current_index();
+        if let Some(next_data_type) = self.maybe_parse(|parser| 
parser.parse_data_type())? {
+            let token = self.token_at(data_type_idx);
+
+            // We ensure that the token is a `Word` token, and not other 
special tokens.
+            if !matches!(token.token, Token::Word(_)) {
+                return self.expected("a name or type", token.clone());
             }
-            data_type = self.parse_data_type()?;
+
+            name = Some(Ident::new(token.to_string()));
+            data_type = next_data_type;
         }
 
         let default_expr = if self.parse_keyword(Keyword::DEFAULT) || 
self.consume_token(&Token::Eq)
diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs
index 9e71883c..682c0d6c 100644
--- a/tests/sqlparser_postgres.rs
+++ b/tests/sqlparser_postgres.rs
@@ -21,6 +21,7 @@
 
 #[macro_use]
 mod test_utils;
+
 use helpers::attached_token::AttachedToken;
 use sqlparser::tokenizer::Span;
 use test_utils::*;
@@ -4105,6 +4106,216 @@ fn parse_update_in_with_subquery() {
     pg_and_generic().verified_stmt(r#"WITH "result" AS (UPDATE "Hero" SET 
"name" = 'Captain America', "number_of_movies" = "number_of_movies" + 1 WHERE 
"secret_identity" = 'Sam Wilson' RETURNING "id", "name", "secret_identity", 
"number_of_movies") SELECT * FROM "result""#);
 }
 
+#[test]
+fn parser_create_function_with_args() {
+    let sql1 = r#"CREATE OR REPLACE FUNCTION check_strings_different(str1 
VARCHAR, str2 VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+    IF str1 <> str2 THEN
+        RETURN TRUE;
+    ELSE
+        RETURN FALSE;
+    END IF;
+END;
+$$"#;
+
+    assert_eq!(
+        pg_and_generic().verified_stmt(sql1),
+        Statement::CreateFunction(CreateFunction {
+            or_alter: false,
+            or_replace: true,
+            temporary: false,
+            name: 
ObjectName::from(vec![Ident::new("check_strings_different")]),
+            args: Some(vec![
+                OperateFunctionArg::with_name(
+                    "str1",
+                    DataType::Varchar(None),
+                ),
+                OperateFunctionArg::with_name(
+                    "str2",
+                    DataType::Varchar(None),
+                ),
+            ]),
+            return_type: Some(DataType::Boolean),
+            language: Some("plpgsql".into()),
+            behavior: None,
+            called_on_null: None,
+            parallel: None,
+            function_body: 
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+                (Value::DollarQuotedString(DollarQuotedString {value: 
"\nBEGIN\n    IF str1 <> str2 THEN\n        RETURN TRUE;\n    ELSE\n        
RETURN FALSE;\n    END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+            ))),
+            if_not_exists: false,
+            using: None,
+            determinism_specifier: None,
+            options: None,
+            remote_connection: None,
+        })
+    );
+
+    let sql2 = r#"CREATE OR REPLACE FUNCTION check_not_zero(int1 INT) RETURNS 
BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+    IF int1 <> 0 THEN
+        RETURN TRUE;
+    ELSE
+        RETURN FALSE;
+    END IF;
+END;
+$$"#;
+    assert_eq!(
+        pg_and_generic().verified_stmt(sql2),
+        Statement::CreateFunction(CreateFunction {
+            or_alter: false,
+            or_replace: true,
+            temporary: false,
+            name: ObjectName::from(vec![Ident::new("check_not_zero")]),
+            args: Some(vec![
+                OperateFunctionArg::with_name(
+                    "int1",
+                    DataType::Int(None)
+                )
+            ]),
+            return_type: Some(DataType::Boolean),
+            language: Some("plpgsql".into()),
+            behavior: None,
+            called_on_null: None,
+            parallel: None,
+            function_body: 
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+                (Value::DollarQuotedString(DollarQuotedString {value: 
"\nBEGIN\n    IF int1 <> 0 THEN\n        RETURN TRUE;\n    ELSE\n        RETURN 
FALSE;\n    END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+            ))),
+            if_not_exists: false,
+            using: None,
+            determinism_specifier: None,
+            options: None,
+            remote_connection: None,
+        })
+    );
+
+    let sql3 = r#"CREATE OR REPLACE FUNCTION check_values_different(a INT, b 
INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+    IF a <> b THEN
+        RETURN TRUE;
+    ELSE
+        RETURN FALSE;
+    END IF;
+END;
+$$"#;
+    assert_eq!(
+        pg_and_generic().verified_stmt(sql3),
+        Statement::CreateFunction(CreateFunction {
+            or_alter: false,
+            or_replace: true,
+            temporary: false,
+            name: ObjectName::from(vec![Ident::new("check_values_different")]),
+            args: Some(vec![
+                OperateFunctionArg::with_name(
+                    "a",
+                    DataType::Int(None)
+                ),
+                OperateFunctionArg::with_name(
+                    "b",
+                    DataType::Int(None)
+                ),
+            ]),
+            return_type: Some(DataType::Boolean),
+            language: Some("plpgsql".into()),
+            behavior: None,
+            called_on_null: None,
+            parallel: None,
+            function_body: 
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+                (Value::DollarQuotedString(DollarQuotedString {value: 
"\nBEGIN\n    IF a <> b THEN\n        RETURN TRUE;\n    ELSE\n        RETURN 
FALSE;\n    END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+            ))),
+            if_not_exists: false,
+            using: None,
+            determinism_specifier: None,
+            options: None,
+            remote_connection: None,
+        })
+    );
+
+    let sql4 = r#"CREATE OR REPLACE FUNCTION check_values_different(int1 INT, 
int2 INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+    IF int1 <> int2 THEN
+        RETURN TRUE;
+    ELSE
+        RETURN FALSE;
+    END IF;
+END;
+$$"#;
+    assert_eq!(
+        pg_and_generic().verified_stmt(sql4),
+        Statement::CreateFunction(CreateFunction {
+            or_alter: false,
+            or_replace: true,
+            temporary: false,
+            name: ObjectName::from(vec![Ident::new("check_values_different")]),
+            args: Some(vec![
+                OperateFunctionArg::with_name(
+                    "int1",
+                    DataType::Int(None)
+                ),
+                OperateFunctionArg::with_name(
+                    "int2",
+                    DataType::Int(None)
+                ),
+            ]),
+            return_type: Some(DataType::Boolean),
+            language: Some("plpgsql".into()),
+            behavior: None,
+            called_on_null: None,
+            parallel: None,
+            function_body: 
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+                (Value::DollarQuotedString(DollarQuotedString {value: 
"\nBEGIN\n    IF int1 <> int2 THEN\n        RETURN TRUE;\n    ELSE\n        
RETURN FALSE;\n    END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+            ))),
+            if_not_exists: false,
+            using: None,
+            determinism_specifier: None,
+            options: None,
+            remote_connection: None,
+        })
+    );
+
+    let sql5 = r#"CREATE OR REPLACE FUNCTION foo(a TIMESTAMP WITH TIME ZONE, b 
VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+    BEGIN
+        RETURN TRUE;
+    END;
+    $$"#;
+    assert_eq!(
+        pg_and_generic().verified_stmt(sql5),
+        Statement::CreateFunction(CreateFunction {
+            or_alter: false,
+            or_replace: true,
+            temporary: false,
+            name: ObjectName::from(vec![Ident::new("foo")]),
+            args: Some(vec![
+                OperateFunctionArg::with_name(
+                    "a",
+                    DataType::Timestamp(None, TimezoneInfo::WithTimeZone)
+                ),
+                OperateFunctionArg::with_name("b", DataType::Varchar(None)),
+            ]),
+            return_type: Some(DataType::Boolean),
+            language: Some("plpgsql".into()),
+            behavior: None,
+            called_on_null: None,
+            parallel: None,
+            function_body: 
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+                (Value::DollarQuotedString(DollarQuotedString {
+                    value: "\n    BEGIN\n        RETURN TRUE;\n    END;\n    
".to_owned(),
+                    tag: None
+                }))
+                .with_empty_span()
+            ))),
+            if_not_exists: false,
+            using: None,
+            determinism_specifier: None,
+            options: None,
+            remote_connection: None,
+        })
+    );
+
+    let incorrect_sql = "CREATE FUNCTION add(function(struct<a,b> int64), b 
INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select 
$1 + $2;'";
+    assert!(pg().parse_sql_statements(incorrect_sql).is_err(),);
+}
+
 #[test]
 fn parse_create_function() {
     let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE 
SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to