Jefffrey commented on code in PR #18450:
URL: https://github.com/apache/datafusion/pull/18450#discussion_r2493529061


##########
datafusion/sql/src/expr/value.rs:
##########


Review Comment:
   Theres some documentation here and above (docstring) that needs to be 
updated too



##########
datafusion/sql/src/expr/value.rs:
##########
@@ -123,8 +123,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
                 return if param_data_types.is_empty() {
                     Ok(Expr::Placeholder(Placeholder::new_with_field(param, 
None)))
                 } else {
-                    // when PREPARE Statement, param_data_types length is 
always 0
-                    plan_err!("Invalid placeholder, not a number: {param}")
+                    // FIXME: This branch is shared by params from PREPARE and 
CREATE FUNCTION, but
+                    // only CREATE FUNCTION currently supports named params. 
For now, we rewrite
+                    // these to positional params.
+                    let named_param_pos = param_data_types
+                        .iter()
+                        .position(|v| v.name() == &param[1..]);
+                    match named_param_pos {
+                        Some(pos) => 
Ok(Expr::Placeholder(Placeholder::new_with_field(
+                            format!("${}", pos + 1),
+                            param_data_types.get(pos).cloned(),
+                        ))),
+                        None => plan_err!("Invalid placeholder: {param}"),

Review Comment:
   ```suggestion
                           None => plan_err!("Unknown placeholder: {param}"),
   ```
   
   Given that invalid used to mean that it wasn't a number, if we accept 
strings now this might be a more appropriate message



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -1112,6 +1140,115 @@ async fn create_scalar_function_from_sql_statement() -> 
Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn create_scalar_function_from_sql_statement_named_arguments() -> 
Result<()> {
+    let function_factory = Arc::new(CustomFunctionFactory::default());
+    let ctx = 
SessionContext::new().with_function_factory(function_factory.clone());
+
+    let sql = r#"
+    CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
+        RETURNS DOUBLE
+        RETURN $a + $b
+    "#;
+
+    assert!(ctx.sql(sql).await.is_ok());
+
+    let result = ctx
+        .sql("select better_add(2.0, 2.0)")
+        .await?
+        .collect()
+        .await?;
+
+    assert_batches_eq!(
+        &[
+            "+-----------------------------------+",
+            "| better_add(Float64(2),Float64(2)) |",
+            "+-----------------------------------+",
+            "| 4.0                               |",
+            "+-----------------------------------+",
+        ],
+        &result
+    );
+
+    // cannot mix named and positional style
+    let bad_expression_sql = r#"
+    CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
+        RETURNS DOUBLE
+        RETURN $1 $b
+    "#;
+    assert!(ctx.sql(bad_expression_sql).await.is_err());
+    Ok(())
+}
+
+#[tokio::test]
+async fn create_scalar_function_from_sql_statement_default_arguments() -> 
Result<()> {
+    let function_factory = Arc::new(CustomFunctionFactory::default());
+    let ctx = 
SessionContext::new().with_function_factory(function_factory.clone());
+
+    let sql = r#"
+    CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0)
+        RETURNS DOUBLE
+        RETURN $a + $b
+    "#;
+
+    assert!(ctx.sql(sql).await.is_ok());
+
+    // Check all function arity supported
+    let result = ctx.sql("select better_add()").await?.collect().await?;
+
+    assert_batches_eq!(
+        &[
+            "+--------------+",
+            "| better_add() |",
+            "+--------------+",
+            "| 4.0          |",
+            "+--------------+",
+        ],
+        &result
+    );
+
+    let result = ctx.sql("select better_add(2.0)").await?.collect().await?;
+
+    assert_batches_eq!(
+        &[
+            "+------------------------+",
+            "| better_add(Float64(2)) |",
+            "+------------------------+",
+            "| 4.0                    |",
+            "+------------------------+",
+        ],
+        &result
+    );
+
+    let result = ctx
+        .sql("select better_add(2.0, 2.0)")
+        .await?
+        .collect()
+        .await?;
+
+    assert_batches_eq!(
+        &[
+            "+-----------------------------------+",
+            "| better_add(Float64(2),Float64(2)) |",
+            "+-----------------------------------+",
+            "| 4.0                               |",
+            "+-----------------------------------+",
+        ],
+        &result
+    );
+
+    assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());
+
+    // non-default argument cannot follow default argument
+    let bad_expression_sql = r#"
+    CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
+        RETURNS DOUBLE
+        RETURN $a $b
+    "#;
+    assert!(ctx.sql(bad_expression_sql).await.is_err());

Review Comment:
   Ditto



##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -1112,6 +1140,115 @@ async fn create_scalar_function_from_sql_statement() -> 
Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn create_scalar_function_from_sql_statement_named_arguments() -> 
Result<()> {
+    let function_factory = Arc::new(CustomFunctionFactory::default());
+    let ctx = 
SessionContext::new().with_function_factory(function_factory.clone());
+
+    let sql = r#"
+    CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
+        RETURNS DOUBLE
+        RETURN $a + $b
+    "#;
+
+    assert!(ctx.sql(sql).await.is_ok());
+
+    let result = ctx
+        .sql("select better_add(2.0, 2.0)")
+        .await?
+        .collect()
+        .await?;
+
+    assert_batches_eq!(
+        &[
+            "+-----------------------------------+",
+            "| better_add(Float64(2),Float64(2)) |",
+            "+-----------------------------------+",
+            "| 4.0                               |",
+            "+-----------------------------------+",
+        ],
+        &result
+    );
+
+    // cannot mix named and positional style
+    let bad_expression_sql = r#"
+    CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
+        RETURNS DOUBLE
+        RETURN $1 $b
+    "#;
+    assert!(ctx.sql(bad_expression_sql).await.is_err());

Review Comment:
   We should check the specific error



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to