This is an automated email from the ASF dual-hosted git repository.
iffyio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 05d7ffb1 Handle optional datatypes properly in `CREATE FUNCTION`
statements (#1826)
05d7ffb1 is described below
commit 05d7ffb1d5ef6e4c4852a200e0aa08fec224aa3c
Author: Luca Cappelletti <[email protected]>
AuthorDate: Wed May 21 05:49:28 2025 +0200
Handle optional datatypes properly in `CREATE FUNCTION` statements (#1826)
Co-authored-by: Ifeanyi Ubah <[email protected]>
---
src/parser/mod.rs | 19 ++--
tests/sqlparser_postgres.rs | 211 ++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 225 insertions(+), 5 deletions(-)
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 992d19c4..f6a45ada 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -5273,12 +5273,21 @@ impl<'a> Parser<'a> {
// parse: [ argname ] argtype
let mut name = None;
let mut data_type = self.parse_data_type()?;
- if let DataType::Custom(n, _) = &data_type {
- // the first token is actually a name
- match n.0[0].clone() {
- ObjectNamePart::Identifier(ident) => name = Some(ident),
+
+ // To check whether the first token is a name or a type, we need to
+ // peek the next token, which if it is another type keyword, then the
+ // first token is a name and not a type in itself.
+ let data_type_idx = self.get_current_index();
+ if let Some(next_data_type) = self.maybe_parse(|parser|
parser.parse_data_type())? {
+ let token = self.token_at(data_type_idx);
+
+ // We ensure that the token is a `Word` token, and not other
special tokens.
+ if !matches!(token.token, Token::Word(_)) {
+ return self.expected("a name or type", token.clone());
}
- data_type = self.parse_data_type()?;
+
+ name = Some(Ident::new(token.to_string()));
+ data_type = next_data_type;
}
let default_expr = if self.parse_keyword(Keyword::DEFAULT) ||
self.consume_token(&Token::Eq)
diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs
index 9e71883c..682c0d6c 100644
--- a/tests/sqlparser_postgres.rs
+++ b/tests/sqlparser_postgres.rs
@@ -21,6 +21,7 @@
#[macro_use]
mod test_utils;
+
use helpers::attached_token::AttachedToken;
use sqlparser::tokenizer::Span;
use test_utils::*;
@@ -4105,6 +4106,216 @@ fn parse_update_in_with_subquery() {
pg_and_generic().verified_stmt(r#"WITH "result" AS (UPDATE "Hero" SET
"name" = 'Captain America', "number_of_movies" = "number_of_movies" + 1 WHERE
"secret_identity" = 'Sam Wilson' RETURNING "id", "name", "secret_identity",
"number_of_movies") SELECT * FROM "result""#);
}
+#[test]
+fn parser_create_function_with_args() {
+ let sql1 = r#"CREATE OR REPLACE FUNCTION check_strings_different(str1
VARCHAR, str2 VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+ IF str1 <> str2 THEN
+ RETURN TRUE;
+ ELSE
+ RETURN FALSE;
+ END IF;
+END;
+$$"#;
+
+ assert_eq!(
+ pg_and_generic().verified_stmt(sql1),
+ Statement::CreateFunction(CreateFunction {
+ or_alter: false,
+ or_replace: true,
+ temporary: false,
+ name:
ObjectName::from(vec![Ident::new("check_strings_different")]),
+ args: Some(vec![
+ OperateFunctionArg::with_name(
+ "str1",
+ DataType::Varchar(None),
+ ),
+ OperateFunctionArg::with_name(
+ "str2",
+ DataType::Varchar(None),
+ ),
+ ]),
+ return_type: Some(DataType::Boolean),
+ language: Some("plpgsql".into()),
+ behavior: None,
+ called_on_null: None,
+ parallel: None,
+ function_body:
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+ (Value::DollarQuotedString(DollarQuotedString {value:
"\nBEGIN\n IF str1 <> str2 THEN\n RETURN TRUE;\n ELSE\n
RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+ ))),
+ if_not_exists: false,
+ using: None,
+ determinism_specifier: None,
+ options: None,
+ remote_connection: None,
+ })
+ );
+
+ let sql2 = r#"CREATE OR REPLACE FUNCTION check_not_zero(int1 INT) RETURNS
BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+ IF int1 <> 0 THEN
+ RETURN TRUE;
+ ELSE
+ RETURN FALSE;
+ END IF;
+END;
+$$"#;
+ assert_eq!(
+ pg_and_generic().verified_stmt(sql2),
+ Statement::CreateFunction(CreateFunction {
+ or_alter: false,
+ or_replace: true,
+ temporary: false,
+ name: ObjectName::from(vec![Ident::new("check_not_zero")]),
+ args: Some(vec![
+ OperateFunctionArg::with_name(
+ "int1",
+ DataType::Int(None)
+ )
+ ]),
+ return_type: Some(DataType::Boolean),
+ language: Some("plpgsql".into()),
+ behavior: None,
+ called_on_null: None,
+ parallel: None,
+ function_body:
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+ (Value::DollarQuotedString(DollarQuotedString {value:
"\nBEGIN\n IF int1 <> 0 THEN\n RETURN TRUE;\n ELSE\n RETURN
FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+ ))),
+ if_not_exists: false,
+ using: None,
+ determinism_specifier: None,
+ options: None,
+ remote_connection: None,
+ })
+ );
+
+ let sql3 = r#"CREATE OR REPLACE FUNCTION check_values_different(a INT, b
INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+ IF a <> b THEN
+ RETURN TRUE;
+ ELSE
+ RETURN FALSE;
+ END IF;
+END;
+$$"#;
+ assert_eq!(
+ pg_and_generic().verified_stmt(sql3),
+ Statement::CreateFunction(CreateFunction {
+ or_alter: false,
+ or_replace: true,
+ temporary: false,
+ name: ObjectName::from(vec![Ident::new("check_values_different")]),
+ args: Some(vec![
+ OperateFunctionArg::with_name(
+ "a",
+ DataType::Int(None)
+ ),
+ OperateFunctionArg::with_name(
+ "b",
+ DataType::Int(None)
+ ),
+ ]),
+ return_type: Some(DataType::Boolean),
+ language: Some("plpgsql".into()),
+ behavior: None,
+ called_on_null: None,
+ parallel: None,
+ function_body:
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+ (Value::DollarQuotedString(DollarQuotedString {value:
"\nBEGIN\n IF a <> b THEN\n RETURN TRUE;\n ELSE\n RETURN
FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+ ))),
+ if_not_exists: false,
+ using: None,
+ determinism_specifier: None,
+ options: None,
+ remote_connection: None,
+ })
+ );
+
+ let sql4 = r#"CREATE OR REPLACE FUNCTION check_values_different(int1 INT,
int2 INT) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+BEGIN
+ IF int1 <> int2 THEN
+ RETURN TRUE;
+ ELSE
+ RETURN FALSE;
+ END IF;
+END;
+$$"#;
+ assert_eq!(
+ pg_and_generic().verified_stmt(sql4),
+ Statement::CreateFunction(CreateFunction {
+ or_alter: false,
+ or_replace: true,
+ temporary: false,
+ name: ObjectName::from(vec![Ident::new("check_values_different")]),
+ args: Some(vec![
+ OperateFunctionArg::with_name(
+ "int1",
+ DataType::Int(None)
+ ),
+ OperateFunctionArg::with_name(
+ "int2",
+ DataType::Int(None)
+ ),
+ ]),
+ return_type: Some(DataType::Boolean),
+ language: Some("plpgsql".into()),
+ behavior: None,
+ called_on_null: None,
+ parallel: None,
+ function_body:
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+ (Value::DollarQuotedString(DollarQuotedString {value:
"\nBEGIN\n IF int1 <> int2 THEN\n RETURN TRUE;\n ELSE\n
RETURN FALSE;\n END IF;\nEND;\n".to_owned(), tag: None})).with_empty_span()
+ ))),
+ if_not_exists: false,
+ using: None,
+ determinism_specifier: None,
+ options: None,
+ remote_connection: None,
+ })
+ );
+
+ let sql5 = r#"CREATE OR REPLACE FUNCTION foo(a TIMESTAMP WITH TIME ZONE, b
VARCHAR) RETURNS BOOLEAN LANGUAGE plpgsql AS $$
+ BEGIN
+ RETURN TRUE;
+ END;
+ $$"#;
+ assert_eq!(
+ pg_and_generic().verified_stmt(sql5),
+ Statement::CreateFunction(CreateFunction {
+ or_alter: false,
+ or_replace: true,
+ temporary: false,
+ name: ObjectName::from(vec![Ident::new("foo")]),
+ args: Some(vec![
+ OperateFunctionArg::with_name(
+ "a",
+ DataType::Timestamp(None, TimezoneInfo::WithTimeZone)
+ ),
+ OperateFunctionArg::with_name("b", DataType::Varchar(None)),
+ ]),
+ return_type: Some(DataType::Boolean),
+ language: Some("plpgsql".into()),
+ behavior: None,
+ called_on_null: None,
+ parallel: None,
+ function_body:
Some(CreateFunctionBody::AsBeforeOptions(Expr::Value(
+ (Value::DollarQuotedString(DollarQuotedString {
+ value: "\n BEGIN\n RETURN TRUE;\n END;\n
".to_owned(),
+ tag: None
+ }))
+ .with_empty_span()
+ ))),
+ if_not_exists: false,
+ using: None,
+ determinism_specifier: None,
+ options: None,
+ remote_connection: None,
+ })
+ );
+
+ let incorrect_sql = "CREATE FUNCTION add(function(struct<a,b> int64), b
INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select
$1 + $2;'";
+ assert!(pg().parse_sql_statements(incorrect_sql).is_err(),);
+}
+
#[test]
fn parse_create_function() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE
SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]