This is an automated email from the ASF dual-hosted git repository. HTHou pushed a commit to branch codex/add-tls-mtls-support in repository https://gitbox.apache.org/repos/asf/iotdb-client-go.git
commit f105c2668fc76ea0aea3fb26df7fc8ba45db728e Author: HTHou <[email protected]> AuthorDate: Fri Jun 26 17:13:08 2026 +0800 Add TLS and mTLS support --- README.md | 23 ++++++++ README_ZH.md | 23 ++++++++ client/session.go | 52 ++++++++---------- client/sessionpool.go | 3 + client/tls.go | 120 ++++++++++++++++++++++++++++++++++++++++ client/tls_test.go | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 340 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index a963a8e..2045b53 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,29 @@ curl -o session_example.go -L https://github.com/apache/iotdb-client-go/raw/main go run session_example.go ``` +## TLS/mTLS + +Set `TLSConfig` on `client.Config`, `client.ClusterConfig`, or `client.PoolConfig` to enable TLS. Add `CertFile` and `KeyFile` when the server requires mTLS client authentication. + +```golang +config := &client.Config{ + Host: host, + Port: port, + UserName: user, + Password: password, + TLSConfig: &client.TLSConfig{ + CAFile: "/path/to/ca.pem", + CertFile: "/path/to/client.pem", + KeyFile: "/path/to/client-key.pem", + }, +} +session := client.NewSession(config) +if err := session.Open(false, 0); err != nil { + log.Fatal(err) +} +defer session.Close() +``` + ## How to Use the SessionPool SessionPool is a wrapper of a Session Set. Using SessionPool, the user do not need to consider how to reuse a session connection. diff --git a/README_ZH.md b/README_ZH.md index e856172..e680123 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -76,6 +76,29 @@ curl -o session_example.go -L https://github.com/apache/iotdb-client-go/raw/main go run session_example.go ``` +## TLS/mTLS + +在 `client.Config`、`client.ClusterConfig` 或 `client.PoolConfig` 上设置 `TLSConfig` 即可启用 TLS。如果服务端要求 mTLS 客户端认证,同时设置 `CertFile` 和 `KeyFile`。 + +```golang +config := &client.Config{ + Host: host, + Port: port, + UserName: user, + Password: password, + TLSConfig: &client.TLSConfig{ + CAFile: "/path/to/ca.pem", + CertFile: "/path/to/client.pem", + KeyFile: "/path/to/client-key.pem", + }, +} +session := client.NewSession(config) +if err := session.Open(false, 0); err != nil { + log.Fatal(err) +} +defer session.Close() +``` + ## SessionPool 通过SessionPool管理session,用户不需要考虑如何重用session,当到达pool的最大值时,获取session的请求会阻塞 注意:session使用完成后需要调用PutBack方法 diff --git a/client/session.go b/client/session.go index 28b326e..7ac75f1 100644 --- a/client/session.go +++ b/client/session.go @@ -26,7 +26,6 @@ import ( "errors" "fmt" "log" - "net" "reflect" "sort" "strings" @@ -68,6 +67,7 @@ type Config struct { sqlDialect string Version Version Database string + TLSConfig *TLSConfig } type Session struct { @@ -100,13 +100,10 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err var err error - // In thrift 0.14.1, this func returns two values; in newer versions, it returns one. - s.trans = thrift.NewTSocketConf(net.JoinHostPort(s.config.Host, s.config.Port), &thrift.TConfiguration{ - ConnectTimeout: time.Duration(connectionTimeoutInMs) * time.Millisecond, // Use 0 for no timeout - }) - // s.trans = thrift.NewTFramedTransport(s.trans) // deprecated - tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE} - s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf) + s.trans, err = newTransport(s.config.Host, s.config.Port, connectionTimeoutInMs, s.config.TLSConfig) + if err != nil { + return err + } if !s.trans.IsOpen() { err = s.trans.Open() if err != nil { @@ -154,6 +151,7 @@ type ClusterConfig struct { ConnectRetryMax int sqlDialect string Database string + TLSConfig *TLSConfig } func (s *Session) OpenCluster(enableRPCCompression bool) error { @@ -1328,24 +1326,23 @@ func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) (Session, err var err error for i := range session.endPointList { ep := session.endPointList[i] - session.trans = thrift.NewTSocketConf(net.JoinHostPort(ep.Host, ep.Port), &thrift.TConfiguration{ - ConnectTimeout: time.Duration(0), // Use 0 for no timeout - }) - // session.trans = thrift.NewTFramedTransport(session.trans) // deprecated - tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE} - session.trans = thrift.NewTFramedTransportConf(session.trans, &tmp_conf) + session.trans, err = newTransport(ep.Host, ep.Port, 0, clusterConfig.TLSConfig) + if err != nil { + log.Println(err) + continue + } if !session.trans.IsOpen() { err = session.trans.Open() if err != nil { log.Println(err) } else { session.config = getConfig(ep.Host, ep.Port, - clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect) + clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect, clusterConfig.TLSConfig) break } } } - if !session.trans.IsOpen() { + if session.trans == nil || !session.trans.IsOpen() { return session, fmt.Errorf("no server can connect") } return session, nil @@ -1354,18 +1351,14 @@ func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) (Session, err func (s *Session) initClusterConn(node endPoint) error { var err error - s.trans = thrift.NewTSocketConf(net.JoinHostPort(node.Host, node.Port), &thrift.TConfiguration{ - ConnectTimeout: time.Duration(0), // Use 0 for no timeout - }) - if err == nil { - // s.trans = thrift.NewTFramedTransport(s.trans) // deprecated - tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE} - s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf) - if !s.trans.IsOpen() { - err = s.trans.Open() - if err != nil { - return err - } + s.trans, err = newTransport(node.Host, node.Port, 0, s.config.TLSConfig) + if err != nil { + return err + } + if !s.trans.IsOpen() { + err = s.trans.Open() + if err != nil { + return err } } @@ -1398,7 +1391,7 @@ func (s *Session) initClusterConn(node endPoint) error { return err } -func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string) *Config { +func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string, tlsConfig *TLSConfig) *Config { return &Config{ Host: host, Port: port, @@ -1409,6 +1402,7 @@ func getConfig(host string, port string, userName string, passWord string, fetch ConnectRetryMax: connectRetryMax, sqlDialect: sqlDialect, Database: database, + TLSConfig: tlsConfig, } } diff --git a/client/sessionpool.go b/client/sessionpool.go index 757b298..c481bf4 100644 --- a/client/sessionpool.go +++ b/client/sessionpool.go @@ -50,6 +50,7 @@ type PoolConfig struct { TimeZone string ConnectRetryMax int Database string + TLSConfig *TLSConfig sqlDialect string } @@ -146,6 +147,7 @@ func getSessionConfig(config *PoolConfig) *Config { ConnectRetryMax: config.ConnectRetryMax, sqlDialect: config.sqlDialect, Database: config.Database, + TLSConfig: config.TLSConfig, } } @@ -159,6 +161,7 @@ func getClusterSessionConfig(config *PoolConfig) *ClusterConfig { ConnectRetryMax: config.ConnectRetryMax, sqlDialect: config.sqlDialect, Database: config.Database, + TLSConfig: config.TLSConfig, } } diff --git a/client/tls.go b/client/tls.go new file mode 100644 index 0000000..9021746 --- /dev/null +++ b/client/tls.go @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package client + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + "time" + + "github.com/apache/thrift/lib/go/thrift" +) + +// TLSConfig enables TLS for an IoTDB client connection. Set CertFile and +// KeyFile together to enable mTLS client authentication. +type TLSConfig struct { + // Config is an optional base tls.Config. It is cloned before use. + Config *tls.Config + + // CAFile is an optional PEM encoded CA certificate file used to verify the server. + CAFile string + + // CertFile and KeyFile are optional PEM encoded client certificate and key files for mTLS. + CertFile string + KeyFile string +} + +func newTransport(host string, port string, connectionTimeoutInMs int, tlsConfig *TLSConfig) (thrift.TTransport, error) { + conf := &thrift.TConfiguration{ + ConnectTimeout: time.Duration(connectionTimeoutInMs) * time.Millisecond, + MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE, + } + hostPort := net.JoinHostPort(host, port) + + var base thrift.TTransport + if tlsConfig == nil { + base = thrift.NewTSocketConf(hostPort, conf) + } else { + cfg, err := buildTLSConfig(tlsConfig) + if err != nil { + return nil, err + } + conf.TLSConfig = cfg + base = thrift.NewTSSLSocketConf(hostPort, conf) + } + + return thrift.NewTFramedTransportConf(base, conf), nil +} + +func buildTLSConfig(config *TLSConfig) (*tls.Config, error) { + if config == nil { + return nil, nil + } + + tlsConfig := &tls.Config{} + if config.Config != nil { + tlsConfig = config.Config.Clone() + } + if config.CAFile != "" { + rootCAs, err := loadCertPool(tlsConfig.RootCAs, config.CAFile) + if err != nil { + return nil, err + } + tlsConfig.RootCAs = rootCAs + } + if config.CertFile != "" || config.KeyFile != "" { + if config.CertFile == "" || config.KeyFile == "" { + return nil, fmt.Errorf("both TLS CertFile and KeyFile must be set") + } + certificate, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) + if err != nil { + return nil, fmt.Errorf("load TLS client certificate/key: %w", err) + } + tlsConfig.Certificates = append(tlsConfig.Certificates, certificate) + } + + return tlsConfig, nil +} + +func loadCertPool(base *x509.CertPool, caFile string) (*x509.CertPool, error) { + rootCAs := base + if rootCAs != nil { + rootCAs = rootCAs.Clone() + } else { + systemPool, err := x509.SystemCertPool() + if err == nil { + rootCAs = systemPool + } else { + rootCAs = x509.NewCertPool() + } + } + + caCert, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("read TLS CA file %q: %w", caFile, err) + } + if !rootCAs.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("append TLS CA file %q: no certificates found", caFile) + } + return rootCAs, nil +} diff --git a/client/tls_test.go b/client/tls_test.go new file mode 100644 index 0000000..5d53ad3 --- /dev/null +++ b/client/tls_test.go @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package client + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +func TestBuildTLSConfigClonesBaseConfig(t *testing.T) { + base := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + cfg, err := buildTLSConfig(&TLSConfig{ + Config: base, + }) + if err != nil { + t.Fatal(err) + } + + if cfg == base { + t.Fatal("buildTLSConfig must clone the base tls.Config") + } + if cfg.MinVersion != tls.VersionTLS12 { + t.Fatalf("MinVersion = %d, want %d", cfg.MinVersion, tls.VersionTLS12) + } +} + +func TestBuildTLSConfigLoadsFiles(t *testing.T) { + caFile, certFile, keyFile := writeTLSFiles(t) + + cfg, err := buildTLSConfig(&TLSConfig{ + CAFile: caFile, + CertFile: certFile, + KeyFile: keyFile, + }) + if err != nil { + t.Fatal(err) + } + if cfg.RootCAs == nil { + t.Fatal("RootCAs should be set") + } + if len(cfg.Certificates) != 1 { + t.Fatalf("Certificates length = %d, want 1", len(cfg.Certificates)) + } +} + +func TestBuildTLSConfigRequiresCertAndKey(t *testing.T) { + _, err := buildTLSConfig(&TLSConfig{CertFile: "client.crt"}) + if err == nil { + t.Fatal("expected error when CertFile is set without KeyFile") + } +} + +func writeTLSFiles(t *testing.T) (caFile string, certFile string, keyFile string) { + t.Helper() + + dir := t.TempDir() + now := time.Now() + caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "iotdb-client-go-test-ca"}, + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + } + caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + t.Fatal(err) + } + + clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + clientTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "iotdb-client-go-test-client"}, + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + clientDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey) + if err != nil { + t.Fatal(err) + } + clientKeyDER, err := x509.MarshalECPrivateKey(clientKey) + if err != nil { + t.Fatal(err) + } + + caFile = filepath.Join(dir, "ca.pem") + certFile = filepath.Join(dir, "client.pem") + keyFile = filepath.Join(dir, "client-key.pem") + writePEMFile(t, caFile, "CERTIFICATE", caDER) + writePEMFile(t, certFile, "CERTIFICATE", clientDER) + writePEMFile(t, keyFile, "EC PRIVATE KEY", clientKeyDER) + return caFile, certFile, keyFile +} + +func writePEMFile(t *testing.T, filename string, blockType string, der []byte) { + t.Helper() + + file, err := os.Create(filename) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + if err := pem.Encode(file, &pem.Block{Type: blockType, Bytes: der}); err != nil { + t.Fatal(err) + } +}
