This is an automated email from the ASF dual-hosted git repository.
jimin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-seata-go.git
The following commit(s) were added to refs/heads/master by this push:
new eaf4b270 test: improve test coverage for pkg/datasource/sql/util (#974)
eaf4b270 is described below
commit eaf4b270871862d372d10e8229c008387e353b7f
Author: EVERFID <[email protected]>
AuthorDate: Sat Nov 8 00:14:29 2025 +0800
test: improve test coverage for pkg/datasource/sql/util (#974)
---
pkg/datasource/sql/util/convert_test.go | 473 ++++++++++++++++++++++++++
pkg/datasource/sql/util/ctxutil_test.go | 504 ++++++++++++++++++++++++++++
pkg/datasource/sql/util/lockkey_test.go | 542 ++++++++++++++++++++++++++++++
pkg/datasource/sql/util/params_test.go | 172 ++++++++++
pkg/datasource/sql/util/sql_test.go | 570 ++++++++++++++++++++++++++++++++
5 files changed, 2261 insertions(+)
diff --git a/pkg/datasource/sql/util/convert_test.go
b/pkg/datasource/sql/util/convert_test.go
index 86a7bb2c..2a7d2fb0 100644
--- a/pkg/datasource/sql/util/convert_test.go
+++ b/pkg/datasource/sql/util/convert_test.go
@@ -18,9 +18,15 @@
package util
import (
+ "database/sql"
+ "errors"
+ "reflect"
+ "strconv"
"testing"
+ "time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestConvertDbVersion(t *testing.T) {
@@ -39,3 +45,470 @@ func TestConvertDbVersion(t *testing.T) {
assert.NoError(t, err3)
assert.Equal(t, v2Int, v3Int)
}
+
+func TestConvertDbVersionWithHyphen(t *testing.T) {
+ version := "5.7.30-log"
+ v, err := ConvertDbVersion(version)
+ assert.NoError(t, err)
+ assert.Greater(t, v, 0)
+
+ // Compare with non-hyphenated version
+ version2 := "5.7.30"
+ v2, err2 := ConvertDbVersion(version2)
+ assert.NoError(t, err2)
+ assert.Equal(t, v, v2) // Should be equal as we only parse the numeric
part
+}
+
+func TestConvertDbVersionInvalidFormat(t *testing.T) {
+ version := "1.2.3.4.5"
+ _, err := ConvertDbVersion(version)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "incompatible version format")
+}
+
+func TestConvertDbVersionSinglePart(t *testing.T) {
+ version := "8"
+ v, err := ConvertDbVersion(version)
+ assert.NoError(t, err)
+ assert.Equal(t, 800, v)
+}
+
+func TestConvertDbVersionTwoParts(t *testing.T) {
+ version := "8.0"
+ v, err := ConvertDbVersion(version)
+ assert.NoError(t, err)
+ assert.Equal(t, 80000, v)
+}
+
+// Test convertAssignRows with string source
+func TestConvertAssignRowsStringToString(t *testing.T) {
+ src := "test"
+ var dest string
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "test", dest)
+}
+
+func TestConvertAssignRowsStringToBytes(t *testing.T) {
+ src := "test"
+ var dest []byte
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("test"), dest)
+}
+
+func TestConvertAssignRowsStringToRawBytes(t *testing.T) {
+ src := "test"
+ var dest sql.RawBytes
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, sql.RawBytes("test"), dest)
+}
+
+func TestConvertAssignRowsStringToNilPointer(t *testing.T) {
+ src := "test"
+ var dest *string
+ err := convertAssignRows(dest, src, nil)
+ assert.Error(t, err)
+ assert.Equal(t, errNilPtr, err)
+}
+
+// Test convertAssignRows with []byte source
+func TestConvertAssignRowsBytesToString(t *testing.T) {
+ src := []byte("test")
+ var dest string
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "test", dest)
+}
+
+func TestConvertAssignRowsBytesToInterface(t *testing.T) {
+ src := []byte("test")
+ var dest interface{}
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("test"), dest)
+}
+
+func TestConvertAssignRowsBytesToBytes(t *testing.T) {
+ src := []byte("test")
+ var dest []byte
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("test"), dest)
+ // Verify it's a clone
+ src[0] = 'x'
+ assert.NotEqual(t, src, dest)
+}
+
+func TestConvertAssignRowsBytesToRawBytes(t *testing.T) {
+ src := []byte("test")
+ var dest sql.RawBytes
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, sql.RawBytes("test"), dest)
+}
+
+// Test convertAssignRows with time.Time source
+func TestConvertAssignRowsTimeToTime(t *testing.T) {
+ now := time.Now()
+ var dest time.Time
+ err := convertAssignRows(&dest, now, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, now, dest)
+}
+
+func TestConvertAssignRowsTimeToString(t *testing.T) {
+ now := time.Now()
+ var dest string
+ err := convertAssignRows(&dest, now, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, now.Format(time.RFC3339Nano), dest)
+}
+
+func TestConvertAssignRowsTimeToBytes(t *testing.T) {
+ now := time.Now()
+ var dest []byte
+ err := convertAssignRows(&dest, now, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte(now.Format(time.RFC3339Nano)), dest)
+}
+
+func TestConvertAssignRowsTimeToRawBytes(t *testing.T) {
+ now := time.Now()
+ var dest sql.RawBytes
+ err := convertAssignRows(&dest, now, nil)
+ assert.NoError(t, err)
+ expected := now.AppendFormat(nil, time.RFC3339Nano)
+ assert.Equal(t, sql.RawBytes(expected), dest)
+}
+
+// Test convertAssignRows with nil source
+func TestConvertAssignRowsNilToInterface(t *testing.T) {
+ var dest interface{}
+ err := convertAssignRows(&dest, nil, nil)
+ assert.NoError(t, err)
+ assert.Nil(t, dest)
+}
+
+func TestConvertAssignRowsNilToBytes(t *testing.T) {
+ var dest []byte
+ err := convertAssignRows(&dest, nil, nil)
+ assert.NoError(t, err)
+ assert.Nil(t, dest)
+}
+
+func TestConvertAssignRowsNilToRawBytes(t *testing.T) {
+ var dest sql.RawBytes
+ err := convertAssignRows(&dest, nil, nil)
+ assert.NoError(t, err)
+ assert.Nil(t, dest)
+}
+
+// Test convertAssignRows with numeric types
+func TestConvertAssignRowsIntToString(t *testing.T) {
+ var dest string
+ err := convertAssignRows(&dest, int64(123), nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "123", dest)
+}
+
+func TestConvertAssignRowsFloatToString(t *testing.T) {
+ var dest string
+ err := convertAssignRows(&dest, float64(123.45), nil)
+ assert.NoError(t, err)
+ assert.Contains(t, dest, "123.45")
+}
+
+func TestConvertAssignRowsBoolToString(t *testing.T) {
+ var dest string
+ err := convertAssignRows(&dest, true, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "true", dest)
+}
+
+func TestConvertAssignRowsIntToBytes(t *testing.T) {
+ var dest []byte
+ err := convertAssignRows(&dest, int64(123), nil)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte("123"), dest)
+}
+
+func TestConvertAssignRowsIntToBool(t *testing.T) {
+ var dest bool
+ err := convertAssignRows(&dest, int64(1), nil)
+ assert.NoError(t, err)
+ assert.True(t, dest)
+}
+
+func TestConvertAssignRowsToInterface(t *testing.T) {
+ var dest interface{}
+ src := "test"
+ err := convertAssignRows(&dest, src, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "test", dest)
+}
+
+// Test convertAssignRows with reflection-based conversions
+func TestConvertAssignRowsStringToInt(t *testing.T) {
+ var dest int
+ err := convertAssignRows(&dest, "123", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, 123, dest)
+}
+
+func TestConvertAssignRowsStringToInt64(t *testing.T) {
+ var dest int64
+ err := convertAssignRows(&dest, "123", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), dest)
+}
+
+func TestConvertAssignRowsStringToUint(t *testing.T) {
+ var dest uint
+ err := convertAssignRows(&dest, "123", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, uint(123), dest)
+}
+
+func TestConvertAssignRowsStringToFloat64(t *testing.T) {
+ var dest float64
+ err := convertAssignRows(&dest, "123.45", nil)
+ assert.NoError(t, err)
+ assert.InDelta(t, 123.45, dest, 0.0001)
+}
+
+func TestConvertAssignRowsInvalidIntConversion(t *testing.T) {
+ var dest int
+ err := convertAssignRows(&dest, "invalid", nil)
+ assert.Error(t, err)
+}
+
+func TestConvertAssignRowsIntOverflow(t *testing.T) {
+ var dest int8
+ err := convertAssignRows(&dest, "1000", nil)
+ assert.Error(t, err)
+}
+
+func TestConvertAssignRowsNullToInt(t *testing.T) {
+ var dest int
+ err := convertAssignRows(&dest, nil, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "converting NULL")
+}
+
+func TestConvertAssignRowsNullToFloat(t *testing.T) {
+ var dest float64
+ err := convertAssignRows(&dest, nil, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "converting NULL")
+}
+
+func TestConvertAssignRowsNullToString(t *testing.T) {
+ var dest string
+ err := convertAssignRows(&dest, nil, nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "converting NULL")
+}
+
+// Test pointer types
+func TestConvertAssignRowsNilToPointer(t *testing.T) {
+ var dest *string
+ err := convertAssignRows(&dest, nil, nil)
+ assert.NoError(t, err)
+ assert.Nil(t, dest)
+}
+
+func TestConvertAssignRowsValueToPointer(t *testing.T) {
+ var dest *string
+ err := convertAssignRows(&dest, "test", nil)
+ assert.NoError(t, err)
+ require.NotNil(t, dest)
+ assert.Equal(t, "test", *dest)
+}
+
+// Test type conversion and assignability
+func TestConvertAssignRowsAssignableTypes(t *testing.T) {
+ type MyString string
+ var dest MyString
+ err := convertAssignRows(&dest, "test", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, MyString("test"), dest)
+}
+
+func TestConvertAssignRowsConvertibleTypes(t *testing.T) {
+ type MyInt int
+ var dest MyInt
+ err := convertAssignRows(&dest, "123", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, MyInt(123), dest)
+}
+
+func TestConvertAssignRowsNotPointer(t *testing.T) {
+ var dest string
+ err := convertAssignRows(dest, "test", nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "destination not a pointer")
+}
+
+func TestConvertAssignRowsUnsupportedConversion(t *testing.T) {
+ type MyStruct struct{}
+ var dest MyStruct
+ err := convertAssignRows(&dest, "test", nil)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "unsupported Scan")
+}
+
+// Test custom Scanner implementation
+type customScanner struct {
+ value string
+}
+
+func (cs *customScanner) Scan(src interface{}) error {
+ if src == nil {
+ return nil
+ }
+ cs.value = asString(src)
+ return nil
+}
+
+func TestConvertAssignRowsCustomScanner(t *testing.T) {
+ var dest customScanner
+ err := convertAssignRows(&dest, "test", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "test", dest.value)
+}
+
+// Test helper functions
+func TestCloneBytes(t *testing.T) {
+ src := []byte("test")
+ dest := cloneBytes(src)
+ assert.Equal(t, src, dest)
+ src[0] = 'x'
+ assert.NotEqual(t, src, dest)
+}
+
+func TestCloneBytesNil(t *testing.T) {
+ var src []byte
+ dest := cloneBytes(src)
+ assert.Nil(t, dest)
+}
+
+func TestAsStringVariousTypes(t *testing.T) {
+ tests := []struct {
+ name string
+ input interface{}
+ expected string
+ }{
+ {"string", "test", "test"},
+ {"[]byte", []byte("test"), "test"},
+ {"int", int(123), "123"},
+ {"int64", int64(123), "123"},
+ {"uint", uint(123), "123"},
+ {"float64", float64(123.45), "123.45"},
+ {"float32", float32(123.45), "123.45"},
+ {"bool true", true, "true"},
+ {"bool false", false, "false"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := asString(tt.input)
+ if tt.name == "float32" || tt.name == "float64" {
+ assert.Contains(t, result, "123.45")
+ } else {
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestAsBytesVariousTypes(t *testing.T) {
+ tests := []struct {
+ name string
+ input interface{}
+ expected string
+ ok bool
+ }{
+ {"int", int(123), "123", true},
+ {"int64", int64(123), "123", true},
+ {"uint", uint(123), "123", true},
+ {"float64", float64(123.45), "123.45", true},
+ {"float32", float32(123.45), "123.45", true},
+ {"bool", true, "true", true},
+ {"string", "test", "test", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rv := reflect.ValueOf(tt.input)
+ result, ok := asBytes(nil, rv)
+ assert.Equal(t, tt.ok, ok)
+ if ok {
+ if tt.name == "float32" || tt.name == "float64"
{
+ assert.Contains(t, string(result),
"123.45")
+ } else {
+ assert.Equal(t, tt.expected,
string(result))
+ }
+ }
+ })
+ }
+}
+
+func TestAsBytesUnsupportedType(t *testing.T) {
+ type MyStruct struct{}
+ rv := reflect.ValueOf(MyStruct{})
+ _, ok := asBytes(nil, rv)
+ assert.False(t, ok)
+}
+
+func TestStrconvErr(t *testing.T) {
+ // Test with NumError
+ numErr := &strconv.NumError{
+ Func: "ParseInt",
+ Num: "abc",
+ Err: strconv.ErrSyntax,
+ }
+ result := strconvErr(numErr)
+ assert.Equal(t, strconv.ErrSyntax, result)
+
+ // Test with regular error
+ regularErr := errors.New("regular error")
+ result = strconvErr(regularErr)
+ assert.Equal(t, regularErr, result)
+}
+
+func TestCalculatePartValue(t *testing.T) {
+ assert.Equal(t, 3000000, calculatePartValue(3, 3, 0))
+ assert.Equal(t, 10000, calculatePartValue(1, 3, 1))
+ assert.Equal(t, 200, calculatePartValue(2, 3, 2))
+}
+
+// Test edge cases
+func TestConvertAssignRowsEmptyString(t *testing.T) {
+ var dest string
+ err := convertAssignRows(&dest, "", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, "", dest)
+}
+
+func TestConvertAssignRowsZeroInt(t *testing.T) {
+ var dest int
+ err := convertAssignRows(&dest, "0", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, 0, dest)
+}
+
+func TestConvertAssignRowsNegativeInt(t *testing.T) {
+ var dest int
+ err := convertAssignRows(&dest, "-123", nil)
+ assert.NoError(t, err)
+ assert.Equal(t, -123, dest)
+}
+
+func TestConvertAssignRowsScientificNotation(t *testing.T) {
+ var dest float64
+ err := convertAssignRows(&dest, "1.23e2", nil)
+ assert.NoError(t, err)
+ assert.InDelta(t, 123.0, dest, 0.0001)
+}
diff --git a/pkg/datasource/sql/util/ctxutil_test.go
b/pkg/datasource/sql/util/ctxutil_test.go
new file mode 100644
index 00000000..6a5690c8
--- /dev/null
+++ b/pkg/datasource/sql/util/ctxutil_test.go
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package util
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+ "github.com/stretchr/testify/assert"
+ "testing"
+)
+
+// Mock implementations for testing
+type mockConn struct {
+ prepareFunc func(query string) (driver.Stmt, error)
+}
+
+func (m *mockConn) Prepare(query string) (driver.Stmt, error) {
+ if m.prepareFunc != nil {
+ return m.prepareFunc(query)
+ }
+ return nil, errors.New("not implemented")
+}
+
+func (m *mockConn) Close() error {
+ return nil
+}
+
+func (m *mockConn) Begin() (driver.Tx, error) {
+ return nil, errors.New("not implemented")
+}
+
+type mockConnPrepareContext struct {
+ mockConn
+ prepareContextFunc func(ctx context.Context, query string)
(driver.Stmt, error)
+}
+
+func (m *mockConnPrepareContext) PrepareContext(ctx context.Context, query
string) (driver.Stmt, error) {
+ if m.prepareContextFunc != nil {
+ return m.prepareContextFunc(ctx, query)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockStmt struct {
+ closeFunc func() error
+ numInputFunc func() int
+ execFunc func(args []driver.Value) (driver.Result, error)
+ queryFunc func(args []driver.Value) (driver.Rows, error)
+}
+
+func (m *mockStmt) Close() error {
+ if m.closeFunc != nil {
+ return m.closeFunc()
+ }
+ return nil
+}
+
+func (m *mockStmt) NumInput() int {
+ if m.numInputFunc != nil {
+ return m.numInputFunc()
+ }
+ return 0
+}
+
+func (m *mockStmt) Exec(args []driver.Value) (driver.Result, error) {
+ if m.execFunc != nil {
+ return m.execFunc(args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+func (m *mockStmt) Query(args []driver.Value) (driver.Rows, error) {
+ if m.queryFunc != nil {
+ return m.queryFunc(args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockStmtExecContext struct {
+ mockStmt
+ execContextFunc func(ctx context.Context, args []driver.NamedValue)
(driver.Result, error)
+}
+
+func (m *mockStmtExecContext) ExecContext(ctx context.Context, args
[]driver.NamedValue) (driver.Result, error) {
+ if m.execContextFunc != nil {
+ return m.execContextFunc(ctx, args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockStmtQueryContext struct {
+ mockStmt
+ queryContextFunc func(ctx context.Context, args []driver.NamedValue)
(driver.Rows, error)
+}
+
+func (m *mockStmtQueryContext) QueryContext(ctx context.Context, args
[]driver.NamedValue) (driver.Rows, error) {
+ if m.queryContextFunc != nil {
+ return m.queryContextFunc(ctx, args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockExecer struct {
+ execFunc func(query string, args []driver.Value) (driver.Result, error)
+}
+
+func (m *mockExecer) Exec(query string, args []driver.Value) (driver.Result,
error) {
+ if m.execFunc != nil {
+ return m.execFunc(query, args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockExecerContext struct {
+ execContextFunc func(ctx context.Context, query string, args
[]driver.NamedValue) (driver.Result, error)
+}
+
+func (m *mockExecerContext) ExecContext(ctx context.Context, query string,
args []driver.NamedValue) (driver.Result, error) {
+ if m.execContextFunc != nil {
+ return m.execContextFunc(ctx, query, args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockQueryer struct {
+ queryFunc func(query string, args []driver.Value) (driver.Rows, error)
+}
+
+func (m *mockQueryer) Query(query string, args []driver.Value) (driver.Rows,
error) {
+ if m.queryFunc != nil {
+ return m.queryFunc(query, args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockQueryerContext struct {
+ queryContextFunc func(ctx context.Context, query string, args
[]driver.NamedValue) (driver.Rows, error)
+}
+
+func (m *mockQueryerContext) QueryContext(ctx context.Context, query string,
args []driver.NamedValue) (driver.Rows, error) {
+ if m.queryContextFunc != nil {
+ return m.queryContextFunc(ctx, query, args)
+ }
+ return nil, errors.New("not implemented")
+}
+
+type mockResult struct {
+ lastInsertId int64
+ rowsAffected int64
+}
+
+func (m *mockResult) LastInsertId() (int64, error) {
+ return m.lastInsertId, nil
+}
+
+func (m *mockResult) RowsAffected() (int64, error) {
+ return m.rowsAffected, nil
+}
+
+type mockRows struct {
+ columns []string
+}
+
+func (m *mockRows) Columns() []string {
+ return m.columns
+}
+
+func (m *mockRows) Close() error {
+ return nil
+}
+
+func (m *mockRows) Next(dest []driver.Value) error {
+ return errors.New("no more rows")
+}
+
+// Tests for ctxDriverPrepare
+func TestCtxDriverPrepare_WithContext(t *testing.T) {
+ mockStmt := &mockStmt{}
+ mockConn := &mockConnPrepareContext{
+ prepareContextFunc: func(ctx context.Context, query string)
(driver.Stmt, error) {
+ assert.Equal(t, "SELECT * FROM test", query)
+ return mockStmt, nil
+ },
+ }
+ ctx := context.Background()
+ query := "SELECT * FROM test"
+
+ result, err := ctxDriverPrepare(ctx, mockConn, query)
+ assert.NoError(t, err)
+ assert.Equal(t, mockStmt, result)
+}
+
+func TestCtxDriverPrepare_WithoutContext(t *testing.T) {
+ mockStmt := &mockStmt{}
+ mockConn := &mockConn{
+ prepareFunc: func(query string) (driver.Stmt, error) {
+ assert.Equal(t, "SELECT * FROM test", query)
+ return mockStmt, nil
+ },
+ }
+ ctx := context.Background()
+ query := "SELECT * FROM test"
+
+ result, err := ctxDriverPrepare(ctx, mockConn, query)
+ assert.NoError(t, err)
+ assert.Equal(t, mockStmt, result)
+}
+
+func TestCtxDriverPrepare_WithCancelledContext(t *testing.T) {
+ closeCalled := false
+ mockStmt := &mockStmt{
+ closeFunc: func() error {
+ closeCalled = true
+ return nil
+ },
+ }
+ mockConn := &mockConn{
+ prepareFunc: func(query string) (driver.Stmt, error) {
+ return mockStmt, nil
+ },
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // Cancel immediately
+ query := "SELECT * FROM test"
+
+ result, err := ctxDriverPrepare(ctx, mockConn, query)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Equal(t, context.Canceled, err)
+ assert.True(t, closeCalled)
+}
+
+func TestCtxDriverPrepare_PrepareError(t *testing.T) {
+ expectedErr := errors.New("prepare error")
+ mockConn := &mockConn{
+ prepareFunc: func(query string) (driver.Stmt, error) {
+ return nil, expectedErr
+ },
+ }
+ ctx := context.Background()
+ query := "SELECT * FROM test"
+
+ result, err := ctxDriverPrepare(ctx, mockConn, query)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Equal(t, expectedErr, err)
+}
+
+// Tests for ctxDriverExec
+func TestCtxDriverExec_WithExecerContext(t *testing.T) {
+ mockResult := &mockResult{lastInsertId: 1, rowsAffected: 1}
+ mockExecerCtx := &mockExecerContext{
+ execContextFunc: func(ctx context.Context, query string, args
[]driver.NamedValue) (driver.Result, error) {
+ assert.Equal(t, "INSERT INTO test VALUES (?)", query)
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0].Value)
+ return mockResult, nil
+ },
+ }
+ ctx := context.Background()
+ query := "INSERT INTO test VALUES (?)"
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverExec(ctx, mockExecerCtx, nil, query, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockResult, result)
+}
+
+func TestCtxDriverExec_WithoutExecerContext(t *testing.T) {
+ mockResult := &mockResult{lastInsertId: 1, rowsAffected: 1}
+ mockExecer := &mockExecer{
+ execFunc: func(query string, args []driver.Value)
(driver.Result, error) {
+ assert.Equal(t, "INSERT INTO test VALUES (?)", query)
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0])
+ return mockResult, nil
+ },
+ }
+ ctx := context.Background()
+ query := "INSERT INTO test VALUES (?)"
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverExec(ctx, nil, mockExecer, query, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockResult, result)
+}
+
+func TestCtxDriverExec_WithCancelledContext(t *testing.T) {
+ mockExecer := &mockExecer{}
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ query := "INSERT INTO test VALUES (?)"
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverExec(ctx, nil, mockExecer, query, args)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Equal(t, context.Canceled, err)
+}
+
+func TestCtxDriverExec_WithNamedParameters(t *testing.T) {
+ mockExecer := &mockExecer{}
+ ctx := context.Background()
+ query := "INSERT INTO test VALUES (?)"
+ args := []driver.NamedValue{{Name: "param1", Value: "test"}}
+
+ result, err := ctxDriverExec(ctx, nil, mockExecer, query, args)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Contains(t, err.Error(), "Named Parameters")
+}
+
+// Tests for CtxDriverQuery
+func TestCtxDriverQuery_WithQueryerContext(t *testing.T) {
+ mockRows := &mockRows{columns: []string{"id", "name"}}
+ mockQueryerCtx := &mockQueryerContext{
+ queryContextFunc: func(ctx context.Context, query string, args
[]driver.NamedValue) (driver.Rows, error) {
+ assert.Equal(t, "SELECT * FROM test", query)
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0].Value)
+ return mockRows, nil
+ },
+ }
+ ctx := context.Background()
+ query := "SELECT * FROM test"
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := CtxDriverQuery(ctx, mockQueryerCtx, nil, query, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockRows, result)
+}
+
+func TestCtxDriverQuery_WithoutQueryerContext(t *testing.T) {
+ mockRows := &mockRows{columns: []string{"id", "name"}}
+ mockQueryer := &mockQueryer{
+ queryFunc: func(query string, args []driver.Value)
(driver.Rows, error) {
+ assert.Equal(t, "SELECT * FROM test", query)
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0])
+ return mockRows, nil
+ },
+ }
+ ctx := context.Background()
+ query := "SELECT * FROM test"
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := CtxDriverQuery(ctx, nil, mockQueryer, query, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockRows, result)
+}
+
+func TestCtxDriverQuery_WithCancelledContext(t *testing.T) {
+ mockQueryer := &mockQueryer{}
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ query := "SELECT * FROM test"
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := CtxDriverQuery(ctx, nil, mockQueryer, query, args)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Equal(t, context.Canceled, err)
+}
+
+// Tests for ctxDriverStmtExec
+func TestCtxDriverStmtExec_WithContext(t *testing.T) {
+ mockResult := &mockResult{lastInsertId: 1, rowsAffected: 1}
+ mockStmt := &mockStmtExecContext{
+ execContextFunc: func(ctx context.Context, args
[]driver.NamedValue) (driver.Result, error) {
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0].Value)
+ return mockResult, nil
+ },
+ }
+ ctx := context.Background()
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverStmtExec(ctx, mockStmt, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockResult, result)
+}
+
+func TestCtxDriverStmtExec_WithoutContext(t *testing.T) {
+ mockResult := &mockResult{lastInsertId: 1, rowsAffected: 1}
+ mockStmt := &mockStmt{
+ execFunc: func(args []driver.Value) (driver.Result, error) {
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0])
+ return mockResult, nil
+ },
+ }
+ ctx := context.Background()
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverStmtExec(ctx, mockStmt, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockResult, result)
+}
+
+func TestCtxDriverStmtExec_WithCancelledContext(t *testing.T) {
+ mockStmt := &mockStmt{}
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverStmtExec(ctx, mockStmt, args)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Equal(t, context.Canceled, err)
+}
+
+// Tests for ctxDriverStmtQuery
+func TestCtxDriverStmtQuery_WithContext(t *testing.T) {
+ mockRows := &mockRows{columns: []string{"id", "name"}}
+ mockStmt := &mockStmtQueryContext{
+ queryContextFunc: func(ctx context.Context, args
[]driver.NamedValue) (driver.Rows, error) {
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0].Value)
+ return mockRows, nil
+ },
+ }
+ ctx := context.Background()
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverStmtQuery(ctx, mockStmt, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockRows, result)
+}
+
+func TestCtxDriverStmtQuery_WithoutContext(t *testing.T) {
+ mockRows := &mockRows{columns: []string{"id", "name"}}
+ mockStmt := &mockStmt{
+ queryFunc: func(args []driver.Value) (driver.Rows, error) {
+ assert.Equal(t, 1, len(args))
+ assert.Equal(t, "test", args[0])
+ return mockRows, nil
+ },
+ }
+ ctx := context.Background()
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverStmtQuery(ctx, mockStmt, args)
+ assert.NoError(t, err)
+ assert.Equal(t, mockRows, result)
+}
+
+func TestCtxDriverStmtQuery_WithCancelledContext(t *testing.T) {
+ mockStmt := &mockStmt{}
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ args := []driver.NamedValue{{Ordinal: 1, Value: "test"}}
+
+ result, err := ctxDriverStmtQuery(ctx, mockStmt, args)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Equal(t, context.Canceled, err)
+}
+
+// Tests for namedValueToValue
+func TestNamedValueToValue_Success(t *testing.T) {
+ input := []driver.NamedValue{
+ {Ordinal: 1, Value: "test"},
+ {Ordinal: 2, Value: int64(123)},
+ }
+
+ result, err := namedValueToValue(input)
+ assert.NoError(t, err)
+ assert.Equal(t, []driver.Value{"test", int64(123)}, result)
+}
+
+func TestNamedValueToValue_WithNamedParameter(t *testing.T) {
+ input := []driver.NamedValue{
+ {Name: "param1", Value: "test"},
+ }
+
+ result, err := namedValueToValue(input)
+ assert.Error(t, err)
+ assert.Nil(t, result)
+ assert.Contains(t, err.Error(), "Named Parameters")
+}
+
+func TestNamedValueToValue_Empty(t *testing.T) {
+ input := []driver.NamedValue{}
+
+ result, err := namedValueToValue(input)
+ assert.NoError(t, err)
+ assert.Empty(t, result)
+}
diff --git a/pkg/datasource/sql/util/lockkey_test.go
b/pkg/datasource/sql/util/lockkey_test.go
new file mode 100644
index 00000000..a69cad9f
--- /dev/null
+++ b/pkg/datasource/sql/util/lockkey_test.go
@@ -0,0 +1,542 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package util
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "seata.apache.org/seata-go/pkg/datasource/sql/types"
+)
+
+func TestBuildLockKey_SinglePrimaryKey(t *testing.T) {
+ // Create table meta with single primary key
+ tableMeta := types.TableMeta{
+ TableName: "users",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ "name": {ColumnName: "name", DatabaseType: 12},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"id", "name"},
+ }
+
+ // Create record image
+ recordImage := &types.RecordImage{
+ TableName: "users",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(1),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "name", Value: "Alice",
KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "USERS:1", lockKey)
+}
+
+func TestBuildLockKey_CompositePrimaryKey(t *testing.T) {
+ // Create table meta with composite primary key
+ tableMeta := types.TableMeta{
+ TableName: "order_items",
+ Columns: map[string]types.ColumnMeta{
+ "order_id": {ColumnName: "order_id", DatabaseType: 4},
+ "item_id": {ColumnName: "item_id", DatabaseType: 4},
+ "quantity": {ColumnName: "quantity", DatabaseType: 4},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "order_id", DatabaseType:
4},
+ {ColumnName: "item_id", DatabaseType:
4},
+ },
+ },
+ },
+ ColumnNames: []string{"order_id", "item_id", "quantity"},
+ }
+
+ // Create record image
+ recordImage := &types.RecordImage{
+ TableName: "order_items",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "order_id", Value:
int64(100), KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "item_id", Value:
int64(5), KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "quantity", Value:
int64(3), KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "ORDER_ITEMS:100_5", lockKey)
+}
+
+func TestBuildLockKey_MultipleRows(t *testing.T) {
+ // Create table meta
+ tableMeta := types.TableMeta{
+ TableName: "products",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ "name": {ColumnName: "name", DatabaseType: 12},
+ "price": {ColumnName: "price", DatabaseType: 8},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"id", "name", "price"},
+ }
+
+ // Create record image with multiple rows
+ recordImage := &types.RecordImage{
+ TableName: "products",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(1),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "name", Value: "Product
A", KeyType: types.IndexTypeNull},
+ {ColumnName: "price", Value:
float64(19.99), KeyType: types.IndexTypeNull},
+ },
+ },
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(2),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "name", Value: "Product
B", KeyType: types.IndexTypeNull},
+ {ColumnName: "price", Value:
float64(29.99), KeyType: types.IndexTypeNull},
+ },
+ },
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(3),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "name", Value: "Product
C", KeyType: types.IndexTypeNull},
+ {ColumnName: "price", Value:
float64(39.99), KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "PRODUCTS:1,2,3", lockKey)
+}
+
+func TestBuildLockKey_EmptyRows(t *testing.T) {
+ // Create table meta
+ tableMeta := types.TableMeta{
+ TableName: "users",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ "name": {ColumnName: "name", DatabaseType: 12},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"id", "name"},
+ }
+
+ // Create record image with no rows
+ recordImage := &types.RecordImage{
+ TableName: "users",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{},
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "USERS:", lockKey)
+}
+
+func TestBuildLockKey_NilValue(t *testing.T) {
+ // Create table meta
+ tableMeta := types.TableMeta{
+ TableName: "users",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ "name": {ColumnName: "name", DatabaseType: 12},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"id", "name"},
+ }
+
+ // Create record image with nil primary key value
+ recordImage := &types.RecordImage{
+ TableName: "users",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: nil, KeyType:
types.IndexTypePrimaryKey},
+ {ColumnName: "name", Value: "Alice",
KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "USERS:", lockKey)
+}
+
+func TestBuildLockKey_StringPrimaryKey(t *testing.T) {
+ // Create table meta with string primary key
+ tableMeta := types.TableMeta{
+ TableName: "categories",
+ Columns: map[string]types.ColumnMeta{
+ "code": {ColumnName: "code", DatabaseType: 12},
+ "name": {ColumnName: "name", DatabaseType: 12},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "code", DatabaseType: 12},
+ },
+ },
+ },
+ ColumnNames: []string{"code", "name"},
+ }
+
+ // Create record image
+ recordImage := &types.RecordImage{
+ TableName: "categories",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "code", Value: "CAT001",
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "name", Value:
"Electronics", KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "CATEGORIES:CAT001", lockKey)
+}
+
+func TestBuildLockKey_CompositeKeyMultipleRows(t *testing.T) {
+ // Create table meta with composite primary key
+ tableMeta := types.TableMeta{
+ TableName: "user_roles",
+ Columns: map[string]types.ColumnMeta{
+ "user_id": {ColumnName: "user_id", DatabaseType: 4},
+ "role_id": {ColumnName: "role_id", DatabaseType: 4},
+ "granted": {ColumnName: "granted", DatabaseType: 91},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "user_id", DatabaseType:
4},
+ {ColumnName: "role_id", DatabaseType:
4},
+ },
+ },
+ },
+ ColumnNames: []string{"user_id", "role_id", "granted"},
+ }
+
+ // Create record image with multiple rows
+ recordImage := &types.RecordImage{
+ TableName: "user_roles",
+ SQLType: types.SQLTypeInsert,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "user_id", Value:
int64(10), KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "role_id", Value:
int64(1), KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "granted", Value:
"2023-01-01", KeyType: types.IndexTypeNull},
+ },
+ },
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "user_id", Value:
int64(10), KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "role_id", Value:
int64(2), KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "granted", Value:
"2023-01-02", KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "USER_ROLES:10_1,10_2", lockKey)
+}
+
+func TestBuildLockKey_LowercaseTableName(t *testing.T) {
+ // Create table meta with lowercase table name
+ tableMeta := types.TableMeta{
+ TableName: "my_table",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ "data": {ColumnName: "data", DatabaseType: 12},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"id", "data"},
+ }
+
+ // Create record image
+ recordImage := &types.RecordImage{
+ TableName: "my_table",
+ SQLType: types.SQLTypeDelete,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(999),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "data", Value: "test",
KeyType: types.IndexTypeNull},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "MY_TABLE:999", lockKey)
+}
+
+func TestBuildLockKey_MixedNilValues(t *testing.T) {
+ // Create table meta with composite key
+ tableMeta := types.TableMeta{
+ TableName: "test_table",
+ Columns: map[string]types.ColumnMeta{
+ "pk1": {ColumnName: "pk1", DatabaseType: 4},
+ "pk2": {ColumnName: "pk2", DatabaseType: 4},
+ "pk3": {ColumnName: "pk3", DatabaseType: 4},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "pk1", DatabaseType: 4},
+ {ColumnName: "pk2", DatabaseType: 4},
+ {ColumnName: "pk3", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"pk1", "pk2", "pk3"},
+ }
+
+ // Create record image with mixed nil values
+ recordImage := &types.RecordImage{
+ TableName: "test_table",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "pk1", Value: int64(1),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "pk2", Value: nil,
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "pk3", Value: int64(3),
KeyType: types.IndexTypePrimaryKey},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "TEST_TABLE:1__3", lockKey)
+}
+
+func TestBuildLockKey_DifferentDataTypes(t *testing.T) {
+ // Create table meta
+ tableMeta := types.TableMeta{
+ TableName: "mixed_types",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ "code": {ColumnName: "code", DatabaseType: 12},
+ "version": {ColumnName: "version", DatabaseType: 4},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ {ColumnName: "code", DatabaseType: 12},
+ {ColumnName: "version", DatabaseType:
4},
+ },
+ },
+ },
+ ColumnNames: []string{"id", "code", "version"},
+ }
+
+ // Create record image
+ recordImage := &types.RecordImage{
+ TableName: "mixed_types",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(100),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "code", Value: "ABC",
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "version", Value:
int32(5), KeyType: types.IndexTypePrimaryKey},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "MIXED_TYPES:100_ABC_5", lockKey)
+}
+
+func TestBuildLockKey_ColumnOrderMatters(t *testing.T) {
+ // Create table meta where column order in Columns slice differs from
ColumnNames order
+ tableMeta := types.TableMeta{
+ TableName: "ordered_table",
+ Columns: map[string]types.ColumnMeta{
+ "pk1": {ColumnName: "pk1", DatabaseType: 4},
+ "pk2": {ColumnName: "pk2", DatabaseType: 4},
+ "pk3": {ColumnName: "pk3", DatabaseType: 4},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "pk1", DatabaseType: 4},
+ {ColumnName: "pk2", DatabaseType: 4},
+ {ColumnName: "pk3", DatabaseType: 4},
+ },
+ },
+ },
+ // The order in ColumnNames determines the lock key order
+ ColumnNames: []string{"pk1", "pk2", "pk3"},
+ }
+
+ // Create record image with columns in different order
+ recordImage := &types.RecordImage{
+ TableName: "ordered_table",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "pk3", Value: int64(3),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "pk1", Value: int64(1),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "pk2", Value: int64(2),
KeyType: types.IndexTypePrimaryKey},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ // Should respect the order defined in ColumnNames: pk1, pk2, pk3
+ assert.Equal(t, "ORDERED_TABLE:1_2_3", lockKey)
+}
+
+func TestBuildLockKey_PartialNilInCompositeKey(t *testing.T) {
+ tableMeta := types.TableMeta{
+ TableName: "test_partial_nil",
+ Columns: map[string]types.ColumnMeta{
+ "pk1": {ColumnName: "pk1", DatabaseType: 4},
+ "pk2": {ColumnName: "pk2", DatabaseType: 4},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "pk1", DatabaseType: 4},
+ {ColumnName: "pk2", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"pk1", "pk2"},
+ }
+
+ recordImage := &types.RecordImage{
+ TableName: "test_partial_nil",
+ SQLType: types.SQLTypeUpdate,
+ Rows: []types.RowImage{
+ {
+ Columns: []types.ColumnImage{
+ {ColumnName: "pk1", Value: int64(100),
KeyType: types.IndexTypePrimaryKey},
+ {ColumnName: "pk2", Value: nil,
KeyType: types.IndexTypePrimaryKey},
+ },
+ },
+ },
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Equal(t, "TEST_PARTIAL_NIL:100_", lockKey)
+}
+
+func TestBuildLockKey_LargeNumberOfRows(t *testing.T) {
+ tableMeta := types.TableMeta{
+ TableName: "large_batch",
+ Columns: map[string]types.ColumnMeta{
+ "id": {ColumnName: "id", DatabaseType: 4},
+ },
+ Indexs: map[string]types.IndexMeta{
+ "PRIMARY": {
+ IType: types.IndexTypePrimaryKey,
+ Columns: []types.ColumnMeta{
+ {ColumnName: "id", DatabaseType: 4},
+ },
+ },
+ },
+ ColumnNames: []string{"id"},
+ }
+
+ rows := make([]types.RowImage, 100)
+ for i := 0; i < 100; i++ {
+ rows[i] = types.RowImage{
+ Columns: []types.ColumnImage{
+ {ColumnName: "id", Value: int64(i + 1),
KeyType: types.IndexTypePrimaryKey},
+ },
+ }
+ }
+
+ recordImage := &types.RecordImage{
+ TableName: "large_batch",
+ SQLType: types.SQLTypeInsert,
+ Rows: rows,
+ }
+
+ lockKey := BuildLockKey(recordImage, tableMeta)
+ assert.Contains(t, lockKey, "LARGE_BATCH:")
+ assert.Contains(t, lockKey, "1,2,3")
+ assert.Contains(t, lockKey, ",99,100")
+}
diff --git a/pkg/datasource/sql/util/params_test.go
b/pkg/datasource/sql/util/params_test.go
new file mode 100644
index 00000000..cb6f54ab
--- /dev/null
+++ b/pkg/datasource/sql/util/params_test.go
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package util
+
+import (
+ "database/sql/driver"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNamedValueToValue(t *testing.T) {
+ tests := []struct {
+ name string
+ input []driver.NamedValue
+ expected []driver.Value
+ }{
+ {
+ name: "empty slice",
+ input: []driver.NamedValue{},
+ expected: []driver.Value{},
+ },
+ {
+ name: "single value",
+ input: []driver.NamedValue{
+ {Name: "param1", Ordinal: 1, Value: "test"},
+ },
+ expected: []driver.Value{"test"},
+ },
+ {
+ name: "multiple values",
+ input: []driver.NamedValue{
+ {Name: "param1", Ordinal: 1, Value: "test1"},
+ {Name: "param2", Ordinal: 2, Value: int64(123)},
+ {Name: "param3", Ordinal: 3, Value: true},
+ },
+ expected: []driver.Value{"test1", int64(123), true},
+ },
+ {
+ name: "nil values",
+ input: []driver.NamedValue{
+ {Name: "param1", Ordinal: 1, Value: nil},
+ {Name: "param2", Ordinal: 2, Value: "test"},
+ },
+ expected: []driver.Value{nil, "test"},
+ },
+ {
+ name: "various types",
+ input: []driver.NamedValue{
+ {Value: int64(42)},
+ {Value: float64(3.14)},
+ {Value: []byte("bytes")},
+ {Value: true},
+ },
+ expected: []driver.Value{int64(42), float64(3.14),
[]byte("bytes"), true},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := NamedValueToValue(tt.input)
+ assert.Equal(t, tt.expected, result)
+ assert.Equal(t, len(tt.expected), len(result))
+ })
+ }
+}
+
+func TestValueToNamedValue(t *testing.T) {
+ tests := []struct {
+ name string
+ input []driver.Value
+ expected []driver.NamedValue
+ }{
+ {
+ name: "empty slice",
+ input: []driver.Value{},
+ expected: []driver.NamedValue{},
+ },
+ {
+ name: "single value",
+ input: []driver.Value{"test"},
+ expected: []driver.NamedValue{
+ {Value: "test", Ordinal: 0},
+ },
+ },
+ {
+ name: "multiple values",
+ input: []driver.Value{"test1", int64(123), true},
+ expected: []driver.NamedValue{
+ {Value: "test1", Ordinal: 0},
+ {Value: int64(123), Ordinal: 1},
+ {Value: true, Ordinal: 2},
+ },
+ },
+ {
+ name: "nil values",
+ input: []driver.Value{nil, "test", nil},
+ expected: []driver.NamedValue{
+ {Value: nil, Ordinal: 0},
+ {Value: "test", Ordinal: 1},
+ {Value: nil, Ordinal: 2},
+ },
+ },
+ {
+ name: "various types",
+ input: []driver.Value{int64(42), float64(3.14),
[]byte("bytes"), true},
+ expected: []driver.NamedValue{
+ {Value: int64(42), Ordinal: 0},
+ {Value: float64(3.14), Ordinal: 1},
+ {Value: []byte("bytes"), Ordinal: 2},
+ {Value: true, Ordinal: 3},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ValueToNamedValue(tt.input)
+ assert.Equal(t, len(tt.expected), len(result))
+ for i := range result {
+ assert.Equal(t, tt.expected[i].Value,
result[i].Value)
+ assert.Equal(t, tt.expected[i].Ordinal,
result[i].Ordinal)
+ }
+ })
+ }
+}
+
+func TestRoundTripConversion(t *testing.T) {
+ original := []driver.Value{"test", int64(123), true, nil}
+ namedValues := ValueToNamedValue(original)
+ converted := NamedValueToValue(namedValues)
+
+ assert.Equal(t, original, converted)
+}
+
+func TestNamedValueToValuePreservesOrder(t *testing.T) {
+ input := []driver.NamedValue{
+ {Name: "z", Ordinal: 2, Value: "third"},
+ {Name: "a", Ordinal: 0, Value: "first"},
+ {Name: "m", Ordinal: 1, Value: "second"},
+ }
+
+ result := NamedValueToValue(input)
+ expected := []driver.Value{"third", "first", "second"}
+
+ assert.Equal(t, expected, result)
+}
+
+func TestValueToNamedValueOrdinalSequence(t *testing.T) {
+ input := []driver.Value{"a", "b", "c", "d", "e"}
+ result := ValueToNamedValue(input)
+
+ for i, nv := range result {
+ assert.Equal(t, i, nv.Ordinal)
+ assert.Equal(t, input[i], nv.Value)
+ }
+}
diff --git a/pkg/datasource/sql/util/sql_test.go
b/pkg/datasource/sql/util/sql_test.go
new file mode 100644
index 00000000..011185b0
--- /dev/null
+++ b/pkg/datasource/sql/util/sql_test.go
@@ -0,0 +1,570 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package util
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+ "io"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// Mock driver.Rows implementation
+type mockDriverRows struct {
+ columns []string
+ data [][]driver.Value
+ currentRow int
+ closed bool
+ closeErr error
+ nextErr error
+ mu sync.Mutex
+}
+
+func newMockDriverRows(columns []string, data [][]driver.Value)
*mockDriverRows {
+ return &mockDriverRows{
+ columns: columns,
+ data: data,
+ currentRow: -1,
+ }
+}
+
+func (m *mockDriverRows) Columns() []string {
+ return m.columns
+}
+
+func (m *mockDriverRows) Close() error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.closed = true
+ return m.closeErr
+}
+
+func (m *mockDriverRows) Next(dest []driver.Value) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.nextErr != nil {
+ return m.nextErr
+ }
+ m.currentRow++
+ if m.currentRow >= len(m.data) {
+ return io.EOF
+ }
+ copy(dest, m.data[m.currentRow])
+ return nil
+}
+
+// Mock driver.Rows with NextResultSet support
+type mockDriverRowsWithNextResultSet struct {
+ *mockDriverRows
+ hasNextResultSet bool
+ nextResultSetErr error
+ nextResultSetCalled int
+}
+
+func newMockDriverRowsWithNextResultSet(columns []string, data
[][]driver.Value, hasNext bool) *mockDriverRowsWithNextResultSet {
+ return &mockDriverRowsWithNextResultSet{
+ mockDriverRows: newMockDriverRows(columns, data),
+ hasNextResultSet: hasNext,
+ }
+}
+
+func (m *mockDriverRowsWithNextResultSet) HasNextResultSet() bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.hasNextResultSet
+}
+
+func (m *mockDriverRowsWithNextResultSet) NextResultSet() error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.nextResultSetErr != nil {
+ return m.nextResultSetErr
+ }
+ m.nextResultSetCalled++
+ m.hasNextResultSet = false
+ m.currentRow = -1
+ return nil
+}
+
+func TestNewScanRows(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id", "name"}, nil)
+ scanRows := NewScanRows(mockRows)
+
+ assert.NotNil(t, scanRows)
+ assert.Equal(t, mockRows, scanRows.rowsi)
+ assert.False(t, scanRows.closed)
+ assert.Nil(t, scanRows.lasterr)
+}
+
+func TestScanRows_Next_WithData(t *testing.T) {
+ data := [][]driver.Value{
+ {int64(1), "Alice"},
+ {int64(2), "Bob"},
+ {int64(3), "Charlie"},
+ }
+ mockRows := newMockDriverRows([]string{"id", "name"}, data)
+ scanRows := NewScanRows(mockRows)
+
+ // First row
+ assert.True(t, scanRows.Next())
+ assert.Equal(t, int64(1), scanRows.lastcols[0])
+ assert.Equal(t, "Alice", scanRows.lastcols[1])
+
+ // Second row
+ assert.True(t, scanRows.Next())
+ assert.Equal(t, int64(2), scanRows.lastcols[0])
+ assert.Equal(t, "Bob", scanRows.lastcols[1])
+
+ // Third row
+ assert.True(t, scanRows.Next())
+ assert.Equal(t, int64(3), scanRows.lastcols[0])
+ assert.Equal(t, "Charlie", scanRows.lastcols[1])
+
+ // No more rows
+ assert.False(t, scanRows.Next())
+}
+
+func TestScanRows_Next_EmptyResult(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id", "name"},
[][]driver.Value{})
+ scanRows := NewScanRows(mockRows)
+
+ assert.False(t, scanRows.Next())
+ assert.Equal(t, io.EOF, scanRows.lasterr)
+}
+
+// Test ScanRows.Next with error
+func TestScanRows_Next_WithError(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ expectedErr := errors.New("database error")
+ mockRows.nextErr = expectedErr
+ scanRows := NewScanRows(mockRows)
+
+ assert.False(t, scanRows.Next())
+ assert.Equal(t, expectedErr, scanRows.lasterr)
+}
+
+func TestScanRows_Scan_BasicTypes(t *testing.T) {
+ data := [][]driver.Value{
+ {int64(123), "test", true, 45.67},
+ }
+ mockRows := newMockDriverRows([]string{"id", "name", "active",
"score"}, data)
+ scanRows := NewScanRows(mockRows)
+
+ assert.True(t, scanRows.Next())
+
+ var id int64
+ var name string
+ var active bool
+ var score float64
+
+ err := scanRows.Scan(&id, &name, &active, &score)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), id)
+ assert.Equal(t, "test", name)
+ assert.True(t, active)
+ assert.InDelta(t, 45.67, score, 0.001)
+}
+
+func TestScanRows_Scan_WithNilValues(t *testing.T) {
+ data := [][]driver.Value{
+ {int64(1), nil},
+ }
+ mockRows := newMockDriverRows([]string{"id", "name"}, data)
+ scanRows := NewScanRows(mockRows)
+
+ assert.True(t, scanRows.Next())
+
+ var id int64
+ var name string
+
+ err := scanRows.Scan(&id, &name)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, "", name) // nil should keep default value
+}
+
+func TestScanRows_Scan_WithoutNext(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+
+ var id int64
+ err := scanRows.Scan(&id)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "Scan called without calling Next")
+}
+
+func TestScanRows_Scan_WrongArgCount(t *testing.T) {
+ data := [][]driver.Value{
+ {int64(1), "test"},
+ }
+ mockRows := newMockDriverRows([]string{"id", "name"}, data)
+ scanRows := NewScanRows(mockRows)
+
+ assert.True(t, scanRows.Next())
+
+ var id int64
+ err := scanRows.Scan(&id) // Should expect 2 arguments
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "expected 2 destination arguments")
+}
+
+func TestScanRows_Scan_WhenClosed(t *testing.T) {
+ data := [][]driver.Value{
+ {int64(1), "test"},
+ }
+ mockRows := newMockDriverRows([]string{"id", "name"}, data)
+ scanRows := NewScanRows(mockRows)
+ scanRows.closed = true
+
+ var id int64
+ var name string
+ err := scanRows.Scan(&id, &name)
+ assert.Error(t, err)
+ assert.Equal(t, errRowsClosed, err)
+}
+
+func TestScanRows_Err(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+
+ // No error initially
+ assert.NoError(t, scanRows.Err())
+
+ // Set an error
+ expectedErr := errors.New("test error")
+ scanRows.lasterr = expectedErr
+ assert.Equal(t, expectedErr, scanRows.Err())
+}
+
+func TestScanRows_Close(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+
+ releaseCalled := false
+ scanRows.releaseConn = func(error) {
+ releaseCalled = true
+ }
+
+ err := scanRows.close(nil)
+ assert.NoError(t, err)
+ assert.True(t, mockRows.closed)
+ assert.True(t, releaseCalled)
+}
+
+func TestScanRows_Close_Idempotent(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+ scanRows.releaseConn = func(error) {}
+
+ // First close
+ err1 := scanRows.close(nil)
+ assert.NoError(t, err1)
+
+ // Second close should not fail
+ err2 := scanRows.close(nil)
+ assert.NoError(t, err2)
+}
+
+func TestScanRows_Close_WithError(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ expectedErr := errors.New("close error")
+ mockRows.closeErr = expectedErr
+ scanRows := NewScanRows(mockRows)
+ scanRows.releaseConn = func(error) {}
+
+ err := scanRows.close(nil)
+ assert.Error(t, err)
+ assert.Equal(t, expectedErr, err)
+}
+
+func TestScanRows_NextResultSet(t *testing.T) {
+ mockRows := newMockDriverRowsWithNextResultSet(
+ []string{"id"},
+ [][]driver.Value{{int64(1)}},
+ true,
+ )
+ scanRows := NewScanRows(mockRows)
+
+ result := scanRows.NextResultSet()
+ assert.True(t, result)
+ assert.Equal(t, 1, mockRows.nextResultSetCalled)
+}
+
+func TestScanRows_NextResultSet_NoNext(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+
+ // mockDriverRows doesn't implement RowsNextResultSet
+ assert.False(t, scanRows.NextResultSet())
+}
+
+func TestScanRows_NextResultSet_WithError(t *testing.T) {
+ expectedErr := errors.New("next result set error")
+ mockRows := newMockDriverRowsWithNextResultSet(
+ []string{"id"},
+ nil,
+ false,
+ )
+ mockRows.nextResultSetErr = expectedErr
+ scanRows := NewScanRows(mockRows)
+
+ assert.False(t, scanRows.NextResultSet())
+ assert.Equal(t, expectedErr, scanRows.lasterr)
+}
+
+func TestScanRows_NextResultSet_WhenClosed(t *testing.T) {
+ mockRows := newMockDriverRowsWithNextResultSet([]string{"id"}, nil,
false)
+ scanRows := NewScanRows(mockRows)
+ scanRows.closed = true
+
+ assert.False(t, scanRows.NextResultSet())
+}
+
+func TestScanRows_ContextCancellation(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"},
[][]driver.Value{{int64(1)}})
+ scanRows := NewScanRows(mockRows)
+
+ // 使用 channel 来确保清理完成
+ done := make(chan struct{})
+ scanRows.releaseConn = func(error) {
+ close(done)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ scanRows.initContextClose(ctx, nil)
+
+ // Cancel context
+ cancel()
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("Context cancellation cleanup timeout")
+ }
+
+ scanRows.closemu.RLock()
+ closed := scanRows.closed
+ scanRows.closemu.RUnlock()
+
+ assert.True(t, closed, "Rows should be closed after context
cancellation")
+}
+
+func TestScanRows_TransactionContext(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"},
[][]driver.Value{{int64(1)}})
+ scanRows := NewScanRows(mockRows)
+
+ done := make(chan struct{})
+ scanRows.releaseConn = func(error) {
+ close(done)
+ }
+
+ ctx := context.Background()
+ txctx, txcancel := context.WithCancel(context.Background())
+ scanRows.initContextClose(ctx, txctx)
+
+ // Cancel transaction context
+ txcancel()
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("Transaction context cancellation cleanup timeout")
+ }
+
+ scanRows.closemu.RLock()
+ closed := scanRows.closed
+ scanRows.closemu.RUnlock()
+
+ assert.True(t, closed, "Rows should be closed after transaction context
cancellation")
+}
+
+func TestScanRows_BypassRowsAwaitDone(t *testing.T) {
+ originalBypass := bypassRowsAwaitDone
+
+ t.Cleanup(func() {
+ bypassRowsAwaitDone = originalBypass
+ })
+
+ bypassRowsAwaitDone = true
+
+ mockRows := newMockDriverRows([]string{"id"},
[][]driver.Value{{int64(1)}})
+ scanRows := NewScanRows(mockRows)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ scanRows.initContextClose(ctx, nil)
+
+ cancel()
+
+ scanRows.closemu.RLock()
+ closed := scanRows.closed
+ scanRows.closemu.RUnlock()
+
+ assert.False(t, closed, "Rows should NOT be closed when bypass is
enabled")
+}
+
+func TestWithLock(t *testing.T) {
+ var mu sync.Mutex
+ counter := 0
+
+ withLock(&mu, func() {
+ counter++
+ })
+
+ assert.Equal(t, 1, counter)
+}
+
+func TestWithLock_WithPanic(t *testing.T) {
+ var mu sync.Mutex
+
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic")
+ }
+ }()
+
+ withLock(&mu, func() {
+ panic("test panic")
+ })
+}
+
+func TestScanRows_LasterrOrErrLocked(t *testing.T) {
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+
+ // No lasterr
+ err := scanRows.lasterrOrErrLocked(errors.New("test error"))
+ assert.Equal(t, "test error", err.Error())
+
+ // With lasterr
+ scanRows.lasterr = errors.New("last error")
+ err = scanRows.lasterrOrErrLocked(errors.New("test error"))
+ assert.Equal(t, "last error", err.Error())
+
+ // With io.EOF as lasterr
+ scanRows.lasterr = io.EOF
+ err = scanRows.lasterrOrErrLocked(errors.New("test error"))
+ assert.Equal(t, "test error", err.Error())
+}
+
+func TestScanRows_Scan_WithConversion(t *testing.T) {
+ data := [][]driver.Value{
+ {"123", "45.67", "true"},
+ }
+ mockRows := newMockDriverRows([]string{"id", "score", "active"}, data)
+ scanRows := NewScanRows(mockRows)
+
+ assert.True(t, scanRows.Next())
+
+ var id int64
+ var score float64
+ var active bool
+
+ err := scanRows.Scan(&id, &score, &active)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), id)
+ assert.InDelta(t, 45.67, score, 0.001)
+ assert.True(t, active)
+}
+
+func TestScanRows_MultipleResultSets(t *testing.T) {
+ // Create mock rows with data for first result set
+ mockRows := newMockDriverRowsWithNextResultSet(
+ []string{"id"},
+ [][]driver.Value{{int64(1)}},
+ true,
+ )
+ scanRows := NewScanRows(mockRows)
+
+ // First result set - read the row
+ assert.True(t, scanRows.Next())
+ var id int64
+ err := scanRows.Scan(&id)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), id)
+
+ // No more rows in first result set
+ assert.False(t, scanRows.Next())
+
+ // Move to next result set
+ assert.True(t, scanRows.NextResultSet())
+ assert.Equal(t, 1, mockRows.nextResultSetCalled)
+}
+
+func TestScanRows_CloseHook(t *testing.T) {
+ originalHook := rowsCloseHook
+
+ t.Cleanup(func() {
+ rowsCloseHook = originalHook
+ })
+
+ hookCalled := make(chan struct{}, 1)
+ rowsCloseHook = func() func(*ScanRows, *error) {
+ return func(rs *ScanRows, err *error) {
+ select {
+ case hookCalled <- struct{}{}:
+ default:
+ }
+ }
+ }
+
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+ scanRows.releaseConn = func(error) {}
+
+ err := scanRows.close(nil)
+ assert.NoError(t, err)
+
+ select {
+ case <-hookCalled:
+ case <-time.After(2 * time.Second):
+ t.Fatal("hook called timeout")
+ }
+}
+
+func TestScanRows_WithCancelFunction(t *testing.T) {
+ cancelCalled := false
+ mockRows := newMockDriverRows([]string{"id"}, nil)
+ scanRows := NewScanRows(mockRows)
+ scanRows.releaseConn = func(error) {}
+ scanRows.cancel = func() {
+ cancelCalled = true
+ }
+
+ scanRows.close(nil)
+
+ assert.True(t, cancelCalled)
+}
+
+func TestScanRows_Next_HasNextResultSetFalse(t *testing.T) {
+ mockRows := newMockDriverRowsWithNextResultSet(
+ []string{"id"},
+ [][]driver.Value{},
+ false,
+ )
+ scanRows := NewScanRows(mockRows)
+ scanRows.releaseConn = func(error) {}
+
+ assert.False(t, scanRows.Next())
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]