This is an automated email from the ASF dual-hosted git repository.

HTHou pushed a commit to branch codex/add-tls-mtls-support
in repository https://gitbox.apache.org/repos/asf/iotdb-client-go.git

commit f105c2668fc76ea0aea3fb26df7fc8ba45db728e
Author: HTHou <[email protected]>
AuthorDate: Fri Jun 26 17:13:08 2026 +0800

    Add TLS and mTLS support
---
 README.md             |  23 ++++++++
 README_ZH.md          |  23 ++++++++
 client/session.go     |  52 ++++++++----------
 client/sessionpool.go |   3 +
 client/tls.go         | 120 ++++++++++++++++++++++++++++++++++++++++
 client/tls_test.go    | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 340 insertions(+), 29 deletions(-)

diff --git a/README.md b/README.md
index a963a8e..2045b53 100644
--- a/README.md
+++ b/README.md
@@ -79,6 +79,29 @@ curl -o session_example.go -L 
https://github.com/apache/iotdb-client-go/raw/main
 go run session_example.go
 ```
 
+## TLS/mTLS
+
+Set `TLSConfig` on `client.Config`, `client.ClusterConfig`, or 
`client.PoolConfig` to enable TLS. Add `CertFile` and `KeyFile` when the server 
requires mTLS client authentication.
+
+```golang
+config := &client.Config{
+    Host:     host,
+    Port:     port,
+    UserName: user,
+    Password: password,
+    TLSConfig: &client.TLSConfig{
+        CAFile:     "/path/to/ca.pem",
+        CertFile:   "/path/to/client.pem",
+        KeyFile:    "/path/to/client-key.pem",
+    },
+}
+session := client.NewSession(config)
+if err := session.Open(false, 0); err != nil {
+    log.Fatal(err)
+}
+defer session.Close()
+```
+
 ## How to Use the SessionPool
 
 SessionPool is a wrapper of a Session Set. Using SessionPool, the user do not 
need to consider how to reuse a session connection.
diff --git a/README_ZH.md b/README_ZH.md
index e856172..e680123 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -76,6 +76,29 @@ curl -o session_example.go -L 
https://github.com/apache/iotdb-client-go/raw/main
 go run session_example.go
 ```
 
+## TLS/mTLS
+
+在 `client.Config`、`client.ClusterConfig` 或 `client.PoolConfig` 上设置 `TLSConfig` 
即可启用 TLS。如果服务端要求 mTLS 客户端认证,同时设置 `CertFile` 和 `KeyFile`。
+
+```golang
+config := &client.Config{
+    Host:     host,
+    Port:     port,
+    UserName: user,
+    Password: password,
+    TLSConfig: &client.TLSConfig{
+        CAFile:     "/path/to/ca.pem",
+        CertFile:   "/path/to/client.pem",
+        KeyFile:    "/path/to/client-key.pem",
+    },
+}
+session := client.NewSession(config)
+if err := session.Open(false, 0); err != nil {
+    log.Fatal(err)
+}
+defer session.Close()
+```
+
 ## SessionPool
 通过SessionPool管理session,用户不需要考虑如何重用session,当到达pool的最大值时,获取session的请求会阻塞
 注意:session使用完成后需要调用PutBack方法
