This is an automated email from the ASF dual-hosted git repository.

chunshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-horaedb.git


The following commit(s) were added to refs/heads/main by this push:
     new c454a8e7 feat: filter out MySQL federated components' emitted 
statements (#1439)
c454a8e7 is described below

commit c454a8e772a4d9424ff03defb8bb35f31745f758
Author: chunshao.rcs <[email protected]>
AuthorDate: Tue Jan 16 14:03:19 2024 +0800

    feat: filter out MySQL federated components' emitted statements (#1439)
    
    ## Rationale
    Some mysql syntax is not supported, such as `show variables`, and user
    errors will be returned.
    
    ## Detailed Changes
    * Introduce `server/src/federated.rs` to filter some mysql syntax and
    return success. `server/src/federated.rs` inspired by
    
[greptime](https://github.com/GreptimeTeam/greptimedb/blob/702ea32538a99e2d163fb1fbd3e75b1ce4ec4232/src/servers/src/mysql/federated.rs).
    Support following sql.
    ```
    select @@version_comment
    select version()
    select 1
    ```
    * Modify `scripts/license-header.txt`.
    
    ## Test Plan
    Add some test and paas CI.
    
    ---------
    
    Co-authored-by: Yingwen <[email protected]>
---
 Cargo.lock                  |  45 ++++-
 Cargo.toml                  |   6 +-
 scripts/license-header.txt  |  32 ++--
 server/Cargo.toml           |   4 +
 server/src/federated.rs     | 417 ++++++++++++++++++++++++++++++++++++++++++++
 server/src/lib.rs           |   2 +
 server/src/mysql/service.rs |   4 +-
 server/src/mysql/worker.rs  |  54 +++---
 server/src/session.rs       | 223 +++++++++++++++++++++++
 9 files changed, 739 insertions(+), 48 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 7118256d..9e29407d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2233,7 +2233,16 @@ version = "0.11.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "d07adf7be193b71cc36b193d0f5fe60b918a3a9db4dad0449f57bcfd519704a3"
 dependencies = [
- "derive_builder_macro",
+ "derive_builder_macro 0.11.2",
+]
+
+[[package]]
+name = "derive_builder"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
+dependencies = [
+ "derive_builder_macro 0.12.0",
 ]
 
 [[package]]
@@ -2248,13 +2257,35 @@ dependencies = [
  "syn 1.0.109",
 ]
 
+[[package]]
+name = "derive_builder_core"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
+dependencies = [
+ "darling 0.14.4",
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+]
+
 [[package]]
 name = "derive_builder_macro"
 version = "0.11.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "8f0314b72bed045f3a68671b3c86328386762c93f82d98c65c3cb5e5f573dd68"
 dependencies = [
- "derive_builder_core",
+ "derive_builder_core 0.11.2",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "derive_builder_macro"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
+dependencies = [
+ "derive_builder_core 0.12.0",
  "syn 1.0.109",
 ]
 
@@ -4517,9 +4548,9 @@ dependencies = [
 
 [[package]]
 name = "once_cell"
-version = "1.17.1"
+version = "1.19.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
+checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
 dependencies = [
  "parking_lot_core 0.9.7",
 ]
@@ -6313,6 +6344,7 @@ name = "server"
 version = "1.2.6-alpha"
 dependencies = [
  "analytic_engine",
+ "arc-swap 1.6.0",
  "arrow 43.0.0",
  "arrow_ext",
  "async-trait",
@@ -6322,6 +6354,7 @@ dependencies = [
  "cluster",
  "common_types",
  "datafusion",
+ "derive_builder 0.12.0",
  "df_operator",
  "flate2",
  "future_ext",
@@ -6336,6 +6369,7 @@ dependencies = [
  "macros",
  "meta_client",
  "notifier",
+ "once_cell",
  "opensrv-mysql",
  "partition_table_engine",
  "paste 1.0.12",
@@ -6348,6 +6382,7 @@ dependencies = [
  "proxy",
  "query_engine",
  "query_frontend",
+ "regex",
  "remote_engine_client",
  "router",
  "runtime",
@@ -6664,7 +6699,7 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "0860f149718809371602b42573693e1ed2b1d0aed35fe69e04e4e4e9918d81f7"
 dependencies = [
  "async-trait",
- "derive_builder",
+ "derive_builder 0.11.2",
  "prettydiff",
  "regex",
  "thiserror",
diff --git a/Cargo.toml b/Cargo.toml
index fe02d11e..24a9b13a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -108,6 +108,7 @@ horaedb-client = "1.0.2"
 common_types = { path = "common_types" }
 datafusion = { git = "https://github.com/CeresDB/arrow-datafusion.git";, rev = 
"9c3a537e25e5ab3299922864034f67fb2f79805d" }
 datafusion-proto = { git = "https://github.com/CeresDB/arrow-datafusion.git";, 
rev = "9c3a537e25e5ab3299922864034f67fb2f79805d" }
+derive_builder = "0.12"
 df_operator = { path = "df_operator" }
 df_engine_extensions = { path = "df_engine_extensions" }
 future_ext = { path = "components/future_ext" }
@@ -135,6 +136,7 @@ meta_client = { path = "meta_client" }
 metric_ext = { path = "components/metric_ext" }
 notifier = { path = "components/notifier" }
 object_store = { path = "components/object_store" }
+once_cell = "1.18"
 panic_ext = { path = "components/panic_ext" }
 partitioned_lock = { path = "components/partitioned_lock" }
 partition_table_engine = { path = "partition_table_engine" }
@@ -152,6 +154,7 @@ proxy = { path = "proxy" }
 query_engine = { path = "query_engine" }
 query_frontend = { path = "query_frontend" }
 rand = "0.7"
+regex = "1"
 remote_engine_client = { path = "remote_engine_client" }
 reqwest = { version = "0.11", default-features = false, features = [
     "rustls-tls",
@@ -184,11 +187,10 @@ trace_metric_derive = { path = 
"components/trace_metric_derive" }
 trace_metric_derive_tests = { path = "components/trace_metric_derive_tests" }
 tonic = "0.8.1"
 tokio = { version = "1.29", features = ["full"] }
+uuid = "1.6.1"
 wal = { path = "src/wal" }
 xorfilter-rs = { git = "https://github.com/CeresDB/xorfilter";, rev = "ac8ef01" 
}
 zstd = { version = "0.12", default-features = false }
-uuid = "1.6.1"
-regex = "1"
 
 # This profile optimizes for good runtime performance.
 [profile.release]
diff --git a/scripts/license-header.txt b/scripts/license-header.txt
index d216be4d..90705e02 100644
--- a/scripts/license-header.txt
+++ b/scripts/license-header.txt
@@ -1,16 +1,16 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
\ No newline at end of file
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
\ No newline at end of file
diff --git a/server/Cargo.toml b/server/Cargo.toml
index 7dd26bd1..19b0ddd1 100644
--- a/server/Cargo.toml
+++ b/server/Cargo.toml
@@ -32,6 +32,7 @@ workspace = true
 
 [dependencies]
 analytic_engine = { workspace = true }
+arc-swap = "1.5"
 arrow = { workspace = true }
 arrow_ext = { workspace = true }
 async-trait = { workspace = true }
@@ -41,6 +42,7 @@ clru = { workspace = true }
 cluster = { workspace = true }
 common_types = { workspace = true }
 datafusion = { workspace = true }
+derive_builder = { workspace = true }
 df_operator = { workspace = true }
 flate2 = "1.0"
 future_ext = { workspace = true }
@@ -55,6 +57,7 @@ logger = { workspace = true }
 macros = { workspace = true }
 meta_client = { workspace = true }
 notifier = { workspace = true }
+once_cell = { workspace = true }
 opensrv-mysql = "0.1.0"
 partition_table_engine = { workspace = true }
 paste = { workspace = true }
@@ -67,6 +70,7 @@ prost = { workspace = true }
 proxy = { workspace = true }
 query_engine = { workspace = true }
 query_frontend = { workspace = true }
+regex = { workspace = true }
 remote_engine_client = { workspace = true }
 router = { workspace = true }
 runtime = { workspace = true }
diff --git a/server/src/federated.rs b/server/src/federated.rs
new file mode 100644
index 00000000..f1636445
--- /dev/null
+++ b/server/src/federated.rs
@@ -0,0 +1,417 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Forked from 
https://github.com/GreptimeTeam/greptimedb/blob/702ea32538a99e2d163fb1fbd3e75b1ce4ec4232/src/servers/src/mysql/federated.rs.
+
+// Copyright 2023 Greptime Team
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//! Use regex to filter out some MySQL federated components' emitted 
statements.
+//! Inspired by Databend's 
"[mysql_federated.rs](https://github.com/datafuselabs/databend/blob/ac706bf65845e6895141c96c0a10bad6fdc2d367/src/query/service/src/servers/mysql/mysql_federated.rs)".
+
+use std::{collections::HashMap, env, sync::Arc};
+
+use arrow::{
+    array::StringArray,
+    datatypes::{DataType, Field, Schema},
+    record_batch::RecordBatch,
+};
+use interpreters::{interpreter::Output, RecordBatchVec};
+use once_cell::sync::Lazy;
+use regex::{bytes::RegexSet, Regex};
+
+use crate::session::SessionRef;
+
+static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT 
@@(.*))").unwrap());
+static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
+static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 
'lower_case_table_names'(.*))").unwrap());
+static SHOW_COLLATION_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new("(?i)^(show collation where(.*))").unwrap());
+static SHOW_VARIABLES_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES(.*))").unwrap());
+
+static SELECT_VERSION_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new(r"(?i)^(SELECT VERSION\(\s*\))").unwrap());
+static SELECT_DATABASE_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new(r"(?i)^(SELECT DATABASE\(\s*\))").unwrap());
+
+// SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP());
+static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new("(?i)^(SELECT TIMEDIFF\\(NOW\\(\\), 
UTC_TIMESTAMP\\(\\)\\))").unwrap());
+
+// sqlalchemy < 1.4.30
+static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
+    Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 
'sql_mode'(.*))").unwrap());
+
+static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
+    RegexSet::new([
+        // Txn.
+        "(?i)^(ROLLBACK(.*))",
+        "(?i)^(COMMIT(.*))",
+        "(?i)^(START(.*))",
+
+        // Set.
+        "(?i)^(SET NAMES(.*))",
+        "(?i)^(SET character_set_results(.*))",
+        "(?i)^(SET net_write_timeout(.*))",
+        "(?i)^(SET FOREIGN_KEY_CHECKS(.*))",
+        "(?i)^(SET AUTOCOMMIT(.*))",
+        "(?i)^(SET SQL_LOG_BIN(.*))",
+        "(?i)^(SET sql_mode(.*))",
+        "(?i)^(SET SQL_SELECT_LIMIT(.*))",
+        "(?i)^(SET @@(.*))",
+
+        "(?i)^(SHOW COLLATION)",
+        "(?i)^(SHOW CHARSET)",
+
+        // mysqlclient.
+        "(?i)^(SELECT \\$\\$)",
+
+        // mysqldump.
+        "(?i)^(SET SESSION(.*))",
+        "(?i)^(SET SQL_QUOTE_SHOW_CREATE(.*))",
+        "(?i)^(LOCK TABLES(.*))",
+        "(?i)^(UNLOCK TABLES(.*))",
+        "(?i)^(SELECT LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, 
INITIAL_SIZE, ENGINE, EXTRA FROM INFORMATION_SCHEMA.FILES(.*))",
+
+        // mydumper.
+        "(?i)^(/\\*!80003 SET(.*) \\*/)$",
+        "(?i)^(SHOW MASTER STATUS)",
+        "(?i)^(SHOW ALL SLAVES STATUS)",
+        "(?i)^(LOCK BINLOG FOR BACKUP)",
+        "(?i)^(LOCK TABLES FOR BACKUP)",
+        "(?i)^(UNLOCK BINLOG(.*))",
+        "(?i)^(/\\*!40101 SET(.*) \\*/)$",
+
+        // DBeaver.
+        "(?i)^(SHOW WARNINGS)",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW WARNINGS)",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW PLUGINS)",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW COLLATION)",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW CHARSET)",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW ENGINES)",
+        "(?i)^(/\\* ApplicationName=(.*)SELECT @@(.*))",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW @@(.*))",
+        "(?i)^(/\\* ApplicationName=(.*)SET net_write_timeout(.*))",
+        "(?i)^(/\\* ApplicationName=(.*)SET SQL_SELECT_LIMIT(.*))",
+        "(?i)^(/\\* ApplicationName=(.*)SHOW VARIABLES(.*))",
+
+        // pt-toolkit
+        "(?i)^(/\\*!40101 SET(.*) \\*/)$",
+
+        // mysqldump 5.7.16
+        "(?i)^(/\\*!40100 SET(.*) \\*/)$",
+        "(?i)^(/\\*!40103 SET(.*) \\*/)$",
+        "(?i)^(/\\*!40111 SET(.*) \\*/)$",
+        "(?i)^(/\\*!40101 SET(.*) \\*/)$",
+        "(?i)^(/\\*!40014 SET(.*) \\*/)$",
+        "(?i)^(/\\*!40000 SET(.*) \\*/)$",
+    ]).unwrap()
+});
+
+static VAR_VALUES: Lazy<HashMap<&str, &str>> =
+    Lazy::new(|| HashMap::from([("version_comment", "Apache HoraeDB")]));
+
+// RecordBatchVec for select function.
+// Format:
+// |function_name|
+// |value|
+fn select_function(name: &str, value: &str) -> RecordBatchVec {
+    let schema = Schema::new(vec![Field::new(name, DataType::Utf8, false)]);
+
+    let arrow_record_batch = RecordBatch::try_new(
+        Arc::new(schema),
+        vec![Arc::new(StringArray::from(vec![value]))],
+    )
+    .unwrap();
+
+    let record_batch = arrow_record_batch.try_into().unwrap();
+
+    vec![record_batch]
+}
+
+// RecordbatchVec for show variable statement.
+// Format is:
+// | Variable_name | Value |
+// | xx            | yy    |
+fn show_variables(name: &str, value: &str) -> RecordBatchVec {
+    let schema = Schema::new(vec![
+        Field::new("Variable_name", DataType::Utf8, false),
+        Field::new("Value", DataType::Utf8, false),
+    ]);
+
+    let arrow_record_batch = RecordBatch::try_new(
+        Arc::new(schema),
+        vec![
+            Arc::new(StringArray::from(vec![name])),
+            Arc::new(StringArray::from(vec![value])),
+        ],
+    )
+    .unwrap();
+
+    let record_batch = arrow_record_batch.try_into().unwrap();
+
+    vec![record_batch]
+}
+
+fn select_variable(query: &str, _session: SessionRef) -> Option<Output> {
+    let mut fields: Vec<Field> = vec![];
+    let mut values: Vec<Arc<(dyn arrow::array::Array + 'static)>> = vec![];
+
+    // query like "SELECT @@aa, @@bb as cc, @dd..."
+    let query = query.to_lowercase();
+    let vars: Vec<&str> = query.split("@@").collect();
+    if vars.len() <= 1 {
+        return None;
+    }
+
+    // skip the first "select"
+    for var in vars.iter().skip(1) {
+        let var = var.trim_matches(|c| c == ' ' || c == ',');
+        let var_as: Vec<&str> = var
+            .split(" as ")
+            .map(|x| {
+                x.trim_matches(|c| c == ' ')
+                    .split_whitespace()
+                    .next()
+                    .unwrap_or("")
+            })
+            .collect();
+
+        // get value of variables from known sources or fallback to defaults
+        let value = VAR_VALUES
+            .get(var_as[0])
+            .map(|v| v.to_string())
+            .unwrap_or_else(|| "0".to_owned());
+
+        values.push(Arc::new(StringArray::from(vec![value])));
+
+        match var_as.len() {
+            1 => {
+                // @@aa
+                // field is '@@aa'
+                fields.push(Field::new(
+                    &format!("@@{}", var_as[0]),
+                    DataType::Utf8,
+                    true,
+                ));
+            }
+            2 => {
+                // @@bb as cc:
+                // var is 'bb'.
+                // field is 'cc'.
+                fields.push(Field::new(var_as[1], DataType::Utf8, true));
+            }
+            _ => return None,
+        }
+    }
+
+    let schema = Schema::new(fields);
+    let arrow_record_batch = RecordBatch::try_new(Arc::new(schema), 
values).unwrap();
+
+    let record_batch = arrow_record_batch.try_into().unwrap();
+
+    Some(Output::Records(vec![record_batch]))
+}
+
+fn check_select_variable(query: &str, session: SessionRef) -> Option<Output> {
+    if [&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
+        .iter()
+        .any(|r| r.is_match(query))
+    {
+        select_variable(query, session)
+    } else {
+        None
+    }
+}
+
+fn check_show_variables(query: &str) -> Option<Output> {
+    let record_batch_vec = if SHOW_SQL_MODE_PATTERN.is_match(query) {
+        Some(show_variables("sql_mode", "ONLY_FULL_GROUP_BY 
STRICT_TRANS_TABLES NO_ZERO_IN_DATE NO_ZERO_DATE ERROR_FOR_DIVISION_BY_ZERO 
NO_ENGINE_SUBSTITUTION"))
+    } else if SHOW_LOWER_CASE_PATTERN.is_match(query) {
+        Some(show_variables("lower_case_table_names", "0"))
+    } else if SHOW_COLLATION_PATTERN.is_match(query) || 
SHOW_VARIABLES_PATTERN.is_match(query) {
+        Some(show_variables("", ""))
+    } else {
+        None
+    };
+    record_batch_vec.map(Output::Records)
+}
+
+// TODO(sunng87): extract this to use sqlparser for more variables
+fn check_set_variables(_query: &str, _session: SessionRef) -> Option<Output> {
+    None
+}
+
+// Check for SET or others query, this is the final check of the federated
+// query.
+fn check_others(query: &str, session: SessionRef) -> Option<Output> {
+    if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
+        return Some(Output::Records(Vec::new()));
+    }
+
+    let record_batch_vec = if SELECT_VERSION_PATTERN.is_match(query) {
+        Some(select_function("version()", &get_version()))
+    } else if SELECT_DATABASE_PATTERN.is_match(query) {
+        let schema = session.schema();
+        Some(select_function("database()", &schema))
+    } else if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
+        Some(select_function(
+            "TIMEDIFF(NOW(), UTC_TIMESTAMP())",
+            "00:00:00",
+        ))
+    } else {
+        None
+    };
+    record_batch_vec.map(Output::Records)
+}
+
+// Check whether the query is a federated or driver setup command,
+// and return some faked results if there are any.
+pub(crate) fn check(query: &str, session: SessionRef) -> Option<Output> {
+    // INSERT don't need MySQL federated check. We assume the query doesn't 
contain
+    // federated or driver setup command if it starts with a 'INSERT' 
statement.
+    if query.len() > 6 && query[..6].eq_ignore_ascii_case("INSERT") {
+        return None;
+    }
+
+    // First to check the query is like "select @@variables".
+    check_select_variable(query, session.clone())
+        // Then to check "show variables like ...".
+        .or_else(|| check_show_variables(query))
+        .or_else(|| check_set_variables(query, session.clone()))
+        // Last check
+        .or_else(|| check_others(query, session))
+}
+
+// get HoraeDB's version.
+fn get_version() -> String {
+    format!("{}-horaedb", env!("CARGO_PKG_VERSION"))
+}
+#[cfg(test)]
+mod test {
+    use arrow::util::pretty;
+
+    use super::*;
+    use crate::session::{Channel, Session};
+    fn pretty_print(data: RecordBatchVec) -> String {
+        let df_batches = &data
+            .iter()
+            .map(|x| x.as_arrow_record_batch().clone())
+            .collect::<Vec<_>>();
+        let result = pretty::pretty_format_batches(df_batches).unwrap();
+
+        result.to_string()
+    }
+    #[test]
+    fn test_check() {
+        let session = Arc::new(Session::new(None, Channel::Mysql));
+        let query = "select 1";
+        let result = check(query, session.clone());
+        assert!(result.is_none());
+
+        let query = "select version";
+        let output = check(query, session.clone());
+        assert!(output.is_none());
+
+        fn test(query: &str, expected: &str) {
+            let session = Arc::new(Session::new(None, Channel::Mysql));
+            let output = check(query, session.clone());
+            match output.unwrap() {
+                Output::Records(r) => {
+                    assert_eq!(pretty_print(r), expected)
+                }
+                _ => unreachable!(),
+            }
+        }
+
+        let query = "select version()";
+        let version = env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| 
"unknown".to_string());
+        let output = check(query, session.clone());
+        match output.unwrap() {
+            Output::Records(r) => {
+                
assert!(pretty_print(r).contains(&format!("{version}-horaedb")));
+            }
+            _ => unreachable!(),
+        }
+
+        let query = "SELECT @@version_comment LIMIT 1";
+        let expected = "\
++-------------------+
+| @@version_comment |
++-------------------+
+| Apache HoraeDB    |
++-------------------+";
+        test(query, expected);
+
+        // complex variables
+        let query = "/* mysql-connector-java-8.0.17 (Revision: 
16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT  
@@session.auto_increment_increment AS auto_increment_increment, 
@@character_set_client AS character_set_client, @@character_set_connection AS 
character_set_connection, @@character_set_results AS character_set_results, 
@@character_set_server AS character_set_server, @@collation_server AS 
collation_server, @@collation_connection AS collation_connection, 
@@init_connect AS init_ [...]
+        let expected = "\
++--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+-----------------------+---------------+
+| auto_increment_increment | character_set_client | character_set_connection | 
character_set_results | character_set_server | collation_server | 
collation_connection | init_connect | interactive_timeout | license | 
lower_case_table_names | max_allowed_packet | net_write_timeout | 
performance_schema | sql_mode | transaction_isolation | wait_timeout; |
++--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+-----------------------+---------------+
+| 0                        | 0                    | 0                        | 
0                     | 0                    | 0                | 0             
       | 0            | 0                   | 0       | 0                      
| 0                  | 0                 | 0                  | 0        | 0    
                 | 0             |
++--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+-----------------------+---------------+";
+        test(query, expected);
+
+        let query = "show variables";
+        let expected = "\
++---------------+-------+
+| Variable_name | Value |
++---------------+-------+
+|               |       |
++---------------+-------+";
+        test(query, expected);
+
+        let query = "show variables like 'lower_case_table_names'";
+        let expected = "\
++------------------------+-------+
+| Variable_name          | Value |
++------------------------+-------+
+| lower_case_table_names | 0     |
++------------------------+-------+";
+        test(query, expected);
+
+        let query = "show collation";
+        let expected = "\
+++
+++"; // empty
+        test(query, expected);
+
+        let query = "SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP())";
+        let expected = "\
++----------------------------------+
+| TIMEDIFF(NOW(), UTC_TIMESTAMP()) |
++----------------------------------+
+| 00:00:00                         |
++----------------------------------+";
+        test(query, expected);
+    }
+}
diff --git a/server/src/lib.rs b/server/src/lib.rs
index d50af7e7..6cf4b33a 100644
--- a/server/src/lib.rs
+++ b/server/src/lib.rs
@@ -23,6 +23,7 @@
 pub mod config;
 mod consts;
 mod error_util;
+mod federated;
 mod grpc;
 mod http;
 pub mod local_tables;
@@ -30,3 +31,4 @@ mod metrics;
 mod mysql;
 mod postgresql;
 pub mod server;
+mod session;
diff --git a/server/src/mysql/service.rs b/server/src/mysql/service.rs
index b87e5c46..b5e2d2af 100644
--- a/server/src/mysql/service.rs
+++ b/server/src/mysql/service.rs
@@ -93,7 +93,7 @@ impl MysqlService {
         loop {
             tokio::select! {
                 conn_result = listener.accept() => {
-                    let (stream, _) = match conn_result {
+                    let (stream, addr) = match conn_result {
                         Ok((s, addr)) => (s, addr),
                         Err(err) => {
                             error!("Mysql Server accept new connection fail. 
err: {}", err);
@@ -104,7 +104,7 @@ impl MysqlService {
 
                     let rt = runtimes.read_runtime.clone();
                     rt.spawn(AsyncMysqlIntermediary::run_on(
-                        MysqlWorker::new(proxy,  timeout),
+                        MysqlWorker::new(proxy, addr, timeout),
                         stream,
                     ));
                 },
diff --git a/server/src/mysql/worker.rs b/server/src/mysql/worker.rs
index a704c967..b25e756b 100644
--- a/server/src/mysql/worker.rs
+++ b/server/src/mysql/worker.rs
@@ -15,23 +15,30 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{marker::PhantomData, sync::Arc, time::Duration};
+use std::{marker::PhantomData, net::SocketAddr, sync::Arc, time::Duration};
 
 use generic_error::BoxError;
 use interpreters::interpreter::Output;
 use logger::{error, info};
-use opensrv_mysql::{AsyncMysqlShim, ErrorKind, QueryResultWriter, 
StatementMetaWriter};
+use opensrv_mysql::{
+    AsyncMysqlShim, ErrorKind, InitWriter, QueryResultWriter, 
StatementMetaWriter,
+};
 use proxy::{context::RequestContext, http::sql::Request, Proxy};
 use snafu::ResultExt;
 
-use crate::mysql::{
-    error::{CreateContext, HandleSql, Result},
-    writer::MysqlQueryResultWriter,
+use crate::{
+    federated,
+    mysql::{
+        error::{CreateContext, HandleSql, Result},
+        writer::MysqlQueryResultWriter,
+    },
+    session::{parse_catalog_and_schema_from_db_string, Channel, Session, 
SessionRef},
 };
 
 pub struct MysqlWorker<W: std::io::Write + Send + Sync> {
     generic_hold: PhantomData<W>,
     proxy: Arc<Proxy>,
+    session: SessionRef,
     timeout: Option<Duration>,
 }
 
@@ -39,10 +46,11 @@ impl<W> MysqlWorker<W>
 where
     W: std::io::Write + Send + Sync,
 {
-    pub fn new(proxy: Arc<Proxy>, timeout: Option<Duration>) -> Self {
+    pub fn new(proxy: Arc<Proxy>, add: SocketAddr, timeout: Option<Duration>) 
-> Self {
         Self {
             generic_hold: PhantomData,
             proxy,
+            session: Arc::new(Session::new(Some(add), Channel::Mysql)),
             timeout,
         }
     }
@@ -102,6 +110,15 @@ where
             }
         }
     }
+
+    async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, 
W>) -> Result<()> {
+        let (catalog, schema) = 
parse_catalog_and_schema_from_db_string(database);
+
+        self.session.set_catalog(catalog.into());
+        self.session.set_schema(schema.into());
+
+        w.ok().map_err(|e| e.into())
+    }
 }
 
 impl<W> MysqlWorker<W>
@@ -109,10 +126,14 @@ where
     W: std::io::Write + Send + Sync,
 {
     async fn do_query<'a>(&'a mut self, sql: &'a str) -> Result<Output> {
-        let ctx = self.create_ctx()?;
+        if let Some(output) = federated::check(sql, self.session.clone()) {
+            return Ok(output);
+        }
+
         let req = Request {
             query: sql.to_string(),
         };
+        let ctx = self.create_ctx(self.session.clone())?;
         self.proxy
             .handle_http_sql_query(&ctx, req)
             .await
@@ -126,23 +147,10 @@ where
             })
     }
 
-    fn create_ctx(&self) -> Result<RequestContext> {
-        let default_catalog = self
-            .proxy
-            .instance()
-            .catalog_manager
-            .default_catalog_name()
-            .to_string();
-        let default_schema = self
-            .proxy
-            .instance()
-            .catalog_manager
-            .default_schema_name()
-            .to_string();
-
+    fn create_ctx(&self, session: SessionRef) -> Result<RequestContext> {
         RequestContext::builder()
-            .catalog(default_catalog)
-            .schema(default_schema)
+            .catalog(session.catalog().to_string())
+            .schema(session.schema().to_string())
             .timeout(self.timeout)
             .build()
             .context(CreateContext)
diff --git a/server/src/session.rs b/server/src/session.rs
new file mode 100644
index 00000000..42733f40
--- /dev/null
+++ b/server/src/session.rs
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Forked from 
https://github.com/GreptimeTeam/greptimedb/blob/ca4d690424b03806ea0f8bd5e491585224bbf220/src/session/src/lib.rs
+
+// Copyright 2023 Greptime Team
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+    fmt::{Display, Formatter},
+    net::SocketAddr,
+    sync::Arc,
+};
+
+use arc_swap::ArcSwap;
+use catalog::consts::{DEFAULT_CATALOG, DEFAULT_SCHEMA};
+use sqlparser::dialect::{Dialect, MySqlDialect, PostgreSqlDialect};
+
+/// Session for persistent connection such as MySQL, PostgreSQL etc.
+#[derive(Debug)]
+pub struct Session {
+    catalog: ArcSwap<String>,
+    schema: ArcSwap<String>,
+    conn_info: ConnInfo,
+}
+
+pub type SessionRef = Arc<Session>;
+
+impl Session {
+    pub fn new(addr: Option<SocketAddr>, channel: Channel) -> Self {
+        Session {
+            catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG.into())),
+            schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA.into())),
+            conn_info: ConnInfo::new(addr, channel),
+        }
+    }
+
+    #[inline]
+    #[allow(dead_code)]
+    pub fn conn_info(&self) -> &ConnInfo {
+        &self.conn_info
+    }
+
+    #[inline]
+    pub fn catalog(&self) -> String {
+        self.catalog.load().to_string()
+    }
+
+    #[inline]
+    pub fn schema(&self) -> String {
+        self.schema.load().to_string()
+    }
+
+    #[inline]
+    pub fn set_catalog(&self, catalog: String) {
+        self.catalog.store(Arc::new(catalog));
+    }
+
+    #[inline]
+    pub fn set_schema(&self, schema: String) {
+        self.schema.store(Arc::new(schema));
+    }
+}
+
+#[derive(Debug)]
+pub struct ConnInfo {
+    pub client_addr: Option<SocketAddr>,
+    pub channel: Channel,
+}
+
+impl Display for ConnInfo {
+    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+        write!(
+            f,
+            "{}[{}]",
+            self.channel,
+            self.client_addr
+                .map(|addr| addr.to_string())
+                .as_deref()
+                .unwrap_or("unknown client addr")
+        )
+    }
+}
+
+impl ConnInfo {
+    pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
+        Self {
+            client_addr,
+            channel,
+        }
+    }
+}
+
+#[derive(Debug, PartialEq)]
+#[allow(dead_code)]
+pub enum Channel {
+    Mysql,
+    Postgres,
+}
+
+impl Channel {
+    #[allow(dead_code)]
+    pub fn dialect(&self) -> Box<dyn Dialect + Send + Sync> {
+        match self {
+            Channel::Mysql => Box::new(MySqlDialect {}),
+            Channel::Postgres => Box::new(PostgreSqlDialect {}),
+        }
+    }
+}
+
+impl Display for Channel {
+    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+        match self {
+            Channel::Mysql => write!(f, "mysql"),
+            Channel::Postgres => write!(f, "postgres"),
+        }
+    }
+}
+
+/// Attempt to parse catalog and schema from given database name
+///
+/// The database name may come from different sources:
+///
+/// - MySQL `schema` name in MySQL protocol login request: it's optional and
+///   user
+/// and switch database using `USE` command
+/// - Postgres `database` parameter in Postgres wire protocol, required
+/// - HTTP RESTful API: the database parameter, optional
+/// - gRPC: the dbname field in header, optional but has a higher priority than
+/// original catalog/schema
+///
+/// When database name is provided, we attempt to parse catalog and schema from
+/// it. We assume the format `[<catalog>-]<schema>`:
+///
+/// - If `[<catalog>-]` part is not provided, we use whole database name as
+/// schema name
+/// - if `[<catalog>-]` is provided, we split database name with `-` and use
+/// `<catalog>` and `<schema>`.
+pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (&str, &str) {
+    let parts = db.splitn(2, '-').collect::<Vec<&str>>();
+    if parts.len() == 2 {
+        (parts[0], parts[1])
+    } else {
+        (DEFAULT_CATALOG, db)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::session::Session;
+
+    /// Build db name from catalog and schema string
+    fn build_db_string(catalog: &str, schema: &str) -> String {
+        if catalog == DEFAULT_CATALOG {
+            schema.to_string()
+        } else {
+            format!("{catalog}-{schema}")
+        }
+    }
+
+    #[test]
+    fn test_db_string() {
+        assert_eq!("test", build_db_string(DEFAULT_CATALOG, "test"));
+        assert_eq!("a0b1c2d3-test", build_db_string("a0b1c2d3", "test"));
+    }
+
+    #[test]
+    fn test_parse_catalog_and_schema() {
+        assert_eq!(
+            (DEFAULT_CATALOG, "fullschema"),
+            parse_catalog_and_schema_from_db_string("fullschema")
+        );
+
+        assert_eq!(
+            ("catalog", "schema"),
+            parse_catalog_and_schema_from_db_string("catalog-schema")
+        );
+
+        assert_eq!(
+            ("catalog", "schema1-schema2"),
+            parse_catalog_and_schema_from_db_string("catalog-schema1-schema2")
+        );
+    }
+
+    #[test]
+    fn test_session() {
+        let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), 
Channel::Mysql);
+
+        // test channel
+        assert_eq!(session.conn_info().channel, Channel::Mysql);
+        let client_addr = session.conn_info().client_addr.as_ref().unwrap();
+        assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
+        assert_eq!(client_addr.port(), 9000);
+
+        assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to