r1b commented on code in PR #18450:
URL: https://github.com/apache/datafusion/pull/18450#discussion_r2496554267
##########
datafusion/core/tests/user_defined/user_defined_scalar_functions.rs:
##########
@@ -1112,6 +1140,115 @@ async fn create_scalar_function_from_sql_statement() ->
Result<()> {
Ok(())
}
+#[tokio::test]
+async fn create_scalar_function_from_sql_statement_named_arguments() ->
Result<()> {
+ let function_factory = Arc::new(CustomFunctionFactory::default());
+ let ctx =
SessionContext::new().with_function_factory(function_factory.clone());
+
+ let sql = r#"
+ CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
+ RETURNS DOUBLE
+ RETURN $a + $b
+ "#;
+
+ assert!(ctx.sql(sql).await.is_ok());
+
+ let result = ctx
+ .sql("select better_add(2.0, 2.0)")
+ .await?
+ .collect()
+ .await?;
+
+ assert_batches_eq!(
+ &[
+ "+-----------------------------------+",
+ "| better_add(Float64(2),Float64(2)) |",
+ "+-----------------------------------+",
+ "| 4.0 |",
+ "+-----------------------------------+",
+ ],
+ &result
+ );
+
+ // cannot mix named and positional style
+ let bad_expression_sql = r#"
+ CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
+ RETURNS DOUBLE
+ RETURN $1 $b
+ "#;
+ assert!(ctx.sql(bad_expression_sql).await.is_err());
+ Ok(())
+}
+
+#[tokio::test]
+async fn create_scalar_function_from_sql_statement_default_arguments() ->
Result<()> {
+ let function_factory = Arc::new(CustomFunctionFactory::default());
+ let ctx =
SessionContext::new().with_function_factory(function_factory.clone());
+
+ let sql = r#"
+ CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0)
+ RETURNS DOUBLE
+ RETURN $a + $b
+ "#;
+
+ assert!(ctx.sql(sql).await.is_ok());
+
+ // Check all function arity supported
+ let result = ctx.sql("select better_add()").await?.collect().await?;
+
+ assert_batches_eq!(
+ &[
+ "+--------------+",
+ "| better_add() |",
+ "+--------------+",
+ "| 4.0 |",
+ "+--------------+",
+ ],
+ &result
+ );
+
+ let result = ctx.sql("select better_add(2.0)").await?.collect().await?;
+
+ assert_batches_eq!(
+ &[
+ "+------------------------+",
+ "| better_add(Float64(2)) |",
+ "+------------------------+",
+ "| 4.0 |",
+ "+------------------------+",
+ ],
+ &result
+ );
+
+ let result = ctx
+ .sql("select better_add(2.0, 2.0)")
+ .await?
+ .collect()
+ .await?;
+
+ assert_batches_eq!(
+ &[
+ "+-----------------------------------+",
+ "| better_add(Float64(2),Float64(2)) |",
+ "+-----------------------------------+",
+ "| 4.0 |",
+ "+-----------------------------------+",
+ ],
+ &result
+ );
+
+ assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());
+
+ // non-default argument cannot follow default argument
+ let bad_expression_sql = r#"
+ CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
+ RETURNS DOUBLE
+ RETURN $a $b
+ "#;
+ assert!(ctx.sql(bad_expression_sql).await.is_err());
Review Comment:
ha! this revealed a bug in the test and the implementation
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]