diff --git a/client/session.go b/client/session.go
index 28b326e..7ac75f1 100644
--- a/client/session.go
+++ b/client/session.go
@@ -26,7 +26,6 @@ import (
        "errors"
        "fmt"
        "log"
-       "net"
        "reflect"
        "sort"
        "strings"
@@ -68,6 +67,7 @@ type Config struct {
        sqlDialect      string
        Version         Version
        Database        string
+       TLSConfig       *TLSConfig
 }
 
 type Session struct {
@@ -100,13 +100,10 @@ func (s *Session) Open(enableRPCCompression bool, 
connectionTimeoutInMs int) err
 
        var err error
 
-       // In thrift 0.14.1, this func returns two values; in newer versions, 
it returns one.
-       s.trans = thrift.NewTSocketConf(net.JoinHostPort(s.config.Host, 
s.config.Port), &thrift.TConfiguration{
-               ConnectTimeout: time.Duration(connectionTimeoutInMs) * 
time.Millisecond, // Use 0 for no timeout
-       })
-       // s.trans = thrift.NewTFramedTransport(s.trans)        // deprecated
-       tmp_conf := thrift.TConfiguration{MaxFrameSize: 
thrift.DEFAULT_MAX_FRAME_SIZE}
-       s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
+       s.trans, err = newTransport(s.config.Host, s.config.Port, 
connectionTimeoutInMs, s.config.TLSConfig)
+       if err != nil {
+               return err
+       }
        if !s.trans.IsOpen() {
                err = s.trans.Open()
                if err != nil {
@@ -154,6 +151,7 @@ type ClusterConfig struct {
        ConnectRetryMax int
        sqlDialect      string
        Database        string
+       TLSConfig       *TLSConfig
 }
 
 func (s *Session) OpenCluster(enableRPCCompression bool) error {
@@ -1328,24 +1326,23 @@ func newClusterSessionWithSqlDialect(clusterConfig 
*ClusterConfig) (Session, err
        var err error
        for i := range session.endPointList {
                ep := session.endPointList[i]
-               session.trans = thrift.NewTSocketConf(net.JoinHostPort(ep.Host, 
ep.Port), &thrift.TConfiguration{
-                       ConnectTimeout: time.Duration(0), // Use 0 for no 
timeout
-               })
-               // session.trans = thrift.NewTFramedTransport(session.trans)    
// deprecated
-               tmp_conf := thrift.TConfiguration{MaxFrameSize: 
thrift.DEFAULT_MAX_FRAME_SIZE}
-               session.trans = thrift.NewTFramedTransportConf(session.trans, 
&tmp_conf)
+               session.trans, err = newTransport(ep.Host, ep.Port, 0, 
clusterConfig.TLSConfig)
+               if err != nil {
+                       log.Println(err)
+                       continue
+               }
                if !session.trans.IsOpen() {
                        err = session.trans.Open()
                        if err != nil {
                                log.Println(err)
                        } else {
                                session.config = getConfig(ep.Host, ep.Port,
-                                       clusterConfig.UserName, 
clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, 
clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect)
+                                       clusterConfig.UserName, 
clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, 
clusterConfig.ConnectRetryMax, clusterConfig.Database, 
clusterConfig.sqlDialect, clusterConfig.TLSConfig)
                                break
                        }
                }
        }
-       if !session.trans.IsOpen() {
+       if session.trans == nil || !session.trans.IsOpen() {
                return session, fmt.Errorf("no server can connect")
        }
        return session, nil
@@ -1354,18 +1351,14 @@ func newClusterSessionWithSqlDialect(clusterConfig 
*ClusterConfig) (Session, err
 func (s *Session) initClusterConn(node endPoint) error {
        var err error
 
-       s.trans = thrift.NewTSocketConf(net.JoinHostPort(node.Host, node.Port), 
&thrift.TConfiguration{
-               ConnectTimeout: time.Duration(0), // Use 0 for no timeout
-       })
-       if err == nil {
-               // s.trans = thrift.NewTFramedTransport(s.trans)        // 
deprecated
-               tmp_conf := thrift.TConfiguration{MaxFrameSize: 
thrift.DEFAULT_MAX_FRAME_SIZE}
-               s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
-               if !s.trans.IsOpen() {
-                       err = s.trans.Open()
-                       if err != nil {
-                               return err
-                       }
+       s.trans, err = newTransport(node.Host, node.Port, 0, s.config.TLSConfig)
+       if err != nil {
+               return err
+       }
+       if !s.trans.IsOpen() {
+               err = s.trans.Open()
+               if err != nil {
+                       return err
                }
        }
 
@@ -1398,7 +1391,7 @@ func (s *Session) initClusterConn(node endPoint) error {
        return err
 }
 
-func getConfig(host string, port string, userName string, passWord string, 
fetchSize int32, timeZone string, connectRetryMax int, database string, 
sqlDialect string) *Config {
+func getConfig(host string, port string, userName string, passWord string, 
fetchSize int32, timeZone string, connectRetryMax int, database string, 
sqlDialect string, tlsConfig *TLSConfig) *Config {
        return &Config{
                Host:            host,
                Port:            port,
@@ -1409,6 +1402,7 @@ func getConfig(host string, port string, userName string, 
passWord string, fetch
                ConnectRetryMax: connectRetryMax,
                sqlDialect:      sqlDialect,
                Database:        database,
+               TLSConfig:       tlsConfig,
        }
 }
 
diff --git a/client/sessionpool.go b/client/sessionpool.go
index 757b298..c481bf4 100644
--- a/client/sessionpool.go
+++ b/client/sessionpool.go
@@ -50,6 +50,7 @@ type PoolConfig struct {
        TimeZone        string
        ConnectRetryMax int
        Database        string
+       TLSConfig       *TLSConfig
        sqlDialect      string
 }
 
@@ -146,6 +147,7 @@ func getSessionConfig(config *PoolConfig) *Config {
                ConnectRetryMax: config.ConnectRetryMax,
                sqlDialect:      config.sqlDialect,
                Database:        config.Database,
+               TLSConfig:       config.TLSConfig,
        }
 }
 
@@ -159,6 +161,7 @@ func getClusterSessionConfig(config *PoolConfig) 
*ClusterConfig {
                ConnectRetryMax: config.ConnectRetryMax,
                sqlDialect:      config.sqlDialect,
                Database:        config.Database,
+               TLSConfig:       config.TLSConfig,
        }
 }
 
diff --git a/client/tls.go b/client/tls.go
new file mode 100644
index 0000000..9021746
--- /dev/null
+++ b/client/tls.go
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package client
+
+import (
+       "crypto/tls"
+       "crypto/x509"
+       "fmt"
+       "net"
+       "os"
+       "time"
+
+       "github.com/apache/thrift/lib/go/thrift"
+)
+
+// TLSConfig enables TLS for an IoTDB client connection. Set CertFile and
+// KeyFile together to enable mTLS client authentication.
+type TLSConfig struct {
+       // Config is an optional base tls.Config. It is cloned before use.
+       Config *tls.Config
+
+       // CAFile is an optional PEM encoded CA certificate file used to verify 
the server.
+       CAFile string
+
+       // CertFile and KeyFile are optional PEM encoded client certificate and 
key files for mTLS.
+       CertFile string
+       KeyFile  string
+}
+
+func newTransport(host string, port string, connectionTimeoutInMs int, 
tlsConfig *TLSConfig) (thrift.TTransport, error) {
+       conf := &thrift.TConfiguration{
+               ConnectTimeout: time.Duration(connectionTimeoutInMs) * 
time.Millisecond,
+               MaxFrameSize:   thrift.DEFAULT_MAX_FRAME_SIZE,
+       }
+       hostPort := net.JoinHostPort(host, port)
+
+       var base thrift.TTransport
+       if tlsConfig == nil {
+               base = thrift.NewTSocketConf(hostPort, conf)
+       } else {
+               cfg, err := buildTLSConfig(tlsConfig)
+               if err != nil {
+                       return nil, err
+               }
+               conf.TLSConfig = cfg
+               base = thrift.NewTSSLSocketConf(hostPort, conf)
+       }
+
+       return thrift.NewTFramedTransportConf(base, conf), nil
+}
+
+func buildTLSConfig(config *TLSConfig) (*tls.Config, error) {
+       if config == nil {
+               return nil, nil
+       }
+
+       tlsConfig := &tls.Config{}
+       if config.Config != nil {
+               tlsConfig = config.Config.Clone()
+       }
+       if config.CAFile != "" {
+               rootCAs, err := loadCertPool(tlsConfig.RootCAs, config.CAFile)
+               if err != nil {
+                       return nil, err
+               }
+               tlsConfig.RootCAs = rootCAs
+       }
+       if config.CertFile != "" || config.KeyFile != "" {
+               if config.CertFile == "" || config.KeyFile == "" {
+                       return nil, fmt.Errorf("both TLS CertFile and KeyFile 
must be set")
+               }
+               certificate, err := tls.LoadX509KeyPair(config.CertFile, 
config.KeyFile)
+               if err != nil {
+                       return nil, fmt.Errorf("load TLS client 
certificate/key: %w", err)
+               }
+               tlsConfig.Certificates = append(tlsConfig.Certificates, 
certificate)
+       }
+
+       return tlsConfig, nil
+}
+
+func loadCertPool(base *x509.CertPool, caFile string) (*x509.CertPool, error) {
+       rootCAs := base
+       if rootCAs != nil {
+               rootCAs = rootCAs.Clone()
+       } else {
+               systemPool, err := x509.SystemCertPool()
+               if err == nil {
+                       rootCAs = systemPool
+               } else {
+                       rootCAs = x509.NewCertPool()
+               }
+       }
+
+       caCert, err := os.ReadFile(caFile)
+       if err != nil {
+               return nil, fmt.Errorf("read TLS CA file %q: %w", caFile, err)
+       }
+       if !rootCAs.AppendCertsFromPEM(caCert) {
+               return nil, fmt.Errorf("append TLS CA file %q: no certificates 
found", caFile)
+       }
+       return rootCAs, nil
+}
diff --git a/client/tls_test.go b/client/tls_test.go
new file mode 100644
index 0000000..5d53ad3
--- /dev/null
+++ b/client/tls_test.go
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package client
+
+import (
+       "crypto/ecdsa"
+       "crypto/elliptic"
+       "crypto/rand"
+       "crypto/tls"
+       "crypto/x509"
+       "crypto/x509/pkix"
+       "encoding/pem"
+       "math/big"
+       "os"
+       "path/filepath"
+       "testing"
+       "time"
+)
+
+func TestBuildTLSConfigClonesBaseConfig(t *testing.T) {
+       base := &tls.Config{
+               MinVersion: tls.VersionTLS12,
+       }
+
+       cfg, err := buildTLSConfig(&TLSConfig{
+               Config: base,
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if cfg == base {
+               t.Fatal("buildTLSConfig must clone the base tls.Config")
+       }
+       if cfg.MinVersion != tls.VersionTLS12 {
+               t.Fatalf("MinVersion = %d, want %d", cfg.MinVersion, 
tls.VersionTLS12)
+       }
+}
+
+func TestBuildTLSConfigLoadsFiles(t *testing.T) {
+       caFile, certFile, keyFile := writeTLSFiles(t)
+
+       cfg, err := buildTLSConfig(&TLSConfig{
+               CAFile:   caFile,
+               CertFile: certFile,
+               KeyFile:  keyFile,
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+       if cfg.RootCAs == nil {
+               t.Fatal("RootCAs should be set")
+       }
+       if len(cfg.Certificates) != 1 {
+               t.Fatalf("Certificates length = %d, want 1", 
len(cfg.Certificates))
+       }
+}
+
+func TestBuildTLSConfigRequiresCertAndKey(t *testing.T) {
+       _, err := buildTLSConfig(&TLSConfig{CertFile: "client.crt"})
+       if err == nil {
+               t.Fatal("expected error when CertFile is set without KeyFile")
+       }
+}
+
+func writeTLSFiles(t *testing.T) (caFile string, certFile string, keyFile 
string) {
+       t.Helper()
+
+       dir := t.TempDir()
+       now := time.Now()
+       caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+       if err != nil {
+               t.Fatal(err)
+       }
+       caTemplate := &x509.Certificate{
+               SerialNumber:          big.NewInt(1),
+               Subject:               pkix.Name{CommonName: 
"iotdb-client-go-test-ca"},
+               NotBefore:             now.Add(-time.Hour),
+               NotAfter:              now.Add(time.Hour),
+               KeyUsage:              x509.KeyUsageCertSign | 
x509.KeyUsageDigitalSignature,
+               BasicConstraintsValid: true,
+               IsCA:                  true,
+       }
+       caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, 
caTemplate, &caKey.PublicKey, caKey)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+       if err != nil {
+               t.Fatal(err)
+       }
+       clientTemplate := &x509.Certificate{
+               SerialNumber: big.NewInt(2),
+               Subject:      pkix.Name{CommonName: 
"iotdb-client-go-test-client"},
+               NotBefore:    now.Add(-time.Hour),
+               NotAfter:     now.Add(time.Hour),
+               KeyUsage:     x509.KeyUsageDigitalSignature,
+               ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+       }
+       clientDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, 
caTemplate, &clientKey.PublicKey, caKey)
+       if err != nil {
+               t.Fatal(err)
+       }
+       clientKeyDER, err := x509.MarshalECPrivateKey(clientKey)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       caFile = filepath.Join(dir, "ca.pem")
+       certFile = filepath.Join(dir, "client.pem")
+       keyFile = filepath.Join(dir, "client-key.pem")
+       writePEMFile(t, caFile, "CERTIFICATE", caDER)
+       writePEMFile(t, certFile, "CERTIFICATE", clientDER)
+       writePEMFile(t, keyFile, "EC PRIVATE KEY", clientKeyDER)
+       return caFile, certFile, keyFile
+}
+
+func writePEMFile(t *testing.T, filename string, blockType string, der []byte) 
{
+       t.Helper()
+
+       file, err := os.Create(filename)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer file.Close()
+
+       if err := pem.Encode(file, &pem.Block{Type: blockType, Bytes: der}); 
err != nil {
+               t.Fatal(err)
+       }
+}

Reply via email to