This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 01e4ae46e8 fix: CLI should support different sql dialects (#7263)
01e4ae46e8 is described below

commit 01e4ae46e8738a7f4b6b4029b48a4010cec9df44
Author: Jonah Gao <[email protected]>
AuthorDate: Sat Aug 12 20:03:28 2023 +0800

    fix: CLI should support different sql dialects (#7263)
    
    * fix: CLI should support different dialects
    
    * fix typo
---
 datafusion-cli/src/exec.rs   |  8 ++++++-
 datafusion-cli/src/helper.rs | 55 ++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 60 insertions(+), 3 deletions(-)

diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs
index 4de7eb5afe..7416b73971 100644
--- a/datafusion-cli/src/exec.rs
+++ b/datafusion-cli/src/exec.rs
@@ -119,7 +119,9 @@ pub async fn exec_from_repl(
     print_options: &mut PrintOptions,
 ) -> rustyline::Result<()> {
     let mut rl = Editor::new()?;
-    rl.set_helper(Some(CliHelper::default()));
+    rl.set_helper(Some(CliHelper::new(
+        &ctx.task_ctx().session_config().options().sql_parser.dialect,
+    )));
     rl.load_history(".history").ok();
 
     let mut print_options = print_options.clone();
@@ -166,6 +168,10 @@ pub async fn exec_from_repl(
                     Ok(_) => {}
                     Err(err) => eprintln!("{err}"),
                 }
+                // dialect might have changed
+                rl.helper_mut().unwrap().set_dialect(
+                    
&ctx.task_ctx().session_config().options().sql_parser.dialect,
+                );
             }
             Err(ReadlineError::Interrupted) => {
                 println!("^C");
diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs
index 981c4b5aa3..e4992122f9 100644
--- a/datafusion-cli/src/helper.rs
+++ b/datafusion-cli/src/helper.rs
@@ -20,6 +20,7 @@
 
 use datafusion::error::DataFusionError;
 use datafusion::sql::parser::{DFParser, Statement};
+use datafusion::sql::sqlparser::dialect::dialect_from_str;
 use datafusion::sql::sqlparser::parser::ParserError;
 use rustyline::completion::Completer;
 use rustyline::completion::FilenameCompleter;
@@ -34,12 +35,25 @@ use rustyline::Context;
 use rustyline::Helper;
 use rustyline::Result;
 
-#[derive(Default)]
 pub struct CliHelper {
     completer: FilenameCompleter,
+    dialect: String,
 }
 
 impl CliHelper {
+    pub fn new(dialect: &str) -> Self {
+        Self {
+            completer: FilenameCompleter::new(),
+            dialect: dialect.into(),
+        }
+    }
+
+    pub fn set_dialect(&mut self, dialect: &str) {
+        if dialect != self.dialect {
+            self.dialect = dialect.to_string();
+        }
+    }
+
     fn validate_input(&self, input: &str) -> Result<ValidationResult> {
         if let Some(sql) = input.strip_suffix(';') {
             let sql = match unescape_input(sql) {
@@ -50,7 +64,18 @@ impl CliHelper {
                     ))))
                 }
             };
-            match DFParser::parse_sql(&sql) {
+
+            let dialect = match dialect_from_str(&self.dialect) {
+                Some(dialect) => dialect,
+                None => {
+                    return Ok(ValidationResult::Invalid(Some(format!(
+                        "  🤔 Invalid dialect: {}",
+                        self.dialect
+                    ))))
+                }
+            };
+
+            match DFParser::parse_sql_with_dialect(&sql, dialect.as_ref()) {
                 Ok(statements) if statements.is_empty() => 
Ok(ValidationResult::Invalid(
                     Some("  🤔 You entered an empty statement".to_string()),
                 )),
@@ -68,6 +93,12 @@ impl CliHelper {
     }
 }
 
+impl Default for CliHelper {
+    fn default() -> Self {
+        Self::new("generic")
+    }
+}
+
 impl Highlighter for CliHelper {}
 
 impl Hinter for CliHelper {
@@ -220,4 +251,24 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn sql_dialect() -> Result<()> {
+        let mut validator = CliHelper::default();
+
+        // shoule be invalid in generic dialect
+        let result =
+            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), 
&validator)?;
+        assert!(
+            matches!(result, ValidationResult::Invalid(Some(e)) if 
e.contains("Invalid statement"))
+        );
+
+        // valid in postgresql dialect
+        validator.set_dialect("postgresql");
+        let result =
+            readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), 
&validator)?;
+        assert!(matches!(result, ValidationResult::Valid(None)));
+
+        Ok(())
+    }
 }

Reply via email to