This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 724a1d1a Add support for Hive's `LOAD DATA` expr (#1520)
724a1d1a is described below

commit 724a1d1aba575fb04a2df54ca8425b39ea753938
Author: wugeer <[email protected]>
AuthorDate: Fri Nov 15 22:53:31 2024 +0800

    Add support for Hive's `LOAD DATA` expr (#1520)
    
    Co-authored-by: Ifeanyi Ubah <[email protected]>
---
 src/ast/mod.rs            |  54 ++++++++++++
 src/dialect/duckdb.rs     |   5 ++
 src/dialect/generic.rs    |   4 +
 src/dialect/hive.rs       |   5 ++
 src/dialect/mod.rs        |  10 +++
 src/keywords.rs           |   1 +
 src/parser/mod.rs         |  52 ++++++++++--
 tests/sqlparser_common.rs | 203 +++++++++++++++++++++++++++++++++++++++++++++-
 8 files changed, 323 insertions(+), 11 deletions(-)

diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index b0ac6bc4..39c74215 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -3347,6 +3347,22 @@ pub enum Statement {
         channel: Ident,
         payload: Option<String>,
     },
+    /// ```sql
+    /// LOAD DATA [LOCAL] INPATH 'filepath' [OVERWRITE] INTO TABLE tablename
+    /// [PARTITION (partcol1=val1, partcol2=val2 ...)]
+    /// [INPUTFORMAT 'inputformat' SERDE 'serde']
+    /// ```
+    /// Loading files into tables
+    ///
+    /// See Hive 
<https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362036#LanguageManualDML-Loadingfilesintotables>
+    LoadData {
+        local: bool,
+        inpath: String,
+        overwrite: bool,
+        table_name: ObjectName,
+        partitioned: Option<Vec<Expr>>,
+        table_format: Option<HiveLoadDataFormat>,
+    },
 }
 
 impl fmt::Display for Statement {
@@ -3949,6 +3965,36 @@ impl fmt::Display for Statement {
                 Ok(())
             }
             Statement::CreateTable(create_table) => create_table.fmt(f),
+            Statement::LoadData {
+                local,
+                inpath,
+                overwrite,
+                table_name,
+                partitioned,
+                table_format,
+            } => {
+                write!(
+                    f,
+                    "LOAD DATA {local}INPATH '{inpath}' {overwrite}INTO TABLE 
{table_name}",
+                    local = if *local { "LOCAL " } else { "" },
+                    inpath = inpath,
+                    overwrite = if *overwrite { "OVERWRITE " } else { "" },
+                    table_name = table_name,
+                )?;
+                if let Some(ref parts) = &partitioned {
+                    if !parts.is_empty() {
+                        write!(f, " PARTITION ({})", 
display_comma_separated(parts))?;
+                    }
+                }
+                if let Some(HiveLoadDataFormat {
+                    serde,
+                    input_format,
+                }) = &table_format
+                {
+                    write!(f, " INPUTFORMAT {input_format} SERDE {serde}")?;
+                }
+                Ok(())
+            }
             Statement::CreateVirtualTable {
                 name,
                 if_not_exists,
@@ -5855,6 +5901,14 @@ pub enum HiveRowFormat {
     DELIMITED { delimiters: Vec<HiveRowDelimiter> },
 }
 
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct HiveLoadDataFormat {
+    pub serde: Expr,
+    pub input_format: Expr,
+}
+
 #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
 #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs
index e1b8db11..905b04e3 100644
--- a/src/dialect/duckdb.rs
+++ b/src/dialect/duckdb.rs
@@ -66,4 +66,9 @@ impl Dialect for DuckDbDialect {
     fn supports_explain_with_utility_options(&self) -> bool {
         true
     }
+
+    /// See DuckDB 
<https://duckdb.org/docs/sql/statements/load_and_install.html#load>
+    fn supports_load_extension(&self) -> bool {
+        true
+    }
 }
diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs
index 8cfac217..4998e0f4 100644
--- a/src/dialect/generic.rs
+++ b/src/dialect/generic.rs
@@ -115,4 +115,8 @@ impl Dialect for GenericDialect {
     fn supports_comment_on(&self) -> bool {
         true
     }
+
+    fn supports_load_extension(&self) -> bool {
+        true
+    }
 }
diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs
index b97bf69b..571f9b9b 100644
--- a/src/dialect/hive.rs
+++ b/src/dialect/hive.rs
@@ -56,4 +56,9 @@ impl Dialect for HiveDialect {
     fn supports_bang_not_operator(&self) -> bool {
         true
     }
+
+    /// See Hive 
<https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362036#LanguageManualDML-Loadingfilesintotables>
+    fn supports_load_data(&self) -> bool {
+        true
+    }
 }
diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs
index ee3fd471..956a5898 100644
--- a/src/dialect/mod.rs
+++ b/src/dialect/mod.rs
@@ -620,6 +620,16 @@ pub trait Dialect: Debug + Any {
         false
     }
 
+    /// Returns true if the dialect supports the `LOAD DATA` statement
+    fn supports_load_data(&self) -> bool {
+        false
+    }
+
+    /// Returns true if the dialect supports the `LOAD extension` statement
+    fn supports_load_extension(&self) -> bool {
+        false
+    }
+
     /// Returns true if this dialect expects the `TOP` option
     /// before the `ALL`/`DISTINCT` options in a `SELECT` statement.
     fn supports_top_before_distinct(&self) -> bool {
diff --git a/src/keywords.rs b/src/keywords.rs
index 9cdc90ce..79026821 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -389,6 +389,7 @@ define_keywords!(
     INITIALLY,
     INNER,
     INOUT,
+    INPATH,
     INPUT,
     INPUTFORMAT,
     INSENSITIVE,
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index a66a627b..a583112a 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -543,10 +543,7 @@ impl<'a> Parser<'a> {
                 Keyword::INSTALL if dialect_of!(self is DuckDbDialect | 
GenericDialect) => {
                     self.parse_install()
                 }
-                // `LOAD` is duckdb specific 
https://duckdb.org/docs/extensions/overview
-                Keyword::LOAD if dialect_of!(self is DuckDbDialect | 
GenericDialect) => {
-                    self.parse_load()
-                }
+                Keyword::LOAD => self.parse_load(),
                 // `OPTIMIZE` is clickhouse specific 
https://clickhouse.tech/docs/en/sql-reference/statements/optimize/
                 Keyword::OPTIMIZE if dialect_of!(self is ClickHouseDialect | 
GenericDialect) => {
                     self.parse_optimize_table()
@@ -11222,6 +11219,22 @@ impl<'a> Parser<'a> {
         }
     }
 
+    pub fn parse_load_data_table_format(
+        &mut self,
+    ) -> Result<Option<HiveLoadDataFormat>, ParserError> {
+        if self.parse_keyword(Keyword::INPUTFORMAT) {
+            let input_format = self.parse_expr()?;
+            self.expect_keyword(Keyword::SERDE)?;
+            let serde = self.parse_expr()?;
+            Ok(Some(HiveLoadDataFormat {
+                input_format,
+                serde,
+            }))
+        } else {
+            Ok(None)
+        }
+    }
+
     /// Parse an UPDATE statement, returning a `Box`ed SetExpr
     ///
     /// This is used to reduce the size of the stack frames in debug builds
@@ -12224,10 +12237,35 @@ impl<'a> Parser<'a> {
         Ok(Statement::Install { extension_name })
     }
 
-    /// `LOAD [extension_name]`
+    /// Parse a SQL LOAD statement
     pub fn parse_load(&mut self) -> Result<Statement, ParserError> {
-        let extension_name = self.parse_identifier(false)?;
-        Ok(Statement::Load { extension_name })
+        if self.dialect.supports_load_extension() {
+            let extension_name = self.parse_identifier(false)?;
+            Ok(Statement::Load { extension_name })
+        } else if self.parse_keyword(Keyword::DATA) && 
self.dialect.supports_load_data() {
+            let local = 
self.parse_one_of_keywords(&[Keyword::LOCAL]).is_some();
+            self.expect_keyword(Keyword::INPATH)?;
+            let inpath = self.parse_literal_string()?;
+            let overwrite = 
self.parse_one_of_keywords(&[Keyword::OVERWRITE]).is_some();
+            self.expect_keyword(Keyword::INTO)?;
+            self.expect_keyword(Keyword::TABLE)?;
+            let table_name = self.parse_object_name(false)?;
+            let partitioned = self.parse_insert_partition()?;
+            let table_format = self.parse_load_data_table_format()?;
+            Ok(Statement::LoadData {
+                local,
+                inpath,
+                overwrite,
+                table_name,
+                partitioned,
+                table_format,
+            })
+        } else {
+            self.expected(
+                "`DATA` or an extension name after `LOAD`",
+                self.peek_token(),
+            )
+        }
     }
 
     /// ```sql
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index 4fdbf7d5..2ffb5f44 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -11583,13 +11583,208 @@ fn parse_notify_channel() {
             dialects.parse_sql_statements(sql).unwrap_err(),
             ParserError::ParserError("Expected: an SQL statement, found: 
NOTIFY".to_string())
         );
-        assert_eq!(
-            dialects.parse_sql_statements(sql).unwrap_err(),
-            ParserError::ParserError("Expected: an SQL statement, found: 
NOTIFY".to_string())
-        );
     }
 }
 
+#[test]
+fn parse_load_data() {
+    let dialects = all_dialects_where(|d| d.supports_load_data());
+    let only_supports_load_extension_dialects =
+        all_dialects_where(|d| !d.supports_load_data() && 
d.supports_load_extension());
+    let not_supports_load_dialects =
+        all_dialects_where(|d| !d.supports_load_data() && 
!d.supports_load_extension());
+
+    let sql = "LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE 
test.my_table";
+    match dialects.verified_stmt(sql) {
+        Statement::LoadData {
+            local,
+            inpath,
+            overwrite,
+            table_name,
+            partitioned,
+            table_format,
+        } => {
+            assert_eq!(false, local);
+            assert_eq!("/local/path/to/data.txt", inpath);
+            assert_eq!(false, overwrite);
+            assert_eq!(
+                ObjectName(vec![Ident::new("test"), Ident::new("my_table")]),
+                table_name
+            );
+            assert_eq!(None, partitioned);
+            assert_eq!(None, table_format);
+        }
+        _ => unreachable!(),
+    };
+
+    // with OVERWRITE keyword
+    let sql = "LOAD DATA INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE 
my_table";
+    match dialects.verified_stmt(sql) {
+        Statement::LoadData {
+            local,
+            inpath,
+            overwrite,
+            table_name,
+            partitioned,
+            table_format,
+        } => {
+            assert_eq!(false, local);
+            assert_eq!("/local/path/to/data.txt", inpath);
+            assert_eq!(true, overwrite);
+            assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name);
+            assert_eq!(None, partitioned);
+            assert_eq!(None, table_format);
+        }
+        _ => unreachable!(),
+    };
+
+    assert_eq!(
+        only_supports_load_extension_dialects
+            .parse_sql_statements(sql)
+            .unwrap_err(),
+        ParserError::ParserError("Expected: end of statement, found: 
INPATH".to_string())
+    );
+    assert_eq!(
+        not_supports_load_dialects
+            .parse_sql_statements(sql)
+            .unwrap_err(),
+        ParserError::ParserError(
+            "Expected: `DATA` or an extension name after `LOAD`, found: 
INPATH".to_string()
+        )
+    );
+
+    // with LOCAL keyword
+    let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE 
test.my_table";
+    match dialects.verified_stmt(sql) {
+        Statement::LoadData {
+            local,
+            inpath,
+            overwrite,
+            table_name,
+            partitioned,
+            table_format,
+        } => {
+            assert_eq!(true, local);
+            assert_eq!("/local/path/to/data.txt", inpath);
+            assert_eq!(false, overwrite);
+            assert_eq!(
+                ObjectName(vec![Ident::new("test"), Ident::new("my_table")]),
+                table_name
+            );
+            assert_eq!(None, partitioned);
+            assert_eq!(None, table_format);
+        }
+        _ => unreachable!(),
+    };
+
+    assert_eq!(
+        only_supports_load_extension_dialects
+            .parse_sql_statements(sql)
+            .unwrap_err(),
+        ParserError::ParserError("Expected: end of statement, found: 
LOCAL".to_string())
+    );
+    assert_eq!(
+        not_supports_load_dialects
+            .parse_sql_statements(sql)
+            .unwrap_err(),
+        ParserError::ParserError(
+            "Expected: `DATA` or an extension name after `LOAD`, found: 
LOCAL".to_string()
+        )
+    );
+
+    // with PARTITION  clause
+    let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE 
my_table PARTITION (year = 2024, month = 11)";
+    match dialects.verified_stmt(sql) {
+        Statement::LoadData {
+            local,
+            inpath,
+            overwrite,
+            table_name,
+            partitioned,
+            table_format,
+        } => {
+            assert_eq!(true, local);
+            assert_eq!("/local/path/to/data.txt", inpath);
+            assert_eq!(false, overwrite);
+            assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name);
+            assert_eq!(
+                Some(vec![
+                    Expr::BinaryOp {
+                        left: Box::new(Expr::Identifier(Ident::new("year"))),
+                        op: BinaryOperator::Eq,
+                        right: 
Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))),
+                    },
+                    Expr::BinaryOp {
+                        left: Box::new(Expr::Identifier(Ident::new("month"))),
+                        op: BinaryOperator::Eq,
+                        right: 
Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))),
+                    }
+                ]),
+                partitioned
+            );
+            assert_eq!(None, table_format);
+        }
+        _ => unreachable!(),
+    };
+
+    // with PARTITION  clause
+    let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' OVERWRITE INTO 
TABLE good.my_table PARTITION (year = 2024, month = 11) INPUTFORMAT 
'org.apache.hadoop.mapred.TextInputFormat' SERDE 
'org.apache.hadoop.hive.serde2.OpenCSVSerde'";
+    match dialects.verified_stmt(sql) {
+        Statement::LoadData {
+            local,
+            inpath,
+            overwrite,
+            table_name,
+            partitioned,
+            table_format,
+        } => {
+            assert_eq!(true, local);
+            assert_eq!("/local/path/to/data.txt", inpath);
+            assert_eq!(true, overwrite);
+            assert_eq!(
+                ObjectName(vec![Ident::new("good"), Ident::new("my_table")]),
+                table_name
+            );
+            assert_eq!(
+                Some(vec![
+                    Expr::BinaryOp {
+                        left: Box::new(Expr::Identifier(Ident::new("year"))),
+                        op: BinaryOperator::Eq,
+                        right: 
Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))),
+                    },
+                    Expr::BinaryOp {
+                        left: Box::new(Expr::Identifier(Ident::new("month"))),
+                        op: BinaryOperator::Eq,
+                        right: 
Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))),
+                    }
+                ]),
+                partitioned
+            );
+            assert_eq!(
+                Some(HiveLoadDataFormat {
+                    serde: Expr::Value(Value::SingleQuotedString(
+                        
"org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string()
+                    )),
+                    input_format: Expr::Value(Value::SingleQuotedString(
+                        "org.apache.hadoop.mapred.TextInputFormat".to_string()
+                    ))
+                }),
+                table_format
+            );
+        }
+        _ => unreachable!(),
+    };
+
+    // negative test case
+    let sql = "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE 
test.my_table";
+    assert_eq!(
+        dialects.parse_sql_statements(sql).unwrap_err(),
+        ParserError::ParserError(
+            "Expected: `DATA` or an extension name after `LOAD`, found: 
DATA2".to_string()
+        )
+    );
+}
+
 #[test]
 fn test_select_top() {
     let dialects = all_dialects_where(|d| d.supports_top_before_distinct());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to